make unit tests pass

This commit is contained in:
Cadene 2024-04-23 21:39:39 +00:00
parent 42ed7bb670
commit 0660f71556
13 changed files with 79 additions and 38 deletions

View File

@ -44,7 +44,7 @@ from datasets import load_dataset
# TODO(rcadene): list available datasets on lerobot page using `datasets` # TODO(rcadene): list available datasets on lerobot page using `datasets`
# download/load hugging face dataset in pyarrow format # download/load hugging face dataset in pyarrow format
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10 hf_dataset, fps = load_dataset("lerobot/pusht", split="train", revision="v1.1"), 10
# display name of dataset and its features # display name of dataset and its features
# TODO(rcadene): update to make the print pretty # TODO(rcadene): update to make the print pretty

View File

@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg)
# If you're doing something different, you will likely need to change at least some of the defaults. # If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig() cfg = DiffusionConfig()
# TODO(alexander-soare): Remove LR scheduler from the policy. # TODO(alexander-soare): Remove LR scheduler from the policy.
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps) policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
policy.train() policy.train()
policy.to(device) policy.to(device)

View File

@ -2,6 +2,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from omegaconf import OmegaConf
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@ -43,7 +44,10 @@ def make_dataset(
) )
if cfg.get("override_dataset_stats"): if cfg.get("override_dataset_stats"):
for key, val in cfg.override_dataset_stats.items(): for key, stats_dict in cfg.override_dataset_stats.items():
dataset.stats[key] = torch.tensor(val) for stats_type, listconfig in stats_dict.items():
# example of stats_type: min, max, mean, std
stats = OmegaConf.to_container(listconfig, resolve=True)
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset return dataset

View File

@ -22,7 +22,7 @@ def preprocess_observation(observation):
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1] # convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w") img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32) img = img.type(torch.float32)
img /= 255 img /= 255

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass from dataclasses import dataclass, field
@dataclass @dataclass
@ -61,13 +61,17 @@ class ActionChunkingTransformerConfig:
n_action_steps: int = 100 n_action_steps: int = 100
# Normalization / Unnormalization # Normalization / Unnormalization
normalize_input_modes: dict[str, str] = { normalize_input_modes: dict[str, str] = field(
"observation.image": "mean_std", default_factory=lambda: {
"observation.state": "mean_std", "observation.image": "mean_std",
} "observation.state": "mean_std",
unnormalize_output_modes: dict[str, str] = { }
"action": "mean_std", )
} unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
}
)
# Architecture. # Architecture.
# Vision backbone. # Vision backbone.
vision_backbone: str = "resnet18" vision_backbone: str = "resnet18"

View File

@ -22,6 +22,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.utils import ( from lerobot.common.policies.utils import (
normalize_inputs, normalize_inputs,
to_buffer_dict,
unnormalize_outputs, unnormalize_outputs,
) )
@ -75,7 +76,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None: if cfg is None:
cfg = ActionChunkingTransformerConfig() cfg = ActionChunkingTransformerConfig()
self.cfg = cfg self.cfg = cfg
self.register_buffer("dataset_stats", dataset_stats) self.dataset_stats = to_buffer_dict(dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes self.unnormalize_output_modes = cfg.unnormalize_output_modes
@ -179,7 +180,12 @@ class ActionChunkingTransformerPolicy(nn.Module):
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively # `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose. # has shape (n_action_steps, batch_size, *), hence the transpose.
actions = self._forward(batch)[0][: self.cfg.n_action_steps] actions = self._forward(batch)[0][: self.cfg.n_action_steps]
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
# TODO(rcadene): make _forward return output dictionary?
out_dict = {"action": actions}
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
actions = out_dict["action"]
self._action_queue.extend(actions.transpose(0, 1)) self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()
@ -214,7 +220,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes) batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
loss_dict = self.forward(batch) loss_dict = self.forward(batch)
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes) # TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
loss = loss_dict["loss"] loss = loss_dict["loss"]
loss.backward() loss.backward()

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass from dataclasses import dataclass, field
@dataclass @dataclass
@ -70,13 +70,17 @@ class DiffusionConfig:
n_action_steps: int = 8 n_action_steps: int = 8
# Normalization / Unnormalization # Normalization / Unnormalization
normalize_input_modes: dict[str, str] = { normalize_input_modes: dict[str, str] = field(
"observation.image": "mean_std", default_factory=lambda: {
"observation.state": "min_max", "observation.image": "mean_std",
} "observation.state": "min_max",
unnormalize_output_modes: dict[str, str] = { }
"action": "min_max", )
} unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "min_max",
}
)
# Architecture / modeling. # Architecture / modeling.
# Vision backbone. # Vision backbone.

View File

@ -31,6 +31,7 @@ from lerobot.common.policies.utils import (
get_dtype_from_parameters, get_dtype_from_parameters,
normalize_inputs, normalize_inputs,
populate_queues, populate_queues,
to_buffer_dict,
unnormalize_outputs, unnormalize_outputs,
) )
@ -57,7 +58,7 @@ class DiffusionPolicy(nn.Module):
if cfg is None: if cfg is None:
cfg = DiffusionConfig() cfg = DiffusionConfig()
self.cfg = cfg self.cfg = cfg
self.register_buffer("dataset_stats", dataset_stats) self.dataset_stats = to_buffer_dict(dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes self.unnormalize_output_modes = cfg.unnormalize_output_modes
@ -144,7 +145,11 @@ class DiffusionPolicy(nn.Module):
else: else:
actions = self.diffusion.generate_actions(batch) actions = self.diffusion.generate_actions(batch)
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes) # TODO(rcadene): make above methods return output dictionary?
out_dict = {"action": actions}
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
actions = out_dict["action"]
self._queues["action"].extend(actions.transpose(0, 1)) self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft() action = self._queues["action"].popleft()
@ -166,7 +171,7 @@ class DiffusionPolicy(nn.Module):
loss = self.forward(batch)["loss"] loss = self.forward(batch)["loss"]
loss.backward() loss.backward()
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes) # TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(), self.diffusion.parameters(),

View File

@ -66,3 +66,20 @@ def unnormalize_outputs(batch, stats, unnormalize_output_modes):
else: else:
raise ValueError(mode) raise ValueError(mode)
return batch return batch
def to_buffer_dict(dataset_stats):
# TODO(rcadene): replace this function by `torch.BufferDict` when it exists
# see: https://github.com/pytorch/pytorch/issues/37386
# TODO(rcadene): make `to_buffer_dict` generic and add docstring
if dataset_stats is None:
return None
new_ds_stats = {}
for key, stats_dict in dataset_stats.items():
new_stats_dict = {}
for stats_type, value in stats_dict.items():
# set requires_grad=False to have the same behavior as a nn.Buffer
new_stats_dict[stats_type] = nn.Parameter(value, requires_grad=False)
new_ds_stats[key] = nn.ParameterDict(new_stats_dict)
return nn.ParameterDict(new_ds_stats)

View File

@ -12,7 +12,8 @@ n_obs_steps: 1
# when temporal_agg=False, n_action_steps=horizon # when temporal_agg=False, n_action_steps=horizon
override_dataset_stats: override_dataset_stats:
observation.image: observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
@ -35,7 +36,7 @@ policy:
# Normalization / Unnormalization # Normalization / Unnormalization
normalize_input_modes: normalize_input_modes:
observation.image: mean_std observation.images.top: mean_std
observation.state: mean_std observation.state: mean_std
unnormalize_output_modes: unnormalize_output_modes:
action: mean_std action: mean_std

View File

@ -19,9 +19,12 @@ online_steps: 0
offline_prioritized_sampler: true offline_prioritized_sampler: true
override_dataset_stats: override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image: observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1) mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1) std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state: observation.state:
min: [13.456424, 32.938293] min: [13.456424, 32.938293]
max: [496.14618, 510.9579] max: [496.14618, 510.9579]

View File

@ -50,11 +50,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
log_output_dir(out_dir) log_output_dir(out_dir)
logging.info("make_dataset") logging.info("make_dataset")
dataset = make_dataset( dataset = make_dataset(cfg)
cfg,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
normalize=False,
)
logging.info("Start rendering episodes from offline buffer") logging.info("Start rendering episodes from offline buffer")
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER) video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)

View File

@ -6,7 +6,6 @@ import torch
from gymnasium.utils.env_checker import check_env from gymnasium.utils.env_checker import check_env
import lerobot import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
@ -38,12 +37,14 @@ def test_factory(env_name):
overrides=[f"env={env_name}", f"device={DEVICE}"], overrides=[f"env={env_name}", f"device={DEVICE}"],
) )
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=1) env = make_env(cfg, num_parallel_envs=1)
obs, _ = env.reset() obs, _ = env.reset()
obs = preprocess_observation(obs) obs = preprocess_observation(obs)
for key in dataset.image_keys:
# test image keys are float32 in range [0,1]
for key in obs:
if "image" not in key:
continue
img = obs[key] img = obs[key]
assert img.dtype == torch.float32 assert img.dtype == torch.float32
# TODO(rcadene): we assume for now that image normalization takes place in the model # TODO(rcadene): we assume for now that image normalization takes place in the model