commit
cb7b375526
|
@ -0,0 +1,33 @@
|
|||
exclude: ^(data/|tests/|diffusion_policy/)
|
||||
default_language_version:
|
||||
python: python3.10
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
- id: debug-statements
|
||||
- id: check-merge-conflict
|
||||
- id: check-case-conflict
|
||||
- id: check-yaml
|
||||
- id: check-toml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.15.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.2.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/python-poetry/poetry
|
||||
rev: 1.8.0
|
||||
hooks:
|
||||
- id: poetry-check
|
||||
- id: poetry-lock
|
||||
args:
|
||||
- "--check"
|
||||
- "--no-update"
|
|
@ -10,7 +10,7 @@ conda activate lerobot
|
|||
|
||||
[Install `poetry`](https://python-poetry.org/docs/#installation) (if you don't have it already)
|
||||
```
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
```
|
||||
|
||||
Install dependencies
|
||||
|
@ -26,6 +26,7 @@ export TMPDIR='~/tmp'
|
|||
|
||||
Install `diffusion_policy` #HACK
|
||||
```
|
||||
# from this directory
|
||||
git clone https://github.com/real-stanford/diffusion_policy
|
||||
cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/
|
||||
```
|
||||
|
@ -107,11 +108,10 @@ eval_episodes=7
|
|||
|
||||
**Style**
|
||||
```
|
||||
isort lerobot && isort tests && black lerobot && black tests
|
||||
pylint lerobot && pylint tests # not enforce for now
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
**Tests**
|
||||
```
|
||||
pytest -sx tests
|
||||
```
|
||||
```
|
||||
|
|
|
@ -70,6 +70,7 @@ def make_offline_buffer(cfg, sampler=None):
|
|||
offline_buffer = PushtExperienceReplay(
|
||||
"pusht",
|
||||
# download="force",
|
||||
# TODO(aliberts): automate download
|
||||
download=False,
|
||||
streaming=False,
|
||||
root="data",
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
@ -10,25 +9,25 @@ import pymunk
|
|||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import (
|
||||
TensorDictPrioritizedReplayBuffer,
|
||||
TensorDictReplayBuffer,
|
||||
)
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
Sampler,
|
||||
SliceSampler,
|
||||
SliceSamplerWithoutReplacement,
|
||||
)
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
||||
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
|
||||
|
||||
|
||||
def get_goal_pose_body(pose):
|
||||
mass = 1
|
||||
|
@ -53,7 +52,7 @@ def add_tee(
|
|||
angle,
|
||||
scale=30,
|
||||
color="LightSlateGray",
|
||||
mask=pymunk.ShapeFilter.ALL_MASKS(),
|
||||
mask=DEFAULT_TEE_MASK,
|
||||
):
|
||||
mass = 1
|
||||
length = 4
|
||||
|
@ -87,7 +86,6 @@ def add_tee(
|
|||
|
||||
|
||||
class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id,
|
||||
|
@ -127,7 +125,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
if split_trajs:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.download == True:
|
||||
if self.download:
|
||||
raise NotImplementedError()
|
||||
|
||||
if root is None:
|
||||
|
@ -193,18 +191,18 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
# TODO(rcadene)
|
||||
|
||||
# load
|
||||
# TODO(aliberts): Dynamic paths
|
||||
zarr_path = (
|
||||
"/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
|
||||
# "/home/simon/build/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
|
||||
)
|
||||
dataset_dict = ReplayBuffer.copy_from_path(
|
||||
zarr_path
|
||||
) # , keys=['img', 'state', 'action'])
|
||||
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
||||
|
||||
episode_ids = dataset_dict.get_episode_idxs()
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
total_frames = dataset_dict["action"].shape[0]
|
||||
assert len(
|
||||
set([dataset_dict[key].shape[0] for key in dataset_dict.keys()])
|
||||
{dataset_dict[key].shape[0] for key in dataset_dict}
|
||||
), "Some data type dont have the same number of total frames."
|
||||
|
||||
# TODO: verify that goal pose is expected to be fixed
|
||||
|
@ -245,9 +243,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = add_tee(
|
||||
space, block_pos[i].tolist(), block_angle[i].item()
|
||||
)
|
||||
block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
|
@ -278,11 +274,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = (
|
||||
episode[0]
|
||||
.expand(total_frames)
|
||||
.memmap_like(self.root / self.dataset_id)
|
||||
)
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
|
||||
|
||||
td_data[idxtd : idxtd + len(episode)] = episode
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torchrl
|
||||
|
@ -9,7 +9,6 @@ import tqdm
|
|||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import (
|
||||
TensorDictPrioritizedReplayBuffer,
|
||||
TensorDictReplayBuffer,
|
||||
)
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
|
@ -22,7 +21,6 @@ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
|||
|
||||
|
||||
class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
|
@ -77,15 +75,11 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
|||
|
||||
if num_slices is not None or slice_len is not None:
|
||||
if sampler is not None:
|
||||
raise ValueError(
|
||||
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
|
||||
)
|
||||
raise ValueError("`num_slices` and `slice_len` are exclusive with the `sampler` argument.")
|
||||
|
||||
if replacement:
|
||||
if not self.shuffle:
|
||||
raise RuntimeError(
|
||||
"shuffle=False can only be used when replacement=False."
|
||||
)
|
||||
raise RuntimeError("shuffle=False can only be used when replacement=False.")
|
||||
sampler = SliceSampler(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
|
@ -130,7 +124,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
|||
|
||||
# load
|
||||
dataset_dir = Path("data") / self.dataset_id
|
||||
dataset_path = dataset_dir / f"buffer.pkl"
|
||||
dataset_path = dataset_dir / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
|
@ -150,12 +144,8 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
|||
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||
next_image = torch.tensor(
|
||||
dataset_dict["next_observations"]["rgb"][idx0:idx1]
|
||||
)
|
||||
next_state = torch.tensor(
|
||||
dataset_dict["next_observations"]["state"][idx0:idx1]
|
||||
)
|
||||
next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||
next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||
|
||||
|
@ -176,11 +166,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
|||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = (
|
||||
episode[0]
|
||||
.expand(total_frames)
|
||||
.memmap_like(self.root / self.dataset_id)
|
||||
)
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
|
||||
|
||||
td_data[idx0:idx1] = episode
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
|
@ -20,7 +19,6 @@ _has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _
|
|||
|
||||
|
||||
class PushtEnv(EnvBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frame_skip: int = 1,
|
||||
|
@ -46,7 +44,8 @@ class PushtEnv(EnvBase):
|
|||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
from diffusion_policy.env.pusht.pusht_env import PushTEnv
|
||||
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
||||
# from diffusion_policy.env.pusht.pusht_env import PushTEnv
|
||||
|
||||
if not from_pixels:
|
||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||
|
@ -71,14 +70,10 @@ class PushtEnv(EnvBase):
|
|||
obs = {"image": torch.from_numpy(raw_obs["image"])}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(
|
||||
torch.float32
|
||||
)
|
||||
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
|
||||
else:
|
||||
# TODO:
|
||||
obs = {
|
||||
"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)
|
||||
}
|
||||
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||
|
||||
obs = TensorDict(obs, batch_size=[])
|
||||
return obs
|
||||
|
@ -109,7 +104,7 @@ class PushtEnv(EnvBase):
|
|||
# step expects shape=(4,) so we pad if necessary
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
for t in range(self.frame_skip):
|
||||
for _ in range(self.frame_skip):
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
sum_reward += reward
|
||||
|
||||
|
|
|
@ -15,12 +15,13 @@ from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
|||
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
MAX_NUM_ACTIONS = 4
|
||||
|
||||
_has_gym = importlib.util.find_spec("gym") is not None
|
||||
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
|
||||
|
||||
|
||||
class SimxarmEnv(EnvBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
|
@ -52,18 +53,13 @@ class SimxarmEnv(EnvBase):
|
|||
from simxarm import TASKS
|
||||
|
||||
if self.task not in TASKS:
|
||||
raise ValueError(
|
||||
f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}"
|
||||
)
|
||||
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
||||
|
||||
self._env = TASKS[self.task]["env"]()
|
||||
|
||||
MAX_NUM_ACTIONS = 4
|
||||
num_actions = len(TASKS[self.task]["action_space"])
|
||||
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||
self._action_padding = np.zeros(
|
||||
(MAX_NUM_ACTIONS - num_actions), dtype=np.float32
|
||||
)
|
||||
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
|
||||
if "w" not in TASKS[self.task]["action_space"]:
|
||||
self._action_padding[-1] = 1.0
|
||||
|
||||
|
@ -75,9 +71,7 @@ class SimxarmEnv(EnvBase):
|
|||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
image = self.render(
|
||||
mode="rgb_array", width=self.image_size, height=self.image_size
|
||||
)
|
||||
image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
|
||||
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||
image = torch.tensor(image.copy(), dtype=torch.uint8)
|
||||
|
||||
|
@ -114,7 +108,7 @@ class SimxarmEnv(EnvBase):
|
|||
action = np.concatenate([action, self._action_padding])
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
for t in range(self.frame_skip):
|
||||
for _ in range(self.frame_skip):
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
sum_reward += reward
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ from torchrl.envs.transforms import ObservationTransform
|
|||
|
||||
|
||||
class Prod(ObservationTransform):
|
||||
|
||||
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import contextlib
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
@ -29,10 +29,8 @@ AGENT_METRICS = [
|
|||
|
||||
def make_dir(dir_path):
|
||||
"""Create directory if it does not already exist."""
|
||||
try:
|
||||
with contextlib.suppress(OSError):
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
return dir_path
|
||||
|
||||
|
||||
|
@ -59,9 +57,7 @@ def print_run(cfg, reward=None):
|
|||
# ('experiment', cfg.exp_name),
|
||||
]
|
||||
if reward is not None:
|
||||
kvs.append(
|
||||
("episode reward", colored(str(int(reward)), "white", attrs=["bold"]))
|
||||
)
|
||||
kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"])))
|
||||
w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
|
||||
div = "-" * w
|
||||
print(div)
|
||||
|
@ -80,7 +76,7 @@ def cfg_to_group(cfg, return_list=False):
|
|||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
class Logger(object):
|
||||
class Logger:
|
||||
"""Primary logger object. Logs either locally or using wandb."""
|
||||
|
||||
def __init__(self, log_dir, job_name, cfg):
|
||||
|
@ -183,7 +179,5 @@ class Logger(object):
|
|||
if category == "eval":
|
||||
keys = ["step", "avg_sum_reward", "avg_max_reward", "pc_success"]
|
||||
self._eval.append(np.array([d[key] for key in keys]))
|
||||
pd.DataFrame(np.array(self._eval)).to_csv(
|
||||
self._log_dir / "eval.log", header=keys, index=None
|
||||
)
|
||||
pd.DataFrame(np.array(self._eval)).to_csv(self._log_dir / "eval.log", header=keys, index=None)
|
||||
self._print(d, category)
|
||||
|
|
|
@ -3,16 +3,17 @@ import copy
|
|||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
|
||||
FIRST_ACTION = 0
|
||||
|
||||
|
||||
class DiffusionPolicy(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
|
@ -105,7 +106,6 @@ class DiffusionPolicy(nn.Module):
|
|||
out = self.diffusion.predict_action(obs_dict)
|
||||
|
||||
# TODO(rcadene): add possibility to return >1 timestemps
|
||||
FIRST_ACTION = 0
|
||||
action = out["action"].squeeze(0)[FIRST_ACTION]
|
||||
return action
|
||||
|
||||
|
@ -132,10 +132,7 @@ class DiffusionPolicy(nn.Module):
|
|||
}
|
||||
return out
|
||||
|
||||
if self.cfg.balanced_sampling:
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
else:
|
||||
batch = replay_buffer.sample()
|
||||
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# ruff: noqa: N806
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import einops
|
||||
|
@ -7,6 +9,8 @@ import torch.nn as nn
|
|||
|
||||
import lerobot.common.policies.tdmpc_helper as h
|
||||
|
||||
FIRST_FRAME = 0
|
||||
|
||||
|
||||
class TOLD(nn.Module):
|
||||
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||
|
@ -17,9 +21,7 @@ class TOLD(nn.Module):
|
|||
|
||||
self.cfg = cfg
|
||||
self._encoder = h.enc(cfg)
|
||||
self._dynamics = h.dynamics(
|
||||
cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim
|
||||
)
|
||||
self._dynamics = h.dynamics(cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim)
|
||||
self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1)
|
||||
self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim)
|
||||
self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)])
|
||||
|
@ -65,17 +67,17 @@ class TOLD(nn.Module):
|
|||
return h.TruncatedNormal(mu, std).sample(clip=0.3)
|
||||
return mu
|
||||
|
||||
def V(self, z):
|
||||
def V(self, z): # noqa: N802
|
||||
"""Predict state value (V)."""
|
||||
return self._V(z)
|
||||
|
||||
def Q(self, z, a, return_type):
|
||||
def Q(self, z, a, return_type): # noqa: N802
|
||||
"""Predict state-action value (Q)."""
|
||||
assert return_type in {"min", "avg", "all"}
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
|
||||
if return_type == "all":
|
||||
return torch.stack(list(q(x) for q in self._Qs), dim=0)
|
||||
return torch.stack([q(x) for q in self._Qs], dim=0)
|
||||
|
||||
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
|
||||
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
|
||||
|
@ -160,11 +162,7 @@ class TDMPC(nn.Module):
|
|||
pi = self.model.pi(z, self.cfg.min_std)
|
||||
G += discount * self.model.Q(z, pi, return_type="min")
|
||||
if self.cfg.uncertainty_cost > 0:
|
||||
G -= (
|
||||
discount
|
||||
* self.cfg.uncertainty_cost
|
||||
* self.model.Q(z, pi, return_type="all").std(dim=0)
|
||||
)
|
||||
G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0)
|
||||
return G
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -180,19 +178,13 @@ class TDMPC(nn.Module):
|
|||
assert step is not None
|
||||
# Seed steps
|
||||
if step < self.cfg.seed_steps and self.model.training:
|
||||
return torch.empty(
|
||||
self.action_dim, dtype=torch.float32, device=self.device
|
||||
).uniform_(-1, 1)
|
||||
return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)
|
||||
|
||||
# Sample policy trajectories
|
||||
horizon = int(
|
||||
min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))
|
||||
)
|
||||
horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
|
||||
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
|
||||
if num_pi_trajs > 0:
|
||||
pi_actions = torch.empty(
|
||||
horizon, num_pi_trajs, self.action_dim, device=self.device
|
||||
)
|
||||
pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device)
|
||||
_z = z.repeat(num_pi_trajs, 1)
|
||||
for t in range(horizon):
|
||||
pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
|
||||
|
@ -201,20 +193,16 @@ class TDMPC(nn.Module):
|
|||
# Initialize state and parameters
|
||||
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
|
||||
mean = torch.zeros(horizon, self.action_dim, device=self.device)
|
||||
std = self.cfg.max_std * torch.ones(
|
||||
horizon, self.action_dim, device=self.device
|
||||
)
|
||||
std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device)
|
||||
if not t0 and hasattr(self, "_prev_mean"):
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
|
||||
# Iterate CEM
|
||||
for i in range(self.cfg.iterations):
|
||||
for _ in range(self.cfg.iterations):
|
||||
actions = torch.clamp(
|
||||
mean.unsqueeze(1)
|
||||
+ std.unsqueeze(1)
|
||||
* torch.randn(
|
||||
horizon, self.cfg.num_samples, self.action_dim, device=std.device
|
||||
),
|
||||
* torch.randn(horizon, self.cfg.num_samples, self.action_dim, device=std.device),
|
||||
-1,
|
||||
1,
|
||||
)
|
||||
|
@ -223,18 +211,14 @@ class TDMPC(nn.Module):
|
|||
|
||||
# Compute elite actions
|
||||
value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
|
||||
elite_idxs = torch.topk(
|
||||
value.squeeze(1), self.cfg.num_elites, dim=0
|
||||
).indices
|
||||
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||
|
||||
# Update parameters
|
||||
max_value = elite_value.max(0)[0]
|
||||
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
|
||||
score /= score.sum(0)
|
||||
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (
|
||||
score.sum(0) + 1e-9
|
||||
)
|
||||
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
|
||||
_std = torch.sqrt(
|
||||
torch.sum(
|
||||
score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2,
|
||||
|
@ -331,7 +315,6 @@ class TDMPC(nn.Module):
|
|||
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
||||
batch = batch.to(self.device)
|
||||
|
||||
FIRST_FRAME = 0
|
||||
obs = {
|
||||
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
|
||||
"state": batch["observation", "state"][FIRST_FRAME],
|
||||
|
@ -359,10 +342,7 @@ class TDMPC(nn.Module):
|
|||
weights = batch["_weight"][FIRST_FRAME, :, None]
|
||||
return obs, action, next_obses, reward, mask, done, idxs, weights
|
||||
|
||||
if self.cfg.balanced_sampling:
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
else:
|
||||
batch = replay_buffer.sample()
|
||||
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
||||
|
||||
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
|
||||
batch, self.cfg.horizon, num_slices
|
||||
|
@ -384,10 +364,7 @@ class TDMPC(nn.Module):
|
|||
|
||||
if isinstance(obs, dict):
|
||||
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
|
||||
next_obses = {
|
||||
k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1)
|
||||
for k in next_obses
|
||||
}
|
||||
next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
|
||||
else:
|
||||
obs = torch.cat([obs, demo_obs])
|
||||
next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
|
||||
|
@ -429,9 +406,7 @@ class TDMPC(nn.Module):
|
|||
td_targets = self._td_target(next_z, reward, mask)
|
||||
|
||||
# Latent rollout
|
||||
zs = torch.empty(
|
||||
horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device
|
||||
)
|
||||
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
|
||||
reward_preds = torch.empty_like(reward, device=self.device)
|
||||
assert reward.shape[0] == horizon
|
||||
z = self.model.encode(obs)
|
||||
|
@ -452,12 +427,10 @@ class TDMPC(nn.Module):
|
|||
value_info["V"] = v.mean().item()
|
||||
|
||||
# Losses
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(
|
||||
-1, 1, 1
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
|
||||
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
|
||||
dim=0
|
||||
)
|
||||
consistency_loss = (
|
||||
rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask
|
||||
).sum(dim=0)
|
||||
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
|
||||
q_value_loss, priority_loss = 0, 0
|
||||
for q in range(self.cfg.num_q):
|
||||
|
@ -465,9 +438,7 @@ class TDMPC(nn.Module):
|
|||
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||
|
||||
expectile = h.linear_schedule(self.cfg.expectile, step)
|
||||
v_value_loss = (
|
||||
rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask
|
||||
).sum(dim=0)
|
||||
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
|
||||
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss
|
||||
|
|
|
@ -5,11 +5,15 @@ import re
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import distributions as pyd
|
||||
from torch.distributions.utils import _standard_normal
|
||||
|
||||
__REDUCE__ = lambda b: "mean" if b else "none"
|
||||
DEFAULT_ACT_FN = nn.Mish()
|
||||
|
||||
|
||||
def __REDUCE__(b): # noqa: N802, N807
|
||||
return "mean" if b else "none"
|
||||
|
||||
|
||||
def l1(pred, target, reduce=False):
|
||||
|
@ -36,11 +40,7 @@ def l2_expectile(diff, expectile=0.7, reduce=False):
|
|||
def _get_out_shape(in_shape, layers):
|
||||
"""Utility function. Returns the output shape of a network for a given input shape."""
|
||||
x = torch.randn(*in_shape).unsqueeze(0)
|
||||
return (
|
||||
(nn.Sequential(*layers) if isinstance(layers, list) else layers)(x)
|
||||
.squeeze(0)
|
||||
.shape
|
||||
)
|
||||
return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape
|
||||
|
||||
|
||||
def gaussian_logprob(eps, log_std):
|
||||
|
@ -73,7 +73,7 @@ def orthogonal_init(m):
|
|||
def ema(m, m_target, tau):
|
||||
"""Update slow-moving average of online network (target network) at rate tau."""
|
||||
with torch.no_grad():
|
||||
for p, p_target in zip(m.parameters(), m_target.parameters()):
|
||||
for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False):
|
||||
p_target.data.lerp_(p.data, tau)
|
||||
|
||||
|
||||
|
@ -86,6 +86,8 @@ def set_requires_grad(net, value):
|
|||
class TruncatedNormal(pyd.Normal):
|
||||
"""Utility class implementing the truncated normal distribution."""
|
||||
|
||||
default_sample_shape = torch.Size()
|
||||
|
||||
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
|
||||
super().__init__(loc, scale, validate_args=False)
|
||||
self.low = low
|
||||
|
@ -97,7 +99,7 @@ class TruncatedNormal(pyd.Normal):
|
|||
x = x - x.detach() + clamped_x.detach()
|
||||
return x
|
||||
|
||||
def sample(self, clip=None, sample_shape=torch.Size()):
|
||||
def sample(self, clip=None, sample_shape=default_sample_shape):
|
||||
shape = self._extended_shape(sample_shape)
|
||||
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
|
||||
eps *= self.scale
|
||||
|
@ -136,7 +138,7 @@ def enc(cfg):
|
|||
"""Returns a TOLD encoder."""
|
||||
pixels_enc_layers, state_enc_layers = None, None
|
||||
if cfg.modality in {"pixels", "all"}:
|
||||
C = int(3 * cfg.frame_stack)
|
||||
C = int(3 * cfg.frame_stack) # noqa: N806
|
||||
pixels_enc_layers = [
|
||||
NormalizeImg(),
|
||||
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
|
||||
|
@ -184,7 +186,7 @@ def enc(cfg):
|
|||
return Multiplexer(nn.ModuleDict(encoders))
|
||||
|
||||
|
||||
def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
|
||||
def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
|
||||
"""Returns an MLP."""
|
||||
if isinstance(mlp_dim, int):
|
||||
mlp_dim = [mlp_dim, mlp_dim]
|
||||
|
@ -199,7 +201,7 @@ def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
|
|||
)
|
||||
|
||||
|
||||
def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
|
||||
def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
|
||||
"""Returns a dynamics network."""
|
||||
return nn.Sequential(
|
||||
mlp(in_dim, mlp_dim, out_dim, act_fn),
|
||||
|
@ -327,7 +329,7 @@ class RandomShiftsAug(nn.Module):
|
|||
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
|
||||
|
||||
|
||||
class Episode(object):
|
||||
class Episode:
|
||||
"""Storage object for a single episode."""
|
||||
|
||||
def __init__(self, cfg, init_obs):
|
||||
|
@ -354,18 +356,10 @@ class Episode(object):
|
|||
self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
|
||||
else:
|
||||
raise ValueError
|
||||
self.actions = torch.empty(
|
||||
(cfg.episode_length, action_dim), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.rewards = torch.empty(
|
||||
(cfg.episode_length,), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.dones = torch.empty(
|
||||
(cfg.episode_length,), dtype=torch.bool, device=self.device
|
||||
)
|
||||
self.masks = torch.empty(
|
||||
(cfg.episode_length,), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device)
|
||||
self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
|
||||
self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device)
|
||||
self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
|
||||
self.cumulative_reward = 0
|
||||
self.done = False
|
||||
self.success = False
|
||||
|
@ -380,23 +374,17 @@ class Episode(object):
|
|||
|
||||
if cfg.modality in {"pixels", "state"}:
|
||||
episode = cls(cfg, obses[0])
|
||||
episode.obses[1:] = torch.tensor(
|
||||
obses[1:], dtype=episode.obses.dtype, device=episode.device
|
||||
)
|
||||
episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device)
|
||||
elif cfg.modality == "all":
|
||||
episode = cls(cfg, {k: v[0] for k, v in obses.items()})
|
||||
for k, v in obses.items():
|
||||
for k in obses:
|
||||
episode.obses[k][1:] = torch.tensor(
|
||||
obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
episode.actions = torch.tensor(
|
||||
actions, dtype=episode.actions.dtype, device=episode.device
|
||||
)
|
||||
episode.rewards = torch.tensor(
|
||||
rewards, dtype=episode.rewards.dtype, device=episode.device
|
||||
)
|
||||
episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device)
|
||||
episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device)
|
||||
episode.dones = (
|
||||
torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
|
||||
if dones is not None
|
||||
|
@ -428,9 +416,7 @@ class Episode(object):
|
|||
v, dtype=self.obses[k].dtype, device=self.obses[k].device
|
||||
)
|
||||
else:
|
||||
self.obses[self._idx + 1] = torch.tensor(
|
||||
obs, dtype=self.obses.dtype, device=self.obses.device
|
||||
)
|
||||
self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device)
|
||||
self.actions[self._idx] = action
|
||||
self.rewards[self._idx] = reward
|
||||
self.dones[self._idx] = done
|
||||
|
@ -453,7 +439,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False):
|
|||
]
|
||||
|
||||
if cfg.task.startswith("xarm"):
|
||||
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl")
|
||||
dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
|
@ -461,7 +447,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False):
|
|||
if k not in dataset_dict and k[:-1] in dataset_dict:
|
||||
dataset_dict[k] = dataset_dict.pop(k[:-1])
|
||||
elif cfg.task.startswith("legged"):
|
||||
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl")
|
||||
dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
|
@ -475,10 +461,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False):
|
|||
|
||||
for i in range(len(dones) - 1):
|
||||
if (
|
||||
np.linalg.norm(
|
||||
dataset_dict["observations"][i + 1]
|
||||
- dataset_dict["next_observations"][i]
|
||||
)
|
||||
np.linalg.norm(dataset_dict["observations"][i + 1] - dataset_dict["next_observations"][i])
|
||||
> 1e-6
|
||||
or dataset_dict["terminals"][i] == 1.0
|
||||
):
|
||||
|
@ -501,7 +484,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False):
|
|||
dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"])
|
||||
|
||||
for key in required_keys:
|
||||
assert key in dataset_dict.keys(), f"Missing `{key}` in dataset."
|
||||
assert key in dataset_dict, f"Missing `{key}` in dataset."
|
||||
|
||||
if return_reward_normalizer:
|
||||
return dataset_dict, reward_normalizer
|
||||
|
@ -553,9 +536,7 @@ def get_reward_normalizer(cfg, dataset):
|
|||
return lambda x: x - 1.0
|
||||
elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]:
|
||||
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
|
||||
return (
|
||||
lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
|
||||
)
|
||||
return lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
|
||||
elif hasattr(cfg, "reward_scale"):
|
||||
return lambda x: x * cfg.reward_scale
|
||||
return lambda x: x
|
||||
|
@ -571,12 +552,12 @@ def linear_schedule(schdl, step):
|
|||
except ValueError:
|
||||
match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl)
|
||||
if match:
|
||||
init, final, start, end = [float(g) for g in match.groups()]
|
||||
init, final, start, end = (float(g) for g in match.groups())
|
||||
mix = np.clip((step - start) / (end - start), 0.0, 1.0)
|
||||
return (1.0 - mix) * init + mix * final
|
||||
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
|
||||
if match:
|
||||
init, final, duration = [float(g) for g in match.groups()]
|
||||
init, final, duration = (float(g) for g in match.groups())
|
||||
mix = np.clip(step / duration, 0.0, 1.0)
|
||||
return (1.0 - mix) * init + mix * final
|
||||
raise NotImplementedError(schdl)
|
||||
|
|
|
@ -22,4 +22,4 @@ env:
|
|||
|
||||
policy:
|
||||
state_dim: 2
|
||||
action_dim: 2
|
||||
action_dim: 2
|
||||
|
|
|
@ -21,4 +21,4 @@ env:
|
|||
|
||||
policy:
|
||||
state_dim: 4
|
||||
action_dim: 4
|
||||
action_dim: 4
|
||||
|
|
|
@ -37,10 +37,11 @@ def eval_policy(
|
|||
tensordict = env.reset()
|
||||
|
||||
ep_frames = []
|
||||
|
||||
if save_video or (return_first_video and i == 0):
|
||||
|
||||
def rendering_callback(env, td=None):
|
||||
ep_frames.append(env.render())
|
||||
ep_frames.append(env.render()) # noqa: B023
|
||||
|
||||
# render first frame before rollout
|
||||
rendering_callback(env)
|
||||
|
|
|
@ -6,8 +6,6 @@ import torch
|
|||
from tensordict.nn import TensorDictModule
|
||||
from termcolor import colored
|
||||
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
||||
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
|
||||
from torchrl.data.datasets.openx import OpenXExperienceReplay
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
|
@ -27,9 +25,7 @@ def train_cli(cfg: dict):
|
|||
)
|
||||
|
||||
|
||||
def train_notebook(
|
||||
out_dir=None, job_name=None, config_name="default", config_path="../configs"
|
||||
):
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
|
@ -38,7 +34,7 @@ def train_notebook(
|
|||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
def log_training_metrics(L, metrics, step, online_episode_idx, start_time, is_offline):
|
||||
def log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline):
|
||||
common_metrics = {
|
||||
"episode": online_episode_idx,
|
||||
"step": step,
|
||||
|
@ -46,12 +42,10 @@ def log_training_metrics(L, metrics, step, online_episode_idx, start_time, is_of
|
|||
"is_offline": float(is_offline),
|
||||
}
|
||||
metrics.update(common_metrics)
|
||||
L.log(metrics, category="train")
|
||||
logger.log(metrics, category="train")
|
||||
|
||||
|
||||
def eval_policy_and_log(
|
||||
env, td_policy, step, online_episode_idx, start_time, cfg, L, is_offline
|
||||
):
|
||||
def eval_policy_and_log(env, td_policy, step, online_episode_idx, start_time, cfg, logger, is_offline):
|
||||
common_metrics = {
|
||||
"episode": online_episode_idx,
|
||||
"step": step,
|
||||
|
@ -65,11 +59,11 @@ def eval_policy_and_log(
|
|||
return_first_video=True,
|
||||
)
|
||||
metrics.update(common_metrics)
|
||||
L.log(metrics, category="eval")
|
||||
logger.log(metrics, category="eval")
|
||||
|
||||
if cfg.wandb.enable:
|
||||
eval_video = L._wandb.Video(first_video, fps=cfg.fps, format="mp4")
|
||||
L._wandb.log({"eval_video": eval_video}, step=step)
|
||||
eval_video = logger._wandb.Video(first_video, fps=cfg.fps, format="mp4")
|
||||
logger._wandb.log({"eval_video": eval_video}, step=step)
|
||||
|
||||
|
||||
def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
@ -116,7 +110,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
sampler=online_sampler,
|
||||
)
|
||||
|
||||
L = Logger(out_dir, job_name, cfg)
|
||||
logger = Logger(out_dir, job_name, cfg)
|
||||
|
||||
online_episode_idx = 0
|
||||
start_time = time.time()
|
||||
|
@ -129,9 +123,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
metrics = policy.update(offline_buffer, step)
|
||||
|
||||
if step % cfg.log_freq == 0:
|
||||
log_training_metrics(
|
||||
L, metrics, step, online_episode_idx, start_time, is_offline=False
|
||||
)
|
||||
log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline=False)
|
||||
|
||||
if step > 0 and step % cfg.eval_freq == 0:
|
||||
eval_policy_and_log(
|
||||
|
@ -141,13 +133,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
online_episode_idx,
|
||||
start_time,
|
||||
cfg,
|
||||
L,
|
||||
logger,
|
||||
is_offline=True,
|
||||
)
|
||||
|
||||
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
||||
print(f"Checkpoint model at step {step}")
|
||||
L.save_model(policy, identifier=step)
|
||||
logger.save_model(policy, identifier=step)
|
||||
|
||||
step += 1
|
||||
|
||||
|
@ -164,9 +156,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
auto_cast_to_device=True,
|
||||
)
|
||||
assert len(rollout) <= cfg.env.episode_length
|
||||
rollout["episode"] = torch.tensor(
|
||||
[online_episode_idx] * len(rollout), dtype=torch.int
|
||||
)
|
||||
rollout["episode"] = torch.tensor([online_episode_idx] * len(rollout), dtype=torch.int)
|
||||
online_buffer.extend(rollout)
|
||||
|
||||
ep_sum_reward = rollout["next", "reward"].sum()
|
||||
|
@ -188,9 +178,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
)
|
||||
metrics.update(train_metrics)
|
||||
if step % cfg.log_freq == 0:
|
||||
log_training_metrics(
|
||||
L, metrics, step, online_episode_idx, start_time, is_offline=False
|
||||
)
|
||||
log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline=False)
|
||||
|
||||
if step > 0 and step % cfg.eval_freq == 0:
|
||||
eval_policy_and_log(
|
||||
|
@ -200,13 +188,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
online_episode_idx,
|
||||
start_time,
|
||||
cfg,
|
||||
L,
|
||||
logger,
|
||||
is_offline=False,
|
||||
)
|
||||
|
||||
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
||||
print(f"Checkpoint model at step {step}")
|
||||
L.save_model(policy, identifier=step)
|
||||
logger.save_model(policy, identifier=step)
|
||||
|
||||
step += 1
|
||||
|
||||
|
|
|
@ -1,24 +1,22 @@
|
|||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import imageio
|
||||
import simxarm
|
||||
import torch
|
||||
from torchrl.data.replay_buffers import (
|
||||
SamplerWithoutReplacement,
|
||||
SliceSampler,
|
||||
SliceSamplerWithoutReplacement,
|
||||
)
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
|
||||
NUM_EPISODES_TO_RENDER = 10
|
||||
MAX_NUM_STEPS = 1000
|
||||
FIRST_FRAME = 0
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||
def visualize_dataset_cli(cfg: dict):
|
||||
visualize_dataset(
|
||||
cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
|
||||
)
|
||||
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
|
||||
|
||||
|
||||
def visualize_dataset(cfg: dict, out_dir=None):
|
||||
|
@ -33,9 +31,6 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
|||
|
||||
offline_buffer = make_offline_buffer(cfg, sampler)
|
||||
|
||||
NUM_EPISODES_TO_RENDER = 10
|
||||
MAX_NUM_STEPS = 1000
|
||||
FIRST_FRAME = 0
|
||||
for _ in range(NUM_EPISODES_TO_RENDER):
|
||||
episode = offline_buffer.sample(MAX_NUM_STEPS)
|
||||
|
||||
|
@ -57,9 +52,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
|||
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
|
||||
assert ep_frames.max().item() <= 255
|
||||
ep_frames = ep_frames.type(torch.uint8)
|
||||
imageio.mimsave(
|
||||
video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps
|
||||
)
|
||||
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps)
|
||||
|
||||
# ran out of episodes
|
||||
if offline_buffer._sampler._sample_list.numel() == 0:
|
||||
|
|
|
@ -192,6 +192,17 @@ files = [
|
|||
[package.dependencies]
|
||||
pycparser = "*"
|
||||
|
||||
[[package]]
|
||||
name = "cfgv"
|
||||
version = "3.4.0"
|
||||
description = "Validate configuration and produce human readable error messages."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"},
|
||||
{file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.3.2"
|
||||
|
@ -420,6 +431,17 @@ url = "https://github.com/real-stanford/diffusion_policy"
|
|||
reference = "HEAD"
|
||||
resolved_reference = "548a52bbb105518058e27bf34dcf90bf6f73681a"
|
||||
|
||||
[[package]]
|
||||
name = "distlib"
|
||||
version = "0.3.8"
|
||||
description = "Distribution utilities"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"},
|
||||
{file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dm-env"
|
||||
version = "1.6"
|
||||
|
@ -741,6 +763,20 @@ antlr4-python3-runtime = "==4.9.*"
|
|||
omegaconf = ">=2.2,<2.4"
|
||||
packaging = "*"
|
||||
|
||||
[[package]]
|
||||
name = "identify"
|
||||
version = "2.5.35"
|
||||
description = "File identification library for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "identify-2.5.35-py2.py3-none-any.whl", hash = "sha256:c4de0081837b211594f8e877a6b4fad7ca32bbfc1a9307fdd61c28bfe923f13e"},
|
||||
{file = "identify-2.5.35.tar.gz", hash = "sha256:10a7ca245cfcd756a554a7288159f72ff105ad233c7c4b9c6f0f4d108f5f6791"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
license = ["ukkonen"]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.6"
|
||||
|
@ -1069,6 +1105,20 @@ doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.
|
|||
extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"]
|
||||
test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.8.0"
|
||||
description = "Node.js virtual environment builder"
|
||||
optional = false
|
||||
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
|
||||
files = [
|
||||
{file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
|
||||
{file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
setuptools = "*"
|
||||
|
||||
[[package]]
|
||||
name = "numba"
|
||||
version = "0.59.0"
|
||||
|
@ -1537,6 +1587,39 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa
|
|||
typing = ["typing-extensions"]
|
||||
xmp = ["defusedxml"]
|
||||
|
||||
[[package]]
|
||||
name = "platformdirs"
|
||||
version = "4.2.0"
|
||||
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "platformdirs-4.2.0-py3-none-any.whl", hash = "sha256:0614df2a2f37e1a662acbd8e2b25b92ccf8632929bc6d43467e17fe89c75e068"},
|
||||
{file = "platformdirs-4.2.0.tar.gz", hash = "sha256:ef0cc731df711022c174543cb70a9b5bd22e5a9337c8624ef2c2ceb8ddad8768"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"]
|
||||
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"]
|
||||
|
||||
[[package]]
|
||||
name = "pre-commit"
|
||||
version = "3.6.2"
|
||||
description = "A framework for managing and maintaining multi-language pre-commit hooks."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pre_commit-3.6.2-py2.py3-none-any.whl", hash = "sha256:ba637c2d7a670c10daedc059f5c49b5bd0aadbccfcd7ec15592cf9665117532c"},
|
||||
{file = "pre_commit-3.6.2.tar.gz", hash = "sha256:c3ef34f463045c88658c5b99f38c1e297abdcc0ff13f98d3370055fbbfabc67e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cfgv = ">=2.0.0"
|
||||
identify = ">=1.0.0"
|
||||
nodeenv = ">=0.11.1"
|
||||
pyyaml = ">=5.1"
|
||||
virtualenv = ">=20.10.0"
|
||||
|
||||
[[package]]
|
||||
name = "proglog"
|
||||
version = "0.1.10"
|
||||
|
@ -2462,6 +2545,26 @@ h2 = ["h2 (>=4,<5)"]
|
|||
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "virtualenv"
|
||||
version = "20.25.1"
|
||||
description = "Virtual Python Environment builder"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "virtualenv-20.25.1-py3-none-any.whl", hash = "sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a"},
|
||||
{file = "virtualenv-20.25.1.tar.gz", hash = "sha256:e08e13ecdca7a0bd53798f356d5831434afa5b07b93f0abdf0797b7a06ffe197"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
distlib = ">=0.3.7,<1"
|
||||
filelock = ">=3.12.2,<4"
|
||||
platformdirs = ">=3.9.1,<5"
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"]
|
||||
test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"]
|
||||
|
||||
[[package]]
|
||||
name = "wandb"
|
||||
version = "0.16.3"
|
||||
|
@ -2538,4 +2641,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "4c34065f18b708f6663ce5740011d2062b2995d1eaefbcb664572870827efd7c"
|
||||
content-hash = "7878b7e80b73355d98402655a8bf51bab122444555cbe8ae5d0f9f1e2effe4b2"
|
||||
|
|
|
@ -48,8 +48,39 @@ opencv-python = "^4.9.0.80"
|
|||
diffusion-policy = {git = "https://github.com/real-stanford/diffusion_policy"}
|
||||
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pre-commit = "^3.6.2"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
target-version = "py310"
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"venv",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||
|
|
Loading…
Reference in New Issue