This commit is contained in:
Remi Cadene 2025-02-22 11:12:39 +00:00
parent 689c5efc72
commit 39ad2d16d4
4 changed files with 42 additions and 17 deletions

View File

@ -34,18 +34,20 @@ python examples/port_datasets/openx_rlds.py \
""" """
import argparse import argparse
import logging
import re import re
import shutil import shutil
import time
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
import tqdm
from examples.port_datasets.openx_utils.configs import OXE_DATASET_CONFIGS, StateEncoding from examples.port_datasets.openx_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
from examples.port_datasets.openx_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS from examples.port_datasets.openx_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
np.set_printoptions(precision=2) np.set_printoptions(precision=2)
@ -138,13 +140,14 @@ def save_as_lerobot_dataset(
num_shards: int | None = None, num_shards: int | None = None,
shard_index: int | None = None, shard_index: int | None = None,
): ):
start_time = time.time()
total_num_episodes = raw_dataset.cardinality().numpy().item() total_num_episodes = raw_dataset.cardinality().numpy().item()
print(f"Total number of episodes {total_num_episodes}") logging.info(f"Total number of episodes {total_num_episodes}")
if num_shards is not None: if num_shards is not None:
sharded_dataset = raw_dataset.shard(num_shards=num_shards, index=shard_index) sharded_dataset = raw_dataset.shard(num_shards=num_shards, index=shard_index)
sharded_num_episodes = sharded_dataset.cardinality().numpy().item() sharded_num_episodes = sharded_dataset.cardinality().numpy().item()
print(f"{sharded_num_episodes=}") logging.info(f"{sharded_num_episodes=}")
num_episodes = sharded_num_episodes num_episodes = sharded_num_episodes
iter_ = iter(sharded_dataset) iter_ = iter(sharded_dataset)
else: else:
@ -155,13 +158,18 @@ def save_as_lerobot_dataset(
raise ValueError(f"Number of episodes is {num_episodes}, but needs to be positive.") raise ValueError(f"Number of episodes is {num_episodes}, but needs to be positive.")
for episode_index in range(num_episodes): for episode_index in range(num_episodes):
print(f"{episode_index} / {num_episodes}") logging.info(f"{episode_index} / {num_episodes} episodes processed")
elapsed_time = time.time() - start_time
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
logging.info(f"It has been {d} days, {h} hours, {m} minutes, {s:.3f} seconds")
episode = next(iter_) episode = next(iter_)
print("\nnext\n") logging.info("next")
episode = transform_raw_dataset(episode, dataset_name) episode = transform_raw_dataset(episode, dataset_name)
traj = episode["steps"] traj = episode["steps"]
for i in tqdm.tqdm(range(traj["action"].shape[0])): for i in range(traj["action"].shape[0]):
image_dict = { image_dict = {
f"observation.images.{key}": value[i].numpy() f"observation.images.{key}": value[i].numpy()
for key, value in traj["observation"].items() for key, value in traj["observation"].items()
@ -176,9 +184,8 @@ def save_as_lerobot_dataset(
} }
) )
print()
lerobot_dataset.save_episode() lerobot_dataset.save_episode()
print("\nsave_episode\n") logging.info("save_episode")
def create_lerobot_dataset( def create_lerobot_dataset(

View File

@ -21,8 +21,14 @@ class PortOpenXDataset(PipelineStep):
self.image_writer_threads = image_writer_threads self.image_writer_threads = image_writer_threads
def run(self, data=None, rank: int = 0, world_size: int = 1): def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
from examples.port_datasets.openx_rlds import create_lerobot_dataset from examples.port_datasets.openx_rlds import create_lerobot_dataset
from examples.port_datasets.openx_utils.test import display_slurm_info, display_system_info from examples.port_datasets.openx_utils.test import display_slurm_info, display_system_info
from lerobot.common.utils.utils import init_logging
init_logging()
disable_progress_bars()
display_system_info() display_system_info()
display_slurm_info() display_slurm_info()
@ -48,18 +54,22 @@ def main(slurm=True):
# for dir_ in Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene").glob("droid_world*"): # for dir_ in Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene").glob("droid_world*"):
# shutil.rmtree(dir_) # shutil.rmtree(dir_)
now = dt.datetime.now()
port_job_name = "port_openx_droid" port_job_name = "port_openx_droid"
logs_dir = Path("/fsx/remi_cadene/logs") logs_dir = Path("/fsx/remi_cadene/logs")
# port_log_dir = logs_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{port_job_name}"
port_log_dir = Path("/fsx/remi_cadene/logs/2025-02-22_00-12-00_port_openx_droid") now = dt.datetime.now()
datetime = f"{now:%Y-%m-%d}_{now:%H-%M-%S}"
# datetime = "2025-02-22_11-17-00"
port_log_dir = logs_dir / f"{datetime}_{port_job_name}"
if slurm: if slurm:
executor_class = SlurmPipelineExecutor executor_class = SlurmPipelineExecutor
dist_extra_kwargs = { dist_extra_kwargs = {
"job_name": port_job_name, "job_name": port_job_name,
"tasks": 10000, "tasks": 2048,
"workers": 20, # 8 * 16, # "workers": 20, # 8 * 16,
"workers": 1, # 8 * 16,
"time": "08:00:00", "time": "08:00:00",
"partition": "hopper-cpu", "partition": "hopper-cpu",
"cpus_per_task": 24, "cpus_per_task": 24,
@ -75,9 +85,7 @@ def main(slurm=True):
port_executor = executor_class( port_executor = executor_class(
pipeline=[ pipeline=[
PortOpenXDataset( PortOpenXDataset(raw_dir=Path("/fsx/mustafa_shukor/droid"), repo_id=f"cadene/droid_{datetime}"),
raw_dir=Path("/fsx/mustafa_shukor/droid"), repo_id="cadene/droid_2025-02-22_00-12-00"
),
], ],
logging_dir=str(port_log_dir), logging_dir=str(port_log_dir),
**dist_extra_kwargs, **dist_extra_kwargs,

View File

@ -136,7 +136,7 @@ def encode_video_frames(
g: int | None = 2, g: int | None = 2,
crf: int | None = 30, crf: int | None = 30,
fast_decode: int = 0, fast_decode: int = 0,
log_level: str | None = "error", log_level: str | None = "quiet",
overwrite: bool = False, overwrite: bool = False,
) -> None: ) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`""" """More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""

View File

@ -216,3 +216,13 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
except TypeError: except TypeError:
# If a TypeError is raised, the string is not a valid dtype # If a TypeError is raised, the string is not a valid dtype
return False return False
def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float):
days = int(elapsed_time_s // (24 * 3600))
elapsed_time_s %= 24 * 3600
hours = int(elapsed_time_s // 3600)
elapsed_time_s %= 3600
minutes = int(elapsed_time_s // 60)
seconds = elapsed_time_s % 60
return days, hours, minutes, seconds