diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 5db97497..4ce447bf 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -50,7 +50,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): def stats_patterns(self) -> dict: return { ("observation", "state"): "b c -> c", - ("observation", "image"): "b c h w -> c", + ("observation", "image"): "b c h w -> c 1 1", ("action",): "b c -> c", } diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 0637f8a3..b1a5806f 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -117,7 +117,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): ("action",): "b c -> c", } for cam in CAMERAS[self.dataset_id]: - d[("observation", "image", cam)] = "b c h w -> c" + d[("observation", "image", cam)] = "b c h w -> c 1 1" return d @property diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 7ef24f2d..001b2ba2 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -58,6 +58,7 @@ class AlohaEnv(AbstractEnv): num_prev_obs=num_prev_obs, num_prev_action=num_prev_action, ) + self._reset_warning_issued = False def _make_env(self): if not _has_gym: @@ -120,47 +121,47 @@ class AlohaEnv(AbstractEnv): return obs def _reset(self, tensordict: Optional[TensorDict] = None): - td = tensordict - if td is None or td.is_empty(): - # we need to handle seed iteration, since self._env.reset() rely an internal _seed. - self._current_seed += 1 - self.set_seed(self._current_seed) + if tensordict is not None and not self._reset_warning_issued: + logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") + self._reset_warning_issued = True - # TODO(rcadene): do not use global variable for this - if "sim_transfer_cube" in self.task: - BOX_POSE[0] = sample_box_pose() # used in sim reset - elif "sim_insertion" in self.task: - BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset + # we need to handle seed iteration, since self._env.reset() rely an internal _seed. + self._current_seed += 1 + self.set_seed(self._current_seed) - raw_obs = self._env.reset() - # TODO(rcadene): add assert - # assert self._current_seed == self._env._seed + # TODO(rcadene): do not use global variable for this + if "sim_transfer_cube" in self.task: + BOX_POSE[0] = sample_box_pose() # used in sim reset + elif "sim_insertion" in self.task: + BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset - obs = self._format_raw_obs(raw_obs.observation) + raw_obs = self._env.reset() + # TODO(rcadene): add assert + # assert self._current_seed == self._env._seed - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue = deque( - [obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} - if "state" in obs: - self._prev_obs_state_queue = deque( - [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs + obs = self._format_raw_obs(raw_obs.observation) - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "done": torch.tensor([False], dtype=torch.bool), - }, - batch_size=[], - ) - else: - raise NotImplementedError() + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue = deque( + [obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} + if "state" in obs: + self._prev_obs_state_queue = deque( + [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs + + td = TensorDict( + { + "observation": TensorDict(obs, batch_size=[]), + "done": torch.tensor([False], dtype=torch.bool), + }, + batch_size=[], + ) self.call_rendering_hooks() return td diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index de86b3ad..689f5869 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,31 +1,20 @@ +from torchrl.envs import SerialEnv from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv def make_env(cfg, transform=None): """ - Provide seed to override the seed in the cfg (useful for batched environments). + Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying + environments. The env therefore returns batches.` """ - # assert cfg.rollout_batch_size == 1, \ - # """ - # For the time being, rollout batch sizes of > 1 are not supported. This is because the SerialEnv rollout does not - # correctly handle terminated environments. If you really want to use a larger batch size, read on... - - # When calling `EnvBase.rollout` with `break_when_any_done == True` all environments stop rolling out as soon as the - # first is terminated or truncated. This almost certainly results in incorrect success metrics, as all but the first - # environment get an opportunity to reach the goal. A possible work around is to comment out `if any_done: break` - # inf `EnvBase._rollout_stop_early`. One potential downside is that the environments `step` function will continue - # to be called and the outputs will continue to be added to the rollout. - - # When calling `EnvBase.rollout` with `break_when_any_done == False` environments are reset when done. - # """ kwargs = { "frame_skip": cfg.env.action_repeat, "from_pixels": cfg.env.from_pixels, "pixels_only": cfg.env.pixels_only, "image_size": cfg.env.image_size, - "num_prev_obs": cfg.n_obs_steps - 1, "seed": cfg.seed, + "num_prev_obs": cfg.n_obs_steps - 1, } if cfg.env.name == "simxarm": @@ -67,13 +56,14 @@ def make_env(cfg, transform=None): return env - # return SerialEnv( - # cfg.rollout_batch_size, - # create_env_fn=_make_env, - # create_env_kwargs={ - # "seed": env_seed for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) - # }, - # ) + return SerialEnv( + cfg.rollout_batch_size, + create_env_fn=_make_env, + create_env_kwargs={ + "seed": env_seed # noqa: B035 + for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + }, + ) # def make_env(env_name, frame_skip, device, is_test=False): diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index 2fe05233..6c348cd6 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -1,4 +1,5 @@ import importlib +import logging from collections import deque from typing import Optional @@ -42,6 +43,7 @@ class PushtEnv(AbstractEnv): num_prev_obs=num_prev_obs, num_prev_action=num_prev_action, ) + self._reset_warning_issued = False def _make_env(self): if not _has_gym: @@ -79,39 +81,39 @@ class PushtEnv(AbstractEnv): return obs def _reset(self, tensordict: Optional[TensorDict] = None): - td = tensordict - if td is None or td.is_empty(): - # we need to handle seed iteration, since self._env.reset() rely an internal _seed. - self._current_seed += 1 - self.set_seed(self._current_seed) - raw_obs = self._env.reset() - assert self._current_seed == self._env._seed + if tensordict is not None and not self._reset_warning_issued: + logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") + self._reset_warning_issued = True - obs = self._format_raw_obs(raw_obs) + # we need to handle seed iteration, since self._env.reset() rely an internal _seed. + self._current_seed += 1 + self.set_seed(self._current_seed) + raw_obs = self._env.reset() + assert self._current_seed == self._env._seed - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue = deque( - [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) - if "state" in obs: - self._prev_obs_state_queue = deque( - [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs + obs = self._format_raw_obs(raw_obs) - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "done": torch.tensor([False], dtype=torch.bool), - }, - batch_size=[], - ) - else: - raise NotImplementedError() + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue = deque( + [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) + if "state" in obs: + self._prev_obs_state_queue = deque( + [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs + + td = TensorDict( + { + "observation": TensorDict(obs, batch_size=[]), + "done": torch.tensor([False], dtype=torch.bool), + }, + batch_size=[], + ) self.call_rendering_hooks() return td diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index ca2d8570..9f16f5d7 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -12,6 +12,17 @@ class AbstractPolicy(nn.Module, ABC): documentation for more information. """ + def __init__(self, n_action_steps: int | None): + """ + n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single + action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then + adds that dimension. + """ + super().__init__() + self.n_action_steps = n_action_steps + if n_action_steps is not None: + self._action_queue = deque([], maxlen=n_action_steps) + @abstractmethod def update(self, replay_buffer, step): """One step of the policy's learning algorithm.""" @@ -24,10 +35,11 @@ class AbstractPolicy(nn.Module, ABC): self.load_state_dict(d) @abstractmethod - def select_action(self, observation) -> Tensor: + def select_actions(self, observation) -> Tensor: """Select an action (or trajectory of actions) based on an observation during rollout. - Should return a (batch_size, n_action_steps, *) tensor of actions. + If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of + actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions. """ def forward(self, *args, **kwargs) -> Tensor: @@ -41,18 +53,14 @@ class AbstractPolicy(nn.Module, ABC): observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that the subclass doesn't have to. - This method effectively wraps the `select_action` method of the subclass. The following assumptions are made: - 1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is + This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made: + 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is the action trajectory horizon and * is the action dimensions. - 2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined. + 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined. """ - n_action_steps_attr = "n_action_steps" - if not hasattr(self, n_action_steps_attr): - raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute") - if not hasattr(self, "_action_queue"): - self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr)) + if self.n_action_steps is None: + return self.select_actions(*args, **kwargs) if len(self._action_queue) == 0: # Each element in the queue has shape (B, *). - self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1)) - + self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1)) return self._action_queue.popleft() diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index e0499cdb..539cdcf5 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -42,7 +42,7 @@ def kl_divergence(mu, logvar): class ActionChunkingTransformerPolicy(AbstractPolicy): def __init__(self, cfg, device, n_action_steps=1): - super().__init__() + super().__init__(n_action_steps) self.cfg = cfg self.n_action_steps = n_action_steps self.device = device @@ -147,7 +147,10 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): return loss @torch.no_grad() - def select_action(self, observation, step_count): + def select_actions(self, observation, step_count): + if observation["image"].shape[0] != 1: + raise NotImplementedError("Batch size > 1 not handled") + # TODO(rcadene): remove unused step_count del step_count diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index db004a71..2c47f172 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -34,7 +34,7 @@ class DiffusionPolicy(AbstractPolicy): # parameters passed to step **kwargs, ): - super().__init__() + super().__init__(n_action_steps) self.cfg = cfg noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) @@ -44,7 +44,6 @@ class DiffusionPolicy(AbstractPolicy): **cfg_obs_encoder, ) - self.n_action_steps = n_action_steps # needed for the parent class self.diffusion = DiffusionUnetImagePolicy( shape_meta=shape_meta, noise_scheduler=noise_scheduler, @@ -94,7 +93,7 @@ class DiffusionPolicy(AbstractPolicy): ) @torch.no_grad() - def select_action(self, observation, step_count): + def select_actions(self, observation, step_count): # TODO(rcadene): remove unused step_count del step_count diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index c5e45300..085baab5 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,4 +1,7 @@ def make_policy(cfg): + if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1: + raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.") + if cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPC diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 4c104bcd..320f6f2b 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -90,7 +90,7 @@ class TDMPC(AbstractPolicy): """Implementation of TD-MPC learning + inference.""" def __init__(self, cfg, device): - super().__init__() + super().__init__(None) self.action_dim = cfg.action_dim self.cfg = cfg @@ -125,7 +125,10 @@ class TDMPC(AbstractPolicy): self.model_target.load_state_dict(d["model_target"]) @torch.no_grad() - def select_action(self, observation, step_count): + def select_actions(self, observation, step_count): + if observation["image"].shape[0] != 1: + raise NotImplementedError("Batch size > 1 not handled") + t0 = step_count.item() == 0 obs = { @@ -133,7 +136,8 @@ class TDMPC(AbstractPolicy): "rgb": observation["image"].contiguous(), "state": observation["state"].contiguous(), } - action = self.act(obs, t0=t0, step=self.step.item()) + # Note: unsqueeze needed because `act` still uses non-batch logic. + action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0) return action @torch.no_grad() @@ -144,7 +148,7 @@ class TDMPC(AbstractPolicy): if self.cfg.mpc: a = self.plan(z, t0=t0, step=step) else: - a = self.model.pi(z, self.cfg.min_std * self.model.training) + a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0) return a @torch.no_grad() diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 27b75c88..52fd1d60 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -11,8 +11,7 @@ hydra: seed: 1337 # batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index -# NOTE: batch size of 1 is not yet supported! This is just a placeholder for future support. See -# `lerobot.common.envs.factory.make_env` for more information. +# NOTE: only diffusion policy supports rollout_batch_size > 1 rollout_batch_size: 1 device: cuda # cpu prefetch: 4 @@ -20,7 +19,7 @@ eval_freq: ??? save_freq: ??? eval_episodes: ??? save_video: false -save_model: true +save_model: false save_buffer: false train_steps: ??? fps: ??? @@ -33,7 +32,7 @@ env: ??? policy: ??? wandb: - enable: false + enable: true # Set to true to disable saving an artifact despite save_model == True disable_artifact: false project: lerobot diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index ce8acbd4..0dae5056 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -22,8 +22,8 @@ keypoint_visible_rate: 1.0 obs_as_global_cond: True eval_episodes: 1 -eval_freq: 5000 -save_freq: 5000 +eval_freq: 10000 +save_freq: 100000 log_freq: 250 offline_steps: 1344000 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index c0199c0c..2c564da0 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -51,16 +51,25 @@ def eval_policy( ep_frames.append(env.render()) # noqa: B023 with torch.inference_mode(): + # TODO(alexander-soare): Due the `break_when_any_done == False` this rolls out for max_steps even when all + # envs are done the first time. But we only use the first rollout. This is a waste of compute. rollout = env.rollout( max_steps=max_steps, policy=policy, auto_cast_to_device=True, callback=maybe_render_frame, + break_when_any_done=False, ) - # print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) - batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1) - batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0] - batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1) + # Figure out where in each rollout sequence the first done condition was encountered (results after this won't + # be included). + # Note: this assumes that the shape of the done key is (batch_size, max_steps, 1). + # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. + rollout_steps = rollout["next", "done"].shape[1] + done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps) + mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1) + batch_sum_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).sum(dim=-1) + batch_max_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).max(dim=-1)[0] + batch_success = (rollout["next", "success"] * mask).flatten(start_dim=1).any(dim=-1) sum_rewards.extend(batch_sum_reward.tolist()) max_rewards.extend(batch_max_reward.tolist()) successes.extend(batch_success.tolist()) diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth index f909ed07..d41ac18c 100644 Binary files a/tests/data/aloha_sim_insertion_human/stats.pth and b/tests/data/aloha_sim_insertion_human/stats.pth differ diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth index 8846b8f6..039d5db3 100644 Binary files a/tests/data/pusht/stats.pth and b/tests/data/pusht/stats.pth differ diff --git a/tests/test_policies.py b/tests/test_policies.py index ee5abdb7..953684ed 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,4 +1,3 @@ - from omegaconf import open_dict import pytest from tensordict import TensorDict @@ -16,35 +15,50 @@ from .utils import DEVICE, init_config @pytest.mark.parametrize( - "env_name,policy_name", + "env_name,policy_name,extra_overrides", [ - ("simxarm", "tdmpc"), - ("pusht", "tdmpc"), - ("simxarm", "diffusion"), - ("pusht", "diffusion"), + ("simxarm", "tdmpc", ["policy.mpc=true"]), + ("pusht", "tdmpc", ["policy.mpc=false"]), + ("simxarm", "diffusion", []), + ("pusht", "diffusion", []), + ("aloha", "act", ["env.task=sim_insertion_scripted"]), ], ) -def test_factory(env_name, policy_name): +def test_concrete_policy(env_name, policy_name, extra_overrides): + """ + Tests: + - Making the policy object. + - Updating the policy. + - Using the policy to select actions at inference time. + """ cfg = init_config( overrides=[ f"env={env_name}", f"policy={policy_name}", f"device={DEVICE}", ] + + extra_overrides ) # Check that we can make the policy object. policy = make_policy(cfg) - # Check that we run select_action and get the appropriate output. + # Check that we run select_actions and get the appropriate output. if env_name == "simxarm": # TODO(rcadene): Not implemented return if policy_name == "tdmpc": # TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this. with open_dict(cfg): - cfg['n_obs_steps'] = 1 + cfg["n_obs_steps"] = 1 offline_buffer = make_offline_buffer(cfg) env = make_env(cfg, transform=offline_buffer.transform) - policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE)) + + policy.update(offline_buffer, torch.tensor(0, device=DEVICE)) + + action = policy( + env.observation_spec.rand()["observation"].to(DEVICE), + torch.tensor(0, device=DEVICE), + ) + assert action.shape == env.action_spec.shape def test_abstract_policy_forward(): @@ -90,21 +104,20 @@ def test_abstract_policy_forward(): def _set_seed(self, seed: int | None): return - class StubPolicy(AbstractPolicy): def __init__(self): - super().__init__() - self.n_action_steps = n_action_steps + super().__init__(n_action_steps) self.n_policy_invocations = 0 def update(self): pass - def select_action(self): + def select_actions(self): self.n_policy_invocations += 1 - return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0) - + return torch.stack( + [torch.tensor([i]) for i in range(self.n_action_steps)] + ).unsqueeze(0) env = StubEnv() policy = StubPolicy() @@ -119,4 +132,4 @@ def test_abstract_policy_forward(): assert len(rollout) == terminate_at + 1 # +1 for the reset observation assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1 - assert torch.equal(rollout['observation'].flatten(), torch.arange(terminate_at + 1)) + assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))