make unit tests pass
This commit is contained in:
parent
42ed7bb670
commit
0660f71556
|
@ -44,7 +44,7 @@ from datasets import load_dataset
|
||||||
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
||||||
|
|
||||||
# download/load hugging face dataset in pyarrow format
|
# download/load hugging face dataset in pyarrow format
|
||||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
hf_dataset, fps = load_dataset("lerobot/pusht", split="train", revision="v1.1"), 10
|
||||||
|
|
||||||
# display name of dataset and its features
|
# display name of dataset and its features
|
||||||
# TODO(rcadene): update to make the print pretty
|
# TODO(rcadene): update to make the print pretty
|
||||||
|
|
|
@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg)
|
||||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||||
cfg = DiffusionConfig()
|
cfg = DiffusionConfig()
|
||||||
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
||||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
|
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
|
||||||
policy.train()
|
policy.train()
|
||||||
policy.to(device)
|
policy.to(device)
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
|
|
||||||
|
@ -43,7 +44,10 @@ def make_dataset(
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.get("override_dataset_stats"):
|
if cfg.get("override_dataset_stats"):
|
||||||
for key, val in cfg.override_dataset_stats.items():
|
for key, stats_dict in cfg.override_dataset_stats.items():
|
||||||
dataset.stats[key] = torch.tensor(val)
|
for stats_type, listconfig in stats_dict.items():
|
||||||
|
# example of stats_type: min, max, mean, std
|
||||||
|
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||||
|
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
@ -22,7 +22,7 @@ def preprocess_observation(observation):
|
||||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
|
|
||||||
# convert to channel first of type float32 in range [0,1]
|
# convert to channel first of type float32 in range [0,1]
|
||||||
img = einops.rearrange(img, "b h w c -> b c h w")
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||||
img = img.type(torch.float32)
|
img = img.type(torch.float32)
|
||||||
img /= 255
|
img /= 255
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -61,13 +61,17 @@ class ActionChunkingTransformerConfig:
|
||||||
n_action_steps: int = 100
|
n_action_steps: int = 100
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes: dict[str, str] = {
|
normalize_input_modes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
"observation.image": "mean_std",
|
"observation.image": "mean_std",
|
||||||
"observation.state": "mean_std",
|
"observation.state": "mean_std",
|
||||||
}
|
}
|
||||||
unnormalize_output_modes: dict[str, str] = {
|
)
|
||||||
|
unnormalize_output_modes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
"action": "mean_std",
|
"action": "mean_std",
|
||||||
}
|
}
|
||||||
|
)
|
||||||
# Architecture.
|
# Architecture.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
vision_backbone: str = "resnet18"
|
vision_backbone: str = "resnet18"
|
||||||
|
|
|
@ -22,6 +22,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||||
from lerobot.common.policies.utils import (
|
from lerobot.common.policies.utils import (
|
||||||
normalize_inputs,
|
normalize_inputs,
|
||||||
|
to_buffer_dict,
|
||||||
unnormalize_outputs,
|
unnormalize_outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -75,7 +76,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = ActionChunkingTransformerConfig()
|
cfg = ActionChunkingTransformerConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.register_buffer("dataset_stats", dataset_stats)
|
self.dataset_stats = to_buffer_dict(dataset_stats)
|
||||||
self.normalize_input_modes = cfg.normalize_input_modes
|
self.normalize_input_modes = cfg.normalize_input_modes
|
||||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||||
|
|
||||||
|
@ -179,7 +180,12 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||||
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
|
||||||
|
# TODO(rcadene): make _forward return output dictionary?
|
||||||
|
out_dict = {"action": actions}
|
||||||
|
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||||
|
actions = out_dict["action"]
|
||||||
|
|
||||||
self._action_queue.extend(actions.transpose(0, 1))
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
|
@ -214,7 +220,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||||
loss_dict = self.forward(batch)
|
loss_dict = self.forward(batch)
|
||||||
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||||
loss = loss_dict["loss"]
|
loss = loss_dict["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -70,13 +70,17 @@ class DiffusionConfig:
|
||||||
n_action_steps: int = 8
|
n_action_steps: int = 8
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes: dict[str, str] = {
|
normalize_input_modes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
"observation.image": "mean_std",
|
"observation.image": "mean_std",
|
||||||
"observation.state": "min_max",
|
"observation.state": "min_max",
|
||||||
}
|
}
|
||||||
unnormalize_output_modes: dict[str, str] = {
|
)
|
||||||
|
unnormalize_output_modes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
"action": "min_max",
|
"action": "min_max",
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
|
|
|
@ -31,6 +31,7 @@ from lerobot.common.policies.utils import (
|
||||||
get_dtype_from_parameters,
|
get_dtype_from_parameters,
|
||||||
normalize_inputs,
|
normalize_inputs,
|
||||||
populate_queues,
|
populate_queues,
|
||||||
|
to_buffer_dict,
|
||||||
unnormalize_outputs,
|
unnormalize_outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -57,7 +58,7 @@ class DiffusionPolicy(nn.Module):
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = DiffusionConfig()
|
cfg = DiffusionConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.register_buffer("dataset_stats", dataset_stats)
|
self.dataset_stats = to_buffer_dict(dataset_stats)
|
||||||
self.normalize_input_modes = cfg.normalize_input_modes
|
self.normalize_input_modes = cfg.normalize_input_modes
|
||||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||||
|
|
||||||
|
@ -144,7 +145,11 @@ class DiffusionPolicy(nn.Module):
|
||||||
else:
|
else:
|
||||||
actions = self.diffusion.generate_actions(batch)
|
actions = self.diffusion.generate_actions(batch)
|
||||||
|
|
||||||
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
# TODO(rcadene): make above methods return output dictionary?
|
||||||
|
out_dict = {"action": actions}
|
||||||
|
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||||
|
actions = out_dict["action"]
|
||||||
|
|
||||||
self._queues["action"].extend(actions.transpose(0, 1))
|
self._queues["action"].extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
action = self._queues["action"].popleft()
|
action = self._queues["action"].popleft()
|
||||||
|
@ -166,7 +171,7 @@ class DiffusionPolicy(nn.Module):
|
||||||
loss = self.forward(batch)["loss"]
|
loss = self.forward(batch)["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
self.diffusion.parameters(),
|
self.diffusion.parameters(),
|
||||||
|
|
|
@ -66,3 +66,20 @@ def unnormalize_outputs(batch, stats, unnormalize_output_modes):
|
||||||
else:
|
else:
|
||||||
raise ValueError(mode)
|
raise ValueError(mode)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def to_buffer_dict(dataset_stats):
|
||||||
|
# TODO(rcadene): replace this function by `torch.BufferDict` when it exists
|
||||||
|
# see: https://github.com/pytorch/pytorch/issues/37386
|
||||||
|
# TODO(rcadene): make `to_buffer_dict` generic and add docstring
|
||||||
|
if dataset_stats is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
new_ds_stats = {}
|
||||||
|
for key, stats_dict in dataset_stats.items():
|
||||||
|
new_stats_dict = {}
|
||||||
|
for stats_type, value in stats_dict.items():
|
||||||
|
# set requires_grad=False to have the same behavior as a nn.Buffer
|
||||||
|
new_stats_dict[stats_type] = nn.Parameter(value, requires_grad=False)
|
||||||
|
new_ds_stats[key] = nn.ParameterDict(new_stats_dict)
|
||||||
|
return nn.ParameterDict(new_ds_stats)
|
||||||
|
|
|
@ -12,7 +12,8 @@ n_obs_steps: 1
|
||||||
# when temporal_agg=False, n_action_steps=horizon
|
# when temporal_agg=False, n_action_steps=horizon
|
||||||
|
|
||||||
override_dataset_stats:
|
override_dataset_stats:
|
||||||
observation.image:
|
observation.images.top:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
|
@ -35,7 +36,7 @@ policy:
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes:
|
normalize_input_modes:
|
||||||
observation.image: mean_std
|
observation.images.top: mean_std
|
||||||
observation.state: mean_std
|
observation.state: mean_std
|
||||||
unnormalize_output_modes:
|
unnormalize_output_modes:
|
||||||
action: mean_std
|
action: mean_std
|
||||||
|
|
|
@ -19,9 +19,12 @@ online_steps: 0
|
||||||
offline_prioritized_sampler: true
|
offline_prioritized_sampler: true
|
||||||
|
|
||||||
override_dataset_stats:
|
override_dataset_stats:
|
||||||
|
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||||
observation.image:
|
observation.image:
|
||||||
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||||
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||||
|
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
|
||||||
|
# from the original codebase, but we should remove these and train our own pretrained model
|
||||||
observation.state:
|
observation.state:
|
||||||
min: [13.456424, 32.938293]
|
min: [13.456424, 32.938293]
|
||||||
max: [496.14618, 510.9579]
|
max: [496.14618, 510.9579]
|
||||||
|
|
|
@ -50,11 +50,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||||
log_output_dir(out_dir)
|
log_output_dir(out_dir)
|
||||||
|
|
||||||
logging.info("make_dataset")
|
logging.info("make_dataset")
|
||||||
dataset = make_dataset(
|
dataset = make_dataset(cfg)
|
||||||
cfg,
|
|
||||||
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
|
||||||
normalize=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Start rendering episodes from offline buffer")
|
logging.info("Start rendering episodes from offline buffer")
|
||||||
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
|
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
|
||||||
|
|
|
@ -6,7 +6,6 @@ import torch
|
||||||
from gymnasium.utils.env_checker import check_env
|
from gymnasium.utils.env_checker import check_env
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
@ -38,12 +37,14 @@ def test_factory(env_name):
|
||||||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = make_dataset(cfg)
|
|
||||||
|
|
||||||
env = make_env(cfg, num_parallel_envs=1)
|
env = make_env(cfg, num_parallel_envs=1)
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
obs = preprocess_observation(obs)
|
obs = preprocess_observation(obs)
|
||||||
for key in dataset.image_keys:
|
|
||||||
|
# test image keys are float32 in range [0,1]
|
||||||
|
for key in obs:
|
||||||
|
if "image" not in key:
|
||||||
|
continue
|
||||||
img = obs[key]
|
img = obs[key]
|
||||||
assert img.dtype == torch.float32
|
assert img.dtype == torch.float32
|
||||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||||
|
|
Loading…
Reference in New Issue