Merge remote-tracking branch 'upstream/main'
This commit is contained in:
commit
ca1c184cb1
|
@ -145,10 +145,3 @@ class ACTConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||||
)
|
)
|
||||||
# Check that there is only one image.
|
|
||||||
# TODO(alexander-soare): generalize this to multiple images.
|
|
||||||
if (
|
|
||||||
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
|
|
||||||
or "observation.images.top" not in self.input_shapes
|
|
||||||
):
|
|
||||||
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')
|
|
||||||
|
|
|
@ -62,6 +62,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
if config is None:
|
if config is None:
|
||||||
config = ACTConfig()
|
config = ACTConfig()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = Normalize(
|
||||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
|
@ -71,8 +72,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
self.unnormalize_outputs = Unnormalize(
|
self.unnormalize_outputs = Unnormalize(
|
||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = ACT(config)
|
self.model = ACT(config)
|
||||||
|
|
||||||
|
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""This should be called whenever the environment is reset."""
|
"""This should be called whenever the environment is reset."""
|
||||||
if self.config.n_action_steps is not None:
|
if self.config.n_action_steps is not None:
|
||||||
|
@ -86,13 +92,10 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||||
queue is empty.
|
queue is empty.
|
||||||
"""
|
"""
|
||||||
assert "observation.images.top" in batch
|
|
||||||
assert "observation.state" in batch
|
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
self._stack_images(batch)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
|
|
||||||
if len(self._action_queue) == 0:
|
if len(self._action_queue) == 0:
|
||||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||||
|
@ -108,8 +111,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
self._stack_images(batch)
|
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
|
||||||
l1_loss = (
|
l1_loss = (
|
||||||
|
@ -132,21 +135,6 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
||||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
|
||||||
|
|
||||||
This function expects `batch` to have (at least):
|
|
||||||
{
|
|
||||||
"observation.state": (B, state_dim) batch of robot states.
|
|
||||||
"observation.images.{name}": (B, C, H, W) tensor of images.
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
# Stack images in the order dictated by input_shapes.
|
|
||||||
batch["observation.images"] = torch.stack(
|
|
||||||
[batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")],
|
|
||||||
dim=-4,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ACT(nn.Module):
|
class ACT(nn.Module):
|
||||||
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
||||||
|
@ -176,10 +164,10 @@ class ACT(nn.Module):
|
||||||
│ encoder │ │ │ │Transf.│ │
|
│ encoder │ │ │ │Transf.│ │
|
||||||
│ │ │ │ │encoder│ │
|
│ │ │ │ │encoder│ │
|
||||||
└───▲─────┘ │ │ │ │ │
|
└───▲─────┘ │ │ │ │ │
|
||||||
│ │ │ └───▲───┘ │
|
│ │ │ └▲──▲─▲─┘ │
|
||||||
│ │ │ │ │
|
│ │ │ │ │ │ │
|
||||||
inputs └─────┼─────┘ │
|
inputs └─────┼──┘ │ image emb. │
|
||||||
│ │
|
│ state emb. │
|
||||||
└───────────────────────┘
|
└───────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -321,18 +309,18 @@ class ACT(nn.Module):
|
||||||
all_cam_features.append(cam_features)
|
all_cam_features.append(cam_features)
|
||||||
all_cam_pos_embeds.append(cam_pos_embed)
|
all_cam_pos_embeds.append(cam_pos_embed)
|
||||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
|
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
|
||||||
encoder_in = torch.cat(all_cam_features, axis=3)
|
encoder_in = torch.cat(all_cam_features, axis=-1)
|
||||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
|
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||||
|
|
||||||
# Get positional embeddings for robot state and latent.
|
# Get positional embeddings for robot state and latent.
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
|
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample)
|
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||||
|
|
||||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||||
encoder_in = torch.cat(
|
encoder_in = torch.cat(
|
||||||
[
|
[
|
||||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
torch.stack([latent_embed, robot_state_embed], axis=0),
|
||||||
encoder_in.flatten(2).permute(2, 0, 1),
|
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
pos_embed = torch.cat(
|
pos_embed = torch.cat(
|
||||||
|
|
|
@ -148,14 +148,21 @@ class DiffusionConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
)
|
)
|
||||||
|
# There should only be one image key.
|
||||||
|
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||||
|
if len(image_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||||
|
)
|
||||||
|
image_key = next(iter(image_keys))
|
||||||
if (
|
if (
|
||||||
self.crop_shape[0] > self.input_shapes["observation.image"][1]
|
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||||
or self.crop_shape[1] > self.input_shapes["observation.image"][2]
|
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
|
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||||
f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
|
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||||
'`input_shapes["observation.image"]`.'
|
"`input_shapes[{image_key}]`."
|
||||||
)
|
)
|
||||||
supported_prediction_types = ["epsilon", "sample"]
|
supported_prediction_types = ["epsilon", "sample"]
|
||||||
if self.prediction_type not in supported_prediction_types:
|
if self.prediction_type not in supported_prediction_types:
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
TODO(alexander-soare):
|
TODO(alexander-soare):
|
||||||
- Remove reliance on Robomimic for SpatialSoftmax.
|
- Remove reliance on Robomimic for SpatialSoftmax.
|
||||||
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||||
|
- Make compatible with multiple image keys.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
@ -83,10 +84,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
self.diffusion = DiffusionModel(config)
|
self.diffusion = DiffusionModel(config)
|
||||||
|
|
||||||
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
|
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||||
|
if len(image_keys) != 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||||
|
)
|
||||||
|
self.input_image_key = image_keys[0]
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||||
Clear observation and action queues. Should be called on `env.reset()`
|
|
||||||
"""
|
|
||||||
self._queues = {
|
self._queues = {
|
||||||
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
||||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||||
|
@ -115,16 +124,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
"""
|
"""
|
||||||
assert "observation.image" in batch
|
|
||||||
assert "observation.state" in batch
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
if len(self._queues["action"]) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
# stack n latest observations from the queue
|
# stack n latest observations from the queue
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||||
actions = self.diffusion.generate_actions(batch)
|
actions = self.diffusion.generate_actions(batch)
|
||||||
|
|
||||||
# TODO(rcadene): make above methods return output dictionary?
|
# TODO(rcadene): make above methods return output dictionary?
|
||||||
|
@ -138,6 +145,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
return {"loss": loss}
|
return {"loss": loss}
|
||||||
|
@ -215,13 +223,12 @@ class DiffusionModel(nn.Module):
|
||||||
|
|
||||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
This function expects `batch` to have (at least):
|
This function expects `batch` to have:
|
||||||
{
|
{
|
||||||
"observation.state": (B, n_obs_steps, state_dim)
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
"observation.image": (B, n_obs_steps, C, H, W)
|
"observation.image": (B, n_obs_steps, C, H, W)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
assert n_obs_steps == self.config.n_obs_steps
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
|
@ -345,9 +352,12 @@ class DiffusionRgbEncoder(nn.Module):
|
||||||
|
|
||||||
# Set up pooling and final layers.
|
# Set up pooling and final layers.
|
||||||
# Use a dry run to get the feature map shape.
|
# Use a dry run to get the feature map shape.
|
||||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should use the
|
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||||
# height and width from `config.crop_shape`.
|
# use the height and width from `config.crop_shape`.
|
||||||
dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape))
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
|
assert len(image_keys) == 1
|
||||||
|
image_key = image_keys[0]
|
||||||
|
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
dummy_feature_map = self.backbone(dummy_input)
|
dummy_feature_map = self.backbone(dummy_input)
|
||||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||||
|
|
|
@ -147,12 +147,18 @@ class TDMPCConfig:
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
|
# There should only be one image key.
|
||||||
|
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||||
|
if len(image_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||||
|
)
|
||||||
|
image_key = next(iter(image_keys))
|
||||||
|
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||||
# augmentation. It should be able to be removed.
|
# augmentation. It should be able to be removed.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only square images are handled now. Got image shape "
|
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||||
f"{self.input_shapes['observation.image']}."
|
|
||||||
)
|
)
|
||||||
if self.n_gaussian_samples <= 0:
|
if self.n_gaussian_samples <= 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -112,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
def save(self, fp):
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
"""Save state dict of TOLD model to filepath."""
|
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||||
torch.save(self.state_dict(), fp)
|
assert len(image_keys) == 1
|
||||||
|
self.input_image_key = image_keys[0]
|
||||||
|
|
||||||
def load(self, fp):
|
self.reset()
|
||||||
"""Load a saved state dict from filepath into current agent."""
|
|
||||||
self.load_state_dict(torch.load(fp))
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
|
@ -137,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]):
|
def select_action(self, batch: dict[str, Tensor]):
|
||||||
"""Select a single action given environment observations."""
|
"""Select a single action given environment observations."""
|
||||||
assert "observation.image" in batch
|
|
||||||
assert "observation.state" in batch
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
|
@ -319,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
|
|
||||||
info = {}
|
info = {}
|
||||||
|
|
||||||
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
|
|
||||||
batch_size = batch["index"].shape[0]
|
|
||||||
|
|
||||||
# (b, t) -> (t, b)
|
# (b, t) -> (t, b)
|
||||||
for key in batch:
|
for key in batch:
|
||||||
if batch[key].ndim > 1:
|
if batch[key].ndim > 1:
|
||||||
|
@ -353,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
# Run latent rollout using the latent dynamics model and policy model.
|
# Run latent rollout using the latent dynamics model and policy model.
|
||||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||||
# gives us a next `z`.
|
# gives us a next `z`.
|
||||||
|
batch_size = batch["index"].shape[0]
|
||||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||||
z_preds[0] = self.model.encode(current_observation)
|
z_preds[0] = self.model.encode(current_observation)
|
||||||
reward_preds = torch.empty_like(reward, device=device)
|
reward_preds = torch.empty_like(reward, device=device)
|
||||||
|
|
|
@ -19,6 +19,10 @@ from torch import nn
|
||||||
|
|
||||||
def populate_queues(queues, batch):
|
def populate_queues(queues, batch):
|
||||||
for key in batch:
|
for key in batch:
|
||||||
|
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
|
||||||
|
# queues have the keys they want).
|
||||||
|
if key not in queues:
|
||||||
|
continue
|
||||||
if len(queues[key]) != queues[key].maxlen:
|
if len(queues[key]) != queues[key].maxlen:
|
||||||
# initialize by copying the first observation several times until the queue is full
|
# initialize by copying the first observation several times until the queue is full
|
||||||
while len(queues[key]) != queues[key].maxlen:
|
while len(queues[key]) != queues[key].maxlen:
|
||||||
|
|
|
@ -23,6 +23,7 @@ import hydra
|
||||||
import torch
|
import torch
|
||||||
from datasets import concatenate_datasets
|
from datasets import concatenate_datasets
|
||||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
|
@ -307,7 +308,7 @@ def add_episodes_inplace(
|
||||||
sampler.num_samples = len(concat_dataset)
|
sampler.num_samples = len(concat_dataset)
|
||||||
|
|
||||||
|
|
||||||
def train(cfg: dict, out_dir=None, job_name=None):
|
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if job_name is None:
|
if job_name is None:
|
||||||
|
|
|
@ -64,6 +64,14 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||||
"act",
|
"act",
|
||||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
||||||
),
|
),
|
||||||
|
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||||
|
(
|
||||||
|
"aloha",
|
||||||
|
"diffusion",
|
||||||
|
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"],
|
||||||
|
),
|
||||||
|
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||||
|
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@require_env
|
@require_env
|
||||||
|
@ -87,6 +95,31 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
+ extra_overrides,
|
+ extra_overrides,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Additional config override logic.
|
||||||
|
if env_name == "aloha" and policy_name == "diffusion":
|
||||||
|
for keys in [
|
||||||
|
("training", "delta_timestamps"),
|
||||||
|
("policy", "input_shapes"),
|
||||||
|
("policy", "input_normalization_modes"),
|
||||||
|
]:
|
||||||
|
dct = dict(cfg[keys[0]][keys[1]])
|
||||||
|
dct["observation.images.top"] = dct["observation.image"]
|
||||||
|
del dct["observation.image"]
|
||||||
|
cfg[keys[0]][keys[1]] = dct
|
||||||
|
cfg.override_dataset_stats = None
|
||||||
|
|
||||||
|
# Additional config override logic.
|
||||||
|
if env_name == "pusht" and policy_name == "act":
|
||||||
|
for keys in [
|
||||||
|
("policy", "input_shapes"),
|
||||||
|
("policy", "input_normalization_modes"),
|
||||||
|
]:
|
||||||
|
dct = dict(cfg[keys[0]][keys[1]])
|
||||||
|
dct["observation.image"] = dct["observation.images.top"]
|
||||||
|
del dct["observation.images.top"]
|
||||||
|
cfg[keys[0]][keys[1]] = dct
|
||||||
|
cfg.override_dataset_stats = None
|
||||||
|
|
||||||
# Check that we can make the policy object.
|
# Check that we can make the policy object.
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||||
|
|
Loading…
Reference in New Issue