Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl
This commit is contained in:
commit
e982c732f1
|
@ -37,7 +37,7 @@ policy = DiffusionPolicy(
|
||||||
cfg_obs_encoder=cfg.obs_encoder,
|
cfg_obs_encoder=cfg.obs_encoder,
|
||||||
cfg_optimizer=cfg.optimizer,
|
cfg_optimizer=cfg.optimizer,
|
||||||
cfg_ema=cfg.ema,
|
cfg_ema=cfg.ema,
|
||||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
n_action_steps=cfg.n_action_steps,
|
||||||
**cfg.policy,
|
**cfg.policy,
|
||||||
)
|
)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
|
@ -164,19 +164,11 @@ def make_dataset(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
delta_timestamps = cfg.policy.get("delta_timestamps")
|
||||||
# TODO(rcadene): implement delta_timestamps in config
|
if delta_timestamps is not None:
|
||||||
delta_timestamps = {
|
for key in delta_timestamps:
|
||||||
"observation.image": [-0.1, 0],
|
if isinstance(delta_timestamps[key], str):
|
||||||
"observation.state": [-0.1, 0],
|
delta_timestamps[key] = eval(delta_timestamps[key])
|
||||||
"action": [-0.1] + [i / clsfunc.fps for i in range(15)],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
delta_timestamps = {
|
|
||||||
"observation.images.top": [0],
|
|
||||||
"observation.state": [0],
|
|
||||||
"action": [i / clsfunc.fps for i in range(cfg.policy.horizon)],
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset = clsfunc(
|
dataset = clsfunc(
|
||||||
dataset_id=cfg.dataset_id,
|
dataset_id=cfg.dataset_id,
|
||||||
|
|
|
@ -6,9 +6,9 @@ import pygame
|
||||||
import pymunk
|
import pymunk
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from gym_pusht.envs.pusht import pymunk_to_shapely
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
|
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
|
||||||
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
|
||||||
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||||
|
|
||||||
# as define in env
|
# as define in env
|
||||||
|
|
|
@ -4,6 +4,9 @@ register(
|
||||||
id="gym_aloha/AlohaInsertion-v0",
|
id="gym_aloha/AlohaInsertion-v0",
|
||||||
entry_point="lerobot.common.envs.aloha.env:AlohaEnv",
|
entry_point="lerobot.common.envs.aloha.env:AlohaEnv",
|
||||||
max_episode_steps=300,
|
max_episode_steps=300,
|
||||||
|
# Even after seeding, the rendered observations are slightly different,
|
||||||
|
# so we set `nondeterministic=True` to pass `check_env` tests
|
||||||
|
nondeterministic=True,
|
||||||
kwargs={"obs_type": "state", "task": "insertion"},
|
kwargs={"obs_type": "state", "task": "insertion"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -11,5 +14,8 @@ register(
|
||||||
id="gym_aloha/AlohaTransferCube-v0",
|
id="gym_aloha/AlohaTransferCube-v0",
|
||||||
entry_point="lerobot.common.envs.aloha.env:AlohaEnv",
|
entry_point="lerobot.common.envs.aloha.env:AlohaEnv",
|
||||||
max_episode_steps=300,
|
max_episode_steps=300,
|
||||||
|
# Even after seeding, the rendered observations are slightly different,
|
||||||
|
# so we set `nondeterministic=True` to pass `check_env` tests
|
||||||
|
nondeterministic=True,
|
||||||
kwargs={"obs_type": "state", "task": "transfer_cube"},
|
kwargs={"obs_type": "state", "task": "transfer_cube"},
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,7 +16,6 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import (
|
||||||
TransferCubeEndEffectorTask,
|
TransferCubeEndEffectorTask,
|
||||||
)
|
)
|
||||||
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||||
from lerobot.common.utils import set_global_seed
|
|
||||||
|
|
||||||
|
|
||||||
class AlohaEnv(gym.Env):
|
class AlohaEnv(gym.Env):
|
||||||
|
@ -49,21 +48,33 @@ class AlohaEnv(gym.Env):
|
||||||
dtype=np.float64,
|
dtype=np.float64,
|
||||||
)
|
)
|
||||||
elif self.obs_type == "pixels":
|
elif self.obs_type == "pixels":
|
||||||
self.observation_space = spaces.Box(
|
|
||||||
low=0, high=255, shape=(self.observation_height, self.observation_width, 3), dtype=np.uint8
|
|
||||||
)
|
|
||||||
elif self.obs_type == "pixels_agent_pos":
|
|
||||||
self.observation_space = spaces.Dict(
|
self.observation_space = spaces.Dict(
|
||||||
{
|
{
|
||||||
"pixels": spaces.Box(
|
"top": spaces.Box(
|
||||||
low=0,
|
low=0,
|
||||||
high=255,
|
high=255,
|
||||||
shape=(self.observation_height, self.observation_width, 3),
|
shape=(self.observation_height, self.observation_width, 3),
|
||||||
dtype=np.uint8,
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif self.obs_type == "pixels_agent_pos":
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"pixels": spaces.Dict(
|
||||||
|
{
|
||||||
|
"top": spaces.Box(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=(self.observation_height, self.observation_width, 3),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
}
|
||||||
),
|
),
|
||||||
"agent_pos": spaces.Box(
|
"agent_pos": spaces.Box(
|
||||||
low=np.array([-1] * len(JOINTS)), # ???
|
low=-np.inf,
|
||||||
high=np.array([1] * len(JOINTS)), # ???
|
high=np.inf,
|
||||||
|
shape=(len(JOINTS),),
|
||||||
dtype=np.float64,
|
dtype=np.float64,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
@ -89,21 +100,21 @@ class AlohaEnv(gym.Env):
|
||||||
if "transfer_cube" in task_name:
|
if "transfer_cube" in task_name:
|
||||||
xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
|
xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
|
||||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||||
task = TransferCubeTask(random=False)
|
task = TransferCubeTask()
|
||||||
elif "insertion" in task_name:
|
elif "insertion" in task_name:
|
||||||
xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
|
xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
|
||||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||||
task = InsertionTask(random=False)
|
task = InsertionTask()
|
||||||
elif "end_effector_transfer_cube" in task_name:
|
elif "end_effector_transfer_cube" in task_name:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
|
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
|
||||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||||
task = TransferCubeEndEffectorTask(random=False)
|
task = TransferCubeEndEffectorTask()
|
||||||
elif "end_effector_insertion" in task_name:
|
elif "end_effector_insertion" in task_name:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
|
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
|
||||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||||
task = InsertionEndEffectorTask(random=False)
|
task = InsertionEndEffectorTask()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(task_name)
|
raise NotImplementedError(task_name)
|
||||||
|
|
||||||
|
@ -116,10 +127,10 @@ class AlohaEnv(gym.Env):
|
||||||
if self.obs_type == "state":
|
if self.obs_type == "state":
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
elif self.obs_type == "pixels":
|
elif self.obs_type == "pixels":
|
||||||
obs = raw_obs["images"]["top"].copy()
|
obs = {"top": raw_obs["images"]["top"].copy()}
|
||||||
elif self.obs_type == "pixels_agent_pos":
|
elif self.obs_type == "pixels_agent_pos":
|
||||||
obs = {
|
obs = {
|
||||||
"pixels": raw_obs["images"]["top"].copy(),
|
"pixels": {"top": raw_obs["images"]["top"].copy()},
|
||||||
"agent_pos": raw_obs["qpos"],
|
"agent_pos": raw_obs["qpos"],
|
||||||
}
|
}
|
||||||
return obs
|
return obs
|
||||||
|
@ -129,14 +140,14 @@ class AlohaEnv(gym.Env):
|
||||||
|
|
||||||
# TODO(rcadene): how to seed the env?
|
# TODO(rcadene): how to seed the env?
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
set_global_seed(seed)
|
|
||||||
self._env.task.random.seed(seed)
|
self._env.task.random.seed(seed)
|
||||||
|
self._env.task._random = np.random.RandomState(seed)
|
||||||
|
|
||||||
# TODO(rcadene): do not use global variable for this
|
# TODO(rcadene): do not use global variable for this
|
||||||
if "transfer_cube" in self.task:
|
if "transfer_cube" in self.task:
|
||||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
BOX_POSE[0] = sample_box_pose(seed) # used in sim reset
|
||||||
elif "insertion" in self.task:
|
elif "insertion" in self.task:
|
||||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
BOX_POSE[0] = np.concatenate(sample_insertion_pose(seed)) # used in sim reset
|
||||||
else:
|
else:
|
||||||
raise ValueError(self.task)
|
raise ValueError(self.task)
|
||||||
|
|
||||||
|
|
|
@ -1,26 +1,30 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def sample_box_pose():
|
def sample_box_pose(seed=None):
|
||||||
x_range = [0.0, 0.2]
|
x_range = [0.0, 0.2]
|
||||||
y_range = [0.4, 0.6]
|
y_range = [0.4, 0.6]
|
||||||
z_range = [0.05, 0.05]
|
z_range = [0.05, 0.05]
|
||||||
|
|
||||||
|
rng = np.random.RandomState(seed)
|
||||||
|
|
||||||
ranges = np.vstack([x_range, y_range, z_range])
|
ranges = np.vstack([x_range, y_range, z_range])
|
||||||
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
cube_position = rng.uniform(ranges[:, 0], ranges[:, 1])
|
||||||
|
|
||||||
cube_quat = np.array([1, 0, 0, 0])
|
cube_quat = np.array([1, 0, 0, 0])
|
||||||
return np.concatenate([cube_position, cube_quat])
|
return np.concatenate([cube_position, cube_quat])
|
||||||
|
|
||||||
|
|
||||||
def sample_insertion_pose():
|
def sample_insertion_pose(seed=None):
|
||||||
# Peg
|
# Peg
|
||||||
x_range = [0.1, 0.2]
|
x_range = [0.1, 0.2]
|
||||||
y_range = [0.4, 0.6]
|
y_range = [0.4, 0.6]
|
||||||
z_range = [0.05, 0.05]
|
z_range = [0.05, 0.05]
|
||||||
|
|
||||||
|
rng = np.random.RandomState(seed)
|
||||||
|
|
||||||
ranges = np.vstack([x_range, y_range, z_range])
|
ranges = np.vstack([x_range, y_range, z_range])
|
||||||
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
peg_position = rng.uniform(ranges[:, 0], ranges[:, 1])
|
||||||
|
|
||||||
peg_quat = np.array([1, 0, 0, 0])
|
peg_quat = np.array([1, 0, 0, 0])
|
||||||
peg_pose = np.concatenate([peg_position, peg_quat])
|
peg_pose = np.concatenate([peg_position, peg_quat])
|
||||||
|
@ -31,7 +35,7 @@ def sample_insertion_pose():
|
||||||
z_range = [0.05, 0.05]
|
z_range = [0.05, 0.05]
|
||||||
|
|
||||||
ranges = np.vstack([x_range, y_range, z_range])
|
ranges = np.vstack([x_range, y_range, z_range])
|
||||||
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
socket_position = rng.uniform(ranges[:, 0], ranges[:, 1])
|
||||||
|
|
||||||
socket_quat = np.array([1, 0, 0, 0])
|
socket_quat = np.array([1, 0, 0, 0])
|
||||||
socket_pose = np.concatenate([socket_position, socket_quat])
|
socket_pose = np.concatenate([socket_position, socket_quat])
|
||||||
|
|
|
@ -30,7 +30,7 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif cfg.env.name == "aloha":
|
elif cfg.env.name == "aloha":
|
||||||
from lerobot.common.envs import aloha as gym_aloha # noqa: F401
|
import gym_aloha # noqa: F401
|
||||||
|
|
||||||
kwargs["task"] = cfg.env.task
|
kwargs["task"] = cfg.env.task
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,20 @@ from lerobot.common.transforms import apply_inverse_transform
|
||||||
|
|
||||||
def preprocess_observation(observation, transform=None):
|
def preprocess_observation(observation, transform=None):
|
||||||
# map to expected inputs for the policy
|
# map to expected inputs for the policy
|
||||||
obs = {
|
obs = {}
|
||||||
"observation.image": torch.from_numpy(observation["pixels"]).float(),
|
|
||||||
"observation.state": torch.from_numpy(observation["agent_pos"]).float(),
|
if isinstance(observation["pixels"], dict):
|
||||||
}
|
imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
|
||||||
|
else:
|
||||||
|
imgs = {"observation.image": observation["pixels"]}
|
||||||
|
|
||||||
|
for imgkey, img in imgs.items():
|
||||||
|
img = torch.from_numpy(img).float()
|
||||||
# convert to (b c h w) torch format
|
# convert to (b c h w) torch format
|
||||||
obs["observation.image"] = einops.rearrange(obs["observation.image"], "b h w c -> b c h w")
|
img = einops.rearrange(img, "b h w c -> b c h w")
|
||||||
|
obs[imgkey] = img
|
||||||
|
|
||||||
|
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
||||||
|
|
||||||
# apply same transforms as in training
|
# apply same transforms as in training
|
||||||
if transform is not None:
|
if transform is not None:
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
def make_policy(cfg):
|
def make_policy(cfg):
|
||||||
if cfg.policy.name not in ["diffusion", "act"] and cfg.rollout_batch_size > 1:
|
|
||||||
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
|
|
||||||
|
|
||||||
if cfg.policy.name == "tdmpc":
|
if cfg.policy.name == "tdmpc":
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
|
|
||||||
policy = TDMPCPolicy(cfg.policy, cfg.device)
|
policy = TDMPCPolicy(
|
||||||
|
cfg.policy, n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, device=cfg.device
|
||||||
|
)
|
||||||
elif cfg.policy.name == "diffusion":
|
elif cfg.policy.name == "diffusion":
|
||||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||||
|
|
||||||
|
@ -17,14 +16,18 @@ def make_policy(cfg):
|
||||||
cfg_obs_encoder=cfg.obs_encoder,
|
cfg_obs_encoder=cfg.obs_encoder,
|
||||||
cfg_optimizer=cfg.optimizer,
|
cfg_optimizer=cfg.optimizer,
|
||||||
cfg_ema=cfg.ema,
|
cfg_ema=cfg.ema,
|
||||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
n_obs_steps=cfg.n_obs_steps,
|
||||||
|
n_action_steps=cfg.n_action_steps,
|
||||||
**cfg.policy,
|
**cfg.policy,
|
||||||
)
|
)
|
||||||
elif cfg.policy.name == "act":
|
elif cfg.policy.name == "act":
|
||||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||||
|
|
||||||
policy = ActionChunkingTransformerPolicy(
|
policy = ActionChunkingTransformerPolicy(
|
||||||
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
|
cfg.policy,
|
||||||
|
cfg.device,
|
||||||
|
n_obs_steps=cfg.n_obs_steps,
|
||||||
|
n_action_steps=cfg.n_action_steps,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.policy.name)
|
raise ValueError(cfg.policy.name)
|
||||||
|
|
|
@ -154,8 +154,14 @@ class TDMPCPolicy(nn.Module):
|
||||||
if len(self._queues["action"]) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||||
|
|
||||||
|
if self.n_obs_steps == 1:
|
||||||
|
# hack to remove the time dimension
|
||||||
|
for key in batch:
|
||||||
|
assert batch[key].shape[1] == 1
|
||||||
|
batch[key] = batch[key][:, 0]
|
||||||
|
|
||||||
actions = []
|
actions = []
|
||||||
batch_size = batch["observation.image."].shape[0]
|
batch_size = batch["observation.image"].shape[0]
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
obs = {
|
obs = {
|
||||||
"rgb": batch["observation.image"][[i]],
|
"rgb": batch["observation.image"][[i]],
|
||||||
|
@ -166,6 +172,10 @@ class TDMPCPolicy(nn.Module):
|
||||||
actions.append(action)
|
actions.append(action)
|
||||||
action = torch.stack(actions)
|
action = torch.stack(actions)
|
||||||
|
|
||||||
|
# self.act returns an action for 1 timestep only, so we copy it over `n_action_steps` time
|
||||||
|
if i in range(self.n_action_steps):
|
||||||
|
self._queues["action"].append(action)
|
||||||
|
|
||||||
action = self._queues["action"].popleft()
|
action = self._queues["action"].popleft()
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
@ -410,22 +420,45 @@ class TDMPCPolicy(nn.Module):
|
||||||
# idxs = torch.cat([idxs, demo_idxs])
|
# idxs = torch.cat([idxs, demo_idxs])
|
||||||
# weights = torch.cat([weights, demo_weights])
|
# weights = torch.cat([weights, demo_weights])
|
||||||
|
|
||||||
|
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
||||||
|
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
|
||||||
|
# batch size b = 256, time/horizon t = 5
|
||||||
|
# b t ... -> t b ...
|
||||||
|
for key in batch:
|
||||||
|
if batch[key].ndim > 1:
|
||||||
|
batch[key] = batch[key].transpose(1, 0)
|
||||||
|
|
||||||
|
action = batch["action"]
|
||||||
|
reward = batch["next.reward"][:, :, None] # add extra channel dimension
|
||||||
|
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
||||||
|
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
|
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
|
weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
|
|
||||||
|
obses = {
|
||||||
|
"rgb": batch["observation.image"],
|
||||||
|
"state": batch["observation.state"],
|
||||||
|
}
|
||||||
|
|
||||||
|
shapes = {}
|
||||||
|
for k in obses:
|
||||||
|
shapes[k] = obses[k].shape
|
||||||
|
obses[k] = einops.rearrange(obses[k], "t b ... -> (t b) ... ")
|
||||||
|
|
||||||
# Apply augmentations
|
# Apply augmentations
|
||||||
aug_tf = h.aug(self.cfg)
|
aug_tf = h.aug(self.cfg)
|
||||||
obs = aug_tf(obs)
|
obses = aug_tf(obses)
|
||||||
|
|
||||||
for k in next_obses:
|
for k in obses:
|
||||||
next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...")
|
t, b = shapes[k][:2]
|
||||||
next_obses = aug_tf(next_obses)
|
obses[k] = einops.rearrange(obses[k], "(t b) ... -> t b ... ", b=b, t=t)
|
||||||
for k in next_obses:
|
|
||||||
next_obses[k] = einops.rearrange(
|
|
||||||
next_obses[k],
|
|
||||||
"(h t) ... -> h t ...",
|
|
||||||
h=self.cfg.horizon,
|
|
||||||
t=self.cfg.batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
horizon = self.cfg.horizon
|
obs, next_obses = {}, {}
|
||||||
|
for k in obses:
|
||||||
|
obs[k] = obses[k][0]
|
||||||
|
next_obses[k] = obses[k][1:].clone()
|
||||||
|
|
||||||
|
horizon = next_obses["rgb"].shape[0]
|
||||||
loss_mask = torch.ones_like(mask, device=self.device)
|
loss_mask = torch.ones_like(mask, device=self.device)
|
||||||
for t in range(1, horizon):
|
for t in range(1, horizon):
|
||||||
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
|
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
|
||||||
|
@ -497,19 +530,19 @@ class TDMPCPolicy(nn.Module):
|
||||||
)
|
)
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
|
|
||||||
if self.cfg.per:
|
# if self.cfg.per:
|
||||||
# Update priorities
|
# # Update priorities
|
||||||
priorities = priority_loss.clamp(max=1e4).detach()
|
# priorities = priority_loss.clamp(max=1e4).detach()
|
||||||
has_nan = torch.isnan(priorities).any().item()
|
# has_nan = torch.isnan(priorities).any().item()
|
||||||
if has_nan:
|
# if has_nan:
|
||||||
print(f"priorities has nan: {priorities=}")
|
# print(f"priorities has nan: {priorities=}")
|
||||||
else:
|
# else:
|
||||||
replay_buffer.update_priority(
|
# replay_buffer.update_priority(
|
||||||
idxs[:num_slices],
|
# idxs[:num_slices],
|
||||||
priorities[:num_slices],
|
# priorities[:num_slices],
|
||||||
)
|
# )
|
||||||
if demo_batch_size > 0:
|
# if demo_batch_size > 0:
|
||||||
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
|
# demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
|
||||||
|
|
||||||
# Update policy + target network
|
# Update policy + target network
|
||||||
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
|
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
|
||||||
|
@ -532,7 +565,7 @@ class TDMPCPolicy(nn.Module):
|
||||||
"data_s": data_s,
|
"data_s": data_s,
|
||||||
"update_s": time.time() - start_time,
|
"update_s": time.time() - start_time,
|
||||||
}
|
}
|
||||||
info["demo_batch_size"] = demo_batch_size
|
# info["demo_batch_size"] = demo_batch_size
|
||||||
info["expectile"] = expectile
|
info["expectile"] = expectile
|
||||||
info.update(value_info)
|
info.update(value_info)
|
||||||
info.update(pi_update_info)
|
info.update(pi_update_info)
|
||||||
|
|
|
@ -10,7 +10,6 @@ log_freq: 250
|
||||||
|
|
||||||
horizon: 100
|
horizon: 100
|
||||||
n_obs_steps: 1
|
n_obs_steps: 1
|
||||||
n_latency_steps: 0
|
|
||||||
# when temporal_agg=False, n_action_steps=horizon
|
# when temporal_agg=False, n_action_steps=horizon
|
||||||
n_action_steps: ${horizon}
|
n_action_steps: ${horizon}
|
||||||
|
|
||||||
|
@ -57,3 +56,8 @@ policy:
|
||||||
|
|
||||||
state_dim: ???
|
state_dim: ???
|
||||||
action_dim: ???
|
action_dim: ???
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
observation.image: [0.0]
|
||||||
|
observation.state: [0.0]
|
||||||
|
action: [0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, 0.4, 0.42, 0.44, 0.46, 0.48, 0.5, 0.52, 0.54, 0.56, 0.58, 0.6, 0.62, 0.64, 0.66, 0.68, 0.70, 0.72, 0.74, 0.76, 0.78, 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98, 1.0, 1.02, 1.04, 1.06, 1.08, 1.1, 1.12, 1.14, 1.16, 1.18, 1.2, 1.22, 1.24, 1.26, 1.28, 1.3, 1.32, 1.34, 1.36, 1.38, 1.40, 1.42, 1.44, 1.46, 1.48, 1.5, 1.52, 1.54, 1.56, 1.58, 1.6, 1.62, 1.64, 1.66, 1.68, 1.7, 1.72, 1.74, 1.76, 1.78, 1.8, 1.82, 1.84, 1.86, 1.88, 1.90, 1.92, 1.94, 1.96, 1.98]
|
||||||
|
|
|
@ -16,7 +16,6 @@ seed: 100000
|
||||||
horizon: 16
|
horizon: 16
|
||||||
n_obs_steps: 2
|
n_obs_steps: 2
|
||||||
n_action_steps: 8
|
n_action_steps: 8
|
||||||
n_latency_steps: 0
|
|
||||||
dataset_obs_steps: ${n_obs_steps}
|
dataset_obs_steps: ${n_obs_steps}
|
||||||
past_action_visible: False
|
past_action_visible: False
|
||||||
keypoint_visible_rate: 1.0
|
keypoint_visible_rate: 1.0
|
||||||
|
@ -38,7 +37,6 @@ policy:
|
||||||
shape_meta: ${shape_meta}
|
shape_meta: ${shape_meta}
|
||||||
|
|
||||||
horizon: ${horizon}
|
horizon: ${horizon}
|
||||||
# n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
|
||||||
n_obs_steps: ${n_obs_steps}
|
n_obs_steps: ${n_obs_steps}
|
||||||
num_inference_steps: 100
|
num_inference_steps: 100
|
||||||
obs_as_global_cond: ${obs_as_global_cond}
|
obs_as_global_cond: ${obs_as_global_cond}
|
||||||
|
@ -64,6 +62,11 @@ policy:
|
||||||
lr_warmup_steps: 500
|
lr_warmup_steps: 500
|
||||||
grad_clip_norm: 10
|
grad_clip_norm: 10
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
observation.image: [-.1, 0]
|
||||||
|
observation.state: [-.1, 0]
|
||||||
|
action: [-.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4]
|
||||||
|
|
||||||
noise_scheduler:
|
noise_scheduler:
|
||||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||||
num_train_timesteps: 100
|
num_train_timesteps: 100
|
||||||
|
|
|
@ -77,3 +77,9 @@ policy:
|
||||||
num_q: 5
|
num_q: 5
|
||||||
mlp_dim: 512
|
mlp_dim: 512
|
||||||
latent_dim: 50
|
latent_dim: 50
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
observation.image: "[i / ${fps} for i in range(6)]"
|
||||||
|
observation.state: "[i / ${fps} for i in range(6)]"
|
||||||
|
action: "[i / ${fps} for i in range(5)]"
|
||||||
|
next.reward: "[i / ${fps} for i in range(5)]"
|
||||||
|
|
|
@ -148,8 +148,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
# )
|
# )
|
||||||
|
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
# TODO(now): uncomment
|
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||||
#env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
|
||||||
|
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
|
|
@ -880,6 +880,26 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
protobuf = ["grpcio-tools (>=1.62.1)"]
|
protobuf = ["grpcio-tools (>=1.62.1)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gym-aloha"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "A gym environment for ALOHA"
|
||||||
|
optional = true
|
||||||
|
python-versions = "^3.10"
|
||||||
|
files = []
|
||||||
|
develop = false
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
dm-control = "1.0.14"
|
||||||
|
gymnasium = "^0.29.1"
|
||||||
|
mujoco = "^2.3.7"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "git"
|
||||||
|
url = "git@github.com:huggingface/gym-aloha.git"
|
||||||
|
reference = "HEAD"
|
||||||
|
resolved_reference = "ec7200831e36c14e343cf7d275c6b047f2fe9d11"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gym-pusht"
|
name = "gym-pusht"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
@ -3714,10 +3734,11 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link
|
||||||
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
|
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
|
||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
|
aloha = ["gym_aloha"]
|
||||||
pusht = ["gym_pusht"]
|
pusht = ["gym_pusht"]
|
||||||
xarm = ["gym_xarm"]
|
xarm = ["gym_xarm"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "c9524cdf000eaa755a2ab3be669118222b4f8b1c262013f103f6874cbd54eeb6"
|
content-hash = "6ef509580cef6bc50e9fbb5095097cbf21218d293a2d171155ced4bbe1d3e151"
|
||||||
|
|
|
@ -54,12 +54,13 @@ gymnasium = "^0.29.1"
|
||||||
cmake = "^3.29.0.1"
|
cmake = "^3.29.0.1"
|
||||||
gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
|
gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
|
||||||
gym_xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
|
gym_xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
|
||||||
# gym_pusht = { path = "../gym-pusht", develop = true, optional = true}
|
gym_aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true}
|
||||||
# gym_xarm = { path = "../gym-xarm", develop = true, optional = true}
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
pusht = ["gym_pusht"]
|
pusht = ["gym_pusht"]
|
||||||
xarm = ["gym_xarm"]
|
xarm = ["gym_xarm"]
|
||||||
|
aloha = ["gym_aloha"]
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pre-commit = "^3.6.2"
|
pre-commit = "^3.6.2"
|
||||||
|
|
|
@ -15,50 +15,50 @@ Note:
|
||||||
import pytest
|
import pytest
|
||||||
import lerobot
|
import lerobot
|
||||||
|
|
||||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
# from lerobot.common.envs.aloha.env import AlohaEnv
|
||||||
from lerobot.common.envs.pusht.env import PushtEnv
|
# from gym_pusht.envs import PushtEnv
|
||||||
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
# from gym_xarm.envs import SimxarmEnv
|
||||||
|
|
||||||
from lerobot.common.datasets.simxarm import SimxarmDataset
|
# from lerobot.common.datasets.simxarm import SimxarmDataset
|
||||||
from lerobot.common.datasets.aloha import AlohaDataset
|
# from lerobot.common.datasets.aloha import AlohaDataset
|
||||||
from lerobot.common.datasets.pusht import PushtDataset
|
# from lerobot.common.datasets.pusht import PushtDataset
|
||||||
|
|
||||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
# from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
# from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
# from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
|
|
||||||
|
|
||||||
def test_available():
|
# def test_available():
|
||||||
pol_classes = [
|
# pol_classes = [
|
||||||
ActionChunkingTransformerPolicy,
|
# ActionChunkingTransformerPolicy,
|
||||||
DiffusionPolicy,
|
# DiffusionPolicy,
|
||||||
TDMPCPolicy,
|
# TDMPCPolicy,
|
||||||
]
|
# ]
|
||||||
|
|
||||||
env_classes = [
|
# env_classes = [
|
||||||
AlohaEnv,
|
# AlohaEnv,
|
||||||
PushtEnv,
|
# PushtEnv,
|
||||||
SimxarmEnv,
|
# SimxarmEnv,
|
||||||
]
|
# ]
|
||||||
|
|
||||||
dat_classes = [
|
# dat_classes = [
|
||||||
AlohaDataset,
|
# AlohaDataset,
|
||||||
PushtDataset,
|
# PushtDataset,
|
||||||
SimxarmDataset,
|
# SimxarmDataset,
|
||||||
]
|
# ]
|
||||||
|
|
||||||
policies = [pol_cls.name for pol_cls in pol_classes]
|
# policies = [pol_cls.name for pol_cls in pol_classes]
|
||||||
assert set(policies) == set(lerobot.available_policies)
|
# assert set(policies) == set(lerobot.available_policies)
|
||||||
|
|
||||||
envs = [env_cls.name for env_cls in env_classes]
|
# envs = [env_cls.name for env_cls in env_classes]
|
||||||
assert set(envs) == set(lerobot.available_envs)
|
# assert set(envs) == set(lerobot.available_envs)
|
||||||
|
|
||||||
tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes}
|
# tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes}
|
||||||
for env in envs:
|
# for env in envs:
|
||||||
assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env])
|
# assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env])
|
||||||
|
|
||||||
datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)}
|
# datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)}
|
||||||
for env in envs:
|
# for env in envs:
|
||||||
assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env])
|
# assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,38 +9,9 @@ from lerobot.common.utils import init_hydra_config
|
||||||
|
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
|
|
||||||
# import dmc_aloha # noqa: F401
|
|
||||||
|
|
||||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
|
||||||
# def print_spec_rollout(env):
|
|
||||||
# print("observation_spec:", env.observation_spec)
|
|
||||||
# print("action_spec:", env.action_spec)
|
|
||||||
# print("reward_spec:", env.reward_spec)
|
|
||||||
# print("done_spec:", env.done_spec)
|
|
||||||
|
|
||||||
# td = env.reset()
|
|
||||||
# print("reset tensordict", td)
|
|
||||||
|
|
||||||
# td = env.rand_step(td)
|
|
||||||
# print("random step tensordict", td)
|
|
||||||
|
|
||||||
# def simple_rollout(steps=100):
|
|
||||||
# # preallocate:
|
|
||||||
# data = TensorDict({}, [steps])
|
|
||||||
# # reset
|
|
||||||
# _data = env.reset()
|
|
||||||
# for i in range(steps):
|
|
||||||
# _data["action"] = env.action_spec.rand()
|
|
||||||
# _data = env.step(_data)
|
|
||||||
# data[i] = _data
|
|
||||||
# _data = step_mdp(_data, keep_other=True)
|
|
||||||
# return data
|
|
||||||
|
|
||||||
# print("data from rollout:", simple_rollout(100))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_task, obs_type",
|
"env_task, obs_type",
|
||||||
[
|
[
|
||||||
|
@ -54,7 +25,7 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||||
def test_aloha(env_task, obs_type):
|
def test_aloha(env_task, obs_type):
|
||||||
from lerobot.common.envs import aloha as gym_aloha # noqa: F401
|
from lerobot.common.envs import aloha as gym_aloha # noqa: F401
|
||||||
env = gym.make(f"gym_aloha/{env_task}", obs_type=obs_type)
|
env = gym.make(f"gym_aloha/{env_task}", obs_type=obs_type)
|
||||||
check_env(env)
|
check_env(env.unwrapped)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,7 +41,7 @@ def test_aloha(env_task, obs_type):
|
||||||
def test_xarm(env_task, obs_type):
|
def test_xarm(env_task, obs_type):
|
||||||
import gym_xarm # noqa: F401
|
import gym_xarm # noqa: F401
|
||||||
env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type)
|
env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type)
|
||||||
check_env(env)
|
check_env(env.unwrapped)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,7 +56,7 @@ def test_xarm(env_task, obs_type):
|
||||||
def test_pusht(env_task, obs_type):
|
def test_pusht(env_task, obs_type):
|
||||||
import gym_pusht # noqa: F401
|
import gym_pusht # noqa: F401
|
||||||
env = gym.make(f"gym_pusht/{env_task}", obs_type=obs_type)
|
env = gym.make(f"gym_pusht/{env_task}", obs_type=obs_type)
|
||||||
check_env(env)
|
check_env(env.unwrapped)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -93,7 +64,7 @@ def test_pusht(env_task, obs_type):
|
||||||
[
|
[
|
||||||
"pusht",
|
"pusht",
|
||||||
"simxarm",
|
"simxarm",
|
||||||
# "aloha",
|
"aloha",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_factory(env_name):
|
def test_factory(env_name):
|
||||||
|
@ -104,9 +75,8 @@ def test_factory(env_name):
|
||||||
|
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
env = make_env(cfg)
|
env = make_env(cfg, num_parallel_envs=1)
|
||||||
obs, info = env.reset()
|
obs, info = env.reset()
|
||||||
obs = {key: obs[key][None, ...] for key in obs}
|
|
||||||
obs = preprocess_observation(obs, transform=dataset.transform)
|
obs = preprocess_observation(obs, transform=dataset.transform)
|
||||||
for key in dataset.image_keys:
|
for key in dataset.image_keys:
|
||||||
img = obs[key]
|
img = obs[key]
|
||||||
|
|
|
@ -1,14 +1,11 @@
|
||||||
import pytest
|
import pytest
|
||||||
from tensordict import TensorDict
|
|
||||||
from tensordict.nn import TensorDictModule
|
|
||||||
import torch
|
import torch
|
||||||
from torchrl.data import UnboundedContinuousTensorSpec
|
|
||||||
from torchrl.envs import EnvBase
|
|
||||||
|
|
||||||
|
from lerobot.common.datasets.utils import cycle
|
||||||
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
|
||||||
from lerobot.common.utils import init_hydra_config
|
from lerobot.common.utils import init_hydra_config
|
||||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
@ -16,22 +13,23 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||||
"env_name,policy_name,extra_overrides",
|
"env_name,policy_name,extra_overrides",
|
||||||
[
|
[
|
||||||
("simxarm", "tdmpc", ["policy.mpc=true"]),
|
("simxarm", "tdmpc", ["policy.mpc=true"]),
|
||||||
("pusht", "tdmpc", ["policy.mpc=false"]),
|
#("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||||
("pusht", "diffusion", []),
|
("pusht", "diffusion", []),
|
||||||
("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
|
# ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
|
||||||
("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
|
#("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
|
||||||
("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
|
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
|
||||||
("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
||||||
# TODO(aliberts): simxarm not working with diffusion
|
# TODO(aliberts): simxarm not working with diffusion
|
||||||
# ("simxarm", "diffusion", []),
|
# ("simxarm", "diffusion", []),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_concrete_policy(env_name, policy_name, extra_overrides):
|
def test_policy(env_name, policy_name, extra_overrides):
|
||||||
"""
|
"""
|
||||||
Tests:
|
Tests:
|
||||||
- Making the policy object.
|
- Making the policy object.
|
||||||
- Updating the policy.
|
- Updating the policy.
|
||||||
- Using the policy to select actions at inference time.
|
- Using the policy to select actions at inference time.
|
||||||
|
- Test the action can be applied to the policy
|
||||||
"""
|
"""
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
DEFAULT_CONFIG_PATH,
|
DEFAULT_CONFIG_PATH,
|
||||||
|
@ -46,91 +44,43 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
# Check that we run select_actions and get the appropriate output.
|
# Check that we run select_actions and get the appropriate output.
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
env = make_env(cfg, transform=dataset.transform)
|
env = make_env(cfg, num_parallel_envs=2)
|
||||||
|
|
||||||
if env_name != "aloha":
|
dataloader = torch.utils.data.DataLoader(
|
||||||
# TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError:
|
dataset,
|
||||||
# seq_length as a list is not supported for now.
|
num_workers=4,
|
||||||
policy.update(dataset, torch.tensor(0, device=DEVICE))
|
batch_size=cfg.policy.batch_size,
|
||||||
|
shuffle=True,
|
||||||
action = policy(
|
pin_memory=DEVICE != "cpu",
|
||||||
env.observation_spec.rand()["observation"].to(DEVICE),
|
drop_last=True,
|
||||||
torch.tensor(0, device=DEVICE),
|
|
||||||
)
|
)
|
||||||
assert action.shape == env.action_spec.shape
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
|
batch = next(dl_iter)
|
||||||
|
|
||||||
def test_abstract_policy_forward():
|
for key in batch:
|
||||||
"""
|
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||||
Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that:
|
|
||||||
- The policy is invoked the expected number of times during a rollout.
|
|
||||||
- The environment's termination condition is respected even when part way through an action trajectory.
|
|
||||||
- The observations are returned correctly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
n_action_steps = 8 # our test policy will output 8 action step horizons
|
# Test updating the policy
|
||||||
terminate_at = 10 # some number that is more than n_action_steps but not a multiple
|
policy(batch, step=0)
|
||||||
rollout_max_steps = terminate_at + 1 # some number greater than terminate_at
|
|
||||||
|
|
||||||
# A minimal environment for testing.
|
# reset the policy and environment
|
||||||
class StubEnv(EnvBase):
|
policy.reset()
|
||||||
|
observation, _ = env.reset(seed=cfg.seed)
|
||||||
|
|
||||||
def __init__(self):
|
# apply transform to normalize the observations
|
||||||
super().__init__()
|
observation = preprocess_observation(observation, dataset.transform)
|
||||||
self.action_spec = UnboundedContinuousTensorSpec(shape=(1,))
|
|
||||||
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
|
|
||||||
|
|
||||||
def _step(self, tensordict: TensorDict) -> TensorDict:
|
# send observation to device/gpu
|
||||||
self.invocation_count += 1
|
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||||
return TensorDict(
|
|
||||||
{
|
|
||||||
"observation": torch.tensor([self.invocation_count]),
|
|
||||||
"reward": torch.tensor([self.invocation_count]),
|
|
||||||
"terminated": torch.tensor(
|
|
||||||
tensordict["action"].item() == terminate_at
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _reset(self, tensordict: TensorDict) -> TensorDict:
|
# get the next action for the environment
|
||||||
self.invocation_count = 0
|
with torch.inference_mode():
|
||||||
return TensorDict(
|
action = policy.select_action(observation, step=0)
|
||||||
{
|
|
||||||
"observation": torch.tensor([self.invocation_count]),
|
|
||||||
"reward": torch.tensor([self.invocation_count]),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _set_seed(self, seed: int | None):
|
# apply inverse transform to unnormalize the action
|
||||||
return
|
action = postprocess_action(action, dataset.transform)
|
||||||
|
|
||||||
class StubPolicy(AbstractPolicy):
|
# Test step through policy
|
||||||
name = "stub"
|
env.step(action)
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(n_action_steps)
|
|
||||||
self.n_policy_invocations = 0
|
|
||||||
|
|
||||||
def update(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def select_actions(self):
|
|
||||||
self.n_policy_invocations += 1
|
|
||||||
return torch.stack(
|
|
||||||
[torch.tensor([i]) for i in range(self.n_action_steps)]
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
env = StubEnv()
|
|
||||||
policy = StubPolicy()
|
|
||||||
policy = TensorDictModule(
|
|
||||||
policy,
|
|
||||||
in_keys=[],
|
|
||||||
out_keys=["action"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Keep track to make sure the policy is called the expected number of times
|
|
||||||
rollout = env.rollout(rollout_max_steps, policy)
|
|
||||||
|
|
||||||
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
|
|
||||||
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
|
|
||||||
assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))
|
|
||||||
|
|
Loading…
Reference in New Issue