Solve conflicts + pre-commit run -a

This commit is contained in:
Cadene 2024-02-29 23:31:32 +00:00
parent 0b9027f05e
commit ae050d2e94
8 changed files with 26 additions and 41 deletions

View File

@ -108,7 +108,10 @@ eval_episodes=7
**Style** **Style**
``` ```
# install if needed
pre-commit install pre-commit install
# apply style and linter checks before git commit
pre-commit run -a
``` ```
**Tests** **Tests**

View File

@ -9,19 +9,14 @@ import pymunk
import torch import torch
import torchrl import torchrl
import tqdm import tqdm
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import (
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import (
Sampler,
)
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
# as define in env # as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage, SUCCESS_THRESHOLD = 0.95 # 95% coverage,

View File

@ -8,9 +8,7 @@ import torchrl
import tqdm import tqdm
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import ( from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import ( from torchrl.data.replay_buffers.samplers import (
Sampler, Sampler,
SliceSampler, SliceSampler,

View File

@ -1,13 +1,12 @@
import contextlib import contextlib
import datetime
import os import os
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import pandas as pd
from omegaconf import OmegaConf from omegaconf import OmegaConf
from termcolor import colored from termcolor import colored
def make_dir(dir_path): def make_dir(dir_path):
"""Create directory if it does not already exist.""" """Create directory if it does not already exist."""
with contextlib.suppress(OSError): with contextlib.suppress(OSError):

View File

@ -5,7 +5,6 @@ import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
@ -128,12 +127,8 @@ class DiffusionPolicy(nn.Module):
out = { out = {
"obs": { "obs": {
"image": batch["observation", "image"].to( "image": batch["observation", "image"].to(self.device, non_blocking=True),
self.device, non_blocking=True "agent_pos": batch["observation", "state"].to(self.device, non_blocking=True),
),
"agent_pos": batch["observation", "state"].to(
self.device, non_blocking=True
),
}, },
"action": batch["action"].to(self.device, non_blocking=True), "action": batch["action"].to(self.device, non_blocking=True),
} }

View File

@ -33,7 +33,7 @@ def init_logging():
logging.getLogger().addHandler(console_handler) logging.getLogger().addHandler(console_handler)
def format_number_KMB(num): def format_big_number(num):
suffixes = ["", "K", "M", "B", "T", "Q"] suffixes = ["", "K", "M", "B", "T", "Q"]
divisor = 1000.0 divisor = 1000.0

View File

@ -1,5 +1,4 @@
import logging import logging
import time
import hydra import hydra
import numpy as np import numpy as np
@ -13,7 +12,7 @@ from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger from lerobot.common.logger import Logger
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import format_number_KMB, init_logging, set_seed from lerobot.common.utils import format_big_number, init_logging, set_seed
from lerobot.scripts.eval import eval_policy from lerobot.scripts.eval import eval_policy
@ -49,11 +48,11 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
num_episodes = num_samples / avg_samples_per_ep num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / offline_buffer.num_samples num_epochs = num_samples / offline_buffer.num_samples
log_items = [ log_items = [
f"step:{format_number_KMB(step)}", f"step:{format_big_number(step)}",
# number of samples seen during training # number of samples seen during training
f"smpl:{format_number_KMB(num_samples)}", f"smpl:{format_big_number(num_samples)}",
# number of episodes seen during training # number of episodes seen during training
f"ep:{format_number_KMB(num_episodes)}", f"ep:{format_big_number(num_episodes)}",
# number of time all unique samples are seen # number of time all unique samples are seen
f"epch:{num_epochs:.2f}", f"epch:{num_epochs:.2f}",
f"loss:{loss:.3f}", f"loss:{loss:.3f}",
@ -86,11 +85,11 @@ def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline):
num_episodes = num_samples / avg_samples_per_ep num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / offline_buffer.num_samples num_epochs = num_samples / offline_buffer.num_samples
log_items = [ log_items = [
f"step:{format_number_KMB(step)}", f"step:{format_big_number(step)}",
# number of samples seen during training # number of samples seen during training
f"smpl:{format_number_KMB(num_samples)}", f"smpl:{format_big_number(num_samples)}",
# number of episodes seen during training # number of episodes seen during training
f"ep:{format_number_KMB(num_episodes)}", f"ep:{format_big_number(num_episodes)}",
# number of time all unique samples are seen # number of time all unique samples are seen
f"epch:{num_epochs:.2f}", f"epch:{num_epochs:.2f}",
f"∑rwrd:{avg_sum_reward:.3f}", f"∑rwrd:{avg_sum_reward:.3f}",
@ -156,7 +155,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
logger = Logger(out_dir, job_name, cfg) logger = Logger(out_dir, job_name, cfg)
online_ep_idx = 0
step = 0 # number of policy update step = 0 # number of policy update
is_offline = True is_offline = True
@ -200,7 +198,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
auto_cast_to_device=True, auto_cast_to_device=True,
) )
assert len(rollout) <= cfg.env.episode_length assert len(rollout) <= cfg.env.episode_length
rollout["episode"] = torch.tensor([online_ep_idx] * len(rollout), dtype=torch.int) # set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout) online_buffer.extend(rollout)
ep_sum_reward = rollout["next", "reward"].sum() ep_sum_reward = rollout["next", "reward"].sum()
@ -210,12 +209,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
"avg_sum_reward": np.nanmean(ep_sum_reward), "avg_sum_reward": np.nanmean(ep_sum_reward),
"avg_max_reward": np.nanmean(ep_max_reward), "avg_max_reward": np.nanmean(ep_max_reward),
"pc_success": np.nanmean(ep_success) * 100, "pc_success": np.nanmean(ep_success) * 100,
"online_ep_idx": online_ep_idx, "env_step": env_step,
"ep_length": len(rollout), "ep_length": len(rollout),
} }
online_ep_idx += 1
for _ in range(cfg.policy.utd): for _ in range(cfg.policy.utd):
train_info = policy.update( train_info = policy.update(
online_buffer, online_buffer,
@ -233,7 +230,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_episodes=cfg.eval_episodes, num_episodes=cfg.eval_episodes,
return_first_video=True, return_first_video=True,
) )
log_eval_info(L, eval_info, step, cfg, offline_buffer, is_offline) log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable: if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval") logger.log_video(first_video, step, mode="eval")

View File

@ -3,9 +3,7 @@ from pathlib import Path
import hydra import hydra
import imageio import imageio
import torch import torch
from torchrl.data.replay_buffers import ( from torchrl.data.replay_buffers import SliceSamplerWithoutReplacement
SliceSamplerWithoutReplacement,
)
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer