Merge pull request from Cadene/pre-commit

Style & Formatting
This commit is contained in:
Simon Alibert 2024-02-29 21:48:19 +01:00 committed by GitHub
commit cb7b375526
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 296 additions and 237 deletions

33
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,33 @@
exclude: ^(data/|tests/|diffusion_policy/)
default_language_version:
python: python3.10
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-added-large-files
- id: debug-statements
- id: check-merge-conflict
- id: check-case-conflict
- id: check-yaml
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.1
hooks:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/python-poetry/poetry
rev: 1.8.0
hooks:
- id: poetry-check
- id: poetry-lock
args:
- "--check"
- "--no-update"

View File

@ -10,7 +10,7 @@ conda activate lerobot
[Install `poetry`](https://python-poetry.org/docs/#installation) (if you don't have it already) [Install `poetry`](https://python-poetry.org/docs/#installation) (if you don't have it already)
``` ```
curl -sSL https://install.python-poetry.org | python3 - curl -sSL https://install.python-poetry.org | python -
``` ```
Install dependencies Install dependencies
@ -26,6 +26,7 @@ export TMPDIR='~/tmp'
Install `diffusion_policy` #HACK Install `diffusion_policy` #HACK
``` ```
# from this directory
git clone https://github.com/real-stanford/diffusion_policy git clone https://github.com/real-stanford/diffusion_policy
cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/ cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/
``` ```
@ -107,11 +108,10 @@ eval_episodes=7
**Style** **Style**
``` ```
isort lerobot && isort tests && black lerobot && black tests pre-commit install
pylint lerobot && pylint tests # not enforce for now
``` ```
**Tests** **Tests**
``` ```
pytest -sx tests pytest -sx tests
``` ```

View File

@ -70,6 +70,7 @@ def make_offline_buffer(cfg, sampler=None):
offline_buffer = PushtExperienceReplay( offline_buffer = PushtExperienceReplay(
"pusht", "pusht",
# download="force", # download="force",
# TODO(aliberts): automate download
download=False, download=False,
streaming=False, streaming=False,
root="data", root="data",

View File

@ -1,7 +1,6 @@
import os import os
import pickle
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Tuple from typing import Callable
import einops import einops
import numpy as np import numpy as np
@ -10,25 +9,25 @@ import pymunk
import torch import torch
import torchrl import torchrl
import tqdm import tqdm
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import ( from torchrl.data.replay_buffers.replay_buffers import (
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer, TensorDictReplayBuffer,
) )
from torchrl.data.replay_buffers.samplers import ( from torchrl.data.replay_buffers.samplers import (
Sampler, Sampler,
SliceSampler,
SliceSamplerWithoutReplacement,
) )
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
# as define in env # as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage, SUCCESS_THRESHOLD = 0.95 # 95% coverage,
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
def get_goal_pose_body(pose): def get_goal_pose_body(pose):
mass = 1 mass = 1
@ -53,7 +52,7 @@ def add_tee(
angle, angle,
scale=30, scale=30,
color="LightSlateGray", color="LightSlateGray",
mask=pymunk.ShapeFilter.ALL_MASKS(), mask=DEFAULT_TEE_MASK,
): ):
mass = 1 mass = 1
length = 4 length = 4
@ -87,7 +86,6 @@ def add_tee(
class PushtExperienceReplay(TensorDictReplayBuffer): class PushtExperienceReplay(TensorDictReplayBuffer):
def __init__( def __init__(
self, self,
dataset_id, dataset_id,
@ -127,7 +125,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
if split_trajs: if split_trajs:
raise NotImplementedError raise NotImplementedError
if self.download == True: if self.download:
raise NotImplementedError() raise NotImplementedError()
if root is None: if root is None:
@ -193,18 +191,18 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
# TODO(rcadene) # TODO(rcadene)
# load # load
# TODO(aliberts): Dynamic paths
zarr_path = ( zarr_path = (
"/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr" "/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
# "/home/simon/build/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
) )
dataset_dict = ReplayBuffer.copy_from_path( dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
zarr_path
) # , keys=['img', 'state', 'action'])
episode_ids = dataset_dict.get_episode_idxs() episode_ids = dataset_dict.get_episode_idxs()
num_episodes = dataset_dict.meta["episode_ends"].shape[0] num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0] total_frames = dataset_dict["action"].shape[0]
assert len( assert len(
set([dataset_dict[key].shape[0] for key in dataset_dict.keys()]) {dataset_dict[key].shape[0] for key in dataset_dict}
), "Some data type dont have the same number of total frames." ), "Some data type dont have the same number of total frames."
# TODO: verify that goal pose is expected to be fixed # TODO: verify that goal pose is expected to be fixed
@ -245,9 +243,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
] ]
space.add(*walls) space.add(*walls)
block_body = add_tee( block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item())
space, block_pos[i].tolist(), block_angle[i].item()
)
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes) block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area intersection_area = goal_geom.intersection(block_geom).area
@ -278,11 +274,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
if episode_id == 0: if episode_id == 0:
# hack to initialize tensordict data structure to store episodes # hack to initialize tensordict data structure to store episodes
td_data = ( td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
episode[0]
.expand(total_frames)
.memmap_like(self.root / self.dataset_id)
)
td_data[idxtd : idxtd + len(episode)] = episode td_data[idxtd : idxtd + len(episode)] = episode

View File

@ -1,7 +1,7 @@
import os import os
import pickle import pickle
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Tuple from typing import Callable
import torch import torch
import torchrl import torchrl
@ -9,7 +9,6 @@ import tqdm
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import ( from torchrl.data.replay_buffers.replay_buffers import (
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer, TensorDictReplayBuffer,
) )
from torchrl.data.replay_buffers.samplers import ( from torchrl.data.replay_buffers.samplers import (
@ -22,7 +21,6 @@ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
class SimxarmExperienceReplay(TensorDictReplayBuffer): class SimxarmExperienceReplay(TensorDictReplayBuffer):
available_datasets = [ available_datasets = [
"xarm_lift_medium", "xarm_lift_medium",
] ]
@ -77,15 +75,11 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
if num_slices is not None or slice_len is not None: if num_slices is not None or slice_len is not None:
if sampler is not None: if sampler is not None:
raise ValueError( raise ValueError("`num_slices` and `slice_len` are exclusive with the `sampler` argument.")
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
)
if replacement: if replacement:
if not self.shuffle: if not self.shuffle:
raise RuntimeError( raise RuntimeError("shuffle=False can only be used when replacement=False.")
"shuffle=False can only be used when replacement=False."
)
sampler = SliceSampler( sampler = SliceSampler(
num_slices=num_slices, num_slices=num_slices,
slice_len=slice_len, slice_len=slice_len,
@ -130,7 +124,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
# load # load
dataset_dir = Path("data") / self.dataset_id dataset_dir = Path("data") / self.dataset_id
dataset_path = dataset_dir / f"buffer.pkl" dataset_path = dataset_dir / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'") print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f: with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f) dataset_dict = pickle.load(f)
@ -150,12 +144,8 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
next_image = torch.tensor( next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
dataset_dict["next_observations"]["rgb"][idx0:idx1] next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
)
next_state = torch.tensor(
dataset_dict["next_observations"]["state"][idx0:idx1]
)
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1]) next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
@ -176,11 +166,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
if episode_id == 0: if episode_id == 0:
# hack to initialize tensordict data structure to store episodes # hack to initialize tensordict data structure to store episodes
td_data = ( td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
episode[0]
.expand(total_frames)
.memmap_like(self.root / self.dataset_id)
)
td_data[idx0:idx1] = episode td_data[idx0:idx1] = episode

View File

@ -1,7 +1,6 @@
import importlib import importlib
from typing import Optional from typing import Optional
import numpy as np
import torch import torch
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.tensor_specs import ( from torchrl.data.tensor_specs import (
@ -20,7 +19,6 @@ _has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _
class PushtEnv(EnvBase): class PushtEnv(EnvBase):
def __init__( def __init__(
self, self,
frame_skip: int = 1, frame_skip: int = 1,
@ -46,7 +44,8 @@ class PushtEnv(EnvBase):
if not _has_gym: if not _has_gym:
raise ImportError("Cannot import gym.") raise ImportError("Cannot import gym.")
from diffusion_policy.env.pusht.pusht_env import PushTEnv # TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
# from diffusion_policy.env.pusht.pusht_env import PushTEnv
if not from_pixels: if not from_pixels:
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv") raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
@ -71,14 +70,10 @@ class PushtEnv(EnvBase):
obs = {"image": torch.from_numpy(raw_obs["image"])} obs = {"image": torch.from_numpy(raw_obs["image"])}
if not self.pixels_only: if not self.pixels_only:
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type( obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
torch.float32
)
else: else:
# TODO: # TODO:
obs = { obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)
}
obs = TensorDict(obs, batch_size=[]) obs = TensorDict(obs, batch_size=[])
return obs return obs
@ -109,7 +104,7 @@ class PushtEnv(EnvBase):
# step expects shape=(4,) so we pad if necessary # step expects shape=(4,) so we pad if necessary
# TODO(rcadene): add info["is_success"] and info["success"] ? # TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0 sum_reward = 0
for t in range(self.frame_skip): for _ in range(self.frame_skip):
raw_obs, reward, done, info = self._env.step(action) raw_obs, reward, done, info = self._env.step(action)
sum_reward += reward sum_reward += reward

View File

@ -15,12 +15,13 @@ from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.utils import set_seed from lerobot.common.utils import set_seed
MAX_NUM_ACTIONS = 4
_has_gym = importlib.util.find_spec("gym") is not None _has_gym = importlib.util.find_spec("gym") is not None
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym _has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
class SimxarmEnv(EnvBase): class SimxarmEnv(EnvBase):
def __init__( def __init__(
self, self,
task, task,
@ -52,18 +53,13 @@ class SimxarmEnv(EnvBase):
from simxarm import TASKS from simxarm import TASKS
if self.task not in TASKS: if self.task not in TASKS:
raise ValueError( raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}"
)
self._env = TASKS[self.task]["env"]() self._env = TASKS[self.task]["env"]()
MAX_NUM_ACTIONS = 4
num_actions = len(TASKS[self.task]["action_space"]) num_actions = len(TASKS[self.task]["action_space"])
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,)) self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
self._action_padding = np.zeros( self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
(MAX_NUM_ACTIONS - num_actions), dtype=np.float32
)
if "w" not in TASKS[self.task]["action_space"]: if "w" not in TASKS[self.task]["action_space"]:
self._action_padding[-1] = 1.0 self._action_padding[-1] = 1.0
@ -75,9 +71,7 @@ class SimxarmEnv(EnvBase):
def _format_raw_obs(self, raw_obs): def _format_raw_obs(self, raw_obs):
if self.from_pixels: if self.from_pixels:
image = self.render( image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
mode="rgb_array", width=self.image_size, height=self.image_size
)
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W) image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
image = torch.tensor(image.copy(), dtype=torch.uint8) image = torch.tensor(image.copy(), dtype=torch.uint8)
@ -114,7 +108,7 @@ class SimxarmEnv(EnvBase):
action = np.concatenate([action, self._action_padding]) action = np.concatenate([action, self._action_padding])
# TODO(rcadene): add info["is_success"] and info["success"] ? # TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0 sum_reward = 0
for t in range(self.frame_skip): for _ in range(self.frame_skip):
raw_obs, reward, done, info = self._env.step(action) raw_obs, reward, done, info = self._env.step(action)
sum_reward += reward sum_reward += reward

View File

@ -5,7 +5,6 @@ from torchrl.envs.transforms import ObservationTransform
class Prod(ObservationTransform): class Prod(ObservationTransform):
def __init__(self, in_keys: Sequence[NestedKey], prod: float): def __init__(self, in_keys: Sequence[NestedKey], prod: float):
super().__init__() super().__init__()
self.in_keys = in_keys self.in_keys = in_keys

View File

@ -1,6 +1,6 @@
import contextlib
import datetime import datetime
import os import os
import re
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@ -29,10 +29,8 @@ AGENT_METRICS = [
def make_dir(dir_path): def make_dir(dir_path):
"""Create directory if it does not already exist.""" """Create directory if it does not already exist."""
try: with contextlib.suppress(OSError):
dir_path.mkdir(parents=True, exist_ok=True) dir_path.mkdir(parents=True, exist_ok=True)
except OSError:
pass
return dir_path return dir_path
@ -59,9 +57,7 @@ def print_run(cfg, reward=None):
# ('experiment', cfg.exp_name), # ('experiment', cfg.exp_name),
] ]
if reward is not None: if reward is not None:
kvs.append( kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"])))
("episode reward", colored(str(int(reward)), "white", attrs=["bold"]))
)
w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21 w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
div = "-" * w div = "-" * w
print(div) print(div)
@ -80,7 +76,7 @@ def cfg_to_group(cfg, return_list=False):
return lst if return_list else "-".join(lst) return lst if return_list else "-".join(lst)
class Logger(object): class Logger:
"""Primary logger object. Logs either locally or using wandb.""" """Primary logger object. Logs either locally or using wandb."""
def __init__(self, log_dir, job_name, cfg): def __init__(self, log_dir, job_name, cfg):
@ -183,7 +179,5 @@ class Logger(object):
if category == "eval": if category == "eval":
keys = ["step", "avg_sum_reward", "avg_max_reward", "pc_success"] keys = ["step", "avg_sum_reward", "avg_max_reward", "pc_success"]
self._eval.append(np.array([d[key] for key in keys])) self._eval.append(np.array([d[key] for key in keys]))
pd.DataFrame(np.array(self._eval)).to_csv( pd.DataFrame(np.array(self._eval)).to_csv(self._log_dir / "eval.log", header=keys, index=None)
self._log_dir / "eval.log", header=keys, index=None
)
self._print(d, category) self._print(d, category)

View File

@ -3,16 +3,17 @@ import copy
import hydra import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
FIRST_ACTION = 0
class DiffusionPolicy(nn.Module): class DiffusionPolicy(nn.Module):
def __init__( def __init__(
self, self,
cfg, cfg,
@ -105,7 +106,6 @@ class DiffusionPolicy(nn.Module):
out = self.diffusion.predict_action(obs_dict) out = self.diffusion.predict_action(obs_dict)
# TODO(rcadene): add possibility to return >1 timestemps # TODO(rcadene): add possibility to return >1 timestemps
FIRST_ACTION = 0
action = out["action"].squeeze(0)[FIRST_ACTION] action = out["action"].squeeze(0)[FIRST_ACTION]
return action return action
@ -132,10 +132,7 @@ class DiffusionPolicy(nn.Module):
} }
return out return out
if self.cfg.balanced_sampling: batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
batch = replay_buffer.sample(batch_size)
else:
batch = replay_buffer.sample()
batch = process_batch(batch, self.cfg.horizon, num_slices) batch = process_batch(batch, self.cfg.horizon, num_slices)
loss = self.diffusion.compute_loss(batch) loss = self.diffusion.compute_loss(batch)

View File

@ -1,3 +1,5 @@
# ruff: noqa: N806
from copy import deepcopy from copy import deepcopy
import einops import einops
@ -7,6 +9,8 @@ import torch.nn as nn
import lerobot.common.policies.tdmpc_helper as h import lerobot.common.policies.tdmpc_helper as h
FIRST_FRAME = 0
class TOLD(nn.Module): class TOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
@ -17,9 +21,7 @@ class TOLD(nn.Module):
self.cfg = cfg self.cfg = cfg
self._encoder = h.enc(cfg) self._encoder = h.enc(cfg)
self._dynamics = h.dynamics( self._dynamics = h.dynamics(cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim)
cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim
)
self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1) self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1)
self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim) self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim)
self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)]) self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)])
@ -65,17 +67,17 @@ class TOLD(nn.Module):
return h.TruncatedNormal(mu, std).sample(clip=0.3) return h.TruncatedNormal(mu, std).sample(clip=0.3)
return mu return mu
def V(self, z): def V(self, z): # noqa: N802
"""Predict state value (V).""" """Predict state value (V)."""
return self._V(z) return self._V(z)
def Q(self, z, a, return_type): def Q(self, z, a, return_type): # noqa: N802
"""Predict state-action value (Q).""" """Predict state-action value (Q)."""
assert return_type in {"min", "avg", "all"} assert return_type in {"min", "avg", "all"}
x = torch.cat([z, a], dim=-1) x = torch.cat([z, a], dim=-1)
if return_type == "all": if return_type == "all":
return torch.stack(list(q(x) for q in self._Qs), dim=0) return torch.stack([q(x) for q in self._Qs], dim=0)
idxs = np.random.choice(self.cfg.num_q, 2, replace=False) idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x) Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
@ -160,11 +162,7 @@ class TDMPC(nn.Module):
pi = self.model.pi(z, self.cfg.min_std) pi = self.model.pi(z, self.cfg.min_std)
G += discount * self.model.Q(z, pi, return_type="min") G += discount * self.model.Q(z, pi, return_type="min")
if self.cfg.uncertainty_cost > 0: if self.cfg.uncertainty_cost > 0:
G -= ( G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0)
discount
* self.cfg.uncertainty_cost
* self.model.Q(z, pi, return_type="all").std(dim=0)
)
return G return G
@torch.no_grad() @torch.no_grad()
@ -180,19 +178,13 @@ class TDMPC(nn.Module):
assert step is not None assert step is not None
# Seed steps # Seed steps
if step < self.cfg.seed_steps and self.model.training: if step < self.cfg.seed_steps and self.model.training:
return torch.empty( return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)
self.action_dim, dtype=torch.float32, device=self.device
).uniform_(-1, 1)
# Sample policy trajectories # Sample policy trajectories
horizon = int( horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))
)
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples) num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
if num_pi_trajs > 0: if num_pi_trajs > 0:
pi_actions = torch.empty( pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device)
horizon, num_pi_trajs, self.action_dim, device=self.device
)
_z = z.repeat(num_pi_trajs, 1) _z = z.repeat(num_pi_trajs, 1)
for t in range(horizon): for t in range(horizon):
pi_actions[t] = self.model.pi(_z, self.cfg.min_std) pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
@ -201,20 +193,16 @@ class TDMPC(nn.Module):
# Initialize state and parameters # Initialize state and parameters
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1) z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
mean = torch.zeros(horizon, self.action_dim, device=self.device) mean = torch.zeros(horizon, self.action_dim, device=self.device)
std = self.cfg.max_std * torch.ones( std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device)
horizon, self.action_dim, device=self.device
)
if not t0 and hasattr(self, "_prev_mean"): if not t0 and hasattr(self, "_prev_mean"):
mean[:-1] = self._prev_mean[1:] mean[:-1] = self._prev_mean[1:]
# Iterate CEM # Iterate CEM
for i in range(self.cfg.iterations): for _ in range(self.cfg.iterations):
actions = torch.clamp( actions = torch.clamp(
mean.unsqueeze(1) mean.unsqueeze(1)
+ std.unsqueeze(1) + std.unsqueeze(1)
* torch.randn( * torch.randn(horizon, self.cfg.num_samples, self.action_dim, device=std.device),
horizon, self.cfg.num_samples, self.action_dim, device=std.device
),
-1, -1,
1, 1,
) )
@ -223,18 +211,14 @@ class TDMPC(nn.Module):
# Compute elite actions # Compute elite actions
value = self.estimate_value(z, actions, horizon).nan_to_num_(0) value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
elite_idxs = torch.topk( elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
value.squeeze(1), self.cfg.num_elites, dim=0
).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update parameters # Update parameters
max_value = elite_value.max(0)[0] max_value = elite_value.max(0)[0]
score = torch.exp(self.cfg.temperature * (elite_value - max_value)) score = torch.exp(self.cfg.temperature * (elite_value - max_value))
score /= score.sum(0) score /= score.sum(0)
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / ( _mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
score.sum(0) + 1e-9
)
_std = torch.sqrt( _std = torch.sqrt(
torch.sum( torch.sum(
score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2, score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2,
@ -331,7 +315,6 @@ class TDMPC(nn.Module):
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
batch = batch.to(self.device) batch = batch.to(self.device)
FIRST_FRAME = 0
obs = { obs = {
"rgb": batch["observation", "image"][FIRST_FRAME].float(), "rgb": batch["observation", "image"][FIRST_FRAME].float(),
"state": batch["observation", "state"][FIRST_FRAME], "state": batch["observation", "state"][FIRST_FRAME],
@ -359,10 +342,7 @@ class TDMPC(nn.Module):
weights = batch["_weight"][FIRST_FRAME, :, None] weights = batch["_weight"][FIRST_FRAME, :, None]
return obs, action, next_obses, reward, mask, done, idxs, weights return obs, action, next_obses, reward, mask, done, idxs, weights
if self.cfg.balanced_sampling: batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
batch = replay_buffer.sample(batch_size)
else:
batch = replay_buffer.sample()
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch( obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
batch, self.cfg.horizon, num_slices batch, self.cfg.horizon, num_slices
@ -384,10 +364,7 @@ class TDMPC(nn.Module):
if isinstance(obs, dict): if isinstance(obs, dict):
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs} obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
next_obses = { next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1)
for k in next_obses
}
else: else:
obs = torch.cat([obs, demo_obs]) obs = torch.cat([obs, demo_obs])
next_obses = torch.cat([next_obses, demo_next_obses], dim=1) next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
@ -429,9 +406,7 @@ class TDMPC(nn.Module):
td_targets = self._td_target(next_z, reward, mask) td_targets = self._td_target(next_z, reward, mask)
# Latent rollout # Latent rollout
zs = torch.empty( zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device
)
reward_preds = torch.empty_like(reward, device=self.device) reward_preds = torch.empty_like(reward, device=self.device)
assert reward.shape[0] == horizon assert reward.shape[0] == horizon
z = self.model.encode(obs) z = self.model.encode(obs)
@ -452,12 +427,10 @@ class TDMPC(nn.Module):
value_info["V"] = v.mean().item() value_info["V"] = v.mean().item()
# Losses # Losses
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view( rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
-1, 1, 1 consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
dim=0
) )
consistency_loss = (
rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask
).sum(dim=0)
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0) reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
q_value_loss, priority_loss = 0, 0 q_value_loss, priority_loss = 0, 0
for q in range(self.cfg.num_q): for q in range(self.cfg.num_q):
@ -465,9 +438,7 @@ class TDMPC(nn.Module):
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0) priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
expectile = h.linear_schedule(self.cfg.expectile, step) expectile = h.linear_schedule(self.cfg.expectile, step)
v_value_loss = ( v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask
).sum(dim=0)
total_loss = ( total_loss = (
self.cfg.consistency_coef * consistency_loss self.cfg.consistency_coef * consistency_loss

View File

@ -5,11 +5,15 @@ import re
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F # noqa: N812
from torch import distributions as pyd from torch import distributions as pyd
from torch.distributions.utils import _standard_normal from torch.distributions.utils import _standard_normal
__REDUCE__ = lambda b: "mean" if b else "none" DEFAULT_ACT_FN = nn.Mish()
def __REDUCE__(b): # noqa: N802, N807
return "mean" if b else "none"
def l1(pred, target, reduce=False): def l1(pred, target, reduce=False):
@ -36,11 +40,7 @@ def l2_expectile(diff, expectile=0.7, reduce=False):
def _get_out_shape(in_shape, layers): def _get_out_shape(in_shape, layers):
"""Utility function. Returns the output shape of a network for a given input shape.""" """Utility function. Returns the output shape of a network for a given input shape."""
x = torch.randn(*in_shape).unsqueeze(0) x = torch.randn(*in_shape).unsqueeze(0)
return ( return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape
(nn.Sequential(*layers) if isinstance(layers, list) else layers)(x)
.squeeze(0)
.shape
)
def gaussian_logprob(eps, log_std): def gaussian_logprob(eps, log_std):
@ -73,7 +73,7 @@ def orthogonal_init(m):
def ema(m, m_target, tau): def ema(m, m_target, tau):
"""Update slow-moving average of online network (target network) at rate tau.""" """Update slow-moving average of online network (target network) at rate tau."""
with torch.no_grad(): with torch.no_grad():
for p, p_target in zip(m.parameters(), m_target.parameters()): for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False):
p_target.data.lerp_(p.data, tau) p_target.data.lerp_(p.data, tau)
@ -86,6 +86,8 @@ def set_requires_grad(net, value):
class TruncatedNormal(pyd.Normal): class TruncatedNormal(pyd.Normal):
"""Utility class implementing the truncated normal distribution.""" """Utility class implementing the truncated normal distribution."""
default_sample_shape = torch.Size()
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
super().__init__(loc, scale, validate_args=False) super().__init__(loc, scale, validate_args=False)
self.low = low self.low = low
@ -97,7 +99,7 @@ class TruncatedNormal(pyd.Normal):
x = x - x.detach() + clamped_x.detach() x = x - x.detach() + clamped_x.detach()
return x return x
def sample(self, clip=None, sample_shape=torch.Size()): def sample(self, clip=None, sample_shape=default_sample_shape):
shape = self._extended_shape(sample_shape) shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
eps *= self.scale eps *= self.scale
@ -136,7 +138,7 @@ def enc(cfg):
"""Returns a TOLD encoder.""" """Returns a TOLD encoder."""
pixels_enc_layers, state_enc_layers = None, None pixels_enc_layers, state_enc_layers = None, None
if cfg.modality in {"pixels", "all"}: if cfg.modality in {"pixels", "all"}:
C = int(3 * cfg.frame_stack) C = int(3 * cfg.frame_stack) # noqa: N806
pixels_enc_layers = [ pixels_enc_layers = [
NormalizeImg(), NormalizeImg(),
nn.Conv2d(C, cfg.num_channels, 7, stride=2), nn.Conv2d(C, cfg.num_channels, 7, stride=2),
@ -184,7 +186,7 @@ def enc(cfg):
return Multiplexer(nn.ModuleDict(encoders)) return Multiplexer(nn.ModuleDict(encoders))
def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()): def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
"""Returns an MLP.""" """Returns an MLP."""
if isinstance(mlp_dim, int): if isinstance(mlp_dim, int):
mlp_dim = [mlp_dim, mlp_dim] mlp_dim = [mlp_dim, mlp_dim]
@ -199,7 +201,7 @@ def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
) )
def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()): def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
"""Returns a dynamics network.""" """Returns a dynamics network."""
return nn.Sequential( return nn.Sequential(
mlp(in_dim, mlp_dim, out_dim, act_fn), mlp(in_dim, mlp_dim, out_dim, act_fn),
@ -327,7 +329,7 @@ class RandomShiftsAug(nn.Module):
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
class Episode(object): class Episode:
"""Storage object for a single episode.""" """Storage object for a single episode."""
def __init__(self, cfg, init_obs): def __init__(self, cfg, init_obs):
@ -354,18 +356,10 @@ class Episode(object):
self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device) self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
else: else:
raise ValueError raise ValueError
self.actions = torch.empty( self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device)
(cfg.episode_length, action_dim), dtype=torch.float32, device=self.device self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
) self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device)
self.rewards = torch.empty( self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
(cfg.episode_length,), dtype=torch.float32, device=self.device
)
self.dones = torch.empty(
(cfg.episode_length,), dtype=torch.bool, device=self.device
)
self.masks = torch.empty(
(cfg.episode_length,), dtype=torch.float32, device=self.device
)
self.cumulative_reward = 0 self.cumulative_reward = 0
self.done = False self.done = False
self.success = False self.success = False
@ -380,23 +374,17 @@ class Episode(object):
if cfg.modality in {"pixels", "state"}: if cfg.modality in {"pixels", "state"}:
episode = cls(cfg, obses[0]) episode = cls(cfg, obses[0])
episode.obses[1:] = torch.tensor( episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device)
obses[1:], dtype=episode.obses.dtype, device=episode.device
)
elif cfg.modality == "all": elif cfg.modality == "all":
episode = cls(cfg, {k: v[0] for k, v in obses.items()}) episode = cls(cfg, {k: v[0] for k, v in obses.items()})
for k, v in obses.items(): for k in obses:
episode.obses[k][1:] = torch.tensor( episode.obses[k][1:] = torch.tensor(
obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
) )
else: else:
raise NotImplementedError raise NotImplementedError
episode.actions = torch.tensor( episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device)
actions, dtype=episode.actions.dtype, device=episode.device episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device)
)
episode.rewards = torch.tensor(
rewards, dtype=episode.rewards.dtype, device=episode.device
)
episode.dones = ( episode.dones = (
torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device) torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
if dones is not None if dones is not None
@ -428,9 +416,7 @@ class Episode(object):
v, dtype=self.obses[k].dtype, device=self.obses[k].device v, dtype=self.obses[k].dtype, device=self.obses[k].device
) )
else: else:
self.obses[self._idx + 1] = torch.tensor( self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device)
obs, dtype=self.obses.dtype, device=self.obses.device
)
self.actions[self._idx] = action self.actions[self._idx] = action
self.rewards[self._idx] = reward self.rewards[self._idx] = reward
self.dones[self._idx] = done self.dones[self._idx] = done
@ -453,7 +439,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False):
] ]
if cfg.task.startswith("xarm"): if cfg.task.startswith("xarm"):
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl") dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
print(f"Using offline dataset '{dataset_path}'") print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f: with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f) dataset_dict = pickle.load(f)
@ -461,7 +447,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False):
if k not in dataset_dict and k[:-1] in dataset_dict: if k not in dataset_dict and k[:-1] in dataset_dict:
dataset_dict[k] = dataset_dict.pop(k[:-1]) dataset_dict[k] = dataset_dict.pop(k[:-1])
elif cfg.task.startswith("legged"): elif cfg.task.startswith("legged"):
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl") dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
print(f"Using offline dataset '{dataset_path}'") print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f: with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f) dataset_dict = pickle.load(f)
@ -475,10 +461,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False):
for i in range(len(dones) - 1): for i in range(len(dones) - 1):
if ( if (
np.linalg.norm( np.linalg.norm(dataset_dict["observations"][i + 1] - dataset_dict["next_observations"][i])
dataset_dict["observations"][i + 1]
- dataset_dict["next_observations"][i]
)
> 1e-6 > 1e-6
or dataset_dict["terminals"][i] == 1.0 or dataset_dict["terminals"][i] == 1.0
): ):
@ -501,7 +484,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False):
dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"]) dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"])
for key in required_keys: for key in required_keys:
assert key in dataset_dict.keys(), f"Missing `{key}` in dataset." assert key in dataset_dict, f"Missing `{key}` in dataset."
if return_reward_normalizer: if return_reward_normalizer:
return dataset_dict, reward_normalizer return dataset_dict, reward_normalizer
@ -553,9 +536,7 @@ def get_reward_normalizer(cfg, dataset):
return lambda x: x - 1.0 return lambda x: x - 1.0
elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]: elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]:
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset) (_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
return ( return lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
)
elif hasattr(cfg, "reward_scale"): elif hasattr(cfg, "reward_scale"):
return lambda x: x * cfg.reward_scale return lambda x: x * cfg.reward_scale
return lambda x: x return lambda x: x
@ -571,12 +552,12 @@ def linear_schedule(schdl, step):
except ValueError: except ValueError:
match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl) match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl)
if match: if match:
init, final, start, end = [float(g) for g in match.groups()] init, final, start, end = (float(g) for g in match.groups())
mix = np.clip((step - start) / (end - start), 0.0, 1.0) mix = np.clip((step - start) / (end - start), 0.0, 1.0)
return (1.0 - mix) * init + mix * final return (1.0 - mix) * init + mix * final
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl) match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
if match: if match:
init, final, duration = [float(g) for g in match.groups()] init, final, duration = (float(g) for g in match.groups())
mix = np.clip(step / duration, 0.0, 1.0) mix = np.clip(step / duration, 0.0, 1.0)
return (1.0 - mix) * init + mix * final return (1.0 - mix) * init + mix * final
raise NotImplementedError(schdl) raise NotImplementedError(schdl)

View File

@ -22,4 +22,4 @@ env:
policy: policy:
state_dim: 2 state_dim: 2
action_dim: 2 action_dim: 2

View File

@ -21,4 +21,4 @@ env:
policy: policy:
state_dim: 4 state_dim: 4
action_dim: 4 action_dim: 4

View File

@ -37,10 +37,11 @@ def eval_policy(
tensordict = env.reset() tensordict = env.reset()
ep_frames = [] ep_frames = []
if save_video or (return_first_video and i == 0): if save_video or (return_first_video and i == 0):
def rendering_callback(env, td=None): def rendering_callback(env, td=None):
ep_frames.append(env.render()) ep_frames.append(env.render()) # noqa: B023
# render first frame before rollout # render first frame before rollout
rendering_callback(env) rendering_callback(env)

View File

@ -6,8 +6,6 @@ import torch
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from termcolor import colored from termcolor import colored
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
from torchrl.data.datasets.openx import OpenXExperienceReplay
from torchrl.data.replay_buffers import PrioritizedSliceSampler from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
@ -27,9 +25,7 @@ def train_cli(cfg: dict):
) )
def train_notebook( def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
out_dir=None, job_name=None, config_name="default", config_path="../configs"
):
from hydra import compose, initialize from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear() hydra.core.global_hydra.GlobalHydra.instance().clear()
@ -38,7 +34,7 @@ def train_notebook(
train(cfg, out_dir=out_dir, job_name=job_name) train(cfg, out_dir=out_dir, job_name=job_name)
def log_training_metrics(L, metrics, step, online_episode_idx, start_time, is_offline): def log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline):
common_metrics = { common_metrics = {
"episode": online_episode_idx, "episode": online_episode_idx,
"step": step, "step": step,
@ -46,12 +42,10 @@ def log_training_metrics(L, metrics, step, online_episode_idx, start_time, is_of
"is_offline": float(is_offline), "is_offline": float(is_offline),
} }
metrics.update(common_metrics) metrics.update(common_metrics)
L.log(metrics, category="train") logger.log(metrics, category="train")
def eval_policy_and_log( def eval_policy_and_log(env, td_policy, step, online_episode_idx, start_time, cfg, logger, is_offline):
env, td_policy, step, online_episode_idx, start_time, cfg, L, is_offline
):
common_metrics = { common_metrics = {
"episode": online_episode_idx, "episode": online_episode_idx,
"step": step, "step": step,
@ -65,11 +59,11 @@ def eval_policy_and_log(
return_first_video=True, return_first_video=True,
) )
metrics.update(common_metrics) metrics.update(common_metrics)
L.log(metrics, category="eval") logger.log(metrics, category="eval")
if cfg.wandb.enable: if cfg.wandb.enable:
eval_video = L._wandb.Video(first_video, fps=cfg.fps, format="mp4") eval_video = logger._wandb.Video(first_video, fps=cfg.fps, format="mp4")
L._wandb.log({"eval_video": eval_video}, step=step) logger._wandb.log({"eval_video": eval_video}, step=step)
def train(cfg: dict, out_dir=None, job_name=None): def train(cfg: dict, out_dir=None, job_name=None):
@ -116,7 +110,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
sampler=online_sampler, sampler=online_sampler,
) )
L = Logger(out_dir, job_name, cfg) logger = Logger(out_dir, job_name, cfg)
online_episode_idx = 0 online_episode_idx = 0
start_time = time.time() start_time = time.time()
@ -129,9 +123,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
metrics = policy.update(offline_buffer, step) metrics = policy.update(offline_buffer, step)
if step % cfg.log_freq == 0: if step % cfg.log_freq == 0:
log_training_metrics( log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline=False)
L, metrics, step, online_episode_idx, start_time, is_offline=False
)
if step > 0 and step % cfg.eval_freq == 0: if step > 0 and step % cfg.eval_freq == 0:
eval_policy_and_log( eval_policy_and_log(
@ -141,13 +133,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_episode_idx, online_episode_idx,
start_time, start_time,
cfg, cfg,
L, logger,
is_offline=True, is_offline=True,
) )
if step > 0 and cfg.save_model and step % cfg.save_freq == 0: if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
print(f"Checkpoint model at step {step}") print(f"Checkpoint model at step {step}")
L.save_model(policy, identifier=step) logger.save_model(policy, identifier=step)
step += 1 step += 1
@ -164,9 +156,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
auto_cast_to_device=True, auto_cast_to_device=True,
) )
assert len(rollout) <= cfg.env.episode_length assert len(rollout) <= cfg.env.episode_length
rollout["episode"] = torch.tensor( rollout["episode"] = torch.tensor([online_episode_idx] * len(rollout), dtype=torch.int)
[online_episode_idx] * len(rollout), dtype=torch.int
)
online_buffer.extend(rollout) online_buffer.extend(rollout)
ep_sum_reward = rollout["next", "reward"].sum() ep_sum_reward = rollout["next", "reward"].sum()
@ -188,9 +178,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
) )
metrics.update(train_metrics) metrics.update(train_metrics)
if step % cfg.log_freq == 0: if step % cfg.log_freq == 0:
log_training_metrics( log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline=False)
L, metrics, step, online_episode_idx, start_time, is_offline=False
)
if step > 0 and step % cfg.eval_freq == 0: if step > 0 and step % cfg.eval_freq == 0:
eval_policy_and_log( eval_policy_and_log(
@ -200,13 +188,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_episode_idx, online_episode_idx,
start_time, start_time,
cfg, cfg,
L, logger,
is_offline=False, is_offline=False,
) )
if step > 0 and cfg.save_model and step % cfg.save_freq == 0: if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
print(f"Checkpoint model at step {step}") print(f"Checkpoint model at step {step}")
L.save_model(policy, identifier=step) logger.save_model(policy, identifier=step)
step += 1 step += 1

