backup wip
This commit is contained in:
parent
ecc7dd3b17
commit
8d2463f45b
|
@ -4,79 +4,3 @@ import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class AbstractPolicy(nn.Module):
|
|
||||||
"""Base policy which all policies should be derived from.
|
|
||||||
|
|
||||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
|
||||||
documentation for more information.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
|
||||||
1. set the required class attributes:
|
|
||||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
|
||||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
|
||||||
- for classes inheriting from `AbstractPolicy`: `name`
|
|
||||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
|
||||||
3. update variables in `tests/test_available.py` by importing your new class
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
|
||||||
|
|
||||||
def __init__(self, n_action_steps: int | None):
|
|
||||||
"""
|
|
||||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
|
||||||
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
|
||||||
adds that dimension.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
|
||||||
self.n_action_steps = n_action_steps
|
|
||||||
self.clear_action_queue()
|
|
||||||
|
|
||||||
def update(self, replay_buffer, step):
|
|
||||||
"""One step of the policy's learning algorithm."""
|
|
||||||
raise NotImplementedError("Abstract method")
|
|
||||||
|
|
||||||
def save(self, fp):
|
|
||||||
torch.save(self.state_dict(), fp)
|
|
||||||
|
|
||||||
def load(self, fp):
|
|
||||||
d = torch.load(fp)
|
|
||||||
self.load_state_dict(d)
|
|
||||||
|
|
||||||
def select_actions(self, observation) -> Tensor:
|
|
||||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
|
||||||
|
|
||||||
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
|
||||||
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Abstract method")
|
|
||||||
|
|
||||||
def clear_action_queue(self):
|
|
||||||
"""This should be called whenever the environment is reset."""
|
|
||||||
if self.n_action_steps is not None:
|
|
||||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs) -> Tensor:
|
|
||||||
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
|
||||||
|
|
||||||
WARNING: In general, this should not be overriden.
|
|
||||||
|
|
||||||
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
|
||||||
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
|
||||||
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
|
||||||
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
|
||||||
the subclass doesn't have to.
|
|
||||||
|
|
||||||
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
|
||||||
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
|
||||||
the action trajectory horizon and * is the action dimensions.
|
|
||||||
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
|
||||||
"""
|
|
||||||
if self.n_action_steps is None:
|
|
||||||
return self.select_actions(*args, **kwargs)
|
|
||||||
if len(self._action_queue) == 0:
|
|
||||||
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
|
||||||
# (n_action_steps, batch_size, *), hence the transpose.
|
|
||||||
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
|
||||||
return self._action_queue.popleft()
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
||||||
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
|
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
|
||||||
"""
|
"""
|
||||||
|
from collections import deque
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
@ -22,6 +22,67 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
from lerobot.common.utils import get_safe_torch_device
|
from lerobot.common.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
|
||||||
|
# class AbstractPolicy(nn.Module):
|
||||||
|
# """Base policy which all policies should be derived from.
|
||||||
|
|
||||||
|
# The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||||
|
# documentation for more information.
|
||||||
|
|
||||||
|
# Note:
|
||||||
|
# When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||||
|
# 1. set the required class attributes:
|
||||||
|
# - for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||||
|
# - for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||||
|
# - for classes inheriting from `AbstractPolicy`: `name`
|
||||||
|
# 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||||
|
# 3. update variables in `tests/test_available.py` by importing your new class
|
||||||
|
# """
|
||||||
|
|
||||||
|
# name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||||
|
|
||||||
|
# def __init__(self, n_action_steps: int | None):
|
||||||
|
# """
|
||||||
|
# n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||||
|
# action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||||
|
# adds that dimension.
|
||||||
|
# """
|
||||||
|
# super().__init__()
|
||||||
|
# assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
||||||
|
# self.n_action_steps = n_action_steps
|
||||||
|
# self.clear_action_queue()
|
||||||
|
|
||||||
|
# def clear_action_queue(self):
|
||||||
|
# """This should be called whenever the environment is reset."""
|
||||||
|
# if self.n_action_steps is not None:
|
||||||
|
# self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||||
|
|
||||||
|
# def forward(self, fn) -> Tensor:
|
||||||
|
# """Inference step that makes multi-step policies compatible with their single-step environments.
|
||||||
|
|
||||||
|
# WARNING: In general, this should not be overriden.
|
||||||
|
|
||||||
|
# Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
||||||
|
# into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
||||||
|
# observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
||||||
|
# observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||||
|
# the subclass doesn't have to.
|
||||||
|
|
||||||
|
# This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
||||||
|
# 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||||
|
# the action trajectory horizon and * is the action dimensions.
|
||||||
|
# 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||||
|
# """
|
||||||
|
# if self.n_action_steps is None:
|
||||||
|
# return self.select_actions(*args, **kwargs)
|
||||||
|
# if len(self._action_queue) == 0:
|
||||||
|
# # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
||||||
|
# # (n_action_steps, batch_size, *), hence the transpose.
|
||||||
|
# self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
||||||
|
# return self._action_queue.popleft()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ActionChunkingTransformerPolicy(nn.Module):
|
class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
"""
|
"""
|
||||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||||
|
@ -168,14 +229,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
nn.init.xavier_uniform_(p)
|
nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_actions(self, batch, *_):
|
def select_action(self, batch, *_):
|
||||||
# TODO(now): Implement queueing mechanism.
|
# TODO(now): Implement queueing mechanism.
|
||||||
self.eval()
|
self.eval()
|
||||||
self._preprocess_batch(batch)
|
self._preprocess_batch(batch)
|
||||||
|
|
||||||
# TODO(now): What's up with this 0.182?
|
# TODO(now): What's up with this 0.182?
|
||||||
action = self.forward(
|
action = self.forward(
|
||||||
robot_state=batch["observation.state"] * 0.182, image=batch["observation.images.top"]
|
robot_state=batch["observation.state"] * 0.182,
|
||||||
|
image=batch["observation.images.top"],
|
||||||
|
return_loss=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.temporal_agg:
|
if self.cfg.temporal_agg:
|
||||||
|
@ -226,7 +289,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
assert batch_size % self.cfg.horizon == 0
|
assert batch_size % self.cfg.horizon == 0
|
||||||
assert batch_size % num_slices == 0
|
assert batch_size % num_slices == 0
|
||||||
|
|
||||||
loss = self.compute_loss(batch)
|
loss = self.forward(batch, return_loss=True)["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
@ -247,44 +310,38 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def compute_loss(self, batch):
|
def forward(self, batch: dict[str, Tensor], return_loss: bool = False):
|
||||||
loss_dict = self.forward(
|
|
||||||
robot_state=batch["observation.state"],
|
|
||||||
image=batch["observation.images.top"],
|
|
||||||
actions=batch["action"],
|
|
||||||
)
|
|
||||||
loss = loss_dict["loss"]
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
|
|
||||||
# TODO(now): Maybe this shouldn't be here?
|
# TODO(now): Maybe this shouldn't be here?
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
image = normalize(image)
|
images = normalize(batch["observation.images.top"])
|
||||||
|
|
||||||
is_training = actions is not None
|
if return_loss: # training time
|
||||||
if is_training: # training time
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
|
||||||
actions = actions[:, : self.horizon]
|
batch["observation.state"], images, batch["action"]
|
||||||
|
)
|
||||||
|
|
||||||
a_hat, (mu, log_sigma_x2) = self._forward(robot_state, image, actions)
|
l1_loss = (
|
||||||
|
F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
* ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
l1 = all_l1.mean()
|
).mean()
|
||||||
|
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
loss_dict["l1"] = l1
|
loss_dict["l1"] = l1_loss
|
||||||
if self.cfg.use_vae:
|
if self.cfg.use_vae:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
# KL-divergence per batch element, then take the mean over the batch.
|
# KL-divergence per batch element, then take the mean over the batch.
|
||||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||||
mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean()
|
mean_kld = (
|
||||||
|
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||||
|
)
|
||||||
loss_dict["kl"] = mean_kld
|
loss_dict["kl"] = mean_kld
|
||||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
|
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
|
||||||
else:
|
else:
|
||||||
loss_dict["loss"] = loss_dict["l1"]
|
loss_dict["loss"] = loss_dict["l1"]
|
||||||
return loss_dict
|
return loss_dict
|
||||||
else:
|
else:
|
||||||
action, _ = self._forward(robot_state, image) # no action, sample from prior
|
action, _ = self._forward(batch["observation.state"], images)
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def _forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
|
def _forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
|
||||||
|
@ -321,7 +378,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
|
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
|
||||||
cls_token_out = self.vae_encoder(
|
cls_token_out = self.vae_encoder(
|
||||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||||
)[0] # (B, D)
|
)[
|
||||||
|
0
|
||||||
|
] # (B, D)
|
||||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||||
mu = latent_pdf_params[:, : self.latent_dim]
|
mu = latent_pdf_params[:, : self.latent_dim]
|
||||||
# This is 2log(sigma). Done this way to match the original implementation.
|
# This is 2log(sigma). Done this way to match the original implementation.
|
||||||
|
|
|
@ -251,7 +251,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
dataset = make_dataset(cfg, stats_path=stats_path)
|
dataset = make_dataset(cfg, stats_path=stats_path)
|
||||||
|
|
||||||
logging.info("Making environment.")
|
logging.info("Making environment.")
|
||||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
env = make_env(cfg, num_parallel_envs=cfg.rollout_batch_size)
|
||||||
|
|
||||||
# when policy is None, rollout a random policy
|
# when policy is None, rollout a random policy
|
||||||
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
|
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
|
||||||
|
|
|
@ -148,7 +148,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
# )
|
# )
|
||||||
|
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
# TODO(now): uncomment
|
||||||
|
#env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||||
|
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
|
|
@ -28,22 +28,23 @@ for to_remove in start_removals:
|
||||||
# Replace keys based on what they start with.
|
# Replace keys based on what they start with.
|
||||||
|
|
||||||
start_replacements = [
|
start_replacements = [
|
||||||
("model.query_embed.weight", "model.pos_embed.weight"),
|
("model.", ""),
|
||||||
("model.pos_table", "model.vae_encoder_pos_enc"),
|
("query_embed.weight", "pos_embed.weight"),
|
||||||
("model.pos_embed.weight", "model.decoder_pos_embed.weight"),
|
("pos_table", "vae_encoder_pos_enc"),
|
||||||
("model.encoder.", "model.vae_encoder."),
|
("pos_embed.weight", "decoder_pos_embed.weight"),
|
||||||
("model.encoder_action_proj.", "model.vae_encoder_action_input_proj."),
|
("encoder.", "vae_encoder."),
|
||||||
("model.encoder_joint_proj.", "model.vae_encoder_robot_state_input_proj."),
|
("encoder_action_proj.", "vae_encoder_action_input_proj."),
|
||||||
("model.latent_proj.", "model.vae_encoder_latent_output_proj."),
|
("encoder_joint_proj.", "vae_encoder_robot_state_input_proj."),
|
||||||
("model.latent_proj.", "model.vae_encoder_latent_output_proj."),
|
("latent_proj.", "vae_encoder_latent_output_proj."),
|
||||||
("model.input_proj.", "model.encoder_img_feat_input_proj."),
|
("latent_proj.", "vae_encoder_latent_output_proj."),
|
||||||
("model.input_proj_robot_state", "model.encoder_robot_state_input_proj"),
|
("input_proj.", "encoder_img_feat_input_proj."),
|
||||||
("model.latent_out_proj.", "model.encoder_latent_input_proj."),
|
("input_proj_robot_state", "encoder_robot_state_input_proj"),
|
||||||
("model.transformer.encoder.", "model.encoder."),
|
("latent_out_proj.", "encoder_latent_input_proj."),
|
||||||
("model.transformer.decoder.", "model.decoder."),
|
("transformer.encoder.", "encoder."),
|
||||||
("model.backbones.0.0.body.", "model.backbone."),
|
("transformer.decoder.", "decoder."),
|
||||||
("model.additional_pos_embed.weight", "model.encoder_robot_and_latent_pos_embed.weight"),
|
("backbones.0.0.body.", "backbone."),
|
||||||
("model.cls_embed.weight", "model.vae_encoder_cls_embed.weight"),
|
("additional_pos_embed.weight", "encoder_robot_and_latent_pos_embed.weight"),
|
||||||
|
("cls_embed.weight", "vae_encoder_cls_embed.weight"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for to_replace, replace_with in start_replacements:
|
for to_replace, replace_with in start_replacements:
|
||||||
|
|
Loading…
Reference in New Issue