backup wip
This commit is contained in:
parent
ea17f4ce50
commit
896a11f60e
|
@ -50,7 +50,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||||
def stats_patterns(self) -> dict:
|
def stats_patterns(self) -> dict:
|
||||||
return {
|
return {
|
||||||
("observation", "state"): "b c -> c",
|
("observation", "state"): "b c -> c",
|
||||||
("observation", "image"): "b c h w -> c",
|
("observation", "image"): "b c h w -> c 1 1",
|
||||||
("action",): "b c -> c",
|
("action",): "b c -> c",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -117,7 +117,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||||
("action",): "b c -> c",
|
("action",): "b c -> c",
|
||||||
}
|
}
|
||||||
for cam in CAMERAS[self.dataset_id]:
|
for cam in CAMERAS[self.dataset_id]:
|
||||||
d[("observation", "image", cam)] = "b c h w -> c"
|
d[("observation", "image", cam)] = "b c h w -> c 1 1"
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -58,6 +58,7 @@ class AlohaEnv(AbstractEnv):
|
||||||
num_prev_obs=num_prev_obs,
|
num_prev_obs=num_prev_obs,
|
||||||
num_prev_action=num_prev_action,
|
num_prev_action=num_prev_action,
|
||||||
)
|
)
|
||||||
|
self._reset_warning_issued = False
|
||||||
|
|
||||||
def _make_env(self):
|
def _make_env(self):
|
||||||
if not _has_gym:
|
if not _has_gym:
|
||||||
|
@ -120,8 +121,10 @@ class AlohaEnv(AbstractEnv):
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
td = tensordict
|
if tensordict is not None and not self._reset_warning_issued:
|
||||||
if td is None or td.is_empty():
|
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||||
|
self._reset_warning_issued = True
|
||||||
|
|
||||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||||
self._current_seed += 1
|
self._current_seed += 1
|
||||||
self.set_seed(self._current_seed)
|
self.set_seed(self._current_seed)
|
||||||
|
@ -159,8 +162,6 @@ class AlohaEnv(AbstractEnv):
|
||||||
},
|
},
|
||||||
batch_size=[],
|
batch_size=[],
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
self.call_rendering_hooks()
|
self.call_rendering_hooks()
|
||||||
return td
|
return td
|
||||||
|
|
|
@ -1,31 +1,20 @@
|
||||||
|
from torchrl.envs import SerialEnv
|
||||||
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg, transform=None):
|
def make_env(cfg, transform=None):
|
||||||
"""
|
"""
|
||||||
Provide seed to override the seed in the cfg (useful for batched environments).
|
Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
|
||||||
|
environments. The env therefore returns batches.`
|
||||||
"""
|
"""
|
||||||
# assert cfg.rollout_batch_size == 1, \
|
|
||||||
# """
|
|
||||||
# For the time being, rollout batch sizes of > 1 are not supported. This is because the SerialEnv rollout does not
|
|
||||||
# correctly handle terminated environments. If you really want to use a larger batch size, read on...
|
|
||||||
|
|
||||||
# When calling `EnvBase.rollout` with `break_when_any_done == True` all environments stop rolling out as soon as the
|
|
||||||
# first is terminated or truncated. This almost certainly results in incorrect success metrics, as all but the first
|
|
||||||
# environment get an opportunity to reach the goal. A possible work around is to comment out `if any_done: break`
|
|
||||||
# inf `EnvBase._rollout_stop_early`. One potential downside is that the environments `step` function will continue
|
|
||||||
# to be called and the outputs will continue to be added to the rollout.
|
|
||||||
|
|
||||||
# When calling `EnvBase.rollout` with `break_when_any_done == False` environments are reset when done.
|
|
||||||
# """
|
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"frame_skip": cfg.env.action_repeat,
|
"frame_skip": cfg.env.action_repeat,
|
||||||
"from_pixels": cfg.env.from_pixels,
|
"from_pixels": cfg.env.from_pixels,
|
||||||
"pixels_only": cfg.env.pixels_only,
|
"pixels_only": cfg.env.pixels_only,
|
||||||
"image_size": cfg.env.image_size,
|
"image_size": cfg.env.image_size,
|
||||||
"num_prev_obs": cfg.n_obs_steps - 1,
|
|
||||||
"seed": cfg.seed,
|
"seed": cfg.seed,
|
||||||
|
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.env.name == "simxarm":
|
if cfg.env.name == "simxarm":
|
||||||
|
@ -67,13 +56,14 @@ def make_env(cfg, transform=None):
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
# return SerialEnv(
|
return SerialEnv(
|
||||||
# cfg.rollout_batch_size,
|
cfg.rollout_batch_size,
|
||||||
# create_env_fn=_make_env,
|
create_env_fn=_make_env,
|
||||||
# create_env_kwargs={
|
create_env_kwargs={
|
||||||
# "seed": env_seed for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
"seed": env_seed # noqa: B035
|
||||||
# },
|
for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||||
# )
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# def make_env(env_name, frame_skip, device, is_test=False):
|
# def make_env(env_name, frame_skip, device, is_test=False):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
import logging
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -42,6 +43,7 @@ class PushtEnv(AbstractEnv):
|
||||||
num_prev_obs=num_prev_obs,
|
num_prev_obs=num_prev_obs,
|
||||||
num_prev_action=num_prev_action,
|
num_prev_action=num_prev_action,
|
||||||
)
|
)
|
||||||
|
self._reset_warning_issued = False
|
||||||
|
|
||||||
def _make_env(self):
|
def _make_env(self):
|
||||||
if not _has_gym:
|
if not _has_gym:
|
||||||
|
@ -79,8 +81,10 @@ class PushtEnv(AbstractEnv):
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
td = tensordict
|
if tensordict is not None and not self._reset_warning_issued:
|
||||||
if td is None or td.is_empty():
|
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||||
|
self._reset_warning_issued = True
|
||||||
|
|
||||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||||
self._current_seed += 1
|
self._current_seed += 1
|
||||||
self.set_seed(self._current_seed)
|
self.set_seed(self._current_seed)
|
||||||
|
@ -110,8 +114,6 @@ class PushtEnv(AbstractEnv):
|
||||||
},
|
},
|
||||||
batch_size=[],
|
batch_size=[],
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
self.call_rendering_hooks()
|
self.call_rendering_hooks()
|
||||||
return td
|
return td
|
||||||
|
|
|
@ -12,6 +12,17 @@ class AbstractPolicy(nn.Module, ABC):
|
||||||
documentation for more information.
|
documentation for more information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_action_steps: int | None):
|
||||||
|
"""
|
||||||
|
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||||
|
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||||
|
adds that dimension.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.n_action_steps = n_action_steps
|
||||||
|
if n_action_steps is not None:
|
||||||
|
self._action_queue = deque([], maxlen=n_action_steps)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(self, replay_buffer, step):
|
def update(self, replay_buffer, step):
|
||||||
"""One step of the policy's learning algorithm."""
|
"""One step of the policy's learning algorithm."""
|
||||||
|
@ -24,10 +35,11 @@ class AbstractPolicy(nn.Module, ABC):
|
||||||
self.load_state_dict(d)
|
self.load_state_dict(d)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def select_action(self, observation) -> Tensor:
|
def select_actions(self, observation) -> Tensor:
|
||||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||||
|
|
||||||
Should return a (batch_size, n_action_steps, *) tensor of actions.
|
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
||||||
|
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, *args, **kwargs) -> Tensor:
|
def forward(self, *args, **kwargs) -> Tensor:
|
||||||
|
@ -41,18 +53,14 @@ class AbstractPolicy(nn.Module, ABC):
|
||||||
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||||
the subclass doesn't have to.
|
the subclass doesn't have to.
|
||||||
|
|
||||||
This method effectively wraps the `select_action` method of the subclass. The following assumptions are made:
|
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
||||||
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||||
the action trajectory horizon and * is the action dimensions.
|
the action trajectory horizon and * is the action dimensions.
|
||||||
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined.
|
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||||
"""
|
"""
|
||||||
n_action_steps_attr = "n_action_steps"
|
if self.n_action_steps is None:
|
||||||
if not hasattr(self, n_action_steps_attr):
|
return self.select_actions(*args, **kwargs)
|
||||||
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute")
|
|
||||||
if not hasattr(self, "_action_queue"):
|
|
||||||
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr))
|
|
||||||
if len(self._action_queue) == 0:
|
if len(self._action_queue) == 0:
|
||||||
# Each element in the queue has shape (B, *).
|
# Each element in the queue has shape (B, *).
|
||||||
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1))
|
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
||||||
|
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
|
@ -42,7 +42,7 @@ def kl_divergence(mu, logvar):
|
||||||
|
|
||||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||||
def __init__(self, cfg, device, n_action_steps=1):
|
def __init__(self, cfg, device, n_action_steps=1):
|
||||||
super().__init__()
|
super().__init__(n_action_steps)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.n_action_steps = n_action_steps
|
self.n_action_steps = n_action_steps
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -147,7 +147,10 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, observation, step_count):
|
def select_actions(self, observation, step_count):
|
||||||
|
if observation["image"].shape[0] != 1:
|
||||||
|
raise NotImplementedError("Batch size > 1 not handled")
|
||||||
|
|
||||||
# TODO(rcadene): remove unused step_count
|
# TODO(rcadene): remove unused step_count
|
||||||
del step_count
|
del step_count
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
# parameters passed to step
|
# parameters passed to step
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(n_action_steps)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||||
|
@ -44,7 +44,6 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
**cfg_obs_encoder,
|
**cfg_obs_encoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.n_action_steps = n_action_steps # needed for the parent class
|
|
||||||
self.diffusion = DiffusionUnetImagePolicy(
|
self.diffusion = DiffusionUnetImagePolicy(
|
||||||
shape_meta=shape_meta,
|
shape_meta=shape_meta,
|
||||||
noise_scheduler=noise_scheduler,
|
noise_scheduler=noise_scheduler,
|
||||||
|
@ -94,7 +93,7 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, observation, step_count):
|
def select_actions(self, observation, step_count):
|
||||||
# TODO(rcadene): remove unused step_count
|
# TODO(rcadene): remove unused step_count
|
||||||
del step_count
|
del step_count
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
def make_policy(cfg):
|
def make_policy(cfg):
|
||||||
|
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
|
||||||
|
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
|
||||||
|
|
||||||
if cfg.policy.name == "tdmpc":
|
if cfg.policy.name == "tdmpc":
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPC
|
from lerobot.common.policies.tdmpc.policy import TDMPC
|
||||||
|
|
||||||
|
|
|
@ -90,7 +90,7 @@ class TDMPC(AbstractPolicy):
|
||||||
"""Implementation of TD-MPC learning + inference."""
|
"""Implementation of TD-MPC learning + inference."""
|
||||||
|
|
||||||
def __init__(self, cfg, device):
|
def __init__(self, cfg, device):
|
||||||
super().__init__()
|
super().__init__(None)
|
||||||
self.action_dim = cfg.action_dim
|
self.action_dim = cfg.action_dim
|
||||||
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
@ -125,7 +125,10 @@ class TDMPC(AbstractPolicy):
|
||||||
self.model_target.load_state_dict(d["model_target"])
|
self.model_target.load_state_dict(d["model_target"])
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, observation, step_count):
|
def select_actions(self, observation, step_count):
|
||||||
|
if observation["image"].shape[0] != 1:
|
||||||
|
raise NotImplementedError("Batch size > 1 not handled")
|
||||||
|
|
||||||
t0 = step_count.item() == 0
|
t0 = step_count.item() == 0
|
||||||
|
|
||||||
obs = {
|
obs = {
|
||||||
|
@ -133,7 +136,8 @@ class TDMPC(AbstractPolicy):
|
||||||
"rgb": observation["image"].contiguous(),
|
"rgb": observation["image"].contiguous(),
|
||||||
"state": observation["state"].contiguous(),
|
"state": observation["state"].contiguous(),
|
||||||
}
|
}
|
||||||
action = self.act(obs, t0=t0, step=self.step.item())
|
# Note: unsqueeze needed because `act` still uses non-batch logic.
|
||||||
|
action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0)
|
||||||
return action
|
return action
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -144,7 +148,7 @@ class TDMPC(AbstractPolicy):
|
||||||
if self.cfg.mpc:
|
if self.cfg.mpc:
|
||||||
a = self.plan(z, t0=t0, step=step)
|
a = self.plan(z, t0=t0, step=step)
|
||||||
else:
|
else:
|
||||||
a = self.model.pi(z, self.cfg.min_std * self.model.training)
|
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
||||||
return a
|
return a
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -11,8 +11,7 @@ hydra:
|
||||||
|
|
||||||
seed: 1337
|
seed: 1337
|
||||||
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
|
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
|
||||||
# NOTE: batch size of 1 is not yet supported! This is just a placeholder for future support. See
|
# NOTE: only diffusion policy supports rollout_batch_size > 1
|
||||||
# `lerobot.common.envs.factory.make_env` for more information.
|
|
||||||
rollout_batch_size: 1
|
rollout_batch_size: 1
|
||||||
device: cuda # cpu
|
device: cuda # cpu
|
||||||
prefetch: 4
|
prefetch: 4
|
||||||
|
@ -20,7 +19,7 @@ eval_freq: ???
|
||||||
save_freq: ???
|
save_freq: ???
|
||||||
eval_episodes: ???
|
eval_episodes: ???
|
||||||
save_video: false
|
save_video: false
|
||||||
save_model: true
|
save_model: false
|
||||||
save_buffer: false
|
save_buffer: false
|
||||||
train_steps: ???
|
train_steps: ???
|
||||||
fps: ???
|
fps: ???
|
||||||
|
@ -33,7 +32,7 @@ env: ???
|
||||||
policy: ???
|
policy: ???
|
||||||
|
|
||||||
wandb:
|
wandb:
|
||||||
enable: false
|
enable: true
|
||||||
# Set to true to disable saving an artifact despite save_model == True
|
# Set to true to disable saving an artifact despite save_model == True
|
||||||
disable_artifact: false
|
disable_artifact: false
|
||||||
project: lerobot
|
project: lerobot
|
||||||
|
|
|
@ -22,8 +22,8 @@ keypoint_visible_rate: 1.0
|
||||||
obs_as_global_cond: True
|
obs_as_global_cond: True
|
||||||
|
|
||||||
eval_episodes: 1
|
eval_episodes: 1
|
||||||
eval_freq: 5000
|
eval_freq: 10000
|
||||||
save_freq: 5000
|
save_freq: 100000
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
|
|
||||||
offline_steps: 1344000
|
offline_steps: 1344000
|
||||||
|
|
|
@ -51,16 +51,25 @@ def eval_policy(
|
||||||
ep_frames.append(env.render()) # noqa: B023
|
ep_frames.append(env.render()) # noqa: B023
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
# TODO(alexander-soare): Due the `break_when_any_done == False` this rolls out for max_steps even when all
|
||||||
|
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
|
||||||
rollout = env.rollout(
|
rollout = env.rollout(
|
||||||
max_steps=max_steps,
|
max_steps=max_steps,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
auto_cast_to_device=True,
|
auto_cast_to_device=True,
|
||||||
callback=maybe_render_frame,
|
callback=maybe_render_frame,
|
||||||
|
break_when_any_done=False,
|
||||||
)
|
)
|
||||||
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
# Figure out where in each rollout sequence the first done condition was encountered (results after this won't
|
||||||
batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1)
|
# be included).
|
||||||
batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0]
|
# Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).
|
||||||
batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1)
|
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||||
|
rollout_steps = rollout["next", "done"].shape[1]
|
||||||
|
done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps)
|
||||||
|
mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1)
|
||||||
|
batch_sum_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).sum(dim=-1)
|
||||||
|
batch_max_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).max(dim=-1)[0]
|
||||||
|
batch_success = (rollout["next", "success"] * mask).flatten(start_dim=1).any(dim=-1)
|
||||||
sum_rewards.extend(batch_sum_reward.tolist())
|
sum_rewards.extend(batch_sum_reward.tolist())
|
||||||
max_rewards.extend(batch_max_reward.tolist())
|
max_rewards.extend(batch_max_reward.tolist())
|
||||||
successes.extend(batch_success.tolist())
|
successes.extend(batch_success.tolist())
|
||||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
from omegaconf import open_dict
|
from omegaconf import open_dict
|
||||||
import pytest
|
import pytest
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
@ -16,35 +15,50 @@ from .utils import DEVICE, init_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name,policy_name",
|
"env_name,policy_name,extra_overrides",
|
||||||
[
|
[
|
||||||
("simxarm", "tdmpc"),
|
("simxarm", "tdmpc", ["policy.mpc=true"]),
|
||||||
("pusht", "tdmpc"),
|
("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||||
("simxarm", "diffusion"),
|
("simxarm", "diffusion", []),
|
||||||
("pusht", "diffusion"),
|
("pusht", "diffusion", []),
|
||||||
|
("aloha", "act", ["env.task=sim_insertion_scripted"]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_factory(env_name, policy_name):
|
def test_concrete_policy(env_name, policy_name, extra_overrides):
|
||||||
|
"""
|
||||||
|
Tests:
|
||||||
|
- Making the policy object.
|
||||||
|
- Updating the policy.
|
||||||
|
- Using the policy to select actions at inference time.
|
||||||
|
"""
|
||||||
cfg = init_config(
|
cfg = init_config(
|
||||||
overrides=[
|
overrides=[
|
||||||
f"env={env_name}",
|
f"env={env_name}",
|
||||||
f"policy={policy_name}",
|
f"policy={policy_name}",
|
||||||
f"device={DEVICE}",
|
f"device={DEVICE}",
|
||||||
]
|
]
|
||||||
|
+ extra_overrides
|
||||||
)
|
)
|
||||||
# Check that we can make the policy object.
|
# Check that we can make the policy object.
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
# Check that we run select_action and get the appropriate output.
|
# Check that we run select_actions and get the appropriate output.
|
||||||
if env_name == "simxarm":
|
if env_name == "simxarm":
|
||||||
# TODO(rcadene): Not implemented
|
# TODO(rcadene): Not implemented
|
||||||
return
|
return
|
||||||
if policy_name == "tdmpc":
|
if policy_name == "tdmpc":
|
||||||
# TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this.
|
# TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this.
|
||||||
with open_dict(cfg):
|
with open_dict(cfg):
|
||||||
cfg['n_obs_steps'] = 1
|
cfg["n_obs_steps"] = 1
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
env = make_env(cfg, transform=offline_buffer.transform)
|
env = make_env(cfg, transform=offline_buffer.transform)
|
||||||
policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE))
|
|
||||||
|
policy.update(offline_buffer, torch.tensor(0, device=DEVICE))
|
||||||
|
|
||||||
|
action = policy(
|
||||||
|
env.observation_spec.rand()["observation"].to(DEVICE),
|
||||||
|
torch.tensor(0, device=DEVICE),
|
||||||
|
)
|
||||||
|
assert action.shape == env.action_spec.shape
|
||||||
|
|
||||||
|
|
||||||
def test_abstract_policy_forward():
|
def test_abstract_policy_forward():
|
||||||
|
@ -91,20 +105,19 @@ def test_abstract_policy_forward():
|
||||||
def _set_seed(self, seed: int | None):
|
def _set_seed(self, seed: int | None):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
class StubPolicy(AbstractPolicy):
|
class StubPolicy(AbstractPolicy):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__(n_action_steps)
|
||||||
self.n_action_steps = n_action_steps
|
|
||||||
self.n_policy_invocations = 0
|
self.n_policy_invocations = 0
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def select_action(self):
|
def select_actions(self):
|
||||||
self.n_policy_invocations += 1
|
self.n_policy_invocations += 1
|
||||||
return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0)
|
return torch.stack(
|
||||||
|
[torch.tensor([i]) for i in range(self.n_action_steps)]
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
env = StubEnv()
|
env = StubEnv()
|
||||||
policy = StubPolicy()
|
policy = StubPolicy()
|
||||||
|
@ -119,4 +132,4 @@ def test_abstract_policy_forward():
|
||||||
|
|
||||||
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
|
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
|
||||||
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
|
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
|
||||||
assert torch.equal(rollout['observation'].flatten(), torch.arange(terminate_at + 1))
|
assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))
|
||||||
|
|
Loading…
Reference in New Issue