fixed bug in crop_dataset_roi.py

added missing buffer.pt in server dir

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-05 18:22:50 +00:00
parent e0527b4a6b
commit 7d5a9530f7
2 changed files with 586 additions and 30 deletions

View File

@ -0,0 +1,560 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import random
from typing import Any, Callable, Optional, Sequence, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
class Transition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: float
next_state: dict[str, torch.Tensor]
done: bool
complementary_info: dict[str, Any] = None
class BatchTransition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: torch.Tensor
next_state: dict[str, torch.Tensor]
done: torch.Tensor
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
# Move state tensors to CPU
transition["state"] = {key: val.to(device, non_blocking=True) for key, val in transition["state"].items()}
# Move action to CPU
transition["action"] = transition["action"].to(device, non_blocking=True)
# No need to move reward or done, as they are float and bool
# Move next_state tensors to CPU
transition["next_state"] = {
key: val.to(device, non_blocking=True) for key, val in transition["next_state"].items()
}
# If complementary_info is present, move its tensors to CPU
if transition["complementary_info"] is not None:
transition["complementary_info"] = {
key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
}
return transition
def move_state_dict_to_device(state_dict, device):
"""
Recursively move all tensors in a (potentially) nested
dict/list/tuple structure to the CPU.
"""
if isinstance(state_dict, torch.Tensor):
return state_dict.to(device)
elif isinstance(state_dict, dict):
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple):
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
else:
return state_dict
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
"""
Perform a per-image random crop over a batch of images in a vectorized way.
(Same as shown previously.)
"""
B, C, H, W = images.shape # noqa: N806
crop_h, crop_w = output_size
if crop_h > H or crop_w > W:
raise ValueError(
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
)
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
return cropped
def random_shift(images: torch.Tensor, pad: int = 4):
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
_, _, h, w = images.shape
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
return random_crop_vectorized(images=images, output_size=(h, w))
class ReplayBuffer:
def __init__(
self,
capacity: int,
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
):
"""
Args:
capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
and returns a batch of augmented images. If None, a default augmentation function is used.
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
"""
self.capacity = capacity
self.device = device
self.memory: list[Transition] = []
self.position = 0
# If no state_keys provided, default to an empty list
# (you can handle this differently if needed)
self.state_keys = state_keys if state_keys is not None else []
if image_augmentation_function is None:
self.image_augmentation_function = functools.partial(random_shift, pad=4)
self.use_drq = use_drq
def __len__(self):
return len(self.memory)
def add(
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
reward: float,
next_state: dict[str, torch.Tensor],
done: bool,
complementary_info: Optional[dict[str, torch.Tensor]] = None,
):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
# Create and store the Transition
self.memory[self.position] = Transition(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
complementary_info=complementary_info,
)
self.position: int = (self.position + 1) % self.capacity
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
@classmethod
def from_lerobot_dataset(
cls,
lerobot_dataset: LeRobotDataset,
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
capacity: Optional[int] = None,
) -> "ReplayBuffer":
"""
Convert a LeRobotDataset into a ReplayBuffer.
Args:
lerobot_dataset (LeRobotDataset): The dataset to convert.
device (str): The device . Defaults to "cuda:0".
state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`.
Defaults to None.
Returns:
ReplayBuffer: The replay buffer with offline dataset transitions.
"""
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
# a replay buffer than from a lerobot dataset.
if capacity is None:
capacity = len(lerobot_dataset)
if capacity < len(lerobot_dataset):
raise ValueError(
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
)
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
# Fill the replay buffer with the lerobot dataset transitions
for data in list_transition:
for k, v in data.items():
if isinstance(v, dict):
for key, tensor in v.items():
v[key] = tensor.to(device)
elif isinstance(v, torch.Tensor):
data[k] = v.to(device)
replay_buffer.add(
state=data["state"],
action=data["action"],
reward=data["reward"],
next_state=data["next_state"],
done=data["done"],
)
return replay_buffer
@staticmethod
def _lerobotdataset_to_transitions(
dataset: LeRobotDataset,
state_keys: Optional[Sequence[str]] = None,
) -> list[Transition]:
"""
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
Args:
dataset (LeRobotDataset):
The dataset to convert. Each item in the dataset is expected to have
at least the following keys:
{
"action": ...
"next.reward": ...
"next.done": ...
"episode_index": ...
}
plus whatever your 'state_keys' specify.
state_keys (Optional[Sequence[str]]):
The dataset keys to include in 'state' and 'next_state'. Their names
will be kept as-is in the output transitions. E.g.
["observation.state", "observation.environment_state"].
If None, you must handle or define default keys.
Returns:
transitions (List[Transition]):
A list of Transition dictionaries with the same length as `dataset`.
"""
# If not provided, you can either raise an error or define a default:
if state_keys is None:
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
transitions: list[Transition] = []
num_frames = len(dataset)
for i in tqdm(range(num_frames)):
current_sample = dataset[i]
# ----- 1) Current state -----
current_state: dict[str, torch.Tensor] = {}
for key in state_keys:
val = current_sample[key]
current_state[key] = val.unsqueeze(0) # Add batch dimension
# ----- 2) Action -----
action = current_sample["action"].unsqueeze(0) # Add batch dimension
# ----- 3) Reward and done -----
reward = float(current_sample["next.reward"].item()) # ensure float
done = bool(current_sample["next.done"].item()) # ensure bool
# ----- 4) Next state -----
# If not done and the next sample is in the same episode, we pull the next sample's state.
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
next_state = current_state # default
if not done and (i < num_frames - 1):
next_sample = dataset[i + 1]
if next_sample["episode_index"] == current_sample["episode_index"]:
# Build next_state from the same keys
next_state_data: dict[str, torch.Tensor] = {}
for key in state_keys:
val = next_sample[key]
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
next_state = next_state_data
# ----- Construct the Transition -----
transition = Transition(
state=current_state,
action=action,
reward=reward,
next_state=next_state,
done=done,
)
transitions.append(transition)
return transitions
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
list_of_transitions = random.sample(self.memory, batch_size)
# -- Build batched states --
batch_state = {}
for key in self.state_keys:
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
# -- Build batched rewards --
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
# -- Build batched next states --
batch_next_state = {}
for key in self.state_keys:
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
# -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
# Return a BatchTransition typed dict
return BatchTransition(
state=batch_state,
action=batch_actions,
reward=batch_rewards,
next_state=batch_next_state,
done=batch_dones,
)
def to_lerobot_dataset(
self,
repo_id: str,
fps=1, # If you have real timestamps, adjust this
root=None,
task_name="from_replay_buffer",
) -> LeRobotDataset:
"""
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object,
splitting episodes by transitions where 'done=True'.
Returns:
LeRobotDataset: The resulting offline dataset.
"""
if len(self.memory) == 0:
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
# Infer the shapes and dtypes of your features
# We'll create a features dict that is suitable for LeRobotDataset
# --------------------------------------------------------------------------------------------
# First, grab one transition to inspect shapes
first_transition = self.memory[0]
# We'll store default metadata for every episode: indexes, timestamps, etc.
features = {
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
"task_index": {"dtype": "int64", "shape": [1]},
}
# Add "action"
act_info = guess_feature_info(
first_transition["action"].squeeze(dim=0), "action"
) # Remove batch dimension
features["action"] = act_info
# Add "reward" (scalars)
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
# Add "done" (boolean scalars)
features["next.done"] = {"dtype": "bool", "shape": (1,)}
# Add state keys
for key in self.state_keys:
sample_val = first_transition["state"][key].squeeze(dim=0) # Remove batch dimension
if not isinstance(sample_val, torch.Tensor):
raise ValueError(
f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors."
)
f_info = guess_feature_info(sample_val, key)
features[key] = f_info
# --------------------------------------------------------------------------------------------
# Create an empty LeRobotDataset
# We'll store all frames as separate images only if we detect shape = (3, H, W) or (1, H, W).
# By default we won't do videos, but feel free to adapt if you have them.
# --------------------------------------------------------------------------------------------
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=fps, # If you have real timestamps, adjust this
root=root, # Or some local path where you'd like the dataset files to go
robot=None,
robot_type=None,
features=features,
use_videos=True, # We won't do actual video encoding for a replay buffer
)
# Start writing images if needed. If you have no image features, this is harmless.
# Set num_processes or num_threads if you want concurrency.
lerobot_dataset.start_image_writer(num_processes=0, num_threads=2)
# --------------------------------------------------------------------------------------------
# Convert transitions into episodes and frames
# We detect episode boundaries by `done == True`.
# --------------------------------------------------------------------------------------------
episode_index = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
frame_idx_in_episode = 0
for global_frame_idx, transition in enumerate(self.memory):
frame_dict = {}
# Fill the data for state keys
for key in self.state_keys:
# Expand dimension to match what the dataset expects (the dataset wants the raw shape)
# We assume your buffer has shape [C, H, W] (if image) or [D] if vector
# This is typically already correct, but if needed you can reshape below.
frame_dict[key] = transition["state"][key].cpu().squeeze(dim=0) # Remove batch dimension
# Fill action, reward, done
# Make sure they are shape (X,) or (X,Y,...) as needed.
frame_dict["action"] = transition["action"].cpu().squeeze(dim=0) # Remove batch dimension
frame_dict["next.reward"] = (
torch.tensor([transition["reward"]], dtype=torch.float32).cpu().squeeze(dim=0)
)
frame_dict["next.done"] = (
torch.tensor([transition["done"]], dtype=torch.bool).cpu().squeeze(dim=0)
)
# Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict)
# Move to next frame
frame_idx_in_episode += 1
# If we reached an episode boundary, call save_episode, reset counters
if transition["done"]:
# Use some placeholder name for the task
lerobot_dataset.save_episode(task="from_replay_buffer")
episode_index += 1
frame_idx_in_episode = 0
# Start a new buffer for the next episode
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
# We are done adding frames
# If the last transition wasn't done=True, we still have an open buffer with frames.
# We'll consider that an incomplete episode and still save it:
if lerobot_dataset.episode_buffer["size"] > 0:
lerobot_dataset.save_episode(task=task_name)
lerobot_dataset.stop_image_writer()
lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False)
return lerobot_dataset
# Utility function to guess shapes/dtypes from a tensor
def guess_feature_info(t: torch.Tensor, name: str):
"""
Return a dictionary with the 'dtype' and 'shape' for a given tensor or array.
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
Otherwise default to 'float32' for numeric. You can customize as needed.
"""
shape = tuple(t.shape)
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
if len(shape) == 3 and shape[0] in [1, 3]:
return {
"dtype": "image",
"shape": shape,
}
else:
# Otherwise treat as numeric
return {
"dtype": "float32",
"shape": shape,
}
def concatenate_batch_transitions(
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place"""
left_batch_transitions["state"] = {
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
for key in left_batch_transitions["state"]
}
left_batch_transitions["action"] = torch.cat(
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
)
left_batch_transitions["reward"] = torch.cat(
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
)
left_batch_transitions["next_state"] = {
key: torch.cat(
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
)
for key in left_batch_transitions["next_state"]
}
left_batch_transitions["done"] = torch.cat(
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
)
return left_batch_transitions
# if __name__ == "__main__":
# dataset_name = "lerobot/pusht_image"
# dataset = LeRobotDataset(repo_id=dataset_name, episodes=range(1, 3))
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
# )
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
# for i in range(len(replay_buffer_converted)):
# replay_convert = replay_buffer_converted[i]
# dataset_convert = dataset[i]
# for key in replay_convert.keys():
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
# continue
# if key in dataset_convert.keys():
# assert torch.equal(replay_convert[key], dataset_convert[key])
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
# )
# for _ in range(20):
# batch = re_reconverted_dataset.sample(32)
# for key in batch.keys():
# if key in {"state", "next_state"}:
# for key_state in batch[key].keys():
# print(key_state, batch[key][key_state].size())
# continue
# print(key, batch[key].size())

View File

@ -187,43 +187,39 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
# 2. Process each episode in the original dataset.
episodes_info = original_dataset.meta.episodes
# (Sort episodes by episode_index for consistency.)
episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"])
# Use the first task from the episode metadata (or "unknown" if not provided)
task = episodes_info[0]["tasks"][0] if episodes_info[0].get("tasks") else "unknown"
for ep in tqdm(episodes_info[:3], desc="Processing episodes"):
ep_index = ep.pop("episode_index")
# Use the first task from the episode metadata (or "unknown" if not provided)
task = ep["tasks"][0] if ep.get("tasks") else "unknown"
last_episode_index = 0
for sample in tqdm(original_dataset):
episode_index = sample.pop("episode_index")
if episode_index != last_episode_index:
new_dataset.save_episode(task, encode_videos=True)
last_episode_index = episode_index
sample.pop("frame_index")
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
new_sample = sample.copy()
# Loop over each observation key that should be cropped/resized.
for key, params in crop_params_dict.items():
if key in new_sample:
top, left, height, width = params
# Apply crop then resize.
cropped = F.crop(new_sample[key], top, left, height, width)
resized = F.resize(cropped, resize_size)
new_sample[key] = resized
# Add the transformed frame to the new dataset.
new_dataset.add_frame(new_sample)
# Reset the episode buffer in the new dataset (this will store frames for one episode).
new_dataset.episode_buffer = new_dataset.create_episode_buffer(episode_index=ep_index)
# 3. Filter and process all frames belonging to this episode.
# Here we loop over the entire dataset and select the frames with the matching episode_index.
# (Depending on the dataset size, you might want a more efficient method.)
ep_frames = [sample for sample in original_dataset if sample["episode_index"] == ep_index]
for sample in tqdm(ep_frames):
sample.pop("episode_index")
sample.pop("frame_index")
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
new_sample = sample.copy()
# Loop over each observation key that should be cropped/resized.
for key, params in crop_params_dict.items():
if key in new_sample:
top, left, height, width = params
# Apply crop then resize.
cropped = F.crop(new_sample[key], top, left, height, width)
resized = F.resize(cropped, resize_size)
new_sample[key] = resized
# Add the transformed frame to the new dataset.
new_dataset.add_frame(new_sample)
# 4. Save the episode (this writes the parquet file and image files).
new_dataset.save_episode(task, encode_videos=True)
# save last episode
new_dataset.save_episode(task, encode_videos=True)
# Optionally, consolidate the new dataset to compute statistics and update video info.
new_dataset.consolidate(run_compute_stats=True, keep_image_files=True)
new_dataset.push_to_hub(tags=None)
return new_dataset