ready for review
This commit is contained in:
parent
1bab4a1dd5
commit
863f28ffd8
|
@ -59,96 +59,10 @@ def make_dataset(
|
|||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
stats = compute_or_load_stats(stats_dataset)
|
||||
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
|
||||
# # TODO(now): These stats are needed to use their pretrained model for sim_transfer_cube_human.
|
||||
# # (Pdb) stats['observation']['state']['mean']
|
||||
# # tensor([-0.0071, -0.6293, 1.0351, -0.0517, -0.4642, -0.0754, 0.4751, -0.0373,
|
||||
# # -0.3324, 0.9034, -0.2258, -0.3127, -0.2412, 0.6866])
|
||||
# stats["observation", "state", "mean"] = torch.tensor(
|
||||
# [
|
||||
# -0.00740268,
|
||||
# -0.63187766,
|
||||
# 1.0356655,
|
||||
# -0.05027218,
|
||||
# -0.46199223,
|
||||
# -0.07467502,
|
||||
# 0.47467607,
|
||||
# -0.03615446,
|
||||
# -0.33203387,
|
||||
# 0.9038929,
|
||||
# -0.22060776,
|
||||
# -0.31011587,
|
||||
# -0.23484458,
|
||||
# 0.6842416,
|
||||
# ]
|
||||
# )
|
||||
# # (Pdb) stats['observation']['state']['std']
|
||||
# # tensor([0.0022, 0.0520, 0.0291, 0.0092, 0.0267, 0.0145, 0.0563, 0.0179, 0.0494,
|
||||
# # 0.0326, 0.0476, 0.0535, 0.0956, 0.0513])
|
||||
# stats["observation", "state", "std"] = torch.tensor(
|
||||
# [
|
||||
# 0.01219023,
|
||||
# 0.2975381,
|
||||
# 0.16728032,
|
||||
# 0.04733803,
|
||||
# 0.1486037,
|
||||
# 0.08788499,
|
||||
# 0.31752336,
|
||||
# 0.1049916,
|
||||
# 0.27933604,
|
||||
# 0.18094037,
|
||||
# 0.26604933,
|
||||
# 0.30466506,
|
||||
# 0.5298686,
|
||||
# 0.25505227,
|
||||
# ]
|
||||
# )
|
||||
# # (Pdb) stats['action']['mean']
|
||||
# # tensor([-0.0075, -0.6346, 1.0353, -0.0465, -0.4686, -0.0738, 0.3723, -0.0396,
|
||||
# # -0.3184, 0.8991, -0.2065, -0.3182, -0.2338, 0.5593])
|
||||
# stats["action"]["mean"] = torch.tensor(
|
||||
# [
|
||||
# -0.00756444,
|
||||
# -0.6281845,
|
||||
# 1.0312834,
|
||||
# -0.04664314,
|
||||
# -0.47211358,
|
||||
# -0.074527,
|
||||
# 0.37389806,
|
||||
# -0.03718753,
|
||||
# -0.3261143,
|
||||
# 0.8997205,
|
||||
# -0.21371077,
|
||||
# -0.31840396,
|
||||
# -0.23360962,
|
||||
# 0.551947,
|
||||
# ]
|
||||
# )
|
||||
# # (Pdb) stats['action']['std']
|
||||
# # tensor([0.0023, 0.0514, 0.0290, 0.0086, 0.0263, 0.0143, 0.0593, 0.0185, 0.0510,
|
||||
# # 0.0328, 0.0478, 0.0531, 0.0945, 0.0794])
|
||||
# stats["action"]["std"] = torch.tensor(
|
||||
# [
|
||||
# 0.01252818,
|
||||
# 0.2957442,
|
||||
# 0.16701928,
|
||||
# 0.04584508,
|
||||
# 0.14833844,
|
||||
# 0.08763024,
|
||||
# 0.30665937,
|
||||
# 0.10600077,
|
||||
# 0.27572668,
|
||||
# 0.1805853,
|
||||
# 0.26304692,
|
||||
# 0.30708534,
|
||||
# 0.5305411,
|
||||
# 0.38381037,
|
||||
# ]
|
||||
# )
|
||||
# transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) # noqa: F821
|
||||
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
# TODO(rcadene): we need to do something about image_keys
|
||||
|
|
|
@ -4,3 +4,79 @@ import torch
|
|||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class AbstractPolicy(nn.Module):
|
||||
"""Base policy which all policies should be derived from.
|
||||
|
||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||
documentation for more information.
|
||||
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||
|
||||
def __init__(self, n_action_steps: int | None):
|
||||
"""
|
||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||
adds that dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
||||
self.n_action_steps = n_action_steps
|
||||
self.clear_action_queue()
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
"""One step of the policy's learning algorithm."""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
def select_actions(self, observation) -> Tensor:
|
||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||
|
||||
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
||||
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
||||
"""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def clear_action_queue(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Tensor:
|
||||
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
||||
|
||||
WARNING: In general, this should not be overriden.
|
||||
|
||||
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
||||
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
||||
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
||||
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||
the subclass doesn't have to.
|
||||
|
||||
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
||||
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||
the action trajectory horizon and * is the action dimensions.
|
||||
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||
"""
|
||||
if self.n_action_steps is None:
|
||||
return self.select_actions(*args, **kwargs)
|
||||
if len(self._action_queue) == 0:
|
||||
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
||||
# (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
|
|
@ -67,7 +67,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
"""
|
||||
TODO(alexander-soare): Add documentation for all parameters.
|
||||
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
||||
"""
|
||||
super().__init__()
|
||||
if getattr(cfg, "n_obs_steps", 1) != 1:
|
||||
|
@ -109,6 +109,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
self.image_normalizer = transforms.Normalize(
|
||||
mean=cfg.image_normalization.mean, std=cfg.image_normalization.std
|
||||
)
|
||||
backbone_model = getattr(torchvision.models, cfg.backbone)(
|
||||
replace_stride_with_dilation=[False, False, cfg.dilation],
|
||||
pretrained=cfg.pretrained_backbone,
|
||||
|
@ -275,9 +278,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
return info
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], return_loss: bool = False):
|
||||
# TODO(now): Maybe this shouldn't be here?
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
images = normalize(batch["observation.images.top"])
|
||||
images = self.image_normalizer(batch["observation.images.top"])
|
||||
|
||||
if return_loss: # training time
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
|
||||
|
|
|
@ -151,7 +151,6 @@ class DiffusionPolicy(nn.Module):
|
|||
|
||||
self.diffusion.train()
|
||||
|
||||
data_s = time.time() - start_time
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
|
@ -172,7 +171,6 @@ class DiffusionPolicy(nn.Module):
|
|||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.lr_scheduler.get_last_lr()[0],
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# @package _global_
|
||||
|
||||
offline_steps: 2000
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
|
||||
eval_episodes: 1
|
||||
|
@ -54,8 +54,12 @@ policy:
|
|||
|
||||
temporal_agg: false
|
||||
|
||||
state_dim: ???
|
||||
action_dim: ???
|
||||
state_dim: 14
|
||||
action_dim: 14
|
||||
|
||||
image_normalization:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
|
||||
delta_timestamps:
|
||||
observation.images.top: [0.0]
|
||||
|
|
|
@ -86,7 +86,9 @@ def eval_policy(
|
|||
def maybe_render_frame(env):
|
||||
if save_video: # noqa: B023
|
||||
if return_first_video:
|
||||
visu = env.envs[0].render(mode="visualization")
|
||||
# TODO(now): Put mode back in.
|
||||
visu = env.envs[0].render()
|
||||
# visu = env.envs[0].render(mode="visualization")
|
||||
visu = visu[None, ...] # add batch dim
|
||||
else:
|
||||
# TODO(now): Put mode back in.
|
||||
|
|
|
@ -1,71 +0,0 @@
|
|||
import torch
|
||||
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
|
||||
cfg = init_hydra_config(
|
||||
"/home/alexander/Projects/lerobot/outputs/train/act_aloha_sim_transfer_cube_human/.hydra/config.yaml"
|
||||
)
|
||||
|
||||
policy = make_policy(cfg)
|
||||
|
||||
state_dict = torch.load("/home/alexander/Projects/act/outputs/sim_transfer_cube_human_vae/policy_last.ckpt")
|
||||
|
||||
# Remove keys based on what they start with.
|
||||
|
||||
start_removals = [
|
||||
# There is a bug that means the pretrained model doesn't even use the final decoder layers.
|
||||
*[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)],
|
||||
"model.is_pad_head.",
|
||||
]
|
||||
|
||||
for to_remove in start_removals:
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(to_remove):
|
||||
del state_dict[k]
|
||||
|
||||
|
||||
# Replace keys based on what they start with.
|
||||
|
||||
start_replacements = [
|
||||
("model.", ""),
|
||||
("query_embed.weight", "pos_embed.weight"),
|
||||
("pos_table", "vae_encoder_pos_enc"),
|
||||
("pos_embed.weight", "decoder_pos_embed.weight"),
|
||||
("encoder.", "vae_encoder."),
|
||||
("encoder_action_proj.", "vae_encoder_action_input_proj."),
|
||||
("encoder_joint_proj.", "vae_encoder_robot_state_input_proj."),
|
||||
("latent_proj.", "vae_encoder_latent_output_proj."),
|
||||
("latent_proj.", "vae_encoder_latent_output_proj."),
|
||||
("input_proj.", "encoder_img_feat_input_proj."),
|
||||
("input_proj_robot_state", "encoder_robot_state_input_proj"),
|
||||
("latent_out_proj.", "encoder_latent_input_proj."),
|
||||
("transformer.encoder.", "encoder."),
|
||||
("transformer.decoder.", "decoder."),
|
||||
("backbones.0.0.body.", "backbone."),
|
||||
("additional_pos_embed.weight", "encoder_robot_and_latent_pos_embed.weight"),
|
||||
("cls_embed.weight", "vae_encoder_cls_embed.weight"),
|
||||
]
|
||||
|
||||
for to_replace, replace_with in start_replacements:
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(to_replace):
|
||||
k_ = replace_with + k.removeprefix(to_replace)
|
||||
state_dict[k_] = state_dict[k]
|
||||
del state_dict[k]
|
||||
|
||||
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if len(missing_keys) != 0:
|
||||
print("MISSING KEYS")
|
||||
print(missing_keys)
|
||||
if len(unexpected_keys) != 0:
|
||||
print("UNEXPECTED KEYS")
|
||||
print(unexpected_keys)
|
||||
|
||||
# if len(missing_keys) != 0 or len(unexpected_keys) != 0:
|
||||
# print("Failed due to mismatch in state dicts.")
|
||||
# exit()
|
||||
|
||||
policy.save("/tmp/weights.pth")
|
Loading…
Reference in New Issue