This commit is contained in:
Remi Cadene 2025-02-20 23:04:31 +00:00
parent b520941cd9
commit 71d1f5e2c9
7 changed files with 306 additions and 92 deletions

View File

@ -17,37 +17,35 @@
For all datasets in the RLDS format. For all datasets in the RLDS format.
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets. For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
NOTE: You need to install tensorflow and tensorflow_datsets before running this script. NOTE: Install `tensorflow` and `tensorflow_datasets` before running this script.
```bash
pip install tensorflow
pip install tensorflow_datasets
```
Example: Example:
python openx_rlds.py \ ```bash
--raw-dir /path/to/bridge_orig/1.0.0 \ python examples/port_datasets/openx_rlds.py \
--local-dir /path/to/local_dir \ --raw-dir /fsx/mustafa_shukor/droid \
--repo-id your_id \ --repo-id cadene/droid \
--use-videos \ --use-videos \
--push-to-hub --push-to-hub
```
""" """
import argparse import argparse
import os
import re import re
import shutil import shutil
import sys
from functools import partial
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
import tqdm
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset from examples.port_datasets.openx_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
from examples.port_datasets.openx_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
current_dir = os.path.dirname(os.path.abspath(__file__)) from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
oxe_utils_dir = os.path.join(current_dir, "oxe_utils")
sys.path.append(oxe_utils_dir)
from oxe_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
np.set_printoptions(precision=2) np.set_printoptions(precision=2)
@ -87,16 +85,23 @@ def transform_raw_dataset(episode, dataset_name):
return episode return episode
def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bool = True): def generate_features_from_raw(dataset_name: str, builder: tfds.core.DatasetBuilder, use_videos: bool = True):
dataset_name = builder.name
state_names = [f"motor_{i}" for i in range(8)] state_names = [f"motor_{i}" for i in range(8)]
if dataset_name in OXE_DATASET_CONFIGS: if dataset_name in OXE_DATASET_CONFIGS:
state_encoding = OXE_DATASET_CONFIGS[dataset_name]["state_encoding"] state_encoding = OXE_DATASET_CONFIGS[dataset_name]["state_encoding"]
if state_encoding == StateEncoding.POS_EULER: if state_encoding == StateEncoding.POS_EULER:
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"] state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"]
if "libero" in dataset_name: if "libero" in dataset_name:
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state state_names = [
"x",
"y",
"z",
"roll",
"pitch",
"yaw",
"gripper",
"gripper",
] # 2D gripper state
elif state_encoding == StateEncoding.POS_QUAT: elif state_encoding == StateEncoding.POS_QUAT:
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"] state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
@ -126,44 +131,68 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
return {**features, **DEFAULT_FEATURES} return {**features, **DEFAULT_FEATURES}
def save_as_lerobot_dataset(lerobot_dataset: LeRobotDataset, raw_dataset: tf.data.Dataset, **kwargs): def save_as_lerobot_dataset(
for episode in raw_dataset.as_numpy_iterator(): dataset_name: str,
lerobot_dataset: LeRobotDataset,
raw_dataset: tf.data.Dataset,
num_shards: int | None = None,
shard_index: int | None = None,
):
total_num_episodes = raw_dataset.cardinality().numpy().item()
print(f"Total number of episodes {total_num_episodes}")
if num_shards is not None:
num_shards = 10000
shard_index = 9999
sharded_dataset = raw_dataset.shard(num_shards=num_shards, index=shard_index)
sharded_num_episodes = sharded_dataset.cardinality().numpy().item()
print(f"{sharded_num_episodes=}")
num_episodes = sharded_num_episodes
iter_ = iter(sharded_dataset)
else:
num_episodes = total_num_episodes
iter_ = iter(raw_dataset)
for episode_index in range(num_episodes):
print(f"{episode_index} / {num_episodes}")
episode = next(iter_)
print("\nnext\n")
episode = transform_raw_dataset(episode, dataset_name)
traj = episode["steps"] traj = episode["steps"]
for i in range(traj["action"].shape[0]): for i in tqdm.tqdm(range(traj["action"].shape[0])):
image_dict = { image_dict = {
f"observation.images.{key}": value[i] f"observation.images.{key}": value[i].numpy()
for key, value in traj["observation"].items() for key, value in traj["observation"].items()
if "depth" not in key and any(x in key for x in ["image", "rgb"]) if "depth" not in key and any(x in key for x in ["image", "rgb"])
} }
lerobot_dataset.add_frame( lerobot_dataset.add_frame(
{ {
**image_dict, **image_dict,
"observation.state": traj["proprio"][i], "observation.state": traj["proprio"][i].numpy(),
"action": traj["action"][i], "action": traj["action"][i].numpy(),
"task": traj["task"][i].numpy().decode(),
} }
) )
lerobot_dataset.save_episode(task=traj["task"][0].decode())
lerobot_dataset.consolidate( print()
run_compute_stats=True, lerobot_dataset.save_episode()
keep_image_files=kwargs["keep_images"], print("\nsave_episode\n")
stat_kwargs={"batch_size": kwargs["batch_size"], "num_workers": kwargs["num_workers"]},
) break
def create_lerobot_dataset( def create_lerobot_dataset(
raw_dir: Path, raw_dir: Path,
repo_id: str = None, repo_id: str = None,
local_dir: Path = None,
push_to_hub: bool = False, push_to_hub: bool = False,
fps: int = None, fps: int = None,
robot_type: str = None, robot_type: str = None,
use_videos: bool = True, use_videos: bool = True,
batch_size: int = 32,
num_workers: int = 8,
image_writer_process: int = 5, image_writer_process: int = 5,
image_writer_threads: int = 10, image_writer_threads: int = 10,
keep_images: bool = True, num_shards: int | None = None,
shard_index: int | None = None,
): ):
last_part = raw_dir.name last_part = raw_dir.name
if re.match(r"^\d+\.\d+\.\d+$", last_part): if re.match(r"^\d+\.\d+\.\d+$", last_part):
@ -175,15 +204,9 @@ def create_lerobot_dataset(
dataset_name = last_part dataset_name = last_part
data_dir = raw_dir.parent data_dir = raw_dir.parent
if local_dir is None:
local_dir = Path(LEROBOT_HOME)
local_dir /= f"{dataset_name}_{version}_lerobot"
if local_dir.exists():
shutil.rmtree(local_dir)
builder = tfds.builder(dataset_name, data_dir=data_dir, version=version) builder = tfds.builder(dataset_name, data_dir=data_dir, version=version)
features = generate_features_from_raw(builder, use_videos) features = generate_features_from_raw(dataset_name, builder, use_videos)
raw_dataset = builder.as_dataset(split="train").map(partial(transform_raw_dataset, dataset_name=dataset_name)) raw_dataset = builder.as_dataset(split="train")
if fps is None: if fps is None:
if dataset_name in OXE_DATASET_CONFIGS: if dataset_name in OXE_DATASET_CONFIGS:
@ -201,7 +224,6 @@ def create_lerobot_dataset(
lerobot_dataset = LeRobotDataset.create( lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id, repo_id=repo_id,
robot_type=robot_type, robot_type=robot_type,
root=local_dir,
fps=fps, fps=fps,
use_videos=use_videos, use_videos=use_videos,
features=features, features=features,
@ -210,16 +232,18 @@ def create_lerobot_dataset(
) )
save_as_lerobot_dataset( save_as_lerobot_dataset(
lerobot_dataset, raw_dataset, keep_images=keep_images, batch_size=batch_size, num_workers=num_workers dataset_name,
lerobot_dataset,
raw_dataset,
num_shards=num_shards,
shard_index=shard_index,
) )
if push_to_hub: if push_to_hub:
assert repo_id is not None assert repo_id is not None
tags = ["LeRobot", dataset_name, "rlds"] tags = []
if dataset_name in OXE_DATASET_CONFIGS: if dataset_name in OXE_DATASET_CONFIGS:
tags.append("openx") tags.append("openx")
if robot_type != "unknown":
tags.append(robot_type)
lerobot_dataset.push_to_hub( lerobot_dataset.push_to_hub(
tags=tags, tags=tags,
private=False, private=False,
@ -237,12 +261,6 @@ def main():
required=True, required=True,
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).", help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
) )
parser.add_argument(
"--local-dir",
type=Path,
required=True,
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
)
parser.add_argument( parser.add_argument(
"--repo-id", "--repo-id",
type=str, type=str,
@ -270,37 +288,25 @@ def main():
action="store_true", action="store_true",
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.", help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
) )
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size loaded by DataLoader for computing the dataset statistics.",
)
parser.add_argument(
"--num-workers",
type=int,
default=8,
help="Number of processes of Dataloader for computing the dataset statistics.",
)
parser.add_argument( parser.add_argument(
"--image-writer-process", "--image-writer-process",
type=int, type=int,
default=5, default=0,
help="Number of processes of image writer for saving images.", help="Number of processes of image writer for saving images.",
) )
parser.add_argument( parser.add_argument(
"--image-writer-threads", "--image-writer-threads",
type=int, type=int,
default=10, default=8,
help="Number of threads per process of image writer for saving images.", help="Number of threads per process of image writer for saving images.",
) )
parser.add_argument(
"--keep-images",
action="store_true",
help="Whether to keep the cached images.",
)
args = parser.parse_args() args = parser.parse_args()
droid_dir = Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene/droid")
if droid_dir.exists():
shutil.rmtree(droid_dir)
create_lerobot_dataset(**vars(args)) create_lerobot_dataset(**vars(args))

View File

@ -0,0 +1,106 @@
import datetime as dt
import shutil
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
class PortOpenXDataset(PipelineStep):
def __init__(
self,
raw_dir: Path,
repo_id: str = None,
image_writer_process: int = 0,
image_writer_threads: int = 8,
):
super().__init__()
self.raw_dir = raw_dir
self.repo_id = repo_id
self.image_writer_process = image_writer_process
self.image_writer_threads = image_writer_threads
def run(self, data=None, rank: int = 0, world_size: int = 1):
from examples.port_datasets.openx_rlds import create_lerobot_dataset
from examples.port_datasets.openx_utils.test import display_slurm_info, display_system_info
display_system_info()
display_slurm_info()
create_lerobot_dataset(
self.raw_dir,
f"{self.repo_id}_world_{world_size}_rank_{rank}",
image_writer_process=self.image_writer_process,
image_writer_threads=self.image_writer_threads,
push_to_hub=False,
num_shards=world_size,
shard_index=rank,
)
class AggregateDatasets(PipelineStep):
def run(self, data=None, rank: int = 0, world_size: int = 1):
print("aggregation")
def main(slurm=True):
for dir_ in Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene").glob("droid_world*"):
shutil.rmtree(dir_)
now = dt.datetime.now()
port_job_name = "port_openx_droid"
logs_dir = Path("/fsx/remi_cadene/logs")
port_log_dir = logs_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{port_job_name}"
if slurm:
executor_class = SlurmPipelineExecutor
dist_extra_kwargs = {
"job_name": port_job_name,
"tasks": 10000,
"workers": 24,
"time": "00:30:00",
"partition": "hopper-cpu",
"cpus_per_task": 12,
"mem_per_cpu_gb": 4,
}
else:
executor_class = LocalPipelineExecutor
dist_extra_kwargs = {
"tasks": 1,
"workers": 1,
}
port_executor = executor_class(
pipeline=[
PortOpenXDataset(raw_dir=Path("/fsx/mustafa_shukor/droid"), repo_id="cadene/droid"),
],
logging_dir=str(port_log_dir),
**dist_extra_kwargs,
)
port_executor.run()
# if slurm:
# merge_extra_kwargs = {}
# else:
# merge_extra_kwargs = {
# "job_name": "aggregate",
# "time": "00:01:00",
# "partition": "hopper-cpu",
# }
# merge_executor = executor_class(
# depends=dist_executor,
# pipeline=[
# Aggregate(),
# ],
# logging_dir=f"/fsx/remi_cadene/logs/openx_rlds_merge",
# tasks=1,
# workers=1,
# **merge_extra_kwargs,
# )
# merge_executor.run()
if __name__ == "__main__":
main()

View File

@ -56,7 +56,9 @@ def zero_action_filter(traj: Dict) -> bool:
0.8897542208433151, 0.8897542208433151,
] ]
) )
DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1 DROID_NORM_0_ACT = (
2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
)
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5) return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5)
@ -799,7 +801,11 @@ OXE_DATASET_CONFIGS = {
}, },
### DROID Finetuning datasets ### DROID Finetuning datasets
"droid_wipe": { "droid_wipe": {
"image_obs_keys": {"primary": "exterior_image_2_left", "secondary": None, "wrist": "wrist_image_left"}, "image_obs_keys": {
"primary": "exterior_image_2_left",
"secondary": None,
"wrist": "wrist_image_left",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["proprio"], "state_obs_keys": ["proprio"],
"state_encoding": StateEncoding.POS_EULER, "state_encoding": StateEncoding.POS_EULER,

View File

@ -0,0 +1,30 @@
#!/bin/bash
#SBATCH --job-name=openx_rlds
#SBATCH --partition=hopper-cpu
#SBATCH --requeue
#SBATCH --time=00:01:00
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --output=/fsx/%u/slurm/%j-%x.out
OUTPUT_DIR="/fsx/${USER}/slurm/${SLURM_JOB_NAME}-${SLURM_JOB_ID}-tasks"
mkdir -p $OUTPUT_DIR
# Function to run a task and redirect output to a separate file
run_task() {
local task_id=$1
local output_file="${OUTPUT_DIR}/task-${task_id}-${SLURM_JOB_ID}.out"
# Run the task and redirect output
python examples/port_datasets/openx_utils/test.py > $output_file 2>&1
}
echo $SBATCH_OUTPUT
# node has 380850M and 96 cpus
trap 'scontrol requeue ${SLURM_JOB_ID}; exit 15' SIGUSR1
echo "Starting job"
# note the "&" to start srun as a background thread
srun python examples/port_datasets/openx_utils/test.py &
# wait for signals...
wait

View File

@ -0,0 +1,54 @@
import os
import psutil
def display_system_info():
# Get the number of CPUs
num_cpus = psutil.cpu_count(logical=True)
print(f"Number of CPUs: {num_cpus}")
# Get memory information
memory_info = psutil.virtual_memory()
total_memory = memory_info.total / (1024**3) # Convert bytes to GB
available_memory = memory_info.available / (1024**3) # Convert bytes to GB
used_memory = memory_info.used / (1024**3) # Convert bytes to GB
print(f"Total Memory: {total_memory:.2f} GB")
print(f"Available Memory: {available_memory:.2f} GB")
print(f"Used Memory: {used_memory:.2f} GB")
def display_slurm_info():
# Get SLURM job ID
job_id = os.getenv("SLURM_JOB_ID")
print(f"SLURM Job ID: {job_id}")
# Get SLURM job name
job_name = os.getenv("SLURM_JOB_NAME")
print(f"SLURM Job Name: {job_name}")
# Get the number of tasks
num_tasks = os.getenv("SLURM_NTASKS")
print(f"Number of Tasks: {num_tasks}")
# Get the number of nodes
num_nodes = os.getenv("SLURM_NNODES")
print(f"Number of Nodes: {num_nodes}")
# Get the number of CPUs per task
cpus_per_task = os.getenv("SLURM_CPUS_PER_TASK")
print(f"CPUs per Task: {cpus_per_task}")
# Get the node list
node_list = os.getenv("SLURM_NODELIST")
print(f"Node List: {node_list}")
# Get the task ID (only available within an srun task)
task_id = os.getenv("SLURM_PROCID")
print(f"Task ID: {task_id}")
if __name__ == "__main__":
display_system_info()
display_slurm_info()

View File

@ -2,7 +2,6 @@
Copied from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/utils/data_utils.py Copied from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/utils/data_utils.py
""" """
from typing import Any, Dict from typing import Any, Dict
import tensorflow as tf import tensorflow as tf
@ -66,6 +65,7 @@ def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
return new_actions return new_actions
# === Bridge-V2 =>> Dataset-Specific Transform === # === Bridge-V2 =>> Dataset-Specific Transform ===
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]: def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action).""" """Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""

View File

@ -19,7 +19,8 @@ Transforms adopt the following structure:
from typing import Any, Dict from typing import Any, Dict
import tensorflow as tf import tensorflow as tf
from oxe_utils.transform_utils import (
from examples.port_datasets.openx_utils.transform_utils import (
binarize_gripper_actions, binarize_gripper_actions,
invert_gripper_actions, invert_gripper_actions,
rel2abs_gripper_actions, rel2abs_gripper_actions,
@ -31,6 +32,7 @@ def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
""" """
DROID dataset transformation for actions expressed in *base* frame of the robot. DROID dataset transformation for actions expressed in *base* frame of the robot.
""" """
def rand_swap_exterior_images(img1, img2): def rand_swap_exterior_images(img1, img2):
""" """
Randomly swaps the two exterior images (for training with single exterior input). Randomly swaps the two exterior images (for training with single exterior input).
@ -55,11 +57,11 @@ def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
) )
) )
# trajectory["observation"]["proprio"] = tf.concat( # trajectory["observation"]["proprio"] = tf.concat(
# ( # (
# trajectory["observation"]["cartesian_position"], # trajectory["observation"]["cartesian_position"],
# trajectory["observation"]["gripper_position"], # trajectory["observation"]["gripper_position"],
# ), # ),
# axis=-1, # axis=-1,
# ) # )
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"] trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"] trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"]
@ -96,7 +98,7 @@ def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
""" """
for key in trajectory.keys(): for key in trajectory:
if key == "traj_metadata": if key == "traj_metadata":
continue continue
elif key in ["observation", "action"]: elif key in ["observation", "action"]:
@ -126,7 +128,7 @@ def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
""" """
for key in trajectory.keys(): for key in trajectory:
if key == "traj_metadata": if key == "traj_metadata":
continue continue
elif key == "observation": elif key == "observation":
@ -198,7 +200,9 @@ def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
) )
eef_value = tf.io.decode_raw(eef_value, tf.float32) eef_value = tf.io.decode_raw(eef_value, tf.float32)
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7)) trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB") gripper_value = tf.io.decode_compressed(
trajectory["observation"]["gripper_closed"], compression_type="ZLIB"
)
gripper_value = tf.io.decode_raw(gripper_value, tf.float32) gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1)) trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
# trajectory["language_instruction"] = tf.fill( # trajectory["language_instruction"] = tf.fill(
@ -228,7 +232,9 @@ def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6] trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][:, -1:] trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][
:, -1:
]
# make gripper action absolute action, +1 = open, 0 = close # make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
@ -264,7 +270,9 @@ def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert absolute gripper action, +1 = open, 0 = close # invert absolute gripper action, +1 = open, 0 = close
gripper_action = invert_gripper_actions(tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)) gripper_action = invert_gripper_actions(
tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)
)
trajectory["action"] = tf.concat( trajectory["action"] = tf.concat(
( (
@ -374,7 +382,9 @@ def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, An
instruction_bytes = trajectory["observation"]["instruction"] instruction_bytes = trajectory["observation"]["instruction"]
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8") instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
# Remove trailing padding --> convert RaggedTensor to regular Tensor. # Remove trailing padding --> convert RaggedTensor to regular Tensor.
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[:, 0] trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[
:, 0
]
return trajectory return trajectory
@ -900,7 +910,9 @@ def libero_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
axis=1, axis=1,
) )
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -2:] # 2D gripper state trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][
:, -2:
] # 2D gripper state
return trajectory return trajectory