Remove latency, tdmpc policy passes tests (TODO: make it work with online RL)
This commit is contained in:
parent
44656d2706
commit
4371a5570d
|
@ -37,7 +37,7 @@ policy = DiffusionPolicy(
|
|||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
n_action_steps=cfg.n_action_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
policy.train()
|
||||
|
|
|
@ -78,15 +78,11 @@ def make_dataset(
|
|||
]
|
||||
)
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): implement delta_timestamps in config
|
||||
delta_timestamps = {
|
||||
"observation.image": [-0.1, 0],
|
||||
"observation.state": [-0.1, 0],
|
||||
"action": [-0.1] + [i / clsfunc.fps for i in range(15)],
|
||||
}
|
||||
else:
|
||||
delta_timestamps = None
|
||||
delta_timestamps = cfg.policy.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
for key in delta_timestamps:
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
def make_policy(cfg):
|
||||
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
|
||||
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
policy = TDMPCPolicy(cfg.policy, cfg.device)
|
||||
policy = TDMPCPolicy(
|
||||
cfg.policy, n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, device=cfg.device
|
||||
)
|
||||
elif cfg.policy.name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||
|
||||
|
@ -17,14 +16,18 @@ def make_policy(cfg):
|
|||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
n_obs_steps=cfg.n_obs_steps,
|
||||
n_action_steps=cfg.n_action_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
elif cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||
|
||||
policy = ActionChunkingTransformerPolicy(
|
||||
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
|
||||
cfg.policy,
|
||||
cfg.device,
|
||||
n_obs_steps=cfg.n_obs_steps,
|
||||
n_action_steps=cfg.n_action_steps,
|
||||
)
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
|
|
@ -154,8 +154,14 @@ class TDMPCPolicy(nn.Module):
|
|||
if len(self._queues["action"]) == 0:
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
|
||||
if self.n_obs_steps == 1:
|
||||
# hack to remove the time dimension
|
||||
for key in batch:
|
||||
assert batch[key].shape[1] == 1
|
||||
batch[key] = batch[key][:, 0]
|
||||
|
||||
actions = []
|
||||
batch_size = batch["observation.image."].shape[0]
|
||||
batch_size = batch["observation.image"].shape[0]
|
||||
for i in range(batch_size):
|
||||
obs = {
|
||||
"rgb": batch["observation.image"][[i]],
|
||||
|
@ -166,6 +172,10 @@ class TDMPCPolicy(nn.Module):
|
|||
actions.append(action)
|
||||
action = torch.stack(actions)
|
||||
|
||||
# self.act returns an action for 1 timestep only, so we copy it over `n_action_steps` time
|
||||
if i in range(self.n_action_steps):
|
||||
self._queues["action"].append(action)
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
|
@ -410,22 +420,45 @@ class TDMPCPolicy(nn.Module):
|
|||
# idxs = torch.cat([idxs, demo_idxs])
|
||||
# weights = torch.cat([weights, demo_weights])
|
||||
|
||||
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
||||
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
|
||||
# batch size b = 256, time/horizon t = 5
|
||||
# b t ... -> t b ...
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"]
|
||||
reward = batch["next.reward"][:, :, None] # add extra channel dimension
|
||||
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
|
||||
obses = {
|
||||
"rgb": batch["observation.image"],
|
||||
"state": batch["observation.state"],
|
||||
}
|
||||
|
||||
shapes = {}
|
||||
for k in obses:
|
||||
shapes[k] = obses[k].shape
|
||||
obses[k] = einops.rearrange(obses[k], "t b ... -> (t b) ... ")
|
||||
|
||||
# Apply augmentations
|
||||
aug_tf = h.aug(self.cfg)
|
||||
obs = aug_tf(obs)
|
||||
obses = aug_tf(obses)
|
||||
|
||||
for k in next_obses:
|
||||
next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...")
|
||||
next_obses = aug_tf(next_obses)
|
||||
for k in next_obses:
|
||||
next_obses[k] = einops.rearrange(
|
||||
next_obses[k],
|
||||
"(h t) ... -> h t ...",
|
||||
h=self.cfg.horizon,
|
||||
t=self.cfg.batch_size,
|
||||
)
|
||||
for k in obses:
|
||||
t, b = shapes[k][:2]
|
||||
obses[k] = einops.rearrange(obses[k], "(t b) ... -> t b ... ", b=b, t=t)
|
||||
|
||||
horizon = self.cfg.horizon
|
||||
obs, next_obses = {}, {}
|
||||
for k in obses:
|
||||
obs[k] = obses[k][0]
|
||||
next_obses[k] = obses[k][1:].clone()
|
||||
|
||||
horizon = next_obses["rgb"].shape[0]
|
||||
loss_mask = torch.ones_like(mask, device=self.device)
|
||||
for t in range(1, horizon):
|
||||
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
|
||||
|
@ -497,19 +530,19 @@ class TDMPCPolicy(nn.Module):
|
|||
)
|
||||
self.optim.step()
|
||||
|
||||
if self.cfg.per:
|
||||
# Update priorities
|
||||
priorities = priority_loss.clamp(max=1e4).detach()
|
||||
has_nan = torch.isnan(priorities).any().item()
|
||||
if has_nan:
|
||||
print(f"priorities has nan: {priorities=}")
|
||||
else:
|
||||
replay_buffer.update_priority(
|
||||
idxs[:num_slices],
|
||||
priorities[:num_slices],
|
||||
)
|
||||
if demo_batch_size > 0:
|
||||
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
|
||||
# if self.cfg.per:
|
||||
# # Update priorities
|
||||
# priorities = priority_loss.clamp(max=1e4).detach()
|
||||
# has_nan = torch.isnan(priorities).any().item()
|
||||
# if has_nan:
|
||||
# print(f"priorities has nan: {priorities=}")
|
||||
# else:
|
||||
# replay_buffer.update_priority(
|
||||
# idxs[:num_slices],
|
||||
# priorities[:num_slices],
|
||||
# )
|
||||
# if demo_batch_size > 0:
|
||||
# demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
|
||||
|
||||
# Update policy + target network
|
||||
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
|
||||
|
@ -532,7 +565,7 @@ class TDMPCPolicy(nn.Module):
|
|||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
info["demo_batch_size"] = demo_batch_size
|
||||
# info["demo_batch_size"] = demo_batch_size
|
||||
info["expectile"] = expectile
|
||||
info.update(value_info)
|
||||
info.update(pi_update_info)
|
||||
|
|
|
@ -10,7 +10,6 @@ log_freq: 250
|
|||
|
||||
horizon: 100
|
||||
n_obs_steps: 1
|
||||
n_latency_steps: 0
|
||||
# when temporal_agg=False, n_action_steps=horizon
|
||||
n_action_steps: ${horizon}
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@ seed: 100000
|
|||
horizon: 16
|
||||
n_obs_steps: 2
|
||||
n_action_steps: 8
|
||||
n_latency_steps: 0
|
||||
dataset_obs_steps: ${n_obs_steps}
|
||||
past_action_visible: False
|
||||
keypoint_visible_rate: 1.0
|
||||
|
@ -38,7 +37,6 @@ policy:
|
|||
shape_meta: ${shape_meta}
|
||||
|
||||
horizon: ${horizon}
|
||||
# n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
||||
n_obs_steps: ${n_obs_steps}
|
||||
num_inference_steps: 100
|
||||
obs_as_global_cond: ${obs_as_global_cond}
|
||||
|
@ -64,6 +62,11 @@ policy:
|
|||
lr_warmup_steps: 500
|
||||
grad_clip_norm: 10
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: [-.1, 0]
|
||||
observation.state: [-.1, 0]
|
||||
action: [-.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4]
|
||||
|
||||
noise_scheduler:
|
||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
num_train_timesteps: 100
|
||||
|
|
|
@ -77,3 +77,9 @@ policy:
|
|||
num_q: 5
|
||||
mlp_dim: 512
|
||||
latent_dim: 50
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(6)]"
|
||||
observation.state: "[i / ${fps} for i in range(6)]"
|
||||
action: "[i / ${fps} for i in range(5)]"
|
||||
next.reward: "[i / ${fps} for i in range(5)]"
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
import pytest
|
||||
from tensordict import TensorDict
|
||||
from tensordict.nn import TensorDictModule
|
||||
import torch
|
||||
from torchrl.data import UnboundedContinuousTensorSpec
|
||||
from torchrl.envs import EnvBase
|
||||
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
|
||||
|
@ -16,22 +13,23 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
|||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
("simxarm", "tdmpc", ["policy.mpc=true"]),
|
||||
("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||
#("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||
("pusht", "diffusion", []),
|
||||
("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
|
||||
("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
|
||||
("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
|
||||
("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
||||
# ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
|
||||
#("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
|
||||
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
|
||||
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
||||
# TODO(aliberts): simxarm not working with diffusion
|
||||
# ("simxarm", "diffusion", []),
|
||||
],
|
||||
)
|
||||
def test_concrete_policy(env_name, policy_name, extra_overrides):
|
||||
def test_policy(env_name, policy_name, extra_overrides):
|
||||
"""
|
||||
Tests:
|
||||
- Making the policy object.
|
||||
- Updating the policy.
|
||||
- Using the policy to select actions at inference time.
|
||||
- Test the action can be applied to the policy
|
||||
"""
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
|
@ -46,91 +44,43 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
|
|||
policy = make_policy(cfg)
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
dataset = make_dataset(cfg)
|
||||
env = make_env(cfg, transform=dataset.transform)
|
||||
env = make_env(cfg, num_parallel_envs=2)
|
||||
|
||||
if env_name != "aloha":
|
||||
# TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError:
|
||||
# seq_length as a list is not supported for now.
|
||||
policy.update(dataset, torch.tensor(0, device=DEVICE))
|
||||
|
||||
action = policy(
|
||||
env.observation_spec.rand()["observation"].to(DEVICE),
|
||||
torch.tensor(0, device=DEVICE),
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=cfg.policy.batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=DEVICE != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
assert action.shape == env.action_spec.shape
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
batch = next(dl_iter)
|
||||
|
||||
def test_abstract_policy_forward():
|
||||
"""
|
||||
Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that:
|
||||
- The policy is invoked the expected number of times during a rollout.
|
||||
- The environment's termination condition is respected even when part way through an action trajectory.
|
||||
- The observations are returned correctly.
|
||||
"""
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
|
||||
n_action_steps = 8 # our test policy will output 8 action step horizons
|
||||
terminate_at = 10 # some number that is more than n_action_steps but not a multiple
|
||||
rollout_max_steps = terminate_at + 1 # some number greater than terminate_at
|
||||
# Test updating the policy
|
||||
policy(batch, step=0)
|
||||
|
||||
# A minimal environment for testing.
|
||||
class StubEnv(EnvBase):
|
||||
# reset the policy and environment
|
||||
policy.reset()
|
||||
observation, _ = env.reset(seed=cfg.seed)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.action_spec = UnboundedContinuousTensorSpec(shape=(1,))
|
||||
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
|
||||
# apply transform to normalize the observations
|
||||
observation = preprocess_observation(observation, dataset.transform)
|
||||
|
||||
def _step(self, tensordict: TensorDict) -> TensorDict:
|
||||
self.invocation_count += 1
|
||||
return TensorDict(
|
||||
{
|
||||
"observation": torch.tensor([self.invocation_count]),
|
||||
"reward": torch.tensor([self.invocation_count]),
|
||||
"terminated": torch.tensor(
|
||||
tensordict["action"].item() == terminate_at
|
||||
),
|
||||
}
|
||||
)
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||
|
||||
def _reset(self, tensordict: TensorDict) -> TensorDict:
|
||||
self.invocation_count = 0
|
||||
return TensorDict(
|
||||
{
|
||||
"observation": torch.tensor([self.invocation_count]),
|
||||
"reward": torch.tensor([self.invocation_count]),
|
||||
}
|
||||
)
|
||||
# get the next action for the environment
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=0)
|
||||
|
||||
def _set_seed(self, seed: int | None):
|
||||
return
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, dataset.transform)
|
||||
|
||||
class StubPolicy(AbstractPolicy):
|
||||
name = "stub"
|
||||
# Test step through policy
|
||||
env.step(action)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(n_action_steps)
|
||||
self.n_policy_invocations = 0
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
def select_actions(self):
|
||||
self.n_policy_invocations += 1
|
||||
return torch.stack(
|
||||
[torch.tensor([i]) for i in range(self.n_action_steps)]
|
||||
).unsqueeze(0)
|
||||
|
||||
env = StubEnv()
|
||||
policy = StubPolicy()
|
||||
policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=[],
|
||||
out_keys=["action"],
|
||||
)
|
||||
|
||||
# Keep track to make sure the policy is called the expected number of times
|
||||
rollout = env.rollout(rollout_max_steps, policy)
|
||||
|
||||
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
|
||||
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
|
||||
assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))
|
||||
|
|
Loading…
Reference in New Issue