From e69cd99f33f0e99a0a5d50e927cae14e1ab6cd64 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 23 Apr 2024 16:22:49 +0100 Subject: [PATCH] backup wip --- lerobot/common/datasets/factory.py | 21 +++++++++++---------- lerobot/common/transforms.py | 10 ++++++++++ lerobot/configs/policy/tdmpc.yaml | 9 ++++----- lerobot/scripts/eval.py | 6 +++++- lerobot/scripts/train.py | 20 ++++++++++---------- 5 files changed, 40 insertions(+), 26 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 0fbfff65..53a0660f 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -4,7 +4,7 @@ from pathlib import Path import torch from torchvision.transforms import v2 -from lerobot.common.transforms import NormalizeTransform +from lerobot.common.transforms import IdentityTransform DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None @@ -38,7 +38,7 @@ def make_dataset( # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, # min_max_from_spec # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std - normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" + # normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": stats = {} @@ -62,14 +62,15 @@ def make_dataset( transforms = v2.Compose( [ - NormalizeTransform( - stats, - in_keys=[ - "observation.state", - "action", - ], - mode=normalization_mode, - ), + # TODO(now): Use the transform + # NormalizeTransform( + # in_keys=[ + # "observation.state", + # "action", + # ], + # mode=normalization_mode, + # ), + IdentityTransform() ] ) diff --git a/lerobot/common/transforms.py b/lerobot/common/transforms.py index fffa835a..f4a90f9e 100644 --- a/lerobot/common/transforms.py +++ b/lerobot/common/transforms.py @@ -63,3 +63,13 @@ class NormalizeTransform(Transform): item[outkey] = (item[inkey] + 1) / 2 item[outkey] = item[outkey] * (max - min) + min return item + + +class IdentityTransform(Transform): + invertible = True + + def forward(self, item): + return item + + def inverse_transform(self, item): + return item diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 4fd2b6bb..cf8bf21d 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -46,7 +46,6 @@ policy: kappa: 0.1 lr: 3e-4 std_schedule: ${policy.min_std} - horizon_schedule: ${policy.horizon} per: true per_alpha: 0.6 per_beta: 0.4 @@ -79,7 +78,7 @@ policy: latent_dim: 50 delta_timestamps: - observation.image: "[i / ${fps} for i in range(6)]" - observation.state: "[i / ${fps} for i in range(6)]" - action: "[i / ${fps} for i in range(5)]" - next.reward: "[i / ${fps} for i in range(5)]" + observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + action: "[i / ${fps} for i in range(${policy.horizon})]" + next.reward: "[i / ${fps} for i in range(${policy.horizon})]" diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 87cb58f7..cc0c8b74 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -157,7 +157,11 @@ def eval_policy( # get the next action for the environment with torch.inference_mode(): - action = policy.select_action(observation, step=step) + # TODO(now): restore + observation["observation.image"] *= 255 + # TODO(now): train_step + action = policy.select_action(observation) + observation["observation.image"] /= 255 # Send action: while True: diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index bf4d274e..4568a1ff 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -270,7 +270,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # create dataloader for offline training dataloader = torch.utils.data.DataLoader( offline_dataset, - num_workers=4, + num_workers=32, batch_size=cfg.policy.batch_size, shuffle=True, pin_memory=cfg.device != "cpu", @@ -316,7 +316,7 @@ def train(cfg: dict, out_dir=None, job_name=None): ) dataloader = torch.utils.data.DataLoader( concat_dataset, - num_workers=4, + num_workers=32, batch_size=cfg.policy.batch_size, sampler=sampler, pin_memory=cfg.device != "cpu", @@ -339,14 +339,14 @@ def train(cfg: dict, out_dir=None, job_name=None): seed=cfg.seed, ) - add_episodes_inplace( - online_dataset, - concat_dataset, - sampler, - hf_dataset=eval_info["episodes"]["hf_dataset"], - episode_data_index=eval_info["episodes"]["episode_data_index"], - pc_online_samples=cfg.get("demo_schedule", 0.5), - ) + add_episodes_inplace( + online_dataset, + concat_dataset, + sampler, + hf_dataset=eval_info["episodes"]["hf_dataset"], + episode_data_index=eval_info["episodes"]["episode_data_index"], + pc_online_samples=cfg.get("demo_schedule", 0.5), + ) for _ in range(cfg.policy.utd): policy.train()