backup wip

This commit is contained in:
Alexander Soare 2024-03-19 18:50:04 +00:00
parent ea17f4ce50
commit 896a11f60e
16 changed files with 169 additions and 138 deletions

View File

@ -50,7 +50,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def stats_patterns(self) -> dict:
return {
("observation", "state"): "b c -> c",
("observation", "image"): "b c h w -> c",
("observation", "image"): "b c h w -> c 1 1",
("action",): "b c -> c",
}

View File

@ -117,7 +117,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
("action",): "b c -> c",
}
for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> c"
d[("observation", "image", cam)] = "b c h w -> c 1 1"
return d
@property

View File

@ -58,6 +58,7 @@ class AlohaEnv(AbstractEnv):
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
self._reset_warning_issued = False
def _make_env(self):
if not _has_gym:
@ -120,47 +121,47 @@ class AlohaEnv(AbstractEnv):
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
if tensordict is not None and not self._reset_warning_issued:
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
self._reset_warning_issued = True
# TODO(rcadene): do not use global variable for this
if "sim_transfer_cube" in self.task:
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif "sim_insertion" in self.task:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
raw_obs = self._env.reset()
# TODO(rcadene): add assert
# assert self._current_seed == self._env._seed
# TODO(rcadene): do not use global variable for this
if "sim_transfer_cube" in self.task:
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif "sim_insertion" in self.task:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
obs = self._format_raw_obs(raw_obs.observation)
raw_obs = self._env.reset()
# TODO(rcadene): add assert
# assert self._current_seed == self._env._seed
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
obs = self._format_raw_obs(raw_obs.observation)
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
self.call_rendering_hooks()
return td

View File

@ -1,31 +1,20 @@
from torchrl.envs import SerialEnv
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
def make_env(cfg, transform=None):
"""
Provide seed to override the seed in the cfg (useful for batched environments).
Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
environments. The env therefore returns batches.`
"""
# assert cfg.rollout_batch_size == 1, \
# """
# For the time being, rollout batch sizes of > 1 are not supported. This is because the SerialEnv rollout does not
# correctly handle terminated environments. If you really want to use a larger batch size, read on...
# When calling `EnvBase.rollout` with `break_when_any_done == True` all environments stop rolling out as soon as the
# first is terminated or truncated. This almost certainly results in incorrect success metrics, as all but the first
# environment get an opportunity to reach the goal. A possible work around is to comment out `if any_done: break`
# inf `EnvBase._rollout_stop_early`. One potential downside is that the environments `step` function will continue
# to be called and the outputs will continue to be added to the rollout.
# When calling `EnvBase.rollout` with `break_when_any_done == False` environments are reset when done.
# """
kwargs = {
"frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size,
"num_prev_obs": cfg.n_obs_steps - 1,
"seed": cfg.seed,
"num_prev_obs": cfg.n_obs_steps - 1,
}
if cfg.env.name == "simxarm":
@ -67,13 +56,14 @@ def make_env(cfg, transform=None):
return env
# return SerialEnv(
# cfg.rollout_batch_size,
# create_env_fn=_make_env,
# create_env_kwargs={
# "seed": env_seed for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
# },
# )
return SerialEnv(
cfg.rollout_batch_size,
create_env_fn=_make_env,
create_env_kwargs={
"seed": env_seed # noqa: B035
for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
},
)
# def make_env(env_name, frame_skip, device, is_test=False):

View File

@ -1,4 +1,5 @@
import importlib
import logging
from collections import deque
from typing import Optional
@ -42,6 +43,7 @@ class PushtEnv(AbstractEnv):
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
self._reset_warning_issued = False
def _make_env(self):
if not _has_gym:
@ -79,39 +81,39 @@ class PushtEnv(AbstractEnv):
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
raw_obs = self._env.reset()
assert self._current_seed == self._env._seed
if tensordict is not None and not self._reset_warning_issued:
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
self._reset_warning_issued = True
obs = self._format_raw_obs(raw_obs)
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
raw_obs = self._env.reset()
assert self._current_seed == self._env._seed
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
obs = self._format_raw_obs(raw_obs)
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
self.call_rendering_hooks()
return td

View File

