Add AbstractEnv, Refactor AlohaEnv, Add rendering_hook in env, Minor modifications, (TODO: Refactor Pusht and Simxarm)

This commit is contained in:
Cadene 2024-03-10 22:00:48 +00:00
parent b49f7b70e2
commit 7bf36cd413
11 changed files with 131 additions and 59 deletions

View File

@ -124,9 +124,6 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
# def _is_downloaded(self) -> bool:
# return False
def _download_and_preproc(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
if not raw_dir.is_dir():

View File

@ -0,0 +1,75 @@
import abc
from collections import deque
from typing import Optional
from tensordict import TensorDict
from torchrl.envs import EnvBase
class AbstractEnv(EnvBase):
def __init__(
self,
task,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
self._rendering_hooks = []
if pixels_only:
assert from_pixels
if from_pixels:
assert image_size
self._make_spec()
self._current_seed = self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
def register_rendering_hook(self, func):
self._rendering_hooks.append(func)
def call_rendering_hooks(self):
for func in self._rendering_hooks:
func(self)
def reset_rendering_hooks(self):
self._rendering_hooks = []
@abc.abstractmethod
def render(self, mode="rgb_array", width=640, height=480):
raise NotImplementedError()
@abc.abstractmethod
def _reset(self, tensordict: Optional[TensorDict] = None):
raise NotImplementedError()
@abc.abstractmethod
def _step(self, tensordict: TensorDict):
raise NotImplementedError()
@abc.abstractmethod
def _make_spec(self):
raise NotImplementedError()
@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError()

View File

@ -15,8 +15,8 @@ from torchrl.data.tensor_specs import (
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import EnvBase
from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.envs.aloha.constants import (
ACTIONS,
ASSETS_DIR,
@ -28,14 +28,13 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import (
InsertionEndEffectorTask,
TransferCubeEndEffectorTask,
)
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
from lerobot.common.utils import set_seed
from .utils import sample_box_pose, sample_insertion_pose
_has_gym = importlib.util.find_spec("gym") is not None
class AlohaEnv(EnvBase):
class AlohaEnv(AbstractEnv):
def __init__(
self,
task,
@ -48,20 +47,17 @@ class AlohaEnv(EnvBase):
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
if pixels_only:
assert from_pixels
if from_pixels:
assert image_size
super().__init__(
task=task,
frame_skip=frame_skip,
from_pixels=from_pixels,
pixels_only=pixels_only,
image_size=image_size,
seed=seed,
device=device,
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
if not _has_gym:
raise ImportError("Cannot import gym.")
@ -70,16 +66,6 @@ class AlohaEnv(EnvBase):
self._env = self._make_env_task(task)
self._make_spec()
self._current_seed = self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
def render(self, mode="rgb_array", width=640, height=480):
# TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
image = self._env.physics.render(height=height, width=width, camera_id="top")
@ -172,6 +158,8 @@ class AlohaEnv(EnvBase):
)
else:
raise NotImplementedError()
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
@ -207,6 +195,8 @@ class AlohaEnv(EnvBase):
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),

View File

@ -27,7 +27,9 @@ def get_sinusoid_encoding_table(n_position, d_hid):
class DETRVAE(nn.Module):
"""This is the DETR module that performs object detection"""
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names):
def __init__(
self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae
):
"""Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
@ -42,6 +44,7 @@ class DETRVAE(nn.Module):
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
self.vae = vae
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
@ -86,7 +89,7 @@ class DETRVAE(nn.Module):
is_training = actions is not None # train or val
bs, _ = qpos.shape
### Obtain latent z from action sequence
if is_training:
if self.vae and is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
@ -200,6 +203,7 @@ def build(args):
action_dim=args.action_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
vae=args.vae,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

View File

@ -11,7 +11,6 @@ from lerobot.common.policies.act.detr_vae import build
def build_act_model_and_optimizer(cfg):
model = build(cfg)
model.cuda()
param_dicts = [
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
@ -51,6 +50,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.kl_weight = self.cfg.kl_weight
logging.info(f"KL Weight {self.kl_weight}")
self.to(self.device)
def update(self, replay_buffer, step):
del step
@ -192,20 +193,25 @@ class ActionChunkingTransformerPolicy(nn.Module):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = normalize(image)
is_train_mode = actions is not None
if is_train_mode: # training time
is_training = actions is not None
if is_training: # training time
actions = actions[:, : self.model.num_queries]
if is_pad is not None:
is_pad = is_pad[:, : self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = {}
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
loss_dict = {}
loss_dict["l1"] = l1
loss_dict["kl"] = total_kld[0]
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
if self.cfg.vae:
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict["kl"] = total_kld[0]
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
else:
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
else:
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior

View File

@ -17,7 +17,7 @@ env:
pixels_only: False
image_size: [3, 480, 640]
action_repeat: 1
episode_length: 300
episode_length: 400
fps: ${fps}
policy:

View File

@ -29,9 +29,10 @@ policy:
kl_weight: 10
hidden_dim: 512
dim_feedforward: 3200
enc_layers: 7
dec_layers: 8
enc_layers: 4
dec_layers: 7
nheads: 8
#camera_names: [top, front_close, left_pillar, right_pillar]
camera_names: [top]
position_embedding: sine
masks: false
@ -39,6 +40,8 @@ policy:
dropout: 0.1
pre_norm: false
vae: true
batch_size: 8
per_alpha: 0.6

View File

@ -38,27 +38,18 @@ def eval_policy(
successes = []
threads = []
for i in tqdm.tqdm(range(num_episodes)):
tensordict = env.reset()
ep_frames = []
if save_video or (return_first_video and i == 0):
def rendering_callback(env, td=None):
def render_frame(env):
ep_frames.append(env.render()) # noqa: B023
# render first frame before rollout
rendering_callback(env)
else:
rendering_callback = None
env.register_rendering_hook(render_frame)
with torch.inference_mode():
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
callback=rendering_callback,
auto_reset=False,
tensordict=tensordict,
auto_cast_to_device=True,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
@ -85,6 +76,8 @@ def eval_policy(
if return_first_video and i == 0:
first_video = stacked_frames.transpose(0, 3, 1, 2)
env.reset_rendering_hooks()
for thread in threads:
thread.join()

View File

@ -1,4 +1,5 @@
import logging
from pathlib import Path
import hydra
import numpy as np
@ -192,6 +193,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
return_first_video=True,
video_dir=Path(out_dir) / "eval",
save_video=True,
)
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable:

View File

@ -17,6 +17,7 @@ apptainer exec --nv \
~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL
source ~/.bashrc
conda activate fowm
#conda activate fowm
conda activate lerobot
srun $CMD

View File

@ -12,10 +12,10 @@ from .utils import init_config
# ("simxarm", "lift"),
("pusht", "pusht"),
# TODO(aliberts): add aloha when dataset is available on hub
# ("aloha", "sim_insertion_human"),
# ("aloha", "sim_insertion_scripted"),
# ("aloha", "sim_transfer_cube_human"),
# ("aloha", "sim_transfer_cube_scripted"),
("aloha", "sim_insertion_human"),
("aloha", "sim_insertion_scripted"),
("aloha", "sim_transfer_cube_human"),
("aloha", "sim_transfer_cube_scripted"),
],
)
def test_factory(env_name, dataset_id):