Add AbstractEnv, Refactor AlohaEnv, Add rendering_hook in env, Minor modifications, (TODO: Refactor Pusht and Simxarm)
This commit is contained in:
parent
b49f7b70e2
commit
7bf36cd413
|
@ -124,9 +124,6 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||||
def image_keys(self) -> list:
|
def image_keys(self) -> list:
|
||||||
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
||||||
|
|
||||||
# def _is_downloaded(self) -> bool:
|
|
||||||
# return False
|
|
||||||
|
|
||||||
def _download_and_preproc(self):
|
def _download_and_preproc(self):
|
||||||
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
||||||
if not raw_dir.is_dir():
|
if not raw_dir.is_dir():
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
import abc
|
||||||
|
from collections import deque
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.envs import EnvBase
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractEnv(EnvBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task,
|
||||||
|
frame_skip: int = 1,
|
||||||
|
from_pixels: bool = False,
|
||||||
|
pixels_only: bool = False,
|
||||||
|
image_size=None,
|
||||||
|
seed=1337,
|
||||||
|
device="cpu",
|
||||||
|
num_prev_obs=1,
|
||||||
|
num_prev_action=0,
|
||||||
|
):
|
||||||
|
super().__init__(device=device, batch_size=[])
|
||||||
|
self.task = task
|
||||||
|
self.frame_skip = frame_skip
|
||||||
|
self.from_pixels = from_pixels
|
||||||
|
self.pixels_only = pixels_only
|
||||||
|
self.image_size = image_size
|
||||||
|
self.num_prev_obs = num_prev_obs
|
||||||
|
self.num_prev_action = num_prev_action
|
||||||
|
self._rendering_hooks = []
|
||||||
|
|
||||||
|
if pixels_only:
|
||||||
|
assert from_pixels
|
||||||
|
if from_pixels:
|
||||||
|
assert image_size
|
||||||
|
|
||||||
|
self._make_spec()
|
||||||
|
self._current_seed = self.set_seed(seed)
|
||||||
|
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||||
|
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||||
|
if self.num_prev_action > 0:
|
||||||
|
raise NotImplementedError()
|
||||||
|
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||||
|
|
||||||
|
def register_rendering_hook(self, func):
|
||||||
|
self._rendering_hooks.append(func)
|
||||||
|
|
||||||
|
def call_rendering_hooks(self):
|
||||||
|
for func in self._rendering_hooks:
|
||||||
|
func(self)
|
||||||
|
|
||||||
|
def reset_rendering_hooks(self):
|
||||||
|
self._rendering_hooks = []
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def render(self, mode="rgb_array", width=640, height=480):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _step(self, tensordict: TensorDict):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _make_spec(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _set_seed(self, seed: Optional[int]):
|
||||||
|
raise NotImplementedError()
|
|
@ -15,8 +15,8 @@ from torchrl.data.tensor_specs import (
|
||||||
DiscreteTensorSpec,
|
DiscreteTensorSpec,
|
||||||
UnboundedContinuousTensorSpec,
|
UnboundedContinuousTensorSpec,
|
||||||
)
|
)
|
||||||
from torchrl.envs import EnvBase
|
|
||||||
|
|
||||||
|
from lerobot.common.envs.abstract import AbstractEnv
|
||||||
from lerobot.common.envs.aloha.constants import (
|
from lerobot.common.envs.aloha.constants import (
|
||||||
ACTIONS,
|
ACTIONS,
|
||||||
ASSETS_DIR,
|
ASSETS_DIR,
|
||||||
|
@ -28,14 +28,13 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import (
|
||||||
InsertionEndEffectorTask,
|
InsertionEndEffectorTask,
|
||||||
TransferCubeEndEffectorTask,
|
TransferCubeEndEffectorTask,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||||
from lerobot.common.utils import set_seed
|
from lerobot.common.utils import set_seed
|
||||||
|
|
||||||
from .utils import sample_box_pose, sample_insertion_pose
|
|
||||||
|
|
||||||
_has_gym = importlib.util.find_spec("gym") is not None
|
_has_gym = importlib.util.find_spec("gym") is not None
|
||||||
|
|
||||||
|
|
||||||
class AlohaEnv(EnvBase):
|
class AlohaEnv(AbstractEnv):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
task,
|
task,
|
||||||
|
@ -48,20 +47,17 @@ class AlohaEnv(EnvBase):
|
||||||
num_prev_obs=1,
|
num_prev_obs=1,
|
||||||
num_prev_action=0,
|
num_prev_action=0,
|
||||||
):
|
):
|
||||||
super().__init__(device=device, batch_size=[])
|
super().__init__(
|
||||||
self.task = task
|
task=task,
|
||||||
self.frame_skip = frame_skip
|
frame_skip=frame_skip,
|
||||||
self.from_pixels = from_pixels
|
from_pixels=from_pixels,
|
||||||
self.pixels_only = pixels_only
|
pixels_only=pixels_only,
|
||||||
self.image_size = image_size
|
image_size=image_size,
|
||||||
self.num_prev_obs = num_prev_obs
|
seed=seed,
|
||||||
self.num_prev_action = num_prev_action
|
device=device,
|
||||||
|
num_prev_obs=num_prev_obs,
|
||||||
if pixels_only:
|
num_prev_action=num_prev_action,
|
||||||
assert from_pixels
|
)
|
||||||
if from_pixels:
|
|
||||||
assert image_size
|
|
||||||
|
|
||||||
if not _has_gym:
|
if not _has_gym:
|
||||||
raise ImportError("Cannot import gym.")
|
raise ImportError("Cannot import gym.")
|
||||||
|
|
||||||
|
@ -70,16 +66,6 @@ class AlohaEnv(EnvBase):
|
||||||
|
|
||||||
self._env = self._make_env_task(task)
|
self._env = self._make_env_task(task)
|
||||||
|
|
||||||
self._make_spec()
|
|
||||||
self._current_seed = self.set_seed(seed)
|
|
||||||
|
|
||||||
if self.num_prev_obs > 0:
|
|
||||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
|
||||||
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
|
||||||
if self.num_prev_action > 0:
|
|
||||||
raise NotImplementedError()
|
|
||||||
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
|
||||||
|
|
||||||
def render(self, mode="rgb_array", width=640, height=480):
|
def render(self, mode="rgb_array", width=640, height=480):
|
||||||
# TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
|
# TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
|
||||||
image = self._env.physics.render(height=height, width=width, camera_id="top")
|
image = self._env.physics.render(height=height, width=width, camera_id="top")
|
||||||
|
@ -172,6 +158,8 @@ class AlohaEnv(EnvBase):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
self.call_rendering_hooks()
|
||||||
return td
|
return td
|
||||||
|
|
||||||
def _step(self, tensordict: TensorDict):
|
def _step(self, tensordict: TensorDict):
|
||||||
|
@ -207,6 +195,8 @@ class AlohaEnv(EnvBase):
|
||||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
obs = stacked_obs
|
obs = stacked_obs
|
||||||
|
|
||||||
|
self.call_rendering_hooks()
|
||||||
|
|
||||||
td = TensorDict(
|
td = TensorDict(
|
||||||
{
|
{
|
||||||
"observation": TensorDict(obs, batch_size=[]),
|
"observation": TensorDict(obs, batch_size=[]),
|
||||||
|
|
|
@ -27,7 +27,9 @@ def get_sinusoid_encoding_table(n_position, d_hid):
|
||||||
class DETRVAE(nn.Module):
|
class DETRVAE(nn.Module):
|
||||||
"""This is the DETR module that performs object detection"""
|
"""This is the DETR module that performs object detection"""
|
||||||
|
|
||||||
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names):
|
def __init__(
|
||||||
|
self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae
|
||||||
|
):
|
||||||
"""Initializes the model.
|
"""Initializes the model.
|
||||||
Parameters:
|
Parameters:
|
||||||
backbones: torch module of the backbone to be used. See backbone.py
|
backbones: torch module of the backbone to be used. See backbone.py
|
||||||
|
@ -42,6 +44,7 @@ class DETRVAE(nn.Module):
|
||||||
self.camera_names = camera_names
|
self.camera_names = camera_names
|
||||||
self.transformer = transformer
|
self.transformer = transformer
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
|
self.vae = vae
|
||||||
hidden_dim = transformer.d_model
|
hidden_dim = transformer.d_model
|
||||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||||
|
@ -86,7 +89,7 @@ class DETRVAE(nn.Module):
|
||||||
is_training = actions is not None # train or val
|
is_training = actions is not None # train or val
|
||||||
bs, _ = qpos.shape
|
bs, _ = qpos.shape
|
||||||
### Obtain latent z from action sequence
|
### Obtain latent z from action sequence
|
||||||
if is_training:
|
if self.vae and is_training:
|
||||||
# project action sequence to embedding dim, and concat with a CLS token
|
# project action sequence to embedding dim, and concat with a CLS token
|
||||||
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
|
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
|
||||||
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||||
|
@ -200,6 +203,7 @@ def build(args):
|
||||||
action_dim=args.action_dim,
|
action_dim=args.action_dim,
|
||||||
num_queries=args.num_queries,
|
num_queries=args.num_queries,
|
||||||
camera_names=args.camera_names,
|
camera_names=args.camera_names,
|
||||||
|
vae=args.vae,
|
||||||
)
|
)
|
||||||
|
|
||||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
|
@ -11,7 +11,6 @@ from lerobot.common.policies.act.detr_vae import build
|
||||||
|
|
||||||
def build_act_model_and_optimizer(cfg):
|
def build_act_model_and_optimizer(cfg):
|
||||||
model = build(cfg)
|
model = build(cfg)
|
||||||
model.cuda()
|
|
||||||
|
|
||||||
param_dicts = [
|
param_dicts = [
|
||||||
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||||
|
@ -51,6 +50,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
self.kl_weight = self.cfg.kl_weight
|
self.kl_weight = self.cfg.kl_weight
|
||||||
logging.info(f"KL Weight {self.kl_weight}")
|
logging.info(f"KL Weight {self.kl_weight}")
|
||||||
|
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
def update(self, replay_buffer, step):
|
def update(self, replay_buffer, step):
|
||||||
del step
|
del step
|
||||||
|
|
||||||
|
@ -192,20 +193,25 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
image = normalize(image)
|
image = normalize(image)
|
||||||
|
|
||||||
is_train_mode = actions is not None
|
is_training = actions is not None
|
||||||
if is_train_mode: # training time
|
if is_training: # training time
|
||||||
actions = actions[:, : self.model.num_queries]
|
actions = actions[:, : self.model.num_queries]
|
||||||
if is_pad is not None:
|
if is_pad is not None:
|
||||||
is_pad = is_pad[:, : self.model.num_queries]
|
is_pad = is_pad[:, : self.model.num_queries]
|
||||||
|
|
||||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
|
||||||
loss_dict = {}
|
|
||||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||||
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||||
|
|
||||||
|
loss_dict = {}
|
||||||
loss_dict["l1"] = l1
|
loss_dict["l1"] = l1
|
||||||
|
if self.cfg.vae:
|
||||||
|
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||||
loss_dict["kl"] = total_kld[0]
|
loss_dict["kl"] = total_kld[0]
|
||||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||||
|
else:
|
||||||
|
loss_dict["loss"] = loss_dict["l1"]
|
||||||
return loss_dict
|
return loss_dict
|
||||||
else:
|
else:
|
||||||
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||||
|
|
|
@ -17,7 +17,7 @@ env:
|
||||||
pixels_only: False
|
pixels_only: False
|
||||||
image_size: [3, 480, 640]
|
image_size: [3, 480, 640]
|
||||||
action_repeat: 1
|
action_repeat: 1
|
||||||
episode_length: 300
|
episode_length: 400
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
|
|
|
@ -29,9 +29,10 @@ policy:
|
||||||
kl_weight: 10
|
kl_weight: 10
|
||||||
hidden_dim: 512
|
hidden_dim: 512
|
||||||
dim_feedforward: 3200
|
dim_feedforward: 3200
|
||||||
enc_layers: 7
|
enc_layers: 4
|
||||||
dec_layers: 8
|
dec_layers: 7
|
||||||
nheads: 8
|
nheads: 8
|
||||||
|
#camera_names: [top, front_close, left_pillar, right_pillar]
|
||||||
camera_names: [top]
|
camera_names: [top]
|
||||||
position_embedding: sine
|
position_embedding: sine
|
||||||
masks: false
|
masks: false
|
||||||
|
@ -39,6 +40,8 @@ policy:
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
pre_norm: false
|
pre_norm: false
|
||||||
|
|
||||||
|
vae: true
|
||||||
|
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
|
|
||||||
per_alpha: 0.6
|
per_alpha: 0.6
|
||||||
|
|
|
@ -38,27 +38,18 @@ def eval_policy(
|
||||||
successes = []
|
successes = []
|
||||||
threads = []
|
threads = []
|
||||||
for i in tqdm.tqdm(range(num_episodes)):
|
for i in tqdm.tqdm(range(num_episodes)):
|
||||||
tensordict = env.reset()
|
|
||||||
|
|
||||||
ep_frames = []
|
ep_frames = []
|
||||||
|
|
||||||
if save_video or (return_first_video and i == 0):
|
if save_video or (return_first_video and i == 0):
|
||||||
|
|
||||||
def rendering_callback(env, td=None):
|
def render_frame(env):
|
||||||
ep_frames.append(env.render()) # noqa: B023
|
ep_frames.append(env.render()) # noqa: B023
|
||||||
|
|
||||||
# render first frame before rollout
|
env.register_rendering_hook(render_frame)
|
||||||
rendering_callback(env)
|
|
||||||
else:
|
|
||||||
rendering_callback = None
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
rollout = env.rollout(
|
rollout = env.rollout(
|
||||||
max_steps=max_steps,
|
max_steps=max_steps,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
callback=rendering_callback,
|
|
||||||
auto_reset=False,
|
|
||||||
tensordict=tensordict,
|
|
||||||
auto_cast_to_device=True,
|
auto_cast_to_device=True,
|
||||||
)
|
)
|
||||||
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
||||||
|
@ -85,6 +76,8 @@ def eval_policy(
|
||||||
if return_first_video and i == 0:
|
if return_first_video and i == 0:
|
||||||
first_video = stacked_frames.transpose(0, 3, 1, 2)
|
first_video = stacked_frames.transpose(0, 3, 1, 2)
|
||||||
|
|
||||||
|
env.reset_rendering_hooks()
|
||||||
|
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -192,6 +193,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
num_episodes=cfg.eval_episodes,
|
num_episodes=cfg.eval_episodes,
|
||||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||||
return_first_video=True,
|
return_first_video=True,
|
||||||
|
video_dir=Path(out_dir) / "eval",
|
||||||
|
save_video=True,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
|
|
|
@ -17,6 +17,7 @@ apptainer exec --nv \
|
||||||
~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL
|
~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL
|
||||||
|
|
||||||
source ~/.bashrc
|
source ~/.bashrc
|
||||||
conda activate fowm
|
#conda activate fowm
|
||||||
|
conda activate lerobot
|
||||||
|
|
||||||
srun $CMD
|
srun $CMD
|
||||||
|
|
|
@ -12,10 +12,10 @@ from .utils import init_config
|
||||||
# ("simxarm", "lift"),
|
# ("simxarm", "lift"),
|
||||||
("pusht", "pusht"),
|
("pusht", "pusht"),
|
||||||
# TODO(aliberts): add aloha when dataset is available on hub
|
# TODO(aliberts): add aloha when dataset is available on hub
|
||||||
# ("aloha", "sim_insertion_human"),
|
("aloha", "sim_insertion_human"),
|
||||||
# ("aloha", "sim_insertion_scripted"),
|
("aloha", "sim_insertion_scripted"),
|
||||||
# ("aloha", "sim_transfer_cube_human"),
|
("aloha", "sim_transfer_cube_human"),
|
||||||
# ("aloha", "sim_transfer_cube_scripted"),
|
("aloha", "sim_transfer_cube_scripted"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_factory(env_name, dataset_id):
|
def test_factory(env_name, dataset_id):
|
||||||
|
|
Loading…
Reference in New Issue