@ -12,6 +12,17 @@ class AbstractPolicy(nn.Module, ABC):
documentation for more information.
"""
def __init__(self, n_action_steps: int | None):
"""
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
adds that dimension.
"""
super().__init__()
self.n_action_steps = n_action_steps
if n_action_steps is not None:
self._action_queue = deque([], maxlen=n_action_steps)
@abstractmethod
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
@ -24,10 +35,11 @@ class AbstractPolicy(nn.Module, ABC):
self.load_state_dict(d)
@abstractmethod
def select_action(self, observation) -> Tensor:
def select_actions(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout.
Should return a (batch_size, n_action_steps, *) tensor of actions.
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
"""
def forward(self, *args, **kwargs) -> Tensor:
@ -41,18 +53,14 @@ class AbstractPolicy(nn.Module, ABC):
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
the subclass doesn't have to.
This method effectively wraps the `select_action` method of the subclass. The following assumptions are made:
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
the action trajectory horizon and * is the action dimensions.
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined.
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
"""
n_action_steps_attr = "n_action_steps"
if not hasattr(self, n_action_steps_attr):
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute")
if not hasattr(self, "_action_queue"):
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr))
if self.n_action_steps is None:
return self.select_actions(*args, **kwargs)
if len(self._action_queue) == 0:
# Each element in the queue has shape (B, *).
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1))
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
return self._action_queue.popleft()

View File

@ -42,7 +42,7 @@ def kl_divergence(mu, logvar):
class ActionChunkingTransformerPolicy(AbstractPolicy):
def __init__(self, cfg, device, n_action_steps=1):
super().__init__()
super().__init__(n_action_steps)
self.cfg = cfg
self.n_action_steps = n_action_steps
self.device = device
@ -147,7 +147,10 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
return loss
@torch.no_grad()
def select_action(self, observation, step_count):
def select_actions(self, observation, step_count):
if observation["image"].shape[0] != 1:
raise NotImplementedError("Batch size > 1 not handled")
# TODO(rcadene): remove unused step_count
del step_count

View File

@ -34,7 +34,7 @@ class DiffusionPolicy(AbstractPolicy):
# parameters passed to step
**kwargs,
):
super().__init__()
super().__init__(n_action_steps)
self.cfg = cfg
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
@ -44,7 +44,6 @@ class DiffusionPolicy(AbstractPolicy):
**cfg_obs_encoder,
)
self.n_action_steps = n_action_steps # needed for the parent class
self.diffusion = DiffusionUnetImagePolicy(
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
@ -94,7 +93,7 @@ class DiffusionPolicy(AbstractPolicy):
)
@torch.no_grad()
def select_action(self, observation, step_count):
def select_actions(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count

View File

@ -1,4 +1,7 @@
def make_policy(cfg):
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPC

View File

@ -90,7 +90,7 @@ class TDMPC(AbstractPolicy):
"""Implementation of TD-MPC learning + inference."""
def __init__(self, cfg, device):
super().__init__()
super().__init__(None)
self.action_dim = cfg.action_dim
self.cfg = cfg
@ -125,7 +125,10 @@ class TDMPC(AbstractPolicy):
self.model_target.load_state_dict(d["model_target"])
@torch.no_grad()
def select_action(self, observation, step_count):
def select_actions(self, observation, step_count):
if observation["image"].shape[0] != 1:
raise NotImplementedError("Batch size > 1 not handled")
t0 = step_count.item() == 0
obs = {
@ -133,7 +136,8 @@ class TDMPC(AbstractPolicy):
"rgb": observation["image"].contiguous(),
"state": observation["state"].contiguous(),
}
action = self.act(obs, t0=t0, step=self.step.item())
# Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0)
return action
@torch.no_grad()
@ -144,7 +148,7 @@ class TDMPC(AbstractPolicy):
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)
else:
a = self.model.pi(z, self.cfg.min_std * self.model.training)
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
return a
@torch.no_grad()

View File

@ -11,8 +11,7 @@ hydra:
seed: 1337
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
# NOTE: batch size of 1 is not yet supported! This is just a placeholder for future support. See
# `lerobot.common.envs.factory.make_env` for more information.
# NOTE: only diffusion policy supports rollout_batch_size > 1
rollout_batch_size: 1
device: cuda # cpu
prefetch: 4
@ -20,7 +19,7 @@ eval_freq: ???
save_freq: ???
eval_episodes: ???
save_video: false
save_model: true
save_model: false
save_buffer: false
train_steps: ???
fps: ???
@ -33,7 +32,7 @@ env: ???
policy: ???
wandb:
enable: false
enable: true
# Set to true to disable saving an artifact despite save_model == True
disable_artifact: false
project: lerobot

View File

@ -22,8 +22,8 @@ keypoint_visible_rate: 1.0
obs_as_global_cond: True
eval_episodes: 1
eval_freq: 5000
save_freq: 5000
eval_freq: 10000
save_freq: 100000
log_freq: 250
offline_steps: 1344000

View File

@ -51,16 +51,25 @@ def eval_policy(
ep_frames.append(env.render()) # noqa: B023
with torch.inference_mode():
# TODO(alexander-soare): Due the `break_when_any_done == False` this rolls out for max_steps even when all
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
auto_cast_to_device=True,
callback=maybe_render_frame,
break_when_any_done=False,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1)
batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0]
batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1)
# Figure out where in each rollout sequence the first done condition was encountered (results after this won't
# be included).
# Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
rollout_steps = rollout["next", "done"].shape[1]
done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps)
mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1)
batch_sum_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).sum(dim=-1)
batch_max_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).max(dim=-1)[0]
batch_success = (rollout["next", "success"] * mask).flatten(start_dim=1).any(dim=-1)
sum_rewards.extend(batch_sum_reward.tolist())
max_rewards.extend(batch_max_reward.tolist())
successes.extend(batch_success.tolist())

Binary file not shown.

View File

@ -1,4 +1,3 @@
from omegaconf import open_dict
import pytest
from tensordict import TensorDict
@ -16,35 +15,50 @@ from .utils import DEVICE, init_config
@pytest.mark.parametrize(
"env_name,policy_name",
"env_name,policy_name,extra_overrides",
[
("simxarm", "tdmpc"),
("pusht", "tdmpc"),
("simxarm", "diffusion"),
("pusht", "diffusion"),
("simxarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc", ["policy.mpc=false"]),
("simxarm", "diffusion", []),
("pusht", "diffusion", []),
("aloha", "act", ["env.task=sim_insertion_scripted"]),
],
)
def test_factory(env_name, policy_name):
def test_concrete_policy(env_name, policy_name, extra_overrides):
"""
Tests:
- Making the policy object.
- Updating the policy.
- Using the policy to select actions at inference time.
"""
cfg = init_config(
overrides=[
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
]
+ extra_overrides
)
# Check that we can make the policy object.
policy = make_policy(cfg)
# Check that we run select_action and get the appropriate output.
# Check that we run select_actions and get the appropriate output.
if env_name == "simxarm":
# TODO(rcadene): Not implemented
return
if policy_name == "tdmpc":
# TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this.
with open_dict(cfg):
cfg['n_obs_steps'] = 1
cfg["n_obs_steps"] = 1
offline_buffer = make_offline_buffer(cfg)
env = make_env(cfg, transform=offline_buffer.transform)
policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE))
policy.update(offline_buffer, torch.tensor(0, device=DEVICE))
action = policy(
env.observation_spec.rand()["observation"].to(DEVICE),
torch.tensor(0, device=DEVICE),
)
assert action.shape == env.action_spec.shape
def test_abstract_policy_forward():
@ -90,21 +104,20 @@ def test_abstract_policy_forward():
def _set_seed(self, seed: int | None):
return
class StubPolicy(AbstractPolicy):
def __init__(self):
super().__init__()
self.n_action_steps = n_action_steps
super().__init__(n_action_steps)
self.n_policy_invocations = 0
def update(self):
pass
def select_action(self):
def select_actions(self):
self.n_policy_invocations += 1
return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0)
return torch.stack(
[torch.tensor([i]) for i in range(self.n_action_steps)]
).unsqueeze(0)
env = StubEnv()
policy = StubPolicy()
@ -119,4 +132,4 @@ def test_abstract_policy_forward():
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
assert torch.equal(rollout['observation'].flatten(), torch.arange(terminate_at + 1))
assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))