backup wip

This commit is contained in:
Alexander Soare 2024-03-19 18:50:04 +00:00
parent ea17f4ce50
commit 896a11f60e
16 changed files with 169 additions and 138 deletions

View File

@ -50,7 +50,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def stats_patterns(self) -> dict: def stats_patterns(self) -> dict:
return { return {
("observation", "state"): "b c -> c", ("observation", "state"): "b c -> c",
("observation", "image"): "b c h w -> c", ("observation", "image"): "b c h w -> c 1 1",
("action",): "b c -> c", ("action",): "b c -> c",
} }

View File

@ -117,7 +117,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
("action",): "b c -> c", ("action",): "b c -> c",
} }
for cam in CAMERAS[self.dataset_id]: for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> c" d[("observation", "image", cam)] = "b c h w -> c 1 1"
return d return d
@property @property

View File

@ -58,6 +58,7 @@ class AlohaEnv(AbstractEnv):
num_prev_obs=num_prev_obs, num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action, num_prev_action=num_prev_action,
) )
self._reset_warning_issued = False
def _make_env(self): def _make_env(self):
if not _has_gym: if not _has_gym:
@ -120,8 +121,10 @@ class AlohaEnv(AbstractEnv):
return obs return obs
def _reset(self, tensordict: Optional[TensorDict] = None): def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict if tensordict is not None and not self._reset_warning_issued:
if td is None or td.is_empty(): logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
self._reset_warning_issued = True
# we need to handle seed iteration, since self._env.reset() rely an internal _seed. # we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1 self._current_seed += 1
self.set_seed(self._current_seed) self.set_seed(self._current_seed)
@ -159,8 +162,6 @@ class AlohaEnv(AbstractEnv):
}, },
batch_size=[], batch_size=[],
) )
else:
raise NotImplementedError()
self.call_rendering_hooks() self.call_rendering_hooks()
return td return td

View File

@ -1,31 +1,20 @@
from torchrl.envs import SerialEnv
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
def make_env(cfg, transform=None): def make_env(cfg, transform=None):
""" """
Provide seed to override the seed in the cfg (useful for batched environments). Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
environments. The env therefore returns batches.`
""" """
# assert cfg.rollout_batch_size == 1, \
# """
# For the time being, rollout batch sizes of > 1 are not supported. This is because the SerialEnv rollout does not
# correctly handle terminated environments. If you really want to use a larger batch size, read on...
# When calling `EnvBase.rollout` with `break_when_any_done == True` all environments stop rolling out as soon as the
# first is terminated or truncated. This almost certainly results in incorrect success metrics, as all but the first
# environment get an opportunity to reach the goal. A possible work around is to comment out `if any_done: break`
# inf `EnvBase._rollout_stop_early`. One potential downside is that the environments `step` function will continue
# to be called and the outputs will continue to be added to the rollout.
# When calling `EnvBase.rollout` with `break_when_any_done == False` environments are reset when done.
# """
kwargs = { kwargs = {
"frame_skip": cfg.env.action_repeat, "frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels, "from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only, "pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size, "image_size": cfg.env.image_size,
"num_prev_obs": cfg.n_obs_steps - 1,
"seed": cfg.seed, "seed": cfg.seed,
"num_prev_obs": cfg.n_obs_steps - 1,
} }
if cfg.env.name == "simxarm": if cfg.env.name == "simxarm":
@ -67,13 +56,14 @@ def make_env(cfg, transform=None):
return env return env
# return SerialEnv( return SerialEnv(
# cfg.rollout_batch_size, cfg.rollout_batch_size,
# create_env_fn=_make_env, create_env_fn=_make_env,
# create_env_kwargs={ create_env_kwargs={
# "seed": env_seed for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) "seed": env_seed # noqa: B035
# }, for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
# ) },
)
# def make_env(env_name, frame_skip, device, is_test=False): # def make_env(env_name, frame_skip, device, is_test=False):

View File

@ -1,4 +1,5 @@
import importlib import importlib
import logging
from collections import deque from collections import deque
from typing import Optional from typing import Optional
@ -42,6 +43,7 @@ class PushtEnv(AbstractEnv):
num_prev_obs=num_prev_obs, num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action, num_prev_action=num_prev_action,
) )
self._reset_warning_issued = False
def _make_env(self): def _make_env(self):
if not _has_gym: if not _has_gym:
@ -79,8 +81,10 @@ class PushtEnv(AbstractEnv):
return obs return obs
def _reset(self, tensordict: Optional[TensorDict] = None): def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict if tensordict is not None and not self._reset_warning_issued:
if td is None or td.is_empty(): logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
self._reset_warning_issued = True
# we need to handle seed iteration, since self._env.reset() rely an internal _seed. # we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1 self._current_seed += 1
self.set_seed(self._current_seed) self.set_seed(self._current_seed)
@ -110,8 +114,6 @@ class PushtEnv(AbstractEnv):
}, },
batch_size=[], batch_size=[],
) )
else:
raise NotImplementedError()
self.call_rendering_hooks() self.call_rendering_hooks()
return td return td

View File

@ -12,6 +12,17 @@ class AbstractPolicy(nn.Module, ABC):
documentation for more information. documentation for more information.
""" """
def __init__(self, n_action_steps: int | None):
"""
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
adds that dimension.
"""
super().__init__()
self.n_action_steps = n_action_steps
if n_action_steps is not None:
self._action_queue = deque([], maxlen=n_action_steps)
@abstractmethod @abstractmethod
def update(self, replay_buffer, step): def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm.""" """One step of the policy's learning algorithm."""
@ -24,10 +35,11 @@ class AbstractPolicy(nn.Module, ABC):
self.load_state_dict(d) self.load_state_dict(d)
@abstractmethod @abstractmethod
def select_action(self, observation) -> Tensor: def select_actions(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout. """Select an action (or trajectory of actions) based on an observation during rollout.
Should return a (batch_size, n_action_steps, *) tensor of actions. If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
""" """
def forward(self, *args, **kwargs) -> Tensor: def forward(self, *args, **kwargs) -> Tensor:
@ -41,18 +53,14 @@ class AbstractPolicy(nn.Module, ABC):
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
the subclass doesn't have to. the subclass doesn't have to.
This method effectively wraps the `select_action` method of the subclass. The following assumptions are made: This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
the action trajectory horizon and * is the action dimensions. the action trajectory horizon and * is the action dimensions.
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined. 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
""" """
n_action_steps_attr = "n_action_steps" if self.n_action_steps is None:
if not hasattr(self, n_action_steps_attr): return self.select_actions(*args, **kwargs)
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute")
if not hasattr(self, "_action_queue"):
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr))
if len(self._action_queue) == 0: if len(self._action_queue) == 0:
# Each element in the queue has shape (B, *). # Each element in the queue has shape (B, *).
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1)) self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()

View File

@ -42,7 +42,7 @@ def kl_divergence(mu, logvar):
class ActionChunkingTransformerPolicy(AbstractPolicy): class ActionChunkingTransformerPolicy(AbstractPolicy):
def __init__(self, cfg, device, n_action_steps=1): def __init__(self, cfg, device, n_action_steps=1):
super().__init__() super().__init__(n_action_steps)
self.cfg = cfg self.cfg = cfg
self.n_action_steps = n_action_steps self.n_action_steps = n_action_steps
self.device = device self.device = device
@ -147,7 +147,10 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
return loss return loss
@torch.no_grad() @torch.no_grad()
def select_action(self, observation, step_count): def select_actions(self, observation, step_count):
if observation["image"].shape[0] != 1:
raise NotImplementedError("Batch size > 1 not handled")
# TODO(rcadene): remove unused step_count # TODO(rcadene): remove unused step_count
del step_count del step_count

View File

@ -34,7 +34,7 @@ class DiffusionPolicy(AbstractPolicy):
# parameters passed to step # parameters passed to step
**kwargs, **kwargs,
): ):
super().__init__() super().__init__(n_action_steps)
self.cfg = cfg self.cfg = cfg
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
@ -44,7 +44,6 @@ class DiffusionPolicy(AbstractPolicy):
**cfg_obs_encoder, **cfg_obs_encoder,
) )
self.n_action_steps = n_action_steps # needed for the parent class
self.diffusion = DiffusionUnetImagePolicy( self.diffusion = DiffusionUnetImagePolicy(
shape_meta=shape_meta, shape_meta=shape_meta,
noise_scheduler=noise_scheduler, noise_scheduler=noise_scheduler,
@ -94,7 +93,7 @@ class DiffusionPolicy(AbstractPolicy):
) )
@torch.no_grad() @torch.no_grad()
def select_action(self, observation, step_count): def select_actions(self, observation, step_count):
# TODO(rcadene): remove unused step_count # TODO(rcadene): remove unused step_count
del step_count del step_count

View File

@ -1,4 +1,7 @@
def make_policy(cfg): def make_policy(cfg):
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
if cfg.policy.name == "tdmpc": if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPC from lerobot.common.policies.tdmpc.policy import TDMPC

View File

@ -90,7 +90,7 @@ class TDMPC(AbstractPolicy):
"""Implementation of TD-MPC learning + inference.""" """Implementation of TD-MPC learning + inference."""
def __init__(self, cfg, device): def __init__(self, cfg, device):
super().__init__() super().__init__(None)
self.action_dim = cfg.action_dim self.action_dim = cfg.action_dim
self.cfg = cfg self.cfg = cfg
@ -125,7 +125,10 @@ class TDMPC(AbstractPolicy):
self.model_target.load_state_dict(d["model_target"]) self.model_target.load_state_dict(d["model_target"])
@torch.no_grad() @torch.no_grad()
def select_action(self, observation, step_count): def select_actions(self, observation, step_count):
if observation["image"].shape[0] != 1:
raise NotImplementedError("Batch size > 1 not handled")
t0 = step_count.item() == 0 t0 = step_count.item() == 0
obs = { obs = {
@ -133,7 +136,8 @@ class TDMPC(AbstractPolicy):
"rgb": observation["image"].contiguous(), "rgb": observation["image"].contiguous(),
"state": observation["state"].contiguous(), "state": observation["state"].contiguous(),
} }
action = self.act(obs, t0=t0, step=self.step.item()) # Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0)
return action return action
@torch.no_grad() @torch.no_grad()
@ -144,7 +148,7 @@ class TDMPC(AbstractPolicy):
if self.cfg.mpc: if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step) a = self.plan(z, t0=t0, step=step)
else: else:
a = self.model.pi(z, self.cfg.min_std * self.model.training) a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
return a return a
@torch.no_grad() @torch.no_grad()

View File

@ -11,8 +11,7 @@ hydra:
seed: 1337 seed: 1337
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index # batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
# NOTE: batch size of 1 is not yet supported! This is just a placeholder for future support. See # NOTE: only diffusion policy supports rollout_batch_size > 1
# `lerobot.common.envs.factory.make_env` for more information.
rollout_batch_size: 1 rollout_batch_size: 1
device: cuda # cpu device: cuda # cpu
prefetch: 4 prefetch: 4
@ -20,7 +19,7 @@ eval_freq: ???
save_freq: ??? save_freq: ???
eval_episodes: ??? eval_episodes: ???
save_video: false save_video: false
save_model: true save_model: false
save_buffer: false save_buffer: false
train_steps: ??? train_steps: ???
fps: ??? fps: ???
@ -33,7 +32,7 @@ env: ???
policy: ??? policy: ???
wandb: wandb:
enable: false enable: true
# Set to true to disable saving an artifact despite save_model == True # Set to true to disable saving an artifact despite save_model == True
disable_artifact: false disable_artifact: false
project: lerobot project: lerobot

View File

@ -22,8 +22,8 @@ keypoint_visible_rate: 1.0
obs_as_global_cond: True obs_as_global_cond: True
eval_episodes: 1 eval_episodes: 1
eval_freq: 5000 eval_freq: 10000
save_freq: 5000 save_freq: 100000
log_freq: 250 log_freq: 250
offline_steps: 1344000 offline_steps: 1344000

View File

@ -51,16 +51,25 @@ def eval_policy(
ep_frames.append(env.render()) # noqa: B023 ep_frames.append(env.render()) # noqa: B023
with torch.inference_mode(): with torch.inference_mode():
# TODO(alexander-soare): Due the `break_when_any_done == False` this rolls out for max_steps even when all
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
rollout = env.rollout( rollout = env.rollout(
max_steps=max_steps, max_steps=max_steps,
policy=policy, policy=policy,
auto_cast_to_device=True, auto_cast_to_device=True,
callback=maybe_render_frame, callback=maybe_render_frame,
break_when_any_done=False,
) )
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) # Figure out where in each rollout sequence the first done condition was encountered (results after this won't
batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1) # be included).
batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0] # Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).
batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1) # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
rollout_steps = rollout["next", "done"].shape[1]
done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps)
mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1)
batch_sum_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).sum(dim=-1)
batch_max_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).max(dim=-1)[0]
batch_success = (rollout["next", "success"] * mask).flatten(start_dim=1).any(dim=-1)
sum_rewards.extend(batch_sum_reward.tolist()) sum_rewards.extend(batch_sum_reward.tolist())
max_rewards.extend(batch_max_reward.tolist()) max_rewards.extend(batch_max_reward.tolist())
successes.extend(batch_success.tolist()) successes.extend(batch_success.tolist())

Binary file not shown.

View File

@ -1,4 +1,3 @@
from omegaconf import open_dict from omegaconf import open_dict
import pytest import pytest
from tensordict import TensorDict from tensordict import TensorDict
@ -16,35 +15,50 @@ from .utils import DEVICE, init_config
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name,policy_name", "env_name,policy_name,extra_overrides",
[ [
("simxarm", "tdmpc"), ("simxarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc"), ("pusht", "tdmpc", ["policy.mpc=false"]),
("simxarm", "diffusion"), ("simxarm", "diffusion", []),
("pusht", "diffusion"), ("pusht", "diffusion", []),
("aloha", "act", ["env.task=sim_insertion_scripted"]),
], ],
) )
def test_factory(env_name, policy_name): def test_concrete_policy(env_name, policy_name, extra_overrides):
"""
Tests:
- Making the policy object.
- Updating the policy.
- Using the policy to select actions at inference time.
"""
cfg = init_config( cfg = init_config(
overrides=[ overrides=[
f"env={env_name}", f"env={env_name}",
f"policy={policy_name}", f"policy={policy_name}",
f"device={DEVICE}", f"device={DEVICE}",
] ]
+ extra_overrides
) )
# Check that we can make the policy object. # Check that we can make the policy object.
policy = make_policy(cfg) policy = make_policy(cfg)
# Check that we run select_action and get the appropriate output. # Check that we run select_actions and get the appropriate output.
if env_name == "simxarm": if env_name == "simxarm":
# TODO(rcadene): Not implemented # TODO(rcadene): Not implemented
return return
if policy_name == "tdmpc": if policy_name == "tdmpc":
# TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this. # TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this.
with open_dict(cfg): with open_dict(cfg):
cfg['n_obs_steps'] = 1 cfg["n_obs_steps"] = 1
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)
env = make_env(cfg, transform=offline_buffer.transform) env = make_env(cfg, transform=offline_buffer.transform)
policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE))
policy.update(offline_buffer, torch.tensor(0, device=DEVICE))
action = policy(
env.observation_spec.rand()["observation"].to(DEVICE),
torch.tensor(0, device=DEVICE),
)
assert action.shape == env.action_spec.shape
def test_abstract_policy_forward(): def test_abstract_policy_forward():
@ -91,20 +105,19 @@ def test_abstract_policy_forward():
def _set_seed(self, seed: int | None): def _set_seed(self, seed: int | None):
return return
class StubPolicy(AbstractPolicy): class StubPolicy(AbstractPolicy):
def __init__(self): def __init__(self):
super().__init__() super().__init__(n_action_steps)
self.n_action_steps = n_action_steps
self.n_policy_invocations = 0 self.n_policy_invocations = 0
def update(self): def update(self):
pass pass
def select_action(self): def select_actions(self):
self.n_policy_invocations += 1 self.n_policy_invocations += 1
return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0) return torch.stack(
[torch.tensor([i]) for i in range(self.n_action_steps)]
).unsqueeze(0)
env = StubEnv() env = StubEnv()
policy = StubPolicy() policy = StubPolicy()
@ -119,4 +132,4 @@ def test_abstract_policy_forward():
assert len(rollout) == terminate_at + 1 # +1 for the reset observation assert len(rollout) == terminate_at + 1 # +1 for the reset observation
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1 assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
assert torch.equal(rollout['observation'].flatten(), torch.arange(terminate_at + 1)) assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))