backup wip
This commit is contained in:
parent
ecc7dd3b17
commit
8d2463f45b
|
@ -4,79 +4,3 @@ import torch
|
|||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class AbstractPolicy(nn.Module):
|
||||
"""Base policy which all policies should be derived from.
|
||||
|
||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||
documentation for more information.
|
||||
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||
|
||||
def __init__(self, n_action_steps: int | None):
|
||||
"""
|
||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||
adds that dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
||||
self.n_action_steps = n_action_steps
|
||||
self.clear_action_queue()
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
"""One step of the policy's learning algorithm."""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
def select_actions(self, observation) -> Tensor:
|
||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||
|
||||
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
||||
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
||||
"""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def clear_action_queue(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Tensor:
|
||||
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
||||
|
||||
WARNING: In general, this should not be overriden.
|
||||
|
||||
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
||||
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
||||
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
||||
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||
the subclass doesn't have to.
|
||||
|
||||
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
||||
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||
the action trajectory horizon and * is the action dimensions.
|
||||
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||
"""
|
||||
if self.n_action_steps is None:
|
||||
return self.select_actions(*args, **kwargs)
|
||||
if len(self._action_queue) == 0:
|
||||
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
||||
# (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
||||
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
import math
|
||||
import time
|
||||
from itertools import chain
|
||||
|
@ -22,6 +22,67 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
|||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
||||
# class AbstractPolicy(nn.Module):
|
||||
# """Base policy which all policies should be derived from.
|
||||
|
||||
# The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||
# documentation for more information.
|
||||
|
||||
# Note:
|
||||
# When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
# 1. set the required class attributes:
|
||||
# - for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
# - for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
# - for classes inheriting from `AbstractPolicy`: `name`
|
||||
# 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
# 3. update variables in `tests/test_available.py` by importing your new class
|
||||
# """
|
||||
|
||||
# name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||
|
||||
# def __init__(self, n_action_steps: int | None):
|
||||
# """
|
||||
# n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||
# action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||
# adds that dimension.
|
||||
# """
|
||||
# super().__init__()
|
||||
# assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
||||
# self.n_action_steps = n_action_steps
|
||||
# self.clear_action_queue()
|
||||
|
||||
# def clear_action_queue(self):
|
||||
# """This should be called whenever the environment is reset."""
|
||||
# if self.n_action_steps is not None:
|
||||
# self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
# def forward(self, fn) -> Tensor:
|
||||
# """Inference step that makes multi-step policies compatible with their single-step environments.
|
||||
|
||||
# WARNING: In general, this should not be overriden.
|
||||
|
||||
# Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
||||
# into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
||||
# observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
||||
# observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||
# the subclass doesn't have to.
|
||||
|
||||
# This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
||||
# 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||
# the action trajectory horizon and * is the action dimensions.
|
||||
# 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||
# """
|
||||
# if self.n_action_steps is None:
|
||||
# return self.select_actions(*args, **kwargs)
|
||||
# if len(self._action_queue) == 0:
|
||||
# # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
||||
# # (n_action_steps, batch_size, *), hence the transpose.
|
||||
# self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
||||
# return self._action_queue.popleft()
|
||||
|
||||
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
|
@ -168,14 +229,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
nn.init.xavier_uniform_(p)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_actions(self, batch, *_):
|
||||
def select_action(self, batch, *_):
|
||||
# TODO(now): Implement queueing mechanism.
|
||||
self.eval()
|
||||
self._preprocess_batch(batch)
|
||||
|
||||
# TODO(now): What's up with this 0.182?
|
||||
action = self.forward(
|
||||
robot_state=batch["observation.state"] * 0.182, image=batch["observation.images.top"]
|
||||
robot_state=batch["observation.state"] * 0.182,
|
||||
image=batch["observation.images.top"],
|
||||
return_loss=False,
|
||||
)
|
||||
|
||||
if self.cfg.temporal_agg:
|
||||
|
@ -226,7 +289,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
loss = self.compute_loss(batch)
|
||||
loss = self.forward(batch, return_loss=True)["loss"]
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
|
@ -247,44 +310,38 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
return info
|
||||
|
||||
def compute_loss(self, batch):
|
||||
loss_dict = self.forward(
|
||||
robot_state=batch["observation.state"],
|
||||
image=batch["observation.images.top"],
|
||||
actions=batch["action"],
|
||||
)
|
||||
loss = loss_dict["loss"]
|
||||
return loss
|
||||
|
||||
def forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
|
||||
def forward(self, batch: dict[str, Tensor], return_loss: bool = False):
|
||||
# TODO(now): Maybe this shouldn't be here?
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
images = normalize(batch["observation.images.top"])
|
||||
|
||||
is_training = actions is not None
|
||||
if is_training: # training time
|
||||
actions = actions[:, : self.horizon]
|
||||
if return_loss: # training time
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
|
||||
batch["observation.state"], images, batch["action"]
|
||||
)
|
||||
|
||||
a_hat, (mu, log_sigma_x2) = self._forward(robot_state, image, actions)
|
||||
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = all_l1.mean()
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
* ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {}
|
||||
loss_dict["l1"] = l1
|
||||
loss_dict["l1"] = l1_loss
|
||||
if self.cfg.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean()
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kl"] = mean_kld
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = loss_dict["l1"]
|
||||
return loss_dict
|
||||
else:
|
||||
action, _ = self._forward(robot_state, image) # no action, sample from prior
|
||||
action, _ = self._forward(batch["observation.state"], images)
|
||||
return action
|
||||
|
||||
def _forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
|
||||
|
@ -321,7 +378,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
|
||||
cls_token_out = self.vae_encoder(
|
||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||
)[0] # (B, D)
|
||||
)[
|
||||
0
|
||||
] # (B, D)
|
||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||
mu = latent_pdf_params[:, : self.latent_dim]
|
||||
# This is 2log(sigma). Done this way to match the original implementation.
|
||||
|
|
|
@ -251,7 +251,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
|||
dataset = make_dataset(cfg, stats_path=stats_path)
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
env = make_env(cfg, num_parallel_envs=cfg.rollout_batch_size)
|
||||
|
||||
# when policy is None, rollout a random policy
|
||||
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
|
||||
|
|
|
@ -148,7 +148,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
# )
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
# TODO(now): uncomment
|
||||
#env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg)
|
||||
|
|
|
@ -28,22 +28,23 @@ for to_remove in start_removals:
|
|||
# Replace keys based on what they start with.
|
||||
|
||||
start_replacements = [
|
||||
("model.query_embed.weight", "model.pos_embed.weight"),
|
||||
("model.pos_table", "model.vae_encoder_pos_enc"),
|
||||
("model.pos_embed.weight", "model.decoder_pos_embed.weight"),
|
||||
("model.encoder.", "model.vae_encoder."),
|
||||
("model.encoder_action_proj.", "model.vae_encoder_action_input_proj."),
|
||||
("model.encoder_joint_proj.", "model.vae_encoder_robot_state_input_proj."),
|
||||
("model.latent_proj.", "model.vae_encoder_latent_output_proj."),
|
||||
("model.latent_proj.", "model.vae_encoder_latent_output_proj."),
|
||||
("model.input_proj.", "model.encoder_img_feat_input_proj."),
|
||||
("model.input_proj_robot_state", "model.encoder_robot_state_input_proj"),
|
||||
("model.latent_out_proj.", "model.encoder_latent_input_proj."),
|
||||
("model.transformer.encoder.", "model.encoder."),
|
||||
("model.transformer.decoder.", "model.decoder."),
|
||||
("model.backbones.0.0.body.", "model.backbone."),
|
||||
("model.additional_pos_embed.weight", "model.encoder_robot_and_latent_pos_embed.weight"),
|
||||
("model.cls_embed.weight", "model.vae_encoder_cls_embed.weight"),
|
||||
("model.", ""),
|
||||
("query_embed.weight", "pos_embed.weight"),
|
||||
("pos_table", "vae_encoder_pos_enc"),
|
||||
("pos_embed.weight", "decoder_pos_embed.weight"),
|
||||
("encoder.", "vae_encoder."),
|
||||
("encoder_action_proj.", "vae_encoder_action_input_proj."),
|
||||
("encoder_joint_proj.", "vae_encoder_robot_state_input_proj."),
|
||||
("latent_proj.", "vae_encoder_latent_output_proj."),
|
||||
("latent_proj.", "vae_encoder_latent_output_proj."),
|
||||
("input_proj.", "encoder_img_feat_input_proj."),
|
||||
("input_proj_robot_state", "encoder_robot_state_input_proj"),
|
||||
("latent_out_proj.", "encoder_latent_input_proj."),
|
||||
("transformer.encoder.", "encoder."),
|
||||
("transformer.decoder.", "decoder."),
|
||||
("backbones.0.0.body.", "backbone."),
|
||||
("additional_pos_embed.weight", "encoder_robot_and_latent_pos_embed.weight"),
|
||||
("cls_embed.weight", "vae_encoder_cls_embed.weight"),
|
||||
]
|
||||
|
||||
for to_replace, replace_with in start_replacements:
|
||||
|
|
Loading…
Reference in New Issue