fixed bug in crop_dataset_roi.py
added missing buffer.pt in server dir Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
e0527b4a6b
commit
7d5a9530f7
|
@ -0,0 +1,560 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import random
|
||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
complementary_info: dict[str, Any] = None
|
||||
|
||||
|
||||
class BatchTransition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
|
||||
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
# Move state tensors to CPU
|
||||
transition["state"] = {key: val.to(device, non_blocking=True) for key, val in transition["state"].items()}
|
||||
|
||||
# Move action to CPU
|
||||
transition["action"] = transition["action"].to(device, non_blocking=True)
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
|
||||
# Move next_state tensors to CPU
|
||||
transition["next_state"] = {
|
||||
key: val.to(device, non_blocking=True) for key, val in transition["next_state"].items()
|
||||
}
|
||||
|
||||
# If complementary_info is present, move its tensors to CPU
|
||||
if transition["complementary_info"] is not None:
|
||||
transition["complementary_info"] = {
|
||||
key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
|
||||
}
|
||||
return transition
|
||||
|
||||
|
||||
def move_state_dict_to_device(state_dict, device):
|
||||
"""
|
||||
Recursively move all tensors in a (potentially) nested
|
||||
dict/list/tuple structure to the CPU.
|
||||
"""
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
return state_dict.to(device)
|
||||
elif isinstance(state_dict, dict):
|
||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
||||
elif isinstance(state_dict, list):
|
||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||
elif isinstance(state_dict, tuple):
|
||||
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
|
||||
else:
|
||||
return state_dict
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
"""
|
||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||
(Same as shown previously.)
|
||||
"""
|
||||
B, C, H, W = images.shape # noqa: N806
|
||||
crop_h, crop_w = output_size
|
||||
|
||||
if crop_h > H or crop_w > W:
|
||||
raise ValueError(
|
||||
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
|
||||
)
|
||||
|
||||
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
|
||||
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
|
||||
|
||||
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
|
||||
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
|
||||
|
||||
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
|
||||
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
|
||||
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
return cropped
|
||||
|
||||
|
||||
def random_shift(images: torch.Tensor, pad: int = 4):
|
||||
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
|
||||
_, _, h, w = images.shape
|
||||
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
|
||||
return random_crop_vectorized(images=images, output_size=(h, w))
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
|
||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.memory: list[Transition] = []
|
||||
self.position = 0
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
# (you can handle this differently if needed)
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
if image_augmentation_function is None:
|
||||
self.image_augmentation_function = functools.partial(random_shift, pad=4)
|
||||
self.use_drq = use_drq
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
|
||||
def add(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
reward: float,
|
||||
next_state: dict[str, torch.Tensor],
|
||||
done: bool,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Saves a transition."""
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append(None)
|
||||
|
||||
# Create and store the Transition
|
||||
self.memory[self.position] = Transition(
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
self.position: int = (self.position + 1) % self.capacity
|
||||
|
||||
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
lerobot_dataset: LeRobotDataset,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
capacity: Optional[int] = None,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
|
||||
Args:
|
||||
lerobot_dataset (LeRobotDataset): The dataset to convert.
|
||||
device (str): The device . Defaults to "cuda:0".
|
||||
state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: The replay buffer with offline dataset transitions.
|
||||
"""
|
||||
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
|
||||
# a replay buffer than from a lerobot dataset.
|
||||
if capacity is None:
|
||||
capacity = len(lerobot_dataset)
|
||||
|
||||
if capacity < len(lerobot_dataset):
|
||||
raise ValueError(
|
||||
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
|
||||
)
|
||||
|
||||
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
# Fill the replay buffer with the lerobot dataset transitions
|
||||
for data in list_transition:
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
for key, tensor in v.items():
|
||||
v[key] = tensor.to(device)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(device)
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=data["action"],
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
)
|
||||
return replay_buffer
|
||||
|
||||
@staticmethod
|
||||
def _lerobotdataset_to_transitions(
|
||||
dataset: LeRobotDataset,
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
) -> list[Transition]:
|
||||
"""
|
||||
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
||||
|
||||
Args:
|
||||
dataset (LeRobotDataset):
|
||||
The dataset to convert. Each item in the dataset is expected to have
|
||||
at least the following keys:
|
||||
{
|
||||
"action": ...
|
||||
"next.reward": ...
|
||||
"next.done": ...
|
||||
"episode_index": ...
|
||||
}
|
||||
plus whatever your 'state_keys' specify.
|
||||
|
||||
state_keys (Optional[Sequence[str]]):
|
||||
The dataset keys to include in 'state' and 'next_state'. Their names
|
||||
will be kept as-is in the output transitions. E.g.
|
||||
["observation.state", "observation.environment_state"].
|
||||
If None, you must handle or define default keys.
|
||||
|
||||
Returns:
|
||||
transitions (List[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
|
||||
# If not provided, you can either raise an error or define a default:
|
||||
if state_keys is None:
|
||||
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
|
||||
|
||||
transitions: list[Transition] = []
|
||||
num_frames = len(dataset)
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
current_sample = dataset[i]
|
||||
|
||||
# ----- 1) Current state -----
|
||||
current_state: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = current_sample[key]
|
||||
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 2) Action -----
|
||||
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
|
||||
next_state = current_state # default
|
||||
if not done and (i < num_frames - 1):
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] == current_sample["episode_index"]:
|
||||
# Build next_state from the same keys
|
||||
next_state_data: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = next_sample[key]
|
||||
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
||||
next_state = next_state_data
|
||||
|
||||
# ----- Construct the Transition -----
|
||||
transition = Transition(
|
||||
state=current_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
return transitions
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
list_of_transitions = random.sample(self.memory, batch_size)
|
||||
|
||||
# -- Build batched states --
|
||||
batch_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
||||
|
||||
# -- Build batched actions --
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# -- Build batched next states --
|
||||
batch_next_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
|
||||
|
||||
# -- Build batched dones --
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
)
|
||||
|
||||
def to_lerobot_dataset(
|
||||
self,
|
||||
repo_id: str,
|
||||
fps=1, # If you have real timestamps, adjust this
|
||||
root=None,
|
||||
task_name="from_replay_buffer",
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object,
|
||||
splitting episodes by transitions where 'done=True'.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: The resulting offline dataset.
|
||||
"""
|
||||
if len(self.memory) == 0:
|
||||
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
|
||||
|
||||
# Infer the shapes and dtypes of your features
|
||||
# We'll create a features dict that is suitable for LeRobotDataset
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# First, grab one transition to inspect shapes
|
||||
first_transition = self.memory[0]
|
||||
|
||||
# We'll store default metadata for every episode: indexes, timestamps, etc.
|
||||
features = {
|
||||
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
|
||||
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
|
||||
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
|
||||
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
|
||||
"task_index": {"dtype": "int64", "shape": [1]},
|
||||
}
|
||||
|
||||
# Add "action"
|
||||
act_info = guess_feature_info(
|
||||
first_transition["action"].squeeze(dim=0), "action"
|
||||
) # Remove batch dimension
|
||||
features["action"] = act_info
|
||||
|
||||
# Add "reward" (scalars)
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
||||
|
||||
# Add "done" (boolean scalars)
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,)}
|
||||
|
||||
# Add state keys
|
||||
for key in self.state_keys:
|
||||
sample_val = first_transition["state"][key].squeeze(dim=0) # Remove batch dimension
|
||||
if not isinstance(sample_val, torch.Tensor):
|
||||
raise ValueError(
|
||||
f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors."
|
||||
)
|
||||
f_info = guess_feature_info(sample_val, key)
|
||||
features[key] = f_info
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Create an empty LeRobotDataset
|
||||
# We'll store all frames as separate images only if we detect shape = (3, H, W) or (1, H, W).
|
||||
# By default we won't do videos, but feel free to adapt if you have them.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps, # If you have real timestamps, adjust this
|
||||
root=root, # Or some local path where you'd like the dataset files to go
|
||||
robot=None,
|
||||
robot_type=None,
|
||||
features=features,
|
||||
use_videos=True, # We won't do actual video encoding for a replay buffer
|
||||
)
|
||||
|
||||
# Start writing images if needed. If you have no image features, this is harmless.
|
||||
# Set num_processes or num_threads if you want concurrency.
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=2)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Convert transitions into episodes and frames
|
||||
# We detect episode boundaries by `done == True`.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
episode_index = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for global_frame_idx, transition in enumerate(self.memory):
|
||||
frame_dict = {}
|
||||
|
||||
# Fill the data for state keys
|
||||
for key in self.state_keys:
|
||||
# Expand dimension to match what the dataset expects (the dataset wants the raw shape)
|
||||
# We assume your buffer has shape [C, H, W] (if image) or [D] if vector
|
||||
# This is typically already correct, but if needed you can reshape below.
|
||||
frame_dict[key] = transition["state"][key].cpu().squeeze(dim=0) # Remove batch dimension
|
||||
|
||||
# Fill action, reward, done
|
||||
# Make sure they are shape (X,) or (X,Y,...) as needed.
|
||||
frame_dict["action"] = transition["action"].cpu().squeeze(dim=0) # Remove batch dimension
|
||||
frame_dict["next.reward"] = (
|
||||
torch.tensor([transition["reward"]], dtype=torch.float32).cpu().squeeze(dim=0)
|
||||
)
|
||||
frame_dict["next.done"] = (
|
||||
torch.tensor([transition["done"]], dtype=torch.bool).cpu().squeeze(dim=0)
|
||||
)
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict)
|
||||
|
||||
# Move to next frame
|
||||
frame_idx_in_episode += 1
|
||||
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
if transition["done"]:
|
||||
# Use some placeholder name for the task
|
||||
lerobot_dataset.save_episode(task="from_replay_buffer")
|
||||
episode_index += 1
|
||||
frame_idx_in_episode = 0
|
||||
# Start a new buffer for the next episode
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
|
||||
|
||||
# We are done adding frames
|
||||
# If the last transition wasn't done=True, we still have an open buffer with frames.
|
||||
# We'll consider that an incomplete episode and still save it:
|
||||
if lerobot_dataset.episode_buffer["size"] > 0:
|
||||
lerobot_dataset.save_episode(task=task_name)
|
||||
|
||||
lerobot_dataset.stop_image_writer()
|
||||
|
||||
lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False)
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
# Utility function to guess shapes/dtypes from a tensor
|
||||
def guess_feature_info(t: torch.Tensor, name: str):
|
||||
"""
|
||||
Return a dictionary with the 'dtype' and 'shape' for a given tensor or array.
|
||||
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
|
||||
Otherwise default to 'float32' for numeric. You can customize as needed.
|
||||
"""
|
||||
shape = tuple(t.shape)
|
||||
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
|
||||
if len(shape) == 3 and shape[0] in [1, 3]:
|
||||
return {
|
||||
"dtype": "image",
|
||||
"shape": shape,
|
||||
}
|
||||
else:
|
||||
# Otherwise treat as numeric
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": shape,
|
||||
}
|
||||
|
||||
|
||||
def concatenate_batch_transitions(
|
||||
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||
) -> BatchTransition:
|
||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
)
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
)
|
||||
return left_batch_transitions
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# dataset_name = "lerobot/pusht_image"
|
||||
# dataset = LeRobotDataset(repo_id=dataset_name, episodes=range(1, 3))
|
||||
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
|
||||
# )
|
||||
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
|
||||
# for i in range(len(replay_buffer_converted)):
|
||||
# replay_convert = replay_buffer_converted[i]
|
||||
# dataset_convert = dataset[i]
|
||||
# for key in replay_convert.keys():
|
||||
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
|
||||
# continue
|
||||
# if key in dataset_convert.keys():
|
||||
# assert torch.equal(replay_convert[key], dataset_convert[key])
|
||||
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
|
||||
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
|
||||
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
|
||||
# )
|
||||
# for _ in range(20):
|
||||
# batch = re_reconverted_dataset.sample(32)
|
||||
|
||||
# for key in batch.keys():
|
||||
# if key in {"state", "next_state"}:
|
||||
# for key_state in batch[key].keys():
|
||||
# print(key_state, batch[key][key_state].size())
|
||||
# continue
|
||||
# print(key, batch[key].size())
|
|
@ -187,43 +187,39 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
|||
# 2. Process each episode in the original dataset.
|
||||
episodes_info = original_dataset.meta.episodes
|
||||
# (Sort episodes by episode_index for consistency.)
|
||||
|
||||
episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"])
|
||||
# Use the first task from the episode metadata (or "unknown" if not provided)
|
||||
task = episodes_info[0]["tasks"][0] if episodes_info[0].get("tasks") else "unknown"
|
||||
|
||||
for ep in tqdm(episodes_info[:3], desc="Processing episodes"):
|
||||
ep_index = ep.pop("episode_index")
|
||||
# Use the first task from the episode metadata (or "unknown" if not provided)
|
||||
task = ep["tasks"][0] if ep.get("tasks") else "unknown"
|
||||
last_episode_index = 0
|
||||
for sample in tqdm(original_dataset):
|
||||
episode_index = sample.pop("episode_index")
|
||||
if episode_index != last_episode_index:
|
||||
new_dataset.save_episode(task, encode_videos=True)
|
||||
last_episode_index = episode_index
|
||||
sample.pop("frame_index")
|
||||
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
|
||||
new_sample = sample.copy()
|
||||
# Loop over each observation key that should be cropped/resized.
|
||||
for key, params in crop_params_dict.items():
|
||||
if key in new_sample:
|
||||
top, left, height, width = params
|
||||
# Apply crop then resize.
|
||||
cropped = F.crop(new_sample[key], top, left, height, width)
|
||||
resized = F.resize(cropped, resize_size)
|
||||
new_sample[key] = resized
|
||||
# Add the transformed frame to the new dataset.
|
||||
new_dataset.add_frame(new_sample)
|
||||
|
||||
# Reset the episode buffer in the new dataset (this will store frames for one episode).
|
||||
new_dataset.episode_buffer = new_dataset.create_episode_buffer(episode_index=ep_index)
|
||||
|
||||
# 3. Filter and process all frames belonging to this episode.
|
||||
# Here we loop over the entire dataset and select the frames with the matching episode_index.
|
||||
# (Depending on the dataset size, you might want a more efficient method.)
|
||||
ep_frames = [sample for sample in original_dataset if sample["episode_index"] == ep_index]
|
||||
|
||||
for sample in tqdm(ep_frames):
|
||||
sample.pop("episode_index")
|
||||
sample.pop("frame_index")
|
||||
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
|
||||
new_sample = sample.copy()
|
||||
# Loop over each observation key that should be cropped/resized.
|
||||
for key, params in crop_params_dict.items():
|
||||
if key in new_sample:
|
||||
top, left, height, width = params
|
||||
# Apply crop then resize.
|
||||
cropped = F.crop(new_sample[key], top, left, height, width)
|
||||
resized = F.resize(cropped, resize_size)
|
||||
new_sample[key] = resized
|
||||
# Add the transformed frame to the new dataset.
|
||||
new_dataset.add_frame(new_sample)
|
||||
|
||||
# 4. Save the episode (this writes the parquet file and image files).
|
||||
new_dataset.save_episode(task, encode_videos=True)
|
||||
# save last episode
|
||||
new_dataset.save_episode(task, encode_videos=True)
|
||||
|
||||
# Optionally, consolidate the new dataset to compute statistics and update video info.
|
||||
new_dataset.consolidate(run_compute_stats=True, keep_image_files=True)
|
||||
|
||||
new_dataset.push_to_hub(tags=None)
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue