Add online training with TD-MPC as proof of concept (#338)
This commit is contained in:
parent
abbb1d2367
commit
f8a6574698
24
Makefile
24
Makefile
|
@ -26,6 +26,7 @@ test-end-to-end:
|
|||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train-with-online
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
|
||||
|
@ -113,7 +114,6 @@ test-diffusion-ete-eval:
|
|||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
|
||||
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
|
||||
test-tdmpc-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=tdmpc \
|
||||
|
@ -133,6 +133,28 @@ test-tdmpc-ete-train:
|
|||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-train-with-online:
|
||||
python lerobot/scripts/train.py \
|
||||
env=pusht \
|
||||
env.gym.obs_type=environment_state_agent_pos \
|
||||
policy=tdmpc_pusht_keypoints \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=10 \
|
||||
device=$(DEVICE) \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=20 \
|
||||
training.save_checkpoint=false \
|
||||
training.save_freq=10 \
|
||||
training.batch_size=2 \
|
||||
training.online_rollout_n_episodes=2 \
|
||||
training.online_rollout_batch_size=2 \
|
||||
training.online_steps_between_rollouts=10 \
|
||||
training.online_buffer_capacity=15 \
|
||||
eval.use_async_envs=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc_online/
|
||||
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
|
|
|
@ -0,0 +1,384 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""An online buffer for the online training loop in train.py
|
||||
|
||||
Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should
|
||||
consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much
|
||||
faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it
|
||||
supports in-place slicing and mutation which is very handy for a dynamic buffer.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def _make_memmap_safe(**kwargs) -> np.memmap:
|
||||
"""Make a numpy memmap with checks on available disk space first.
|
||||
|
||||
Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape"
|
||||
|
||||
For information on dtypes:
|
||||
https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
|
||||
"""
|
||||
if kwargs["mode"].startswith("w"):
|
||||
required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes
|
||||
stats = os.statvfs(Path(kwargs["filename"]).parent)
|
||||
available_space = stats.f_bavail * stats.f_frsize # bytes
|
||||
if required_space >= available_space * 0.8:
|
||||
raise RuntimeError(
|
||||
f"You're about to take up {required_space} of {available_space} bytes available."
|
||||
)
|
||||
return np.memmap(**kwargs)
|
||||
|
||||
|
||||
class OnlineBuffer(torch.utils.data.Dataset):
|
||||
"""FIFO data buffer for the online training loop in train.py.
|
||||
|
||||
Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training
|
||||
loop in the same way that a LeRobotDataset would be used.
|
||||
|
||||
The underlying data structure will have data inserted in a circular fashion. Always insert after the
|
||||
last index, and when you reach the end, wrap around to the start.
|
||||
|
||||
The data is stored in a numpy memmap.
|
||||
"""
|
||||
|
||||
NEXT_INDEX_KEY = "_next_index"
|
||||
OCCUPANCY_MASK_KEY = "_occupancy_mask"
|
||||
INDEX_KEY = "index"
|
||||
FRAME_INDEX_KEY = "frame_index"
|
||||
EPISODE_INDEX_KEY = "episode_index"
|
||||
TIMESTAMP_KEY = "timestamp"
|
||||
IS_PAD_POSTFIX = "_is_pad"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
write_dir: str | Path,
|
||||
data_spec: dict[str, Any] | None,
|
||||
buffer_capacity: int | None,
|
||||
fps: float | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None,
|
||||
):
|
||||
"""
|
||||
The online buffer can be provided from scratch or you can load an existing online buffer by passing
|
||||
a `write_dir` associated with an existing buffer.
|
||||
|
||||
Args:
|
||||
write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key.
|
||||
Note that if the files already exist, they are opened in read-write mode (used for training
|
||||
resumption.)
|
||||
data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int],
|
||||
"dtype": np.dtype}}. This should include all the data that you wish to record into the buffer,
|
||||
but note that "index", "frame_index" and "episode_index" are already accounted for by this
|
||||
class, so you don't need to include them.
|
||||
buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your
|
||||
system's available disk space when choosing this.
|
||||
fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the
|
||||
delta_timestamps logic. You can pass None if you are not using delta_timestamps.
|
||||
delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally
|
||||
converted to dict[str, np.ndarray] for optimization purposes.
|
||||
|
||||
"""
|
||||
self.set_delta_timestamps(delta_timestamps)
|
||||
self._fps = fps
|
||||
# Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from
|
||||
# the requested frames. It is only used when `delta_timestamps` is provided.
|
||||
# minus 1e-4 to account for possible numerical error
|
||||
self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None
|
||||
self._buffer_capacity = buffer_capacity
|
||||
data_spec = self._make_data_spec(data_spec, buffer_capacity)
|
||||
Path(write_dir).mkdir(parents=True, exist_ok=True)
|
||||
self._data = {}
|
||||
for k, v in data_spec.items():
|
||||
self._data[k] = _make_memmap_safe(
|
||||
filename=Path(write_dir) / k,
|
||||
dtype=v["dtype"] if v is not None else None,
|
||||
mode="r+" if (Path(write_dir) / k).exists() else "w+",
|
||||
shape=tuple(v["shape"]) if v is not None else None,
|
||||
)
|
||||
|
||||
@property
|
||||
def delta_timestamps(self) -> dict[str, np.ndarray] | None:
|
||||
return self._delta_timestamps
|
||||
|
||||
def set_delta_timestamps(self, value: dict[str, list[float]] | None):
|
||||
"""Set delta_timestamps converting the values to numpy arrays.
|
||||
|
||||
The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays
|
||||
need to be converted into numpy arrays.
|
||||
"""
|
||||
if value is not None:
|
||||
self._delta_timestamps = {k: np.array(v) for k, v in value.items()}
|
||||
else:
|
||||
self._delta_timestamps = None
|
||||
|
||||
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
|
||||
"""Makes the data spec for np.memmap."""
|
||||
if any(k.startswith("_") for k in data_spec):
|
||||
raise ValueError(
|
||||
"data_spec keys should not start with '_'. This prefix is reserved for internal logic."
|
||||
)
|
||||
preset_keys = {
|
||||
OnlineBuffer.INDEX_KEY,
|
||||
OnlineBuffer.FRAME_INDEX_KEY,
|
||||
OnlineBuffer.EPISODE_INDEX_KEY,
|
||||
OnlineBuffer.TIMESTAMP_KEY,
|
||||
}
|
||||
if len(intersection := set(data_spec).intersection(preset_keys)) > 0:
|
||||
raise ValueError(
|
||||
f"data_spec should not contain any of {preset_keys} as these are handled internally. "
|
||||
f"The provided data_spec has {intersection}."
|
||||
)
|
||||
complete_data_spec = {
|
||||
# _next_index will be a pointer to the next index that we should start filling from when we add
|
||||
# more data.
|
||||
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
||||
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
||||
# with real data rather than the dummy initialization.
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
|
||||
}
|
||||
for k, v in data_spec.items():
|
||||
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
|
||||
return complete_data_spec
|
||||
|
||||
def add_data(self, data: dict[str, np.ndarray]):
|
||||
"""Add new data to the buffer, which could potentially mean shifting old data out.
|
||||
|
||||
The new data should contain all the frames (in order) of any number of episodes. The indices should
|
||||
start from 0 (note to the developer: this can easily be generalized). See the `rollout` and
|
||||
`eval_policy` functions in `eval.py` for more information on how the data is constructed.
|
||||
|
||||
Shift the incoming data index and episode_index to continue on from the last frame. Note that this
|
||||
will be done in place!
|
||||
"""
|
||||
if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0:
|
||||
raise ValueError(f"Missing data keys: {missing_keys}")
|
||||
new_data_length = len(data[self.data_keys[0]])
|
||||
if not all(len(data[k]) == new_data_length for k in self.data_keys):
|
||||
raise ValueError("All data items should have the same length")
|
||||
|
||||
next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY]
|
||||
|
||||
# Sanity check to make sure that the new data indices start from 0.
|
||||
assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0
|
||||
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
|
||||
|
||||
# Shift the incoming indices if necessary.
|
||||
if self.num_samples > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
|
||||
|
||||
# Insert the new data starting from next_index. It may be necessary to wrap around to the start.
|
||||
n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index))
|
||||
for k in self.data_keys:
|
||||
if n_surplus == 0:
|
||||
slc = slice(next_index, next_index + new_data_length)
|
||||
self._data[k][slc] = data[k]
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True
|
||||
else:
|
||||
self._data[k][next_index:] = data[k][:-n_surplus]
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True
|
||||
self._data[k][:n_surplus] = data[k][-n_surplus:]
|
||||
if n_surplus == 0:
|
||||
self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length
|
||||
else:
|
||||
self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus
|
||||
|
||||
@property
|
||||
def data_keys(self) -> list[str]:
|
||||
keys = set(self._data)
|
||||
keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY)
|
||||
keys.remove(OnlineBuffer.NEXT_INDEX_KEY)
|
||||
return sorted(keys)
|
||||
|
||||
@property
|
||||
def fps(self) -> float | None:
|
||||
return self._fps
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(
|
||||
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def _item_to_tensors(self, item: dict) -> dict:
|
||||
item_ = {}
|
||||
for k, v in item.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
item_[k] = v
|
||||
elif isinstance(v, np.ndarray):
|
||||
item_[k] = torch.from_numpy(v)
|
||||
else:
|
||||
item_[k] = torch.tensor(v)
|
||||
return item_
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self) or idx < -len(self):
|
||||
raise IndexError
|
||||
|
||||
item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")}
|
||||
|
||||
if self.delta_timestamps is None:
|
||||
return self._item_to_tensors(item)
|
||||
|
||||
episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY]
|
||||
current_ts = item[OnlineBuffer.TIMESTAMP_KEY]
|
||||
episode_data_indices = np.where(
|
||||
np.bitwise_and(
|
||||
self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index,
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
|
||||
)
|
||||
)[0]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
|
||||
|
||||
for data_key in self.delta_timestamps:
|
||||
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
|
||||
# Get timestamps used as query to retrieve data of previous/future frames.
|
||||
query_ts = current_ts + self.delta_timestamps[data_key]
|
||||
|
||||
# Compute distances between each query timestamp and all timestamps of all the frames belonging to
|
||||
# the episode.
|
||||
dist = np.abs(query_ts[:, None] - episode_timestamps[None, :])
|
||||
argmin_ = np.argmin(dist, axis=1)
|
||||
min_ = dist[np.arange(dist.shape[0]), argmin_]
|
||||
|
||||
is_pad = min_ > self.tolerance_s
|
||||
|
||||
# Check violated query timestamps are all outside the episode range.
|
||||
assert (
|
||||
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
|
||||
).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
|
||||
") inside the episode range."
|
||||
)
|
||||
|
||||
# Load frames for this data key.
|
||||
item[data_key] = self._data[data_key][episode_data_indices[argmin_]]
|
||||
|
||||
item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad
|
||||
|
||||
return self._item_to_tensors(item)
|
||||
|
||||
def get_data_by_key(self, key: str) -> torch.Tensor:
|
||||
"""Returns all data for a given data key as a Tensor."""
|
||||
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
|
||||
|
||||
def compute_sampler_weights(
|
||||
offline_dataset: LeRobotDataset,
|
||||
offline_drop_n_last_frames: int = 0,
|
||||
online_dataset: OnlineBuffer | None = None,
|
||||
online_sampling_ratio: float | None = None,
|
||||
online_drop_n_last_frames: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Compute the sampling weights for the online training dataloader in train.py.
|
||||
|
||||
Args:
|
||||
offline_dataset: The LeRobotDataset used for offline pre-training.
|
||||
online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode.
|
||||
online_dataset: The OnlineBuffer used in online training.
|
||||
online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an
|
||||
online dataset is provided, this value must also be provided.
|
||||
online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online
|
||||
dataset.
|
||||
Returns:
|
||||
Tensor of weights for [offline_dataset; online_dataset], normalized to 1.
|
||||
|
||||
Notes to maintainers:
|
||||
- This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach.
|
||||
- When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace
|
||||
`EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature
|
||||
is the ability to turn shuffling off.
|
||||
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
|
||||
included here to avoid adding complexity.
|
||||
"""
|
||||
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
|
||||
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
|
||||
if (online_dataset is None) ^ (online_sampling_ratio is None):
|
||||
raise ValueError(
|
||||
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
|
||||
)
|
||||
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
|
||||
weights = []
|
||||
|
||||
if len(offline_dataset) > 0:
|
||||
offline_data_mask_indices = []
|
||||
for start_index, end_index in zip(
|
||||
offline_dataset.episode_data_index["from"],
|
||||
offline_dataset.episode_data_index["to"],
|
||||
strict=True,
|
||||
):
|
||||
offline_data_mask_indices.extend(
|
||||
range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
|
||||
)
|
||||
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
|
||||
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
|
||||
weights.append(
|
||||
torch.full(
|
||||
size=(len(offline_dataset),),
|
||||
fill_value=offline_sampling_ratio / offline_data_mask.sum(),
|
||||
)
|
||||
* offline_data_mask
|
||||
)
|
||||
|
||||
if online_dataset is not None and len(online_dataset) > 0:
|
||||
online_data_mask_indices = []
|
||||
episode_indices = online_dataset.get_data_by_key("episode_index")
|
||||
for episode_idx in torch.unique(episode_indices):
|
||||
where_episode = torch.where(episode_indices == episode_idx)
|
||||
start_index = where_episode[0][0]
|
||||
end_index = where_episode[0][-1] + 1
|
||||
online_data_mask_indices.extend(
|
||||
range(start_index.item(), end_index.item() - online_drop_n_last_frames)
|
||||
)
|
||||
online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool)
|
||||
online_data_mask[torch.tensor(online_data_mask_indices)] = True
|
||||
weights.append(
|
||||
torch.full(
|
||||
size=(len(online_dataset),),
|
||||
fill_value=online_sampling_ratio / online_data_mask.sum(),
|
||||
)
|
||||
* online_data_mask
|
||||
)
|
||||
|
||||
weights = torch.cat(weights)
|
||||
|
||||
if weights.sum() == 0:
|
||||
weights += 1 / len(weights)
|
||||
else:
|
||||
weights /= weights.sum()
|
||||
|
||||
return weights
|
|
@ -25,12 +25,16 @@ class TDMPCConfig:
|
|||
camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||
|
||||
Args:
|
||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
action repeats in Q-learning or ask your favorite chatbot)
|
||||
horizon: Horizon for model predictive control.
|
||||
n_action_steps: Number of action steps to take from the plan given by model predictive control. This
|
||||
is an alternative to using action repeats. If this is set to more than 1, then we require
|
||||
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
|
||||
approach of using multiple steps from the plan is not in the original implementation.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
|
@ -100,6 +104,7 @@ class TDMPCConfig:
|
|||
# Input / output structure.
|
||||
n_action_repeats: int = 2
|
||||
horizon: int = 5
|
||||
n_action_steps: int = 1
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
|
@ -158,17 +163,18 @@ class TDMPCConfig:
|
|||
"""Input validation (not exhaustive)."""
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if len(image_keys) != 1:
|
||||
if len(image_keys) > 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
if len(image_keys) > 0:
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
)
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
|
@ -179,3 +185,12 @@ class TDMPCConfig:
|
|||
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
|
||||
"information."
|
||||
)
|
||||
if self.n_action_steps > 1:
|
||||
if self.n_action_repeats != 1:
|
||||
raise ValueError(
|
||||
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
|
||||
)
|
||||
if not self.use_mpc:
|
||||
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
|
|
|
@ -19,14 +19,10 @@
|
|||
The comments in this code may sometimes refer to these references:
|
||||
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
|
||||
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
|
||||
|
||||
TODO(alexander-soare): Make rollout work for batch sizes larger than 1.
|
||||
TODO(alexander-soare): Use batch-first throughout.
|
||||
"""
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
@ -56,9 +52,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
process communication to use the xarm environment from FOWM. This is because our xarm
|
||||
environment uses newer dependencies and does not match the environment in FOWM. See
|
||||
https://github.com/huggingface/lerobot/pull/103 for implementation details.
|
||||
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
|
||||
- We have NOT checked that training on LeRobot reproduces the results from FOWM.
|
||||
- Nevertheless, we have verified that we can train TD-MPC for PushT. See
|
||||
`lerobot/configs/policy/tdmpc_pusht_keypoints.yaml`.
|
||||
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
|
||||
match our xarm environment.
|
||||
match our xarm environment.
|
||||
"""
|
||||
|
||||
name = "tdmpc"
|
||||
|
@ -74,22 +72,6 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
logging.warning(
|
||||
"""
|
||||
Please note several warnings for this policy.
|
||||
|
||||
- Evaluation of pretrained weights created with the original FOWM code
|
||||
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
|
||||
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
|
||||
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
|
||||
process communication to use the xarm environment from FOWM. This is because our xarm
|
||||
environment uses newer dependencies and does not match the environment in FOWM. See
|
||||
https://github.com/huggingface/lerobot/pull/103 for implementation details.
|
||||
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
|
||||
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
|
||||
match our xarm environment.
|
||||
"""
|
||||
)
|
||||
|
||||
if config is None:
|
||||
config = TDMPCConfig()
|
||||
|
@ -114,8 +96,14 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
assert len(image_keys) == 1
|
||||
self.input_image_key = image_keys[0]
|
||||
self._use_image = False
|
||||
self._use_env_state = False
|
||||
if len(image_keys) > 0:
|
||||
assert len(image_keys) == 1
|
||||
self._use_image = True
|
||||
self.input_image_key = image_keys[0]
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
|
||||
self.reset()
|
||||
|
||||
|
@ -125,10 +113,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=1),
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=self.config.n_action_repeats),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self._use_image:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if self._use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
|
@ -137,8 +128,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
|
@ -152,49 +144,57 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
batch[key] = batch[key][:, 0]
|
||||
|
||||
# NOTE: Order of observations matters here.
|
||||
z = self.model.encode({k: batch[k] for k in ["observation.image", "observation.state"]})
|
||||
if self.config.use_mpc:
|
||||
batch_size = batch["observation.image"].shape[0]
|
||||
# Batch processing is not handled in MPC mode, so process the batch in a loop.
|
||||
action = [] # will be a batch of actions for one step
|
||||
for i in range(batch_size):
|
||||
# Note: self.plan does not handle batches, hence the squeeze.
|
||||
action.append(self.plan(z[i]))
|
||||
action = torch.stack(action)
|
||||
encode_keys = []
|
||||
if self._use_image:
|
||||
encode_keys.append("observation.image")
|
||||
if self._use_env_state:
|
||||
encode_keys.append("observation.environment_state")
|
||||
encode_keys.append("observation.state")
|
||||
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||
if self.config.use_mpc: # noqa: SIM108
|
||||
actions = self.plan(z) # (horizon, batch, action_dim)
|
||||
else:
|
||||
# Plan with the policy (π) alone.
|
||||
action = self.model.pi(z)
|
||||
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
|
||||
# sequence dimension like in the MPC branch.
|
||||
actions = self.model.pi(z).unsqueeze(0)
|
||||
|
||||
self.unnormalize_outputs({"action": action})["action"]
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
|
||||
for _ in range(self.config.n_action_repeats):
|
||||
self._queues["action"].append(action)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.n_action_repeats > 1:
|
||||
for _ in range(self.config.n_action_repeats):
|
||||
self._queues["action"].append(actions[0])
|
||||
else:
|
||||
# Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action.
|
||||
self._queues["action"].extend(actions[: self.config.n_action_steps])
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return torch.clamp(action, -1, 1)
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
def plan(self, z: Tensor) -> Tensor:
|
||||
"""Plan next action using TD-MPC inference.
|
||||
"""Plan sequence of actions using TD-MPC inference.
|
||||
|
||||
Args:
|
||||
z: (latent_dim,) tensor for the initial state.
|
||||
z: (batch, latent_dim,) tensor for the initial state.
|
||||
Returns:
|
||||
(action_dim,) tensor for the next action.
|
||||
|
||||
TODO(alexander-soare) Extend this to be able to work with batches.
|
||||
(horizon, batch, action_dim,) tensor for the planned trajectory of actions.
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch_size = z.shape[0]
|
||||
|
||||
# Sample Nπ trajectories from the policy.
|
||||
pi_actions = torch.empty(
|
||||
self.config.horizon,
|
||||
self.config.n_pi_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=device,
|
||||
)
|
||||
if self.config.n_pi_samples > 0:
|
||||
_z = einops.repeat(z, "d -> n d", n=self.config.n_pi_samples)
|
||||
_z = einops.repeat(z, "b d -> n b d", n=self.config.n_pi_samples)
|
||||
for t in range(self.config.horizon):
|
||||
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
|
||||
# helpful for CEM.
|
||||
|
@ -203,12 +203,14 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||
# trajectories.
|
||||
z = einops.repeat(z, "d -> n d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||
|
||||
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(self.config.horizon, self.config.output_shapes["action"][0], device=device)
|
||||
mean = torch.zeros(
|
||||
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
|
||||
)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
|
@ -219,6 +221,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
std_normal_noise = torch.randn(
|
||||
self.config.horizon,
|
||||
self.config.n_gaussian_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=std.device,
|
||||
)
|
||||
|
@ -227,21 +230,24 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
# Compute elite actions.
|
||||
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
|
||||
value = self.estimate_value(z, actions).nan_to_num_(0)
|
||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices
|
||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
|
||||
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
|
||||
# (horizon, n_elites, batch, action_dim)
|
||||
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
|
||||
|
||||
# Update guassian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||
max_value = elite_value.max(0)[0]
|
||||
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
|
||||
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
|
||||
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
|
||||
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
|
||||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||
score /= score.sum()
|
||||
_mean = torch.sum(einops.rearrange(score, "n -> n 1") * elite_actions, dim=1)
|
||||
score /= score.sum(axis=0, keepdim=True)
|
||||
# (horizon, batch, action_dim)
|
||||
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
|
||||
_std = torch.sqrt(
|
||||
torch.sum(
|
||||
einops.rearrange(score, "n -> n 1")
|
||||
* (elite_actions - einops.rearrange(_mean, "h d -> h 1 d")) ** 2,
|
||||
einops.rearrange(score, "n b -> n b 1")
|
||||
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
|
@ -256,11 +262,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
|
||||
# scores from the last iteration.
|
||||
actions = elite_actions[:, torch.multinomial(score, 1).item()]
|
||||
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
|
||||
|
||||
# Select only the first action
|
||||
action = actions[0]
|
||||
return action
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_value(self, z: Tensor, actions: Tensor):
|
||||
|
@ -312,13 +316,17 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss."""
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
@ -328,12 +336,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"] # (t, b)
|
||||
reward = batch["next.reward"] # (t,)
|
||||
action = batch["action"] # (t, b, action_dim)
|
||||
reward = batch["next.reward"] # (t, b)
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
# Apply random image augmentations.
|
||||
if self.config.max_random_shift_ratio > 0:
|
||||
if self._use_image and self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
observations["observation.image"],
|
||||
|
@ -345,7 +353,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
for k in observations:
|
||||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon = next_observations["observation.image"].shape[0]
|
||||
horizon, batch_size = next_observations[
|
||||
"observation.image" if self._use_image else "observation.environment_state"
|
||||
].shape[:2]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
|
@ -415,7 +425,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
q_value_loss = (
|
||||
(
|
||||
F.mse_loss(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
reduction="none",
|
||||
|
@ -464,10 +475,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
|
||||
# Calculate the MSE between the actions and the action predictions.
|
||||
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
|
||||
# gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying
|
||||
# the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset
|
||||
# as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration
|
||||
# parameter for it (see below where we compute the total loss).
|
||||
# gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to
|
||||
# multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action
|
||||
# dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop
|
||||
# the 0.5 as we instead make a configuration parameter for it (see below where we compute the total
|
||||
# loss).
|
||||
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
|
||||
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
|
||||
# other losses.
|
||||
|
@ -728,6 +740,16 @@ class TDMPCObservationEncoder(nn.Module):
|
|||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
@ -736,8 +758,11 @@ class TDMPCObservationEncoder(nn.Module):
|
|||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
# NOTE: Order of observations matters here.
|
||||
if "observation.image" in self.config.input_shapes:
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
|
|
@ -32,19 +32,54 @@ video_backend: pyav
|
|||
|
||||
training:
|
||||
offline_steps: ???
|
||||
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
|
||||
online_steps: ???
|
||||
online_steps_between_rollouts: ???
|
||||
online_sampling_ratio: 0.5
|
||||
# `online_env_seed` is used for environments for online training data rollouts.
|
||||
online_env_seed: ???
|
||||
|
||||
# Number of workers for the offline training dataloader.
|
||||
num_workers: 4
|
||||
|
||||
batch_size: ???
|
||||
|
||||
eval_freq: ???
|
||||
log_freq: 200
|
||||
save_checkpoint: true
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: ???
|
||||
num_workers: 4
|
||||
batch_size: ???
|
||||
|
||||
# Online training. Note that the online training loop adopts most of the options above apart from the
|
||||
# dataloader options. Unless otherwise specified.
|
||||
# The online training look looks something like:
|
||||
#
|
||||
# for i in range(online_steps):
|
||||
# do_online_rollout_and_update_online_buffer()
|
||||
# for j in range(online_steps_between_rollouts):
|
||||
# batch = next(dataloader_with_offline_and_online_data)
|
||||
# loss = policy(batch)
|
||||
# loss.backward()
|
||||
# optimizer.step()
|
||||
#
|
||||
online_steps: ???
|
||||
# How many episodes to collect at once when we reach the online rollout part of the training loop.
|
||||
online_rollout_n_episodes: 1
|
||||
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
|
||||
# the policy. Ideally you should set this to by an even divisor or online_rollout_n_episodes.
|
||||
online_rollout_batch_size: 1
|
||||
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
|
||||
online_steps_between_rollouts: null
|
||||
# The proportion of online samples (vs offline samples) to include in the online training batches.
|
||||
online_sampling_ratio: 0.5
|
||||
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
|
||||
online_env_seed: null
|
||||
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
|
||||
# FIFO.
|
||||
online_buffer_capacity: null
|
||||
# The minimum number of frames to have in the online buffer before commencing online training.
|
||||
# If online_buffer_seed_size > online_rollout_n_episodes, the rollout will be run multiple times until the
|
||||
# seed size condition is satisfied.
|
||||
online_buffer_seed_size: 0
|
||||
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
|
||||
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
|
||||
# + eval + environment rendering simultaneously.
|
||||
do_online_rollout_async: false
|
||||
|
||||
image_transforms:
|
||||
# These transforms are all using standard torchvision.transforms.v2
|
||||
# You can find out how these transformations affect images here:
|
||||
|
|
|
@ -9,7 +9,7 @@ env:
|
|||
state_dim: 4
|
||||
action_dim: 4
|
||||
fps: ${fps}
|
||||
episode_length: 25
|
||||
episode_length: 200
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
|
|
|
@ -4,19 +4,30 @@ seed: 1
|
|||
dataset_repo_id: lerobot/xarm_lift_medium
|
||||
|
||||
training:
|
||||
offline_steps: 25000
|
||||
# TODO(alexander-soare): uncomment when online training gets reinstated
|
||||
online_steps: 0 # 25000 not implemented yet
|
||||
eval_freq: 5000
|
||||
online_steps_between_rollouts: 1
|
||||
online_sampling_ratio: 0.5
|
||||
online_env_seed: 10000
|
||||
log_freq: 100
|
||||
offline_steps: 50000
|
||||
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 5000
|
||||
log_freq: 100
|
||||
|
||||
online_steps: 50000
|
||||
online_rollout_n_episodes: 1
|
||||
online_rollout_batch_size: 1
|
||||
# Note: in FOWM `online_steps_between_rollouts` is actually dynamically set to match exactly the length of
|
||||
# the last sampled episode.
|
||||
online_steps_between_rollouts: 50
|
||||
online_sampling_ratio: 0.5
|
||||
online_env_seed: 10000
|
||||
# FOWM Push uses 10000 for `online_buffer_capacity`. Given that their maximum episode length for this task
|
||||
# is 25, 10000 is approx 400 of their episodes worth. Since our episodes are about 8 times longer, we'll use
|
||||
# 80000.
|
||||
online_buffer_capacity: 80000
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
|
@ -31,6 +42,7 @@ policy:
|
|||
# Input / output structure.
|
||||
n_action_repeats: 2
|
||||
horizon: 5
|
||||
n_action_steps: 1
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# env=pusht \
|
||||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
# policy=tdmpc_pusht_keypoints \
|
||||
# eval.batch_size=50 \
|
||||
# eval.n_episodes=50 \
|
||||
# eval.use_async_envs=true \
|
||||
# device=cuda \
|
||||
# use_amp=true
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: lerobot/pusht_keypoints
|
||||
|
||||
training:
|
||||
offline_steps: 0
|
||||
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 10000
|
||||
log_freq: 500
|
||||
save_freq: 50000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 40000
|
||||
online_buffer_seed_size: 0
|
||||
do_online_rollout_async: false
|
||||
|
||||
delta_timestamps:
|
||||
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: tdmpc
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 5
|
||||
n_action_steps: 5
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.environment_state: [16]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.environment_state: min_max
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: 32
|
||||
state_encoder_hidden_dim: 256
|
||||
latent_dim: 50
|
||||
q_ensemble_size: 5
|
||||
mlp_dim: 512
|
||||
# Reinforcement learning.
|
||||
discount: 0.98
|
||||
|
||||
# Inference.
|
||||
use_mpc: true
|
||||
cem_iterations: 6
|
||||
max_std: 2.0
|
||||
min_std: 0.05
|
||||
n_gaussian_samples: 512
|
||||
n_pi_samples: 51
|
||||
uncertainty_regularizer_coeff: 1.0
|
||||
n_elites: 50
|
||||
elite_weighting_temperature: 0.5
|
||||
gaussian_mean_momentum: 0.1
|
||||
|
||||
# Training and loss computation.
|
||||
max_random_shift_ratio: 0.0476
|
||||
# Loss coefficients.
|
||||
reward_coeff: 0.5
|
||||
expectile_weight: 0.9
|
||||
value_coeff: 0.1
|
||||
consistency_coeff: 20.0
|
||||
advantage_scaling: 3.0
|
||||
pi_coeff: 0.5
|
||||
temporal_decay_coeff: 0.5
|
||||
# Target model.
|
||||
target_model_momentum: 0.995
|
|
@ -56,16 +56,13 @@ import einops
|
|||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value, concatenate_datasets
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from PIL import Image as PILImage
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.logger import log_output_dir
|
||||
|
@ -318,41 +315,17 @@ def eval_policy(
|
|||
rollout_data,
|
||||
done_indices,
|
||||
start_episode_index=batch_ix * env.num_envs,
|
||||
start_data_index=(
|
||||
0 if episode_data is None else (episode_data["episode_data_index"]["to"][-1].item())
|
||||
),
|
||||
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
||||
fps=env.unwrapped.metadata["render_fps"],
|
||||
)
|
||||
if episode_data is None:
|
||||
episode_data = this_episode_data
|
||||
else:
|
||||
# Some sanity checks to make sure we are not correctly compiling the data.
|
||||
assert (
|
||||
episode_data["hf_dataset"]["episode_index"][-1] + 1
|
||||
== this_episode_data["hf_dataset"]["episode_index"][0]
|
||||
)
|
||||
assert (
|
||||
episode_data["hf_dataset"]["index"][-1] + 1 == this_episode_data["hf_dataset"]["index"][0]
|
||||
)
|
||||
assert torch.equal(
|
||||
episode_data["episode_data_index"]["to"][-1],
|
||||
this_episode_data["episode_data_index"]["from"][0],
|
||||
)
|
||||
# Some sanity checks to make sure we are correctly compiling the data.
|
||||
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
|
||||
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
|
||||
# Concatenate the episode data.
|
||||
episode_data = {
|
||||
"hf_dataset": concatenate_datasets(
|
||||
[episode_data["hf_dataset"], this_episode_data["hf_dataset"]]
|
||||
),
|
||||
"episode_data_index": {
|
||||
k: torch.cat(
|
||||
[
|
||||
episode_data["episode_data_index"][k],
|
||||
this_episode_data["episode_data_index"][k],
|
||||
]
|
||||
)
|
||||
for k in ["from", "to"]
|
||||
},
|
||||
}
|
||||
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
|
||||
|
||||
# Maybe render video for visualization.
|
||||
if max_episodes_rendered > 0 and len(ep_frames) > 0:
|
||||
|
@ -434,89 +407,39 @@ def _compile_episode_data(
|
|||
Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`).
|
||||
"""
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
total_frames = 0
|
||||
data_index_from = start_data_index
|
||||
for ep_ix in range(rollout_data["action"].shape[0]):
|
||||
num_frames = done_indices[ep_ix].item() + 1 # + 1 to include the first done frame
|
||||
# + 2 to include the first done frame and the last observation frame.
|
||||
num_frames = done_indices[ep_ix].item() + 2
|
||||
total_frames += num_frames
|
||||
|
||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
|
||||
ep_dict = {
|
||||
"action": rollout_data["action"][ep_ix, :num_frames],
|
||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * num_frames),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
"next.done": rollout_data["done"][ep_ix, :num_frames],
|
||||
"next.reward": rollout_data["reward"][ep_ix, :num_frames].type(torch.float32),
|
||||
"action": rollout_data["action"][ep_ix, : num_frames - 1],
|
||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
|
||||
"frame_index": torch.arange(0, num_frames - 1, 1),
|
||||
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
|
||||
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
|
||||
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
|
||||
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
|
||||
}
|
||||
|
||||
# For the last observation frame, all other keys will just be copy padded.
|
||||
for k in ep_dict:
|
||||
ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]])
|
||||
|
||||
for key in rollout_data["observation"]:
|
||||
ep_dict[key] = rollout_data["observation"][key][ep_ix][:num_frames]
|
||||
ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames]
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(data_index_from)
|
||||
episode_data_index["to"].append(data_index_from + num_frames)
|
||||
|
||||
data_index_from += num_frames
|
||||
|
||||
data_dict = {}
|
||||
for key in ep_dicts[0]:
|
||||
if "image" not in key:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for img in ep_dict[key]:
|
||||
# sanity check that images are channel first
|
||||
c, h, w = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
|
||||
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
|
||||
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
|
||||
|
||||
# from float32 in range [0,1] to uint8 in range [0,255]
|
||||
img *= 255
|
||||
img = img.type(torch.uint8)
|
||||
|
||||
# convert to channel last and numpy as expected by PIL
|
||||
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
|
||||
|
||||
data_dict[key].append(img)
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
||||
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
|
||||
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
|
||||
|
||||
# TODO(rcadene): clean this
|
||||
features = {}
|
||||
for key in rollout_data["observation"]:
|
||||
if "image" in key:
|
||||
features[key] = Image()
|
||||
else:
|
||||
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
|
||||
features.update(
|
||||
{
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
}
|
||||
)
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return {
|
||||
"hf_dataset": hf_dataset,
|
||||
"episode_data_index": episode_data_index,
|
||||
}
|
||||
return data_dict
|
||||
|
||||
|
||||
def main(
|
||||
|
|
|
@ -15,20 +15,25 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from threading import Lock
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
|
||||
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
|
@ -107,6 +112,7 @@ def update_policy(
|
|||
grad_scaler: GradScaler,
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
):
|
||||
"""Returns a dictionary of items for logging."""
|
||||
start_time = time.perf_counter()
|
||||
|
@ -129,7 +135,8 @@ def update_policy(
|
|||
|
||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||
grad_scaler.step(optimizer)
|
||||
with lock if lock is not None else nullcontext():
|
||||
grad_scaler.step(optimizer)
|
||||
# Updates the scale for next iteration.
|
||||
grad_scaler.update()
|
||||
|
||||
|
@ -149,11 +156,12 @@ def update_policy(
|
|||
"update_s": time.perf_counter() - start_time,
|
||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||
}
|
||||
info.update({k: v for k, v in output_dict.items() if k not in info})
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||
def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
|
||||
loss = info["loss"]
|
||||
grad_norm = info["grad_norm"]
|
||||
lr = info["lr"]
|
||||
|
@ -187,12 +195,12 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
|||
info["num_samples"] = num_samples
|
||||
info["num_episodes"] = num_episodes
|
||||
info["num_epochs"] = num_epochs
|
||||
info["is_offline"] = is_offline
|
||||
info["is_online"] = is_online
|
||||
|
||||
logger.log_dict(info, step, mode="train")
|
||||
|
||||
|
||||
def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
||||
def log_eval_info(logger, info, step, cfg, dataset, is_online):
|
||||
eval_s = info["eval_s"]
|
||||
avg_sum_reward = info["avg_sum_reward"]
|
||||
pc_success = info["pc_success"]
|
||||
|
@ -221,7 +229,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
|||
info["num_samples"] = num_samples
|
||||
info["num_episodes"] = num_episodes
|
||||
info["num_epochs"] = num_epochs
|
||||
info["is_offline"] = is_offline
|
||||
info["is_online"] = is_online
|
||||
|
||||
logger.log_dict(info, step, mode="eval")
|
||||
|
||||
|
@ -234,6 +242,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
init_logging()
|
||||
|
||||
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
|
||||
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
||||
|
||||
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
|
||||
# to check for any differences between the provided config and the checkpoint's config.
|
||||
if cfg.resume:
|
||||
|
@ -279,9 +290,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
# log metrics to terminal and wandb
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
|
||||
if cfg.training.online_steps > 0:
|
||||
raise NotImplementedError("Online training is not implemented yet.")
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
# Check device is available
|
||||
|
@ -336,7 +344,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def evaluate_and_checkpoint_if_needed(step):
|
||||
def evaluate_and_checkpoint_if_needed(step, is_online):
|
||||
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
step_identifier = f"{step:0{_num_digits}d}"
|
||||
|
||||
|
@ -352,7 +360,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_online=is_online)
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
@ -396,8 +404,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
offline_step = 0
|
||||
for _ in range(step, cfg.training.offline_steps):
|
||||
if step == 0:
|
||||
if offline_step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
@ -420,13 +429,207 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
train_info["dataloading_s"] = dataloading_s
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
# so we pass in step + 1.
|
||||
evaluate_and_checkpoint_if_needed(step + 1)
|
||||
evaluate_and_checkpoint_if_needed(step + 1, is_online=False)
|
||||
|
||||
step += 1
|
||||
offline_step += 1 # noqa: SIM113
|
||||
|
||||
if cfg.training.online_steps == 0:
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
return
|
||||
|
||||
# Online training.
|
||||
|
||||
# Create an env dedicated to online episodes collection from policy rollout.
|
||||
online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
|
||||
resolve_delta_timestamps(cfg)
|
||||
online_buffer_path = logger.log_dir / "online_buffer"
|
||||
if cfg.resume and not online_buffer_path.exists():
|
||||
# If we are resuming a run, we default to the data shapes and buffer capacity from the saved online
|
||||
# buffer.
|
||||
logging.warning(
|
||||
"When online training is resumed, we load the latest online buffer from the prior run, "
|
||||
"and this might not coincide with the state of the buffer as it was at the moment the checkpoint "
|
||||
"was made. This is because the online buffer is updated on disk during training, independently "
|
||||
"of our explicit checkpointing mechanisms."
|
||||
)
|
||||
online_dataset = OnlineBuffer(
|
||||
online_buffer_path,
|
||||
data_spec={
|
||||
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.input_shapes.items()},
|
||||
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
|
||||
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
|
||||
"next.done": {"shape": (), "dtype": np.dtype("?")},
|
||||
"next.success": {"shape": (), "dtype": np.dtype("?")},
|
||||
},
|
||||
buffer_capacity=cfg.training.online_buffer_capacity,
|
||||
fps=online_env.unwrapped.metadata["render_fps"],
|
||||
delta_timestamps=cfg.training.delta_timestamps,
|
||||
)
|
||||
|
||||
# If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this
|
||||
# makes it possible to do online rollouts in parallel with training updates).
|
||||
online_rollout_policy = deepcopy(policy) if cfg.training.do_online_rollout_async else policy
|
||||
|
||||
# Create dataloader for online training.
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
sampler_weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
|
||||
online_dataset=online_dataset,
|
||||
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||
# this final observation in the offline datasets, but we might add them in future.
|
||||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||
)
|
||||
sampler = torch.utils.data.WeightedRandomSampler(
|
||||
sampler_weights,
|
||||
num_samples=len(concat_dataset),
|
||||
replacement=True,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.training.num_workers,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
# Lock and thread pool executor for asynchronous online rollouts. When asynchronous mode is disabled,
|
||||
# these are still used but effectively do nothing.
|
||||
lock = Lock()
|
||||
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
|
||||
# parallelization of rollouts is handled within the job.
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
online_step = 0
|
||||
online_rollout_s = 0 # time take to do online rollout
|
||||
update_online_buffer_s = 0 # time taken to update the online buffer with the online rollout data
|
||||
# Time taken waiting for the online buffer to finish being updated. This is relevant when using the async
|
||||
# online rollout option.
|
||||
await_update_online_buffer_s = 0
|
||||
rollout_start_seed = cfg.training.online_env_seed
|
||||
|
||||
while True:
|
||||
if online_step == cfg.training.online_steps:
|
||||
break
|
||||
|
||||
if online_step == 0:
|
||||
logging.info("Start online training by interacting with environment")
|
||||
|
||||
def sample_trajectory_and_update_buffer():
|
||||
nonlocal rollout_start_seed
|
||||
with lock:
|
||||
online_rollout_policy.load_state_dict(policy.state_dict())
|
||||
online_rollout_policy.eval()
|
||||
start_rollout_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
online_env,
|
||||
online_rollout_policy,
|
||||
n_episodes=cfg.training.online_rollout_n_episodes,
|
||||
max_episodes_rendered=min(10, cfg.training.online_rollout_n_episodes),
|
||||
videos_dir=logger.log_dir / "online_rollout_videos",
|
||||
return_episode_data=True,
|
||||
start_seed=(
|
||||
rollout_start_seed := (rollout_start_seed + cfg.training.batch_size) % 1000000
|
||||
),
|
||||
)
|
||||
online_rollout_s = time.perf_counter() - start_rollout_time
|
||||
|
||||
with lock:
|
||||
start_update_buffer_time = time.perf_counter()
|
||||
online_dataset.add_data(eval_info["episodes"])
|
||||
|
||||
# Update the concatenated dataset length used during sampling.
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||
|
||||
# Update the sampling weights.
|
||||
sampler.weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
|
||||
online_dataset=online_dataset,
|
||||
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||
# this final observation in the offline datasets, but we might add them in future.
|
||||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||
)
|
||||
sampler.num_samples = len(concat_dataset)
|
||||
|
||||
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
|
||||
|
||||
return online_rollout_s, update_online_buffer_s
|
||||
|
||||
future = executor.submit(sample_trajectory_and_update_buffer)
|
||||
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
|
||||
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
|
||||
if (
|
||||
not cfg.training.do_online_rollout_async
|
||||
or len(online_dataset) <= cfg.training.online_buffer_seed_size
|
||||
):
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
|
||||
if len(online_dataset) <= cfg.training.online_buffer_seed_size:
|
||||
logging.info(
|
||||
f"Seeding online buffer: {len(online_dataset)}/{cfg.training.online_buffer_seed_size}"
|
||||
)
|
||||
continue
|
||||
|
||||
policy.train()
|
||||
for _ in range(cfg.training.online_steps_between_rollouts):
|
||||
with lock:
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.training.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
lock=lock,
|
||||
)
|
||||
|
||||
train_info["dataloading_s"] = dataloading_s
|
||||
train_info["online_rollout_s"] = online_rollout_s
|
||||
train_info["update_online_buffer_s"] = update_online_buffer_s
|
||||
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
|
||||
with lock:
|
||||
train_info["online_buffer_size"] = len(online_dataset)
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
# so we pass in step + 1.
|
||||
evaluate_and_checkpoint_if_needed(step + 1, is_online=True)
|
||||
|
||||
step += 1
|
||||
online_step += 1
|
||||
|
||||
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
|
||||
# to do the next batch of rollouts.
|
||||
if future.running():
|
||||
start = time.perf_counter()
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
await_update_online_buffer_s = time.perf_counter() - start
|
||||
|
||||
if online_step >= cfg.training.online_steps:
|
||||
break
|
||||
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:376c501d2780c7204850b58210a5a9476347cf9b5afb8f45b185d23ad6b5be4d
|
||||
size 928
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:213310c6fceffa9fd31066b87b9305484ff7289051a67f3a2490c39640fe7e28
|
||||
size 16904
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2499552badd9201bc73bfa91ba881f43cecdfd93c6f3ca14d3aaf753828a0f4d
|
||||
size 240
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9ff50ee7b750022e17d867ff20307259da8d99a48fff440499e8ca9b4cf42a4a
|
||||
size 36312
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:81457cfd193d9d46b6871071a3971c2901fefa544ab225576132772087b4cf3a
|
||||
size 472
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0
|
||||
size 16904
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee
|
||||
size 240
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6
|
||||
size 36312
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6cdb181ba6acc4aa1209a9ea5dd783f077ff87760257de1026c33f8e2fb2b2b1
|
||||
size 472
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0
|
||||
size 16904
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee
|
||||
size 240
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6
|
||||
size 36312
|
|
@ -108,7 +108,8 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
|||
|
||||
if __name__ == "__main__":
|
||||
env_policies = [
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=false"], ""),
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
|
||||
# (
|
||||
# "pusht",
|
||||
# "diffusion",
|
||||
|
|
|
@ -0,0 +1,320 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.d
|
||||
from copy import deepcopy
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
|
||||
# Some constants for OnlineBuffer tests.
|
||||
data_key = "data"
|
||||
data_shape = (2, 3) # just some arbitrary > 1D shape
|
||||
buffer_capacity = 100
|
||||
fps = 10
|
||||
|
||||
|
||||
def make_new_buffer(
|
||||
write_dir: str | None = None, delta_timestamps: dict[str, list[float]] | None = None
|
||||
) -> tuple[OnlineBuffer, str]:
|
||||
if write_dir is None:
|
||||
write_dir = f"/tmp/online_buffer_{uuid4().hex}"
|
||||
buffer = OnlineBuffer(
|
||||
write_dir,
|
||||
data_spec={data_key: {"shape": data_shape, "dtype": np.dtype("float32")}},
|
||||
buffer_capacity=buffer_capacity,
|
||||
fps=fps,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
return buffer, write_dir
|
||||
|
||||
|
||||
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]:
|
||||
new_data = {
|
||||
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape),
|
||||
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode),
|
||||
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes),
|
||||
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes),
|
||||
}
|
||||
return new_data
|
||||
|
||||
|
||||
def test_non_mutate():
|
||||
"""Checks that the data provided to the add_data method is copied rather than passed by reference.
|
||||
|
||||
This means that mutating the data in the buffer does not mutate the original data.
|
||||
|
||||
NOTE: If this test fails, it means some of the other tests may be compromised. For example, we can't trust
|
||||
a success case for `test_write_read`.
|
||||
"""
|
||||
buffer, _ = make_new_buffer()
|
||||
new_data = make_spoof_data_frames(2, buffer_capacity // 4)
|
||||
new_data_copy = deepcopy(new_data)
|
||||
buffer.add_data(new_data)
|
||||
buffer._data[data_key][:] += 1
|
||||
assert all(np.array_equal(new_data[k], new_data_copy[k]) for k in new_data)
|
||||
|
||||
|
||||
def test_index_error_no_data():
|
||||
buffer, _ = make_new_buffer()
|
||||
with pytest.raises(IndexError):
|
||||
buffer[0]
|
||||
|
||||
|
||||
def test_index_error_with_data():
|
||||
buffer, _ = make_new_buffer()
|
||||
n_frames = buffer_capacity // 2
|
||||
new_data = make_spoof_data_frames(1, n_frames)
|
||||
buffer.add_data(new_data)
|
||||
with pytest.raises(IndexError):
|
||||
buffer[n_frames]
|
||||
with pytest.raises(IndexError):
|
||||
buffer[-n_frames - 1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("do_reload", [False, True])
|
||||
def test_write_read(do_reload: bool):
|
||||
"""Checks that data can be added to the buffer and read back.
|
||||
|
||||
If do_reload we delete the buffer object and load the buffer back from disk before reading.
|
||||
"""
|
||||
buffer, write_dir = make_new_buffer()
|
||||
n_episodes = 2
|
||||
n_frames_per_episode = buffer_capacity // 4
|
||||
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
|
||||
buffer.add_data(new_data)
|
||||
|
||||
if do_reload:
|
||||
del buffer
|
||||
buffer, _ = make_new_buffer(write_dir)
|
||||
|
||||
assert len(buffer) == n_frames_per_episode * n_episodes
|
||||
for i, item in enumerate(buffer):
|
||||
assert all(isinstance(item[k], torch.Tensor) for k in item)
|
||||
assert np.array_equal(item[data_key].numpy(), new_data[data_key][i])
|
||||
|
||||
|
||||
def test_read_data_key():
|
||||
"""Tests that data can be added to a buffer and all data for a. specific key can be read back."""
|
||||
buffer, _ = make_new_buffer()
|
||||
n_episodes = 2
|
||||
n_frames_per_episode = buffer_capacity // 4
|
||||
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
|
||||
buffer.add_data(new_data)
|
||||
|
||||
data_from_buffer = buffer.get_data_by_key(data_key)
|
||||
assert isinstance(data_from_buffer, torch.Tensor)
|
||||
assert np.array_equal(data_from_buffer.numpy(), new_data[data_key])
|
||||
|
||||
|
||||
def test_fifo():
|
||||
"""Checks that if data is added beyond the buffer capacity, we discard the oldest data first."""
|
||||
buffer, _ = make_new_buffer()
|
||||
n_frames_per_episode = buffer_capacity // 4
|
||||
n_episodes = 3
|
||||
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
|
||||
buffer.add_data(new_data)
|
||||
n_more_episodes = 2
|
||||
# Developer sanity check (in case someone changes the global `buffer_capacity`).
|
||||
assert (
|
||||
n_episodes + n_more_episodes
|
||||
) * n_frames_per_episode > buffer_capacity, "Something went wrong with the test code."
|
||||
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
|
||||
buffer.add_data(more_new_data)
|
||||
assert len(buffer) == buffer_capacity, "The buffer should be full."
|
||||
|
||||
expected_data = {}
|
||||
for k in new_data:
|
||||
# Concatenate, left-truncate, then roll, to imitate the cyclical FIFO pattern in OnlineBuffer.
|
||||
expected_data[k] = np.roll(
|
||||
np.concatenate([new_data[k], more_new_data[k]])[-buffer_capacity:],
|
||||
shift=len(new_data[k]) + len(more_new_data[k]) - buffer_capacity,
|
||||
axis=0,
|
||||
)
|
||||
|
||||
for i, item in enumerate(buffer):
|
||||
assert all(isinstance(item[k], torch.Tensor) for k in item)
|
||||
assert np.array_equal(item[data_key].numpy(), expected_data[data_key][i])
|
||||
|
||||
|
||||
def test_delta_timestamps_within_tolerance():
|
||||
"""Check that getting an item with delta_timestamps within tolerance succeeds.
|
||||
|
||||
Note: Copied from `test_datasets.py::test_load_previous_and_future_frames_within_tolerance`.
|
||||
"""
|
||||
# Sanity check on global fps as we are assuming it is 10 here.
|
||||
assert fps == 10, "This test assumes fps==10"
|
||||
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.139]})
|
||||
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
|
||||
buffer.add_data(new_data)
|
||||
buffer.tolerance_s = 0.04
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
|
||||
assert torch.allclose(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
|
||||
assert not is_pad.any(), "Unexpected padding detected"
|
||||
|
||||
|
||||
def test_delta_timestamps_outside_tolerance_inside_episode_range():
|
||||
"""Check that getting an item with delta_timestamps outside of tolerance fails.
|
||||
|
||||
We expect it to fail if and only if the requested timestamps are within the episode range.
|
||||
|
||||
Note: Copied from
|
||||
`test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_inside_episode_range`
|
||||
"""
|
||||
# Sanity check on global fps as we are assuming it is 10 here.
|
||||
assert fps == 10, "This test assumes fps==10"
|
||||
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.141]})
|
||||
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
|
||||
buffer.add_data(new_data)
|
||||
buffer.tolerance_s = 0.04
|
||||
with pytest.raises(AssertionError):
|
||||
buffer[2]
|
||||
|
||||
|
||||
def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||
"""Check that copy-padding of timestamps outside of the episode range works.
|
||||
|
||||
Note: Copied from
|
||||
`test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_outside_episode_range`
|
||||
"""
|
||||
# Sanity check on global fps as we are assuming it is 10 here.
|
||||
assert fps == 10, "This test assumes fps==10"
|
||||
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.3, -0.24, 0, 0.26, 0.3]})
|
||||
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
|
||||
buffer.add_data(new_data)
|
||||
buffer.tolerance_s = 0.04
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||
assert torch.equal(
|
||||
is_pad, torch.tensor([True, False, False, True, True])
|
||||
), "Padding does not match expected values"
|
||||
|
||||
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
@pytest.mark.parametrize("offline_dataset_size", [0, 6])
|
||||
@pytest.mark.parametrize("online_dataset_size", [0, 4])
|
||||
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
|
||||
def test_compute_sampler_weights_trivial(
|
||||
offline_dataset_size: int, online_dataset_size: int, online_sampling_ratio: float
|
||||
):
|
||||
# Pass/skip the test if both datasets sizes are zero.
|
||||
if offline_dataset_size + online_dataset_size == 0:
|
||||
return
|
||||
# Create spoof offline dataset.
|
||||
offline_dataset = LeRobotDataset.from_preloaded(
|
||||
hf_dataset=Dataset.from_dict({"data": list(range(offline_dataset_size))})
|
||||
)
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
if offline_dataset_size == 0:
|
||||
offline_dataset.episode_data_index = {}
|
||||
else:
|
||||
# Set up an episode_data_index with at least two episodes.
|
||||
offline_dataset.episode_data_index = {
|
||||
"from": torch.tensor([0, offline_dataset_size // 2]),
|
||||
"to": torch.tensor([offline_dataset_size // 2, offline_dataset_size]),
|
||||
}
|
||||
# Create spoof online datset.
|
||||
online_dataset, _ = make_new_buffer()
|
||||
if online_dataset_size > 0:
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2)
|
||||
)
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
)
|
||||
if offline_dataset_size == 0 or online_dataset_size == 0:
|
||||
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
||||
elif online_sampling_ratio == 0:
|
||||
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)])
|
||||
elif online_sampling_ratio == 1:
|
||||
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
|
||||
expected_weights /= expected_weights.sum()
|
||||
assert torch.allclose(weights, expected_weights)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio():
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
# Create spoof offline dataset.
|
||||
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
offline_dataset.episode_data_index = {
|
||||
"from": torch.tensor([0, 2]),
|
||||
"to": torch.tensor([2, 4]),
|
||||
}
|
||||
# Create spoof online datset.
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_sampling_ratio = 0.8
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
)
|
||||
assert torch.allclose(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
|
||||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n():
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
# Create spoof offline dataset.
|
||||
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
offline_dataset.episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([4]),
|
||||
}
|
||||
# Create spoof online datset.
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
|
||||
)
|
||||
assert torch.allclose(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
|
||||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_drop_n_last_frames():
|
||||
"""Note: test copied from test_sampler."""
|
||||
data_dict = {
|
||||
"timestamp": [0, 0.1],
|
||||
"index": [0, 1],
|
||||
"episode_index": [0, 0],
|
||||
"frame_index": [0, 1],
|
||||
}
|
||||
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict(data_dict))
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
offline_dataset.episode_data_index = {"from": torch.tensor([0]), "to": torch.tensor([2])}
|
||||
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=1,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=0.5,
|
||||
online_drop_n_last_frames=1,
|
||||
)
|
||||
assert torch.allclose(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
|
|
@ -357,7 +357,8 @@ def test_normalize(insert_temporal_dim):
|
|||
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
|
||||
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
|
||||
# to test with `policy.use_mpc=false`.
|
||||
("xarm", "tdmpc", ["policy.use_mpc=false"], ""),
|
||||
("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
|
||||
(
|
||||
"pusht",
|
||||
"diffusion",
|
||||
|
|
Loading…
Reference in New Issue