backup wip

This commit is contained in:
Alexander Soare 2024-04-05 18:46:30 +01:00
parent ecc7dd3b17
commit 8d2463f45b
5 changed files with 105 additions and 120 deletions

View File

@ -4,79 +4,3 @@ import torch
from torch import Tensor, nn from torch import Tensor, nn
class AbstractPolicy(nn.Module):
"""Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
documentation for more information.
Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
"""
name: str | None = None # same name should be used to instantiate the policy in factory.py
def __init__(self, n_action_steps: int | None):
"""
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
adds that dimension.
"""
super().__init__()
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
self.n_action_steps = n_action_steps
self.clear_action_queue()
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
raise NotImplementedError("Abstract method")
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
def select_actions(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout.
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
"""
raise NotImplementedError("Abstract method")
def clear_action_queue(self):
"""This should be called whenever the environment is reset."""
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def forward(self, *args, **kwargs) -> Tensor:
"""Inference step that makes multi-step policies compatible with their single-step environments.
WARNING: In general, this should not be overriden.
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
the subclass doesn't have to.
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
the action trajectory horizon and * is the action dimensions.
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
"""
if self.n_action_steps is None:
return self.select_actions(*args, **kwargs)
if len(self._action_queue) == 0:
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
# (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
return self._action_queue.popleft()

View File

@ -3,7 +3,7 @@
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments. The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
""" """
from collections import deque
import math import math
import time import time
from itertools import chain from itertools import chain
@ -22,6 +22,67 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.utils import get_safe_torch_device from lerobot.common.utils import get_safe_torch_device
# class AbstractPolicy(nn.Module):
# """Base policy which all policies should be derived from.
# The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
# documentation for more information.
# Note:
# When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
# 1. set the required class attributes:
# - for classes inheriting from `AbstractDataset`: `available_datasets`
# - for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
# - for classes inheriting from `AbstractPolicy`: `name`
# 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
# 3. update variables in `tests/test_available.py` by importing your new class
# """
# name: str | None = None # same name should be used to instantiate the policy in factory.py
# def __init__(self, n_action_steps: int | None):
# """
# n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
# action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
# adds that dimension.
# """
# super().__init__()
# assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
# self.n_action_steps = n_action_steps
# self.clear_action_queue()
# def clear_action_queue(self):
# """This should be called whenever the environment is reset."""
# if self.n_action_steps is not None:
# self._action_queue = deque([], maxlen=self.n_action_steps)
# def forward(self, fn) -> Tensor:
# """Inference step that makes multi-step policies compatible with their single-step environments.
# WARNING: In general, this should not be overriden.
# Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
# into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
# observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
# observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
# the subclass doesn't have to.
# This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
# 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
# the action trajectory horizon and * is the action dimensions.
# 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
# """
# if self.n_action_steps is None:
# return self.select_actions(*args, **kwargs)
# if len(self._action_queue) == 0:
# # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
# # (n_action_steps, batch_size, *), hence the transpose.
# self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
# return self._action_queue.popleft()
class ActionChunkingTransformerPolicy(nn.Module): class ActionChunkingTransformerPolicy(nn.Module):
""" """
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
@ -168,14 +229,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
@torch.no_grad() @torch.no_grad()
def select_actions(self, batch, *_): def select_action(self, batch, *_):
# TODO(now): Implement queueing mechanism. # TODO(now): Implement queueing mechanism.
self.eval() self.eval()
self._preprocess_batch(batch) self._preprocess_batch(batch)
# TODO(now): What's up with this 0.182? # TODO(now): What's up with this 0.182?
action = self.forward( action = self.forward(
robot_state=batch["observation.state"] * 0.182, image=batch["observation.images.top"] robot_state=batch["observation.state"] * 0.182,
image=batch["observation.images.top"],
return_loss=False,
) )
if self.cfg.temporal_agg: if self.cfg.temporal_agg:
@ -226,7 +289,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
assert batch_size % self.cfg.horizon == 0 assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0 assert batch_size % num_slices == 0
loss = self.compute_loss(batch) loss = self.forward(batch, return_loss=True)["loss"]
loss.backward() loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
@ -247,44 +310,38 @@ class ActionChunkingTransformerPolicy(nn.Module):
return info return info
def compute_loss(self, batch): def forward(self, batch: dict[str, Tensor], return_loss: bool = False):
loss_dict = self.forward(
robot_state=batch["observation.state"],
image=batch["observation.images.top"],
actions=batch["action"],
)
loss = loss_dict["loss"]
return loss
def forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
# TODO(now): Maybe this shouldn't be here? # TODO(now): Maybe this shouldn't be here?
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = normalize(image) images = normalize(batch["observation.images.top"])
is_training = actions is not None if return_loss: # training time
if is_training: # training time actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
actions = actions[:, : self.horizon] batch["observation.state"], images, batch["action"]
)
a_hat, (mu, log_sigma_x2) = self._forward(robot_state, image, actions) l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none")
all_l1 = F.l1_loss(actions, a_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
l1 = all_l1.mean() ).mean()
loss_dict = {} loss_dict = {}
loss_dict["l1"] = l1 loss_dict["l1"] = l1_loss
if self.cfg.use_vae: if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total # each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch. # KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details). # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean() mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kl"] = mean_kld loss_dict["kl"] = mean_kld
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
else: else:
loss_dict["loss"] = loss_dict["l1"] loss_dict["loss"] = loss_dict["l1"]
return loss_dict return loss_dict
else: else:
action, _ = self._forward(robot_state, image) # no action, sample from prior action, _ = self._forward(batch["observation.state"], images)
return action return action
def _forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None): def _forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
@ -321,7 +378,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Forward pass through VAE encoder and sample the latent with the reparameterization trick. # Forward pass through VAE encoder and sample the latent with the reparameterization trick.
cls_token_out = self.vae_encoder( cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # (B, D) )[
0
] # (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.latent_dim] mu = latent_pdf_params[:, : self.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation. # This is 2log(sigma). Done this way to match the original implementation.

