From 896a11f60e3a0f9d7107642d692de1c20ee4fd48 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 19 Mar 2024 18:50:04 +0000 Subject: [PATCH] backup wip --- lerobot/common/datasets/abstract.py | 2 +- lerobot/common/datasets/aloha.py | 2 +- lerobot/common/envs/aloha/env.py | 73 +++++++++--------- lerobot/common/envs/factory.py | 34 +++----- lerobot/common/envs/pusht/env.py | 62 ++++++++------- lerobot/common/policies/abstract.py | 32 +++++--- lerobot/common/policies/act/policy.py | 7 +- lerobot/common/policies/diffusion/policy.py | 5 +- lerobot/common/policies/factory.py | 3 + lerobot/common/policies/tdmpc/policy.py | 12 ++- lerobot/configs/default.yaml | 7 +- lerobot/configs/policy/diffusion.yaml | 4 +- lerobot/scripts/eval.py | 17 +++- .../data/aloha_sim_insertion_human/stats.pth | Bin 4306 -> 4370 bytes tests/data/pusht/stats.pth | Bin 4242 -> 4306 bytes tests/test_policies.py | 47 +++++++---- 16 files changed, 169 insertions(+), 138 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 5db97497..4ce447bf 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -50,7 +50,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): def stats_patterns(self) -> dict: return { ("observation", "state"): "b c -> c", - ("observation", "image"): "b c h w -> c", + ("observation", "image"): "b c h w -> c 1 1", ("action",): "b c -> c", } diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 0637f8a3..b1a5806f 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -117,7 +117,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): ("action",): "b c -> c", } for cam in CAMERAS[self.dataset_id]: - d[("observation", "image", cam)] = "b c h w -> c" + d[("observation", "image", cam)] = "b c h w -> c 1 1" return d @property diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 7ef24f2d..001b2ba2 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -58,6 +58,7 @@ class AlohaEnv(AbstractEnv): num_prev_obs=num_prev_obs, num_prev_action=num_prev_action, ) + self._reset_warning_issued = False def _make_env(self): if not _has_gym: @@ -120,47 +121,47 @@ class AlohaEnv(AbstractEnv): return obs def _reset(self, tensordict: Optional[TensorDict] = None): - td = tensordict - if td is None or td.is_empty(): - # we need to handle seed iteration, since self._env.reset() rely an internal _seed. - self._current_seed += 1 - self.set_seed(self._current_seed) + if tensordict is not None and not self._reset_warning_issued: + logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") + self._reset_warning_issued = True - # TODO(rcadene): do not use global variable for this - if "sim_transfer_cube" in self.task: - BOX_POSE[0] = sample_box_pose() # used in sim reset - elif "sim_insertion" in self.task: - BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset + # we need to handle seed iteration, since self._env.reset() rely an internal _seed. + self._current_seed += 1 + self.set_seed(self._current_seed) - raw_obs = self._env.reset() - # TODO(rcadene): add assert - # assert self._current_seed == self._env._seed + # TODO(rcadene): do not use global variable for this + if "sim_transfer_cube" in self.task: + BOX_POSE[0] = sample_box_pose() # used in sim reset + elif "sim_insertion" in self.task: + BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset - obs = self._format_raw_obs(raw_obs.observation) + raw_obs = self._env.reset() + # TODO(rcadene): add assert + # assert self._current_seed == self._env._seed - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue = deque( - [obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} - if "state" in obs: - self._prev_obs_state_queue = deque( - [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs + obs = self._format_raw_obs(raw_obs.observation) - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "done": torch.tensor([False], dtype=torch.bool), - }, - batch_size=[], - ) - else: - raise NotImplementedError() + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue = deque( + [obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} + if "state" in obs: + self._prev_obs_state_queue = deque( + [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs + + td = TensorDict( + { + "observation": TensorDict(obs, batch_size=[]), + "done": torch.tensor([False], dtype=torch.bool), + }, + batch_size=[], + ) self.call_rendering_hooks() return td diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index de86b3ad..689f5869 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,31 +1,20 @@ +from torchrl.envs import SerialEnv from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv def make_env(cfg, transform=None): """ - Provide seed to override the seed in the cfg (useful for batched environments). + Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying + environments. The env therefore returns batches.` """ - # assert cfg.rollout_batch_size == 1, \ - # """ - # For the time being, rollout batch sizes of > 1 are not supported. This is because the SerialEnv rollout does not - # correctly handle terminated environments. If you really want to use a larger batch size, read on... - - # When calling `EnvBase.rollout` with `break_when_any_done == True` all environments stop rolling out as soon as the - # first is terminated or truncated. This almost certainly results in incorrect success metrics, as all but the first - # environment get an opportunity to reach the goal. A possible work around is to comment out `if any_done: break` - # inf `EnvBase._rollout_stop_early`. One potential downside is that the environments `step` function will continue - # to be called and the outputs will continue to be added to the rollout. - - # When calling `EnvBase.rollout` with `break_when_any_done == False` environments are reset when done. - # """ kwargs = { "frame_skip": cfg.env.action_repeat, "from_pixels": cfg.env.from_pixels, "pixels_only": cfg.env.pixels_only, "image_size": cfg.env.image_size, - "num_prev_obs": cfg.n_obs_steps - 1, "seed": cfg.seed, + "num_prev_obs": cfg.n_obs_steps - 1, } if cfg.env.name == "simxarm": @@ -67,13 +56,14 @@ def make_env(cfg, transform=None): return env - # return SerialEnv( - # cfg.rollout_batch_size, - # create_env_fn=_make_env, - # create_env_kwargs={ - # "seed": env_seed for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) - # }, - # ) + return SerialEnv( + cfg.rollout_batch_size, + create_env_fn=_make_env, + create_env_kwargs={ + "seed": env_seed # noqa: B035 + for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + }, + ) # def make_env(env_name, frame_skip, device, is_test=False): diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index 2fe05233..6c348cd6 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -1,4 +1,5 @@ import importlib +import logging from collections import deque from typing import Optional @@ -42,6 +43,7 @@ class PushtEnv(AbstractEnv): num_prev_obs=num_prev_obs, num_prev_action=num_prev_action, ) + self._reset_warning_issued = False def _make_env(self): if not _has_gym: @@ -79,39 +81,39 @@ class PushtEnv(AbstractEnv): return obs def _reset(self, tensordict: Optional[TensorDict] = None): - td = tensordict - if td is None or td.is_empty(): - # we need to handle seed iteration, since self._env.reset() rely an internal _seed. - self._current_seed += 1 - self.set_seed(self._current_seed) - raw_obs = self._env.reset() - assert self._current_seed == self._env._seed + if tensordict is not None and not self._reset_warning_issued: + logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") + self._reset_warning_issued = True - obs = self._format_raw_obs(raw_obs) + # we need to handle seed iteration, since self._env.reset() rely an internal _seed. + self._current_seed += 1 + self.set_seed(self._current_seed) + raw_obs = self._env.reset() + assert self._current_seed == self._env._seed - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue = deque( - [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) - if "state" in obs: - self._prev_obs_state_queue = deque( - [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs + obs = self._format_raw_obs(raw_obs) - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "done": torch.tensor([False], dtype=torch.bool), - }, - batch_size=[], - ) - else: - raise NotImplementedError() + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue = deque( + [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) + if "state" in obs: + self._prev_obs_state_queue = deque( + [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs + + td = TensorDict( + { + "observation": TensorDict(obs, batch_size=[]), + "done": torch.tensor([False], dtype=torch.bool), + }, + batch_size=[], + ) self.call_rendering_hooks() return td diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index ca2d8570..9f16f5d7 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -12,6 +12,17 @@ class AbstractPolicy(nn.Module, ABC): documentation for more information. """ + def __init__(self, n_action_steps: int | None): + """ + n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single + action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then + adds that dimension. + """ + super().__init__() + self.n_action_steps = n_action_steps + if n_action_steps is not None: + self._action_queue = deque([], maxlen=n_action_steps) + @abstractmethod def update(self, replay_buffer, step): """One step of the policy's learning algorithm.""" @@ -24,10 +35,11 @@ class AbstractPolicy(nn.Module, ABC): self.load_state_dict(d) @abstractmethod - def select_action(self, observation) -> Tensor: + def select_actions(self, observation) -> Tensor: """Select an action (or trajectory of actions) based on an observation during rollout. - Should return a (batch_size, n_action_steps, *) tensor of actions. + If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of + actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions. """ def forward(self, *args, **kwargs) -> Tensor: @@ -41,18 +53,14 @@ class AbstractPolicy(nn.Module, ABC): observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that the subclass doesn't have to. - This method effectively wraps the `select_action` method of the subclass. The following assumptions are made: - 1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is + This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made: + 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is the action trajectory horizon and * is the action dimensions. - 2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined. + 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined. """ - n_action_steps_attr = "n_action_steps" - if not hasattr(self, n_action_steps_attr): - raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute") - if not hasattr(self, "_action_queue"): - self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr)) + if self.n_action_steps is None: + return self.select_actions(*args, **kwargs) if len(self._action_queue) == 0: # Each element in the queue has shape (B, *). - self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1)) - + self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1)) return self._action_queue.popleft() diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index e0499cdb..539cdcf5 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -42,7 +42,7 @@ def kl_divergence(mu, logvar): class ActionChunkingTransformerPolicy(AbstractPolicy): def __init__(self, cfg, device, n_action_steps=1): - super().__init__() + super().__init__(n_action_steps) self.cfg = cfg self.n_action_steps = n_action_steps self.device = device @@ -147,7 +147,10 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): return loss @torch.no_grad() - def select_action(self, observation, step_count): + def select_actions(self, observation, step_count): + if observation["image"].shape[0] != 1: + raise NotImplementedError("Batch size > 1 not handled") + # TODO(rcadene): remove unused step_count del step_count diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index db004a71..2c47f172 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -34,7 +34,7 @@ class DiffusionPolicy(AbstractPolicy): # parameters passed to step **kwargs, ): - super().__init__() + super().__init__(n_action_steps) self.cfg = cfg noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) @@ -44,7 +44,6 @@ class DiffusionPolicy(AbstractPolicy): **cfg_obs_encoder, ) - self.n_action_steps = n_action_steps # needed for the parent class self.diffusion = DiffusionUnetImagePolicy( shape_meta=shape_meta, noise_scheduler=noise_scheduler, @@ -94,7 +93,7 @@ class DiffusionPolicy(AbstractPolicy): ) @torch.no_grad() - def select_action(self, observation, step_count): + def select_actions(self, observation, step_count): # TODO(rcadene): remove unused step_count del step_count diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index c5e45300..085baab5 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,4 +1,7 @@ def make_policy(cfg): + if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1: + raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.") + if cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPC diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 4c104bcd..320f6f2b 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -90,7 +90,7 @@ class TDMPC(AbstractPolicy): """Implementation of TD-MPC learning + inference.""" def __init__(self, cfg, device): - super().__init__() + super().__init__(None) self.action_dim = cfg.action_dim self.cfg = cfg @@ -125,7 +125,10 @@ class TDMPC(AbstractPolicy): self.model_target.load_state_dict(d["model_target"]) @torch.no_grad() - def select_action(self, observation, step_count): + def select_actions(self, observation, step_count): + if observation["image"].shape[0] != 1: + raise NotImplementedError("Batch size > 1 not handled") + t0 = step_count.item() == 0 obs = { @@ -133,7 +136,8 @@ class TDMPC(AbstractPolicy): "rgb": observation["image"].contiguous(), "state": observation["state"].contiguous(), } - action = self.act(obs, t0=t0, step=self.step.item()) + # Note: unsqueeze needed because `act` still uses non-batch logic. + action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0) return action @torch.no_grad() @@ -144,7 +148,7 @@ class TDMPC(AbstractPolicy): if self.cfg.mpc: a = self.plan(z, t0=t0, step=step) else: - a = self.model.pi(z, self.cfg.min_std * self.model.training) + a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0) return a @torch.no_grad() diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 27b75c88..52fd1d60 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -11,8 +11,7 @@ hydra: seed: 1337 # batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index -# NOTE: batch size of 1 is not yet supported! This is just a placeholder for future support. See -# `lerobot.common.envs.factory.make_env` for more information. +# NOTE: only diffusion policy supports rollout_batch_size > 1 rollout_batch_size: 1 device: cuda # cpu prefetch: 4 @@ -20,7 +19,7 @@ eval_freq: ??? save_freq: ??? eval_episodes: ??? save_video: false -save_model: true +save_model: false save_buffer: false train_steps: ??? fps: ??? @@ -33,7 +32,7 @@ env: ??? policy: ??? wandb: - enable: false + enable: true # Set to true to disable saving an artifact despite save_model == True disable_artifact: false project: lerobot diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index ce8acbd4..0dae5056 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -22,8 +22,8 @@ keypoint_visible_rate: 1.0 obs_as_global_cond: True eval_episodes: 1 -eval_freq: 5000 -save_freq: 5000 +eval_freq: 10000 +save_freq: 100000 log_freq: 250 offline_steps: 1344000 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index c0199c0c..2c564da0 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -51,16 +51,25 @@ def eval_policy( ep_frames.append(env.render()) # noqa: B023 with torch.inference_mode(): + # TODO(alexander-soare): Due the `break_when_any_done == False` this rolls out for max_steps even when all + # envs are done the first time. But we only use the first rollout. This is a waste of compute. rollout = env.rollout( max_steps=max_steps, policy=policy, auto_cast_to_device=True, callback=maybe_render_frame, + break_when_any_done=False, ) - # print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) - batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1) - batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0] - batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1) + # Figure out where in each rollout sequence the first done condition was encountered (results after this won't + # be included). + # Note: this assumes that the shape of the done key is (batch_size, max_steps, 1). + # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. + rollout_steps = rollout["next", "done"].shape[1] + done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps) + mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1) + batch_sum_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).sum(dim=-1) + batch_max_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).max(dim=-1)[0] + batch_success = (rollout["next", "success"] * mask).flatten(start_dim=1).any(dim=-1) sum_rewards.extend(batch_sum_reward.tolist()) max_rewards.extend(batch_max_reward.tolist()) successes.extend(batch_success.tolist()) diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth index f909ed075ce48cf7f677cc82a8859615b15924c1..d41ac18cdfeb94b610369f116db8b267cf642af8 100644 GIT binary patch delta 754 zcmcblI7w;43no5qMsLRULT3;KWK4d?qze`Df(a=y!-N81LLqRW2&mBJcIGvVT#o|X zb@^Et82DKy3-HK@^D~$*6qh8H6zeBdmZatvrKA?QIT+aHvyI&< zz?+>zD$J5;vICoggIM8IYaPqOHj_o!Y!7A4w?3@1(Ds$&9BYLS5jMgp64p^%J8ZIZ z7;LBB{c3gd?+jb@hrBl1kG}vLk}lTSx%mK_1Cyli4@IX(Ulg1gvz47d#;;rRUV(>! zfq`c-KacEWGfpqb-7b4=uLbV2O$^^{3sO;EcHb1PV)HgmFGdH|wq<+n%<|qZal76= zM2UB=@2uChuOqnkZqf_4Q;eUtrzIe8zlTVb{e7!7Hn(`0><)KswheM)-2*c2qe9=y z$qw8O4ioy`TEB=nY9qOu(^f!uk@fmROKoqr&$Lc>^vx!cJXi9crBSY*e74)wFGjw`K*{Y zI3~OESpYc=eDW-yz?vM$FUG{dIe9mq91AFDKzzoS$zS*^8ShNiGEdsn5nSkNVz`+3`85qD`fmnWrA80wxWN`s`X^{IhfU+Pg15ySD dL6ZXo`~`o2Vgv{RyxG`6^oGeB1*F*^q5#o~@`eBa delta 716 zcmbQFbV+f;3nsSKLT7Kr*2&+Pbb(AS5Yvzu#0&&6bHL09AanC_<~595USWwPe}7e@J-4mE z-fPyoY}>85_C{{tw~Id*x+kSQbpQ4Jwe}((r`d45y<{u0b))U(Mb_H`yxBRX=c#X* z+`#7GuyIq5b@$B0HnKbK+PMGkvKHoTwGAq5wsuPPwh0g6v{tz=%ci66vdxBdpRH_@ zYHbe|ezM-dch3rBh_~M3$D7z4m?T3!DLTo&S8)1~s_X=kc(!oqY#s&%2A;|OJhGF^ zIK3n#?3dc!Ena48TE5s8q$2g3<#D)*&EGh^7#)5rUb^S6n)m)wy7l(9x_o=X4Sv|F zFXq|Hu|CwUF=_su%$EWCZxmPB%c`%nIT+4vSG{nH?LI@EJs{(F1qAaaH*h;RRH|OH z7PVY#!xQ(&=5#~1wM=A(ZR^5D>m7C*$kK|V_$$D7ZJ=>+@ac0LOr=K!BP z3n;86PvjS4I>9mdH=i5}C}coH4DpC?l z67_+uV`KsbHv@`2=dYi~!Jb?#bZ-^3ovpYXD_ISO%mF4umF86z~_^0g4bH S2=HcO2hjzSKMF{*K|}$D81YR2 diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth index 8846b8f65ae9b52dc74d369c239d64f42ff214fb..039d5db3d53a93a92163692e0702cc1c5de65298 100644 GIT binary patch delta 575 zcmbQFcu8@?NhUsTMsLRULVFMeWK2HKqze^tg9$M)!-RZcLZ)z`5U9}R4CY2gE)%a@ z=c6nP3`bceU*MJz=V!2EC@x7XDb`P_EJ@8TN=YqpbFyNXJdss@@&Z=w$quXy0((8D zJ3n@7bq?@m=Lq-i`#<>tYk|Q0$GncFqPHADf=rVx-`^a-R=^~AV12o>nMIXzwM`{h z`TeDL!+97O74p@Vs30=W^8O|WM*z^ zU=A`h@`JnOO@28Ru*-p_s?6bHfCh0&Vo9RDMSwRW6EH{^I5fRw?35D&xT_rTCJ5HJwj0SY=G2=HcO2hjzS3k9UvAff=0DxYrv delta 539 zcmcblI7xBCNhY?|LVIt<*2&kIbb(Aa5L1X5#PkI*eZkBSAaipEb0Z_yw$)qKE@fe0 zSjsY4fJa80pFxhHxFoTpSU;(h;Pw{dzg3dnDq;Vfd*3pOB}bEDy8 z0qz2U2P`~}0Ur;6?GxMkp?&iNZWcxv150yDO9NALQ&R&2Lvu@WV%v7EEr=Z>+xGM-Qk>^ z$uGwOb~(^gl^8AtXz->amL%$11b8zt0fU5rg9Ah|Fo1jmv7T%4U4D@DVgd@XARlM| qWkFa5qzn!&aWhQT2Zm;WfPr8IDCmG7z?+R7L>o-rC?L%S5d{DdfR>v8 diff --git a/tests/test_policies.py b/tests/test_policies.py index ee5abdb7..953684ed 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,4 +1,3 @@ - from omegaconf import open_dict import pytest from tensordict import TensorDict @@ -16,35 +15,50 @@ from .utils import DEVICE, init_config @pytest.mark.parametrize( - "env_name,policy_name", + "env_name,policy_name,extra_overrides", [ - ("simxarm", "tdmpc"), - ("pusht", "tdmpc"), - ("simxarm", "diffusion"), - ("pusht", "diffusion"), + ("simxarm", "tdmpc", ["policy.mpc=true"]), + ("pusht", "tdmpc", ["policy.mpc=false"]), + ("simxarm", "diffusion", []), + ("pusht", "diffusion", []), + ("aloha", "act", ["env.task=sim_insertion_scripted"]), ], ) -def test_factory(env_name, policy_name): +def test_concrete_policy(env_name, policy_name, extra_overrides): + """ + Tests: + - Making the policy object. + - Updating the policy. + - Using the policy to select actions at inference time. + """ cfg = init_config( overrides=[ f"env={env_name}", f"policy={policy_name}", f"device={DEVICE}", ] + + extra_overrides ) # Check that we can make the policy object. policy = make_policy(cfg) - # Check that we run select_action and get the appropriate output. + # Check that we run select_actions and get the appropriate output. if env_name == "simxarm": # TODO(rcadene): Not implemented return if policy_name == "tdmpc": # TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this. with open_dict(cfg): - cfg['n_obs_steps'] = 1 + cfg["n_obs_steps"] = 1 offline_buffer = make_offline_buffer(cfg) env = make_env(cfg, transform=offline_buffer.transform) - policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE)) + + policy.update(offline_buffer, torch.tensor(0, device=DEVICE)) + + action = policy( + env.observation_spec.rand()["observation"].to(DEVICE), + torch.tensor(0, device=DEVICE), + ) + assert action.shape == env.action_spec.shape def test_abstract_policy_forward(): @@ -90,21 +104,20 @@ def test_abstract_policy_forward(): def _set_seed(self, seed: int | None): return - class StubPolicy(AbstractPolicy): def __init__(self): - super().__init__() - self.n_action_steps = n_action_steps + super().__init__(n_action_steps) self.n_policy_invocations = 0 def update(self): pass - def select_action(self): + def select_actions(self): self.n_policy_invocations += 1 - return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0) - + return torch.stack( + [torch.tensor([i]) for i in range(self.n_action_steps)] + ).unsqueeze(0) env = StubEnv() policy = StubPolicy() @@ -119,4 +132,4 @@ def test_abstract_policy_forward(): assert len(rollout) == terminate_at + 1 # +1 for the reset observation assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1 - assert torch.equal(rollout['observation'].flatten(), torch.arange(terminate_at + 1)) + assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))