View File

@ -1,24 +1,22 @@
import pickle
from pathlib import Path from pathlib import Path
import hydra import hydra
import imageio import imageio
import simxarm
import torch import torch
from torchrl.data.replay_buffers import ( from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
SliceSampler,
SliceSamplerWithoutReplacement, SliceSamplerWithoutReplacement,
) )
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
NUM_EPISODES_TO_RENDER = 10
MAX_NUM_STEPS = 1000
FIRST_FRAME = 0
@hydra.main(version_base=None, config_name="default", config_path="../configs") @hydra.main(version_base=None, config_name="default", config_path="../configs")
def visualize_dataset_cli(cfg: dict): def visualize_dataset_cli(cfg: dict):
visualize_dataset( visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
)
def visualize_dataset(cfg: dict, out_dir=None): def visualize_dataset(cfg: dict, out_dir=None):
@ -33,9 +31,6 @@ def visualize_dataset(cfg: dict, out_dir=None):
offline_buffer = make_offline_buffer(cfg, sampler) offline_buffer = make_offline_buffer(cfg, sampler)
NUM_EPISODES_TO_RENDER = 10
MAX_NUM_STEPS = 1000
FIRST_FRAME = 0
for _ in range(NUM_EPISODES_TO_RENDER): for _ in range(NUM_EPISODES_TO_RENDER):
episode = offline_buffer.sample(MAX_NUM_STEPS) episode = offline_buffer.sample(MAX_NUM_STEPS)
@ -57,9 +52,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check" assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
assert ep_frames.max().item() <= 255 assert ep_frames.max().item() <= 255
ep_frames = ep_frames.type(torch.uint8) ep_frames = ep_frames.type(torch.uint8)
imageio.mimsave( imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps)
video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps
)
# ran out of episodes # ran out of episodes
if offline_buffer._sampler._sample_list.numel() == 0: if offline_buffer._sampler._sample_list.numel() == 0:

105
poetry.lock generated
View File

@ -192,6 +192,17 @@ files = [
[package.dependencies] [package.dependencies]
pycparser = "*" pycparser = "*"
[[package]]
name = "cfgv"
version = "3.4.0"
description = "Validate configuration and produce human readable error messages."
optional = false
python-versions = ">=3.8"
files = [
{file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"},
{file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"},
]
[[package]] [[package]]
name = "charset-normalizer" name = "charset-normalizer"
version = "3.3.2" version = "3.3.2"
@ -420,6 +431,17 @@ url = "https://github.com/real-stanford/diffusion_policy"
reference = "HEAD" reference = "HEAD"
resolved_reference = "548a52bbb105518058e27bf34dcf90bf6f73681a" resolved_reference = "548a52bbb105518058e27bf34dcf90bf6f73681a"
[[package]]
name = "distlib"
version = "0.3.8"
description = "Distribution utilities"
optional = false
python-versions = "*"
files = [
{file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"},
{file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
]
[[package]] [[package]]
name = "dm-env" name = "dm-env"
version = "1.6" version = "1.6"
@ -741,6 +763,20 @@ antlr4-python3-runtime = "==4.9.*"
omegaconf = ">=2.2,<2.4" omegaconf = ">=2.2,<2.4"
packaging = "*" packaging = "*"
[[package]]
name = "identify"
version = "2.5.35"
description = "File identification library for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "identify-2.5.35-py2.py3-none-any.whl", hash = "sha256:c4de0081837b211594f8e877a6b4fad7ca32bbfc1a9307fdd61c28bfe923f13e"},
{file = "identify-2.5.35.tar.gz", hash = "sha256:10a7ca245cfcd756a554a7288159f72ff105ad233c7c4b9c6f0f4d108f5f6791"},
]
[package.extras]
license = ["ukkonen"]
[[package]] [[package]]
name = "idna" name = "idna"
version = "3.6" version = "3.6"
@ -1069,6 +1105,20 @@ doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.
extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"]
test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
[[package]]
name = "nodeenv"
version = "1.8.0"
description = "Node.js virtual environment builder"
optional = false
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
files = [
{file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
{file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
]
[package.dependencies]
setuptools = "*"
[[package]] [[package]]
name = "numba" name = "numba"
version = "0.59.0" version = "0.59.0"
@ -1537,6 +1587,39 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa
typing = ["typing-extensions"] typing = ["typing-extensions"]
xmp = ["defusedxml"] xmp = ["defusedxml"]
[[package]]
name = "platformdirs"
version = "4.2.0"
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
optional = false
python-versions = ">=3.8"
files = [
{file = "platformdirs-4.2.0-py3-none-any.whl", hash = "sha256:0614df2a2f37e1a662acbd8e2b25b92ccf8632929bc6d43467e17fe89c75e068"},
{file = "platformdirs-4.2.0.tar.gz", hash = "sha256:ef0cc731df711022c174543cb70a9b5bd22e5a9337c8624ef2c2ceb8ddad8768"},
]
[package.extras]
docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"]
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"]
[[package]]
name = "pre-commit"
version = "3.6.2"
description = "A framework for managing and maintaining multi-language pre-commit hooks."
optional = false
python-versions = ">=3.9"
files = [
{file = "pre_commit-3.6.2-py2.py3-none-any.whl", hash = "sha256:ba637c2d7a670c10daedc059f5c49b5bd0aadbccfcd7ec15592cf9665117532c"},
{file = "pre_commit-3.6.2.tar.gz", hash = "sha256:c3ef34f463045c88658c5b99f38c1e297abdcc0ff13f98d3370055fbbfabc67e"},
]
[package.dependencies]
cfgv = ">=2.0.0"
identify = ">=1.0.0"
nodeenv = ">=0.11.1"
pyyaml = ">=5.1"
virtualenv = ">=20.10.0"
[[package]] [[package]]
name = "proglog" name = "proglog"
version = "0.1.10" version = "0.1.10"
@ -2462,6 +2545,26 @@ h2 = ["h2 (>=4,<5)"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"] zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "virtualenv"
version = "20.25.1"
description = "Virtual Python Environment builder"
optional = false
python-versions = ">=3.7"
files = [
{file = "virtualenv-20.25.1-py3-none-any.whl", hash = "sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a"},
{file = "virtualenv-20.25.1.tar.gz", hash = "sha256:e08e13ecdca7a0bd53798f356d5831434afa5b07b93f0abdf0797b7a06ffe197"},
]
[package.dependencies]
distlib = ">=0.3.7,<1"
filelock = ">=3.12.2,<4"
platformdirs = ">=3.9.1,<5"
[package.extras]
docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"]
test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"]
[[package]] [[package]]
name = "wandb" name = "wandb"
version = "0.16.3" version = "0.16.3"
@ -2538,4 +2641,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "4c34065f18b708f6663ce5740011d2062b2995d1eaefbcb664572870827efd7c" content-hash = "7878b7e80b73355d98402655a8bf51bab122444555cbe8ae5d0f9f1e2effe4b2"

View File

@ -48,8 +48,39 @@ opencv-python = "^4.9.0.80"
diffusion-policy = {git = "https://github.com/real-stanford/diffusion_policy"} diffusion-policy = {git = "https://github.com/real-stanford/diffusion_policy"}
[tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.ruff]
line-length = 110
target-version = "py310"
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
]
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]

View File

@ -18,5 +18,5 @@ apptainer exec --nv \
source ~/.bashrc source ~/.bashrc
conda activate fowm conda activate fowm
srun $CMD srun $CMD