WIP
This commit is contained in:
parent
b520941cd9
commit
71d1f5e2c9
|
@ -17,37 +17,35 @@
|
||||||
For all datasets in the RLDS format.
|
For all datasets in the RLDS format.
|
||||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
||||||
|
|
||||||
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
|
NOTE: Install `tensorflow` and `tensorflow_datasets` before running this script.
|
||||||
|
```bash
|
||||||
|
pip install tensorflow
|
||||||
|
pip install tensorflow_datasets
|
||||||
|
```
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
python openx_rlds.py \
|
```bash
|
||||||
--raw-dir /path/to/bridge_orig/1.0.0 \
|
python examples/port_datasets/openx_rlds.py \
|
||||||
--local-dir /path/to/local_dir \
|
--raw-dir /fsx/mustafa_shukor/droid \
|
||||||
--repo-id your_id \
|
--repo-id cadene/droid \
|
||||||
--use-videos \
|
--use-videos \
|
||||||
--push-to-hub
|
--push-to-hub
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_datasets as tfds
|
import tensorflow_datasets as tfds
|
||||||
|
import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
from examples.port_datasets.openx_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
|
||||||
|
from examples.port_datasets.openx_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
oxe_utils_dir = os.path.join(current_dir, "oxe_utils")
|
|
||||||
sys.path.append(oxe_utils_dir)
|
|
||||||
|
|
||||||
from oxe_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
|
|
||||||
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
|
||||||
|
|
||||||
np.set_printoptions(precision=2)
|
np.set_printoptions(precision=2)
|
||||||
|
|
||||||
|
@ -87,16 +85,23 @@ def transform_raw_dataset(episode, dataset_name):
|
||||||
return episode
|
return episode
|
||||||
|
|
||||||
|
|
||||||
def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bool = True):
|
def generate_features_from_raw(dataset_name: str, builder: tfds.core.DatasetBuilder, use_videos: bool = True):
|
||||||
dataset_name = builder.name
|
|
||||||
|
|
||||||
state_names = [f"motor_{i}" for i in range(8)]
|
state_names = [f"motor_{i}" for i in range(8)]
|
||||||
if dataset_name in OXE_DATASET_CONFIGS:
|
if dataset_name in OXE_DATASET_CONFIGS:
|
||||||
state_encoding = OXE_DATASET_CONFIGS[dataset_name]["state_encoding"]
|
state_encoding = OXE_DATASET_CONFIGS[dataset_name]["state_encoding"]
|
||||||
if state_encoding == StateEncoding.POS_EULER:
|
if state_encoding == StateEncoding.POS_EULER:
|
||||||
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"]
|
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"]
|
||||||
if "libero" in dataset_name:
|
if "libero" in dataset_name:
|
||||||
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state
|
state_names = [
|
||||||
|
"x",
|
||||||
|
"y",
|
||||||
|
"z",
|
||||||
|
"roll",
|
||||||
|
"pitch",
|
||||||
|
"yaw",
|
||||||
|
"gripper",
|
||||||
|
"gripper",
|
||||||
|
] # 2D gripper state
|
||||||
elif state_encoding == StateEncoding.POS_QUAT:
|
elif state_encoding == StateEncoding.POS_QUAT:
|
||||||
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
|
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
|
||||||
|
|
||||||
|
@ -126,44 +131,68 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
|
||||||
return {**features, **DEFAULT_FEATURES}
|
return {**features, **DEFAULT_FEATURES}
|
||||||
|
|
||||||
|
|
||||||
def save_as_lerobot_dataset(lerobot_dataset: LeRobotDataset, raw_dataset: tf.data.Dataset, **kwargs):
|
def save_as_lerobot_dataset(
|
||||||
for episode in raw_dataset.as_numpy_iterator():
|
dataset_name: str,
|
||||||
|
lerobot_dataset: LeRobotDataset,
|
||||||
|
raw_dataset: tf.data.Dataset,
|
||||||
|
num_shards: int | None = None,
|
||||||
|
shard_index: int | None = None,
|
||||||
|
):
|
||||||
|
total_num_episodes = raw_dataset.cardinality().numpy().item()
|
||||||
|
print(f"Total number of episodes {total_num_episodes}")
|
||||||
|
|
||||||
|
if num_shards is not None:
|
||||||
|
num_shards = 10000
|
||||||
|
shard_index = 9999
|
||||||
|
sharded_dataset = raw_dataset.shard(num_shards=num_shards, index=shard_index)
|
||||||
|
sharded_num_episodes = sharded_dataset.cardinality().numpy().item()
|
||||||
|
print(f"{sharded_num_episodes=}")
|
||||||
|
num_episodes = sharded_num_episodes
|
||||||
|
iter_ = iter(sharded_dataset)
|
||||||
|
else:
|
||||||
|
num_episodes = total_num_episodes
|
||||||
|
iter_ = iter(raw_dataset)
|
||||||
|
|
||||||
|
for episode_index in range(num_episodes):
|
||||||
|
print(f"{episode_index} / {num_episodes}")
|
||||||
|
episode = next(iter_)
|
||||||
|
print("\nnext\n")
|
||||||
|
episode = transform_raw_dataset(episode, dataset_name)
|
||||||
|
|
||||||
traj = episode["steps"]
|
traj = episode["steps"]
|
||||||
for i in range(traj["action"].shape[0]):
|
for i in tqdm.tqdm(range(traj["action"].shape[0])):
|
||||||
image_dict = {
|
image_dict = {
|
||||||
f"observation.images.{key}": value[i]
|
f"observation.images.{key}": value[i].numpy()
|
||||||
for key, value in traj["observation"].items()
|
for key, value in traj["observation"].items()
|
||||||
if "depth" not in key and any(x in key for x in ["image", "rgb"])
|
if "depth" not in key and any(x in key for x in ["image", "rgb"])
|
||||||
}
|
}
|
||||||
lerobot_dataset.add_frame(
|
lerobot_dataset.add_frame(
|
||||||
{
|
{
|
||||||
**image_dict,
|
**image_dict,
|
||||||
"observation.state": traj["proprio"][i],
|
"observation.state": traj["proprio"][i].numpy(),
|
||||||
"action": traj["action"][i],
|
"action": traj["action"][i].numpy(),
|
||||||
|
"task": traj["task"][i].numpy().decode(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
lerobot_dataset.save_episode(task=traj["task"][0].decode())
|
|
||||||
|
|
||||||
lerobot_dataset.consolidate(
|
print()
|
||||||
run_compute_stats=True,
|
lerobot_dataset.save_episode()
|
||||||
keep_image_files=kwargs["keep_images"],
|
print("\nsave_episode\n")
|
||||||
stat_kwargs={"batch_size": kwargs["batch_size"], "num_workers": kwargs["num_workers"]},
|
|
||||||
)
|
break
|
||||||
|
|
||||||
|
|
||||||
def create_lerobot_dataset(
|
def create_lerobot_dataset(
|
||||||
raw_dir: Path,
|
raw_dir: Path,
|
||||||
repo_id: str = None,
|
repo_id: str = None,
|
||||||
local_dir: Path = None,
|
|
||||||
push_to_hub: bool = False,
|
push_to_hub: bool = False,
|
||||||
fps: int = None,
|
fps: int = None,
|
||||||
robot_type: str = None,
|
robot_type: str = None,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
batch_size: int = 32,
|
|
||||||
num_workers: int = 8,
|
|
||||||
image_writer_process: int = 5,
|
image_writer_process: int = 5,
|
||||||
image_writer_threads: int = 10,
|
image_writer_threads: int = 10,
|
||||||
keep_images: bool = True,
|
num_shards: int | None = None,
|
||||||
|
shard_index: int | None = None,
|
||||||
):
|
):
|
||||||
last_part = raw_dir.name
|
last_part = raw_dir.name
|
||||||
if re.match(r"^\d+\.\d+\.\d+$", last_part):
|
if re.match(r"^\d+\.\d+\.\d+$", last_part):
|
||||||
|
@ -175,15 +204,9 @@ def create_lerobot_dataset(
|
||||||
dataset_name = last_part
|
dataset_name = last_part
|
||||||
data_dir = raw_dir.parent
|
data_dir = raw_dir.parent
|
||||||
|
|
||||||
if local_dir is None:
|
|
||||||
local_dir = Path(LEROBOT_HOME)
|
|
||||||
local_dir /= f"{dataset_name}_{version}_lerobot"
|
|
||||||
if local_dir.exists():
|
|
||||||
shutil.rmtree(local_dir)
|
|
||||||
|
|
||||||
builder = tfds.builder(dataset_name, data_dir=data_dir, version=version)
|
builder = tfds.builder(dataset_name, data_dir=data_dir, version=version)
|
||||||
features = generate_features_from_raw(builder, use_videos)
|
features = generate_features_from_raw(dataset_name, builder, use_videos)
|
||||||
raw_dataset = builder.as_dataset(split="train").map(partial(transform_raw_dataset, dataset_name=dataset_name))
|
raw_dataset = builder.as_dataset(split="train")
|
||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
if dataset_name in OXE_DATASET_CONFIGS:
|
if dataset_name in OXE_DATASET_CONFIGS:
|
||||||
|
@ -201,7 +224,6 @@ def create_lerobot_dataset(
|
||||||
lerobot_dataset = LeRobotDataset.create(
|
lerobot_dataset = LeRobotDataset.create(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
robot_type=robot_type,
|
robot_type=robot_type,
|
||||||
root=local_dir,
|
|
||||||
fps=fps,
|
fps=fps,
|
||||||
use_videos=use_videos,
|
use_videos=use_videos,
|
||||||
features=features,
|
features=features,
|
||||||
|
@ -210,16 +232,18 @@ def create_lerobot_dataset(
|
||||||
)
|
)
|
||||||
|
|
||||||
save_as_lerobot_dataset(
|
save_as_lerobot_dataset(
|
||||||
lerobot_dataset, raw_dataset, keep_images=keep_images, batch_size=batch_size, num_workers=num_workers
|
dataset_name,
|
||||||
|
lerobot_dataset,
|
||||||
|
raw_dataset,
|
||||||
|
num_shards=num_shards,
|
||||||
|
shard_index=shard_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
assert repo_id is not None
|
assert repo_id is not None
|
||||||
tags = ["LeRobot", dataset_name, "rlds"]
|
tags = []
|
||||||
if dataset_name in OXE_DATASET_CONFIGS:
|
if dataset_name in OXE_DATASET_CONFIGS:
|
||||||
tags.append("openx")
|
tags.append("openx")
|
||||||
if robot_type != "unknown":
|
|
||||||
tags.append(robot_type)
|
|
||||||
lerobot_dataset.push_to_hub(
|
lerobot_dataset.push_to_hub(
|
||||||
tags=tags,
|
tags=tags,
|
||||||
private=False,
|
private=False,
|
||||||
|
@ -237,12 +261,6 @@ def main():
|
||||||
required=True,
|
required=True,
|
||||||
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
|
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--local-dir",
|
|
||||||
type=Path,
|
|
||||||
required=True,
|
|
||||||
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--repo-id",
|
"--repo-id",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -270,37 +288,25 @@ def main():
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
|
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--batch-size",
|
|
||||||
type=int,
|
|
||||||
default=32,
|
|
||||||
help="Batch size loaded by DataLoader for computing the dataset statistics.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-workers",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--image-writer-process",
|
"--image-writer-process",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=0,
|
||||||
help="Number of processes of image writer for saving images.",
|
help="Number of processes of image writer for saving images.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--image-writer-threads",
|
"--image-writer-threads",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=8,
|
||||||
help="Number of threads per process of image writer for saving images.",
|
help="Number of threads per process of image writer for saving images.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--keep-images",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to keep the cached images.",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
droid_dir = Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene/droid")
|
||||||
|
if droid_dir.exists():
|
||||||
|
shutil.rmtree(droid_dir)
|
||||||
|
|
||||||
create_lerobot_dataset(**vars(args))
|
create_lerobot_dataset(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,106 @@
|
||||||
|
import datetime as dt
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from datatrove.executor import LocalPipelineExecutor
|
||||||
|
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||||
|
from datatrove.pipeline.base import PipelineStep
|
||||||
|
|
||||||
|
|
||||||
|
class PortOpenXDataset(PipelineStep):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
raw_dir: Path,
|
||||||
|
repo_id: str = None,
|
||||||
|
image_writer_process: int = 0,
|
||||||
|
image_writer_threads: int = 8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.raw_dir = raw_dir
|
||||||
|
self.repo_id = repo_id
|
||||||
|
self.image_writer_process = image_writer_process
|
||||||
|
self.image_writer_threads = image_writer_threads
|
||||||
|
|
||||||
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
|
from examples.port_datasets.openx_rlds import create_lerobot_dataset
|
||||||
|
from examples.port_datasets.openx_utils.test import display_slurm_info, display_system_info
|
||||||
|
|
||||||
|
display_system_info()
|
||||||
|
display_slurm_info()
|
||||||
|
|
||||||
|
create_lerobot_dataset(
|
||||||
|
self.raw_dir,
|
||||||
|
f"{self.repo_id}_world_{world_size}_rank_{rank}",
|
||||||
|
image_writer_process=self.image_writer_process,
|
||||||
|
image_writer_threads=self.image_writer_threads,
|
||||||
|
push_to_hub=False,
|
||||||
|
num_shards=world_size,
|
||||||
|
shard_index=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AggregateDatasets(PipelineStep):
|
||||||
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
|
print("aggregation")
|
||||||
|
|
||||||
|
|
||||||
|
def main(slurm=True):
|
||||||
|
for dir_ in Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene").glob("droid_world*"):
|
||||||
|
shutil.rmtree(dir_)
|
||||||
|
|
||||||
|
now = dt.datetime.now()
|
||||||
|
port_job_name = "port_openx_droid"
|
||||||
|
logs_dir = Path("/fsx/remi_cadene/logs")
|
||||||
|
port_log_dir = logs_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{port_job_name}"
|
||||||
|
|
||||||
|
if slurm:
|
||||||
|
executor_class = SlurmPipelineExecutor
|
||||||
|
dist_extra_kwargs = {
|
||||||
|
"job_name": port_job_name,
|
||||||
|
"tasks": 10000,
|
||||||
|
"workers": 24,
|
||||||
|
"time": "00:30:00",
|
||||||
|
"partition": "hopper-cpu",
|
||||||
|
"cpus_per_task": 12,
|
||||||
|
"mem_per_cpu_gb": 4,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
executor_class = LocalPipelineExecutor
|
||||||
|
dist_extra_kwargs = {
|
||||||
|
"tasks": 1,
|
||||||
|
"workers": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
port_executor = executor_class(
|
||||||
|
pipeline=[
|
||||||
|
PortOpenXDataset(raw_dir=Path("/fsx/mustafa_shukor/droid"), repo_id="cadene/droid"),
|
||||||
|
],
|
||||||
|
logging_dir=str(port_log_dir),
|
||||||
|
**dist_extra_kwargs,
|
||||||
|
)
|
||||||
|
port_executor.run()
|
||||||
|
|
||||||
|
# if slurm:
|
||||||
|
# merge_extra_kwargs = {}
|
||||||
|
# else:
|
||||||
|
# merge_extra_kwargs = {
|
||||||
|
# "job_name": "aggregate",
|
||||||
|
# "time": "00:01:00",
|
||||||
|
# "partition": "hopper-cpu",
|
||||||
|
# }
|
||||||
|
|
||||||
|
# merge_executor = executor_class(
|
||||||
|
# depends=dist_executor,
|
||||||
|
# pipeline=[
|
||||||
|
# Aggregate(),
|
||||||
|
# ],
|
||||||
|
# logging_dir=f"/fsx/remi_cadene/logs/openx_rlds_merge",
|
||||||
|
# tasks=1,
|
||||||
|
# workers=1,
|
||||||
|
# **merge_extra_kwargs,
|
||||||
|
# )
|
||||||
|
# merge_executor.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -56,7 +56,9 @@ def zero_action_filter(traj: Dict) -> bool:
|
||||||
0.8897542208433151,
|
0.8897542208433151,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
|
DROID_NORM_0_ACT = (
|
||||||
|
2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
|
||||||
|
)
|
||||||
|
|
||||||
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5)
|
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5)
|
||||||
|
|
||||||
|
@ -799,7 +801,11 @@ OXE_DATASET_CONFIGS = {
|
||||||
},
|
},
|
||||||
### DROID Finetuning datasets
|
### DROID Finetuning datasets
|
||||||
"droid_wipe": {
|
"droid_wipe": {
|
||||||
"image_obs_keys": {"primary": "exterior_image_2_left", "secondary": None, "wrist": "wrist_image_left"},
|
"image_obs_keys": {
|
||||||
|
"primary": "exterior_image_2_left",
|
||||||
|
"secondary": None,
|
||||||
|
"wrist": "wrist_image_left",
|
||||||
|
},
|
||||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||||
"state_obs_keys": ["proprio"],
|
"state_obs_keys": ["proprio"],
|
||||||
"state_encoding": StateEncoding.POS_EULER,
|
"state_encoding": StateEncoding.POS_EULER,
|
|
@ -0,0 +1,30 @@
|
||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=openx_rlds
|
||||||
|
#SBATCH --partition=hopper-cpu
|
||||||
|
#SBATCH --requeue
|
||||||
|
#SBATCH --time=00:01:00
|
||||||
|
#SBATCH --nodes=1
|
||||||
|
#SBATCH --ntasks-per-node=8
|
||||||
|
#SBATCH --output=/fsx/%u/slurm/%j-%x.out
|
||||||
|
|
||||||
|
OUTPUT_DIR="/fsx/${USER}/slurm/${SLURM_JOB_NAME}-${SLURM_JOB_ID}-tasks"
|
||||||
|
mkdir -p $OUTPUT_DIR
|
||||||
|
|
||||||
|
# Function to run a task and redirect output to a separate file
|
||||||
|
run_task() {
|
||||||
|
local task_id=$1
|
||||||
|
local output_file="${OUTPUT_DIR}/task-${task_id}-${SLURM_JOB_ID}.out"
|
||||||
|
|
||||||
|
# Run the task and redirect output
|
||||||
|
python examples/port_datasets/openx_utils/test.py > $output_file 2>&1
|
||||||
|
}
|
||||||
|
|
||||||
|
echo $SBATCH_OUTPUT
|
||||||
|
|
||||||
|
# node has 380850M and 96 cpus
|
||||||
|
trap 'scontrol requeue ${SLURM_JOB_ID}; exit 15' SIGUSR1
|
||||||
|
echo "Starting job"
|
||||||
|
# note the "&" to start srun as a background thread
|
||||||
|
srun python examples/port_datasets/openx_utils/test.py &
|
||||||
|
# wait for signals...
|
||||||
|
wait
|
|
@ -0,0 +1,54 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
|
||||||
|
def display_system_info():
|
||||||
|
# Get the number of CPUs
|
||||||
|
num_cpus = psutil.cpu_count(logical=True)
|
||||||
|
print(f"Number of CPUs: {num_cpus}")
|
||||||
|
|
||||||
|
# Get memory information
|
||||||
|
memory_info = psutil.virtual_memory()
|
||||||
|
total_memory = memory_info.total / (1024**3) # Convert bytes to GB
|
||||||
|
available_memory = memory_info.available / (1024**3) # Convert bytes to GB
|
||||||
|
used_memory = memory_info.used / (1024**3) # Convert bytes to GB
|
||||||
|
|
||||||
|
print(f"Total Memory: {total_memory:.2f} GB")
|
||||||
|
print(f"Available Memory: {available_memory:.2f} GB")
|
||||||
|
print(f"Used Memory: {used_memory:.2f} GB")
|
||||||
|
|
||||||
|
|
||||||
|
def display_slurm_info():
|
||||||
|
# Get SLURM job ID
|
||||||
|
job_id = os.getenv("SLURM_JOB_ID")
|
||||||
|
print(f"SLURM Job ID: {job_id}")
|
||||||
|
|
||||||
|
# Get SLURM job name
|
||||||
|
job_name = os.getenv("SLURM_JOB_NAME")
|
||||||
|
print(f"SLURM Job Name: {job_name}")
|
||||||
|
|
||||||
|
# Get the number of tasks
|
||||||
|
num_tasks = os.getenv("SLURM_NTASKS")
|
||||||
|
print(f"Number of Tasks: {num_tasks}")
|
||||||
|
|
||||||
|
# Get the number of nodes
|
||||||
|
num_nodes = os.getenv("SLURM_NNODES")
|
||||||
|
print(f"Number of Nodes: {num_nodes}")
|
||||||
|
|
||||||
|
# Get the number of CPUs per task
|
||||||
|
cpus_per_task = os.getenv("SLURM_CPUS_PER_TASK")
|
||||||
|
print(f"CPUs per Task: {cpus_per_task}")
|
||||||
|
|
||||||
|
# Get the node list
|
||||||
|
node_list = os.getenv("SLURM_NODELIST")
|
||||||
|
print(f"Node List: {node_list}")
|
||||||
|
|
||||||
|
# Get the task ID (only available within an srun task)
|
||||||
|
task_id = os.getenv("SLURM_PROCID")
|
||||||
|
print(f"Task ID: {task_id}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
display_system_info()
|
||||||
|
display_slurm_info()
|
|
@ -2,7 +2,6 @@
|
||||||
Copied from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/utils/data_utils.py
|
Copied from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/utils/data_utils.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -66,6 +65,7 @@ def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||||
|
|
||||||
return new_actions
|
return new_actions
|
||||||
|
|
||||||
|
|
||||||
# === Bridge-V2 =>> Dataset-Specific Transform ===
|
# === Bridge-V2 =>> Dataset-Specific Transform ===
|
||||||
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
|
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
|
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
|
|
@ -19,7 +19,8 @@ Transforms adopt the following structure:
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from oxe_utils.transform_utils import (
|
|
||||||
|
from examples.port_datasets.openx_utils.transform_utils import (
|
||||||
binarize_gripper_actions,
|
binarize_gripper_actions,
|
||||||
invert_gripper_actions,
|
invert_gripper_actions,
|
||||||
rel2abs_gripper_actions,
|
rel2abs_gripper_actions,
|
||||||
|
@ -31,6 +32,7 @@ def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def rand_swap_exterior_images(img1, img2):
|
def rand_swap_exterior_images(img1, img2):
|
||||||
"""
|
"""
|
||||||
Randomly swaps the two exterior images (for training with single exterior input).
|
Randomly swaps the two exterior images (for training with single exterior input).
|
||||||
|
@ -96,7 +98,7 @@ def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|
||||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||||
"""
|
"""
|
||||||
for key in trajectory.keys():
|
for key in trajectory:
|
||||||
if key == "traj_metadata":
|
if key == "traj_metadata":
|
||||||
continue
|
continue
|
||||||
elif key in ["observation", "action"]:
|
elif key in ["observation", "action"]:
|
||||||
|
@ -126,7 +128,7 @@ def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|
||||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||||
"""
|
"""
|
||||||
for key in trajectory.keys():
|
for key in trajectory:
|
||||||
if key == "traj_metadata":
|
if key == "traj_metadata":
|
||||||
continue
|
continue
|
||||||
elif key == "observation":
|
elif key == "observation":
|
||||||
|
@ -198,7 +200,9 @@ def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
)
|
)
|
||||||
eef_value = tf.io.decode_raw(eef_value, tf.float32)
|
eef_value = tf.io.decode_raw(eef_value, tf.float32)
|
||||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
|
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
|
||||||
gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB")
|
gripper_value = tf.io.decode_compressed(
|
||||||
|
trajectory["observation"]["gripper_closed"], compression_type="ZLIB"
|
||||||
|
)
|
||||||
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
|
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
|
||||||
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
|
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
|
||||||
# trajectory["language_instruction"] = tf.fill(
|
# trajectory["language_instruction"] = tf.fill(
|
||||||
|
@ -228,7 +232,9 @@ def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|
||||||
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
|
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
|
||||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][:, -1:]
|
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][
|
||||||
|
:, -1:
|
||||||
|
]
|
||||||
|
|
||||||
# make gripper action absolute action, +1 = open, 0 = close
|
# make gripper action absolute action, +1 = open, 0 = close
|
||||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||||
|
@ -264,7 +270,9 @@ def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict
|
||||||
|
|
||||||
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
# invert absolute gripper action, +1 = open, 0 = close
|
# invert absolute gripper action, +1 = open, 0 = close
|
||||||
gripper_action = invert_gripper_actions(tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1))
|
gripper_action = invert_gripper_actions(
|
||||||
|
tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
trajectory["action"] = tf.concat(
|
||||||
(
|
(
|
||||||
|
@ -374,7 +382,9 @@ def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, An
|
||||||
instruction_bytes = trajectory["observation"]["instruction"]
|
instruction_bytes = trajectory["observation"]["instruction"]
|
||||||
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
|
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
|
||||||
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
|
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
|
||||||
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[:, 0]
|
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[
|
||||||
|
:, 0
|
||||||
|
]
|
||||||
return trajectory
|
return trajectory
|
||||||
|
|
||||||
|
|
||||||
|
@ -900,7 +910,9 @@ def libero_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
axis=1,
|
axis=1,
|
||||||
)
|
)
|
||||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -2:] # 2D gripper state
|
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][
|
||||||
|
:, -2:
|
||||||
|
] # 2D gripper state
|
||||||
return trajectory
|
return trajectory
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue