From ed80db2846f9855346d57884fbe1b07f26fffc6c Mon Sep 17 00:00:00 2001 From: Cadene Date: Sun, 25 Feb 2024 12:02:29 +0000 Subject: [PATCH] Sanitize cfg.env --- lerobot/common/envs/factory.py | 20 ++++++++++---------- lerobot/common/logger.py | 8 ++++++-- lerobot/configs/default.yaml | 27 ++++++++++++--------------- lerobot/configs/pusht.yaml | 21 ++++++++++++--------- lerobot/scripts/eval.py | 4 ++-- lerobot/scripts/train.py | 8 ++++---- 6 files changed, 46 insertions(+), 42 deletions(-) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 770ea392..fa094734 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -7,26 +7,26 @@ from lerobot.common.envs.transforms import Prod def make_env(cfg): kwargs = { - "frame_skip": cfg.action_repeat, - "from_pixels": cfg.from_pixels, - "pixels_only": cfg.pixels_only, - "image_size": cfg.image_size, + "frame_skip": cfg.env.action_repeat, + "from_pixels": cfg.env.from_pixels, + "pixels_only": cfg.env.pixels_only, + "image_size": cfg.env.image_size, } - if cfg.env == "simxarm": - kwargs["task"] = cfg.task + if cfg.env.name == "simxarm": + kwargs["task"] = cfg.env.task clsfunc = SimxarmEnv - elif cfg.env == "pusht": + elif cfg.env.name == "pusht": clsfunc = PushtEnv else: - raise ValueError(cfg.env) + raise ValueError(cfg.env.name) env = clsfunc(**kwargs) # limit rollout to max_steps - env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length)) + env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) - if cfg.env == "pusht": + if cfg.env.name == "pusht": # to ensure pusht is in [0,255] like simxarm env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0)) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 306abd22..031e062d 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -51,7 +51,7 @@ def print_run(cfg, reward=None): kvs = [ ("task", cfg.task), - ("train steps", f"{int(cfg.train_steps * cfg.action_repeat):,}"), + ("train steps", f"{int(cfg.train_steps * cfg.env.action_repeat):,}"), # ('observations', 'x'.join([str(s) for s in cfg.obs_shape])), # ('actions', cfg.action_dim), # ('experiment', cfg.exp_name), @@ -117,7 +117,11 @@ class VideoRecorder: if self.enabled: frames = np.stack(self.frames).transpose(0, 3, 1, 2) self._wandb.log( - {"eval_video": self._wandb.Video(frames, fps=self.fps, format="mp4")}, + { + "eval_video": self._wandb.Video( + frames, fps=self.env.fps, format="mp4" + ) + }, step=step, ) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index ad3b859e..690d417f 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -13,21 +13,18 @@ eval_episodes: 20 save_video: false save_model: false save_buffer: false - -# env -env: simxarm -task: lift -from_pixels: True -pixels_only: False -image_size: 84 +train_steps: 50000 fps: 15 - - -# xarm_lift -action_repeat: 2 -episode_length: 25 -train_steps: 50000 +env: + name: simxarm + task: lift + from_pixels: True + pixels_only: False + image_size: 84 + action_repeat: 2 + episode_length: 25 + fps: ${fps} policy: @@ -37,14 +34,14 @@ policy: # xarm_lift train_steps: ${train_steps} - episode_length: ${episode_length} + episode_length: ${env.episode_length} discount: 0.9 modality: 'all' # pixels frame_stack: 1 num_channels: 32 - img_size: ${image_size} + img_size: ${env.image_size} state_dim: 4 action_dim: 4 diff --git a/lerobot/configs/pusht.yaml b/lerobot/configs/pusht.yaml index eaee713a..4ca83192 100644 --- a/lerobot/configs/pusht.yaml +++ b/lerobot/configs/pusht.yaml @@ -6,16 +6,19 @@ hydra: job: name: pusht -# env -env: pusht -task: pusht -image_size: 96 -action_repeat: 1 -state_dim: 2 -action_dim: 2 -fps: 10 eval_episodes: 50 -episode_length: 300 eval_freq: 7500 save_freq: 75000 +fps: 10 +env: + name: pusht + task: pusht + image_size: 96 + fps: ${fps} + action_repeat: 1 + episode_length: 300 + +policy: + state_dim: 2 + action_dim: 2 \ No newline at end of file diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 7b3357d9..077f2556 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -126,8 +126,8 @@ def eval(cfg: dict, out_dir=None): policy=policy, save_video=True, video_dir=Path(out_dir) / "eval", - fps=cfg.fps, - max_steps=cfg.episode_length, + fps=cfg.env.fps, + max_steps=cfg.env.episode_length, num_episodes=cfg.eval_episodes, ) print(metrics) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 6c91b83f..61cc1e63 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -88,7 +88,7 @@ def train(cfg: dict, out_dir=None, job_name=None): while step < cfg.train_steps: is_offline = True - num_updates = cfg.episode_length + num_updates = cfg.env.episode_length _step = step + num_updates rollout_metrics = {} @@ -98,11 +98,11 @@ def train(cfg: dict, out_dir=None, job_name=None): # TODO: use SyncDataCollector for that? with torch.no_grad(): rollout = env.rollout( - max_steps=cfg.episode_length, + max_steps=cfg.env.episode_length, policy=td_policy, auto_cast_to_device=True, ) - assert len(rollout) <= cfg.episode_length + assert len(rollout) <= cfg.env.episode_length rollout["episode"] = torch.tensor( [online_episode_idx] * len(rollout), dtype=torch.int ) @@ -133,7 +133,7 @@ def train(cfg: dict, out_dir=None, job_name=None): ) # Log training metrics - env_step = int(_step * cfg.action_repeat) + env_step = int(_step * cfg.env.action_repeat) common_metrics = { "episode": online_episode_idx, "step": _step,