diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 83f94cfe..33742e14 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -27,14 +27,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv if n_envs is not None and n_envs < 1: raise ValueError("`n_envs must be at least 1") - kwargs = { - "obs_type": "pixels_agent_pos", - "render_mode": "rgb_array", - "max_episode_steps": cfg.env.episode_length, - "visualization_width": 384, - "visualization_height": 384, - } - package_name = f"gym_{cfg.env.name}" try: @@ -46,12 +38,16 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv raise e gym_handle = f"{package_name}/{cfg.env.task}" + gym_kwgs = cfg.env.get("gym", {}) + + if cfg.env.get("episode_length"): + gym_kwgs["max_episode_steps"] = cfg.env.episode_length # batched version of the env that returns an observation of shape (b, c) env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv env = env_cls( [ - lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs) + lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs) for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size) ] ) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 9ae30784..f2238769 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -37,6 +37,8 @@ training: save_freq: ??? log_freq: 250 save_checkpoint: true + num_workers: 4 + batch_size: ??? eval: n_episodes: 1 diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 95e4503d..d93afba7 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -5,10 +5,13 @@ fps: 50 env: name: aloha task: AlohaInsertion-v0 - from_pixels: True - pixels_only: False image_size: [3, 480, 640] - episode_length: 400 - fps: ${fps} state_dim: 14 action_dim: 14 + fps: ${fps} + episode_length: 400 + gym: + obs_type: pixels_agent_pos + render_mode: rgb_array + visualization_width: 384 + visualization_height: 384 diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index 43e9d187..771fbbf4 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -5,10 +5,13 @@ fps: 10 env: name: pusht task: PushT-v0 - from_pixels: True - pixels_only: False image_size: 96 - episode_length: 300 - fps: ${fps} state_dim: 2 action_dim: 2 + fps: ${fps} + episode_length: 300 + gym: + obs_type: pixels_agent_pos + render_mode: rgb_array + visualization_width: 384 + visualization_height: 384 diff --git a/lerobot/configs/env/xarm.yaml b/lerobot/configs/env/xarm.yaml index 098b0396..9dbb96f5 100644 --- a/lerobot/configs/env/xarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -5,10 +5,13 @@ fps: 15 env: name: xarm task: XarmLift-v0 - from_pixels: True - pixels_only: False image_size: 84 - episode_length: 25 - fps: ${fps} state_dim: 4 action_dim: 4 + fps: ${fps} + episode_length: 25 + gym: + obs_type: pixels_agent_pos + render_mode: rgb_array + visualization_width: 384 + visualization_height: 384 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5fb86f36..9bf49c05 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -281,8 +281,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("make_dataset") offline_dataset = make_dataset(cfg) - logging.info("make_env") - eval_env = make_env(cfg) + # Create environment used for evaluating checkpoints during training on simulation data. + # On real-world data, no need to create an environment as evaluations are done outside train.py, + # using the eval.py instead, with gym_dora environment and dora-rs. + if cfg.training.eval_freq > 0: + logging.info("make_env") + eval_env = make_env(cfg) logging.info("make_policy") policy = make_policy( @@ -315,7 +319,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Note: this helper will be used in offline and online training loops. def evaluate_and_checkpoint_if_needed(step): - if step % cfg.training.eval_freq == 0: + if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0: logging.info(f"Eval policy at step {step}") with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): eval_info = eval_policy( @@ -349,7 +353,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # create dataloader for offline training dataloader = torch.utils.data.DataLoader( offline_dataset, - num_workers=4, + num_workers=cfg.training.num_workers, batch_size=cfg.training.batch_size, shuffle=True, pin_memory=device.type != "cpu", @@ -386,6 +390,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No step += 1 + logging.info("End of offline training") + + if cfg.training.online_steps == 0: + if cfg.training.eval_freq > 0: + eval_env.close() + return + + # create an env dedicated to online episodes collection from policy rollout + online_training_env = make_env(cfg, n_envs=1) + # create an empty online dataset similar to offline dataset online_dataset = deepcopy(offline_dataset) online_dataset.hf_dataset = {} @@ -406,8 +420,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No drop_last=False, ) + logging.info("End of online training") + eval_env.close() - logging.info("End of training") + online_training_env.close() @hydra.main(version_base="1.2", config_name="default", config_path="../configs")