View File

@ -251,7 +251,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
dataset = make_dataset(cfg, stats_path=stats_path) dataset = make_dataset(cfg, stats_path=stats_path)
logging.info("Making environment.") logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) env = make_env(cfg, num_parallel_envs=cfg.rollout_batch_size)
# when policy is None, rollout a random policy # when policy is None, rollout a random policy
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None

View File

@ -148,7 +148,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
# ) # )
logging.info("make_env") logging.info("make_env")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) # TODO(now): uncomment
#env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
logging.info("make_policy") logging.info("make_policy")
policy = make_policy(cfg) policy = make_policy(cfg)

View File

@ -28,22 +28,23 @@ for to_remove in start_removals:
# Replace keys based on what they start with. # Replace keys based on what they start with.
start_replacements = [ start_replacements = [
("model.query_embed.weight", "model.pos_embed.weight"), ("model.", ""),
("model.pos_table", "model.vae_encoder_pos_enc"), ("query_embed.weight", "pos_embed.weight"),
("model.pos_embed.weight", "model.decoder_pos_embed.weight"), ("pos_table", "vae_encoder_pos_enc"),
("model.encoder.", "model.vae_encoder."), ("pos_embed.weight", "decoder_pos_embed.weight"),
("model.encoder_action_proj.", "model.vae_encoder_action_input_proj."), ("encoder.", "vae_encoder."),
("model.encoder_joint_proj.", "model.vae_encoder_robot_state_input_proj."), ("encoder_action_proj.", "vae_encoder_action_input_proj."),
("model.latent_proj.", "model.vae_encoder_latent_output_proj."), ("encoder_joint_proj.", "vae_encoder_robot_state_input_proj."),
("model.latent_proj.", "model.vae_encoder_latent_output_proj."), ("latent_proj.", "vae_encoder_latent_output_proj."),
("model.input_proj.", "model.encoder_img_feat_input_proj."), ("latent_proj.", "vae_encoder_latent_output_proj."),
("model.input_proj_robot_state", "model.encoder_robot_state_input_proj"), ("input_proj.", "encoder_img_feat_input_proj."),
("model.latent_out_proj.", "model.encoder_latent_input_proj."), ("input_proj_robot_state", "encoder_robot_state_input_proj"),
("model.transformer.encoder.", "model.encoder."), ("latent_out_proj.", "encoder_latent_input_proj."),
("model.transformer.decoder.", "model.decoder."), ("transformer.encoder.", "encoder."),
("model.backbones.0.0.body.", "model.backbone."), ("transformer.decoder.", "decoder."),
("model.additional_pos_embed.weight", "model.encoder_robot_and_latent_pos_embed.weight"), ("backbones.0.0.body.", "backbone."),
("model.cls_embed.weight", "model.vae_encoder_cls_embed.weight"), ("additional_pos_embed.weight", "encoder_robot_and_latent_pos_embed.weight"),
("cls_embed.weight", "vae_encoder_cls_embed.weight"),
] ]
for to_replace, replace_with in start_replacements: for to_replace, replace_with in start_replacements: