Add AbstractEnv, Refactor AlohaEnv, Add rendering_hook in env, Minor modifications, (TODO: Refactor Pusht and Simxarm)
This commit is contained in:
parent
b49f7b70e2
commit
7bf36cd413
|
@ -124,9 +124,6 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
|||
def image_keys(self) -> list:
|
||||
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
||||
|
||||
# def _is_downloaded(self) -> bool:
|
||||
# return False
|
||||
|
||||
def _download_and_preproc(self):
|
||||
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
||||
if not raw_dir.is_dir():
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
import abc
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
from tensordict import TensorDict
|
||||
from torchrl.envs import EnvBase
|
||||
|
||||
|
||||
class AbstractEnv(EnvBase):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
self.num_prev_obs = num_prev_obs
|
||||
self.num_prev_action = num_prev_action
|
||||
self._rendering_hooks = []
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
|
||||
self._make_spec()
|
||||
self._current_seed = self.set_seed(seed)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||
if self.num_prev_action > 0:
|
||||
raise NotImplementedError()
|
||||
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||
|
||||
def register_rendering_hook(self, func):
|
||||
self._rendering_hooks.append(func)
|
||||
|
||||
def call_rendering_hooks(self):
|
||||
for func in self._rendering_hooks:
|
||||
func(self)
|
||||
|
||||
def reset_rendering_hooks(self):
|
||||
self._rendering_hooks = []
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self, mode="rgb_array", width=640, height=480):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _step(self, tensordict: TensorDict):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _make_spec(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
raise NotImplementedError()
|
|
@ -15,8 +15,8 @@ from torchrl.data.tensor_specs import (
|
|||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
from torchrl.envs import EnvBase
|
||||
|
||||
from lerobot.common.envs.abstract import AbstractEnv
|
||||
from lerobot.common.envs.aloha.constants import (
|
||||
ACTIONS,
|
||||
ASSETS_DIR,
|
||||
|
@ -28,14 +28,13 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import (
|
|||
InsertionEndEffectorTask,
|
||||
TransferCubeEndEffectorTask,
|
||||
)
|
||||
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
from .utils import sample_box_pose, sample_insertion_pose
|
||||
|
||||
_has_gym = importlib.util.find_spec("gym") is not None
|
||||
|
||||
|
||||
class AlohaEnv(EnvBase):
|
||||
class AlohaEnv(AbstractEnv):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
|
@ -48,20 +47,17 @@ class AlohaEnv(EnvBase):
|
|||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
self.num_prev_obs = num_prev_obs
|
||||
self.num_prev_action = num_prev_action
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
|
||||
super().__init__(
|
||||
task=task,
|
||||
frame_skip=frame_skip,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=image_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
|
@ -70,16 +66,6 @@ class AlohaEnv(EnvBase):
|
|||
|
||||
self._env = self._make_env_task(task)
|
||||
|
||||
self._make_spec()
|
||||
self._current_seed = self.set_seed(seed)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||
if self.num_prev_action > 0:
|
||||
raise NotImplementedError()
|
||||
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||
|
||||
def render(self, mode="rgb_array", width=640, height=480):
|
||||
# TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
|
||||
image = self._env.physics.render(height=height, width=width, camera_id="top")
|
||||
|
@ -172,6 +158,8 @@ class AlohaEnv(EnvBase):
|
|||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.call_rendering_hooks()
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
|
@ -207,6 +195,8 @@ class AlohaEnv(EnvBase):
|
|||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
self.call_rendering_hooks()
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
|
|
|
@ -27,7 +27,9 @@ def get_sinusoid_encoding_table(n_position, d_hid):
|
|||
class DETRVAE(nn.Module):
|
||||
"""This is the DETR module that performs object detection"""
|
||||
|
||||
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names):
|
||||
def __init__(
|
||||
self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae
|
||||
):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
|
@ -42,6 +44,7 @@ class DETRVAE(nn.Module):
|
|||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
self.vae = vae
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
|
@ -86,7 +89,7 @@ class DETRVAE(nn.Module):
|
|||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
### Obtain latent z from action sequence
|
||||
if is_training:
|
||||
if self.vae and is_training:
|
||||
# project action sequence to embedding dim, and concat with a CLS token
|
||||
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
|
||||
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||
|
@ -200,6 +203,7 @@ def build(args):
|
|||
action_dim=args.action_dim,
|
||||
num_queries=args.num_queries,
|
||||
camera_names=args.camera_names,
|
||||
vae=args.vae,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
|
|
@ -11,7 +11,6 @@ from lerobot.common.policies.act.detr_vae import build
|
|||
|
||||
def build_act_model_and_optimizer(cfg):
|
||||
model = build(cfg)
|
||||
model.cuda()
|
||||
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||
|
@ -51,6 +50,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
self.kl_weight = self.cfg.kl_weight
|
||||
logging.info(f"KL Weight {self.kl_weight}")
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
del step
|
||||
|
||||
|
@ -192,20 +193,25 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
|
||||
is_train_mode = actions is not None
|
||||
if is_train_mode: # training time
|
||||
is_training = actions is not None
|
||||
if is_training: # training time
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
if is_pad is not None:
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
loss_dict = {}
|
||||
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
|
||||
loss_dict = {}
|
||||
loss_dict["l1"] = l1
|
||||
if self.cfg.vae:
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
loss_dict["kl"] = total_kld[0]
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = loss_dict["l1"]
|
||||
return loss_dict
|
||||
else:
|
||||
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
|
|
|
@ -17,7 +17,7 @@ env:
|
|||
pixels_only: False
|
||||
image_size: [3, 480, 640]
|
||||
action_repeat: 1
|
||||
episode_length: 300
|
||||
episode_length: 400
|
||||
fps: ${fps}
|
||||
|
||||
policy:
|
||||
|
|
|
@ -29,9 +29,10 @@ policy:
|
|||
kl_weight: 10
|
||||
hidden_dim: 512
|
||||
dim_feedforward: 3200
|
||||
enc_layers: 7
|
||||
dec_layers: 8
|
||||
enc_layers: 4
|
||||
dec_layers: 7
|
||||
nheads: 8
|
||||
#camera_names: [top, front_close, left_pillar, right_pillar]
|
||||
camera_names: [top]
|
||||
position_embedding: sine
|
||||
masks: false
|
||||
|
@ -39,6 +40,8 @@ policy:
|
|||
dropout: 0.1
|
||||
pre_norm: false
|
||||
|
||||
vae: true
|
||||
|
||||
batch_size: 8
|
||||
|
||||
per_alpha: 0.6
|
||||
|
|
|
@ -38,27 +38,18 @@ def eval_policy(
|
|||
successes = []
|
||||
threads = []
|
||||
for i in tqdm.tqdm(range(num_episodes)):
|
||||
tensordict = env.reset()
|
||||
|
||||
ep_frames = []
|
||||
|
||||
if save_video or (return_first_video and i == 0):
|
||||
|
||||
def rendering_callback(env, td=None):
|
||||
def render_frame(env):
|
||||
ep_frames.append(env.render()) # noqa: B023
|
||||
|
||||
# render first frame before rollout
|
||||
rendering_callback(env)
|
||||
else:
|
||||
rendering_callback = None
|
||||
env.register_rendering_hook(render_frame)
|
||||
|
||||
with torch.inference_mode():
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
callback=rendering_callback,
|
||||
auto_reset=False,
|
||||
tensordict=tensordict,
|
||||
auto_cast_to_device=True,
|
||||
)
|
||||
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
||||
|
@ -85,6 +76,8 @@ def eval_policy(
|
|||
if return_first_video and i == 0:
|
||||
first_video = stacked_frames.transpose(0, 3, 1, 2)
|
||||
|
||||
env.reset_rendering_hooks()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
|
@ -192,6 +193,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
num_episodes=cfg.eval_episodes,
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
return_first_video=True,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
save_video=True,
|
||||
)
|
||||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
||||
if cfg.wandb.enable:
|
||||
|
|
|
@ -17,6 +17,7 @@ apptainer exec --nv \
|
|||
~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL
|
||||
|
||||
source ~/.bashrc
|
||||
conda activate fowm
|
||||
#conda activate fowm
|
||||
conda activate lerobot
|
||||
|
||||
srun $CMD
|
||||
|
|
|
@ -12,10 +12,10 @@ from .utils import init_config
|
|||
# ("simxarm", "lift"),
|
||||
("pusht", "pusht"),
|
||||
# TODO(aliberts): add aloha when dataset is available on hub
|
||||
# ("aloha", "sim_insertion_human"),
|
||||
# ("aloha", "sim_insertion_scripted"),
|
||||
# ("aloha", "sim_transfer_cube_human"),
|
||||
# ("aloha", "sim_transfer_cube_scripted"),
|
||||
("aloha", "sim_insertion_human"),
|
||||
("aloha", "sim_insertion_scripted"),
|
||||
("aloha", "sim_transfer_cube_human"),
|
||||
("aloha", "sim_transfer_cube_scripted"),
|
||||
],
|
||||
)
|
||||
def test_factory(env_name, dataset_id):
|
||||
|
|
Loading…
Reference in New Issue