ready for review

This commit is contained in:
Alexander Soare 2024-04-08 13:10:19 +01:00
parent 1bab4a1dd5
commit 863f28ffd8
7 changed files with 92 additions and 168 deletions

View File

@ -59,96 +59,10 @@ def make_dataset(
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_or_load_stats(stats_dataset)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
# # TODO(now): These stats are needed to use their pretrained model for sim_transfer_cube_human.
# # (Pdb) stats['observation']['state']['mean']
# # tensor([-0.0071, -0.6293, 1.0351, -0.0517, -0.4642, -0.0754, 0.4751, -0.0373,
# # -0.3324, 0.9034, -0.2258, -0.3127, -0.2412, 0.6866])
# stats["observation", "state", "mean"] = torch.tensor(
# [
# -0.00740268,
# -0.63187766,
# 1.0356655,
# -0.05027218,
# -0.46199223,
# -0.07467502,
# 0.47467607,
# -0.03615446,
# -0.33203387,
# 0.9038929,
# -0.22060776,
# -0.31011587,
# -0.23484458,
# 0.6842416,
# ]
# )
# # (Pdb) stats['observation']['state']['std']
# # tensor([0.0022, 0.0520, 0.0291, 0.0092, 0.0267, 0.0145, 0.0563, 0.0179, 0.0494,
# # 0.0326, 0.0476, 0.0535, 0.0956, 0.0513])
# stats["observation", "state", "std"] = torch.tensor(
# [
# 0.01219023,
# 0.2975381,
# 0.16728032,
# 0.04733803,
# 0.1486037,
# 0.08788499,
# 0.31752336,
# 0.1049916,
# 0.27933604,
# 0.18094037,
# 0.26604933,
# 0.30466506,
# 0.5298686,
# 0.25505227,
# ]
# )
# # (Pdb) stats['action']['mean']
# # tensor([-0.0075, -0.6346, 1.0353, -0.0465, -0.4686, -0.0738, 0.3723, -0.0396,
# # -0.3184, 0.8991, -0.2065, -0.3182, -0.2338, 0.5593])
# stats["action"]["mean"] = torch.tensor(
# [
# -0.00756444,
# -0.6281845,
# 1.0312834,
# -0.04664314,
# -0.47211358,
# -0.074527,
# 0.37389806,
# -0.03718753,
# -0.3261143,
# 0.8997205,
# -0.21371077,
# -0.31840396,
# -0.23360962,
# 0.551947,
# ]
# )
# # (Pdb) stats['action']['std']
# # tensor([0.0023, 0.0514, 0.0290, 0.0086, 0.0263, 0.0143, 0.0593, 0.0185, 0.0510,
# # 0.0328, 0.0478, 0.0531, 0.0945, 0.0794])
# stats["action"]["std"] = torch.tensor(
# [
# 0.01252818,
# 0.2957442,
# 0.16701928,
# 0.04584508,
# 0.14833844,
# 0.08763024,
# 0.30665937,
# 0.10600077,
# 0.27572668,
# 0.1805853,
# 0.26304692,
# 0.30708534,
# 0.5305411,
# 0.38381037,
# ]
# )
# transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) # noqa: F821
transforms = v2.Compose(
[
# TODO(rcadene): we need to do something about image_keys

View File

@ -4,3 +4,79 @@ import torch
from torch import Tensor, nn
class AbstractPolicy(nn.Module):
"""Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
documentation for more information.
Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
"""
name: str | None = None # same name should be used to instantiate the policy in factory.py
def __init__(self, n_action_steps: int | None):
"""
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
adds that dimension.
"""
super().__init__()
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
self.n_action_steps = n_action_steps
self.clear_action_queue()
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
raise NotImplementedError("Abstract method")
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
def select_actions(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout.
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
"""
raise NotImplementedError("Abstract method")
def clear_action_queue(self):
"""This should be called whenever the environment is reset."""
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def forward(self, *args, **kwargs) -> Tensor:
"""Inference step that makes multi-step policies compatible with their single-step environments.
WARNING: In general, this should not be overriden.
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
the subclass doesn't have to.
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
the action trajectory horizon and * is the action dimensions.
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
"""
if self.n_action_steps is None:
return self.select_actions(*args, **kwargs)
if len(self._action_queue) == 0:
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
# (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
return self._action_queue.popleft()

View File

@ -67,7 +67,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
def __init__(self, cfg, device, n_action_steps=1):
"""
TODO(alexander-soare): Add documentation for all parameters.
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
"""
super().__init__()
if getattr(cfg, "n_obs_steps", 1) != 1:
@ -109,6 +109,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
)
# Backbone for image feature extraction.
self.image_normalizer = transforms.Normalize(
mean=cfg.image_normalization.mean, std=cfg.image_normalization.std
)
backbone_model = getattr(torchvision.models, cfg.backbone)(
replace_stride_with_dilation=[False, False, cfg.dilation],
pretrained=cfg.pretrained_backbone,
@ -275,9 +278,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
return info
def forward(self, batch: dict[str, Tensor], return_loss: bool = False):
# TODO(now): Maybe this shouldn't be here?
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
images = normalize(batch["observation.images.top"])
images = self.image_normalizer(batch["observation.images.top"])
if return_loss: # training time
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(

View File

@ -151,7 +151,6 @@ class DiffusionPolicy(nn.Module):
self.diffusion.train()
data_s = time.time() - start_time
loss = self.diffusion.compute_loss(batch)
loss.backward()
@ -172,7 +171,6 @@ class DiffusionPolicy(nn.Module):
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.lr_scheduler.get_last_lr()[0],
"data_s": data_s,
"update_s": time.time() - start_time,
}

View File

@ -1,6 +1,6 @@
# @package _global_
offline_steps: 2000
offline_steps: 80000
online_steps: 0
eval_episodes: 1
@ -54,8 +54,12 @@ policy:
temporal_agg: false
state_dim: ???
action_dim: ???
state_dim: 14
action_dim: 14
image_normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
delta_timestamps:
observation.images.top: [0.0]

View File

@ -86,7 +86,9 @@ def eval_policy(
def maybe_render_frame(env):
if save_video: # noqa: B023
if return_first_video:
visu = env.envs[0].render(mode="visualization")
# TODO(now): Put mode back in.
visu = env.envs[0].render()
# visu = env.envs[0].render(mode="visualization")
visu = visu[None, ...] # add batch dim
else:
# TODO(now): Put mode back in.

View File

@ -1,71 +0,0 @@
import torch
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import init_hydra_config
cfg = init_hydra_config(
"/home/alexander/Projects/lerobot/outputs/train/act_aloha_sim_transfer_cube_human/.hydra/config.yaml"
)
policy = make_policy(cfg)
state_dict = torch.load("/home/alexander/Projects/act/outputs/sim_transfer_cube_human_vae/policy_last.ckpt")
# Remove keys based on what they start with.
start_removals = [
# There is a bug that means the pretrained model doesn't even use the final decoder layers.
*[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)],
"model.is_pad_head.",
]
for to_remove in start_removals:
for k in list(state_dict.keys()):
if k.startswith(to_remove):
del state_dict[k]
# Replace keys based on what they start with.
start_replacements = [
("model.", ""),
("query_embed.weight", "pos_embed.weight"),
("pos_table", "vae_encoder_pos_enc"),
("pos_embed.weight", "decoder_pos_embed.weight"),
("encoder.", "vae_encoder."),
("encoder_action_proj.", "vae_encoder_action_input_proj."),
("encoder_joint_proj.", "vae_encoder_robot_state_input_proj."),
("latent_proj.", "vae_encoder_latent_output_proj."),
("latent_proj.", "vae_encoder_latent_output_proj."),
("input_proj.", "encoder_img_feat_input_proj."),
("input_proj_robot_state", "encoder_robot_state_input_proj"),
("latent_out_proj.", "encoder_latent_input_proj."),
("transformer.encoder.", "encoder."),
("transformer.decoder.", "decoder."),
("backbones.0.0.body.", "backbone."),
("additional_pos_embed.weight", "encoder_robot_and_latent_pos_embed.weight"),
("cls_embed.weight", "vae_encoder_cls_embed.weight"),
]
for to_replace, replace_with in start_replacements:
for k in list(state_dict.keys()):
if k.startswith(to_replace):
k_ = replace_with + k.removeprefix(to_replace)
state_dict[k_] = state_dict[k]
del state_dict[k]
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
if len(missing_keys) != 0:
print("MISSING KEYS")
print(missing_keys)
if len(unexpected_keys) != 0:
print("UNEXPECTED KEYS")
print(unexpected_keys)
# if len(missing_keys) != 0 or len(unexpected_keys) != 0:
# print("Failed due to mismatch in state dicts.")
# exit()
policy.save("/tmp/weights.pth")