backup wip
This commit is contained in:
parent
ea17f4ce50
commit
896a11f60e
|
@ -50,7 +50,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||
def stats_patterns(self) -> dict:
|
||||
return {
|
||||
("observation", "state"): "b c -> c",
|
||||
("observation", "image"): "b c h w -> c",
|
||||
("observation", "image"): "b c h w -> c 1 1",
|
||||
("action",): "b c -> c",
|
||||
}
|
||||
|
||||
|
|
|
@ -117,7 +117,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
|||
("action",): "b c -> c",
|
||||
}
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
d[("observation", "image", cam)] = "b c h w -> c"
|
||||
d[("observation", "image", cam)] = "b c h w -> c 1 1"
|
||||
return d
|
||||
|
||||
@property
|
||||
|
|
|
@ -58,6 +58,7 @@ class AlohaEnv(AbstractEnv):
|
|||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
self._reset_warning_issued = False
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
|
@ -120,47 +121,47 @@ class AlohaEnv(AbstractEnv):
|
|||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
if tensordict is not None and not self._reset_warning_issued:
|
||||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||
self._reset_warning_issued = True
|
||||
|
||||
# TODO(rcadene): do not use global variable for this
|
||||
if "sim_transfer_cube" in self.task:
|
||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||
elif "sim_insertion" in self.task:
|
||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
|
||||
raw_obs = self._env.reset()
|
||||
# TODO(rcadene): add assert
|
||||
# assert self._current_seed == self._env._seed
|
||||
# TODO(rcadene): do not use global variable for this
|
||||
if "sim_transfer_cube" in self.task:
|
||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||
elif "sim_insertion" in self.task:
|
||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||
|
||||
obs = self._format_raw_obs(raw_obs.observation)
|
||||
raw_obs = self._env.reset()
|
||||
# TODO(rcadene): add assert
|
||||
# assert self._current_seed == self._env._seed
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
obs = self._format_raw_obs(raw_obs.observation)
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
|
||||
self.call_rendering_hooks()
|
||||
return td
|
||||
|
|
|
@ -1,31 +1,20 @@
|
|||
from torchrl.envs import SerialEnv
|
||||
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
||||
|
||||
|
||||
def make_env(cfg, transform=None):
|
||||
"""
|
||||
Provide seed to override the seed in the cfg (useful for batched environments).
|
||||
Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
|
||||
environments. The env therefore returns batches.`
|
||||
"""
|
||||
# assert cfg.rollout_batch_size == 1, \
|
||||
# """
|
||||
# For the time being, rollout batch sizes of > 1 are not supported. This is because the SerialEnv rollout does not
|
||||
# correctly handle terminated environments. If you really want to use a larger batch size, read on...
|
||||
|
||||
# When calling `EnvBase.rollout` with `break_when_any_done == True` all environments stop rolling out as soon as the
|
||||
# first is terminated or truncated. This almost certainly results in incorrect success metrics, as all but the first
|
||||
# environment get an opportunity to reach the goal. A possible work around is to comment out `if any_done: break`
|
||||
# inf `EnvBase._rollout_stop_early`. One potential downside is that the environments `step` function will continue
|
||||
# to be called and the outputs will continue to be added to the rollout.
|
||||
|
||||
# When calling `EnvBase.rollout` with `break_when_any_done == False` environments are reset when done.
|
||||
# """
|
||||
|
||||
kwargs = {
|
||||
"frame_skip": cfg.env.action_repeat,
|
||||
"from_pixels": cfg.env.from_pixels,
|
||||
"pixels_only": cfg.env.pixels_only,
|
||||
"image_size": cfg.env.image_size,
|
||||
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||
"seed": cfg.seed,
|
||||
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||
}
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
|
@ -67,13 +56,14 @@ def make_env(cfg, transform=None):
|
|||
|
||||
return env
|
||||
|
||||
# return SerialEnv(
|
||||
# cfg.rollout_batch_size,
|
||||
# create_env_fn=_make_env,
|
||||
# create_env_kwargs={
|
||||
# "seed": env_seed for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||
# },
|
||||
# )
|
||||
return SerialEnv(
|
||||
cfg.rollout_batch_size,
|
||||
create_env_fn=_make_env,
|
||||
create_env_kwargs={
|
||||
"seed": env_seed # noqa: B035
|
||||
for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# def make_env(env_name, frame_skip, device, is_test=False):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import importlib
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
|
@ -42,6 +43,7 @@ class PushtEnv(AbstractEnv):
|
|||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
self._reset_warning_issued = False
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
|
@ -79,39 +81,39 @@ class PushtEnv(AbstractEnv):
|
|||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
raw_obs = self._env.reset()
|
||||
assert self._current_seed == self._env._seed
|
||||
if tensordict is not None and not self._reset_warning_issued:
|
||||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||
self._reset_warning_issued = True
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
raw_obs = self._env.reset()
|
||||
assert self._current_seed == self._env._seed
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
|
||||
self.call_rendering_hooks()
|
||||
return td
|
||||
|
|
|
@ -12,6 +12,17 @@ class AbstractPolicy(nn.Module, ABC):
|
|||
documentation for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, n_action_steps: int | None):
|
||||
"""
|
||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||
adds that dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
self.n_action_steps = n_action_steps
|
||||
if n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=n_action_steps)
|
||||
|
||||
@abstractmethod
|
||||
def update(self, replay_buffer, step):
|
||||
"""One step of the policy's learning algorithm."""
|
||||
|
@ -24,10 +35,11 @@ class AbstractPolicy(nn.Module, ABC):
|
|||
self.load_state_dict(d)
|
||||
|
||||
@abstractmethod
|
||||
def select_action(self, observation) -> Tensor:
|
||||
def select_actions(self, observation) -> Tensor:
|
||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||
|
||||
Should return a (batch_size, n_action_steps, *) tensor of actions.
|
||||
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
||||
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
||||
"""
|
||||
|
||||
def forward(self, *args, **kwargs) -> Tensor:
|
||||
|
@ -41,18 +53,14 @@ class AbstractPolicy(nn.Module, ABC):
|
|||
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||
the subclass doesn't have to.
|
||||
|
||||
This method effectively wraps the `select_action` method of the subclass. The following assumptions are made:
|
||||
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
||||
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||
the action trajectory horizon and * is the action dimensions.
|
||||
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||
"""
|
||||
n_action_steps_attr = "n_action_steps"
|
||||
if not hasattr(self, n_action_steps_attr):
|
||||
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute")
|
||||
if not hasattr(self, "_action_queue"):
|
||||
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr))
|
||||
if self.n_action_steps is None:
|
||||
return self.select_actions(*args, **kwargs)
|
||||
if len(self._action_queue) == 0:
|
||||
# Each element in the queue has shape (B, *).
|
||||
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1))
|
||||
|
||||
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
|
|
@ -42,7 +42,7 @@ def kl_divergence(mu, logvar):
|
|||
|
||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
super().__init__()
|
||||
super().__init__(n_action_steps)
|
||||
self.cfg = cfg
|
||||
self.n_action_steps = n_action_steps
|
||||
self.device = device
|
||||
|
@ -147,7 +147,10 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
|||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, observation, step_count):
|
||||
def select_actions(self, observation, step_count):
|
||||
if observation["image"].shape[0] != 1:
|
||||
raise NotImplementedError("Batch size > 1 not handled")
|
||||
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ class DiffusionPolicy(AbstractPolicy):
|
|||
# parameters passed to step
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(n_action_steps)
|
||||
self.cfg = cfg
|
||||
|
||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||
|
@ -44,7 +44,6 @@ class DiffusionPolicy(AbstractPolicy):
|
|||
**cfg_obs_encoder,
|
||||
)
|
||||
|
||||
self.n_action_steps = n_action_steps # needed for the parent class
|
||||
self.diffusion = DiffusionUnetImagePolicy(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=noise_scheduler,
|
||||
|
@ -94,7 +93,7 @@ class DiffusionPolicy(AbstractPolicy):
|
|||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, observation, step_count):
|
||||
def select_actions(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
def make_policy(cfg):
|
||||
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
|
||||
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPC
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ class TDMPC(AbstractPolicy):
|
|||
"""Implementation of TD-MPC learning + inference."""
|
||||
|
||||
def __init__(self, cfg, device):
|
||||
super().__init__()
|
||||
super().__init__(None)
|
||||
self.action_dim = cfg.action_dim
|
||||
|
||||
self.cfg = cfg
|
||||
|
@ -125,7 +125,10 @@ class TDMPC(AbstractPolicy):
|
|||
self.model_target.load_state_dict(d["model_target"])
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, observation, step_count):
|
||||
def select_actions(self, observation, step_count):
|
||||
if observation["image"].shape[0] != 1:
|
||||
raise NotImplementedError("Batch size > 1 not handled")
|
||||
|
||||
t0 = step_count.item() == 0
|
||||
|
||||
obs = {
|
||||
|
@ -133,7 +136,8 @@ class TDMPC(AbstractPolicy):
|
|||
"rgb": observation["image"].contiguous(),
|
||||
"state": observation["state"].contiguous(),
|
||||
}
|
||||
action = self.act(obs, t0=t0, step=self.step.item())
|
||||
# Note: unsqueeze needed because `act` still uses non-batch logic.
|
||||
action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0)
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -144,7 +148,7 @@ class TDMPC(AbstractPolicy):
|
|||
if self.cfg.mpc:
|
||||
a = self.plan(z, t0=t0, step=step)
|
||||
else:
|
||||
a = self.model.pi(z, self.cfg.min_std * self.model.training)
|
||||
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
||||
return a
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -11,8 +11,7 @@ hydra:
|
|||
|
||||
seed: 1337
|
||||
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
|
||||
# NOTE: batch size of 1 is not yet supported! This is just a placeholder for future support. See
|
||||
# `lerobot.common.envs.factory.make_env` for more information.
|
||||
# NOTE: only diffusion policy supports rollout_batch_size > 1
|
||||
rollout_batch_size: 1
|
||||
device: cuda # cpu
|
||||
prefetch: 4
|
||||
|
@ -20,7 +19,7 @@ eval_freq: ???
|
|||
save_freq: ???
|
||||
eval_episodes: ???
|
||||
save_video: false
|
||||
save_model: true
|
||||
save_model: false
|
||||
save_buffer: false
|
||||
train_steps: ???
|
||||
fps: ???
|
||||
|
@ -33,7 +32,7 @@ env: ???
|
|||
policy: ???
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
enable: true
|
||||
# Set to true to disable saving an artifact despite save_model == True
|
||||
disable_artifact: false
|
||||
project: lerobot
|
||||
|
|
|
@ -22,8 +22,8 @@ keypoint_visible_rate: 1.0
|
|||
obs_as_global_cond: True
|
||||
|
||||
eval_episodes: 1
|
||||
eval_freq: 5000
|
||||
save_freq: 5000
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
|
||||
offline_steps: 1344000
|
||||
|
|
|
@ -51,16 +51,25 @@ def eval_policy(
|
|||
ep_frames.append(env.render()) # noqa: B023
|
||||
|
||||
with torch.inference_mode():
|
||||
# TODO(alexander-soare): Due the `break_when_any_done == False` this rolls out for max_steps even when all
|
||||
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
auto_cast_to_device=True,
|
||||
callback=maybe_render_frame,
|
||||
break_when_any_done=False,
|
||||
)
|
||||
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
||||
batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1)
|
||||
batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0]
|
||||
batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1)
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after this won't
|
||||
# be included).
|
||||
# Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).
|
||||
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||
rollout_steps = rollout["next", "done"].shape[1]
|
||||
done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps)
|
||||
mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1)
|
||||
batch_sum_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).sum(dim=-1)
|
||||
batch_max_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).max(dim=-1)[0]
|
||||
batch_success = (rollout["next", "success"] * mask).flatten(start_dim=1).any(dim=-1)
|
||||
sum_rewards.extend(batch_sum_reward.tolist())
|
||||
max_rewards.extend(batch_max_reward.tolist())
|
||||
successes.extend(batch_success.tolist())
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -1,4 +1,3 @@
|
|||
|
||||
from omegaconf import open_dict
|
||||
import pytest
|
||||
from tensordict import TensorDict
|
||||
|
@ -16,35 +15,50 @@ from .utils import DEVICE, init_config
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name",
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
("simxarm", "tdmpc"),
|
||||
("pusht", "tdmpc"),
|
||||
("simxarm", "diffusion"),
|
||||
("pusht", "diffusion"),
|
||||
("simxarm", "tdmpc", ["policy.mpc=true"]),
|
||||
("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||
("simxarm", "diffusion", []),
|
||||
("pusht", "diffusion", []),
|
||||
("aloha", "act", ["env.task=sim_insertion_scripted"]),
|
||||
],
|
||||
)
|
||||
def test_factory(env_name, policy_name):
|
||||
def test_concrete_policy(env_name, policy_name, extra_overrides):
|
||||
"""
|
||||
Tests:
|
||||
- Making the policy object.
|
||||
- Updating the policy.
|
||||
- Using the policy to select actions at inference time.
|
||||
"""
|
||||
cfg = init_config(
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
]
|
||||
+ extra_overrides
|
||||
)
|
||||
# Check that we can make the policy object.
|
||||
policy = make_policy(cfg)
|
||||
# Check that we run select_action and get the appropriate output.
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
if env_name == "simxarm":
|
||||
# TODO(rcadene): Not implemented
|
||||
return
|
||||
if policy_name == "tdmpc":
|
||||
# TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this.
|
||||
with open_dict(cfg):
|
||||
cfg['n_obs_steps'] = 1
|
||||
cfg["n_obs_steps"] = 1
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE))
|
||||
|
||||
policy.update(offline_buffer, torch.tensor(0, device=DEVICE))
|
||||
|
||||
action = policy(
|
||||
env.observation_spec.rand()["observation"].to(DEVICE),
|
||||
torch.tensor(0, device=DEVICE),
|
||||
)
|
||||
assert action.shape == env.action_spec.shape
|
||||
|
||||
|
||||
def test_abstract_policy_forward():
|
||||
|
@ -90,21 +104,20 @@ def test_abstract_policy_forward():
|
|||
|
||||
def _set_seed(self, seed: int | None):
|
||||
return
|
||||
|
||||
|
||||
class StubPolicy(AbstractPolicy):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.n_action_steps = n_action_steps
|
||||
super().__init__(n_action_steps)
|
||||
self.n_policy_invocations = 0
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
def select_action(self):
|
||||
def select_actions(self):
|
||||
self.n_policy_invocations += 1
|
||||
return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0)
|
||||
|
||||
return torch.stack(
|
||||
[torch.tensor([i]) for i in range(self.n_action_steps)]
|
||||
).unsqueeze(0)
|
||||
|
||||
env = StubEnv()
|
||||
policy = StubPolicy()
|
||||
|
@ -119,4 +132,4 @@ def test_abstract_policy_forward():
|
|||
|
||||
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
|
||||
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
|
||||
assert torch.equal(rollout['observation'].flatten(), torch.arange(terminate_at + 1))
|
||||
assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))
|
||||
|
|
Loading…
Reference in New Issue