WIP
This commit is contained in:
parent
b520941cd9
commit
71d1f5e2c9
|
@ -17,37 +17,35 @@
|
|||
For all datasets in the RLDS format.
|
||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
||||
|
||||
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
|
||||
NOTE: Install `tensorflow` and `tensorflow_datasets` before running this script.
|
||||
```bash
|
||||
pip install tensorflow
|
||||
pip install tensorflow_datasets
|
||||
```
|
||||
|
||||
Example:
|
||||
python openx_rlds.py \
|
||||
--raw-dir /path/to/bridge_orig/1.0.0 \
|
||||
--local-dir /path/to/local_dir \
|
||||
--repo-id your_id \
|
||||
--use-videos \
|
||||
--push-to-hub
|
||||
```bash
|
||||
python examples/port_datasets/openx_rlds.py \
|
||||
--raw-dir /fsx/mustafa_shukor/droid \
|
||||
--repo-id cadene/droid \
|
||||
--use-videos \
|
||||
--push-to-hub
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
oxe_utils_dir = os.path.join(current_dir, "oxe_utils")
|
||||
sys.path.append(oxe_utils_dir)
|
||||
|
||||
from oxe_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
|
||||
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
||||
from examples.port_datasets.openx_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
|
||||
from examples.port_datasets.openx_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
np.set_printoptions(precision=2)
|
||||
|
||||
|
@ -87,16 +85,23 @@ def transform_raw_dataset(episode, dataset_name):
|
|||
return episode
|
||||
|
||||
|
||||
def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bool = True):
|
||||
dataset_name = builder.name
|
||||
|
||||
def generate_features_from_raw(dataset_name: str, builder: tfds.core.DatasetBuilder, use_videos: bool = True):
|
||||
state_names = [f"motor_{i}" for i in range(8)]
|
||||
if dataset_name in OXE_DATASET_CONFIGS:
|
||||
state_encoding = OXE_DATASET_CONFIGS[dataset_name]["state_encoding"]
|
||||
if state_encoding == StateEncoding.POS_EULER:
|
||||
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"]
|
||||
if "libero" in dataset_name:
|
||||
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state
|
||||
state_names = [
|
||||
"x",
|
||||
"y",
|
||||
"z",
|
||||
"roll",
|
||||
"pitch",
|
||||
"yaw",
|
||||
"gripper",
|
||||
"gripper",
|
||||
] # 2D gripper state
|
||||
elif state_encoding == StateEncoding.POS_QUAT:
|
||||
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
|
||||
|
||||
|
@ -126,44 +131,68 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
|
|||
return {**features, **DEFAULT_FEATURES}
|
||||
|
||||
|
||||
def save_as_lerobot_dataset(lerobot_dataset: LeRobotDataset, raw_dataset: tf.data.Dataset, **kwargs):
|
||||
for episode in raw_dataset.as_numpy_iterator():
|
||||
def save_as_lerobot_dataset(
|
||||
dataset_name: str,
|
||||
lerobot_dataset: LeRobotDataset,
|
||||
raw_dataset: tf.data.Dataset,
|
||||
num_shards: int | None = None,
|
||||
shard_index: int | None = None,
|
||||
):
|
||||
total_num_episodes = raw_dataset.cardinality().numpy().item()
|
||||
print(f"Total number of episodes {total_num_episodes}")
|
||||
|
||||
if num_shards is not None:
|
||||
num_shards = 10000
|
||||
shard_index = 9999
|
||||
sharded_dataset = raw_dataset.shard(num_shards=num_shards, index=shard_index)
|
||||
sharded_num_episodes = sharded_dataset.cardinality().numpy().item()
|
||||
print(f"{sharded_num_episodes=}")
|
||||
num_episodes = sharded_num_episodes
|
||||
iter_ = iter(sharded_dataset)
|
||||
else:
|
||||
num_episodes = total_num_episodes
|
||||
iter_ = iter(raw_dataset)
|
||||
|
||||
for episode_index in range(num_episodes):
|
||||
print(f"{episode_index} / {num_episodes}")
|
||||
episode = next(iter_)
|
||||
print("\nnext\n")
|
||||
episode = transform_raw_dataset(episode, dataset_name)
|
||||
|
||||
traj = episode["steps"]
|
||||
for i in range(traj["action"].shape[0]):
|
||||
for i in tqdm.tqdm(range(traj["action"].shape[0])):
|
||||
image_dict = {
|
||||
f"observation.images.{key}": value[i]
|
||||
f"observation.images.{key}": value[i].numpy()
|
||||
for key, value in traj["observation"].items()
|
||||
if "depth" not in key and any(x in key for x in ["image", "rgb"])
|
||||
}
|
||||
lerobot_dataset.add_frame(
|
||||
{
|
||||
**image_dict,
|
||||
"observation.state": traj["proprio"][i],
|
||||
"action": traj["action"][i],
|
||||
"observation.state": traj["proprio"][i].numpy(),
|
||||
"action": traj["action"][i].numpy(),
|
||||
"task": traj["task"][i].numpy().decode(),
|
||||
}
|
||||
)
|
||||
lerobot_dataset.save_episode(task=traj["task"][0].decode())
|
||||
|
||||
lerobot_dataset.consolidate(
|
||||
run_compute_stats=True,
|
||||
keep_image_files=kwargs["keep_images"],
|
||||
stat_kwargs={"batch_size": kwargs["batch_size"], "num_workers": kwargs["num_workers"]},
|
||||
)
|
||||
print()
|
||||
lerobot_dataset.save_episode()
|
||||
print("\nsave_episode\n")
|
||||
|
||||
break
|
||||
|
||||
|
||||
def create_lerobot_dataset(
|
||||
raw_dir: Path,
|
||||
repo_id: str = None,
|
||||
local_dir: Path = None,
|
||||
push_to_hub: bool = False,
|
||||
fps: int = None,
|
||||
robot_type: str = None,
|
||||
use_videos: bool = True,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 8,
|
||||
image_writer_process: int = 5,
|
||||
image_writer_threads: int = 10,
|
||||
keep_images: bool = True,
|
||||
num_shards: int | None = None,
|
||||
shard_index: int | None = None,
|
||||
):
|
||||
last_part = raw_dir.name
|
||||
if re.match(r"^\d+\.\d+\.\d+$", last_part):
|
||||
|
@ -175,15 +204,9 @@ def create_lerobot_dataset(
|
|||
dataset_name = last_part
|
||||
data_dir = raw_dir.parent
|
||||
|
||||
if local_dir is None:
|
||||
local_dir = Path(LEROBOT_HOME)
|
||||
local_dir /= f"{dataset_name}_{version}_lerobot"
|
||||
if local_dir.exists():
|
||||
shutil.rmtree(local_dir)
|
||||
|
||||
builder = tfds.builder(dataset_name, data_dir=data_dir, version=version)
|
||||
features = generate_features_from_raw(builder, use_videos)
|
||||
raw_dataset = builder.as_dataset(split="train").map(partial(transform_raw_dataset, dataset_name=dataset_name))
|
||||
features = generate_features_from_raw(dataset_name, builder, use_videos)
|
||||
raw_dataset = builder.as_dataset(split="train")
|
||||
|
||||
if fps is None:
|
||||
if dataset_name in OXE_DATASET_CONFIGS:
|
||||
|
@ -201,7 +224,6 @@ def create_lerobot_dataset(
|
|||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
robot_type=robot_type,
|
||||
root=local_dir,
|
||||
fps=fps,
|
||||
use_videos=use_videos,
|
||||
features=features,
|
||||
|
@ -210,16 +232,18 @@ def create_lerobot_dataset(
|
|||
)
|
||||
|
||||
save_as_lerobot_dataset(
|
||||
lerobot_dataset, raw_dataset, keep_images=keep_images, batch_size=batch_size, num_workers=num_workers
|
||||
dataset_name,
|
||||
lerobot_dataset,
|
||||
raw_dataset,
|
||||
num_shards=num_shards,
|
||||
shard_index=shard_index,
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
assert repo_id is not None
|
||||
tags = ["LeRobot", dataset_name, "rlds"]
|
||||
tags = []
|
||||
if dataset_name in OXE_DATASET_CONFIGS:
|
||||
tags.append("openx")
|
||||
if robot_type != "unknown":
|
||||
tags.append(robot_type)
|
||||
lerobot_dataset.push_to_hub(
|
||||
tags=tags,
|
||||
private=False,
|
||||
|
@ -237,12 +261,6 @@ def main():
|
|||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
|
@ -270,37 +288,25 @@ def main():
|
|||
action="store_true",
|
||||
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size loaded by DataLoader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-writer-process",
|
||||
type=int,
|
||||
default=5,
|
||||
default=0,
|
||||
help="Number of processes of image writer for saving images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-writer-threads",
|
||||
type=int,
|
||||
default=10,
|
||||
default=8,
|
||||
help="Number of threads per process of image writer for saving images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-images",
|
||||
action="store_true",
|
||||
help="Whether to keep the cached images.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
droid_dir = Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene/droid")
|
||||
if droid_dir.exists():
|
||||
shutil.rmtree(droid_dir)
|
||||
|
||||
create_lerobot_dataset(**vars(args))
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
import datetime as dt
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
|
||||
class PortOpenXDataset(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
raw_dir: Path,
|
||||
repo_id: str = None,
|
||||
image_writer_process: int = 0,
|
||||
image_writer_threads: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
self.raw_dir = raw_dir
|
||||
self.repo_id = repo_id
|
||||
self.image_writer_process = image_writer_process
|
||||
self.image_writer_threads = image_writer_threads
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
from examples.port_datasets.openx_rlds import create_lerobot_dataset
|
||||
from examples.port_datasets.openx_utils.test import display_slurm_info, display_system_info
|
||||
|
||||
display_system_info()
|
||||
display_slurm_info()
|
||||
|
||||
create_lerobot_dataset(
|
||||
self.raw_dir,
|
||||
f"{self.repo_id}_world_{world_size}_rank_{rank}",
|
||||
image_writer_process=self.image_writer_process,
|
||||
image_writer_threads=self.image_writer_threads,
|
||||
push_to_hub=False,
|
||||
num_shards=world_size,
|
||||
shard_index=rank,
|
||||
)
|
||||
|
||||
|
||||
class AggregateDatasets(PipelineStep):
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
print("aggregation")
|
||||
|
||||
|
||||
def main(slurm=True):
|
||||
for dir_ in Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene").glob("droid_world*"):
|
||||
shutil.rmtree(dir_)
|
||||
|
||||
now = dt.datetime.now()
|
||||
port_job_name = "port_openx_droid"
|
||||
logs_dir = Path("/fsx/remi_cadene/logs")
|
||||
port_log_dir = logs_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{port_job_name}"
|
||||
|
||||
if slurm:
|
||||
executor_class = SlurmPipelineExecutor
|
||||
dist_extra_kwargs = {
|
||||
"job_name": port_job_name,
|
||||
"tasks": 10000,
|
||||
"workers": 24,
|
||||
"time": "00:30:00",
|
||||
"partition": "hopper-cpu",
|
||||
"cpus_per_task": 12,
|
||||
"mem_per_cpu_gb": 4,
|
||||
}
|
||||
else:
|
||||
executor_class = LocalPipelineExecutor
|
||||
dist_extra_kwargs = {
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
}
|
||||
|
||||
port_executor = executor_class(
|
||||
pipeline=[
|
||||
PortOpenXDataset(raw_dir=Path("/fsx/mustafa_shukor/droid"), repo_id="cadene/droid"),
|
||||
],
|
||||
logging_dir=str(port_log_dir),
|
||||
**dist_extra_kwargs,
|
||||
)
|
||||
port_executor.run()
|
||||
|
||||
# if slurm:
|
||||
# merge_extra_kwargs = {}
|
||||
# else:
|
||||
# merge_extra_kwargs = {
|
||||
# "job_name": "aggregate",
|
||||
# "time": "00:01:00",
|
||||
# "partition": "hopper-cpu",
|
||||
# }
|
||||
|
||||
# merge_executor = executor_class(
|
||||
# depends=dist_executor,
|
||||
# pipeline=[
|
||||
# Aggregate(),
|
||||
# ],
|
||||
# logging_dir=f"/fsx/remi_cadene/logs/openx_rlds_merge",
|
||||
# tasks=1,
|
||||
# workers=1,
|
||||
# **merge_extra_kwargs,
|
||||
# )
|
||||
# merge_executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -56,7 +56,9 @@ def zero_action_filter(traj: Dict) -> bool:
|
|||
0.8897542208433151,
|
||||
]
|
||||
)
|
||||
DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
|
||||
DROID_NORM_0_ACT = (
|
||||
2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
|
||||
)
|
||||
|
||||
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5)
|
||||
|
||||
|
@ -799,7 +801,11 @@ OXE_DATASET_CONFIGS = {
|
|||
},
|
||||
### DROID Finetuning datasets
|
||||
"droid_wipe": {
|
||||
"image_obs_keys": {"primary": "exterior_image_2_left", "secondary": None, "wrist": "wrist_image_left"},
|
||||
"image_obs_keys": {
|
||||
"primary": "exterior_image_2_left",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image_left",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["proprio"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
|
@ -0,0 +1,30 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --job-name=openx_rlds
|
||||
#SBATCH --partition=hopper-cpu
|
||||
#SBATCH --requeue
|
||||
#SBATCH --time=00:01:00
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --output=/fsx/%u/slurm/%j-%x.out
|
||||
|
||||
OUTPUT_DIR="/fsx/${USER}/slurm/${SLURM_JOB_NAME}-${SLURM_JOB_ID}-tasks"
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# Function to run a task and redirect output to a separate file
|
||||
run_task() {
|
||||
local task_id=$1
|
||||
local output_file="${OUTPUT_DIR}/task-${task_id}-${SLURM_JOB_ID}.out"
|
||||
|
||||
# Run the task and redirect output
|
||||
python examples/port_datasets/openx_utils/test.py > $output_file 2>&1
|
||||
}
|
||||
|
||||
echo $SBATCH_OUTPUT
|
||||
|
||||
# node has 380850M and 96 cpus
|
||||
trap 'scontrol requeue ${SLURM_JOB_ID}; exit 15' SIGUSR1
|
||||
echo "Starting job"
|
||||
# note the "&" to start srun as a background thread
|
||||
srun python examples/port_datasets/openx_utils/test.py &
|
||||
# wait for signals...
|
||||
wait
|
|
@ -0,0 +1,54 @@
|
|||
import os
|
||||
|
||||
import psutil
|
||||
|
||||
|
||||
def display_system_info():
|
||||
# Get the number of CPUs
|
||||
num_cpus = psutil.cpu_count(logical=True)
|
||||
print(f"Number of CPUs: {num_cpus}")
|
||||
|
||||
# Get memory information
|
||||
memory_info = psutil.virtual_memory()
|
||||
total_memory = memory_info.total / (1024**3) # Convert bytes to GB
|
||||
available_memory = memory_info.available / (1024**3) # Convert bytes to GB
|
||||
used_memory = memory_info.used / (1024**3) # Convert bytes to GB
|
||||
|
||||
print(f"Total Memory: {total_memory:.2f} GB")
|
||||
print(f"Available Memory: {available_memory:.2f} GB")
|
||||
print(f"Used Memory: {used_memory:.2f} GB")
|
||||
|
||||
|
||||
def display_slurm_info():
|
||||
# Get SLURM job ID
|
||||
job_id = os.getenv("SLURM_JOB_ID")
|
||||
print(f"SLURM Job ID: {job_id}")
|
||||
|
||||
# Get SLURM job name
|
||||
job_name = os.getenv("SLURM_JOB_NAME")
|
||||
print(f"SLURM Job Name: {job_name}")
|
||||
|
||||
# Get the number of tasks
|
||||
num_tasks = os.getenv("SLURM_NTASKS")
|
||||
print(f"Number of Tasks: {num_tasks}")
|
||||
|
||||
# Get the number of nodes
|
||||
num_nodes = os.getenv("SLURM_NNODES")
|
||||
print(f"Number of Nodes: {num_nodes}")
|
||||
|
||||
# Get the number of CPUs per task
|
||||
cpus_per_task = os.getenv("SLURM_CPUS_PER_TASK")
|
||||
print(f"CPUs per Task: {cpus_per_task}")
|
||||
|
||||
# Get the node list
|
||||
node_list = os.getenv("SLURM_NODELIST")
|
||||
print(f"Node List: {node_list}")
|
||||
|
||||
# Get the task ID (only available within an srun task)
|
||||
task_id = os.getenv("SLURM_PROCID")
|
||||
print(f"Task ID: {task_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
display_system_info()
|
||||
display_slurm_info()
|
|
@ -2,7 +2,6 @@
|
|||
Copied from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/utils/data_utils.py
|
||||
"""
|
||||
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
|
@ -66,6 +65,7 @@ def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
|||
|
||||
return new_actions
|
||||
|
||||
|
||||
# === Bridge-V2 =>> Dataset-Specific Transform ===
|
||||
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
|
|
@ -19,7 +19,8 @@ Transforms adopt the following structure:
|
|||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
from oxe_utils.transform_utils import (
|
||||
|
||||
from examples.port_datasets.openx_utils.transform_utils import (
|
||||
binarize_gripper_actions,
|
||||
invert_gripper_actions,
|
||||
rel2abs_gripper_actions,
|
||||
|
@ -31,6 +32,7 @@ def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|||
"""
|
||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
||||
"""
|
||||
|
||||
def rand_swap_exterior_images(img1, img2):
|
||||
"""
|
||||
Randomly swaps the two exterior images (for training with single exterior input).
|
||||
|
@ -55,11 +57,11 @@ def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|||
)
|
||||
)
|
||||
# trajectory["observation"]["proprio"] = tf.concat(
|
||||
# (
|
||||
# trajectory["observation"]["cartesian_position"],
|
||||
# trajectory["observation"]["gripper_position"],
|
||||
# ),
|
||||
# axis=-1,
|
||||
# (
|
||||
# trajectory["observation"]["cartesian_position"],
|
||||
# trajectory["observation"]["gripper_position"],
|
||||
# ),
|
||||
# axis=-1,
|
||||
# )
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"]
|
||||
|
@ -96,7 +98,7 @@ def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||
"""
|
||||
for key in trajectory.keys():
|
||||
for key in trajectory:
|
||||
if key == "traj_metadata":
|
||||
continue
|
||||
elif key in ["observation", "action"]:
|
||||
|
@ -126,7 +128,7 @@ def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||
"""
|
||||
for key in trajectory.keys():
|
||||
for key in trajectory:
|
||||
if key == "traj_metadata":
|
||||
continue
|
||||
elif key == "observation":
|
||||
|
@ -198,7 +200,9 @@ def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|||
)
|
||||
eef_value = tf.io.decode_raw(eef_value, tf.float32)
|
||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
|
||||
gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB")
|
||||
gripper_value = tf.io.decode_compressed(
|
||||
trajectory["observation"]["gripper_closed"], compression_type="ZLIB"
|
||||
)
|
||||
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
|
||||
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
|
||||
# trajectory["language_instruction"] = tf.fill(
|
||||
|
@ -228,7 +232,9 @@ def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
||||
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
|
||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][:, -1:]
|
||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][
|
||||
:, -1:
|
||||
]
|
||||
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
|
@ -264,7 +270,9 @@ def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict
|
|||
|
||||
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert absolute gripper action, +1 = open, 0 = close
|
||||
gripper_action = invert_gripper_actions(tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1))
|
||||
gripper_action = invert_gripper_actions(
|
||||
tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)
|
||||
)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
|
@ -374,7 +382,9 @@ def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, An
|
|||
instruction_bytes = trajectory["observation"]["instruction"]
|
||||
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
|
||||
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
|
||||
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[:, 0]
|
||||
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[
|
||||
:, 0
|
||||
]
|
||||
return trajectory
|
||||
|
||||
|
||||
|
@ -900,7 +910,9 @@ def libero_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|||
axis=1,
|
||||
)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -2:] # 2D gripper state
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][
|
||||
:, -2:
|
||||
] # 2D gripper state
|
||||
return trajectory
|
||||
|
||||
|
Loading…
Reference in New Issue