Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl

This commit is contained in:
Alexander Soare 2024-04-08 09:25:45 +01:00
commit e982c732f1
19 changed files with 253 additions and 242 deletions

View File

@ -37,7 +37,7 @@ policy = DiffusionPolicy(
cfg_obs_encoder=cfg.obs_encoder, cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer, cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema, cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, n_action_steps=cfg.n_action_steps,
**cfg.policy, **cfg.policy,
) )
policy.train() policy.train()

View File

@ -164,19 +164,11 @@ def make_dataset(
] ]
) )
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": delta_timestamps = cfg.policy.get("delta_timestamps")
# TODO(rcadene): implement delta_timestamps in config if delta_timestamps is not None:
delta_timestamps = { for key in delta_timestamps:
"observation.image": [-0.1, 0], if isinstance(delta_timestamps[key], str):
"observation.state": [-0.1, 0], delta_timestamps[key] = eval(delta_timestamps[key])
"action": [-0.1] + [i / clsfunc.fps for i in range(15)],
}
else:
delta_timestamps = {
"observation.images.top": [0],
"observation.state": [0],
"action": [i / clsfunc.fps for i in range(cfg.policy.horizon)],
}
dataset = clsfunc( dataset = clsfunc(
dataset_id=cfg.dataset_id, dataset_id=cfg.dataset_id,

View File

@ -6,9 +6,9 @@ import pygame
import pymunk import pymunk
import torch import torch
import tqdm import tqdm
from gym_pusht.envs.pusht import pymunk_to_shapely
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
# as define in env # as define in env

View File

@ -4,6 +4,9 @@ register(
id="gym_aloha/AlohaInsertion-v0", id="gym_aloha/AlohaInsertion-v0",
entry_point="lerobot.common.envs.aloha.env:AlohaEnv", entry_point="lerobot.common.envs.aloha.env:AlohaEnv",
max_episode_steps=300, max_episode_steps=300,
# Even after seeding, the rendered observations are slightly different,
# so we set `nondeterministic=True` to pass `check_env` tests
nondeterministic=True,
kwargs={"obs_type": "state", "task": "insertion"}, kwargs={"obs_type": "state", "task": "insertion"},
) )
@ -11,5 +14,8 @@ register(
id="gym_aloha/AlohaTransferCube-v0", id="gym_aloha/AlohaTransferCube-v0",
entry_point="lerobot.common.envs.aloha.env:AlohaEnv", entry_point="lerobot.common.envs.aloha.env:AlohaEnv",
max_episode_steps=300, max_episode_steps=300,
# Even after seeding, the rendered observations are slightly different,
# so we set `nondeterministic=True` to pass `check_env` tests
nondeterministic=True,
kwargs={"obs_type": "state", "task": "transfer_cube"}, kwargs={"obs_type": "state", "task": "transfer_cube"},
) )

View File

@ -16,7 +16,6 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import (
TransferCubeEndEffectorTask, TransferCubeEndEffectorTask,
) )
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
from lerobot.common.utils import set_global_seed
class AlohaEnv(gym.Env): class AlohaEnv(gym.Env):
@ -49,21 +48,33 @@ class AlohaEnv(gym.Env):
dtype=np.float64, dtype=np.float64,
) )
elif self.obs_type == "pixels": elif self.obs_type == "pixels":
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.observation_height, self.observation_width, 3), dtype=np.uint8
)
elif self.obs_type == "pixels_agent_pos":
self.observation_space = spaces.Dict( self.observation_space = spaces.Dict(
{ {
"pixels": spaces.Box( "top": spaces.Box(
low=0, low=0,
high=255, high=255,
shape=(self.observation_height, self.observation_width, 3), shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8, dtype=np.uint8,
)
}
)
elif self.obs_type == "pixels_agent_pos":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(
{
"top": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
)
}
), ),
"agent_pos": spaces.Box( "agent_pos": spaces.Box(
low=np.array([-1] * len(JOINTS)), # ??? low=-np.inf,
high=np.array([1] * len(JOINTS)), # ??? high=np.inf,
shape=(len(JOINTS),),
dtype=np.float64, dtype=np.float64,
), ),
} }
@ -89,21 +100,21 @@ class AlohaEnv(gym.Env):
if "transfer_cube" in task_name: if "transfer_cube" in task_name:
xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml" xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path)) physics = mujoco.Physics.from_xml_path(str(xml_path))
task = TransferCubeTask(random=False) task = TransferCubeTask()
elif "insertion" in task_name: elif "insertion" in task_name:
xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml" xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path)) physics = mujoco.Physics.from_xml_path(str(xml_path))
task = InsertionTask(random=False) task = InsertionTask()
elif "end_effector_transfer_cube" in task_name: elif "end_effector_transfer_cube" in task_name:
raise NotImplementedError() raise NotImplementedError()
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml" xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path)) physics = mujoco.Physics.from_xml_path(str(xml_path))
task = TransferCubeEndEffectorTask(random=False) task = TransferCubeEndEffectorTask()
elif "end_effector_insertion" in task_name: elif "end_effector_insertion" in task_name:
raise NotImplementedError() raise NotImplementedError()
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml" xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path)) physics = mujoco.Physics.from_xml_path(str(xml_path))
task = InsertionEndEffectorTask(random=False) task = InsertionEndEffectorTask()
else: else:
raise NotImplementedError(task_name) raise NotImplementedError(task_name)
@ -116,10 +127,10 @@ class AlohaEnv(gym.Env):
if self.obs_type == "state": if self.obs_type == "state":
raise NotImplementedError() raise NotImplementedError()
elif self.obs_type == "pixels": elif self.obs_type == "pixels":
obs = raw_obs["images"]["top"].copy() obs = {"top": raw_obs["images"]["top"].copy()}
elif self.obs_type == "pixels_agent_pos": elif self.obs_type == "pixels_agent_pos":
obs = { obs = {
"pixels": raw_obs["images"]["top"].copy(), "pixels": {"top": raw_obs["images"]["top"].copy()},
"agent_pos": raw_obs["qpos"], "agent_pos": raw_obs["qpos"],
} }
return obs return obs
@ -129,14 +140,14 @@ class AlohaEnv(gym.Env):
# TODO(rcadene): how to seed the env? # TODO(rcadene): how to seed the env?
if seed is not None: if seed is not None:
set_global_seed(seed)
self._env.task.random.seed(seed) self._env.task.random.seed(seed)
self._env.task._random = np.random.RandomState(seed)
# TODO(rcadene): do not use global variable for this # TODO(rcadene): do not use global variable for this
if "transfer_cube" in self.task: if "transfer_cube" in self.task:
BOX_POSE[0] = sample_box_pose() # used in sim reset BOX_POSE[0] = sample_box_pose(seed) # used in sim reset
elif "insertion" in self.task: elif "insertion" in self.task:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset BOX_POSE[0] = np.concatenate(sample_insertion_pose(seed)) # used in sim reset
else: else:
raise ValueError(self.task) raise ValueError(self.task)

View File

@ -1,26 +1,30 @@
import numpy as np import numpy as np
def sample_box_pose(): def sample_box_pose(seed=None):
x_range = [0.0, 0.2] x_range = [0.0, 0.2]
y_range = [0.4, 0.6] y_range = [0.4, 0.6]
z_range = [0.05, 0.05] z_range = [0.05, 0.05]
rng = np.random.RandomState(seed)
ranges = np.vstack([x_range, y_range, z_range]) ranges = np.vstack([x_range, y_range, z_range])
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) cube_position = rng.uniform(ranges[:, 0], ranges[:, 1])
cube_quat = np.array([1, 0, 0, 0]) cube_quat = np.array([1, 0, 0, 0])
return np.concatenate([cube_position, cube_quat]) return np.concatenate([cube_position, cube_quat])
def sample_insertion_pose(): def sample_insertion_pose(seed=None):
# Peg # Peg
x_range = [0.1, 0.2] x_range = [0.1, 0.2]
y_range = [0.4, 0.6] y_range = [0.4, 0.6]
z_range = [0.05, 0.05] z_range = [0.05, 0.05]
rng = np.random.RandomState(seed)
ranges = np.vstack([x_range, y_range, z_range]) ranges = np.vstack([x_range, y_range, z_range])
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) peg_position = rng.uniform(ranges[:, 0], ranges[:, 1])
peg_quat = np.array([1, 0, 0, 0]) peg_quat = np.array([1, 0, 0, 0])
peg_pose = np.concatenate([peg_position, peg_quat]) peg_pose = np.concatenate([peg_position, peg_quat])
@ -31,7 +35,7 @@ def sample_insertion_pose():
z_range = [0.05, 0.05] z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range]) ranges = np.vstack([x_range, y_range, z_range])
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) socket_position = rng.uniform(ranges[:, 0], ranges[:, 1])
socket_quat = np.array([1, 0, 0, 0]) socket_quat = np.array([1, 0, 0, 0])
socket_pose = np.concatenate([socket_position, socket_quat]) socket_pose = np.concatenate([socket_position, socket_quat])

View File

@ -30,7 +30,7 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
**kwargs, **kwargs,
) )
elif cfg.env.name == "aloha": elif cfg.env.name == "aloha":
from lerobot.common.envs import aloha as gym_aloha # noqa: F401 import gym_aloha # noqa: F401
kwargs["task"] = cfg.env.task kwargs["task"] = cfg.env.task

View File

@ -6,12 +6,20 @@ from lerobot.common.transforms import apply_inverse_transform
def preprocess_observation(observation, transform=None): def preprocess_observation(observation, transform=None):
# map to expected inputs for the policy # map to expected inputs for the policy
obs = { obs = {}
"observation.image": torch.from_numpy(observation["pixels"]).float(),
"observation.state": torch.from_numpy(observation["agent_pos"]).float(), if isinstance(observation["pixels"], dict):
} imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
# convert to (b c h w) torch format else:
obs["observation.image"] = einops.rearrange(obs["observation.image"], "b h w c -> b c h w") imgs = {"observation.image": observation["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img).float()
# convert to (b c h w) torch format
img = einops.rearrange(img, "b h w c -> b c h w")
obs[imgkey] = img
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
# apply same transforms as in training # apply same transforms as in training
if transform is not None: if transform is not None:

View File

@ -1,11 +1,10 @@
def make_policy(cfg): def make_policy(cfg):
if cfg.policy.name not in ["diffusion", "act"] and cfg.rollout_batch_size > 1:
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
if cfg.policy.name == "tdmpc": if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
policy = TDMPCPolicy(cfg.policy, cfg.device) policy = TDMPCPolicy(
cfg.policy, n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, device=cfg.device
)
elif cfg.policy.name == "diffusion": elif cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.policy import DiffusionPolicy from lerobot.common.policies.diffusion.policy import DiffusionPolicy
@ -17,14 +16,18 @@ def make_policy(cfg):
cfg_obs_encoder=cfg.obs_encoder, cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer, cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema, cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, n_obs_steps=cfg.n_obs_steps,
n_action_steps=cfg.n_action_steps,
**cfg.policy, **cfg.policy,
) )
elif cfg.policy.name == "act": elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy( policy = ActionChunkingTransformerPolicy(
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps cfg.policy,
cfg.device,
n_obs_steps=cfg.n_obs_steps,
n_action_steps=cfg.n_action_steps,
) )
else: else:
raise ValueError(cfg.policy.name) raise ValueError(cfg.policy.name)

View File

@ -154,8 +154,14 @@ class TDMPCPolicy(nn.Module):
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
if self.n_obs_steps == 1:
# hack to remove the time dimension
for key in batch:
assert batch[key].shape[1] == 1
batch[key] = batch[key][:, 0]
actions = [] actions = []
batch_size = batch["observation.image."].shape[0] batch_size = batch["observation.image"].shape[0]
for i in range(batch_size): for i in range(batch_size):
obs = { obs = {
"rgb": batch["observation.image"][[i]], "rgb": batch["observation.image"][[i]],
@ -166,6 +172,10 @@ class TDMPCPolicy(nn.Module):
actions.append(action) actions.append(action)
action = torch.stack(actions) action = torch.stack(actions)
# self.act returns an action for 1 timestep only, so we copy it over `n_action_steps` time
if i in range(self.n_action_steps):
self._queues["action"].append(action)
action = self._queues["action"].popleft() action = self._queues["action"].popleft()
return action return action
@ -410,22 +420,45 @@ class TDMPCPolicy(nn.Module):
# idxs = torch.cat([idxs, demo_idxs]) # idxs = torch.cat([idxs, demo_idxs])
# weights = torch.cat([weights, demo_weights]) # weights = torch.cat([weights, demo_weights])
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
# batch size b = 256, time/horizon t = 5
# b t ... -> t b ...
for key in batch:
if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
action = batch["action"]
reward = batch["next.reward"][:, :, None] # add extra channel dimension
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
obses = {
"rgb": batch["observation.image"],
"state": batch["observation.state"],
}
shapes = {}
for k in obses:
shapes[k] = obses[k].shape
obses[k] = einops.rearrange(obses[k], "t b ... -> (t b) ... ")
# Apply augmentations # Apply augmentations
aug_tf = h.aug(self.cfg) aug_tf = h.aug(self.cfg)
obs = aug_tf(obs) obses = aug_tf(obses)
for k in next_obses: for k in obses:
next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...") t, b = shapes[k][:2]
next_obses = aug_tf(next_obses) obses[k] = einops.rearrange(obses[k], "(t b) ... -> t b ... ", b=b, t=t)
for k in next_obses:
next_obses[k] = einops.rearrange(
next_obses[k],
"(h t) ... -> h t ...",
h=self.cfg.horizon,
t=self.cfg.batch_size,
)
horizon = self.cfg.horizon obs, next_obses = {}, {}
for k in obses:
obs[k] = obses[k][0]
next_obses[k] = obses[k][1:].clone()
horizon = next_obses["rgb"].shape[0]
loss_mask = torch.ones_like(mask, device=self.device) loss_mask = torch.ones_like(mask, device=self.device)
for t in range(1, horizon): for t in range(1, horizon):
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1]) loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
@ -497,19 +530,19 @@ class TDMPCPolicy(nn.Module):
) )
self.optim.step() self.optim.step()
if self.cfg.per: # if self.cfg.per:
# Update priorities # # Update priorities
priorities = priority_loss.clamp(max=1e4).detach() # priorities = priority_loss.clamp(max=1e4).detach()
has_nan = torch.isnan(priorities).any().item() # has_nan = torch.isnan(priorities).any().item()
if has_nan: # if has_nan:
print(f"priorities has nan: {priorities=}") # print(f"priorities has nan: {priorities=}")
else: # else:
replay_buffer.update_priority( # replay_buffer.update_priority(
idxs[:num_slices], # idxs[:num_slices],
priorities[:num_slices], # priorities[:num_slices],
) # )
if demo_batch_size > 0: # if demo_batch_size > 0:
demo_buffer.update_priority(demo_idxs, priorities[num_slices:]) # demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
# Update policy + target network # Update policy + target network
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action) _, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
@ -532,7 +565,7 @@ class TDMPCPolicy(nn.Module):
"data_s": data_s, "data_s": data_s,
"update_s": time.time() - start_time, "update_s": time.time() - start_time,
} }
info["demo_batch_size"] = demo_batch_size # info["demo_batch_size"] = demo_batch_size
info["expectile"] = expectile info["expectile"] = expectile
info.update(value_info) info.update(value_info)
info.update(pi_update_info) info.update(pi_update_info)

View File

@ -10,7 +10,6 @@ log_freq: 250
horizon: 100 horizon: 100
n_obs_steps: 1 n_obs_steps: 1
n_latency_steps: 0
# when temporal_agg=False, n_action_steps=horizon # when temporal_agg=False, n_action_steps=horizon
n_action_steps: ${horizon} n_action_steps: ${horizon}
@ -57,3 +56,8 @@ policy:
state_dim: ??? state_dim: ???
action_dim: ??? action_dim: ???
delta_timestamps:
observation.image: [0.0]
observation.state: [0.0]
action: [0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, 0.4, 0.42, 0.44, 0.46, 0.48, 0.5, 0.52, 0.54, 0.56, 0.58, 0.6, 0.62, 0.64, 0.66, 0.68, 0.70, 0.72, 0.74, 0.76, 0.78, 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98, 1.0, 1.02, 1.04, 1.06, 1.08, 1.1, 1.12, 1.14, 1.16, 1.18, 1.2, 1.22, 1.24, 1.26, 1.28, 1.3, 1.32, 1.34, 1.36, 1.38, 1.40, 1.42, 1.44, 1.46, 1.48, 1.5, 1.52, 1.54, 1.56, 1.58, 1.6, 1.62, 1.64, 1.66, 1.68, 1.7, 1.72, 1.74, 1.76, 1.78, 1.8, 1.82, 1.84, 1.86, 1.88, 1.90, 1.92, 1.94, 1.96, 1.98]

View File

@ -16,7 +16,6 @@ seed: 100000
horizon: 16 horizon: 16
n_obs_steps: 2 n_obs_steps: 2
n_action_steps: 8 n_action_steps: 8
n_latency_steps: 0
dataset_obs_steps: ${n_obs_steps} dataset_obs_steps: ${n_obs_steps}
past_action_visible: False past_action_visible: False
keypoint_visible_rate: 1.0 keypoint_visible_rate: 1.0
@ -38,7 +37,6 @@ policy:
shape_meta: ${shape_meta} shape_meta: ${shape_meta}
horizon: ${horizon} horizon: ${horizon}
# n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
n_obs_steps: ${n_obs_steps} n_obs_steps: ${n_obs_steps}
num_inference_steps: 100 num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond} obs_as_global_cond: ${obs_as_global_cond}
@ -64,6 +62,11 @@ policy:
lr_warmup_steps: 500 lr_warmup_steps: 500
grad_clip_norm: 10 grad_clip_norm: 10
delta_timestamps:
observation.image: [-.1, 0]
observation.state: [-.1, 0]
action: [-.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4]
noise_scheduler: noise_scheduler:
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
num_train_timesteps: 100 num_train_timesteps: 100

View File

@ -77,3 +77,9 @@ policy:
num_q: 5 num_q: 5
mlp_dim: 512 mlp_dim: 512
latent_dim: 50 latent_dim: 50
delta_timestamps:
observation.image: "[i / ${fps} for i in range(6)]"
observation.state: "[i / ${fps} for i in range(6)]"
action: "[i / ${fps} for i in range(5)]"
next.reward: "[i / ${fps} for i in range(5)]"

View File

@ -148,8 +148,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# ) # )
logging.info("make_env") logging.info("make_env")
# TODO(now): uncomment env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
#env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
logging.info("make_policy") logging.info("make_policy")
policy = make_policy(cfg) policy = make_policy(cfg)

23
poetry.lock generated
View File

@ -880,6 +880,26 @@ files = [
[package.extras] [package.extras]
protobuf = ["grpcio-tools (>=1.62.1)"] protobuf = ["grpcio-tools (>=1.62.1)"]
[[package]]
name = "gym-aloha"
version = "0.1.0"
description = "A gym environment for ALOHA"
optional = true
python-versions = "^3.10"
files = []
develop = false
[package.dependencies]
dm-control = "1.0.14"
gymnasium = "^0.29.1"
mujoco = "^2.3.7"
[package.source]
type = "git"
url = "git@github.com:huggingface/gym-aloha.git"
reference = "HEAD"
resolved_reference = "ec7200831e36c14e343cf7d275c6b047f2fe9d11"
[[package]] [[package]]
name = "gym-pusht" name = "gym-pusht"
version = "0.1.0" version = "0.1.0"
@ -3714,10 +3734,11 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
[extras] [extras]
aloha = ["gym_aloha"]
pusht = ["gym_pusht"] pusht = ["gym_pusht"]
xarm = ["gym_xarm"] xarm = ["gym_xarm"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "c9524cdf000eaa755a2ab3be669118222b4f8b1c262013f103f6874cbd54eeb6" content-hash = "6ef509580cef6bc50e9fbb5095097cbf21218d293a2d171155ced4bbe1d3e151"

View File

@ -54,12 +54,13 @@ gymnasium = "^0.29.1"
cmake = "^3.29.0.1" cmake = "^3.29.0.1"
gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true} gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
gym_xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true} gym_xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
# gym_pusht = { path = "../gym-pusht", develop = true, optional = true} gym_aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true}
# gym_xarm = { path = "../gym-xarm", develop = true, optional = true}
[tool.poetry.extras] [tool.poetry.extras]
pusht = ["gym_pusht"] pusht = ["gym_pusht"]
xarm = ["gym_xarm"] xarm = ["gym_xarm"]
aloha = ["gym_aloha"]
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2" pre-commit = "^3.6.2"

View File

@ -15,50 +15,50 @@ Note:
import pytest import pytest
import lerobot import lerobot
from lerobot.common.envs.aloha.env import AlohaEnv # from lerobot.common.envs.aloha.env import AlohaEnv
from lerobot.common.envs.pusht.env import PushtEnv # from gym_pusht.envs import PushtEnv
from lerobot.common.envs.simxarm.env import SimxarmEnv # from gym_xarm.envs import SimxarmEnv
from lerobot.common.datasets.simxarm import SimxarmDataset # from lerobot.common.datasets.simxarm import SimxarmDataset
from lerobot.common.datasets.aloha import AlohaDataset # from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.common.datasets.pusht import PushtDataset # from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy # from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.policy import DiffusionPolicy # from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy # from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
def test_available(): # def test_available():
pol_classes = [ # pol_classes = [
ActionChunkingTransformerPolicy, # ActionChunkingTransformerPolicy,
DiffusionPolicy, # DiffusionPolicy,
TDMPCPolicy, # TDMPCPolicy,
] # ]
env_classes = [ # env_classes = [
AlohaEnv, # AlohaEnv,
PushtEnv, # PushtEnv,
SimxarmEnv, # SimxarmEnv,
] # ]
dat_classes = [ # dat_classes = [
AlohaDataset, # AlohaDataset,
PushtDataset, # PushtDataset,
SimxarmDataset, # SimxarmDataset,
] # ]
policies = [pol_cls.name for pol_cls in pol_classes] # policies = [pol_cls.name for pol_cls in pol_classes]
assert set(policies) == set(lerobot.available_policies) # assert set(policies) == set(lerobot.available_policies)
envs = [env_cls.name for env_cls in env_classes] # envs = [env_cls.name for env_cls in env_classes]
assert set(envs) == set(lerobot.available_envs) # assert set(envs) == set(lerobot.available_envs)
tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes} # tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes}
for env in envs: # for env in envs:
assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env]) # assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env])
datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)} # datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)}
for env in envs: # for env in envs:
assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env]) # assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env])

View File

@ -9,38 +9,9 @@ from lerobot.common.utils import init_hydra_config
from lerobot.common.envs.utils import preprocess_observation from lerobot.common.envs.utils import preprocess_observation
# import dmc_aloha # noqa: F401
from .utils import DEVICE, DEFAULT_CONFIG_PATH from .utils import DEVICE, DEFAULT_CONFIG_PATH
# def print_spec_rollout(env):
# print("observation_spec:", env.observation_spec)
# print("action_spec:", env.action_spec)
# print("reward_spec:", env.reward_spec)
# print("done_spec:", env.done_spec)
# td = env.reset()
# print("reset tensordict", td)
# td = env.rand_step(td)
# print("random step tensordict", td)
# def simple_rollout(steps=100):
# # preallocate:
# data = TensorDict({}, [steps])
# # reset
# _data = env.reset()
# for i in range(steps):
# _data["action"] = env.action_spec.rand()
# _data = env.step(_data)
# data[i] = _data
# _data = step_mdp(_data, keep_other=True)
# return data
# print("data from rollout:", simple_rollout(100))
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_task, obs_type", "env_task, obs_type",
[ [
@ -54,7 +25,7 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
def test_aloha(env_task, obs_type): def test_aloha(env_task, obs_type):
from lerobot.common.envs import aloha as gym_aloha # noqa: F401 from lerobot.common.envs import aloha as gym_aloha # noqa: F401
env = gym.make(f"gym_aloha/{env_task}", obs_type=obs_type) env = gym.make(f"gym_aloha/{env_task}", obs_type=obs_type)
check_env(env) check_env(env.unwrapped)
@ -70,7 +41,7 @@ def test_aloha(env_task, obs_type):
def test_xarm(env_task, obs_type): def test_xarm(env_task, obs_type):
import gym_xarm # noqa: F401 import gym_xarm # noqa: F401
env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type) env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type)
check_env(env) check_env(env.unwrapped)
@ -85,7 +56,7 @@ def test_xarm(env_task, obs_type):
def test_pusht(env_task, obs_type): def test_pusht(env_task, obs_type):
import gym_pusht # noqa: F401 import gym_pusht # noqa: F401
env = gym.make(f"gym_pusht/{env_task}", obs_type=obs_type) env = gym.make(f"gym_pusht/{env_task}", obs_type=obs_type)
check_env(env) check_env(env.unwrapped)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -93,7 +64,7 @@ def test_pusht(env_task, obs_type):
[ [
"pusht", "pusht",
"simxarm", "simxarm",
# "aloha", "aloha",
], ],
) )
def test_factory(env_name): def test_factory(env_name):
@ -104,9 +75,8 @@ def test_factory(env_name):
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
env = make_env(cfg) env = make_env(cfg, num_parallel_envs=1)
obs, info = env.reset() obs, info = env.reset()
obs = {key: obs[key][None, ...] for key in obs}
obs = preprocess_observation(obs, transform=dataset.transform) obs = preprocess_observation(obs, transform=dataset.transform)
for key in dataset.image_keys: for key in dataset.image_keys:
img = obs[key] img = obs[key]

View File

