make unit tests pass
This commit is contained in:
parent
42ed7bb670
commit
0660f71556
|
@ -44,7 +44,7 @@ from datasets import load_dataset
|
|||
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
||||
|
||||
# download/load hugging face dataset in pyarrow format
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train", revision="v1.1"), 10
|
||||
|
||||
# display name of dataset and its features
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
|
|
|
@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg)
|
|||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
|
||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
@ -43,7 +44,10 @@ def make_dataset(
|
|||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
for key, val in cfg.override_dataset_stats.items():
|
||||
dataset.stats[key] = torch.tensor(val)
|
||||
for key, stats_dict in cfg.override_dataset_stats.items():
|
||||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
|
|
@ -22,7 +22,7 @@ def preprocess_observation(observation):
|
|||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w")
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -61,13 +61,17 @@ class ActionChunkingTransformerConfig:
|
|||
n_action_steps: int = 100
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
unnormalize_output_modes: dict[str, str] = {
|
||||
"action": "mean_std",
|
||||
}
|
||||
normalize_input_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
)
|
||||
unnormalize_output_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "mean_std",
|
||||
}
|
||||
)
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
|
|
|
@ -22,6 +22,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
|||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.utils import (
|
||||
normalize_inputs,
|
||||
to_buffer_dict,
|
||||
unnormalize_outputs,
|
||||
)
|
||||
|
||||
|
@ -75,7 +76,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
self.cfg = cfg
|
||||
self.register_buffer("dataset_stats", dataset_stats)
|
||||
self.dataset_stats = to_buffer_dict(dataset_stats)
|
||||
self.normalize_input_modes = cfg.normalize_input_modes
|
||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||
|
||||
|
@ -179,7 +180,12 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
out_dict = {"action": actions}
|
||||
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||
actions = out_dict["action"]
|
||||
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
|
@ -214,7 +220,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
loss_dict = self.forward(batch)
|
||||
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -70,13 +70,17 @@ class DiffusionConfig:
|
|||
n_action_steps: int = 8
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
unnormalize_output_modes: dict[str, str] = {
|
||||
"action": "min_max",
|
||||
}
|
||||
normalize_input_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
unnormalize_output_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "min_max",
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
|
|
|
@ -31,6 +31,7 @@ from lerobot.common.policies.utils import (
|
|||
get_dtype_from_parameters,
|
||||
normalize_inputs,
|
||||
populate_queues,
|
||||
to_buffer_dict,
|
||||
unnormalize_outputs,
|
||||
)
|
||||
|
||||
|
@ -57,7 +58,7 @@ class DiffusionPolicy(nn.Module):
|
|||
if cfg is None:
|
||||
cfg = DiffusionConfig()
|
||||
self.cfg = cfg
|
||||
self.register_buffer("dataset_stats", dataset_stats)
|
||||
self.dataset_stats = to_buffer_dict(dataset_stats)
|
||||
self.normalize_input_modes = cfg.normalize_input_modes
|
||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||
|
||||
|
@ -144,7 +145,11 @@ class DiffusionPolicy(nn.Module):
|
|||
else:
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
out_dict = {"action": actions}
|
||||
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||
actions = out_dict["action"]
|
||||
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
|
@ -166,7 +171,7 @@ class DiffusionPolicy(nn.Module):
|
|||
loss = self.forward(batch)["loss"]
|
||||
loss.backward()
|
||||
|
||||
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.diffusion.parameters(),
|
||||
|
|
|
@ -66,3 +66,20 @@ def unnormalize_outputs(batch, stats, unnormalize_output_modes):
|
|||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
def to_buffer_dict(dataset_stats):
|
||||
# TODO(rcadene): replace this function by `torch.BufferDict` when it exists
|
||||
# see: https://github.com/pytorch/pytorch/issues/37386
|
||||
# TODO(rcadene): make `to_buffer_dict` generic and add docstring
|
||||
if dataset_stats is None:
|
||||
return None
|
||||
|
||||
new_ds_stats = {}
|
||||
for key, stats_dict in dataset_stats.items():
|
||||
new_stats_dict = {}
|
||||
for stats_type, value in stats_dict.items():
|
||||
# set requires_grad=False to have the same behavior as a nn.Buffer
|
||||
new_stats_dict[stats_type] = nn.Parameter(value, requires_grad=False)
|
||||
new_ds_stats[key] = nn.ParameterDict(new_stats_dict)
|
||||
return nn.ParameterDict(new_ds_stats)
|
||||
|
|
|
@ -12,7 +12,8 @@ n_obs_steps: 1
|
|||
# when temporal_agg=False, n_action_steps=horizon
|
||||
|
||||
override_dataset_stats:
|
||||
observation.image:
|
||||
observation.images.top:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
|
@ -35,7 +36,7 @@ policy:
|
|||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes:
|
||||
observation.image: mean_std
|
||||
observation.images.top: mean_std
|
||||
observation.state: mean_std
|
||||
unnormalize_output_modes:
|
||||
action: mean_std
|
||||
|
|
|
@ -19,9 +19,12 @@ online_steps: 0
|
|||
offline_prioritized_sampler: true
|
||||
|
||||
override_dataset_stats:
|
||||
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||
observation.image:
|
||||
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
|
||||
# from the original codebase, but we should remove these and train our own pretrained model
|
||||
observation.state:
|
||||
min: [13.456424, 32.938293]
|
||||
max: [496.14618, 510.9579]
|
||||
|
|
|
@ -50,11 +50,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
|||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("make_dataset")
|
||||
dataset = make_dataset(
|
||||
cfg,
|
||||
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
||||
normalize=False,
|
||||
)
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
logging.info("Start rendering episodes from offline buffer")
|
||||
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
|
||||
|
|
|
@ -6,7 +6,6 @@ import torch
|
|||
from gymnasium.utils.env_checker import check_env
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
@ -38,12 +37,14 @@ def test_factory(env_name):
|
|||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||
)
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
env = make_env(cfg, num_parallel_envs=1)
|
||||
obs, _ = env.reset()
|
||||
obs = preprocess_observation(obs)
|
||||
for key in dataset.image_keys:
|
||||
|
||||
# test image keys are float32 in range [0,1]
|
||||
for key in obs:
|
||||
if "image" not in key:
|
||||
continue
|
||||
img = obs[key]
|
||||
assert img.dtype == torch.float32
|
||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||
|
|
Loading…
Reference in New Issue