ready for review
This commit is contained in:
parent
1bab4a1dd5
commit
863f28ffd8
|
@ -59,96 +59,10 @@ def make_dataset(
|
||||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||||
)
|
)
|
||||||
stats = compute_or_load_stats(stats_dataset)
|
stats = compute_or_load_stats(stats_dataset)
|
||||||
|
|
||||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||||
|
|
||||||
# # TODO(now): These stats are needed to use their pretrained model for sim_transfer_cube_human.
|
|
||||||
# # (Pdb) stats['observation']['state']['mean']
|
|
||||||
# # tensor([-0.0071, -0.6293, 1.0351, -0.0517, -0.4642, -0.0754, 0.4751, -0.0373,
|
|
||||||
# # -0.3324, 0.9034, -0.2258, -0.3127, -0.2412, 0.6866])
|
|
||||||
# stats["observation", "state", "mean"] = torch.tensor(
|
|
||||||
# [
|
|
||||||
# -0.00740268,
|
|
||||||
# -0.63187766,
|
|
||||||
# 1.0356655,
|
|
||||||
# -0.05027218,
|
|
||||||
# -0.46199223,
|
|
||||||
# -0.07467502,
|
|
||||||
# 0.47467607,
|
|
||||||
# -0.03615446,
|
|
||||||
# -0.33203387,
|
|
||||||
# 0.9038929,
|
|
||||||
# -0.22060776,
|
|
||||||
# -0.31011587,
|
|
||||||
# -0.23484458,
|
|
||||||
# 0.6842416,
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
# # (Pdb) stats['observation']['state']['std']
|
|
||||||
# # tensor([0.0022, 0.0520, 0.0291, 0.0092, 0.0267, 0.0145, 0.0563, 0.0179, 0.0494,
|
|
||||||
# # 0.0326, 0.0476, 0.0535, 0.0956, 0.0513])
|
|
||||||
# stats["observation", "state", "std"] = torch.tensor(
|
|
||||||
# [
|
|
||||||
# 0.01219023,
|
|
||||||
# 0.2975381,
|
|
||||||
# 0.16728032,
|
|
||||||
# 0.04733803,
|
|
||||||
# 0.1486037,
|
|
||||||
# 0.08788499,
|
|
||||||
# 0.31752336,
|
|
||||||
# 0.1049916,
|
|
||||||
# 0.27933604,
|
|
||||||
# 0.18094037,
|
|
||||||
# 0.26604933,
|
|
||||||
# 0.30466506,
|
|
||||||
# 0.5298686,
|
|
||||||
# 0.25505227,
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
# # (Pdb) stats['action']['mean']
|
|
||||||
# # tensor([-0.0075, -0.6346, 1.0353, -0.0465, -0.4686, -0.0738, 0.3723, -0.0396,
|
|
||||||
# # -0.3184, 0.8991, -0.2065, -0.3182, -0.2338, 0.5593])
|
|
||||||
# stats["action"]["mean"] = torch.tensor(
|
|
||||||
# [
|
|
||||||
# -0.00756444,
|
|
||||||
# -0.6281845,
|
|
||||||
# 1.0312834,
|
|
||||||
# -0.04664314,
|
|
||||||
# -0.47211358,
|
|
||||||
# -0.074527,
|
|
||||||
# 0.37389806,
|
|
||||||
# -0.03718753,
|
|
||||||
# -0.3261143,
|
|
||||||
# 0.8997205,
|
|
||||||
# -0.21371077,
|
|
||||||
# -0.31840396,
|
|
||||||
# -0.23360962,
|
|
||||||
# 0.551947,
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
# # (Pdb) stats['action']['std']
|
|
||||||
# # tensor([0.0023, 0.0514, 0.0290, 0.0086, 0.0263, 0.0143, 0.0593, 0.0185, 0.0510,
|
|
||||||
# # 0.0328, 0.0478, 0.0531, 0.0945, 0.0794])
|
|
||||||
# stats["action"]["std"] = torch.tensor(
|
|
||||||
# [
|
|
||||||
# 0.01252818,
|
|
||||||
# 0.2957442,
|
|
||||||
# 0.16701928,
|
|
||||||
# 0.04584508,
|
|
||||||
# 0.14833844,
|
|
||||||
# 0.08763024,
|
|
||||||
# 0.30665937,
|
|
||||||
# 0.10600077,
|
|
||||||
# 0.27572668,
|
|
||||||
# 0.1805853,
|
|
||||||
# 0.26304692,
|
|
||||||
# 0.30708534,
|
|
||||||
# 0.5305411,
|
|
||||||
# 0.38381037,
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
# transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) # noqa: F821
|
|
||||||
|
|
||||||
transforms = v2.Compose(
|
transforms = v2.Compose(
|
||||||
[
|
[
|
||||||
# TODO(rcadene): we need to do something about image_keys
|
# TODO(rcadene): we need to do something about image_keys
|
||||||
|
|
|
@ -4,3 +4,79 @@ import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractPolicy(nn.Module):
|
||||||
|
"""Base policy which all policies should be derived from.
|
||||||
|
|
||||||
|
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||||
|
documentation for more information.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||||
|
1. set the required class attributes:
|
||||||
|
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||||
|
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||||
|
- for classes inheriting from `AbstractPolicy`: `name`
|
||||||
|
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||||
|
3. update variables in `tests/test_available.py` by importing your new class
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||||
|
|
||||||
|
def __init__(self, n_action_steps: int | None):
|
||||||
|
"""
|
||||||
|
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||||
|
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||||
|
adds that dimension.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
||||||
|
self.n_action_steps = n_action_steps
|
||||||
|
self.clear_action_queue()
|
||||||
|
|
||||||
|
def update(self, replay_buffer, step):
|
||||||
|
"""One step of the policy's learning algorithm."""
|
||||||
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
|
def save(self, fp):
|
||||||
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
|
def load(self, fp):
|
||||||
|
d = torch.load(fp)
|
||||||
|
self.load_state_dict(d)
|
||||||
|
|
||||||
|
def select_actions(self, observation) -> Tensor:
|
||||||
|
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||||
|
|
||||||
|
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
||||||
|
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
|
def clear_action_queue(self):
|
||||||
|
"""This should be called whenever the environment is reset."""
|
||||||
|
if self.n_action_steps is not None:
|
||||||
|
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs) -> Tensor:
|
||||||
|
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
||||||
|
|
||||||
|
WARNING: In general, this should not be overriden.
|
||||||
|
|
||||||
|
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
||||||
|
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
||||||
|
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
||||||
|
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||||
|
the subclass doesn't have to.
|
||||||
|
|
||||||
|
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
||||||
|
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||||
|
the action trajectory horizon and * is the action dimensions.
|
||||||
|
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||||
|
"""
|
||||||
|
if self.n_action_steps is None:
|
||||||
|
return self.select_actions(*args, **kwargs)
|
||||||
|
if len(self._action_queue) == 0:
|
||||||
|
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
||||||
|
# (n_action_steps, batch_size, *), hence the transpose.
|
||||||
|
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
||||||
|
return self._action_queue.popleft()
|
||||||
|
|
|
@ -67,7 +67,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
def __init__(self, cfg, device, n_action_steps=1):
|
def __init__(self, cfg, device, n_action_steps=1):
|
||||||
"""
|
"""
|
||||||
TODO(alexander-soare): Add documentation for all parameters.
|
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if getattr(cfg, "n_obs_steps", 1) != 1:
|
if getattr(cfg, "n_obs_steps", 1) != 1:
|
||||||
|
@ -109,6 +109,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Backbone for image feature extraction.
|
# Backbone for image feature extraction.
|
||||||
|
self.image_normalizer = transforms.Normalize(
|
||||||
|
mean=cfg.image_normalization.mean, std=cfg.image_normalization.std
|
||||||
|
)
|
||||||
backbone_model = getattr(torchvision.models, cfg.backbone)(
|
backbone_model = getattr(torchvision.models, cfg.backbone)(
|
||||||
replace_stride_with_dilation=[False, False, cfg.dilation],
|
replace_stride_with_dilation=[False, False, cfg.dilation],
|
||||||
pretrained=cfg.pretrained_backbone,
|
pretrained=cfg.pretrained_backbone,
|
||||||
|
@ -275,9 +278,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor], return_loss: bool = False):
|
def forward(self, batch: dict[str, Tensor], return_loss: bool = False):
|
||||||
# TODO(now): Maybe this shouldn't be here?
|
images = self.image_normalizer(batch["observation.images.top"])
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
images = normalize(batch["observation.images.top"])
|
|
||||||
|
|
||||||
if return_loss: # training time
|
if return_loss: # training time
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
|
||||||
|
|
|
@ -151,7 +151,6 @@ class DiffusionPolicy(nn.Module):
|
||||||
|
|
||||||
self.diffusion.train()
|
self.diffusion.train()
|
||||||
|
|
||||||
data_s = time.time() - start_time
|
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
@ -172,7 +171,6 @@ class DiffusionPolicy(nn.Module):
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"grad_norm": float(grad_norm),
|
"grad_norm": float(grad_norm),
|
||||||
"lr": self.lr_scheduler.get_last_lr()[0],
|
"lr": self.lr_scheduler.get_last_lr()[0],
|
||||||
"data_s": data_s,
|
|
||||||
"update_s": time.time() - start_time,
|
"update_s": time.time() - start_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
offline_steps: 2000
|
offline_steps: 80000
|
||||||
online_steps: 0
|
online_steps: 0
|
||||||
|
|
||||||
eval_episodes: 1
|
eval_episodes: 1
|
||||||
|
@ -54,8 +54,12 @@ policy:
|
||||||
|
|
||||||
temporal_agg: false
|
temporal_agg: false
|
||||||
|
|
||||||
state_dim: ???
|
state_dim: 14
|
||||||
action_dim: ???
|
action_dim: 14
|
||||||
|
|
||||||
|
image_normalization:
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
observation.images.top: [0.0]
|
observation.images.top: [0.0]
|
||||||
|
|
|
@ -86,7 +86,9 @@ def eval_policy(
|
||||||
def maybe_render_frame(env):
|
def maybe_render_frame(env):
|
||||||
if save_video: # noqa: B023
|
if save_video: # noqa: B023
|
||||||
if return_first_video:
|
if return_first_video:
|
||||||
visu = env.envs[0].render(mode="visualization")
|
# TODO(now): Put mode back in.
|
||||||
|
visu = env.envs[0].render()
|
||||||
|
# visu = env.envs[0].render(mode="visualization")
|
||||||
visu = visu[None, ...] # add batch dim
|
visu = visu[None, ...] # add batch dim
|
||||||
else:
|
else:
|
||||||
# TODO(now): Put mode back in.
|
# TODO(now): Put mode back in.
|
||||||
|
|
|
@ -1,71 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.common.policies.factory import make_policy
|
|
||||||
from lerobot.common.utils import init_hydra_config
|
|
||||||
|
|
||||||
cfg = init_hydra_config(
|
|
||||||
"/home/alexander/Projects/lerobot/outputs/train/act_aloha_sim_transfer_cube_human/.hydra/config.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
policy = make_policy(cfg)
|
|
||||||
|
|
||||||
state_dict = torch.load("/home/alexander/Projects/act/outputs/sim_transfer_cube_human_vae/policy_last.ckpt")
|
|
||||||
|
|
||||||
# Remove keys based on what they start with.
|
|
||||||
|
|
||||||
start_removals = [
|
|
||||||
# There is a bug that means the pretrained model doesn't even use the final decoder layers.
|
|
||||||
*[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)],
|
|
||||||
"model.is_pad_head.",
|
|
||||||
]
|
|
||||||
|
|
||||||
for to_remove in start_removals:
|
|
||||||
for k in list(state_dict.keys()):
|
|
||||||
if k.startswith(to_remove):
|
|
||||||
del state_dict[k]
|
|
||||||
|
|
||||||
|
|
||||||
# Replace keys based on what they start with.
|
|
||||||
|
|
||||||
start_replacements = [
|
|
||||||
("model.", ""),
|
|
||||||
("query_embed.weight", "pos_embed.weight"),
|
|
||||||
("pos_table", "vae_encoder_pos_enc"),
|
|
||||||
("pos_embed.weight", "decoder_pos_embed.weight"),
|
|
||||||
("encoder.", "vae_encoder."),
|
|
||||||
("encoder_action_proj.", "vae_encoder_action_input_proj."),
|
|
||||||
("encoder_joint_proj.", "vae_encoder_robot_state_input_proj."),
|
|
||||||
("latent_proj.", "vae_encoder_latent_output_proj."),
|
|
||||||
("latent_proj.", "vae_encoder_latent_output_proj."),
|
|
||||||
("input_proj.", "encoder_img_feat_input_proj."),
|
|
||||||
("input_proj_robot_state", "encoder_robot_state_input_proj"),
|
|
||||||
("latent_out_proj.", "encoder_latent_input_proj."),
|
|
||||||
("transformer.encoder.", "encoder."),
|
|
||||||
("transformer.decoder.", "decoder."),
|
|
||||||
("backbones.0.0.body.", "backbone."),
|
|
||||||
("additional_pos_embed.weight", "encoder_robot_and_latent_pos_embed.weight"),
|
|
||||||
("cls_embed.weight", "vae_encoder_cls_embed.weight"),
|
|
||||||
]
|
|
||||||
|
|
||||||
for to_replace, replace_with in start_replacements:
|
|
||||||
for k in list(state_dict.keys()):
|
|
||||||
if k.startswith(to_replace):
|
|
||||||
k_ = replace_with + k.removeprefix(to_replace)
|
|
||||||
state_dict[k_] = state_dict[k]
|
|
||||||
del state_dict[k]
|
|
||||||
|
|
||||||
|
|
||||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
|
||||||
|
|
||||||
if len(missing_keys) != 0:
|
|
||||||
print("MISSING KEYS")
|
|
||||||
print(missing_keys)
|
|
||||||
if len(unexpected_keys) != 0:
|
|
||||||
print("UNEXPECTED KEYS")
|
|
||||||
print(unexpected_keys)
|
|
||||||
|
|
||||||
# if len(missing_keys) != 0 or len(unexpected_keys) != 0:
|
|
||||||
# print("Failed due to mismatch in state dicts.")
|
|
||||||
# exit()
|
|
||||||
|
|
||||||
policy.save("/tmp/weights.pth")
|
|
Loading…
Reference in New Issue