@ -1,14 +1,11 @@
import pytest import pytest
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
import torch import torch
from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.utils import init_hydra_config from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH from .utils import DEVICE, DEFAULT_CONFIG_PATH
@ -16,22 +13,23 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
"env_name,policy_name,extra_overrides", "env_name,policy_name,extra_overrides",
[ [
("simxarm", "tdmpc", ["policy.mpc=true"]), ("simxarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc", ["policy.mpc=false"]), #("pusht", "tdmpc", ["policy.mpc=false"]),
("pusht", "diffusion", []), ("pusht", "diffusion", []),
("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]), # ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]), #("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]), #("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]), #("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
# TODO(aliberts): simxarm not working with diffusion # TODO(aliberts): simxarm not working with diffusion
# ("simxarm", "diffusion", []), # ("simxarm", "diffusion", []),
], ],
) )
def test_concrete_policy(env_name, policy_name, extra_overrides): def test_policy(env_name, policy_name, extra_overrides):
""" """
Tests: Tests:
- Making the policy object. - Making the policy object.
- Updating the policy. - Updating the policy.
- Using the policy to select actions at inference time. - Using the policy to select actions at inference time.
- Test the action can be applied to the policy
""" """
cfg = init_hydra_config( cfg = init_hydra_config(
DEFAULT_CONFIG_PATH, DEFAULT_CONFIG_PATH,
@ -46,91 +44,43 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
policy = make_policy(cfg) policy = make_policy(cfg)
# Check that we run select_actions and get the appropriate output. # Check that we run select_actions and get the appropriate output.
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
env = make_env(cfg, transform=dataset.transform) env = make_env(cfg, num_parallel_envs=2)
if env_name != "aloha": dataloader = torch.utils.data.DataLoader(
# TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError: dataset,
# seq_length as a list is not supported for now. num_workers=4,
policy.update(dataset, torch.tensor(0, device=DEVICE)) batch_size=cfg.policy.batch_size,
shuffle=True,
action = policy( pin_memory=DEVICE != "cpu",
env.observation_spec.rand()["observation"].to(DEVICE), drop_last=True,
torch.tensor(0, device=DEVICE),
) )
assert action.shape == env.action_spec.shape dl_iter = cycle(dataloader)
batch = next(dl_iter)
def test_abstract_policy_forward(): for key in batch:
""" batch[key] = batch[key].to(DEVICE, non_blocking=True)
Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that:
- The policy is invoked the expected number of times during a rollout.
- The environment's termination condition is respected even when part way through an action trajectory.
- The observations are returned correctly.
"""
n_action_steps = 8 # our test policy will output 8 action step horizons # Test updating the policy
terminate_at = 10 # some number that is more than n_action_steps but not a multiple policy(batch, step=0)
rollout_max_steps = terminate_at + 1 # some number greater than terminate_at
# A minimal environment for testing. # reset the policy and environment
class StubEnv(EnvBase): policy.reset()
observation, _ = env.reset(seed=cfg.seed)
def __init__(self): # apply transform to normalize the observations
super().__init__() observation = preprocess_observation(observation, dataset.transform)
self.action_spec = UnboundedContinuousTensorSpec(shape=(1,))
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
def _step(self, tensordict: TensorDict) -> TensorDict: # send observation to device/gpu
self.invocation_count += 1 observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
return TensorDict(
{
"observation": torch.tensor([self.invocation_count]),
"reward": torch.tensor([self.invocation_count]),
"terminated": torch.tensor(
tensordict["action"].item() == terminate_at
),
}
)
def _reset(self, tensordict: TensorDict) -> TensorDict: # get the next action for the environment
self.invocation_count = 0 with torch.inference_mode():
return TensorDict( action = policy.select_action(observation, step=0)
{
"observation": torch.tensor([self.invocation_count]),
"reward": torch.tensor([self.invocation_count]),
}
)
def _set_seed(self, seed: int | None): # apply inverse transform to unnormalize the action
return action = postprocess_action(action, dataset.transform)
class StubPolicy(AbstractPolicy): # Test step through policy
name = "stub" env.step(action)
def __init__(self):
super().__init__(n_action_steps)
self.n_policy_invocations = 0
def update(self):
pass
def select_actions(self):
self.n_policy_invocations += 1
return torch.stack(
[torch.tensor([i]) for i in range(self.n_action_steps)]
).unsqueeze(0)
env = StubEnv()
policy = StubPolicy()
policy = TensorDictModule(
policy,
in_keys=[],
out_keys=["action"],
)
# Keep track to make sure the policy is called the expected number of times
rollout = env.rollout(rollout_max_steps, policy)
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))