let's go
This commit is contained in:
parent
689c5efc72
commit
39ad2d16d4
|
@ -34,18 +34,20 @@ python examples/port_datasets/openx_rlds.py \
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_datasets as tfds
|
import tensorflow_datasets as tfds
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from examples.port_datasets.openx_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
|
from examples.port_datasets.openx_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
|
||||||
from examples.port_datasets.openx_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
from examples.port_datasets.openx_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
|
||||||
|
|
||||||
np.set_printoptions(precision=2)
|
np.set_printoptions(precision=2)
|
||||||
|
|
||||||
|
@ -138,13 +140,14 @@ def save_as_lerobot_dataset(
|
||||||
num_shards: int | None = None,
|
num_shards: int | None = None,
|
||||||
shard_index: int | None = None,
|
shard_index: int | None = None,
|
||||||
):
|
):
|
||||||
|
start_time = time.time()
|
||||||
total_num_episodes = raw_dataset.cardinality().numpy().item()
|
total_num_episodes = raw_dataset.cardinality().numpy().item()
|
||||||
print(f"Total number of episodes {total_num_episodes}")
|
logging.info(f"Total number of episodes {total_num_episodes}")
|
||||||
|
|
||||||
if num_shards is not None:
|
if num_shards is not None:
|
||||||
sharded_dataset = raw_dataset.shard(num_shards=num_shards, index=shard_index)
|
sharded_dataset = raw_dataset.shard(num_shards=num_shards, index=shard_index)
|
||||||
sharded_num_episodes = sharded_dataset.cardinality().numpy().item()
|
sharded_num_episodes = sharded_dataset.cardinality().numpy().item()
|
||||||
print(f"{sharded_num_episodes=}")
|
logging.info(f"{sharded_num_episodes=}")
|
||||||
num_episodes = sharded_num_episodes
|
num_episodes = sharded_num_episodes
|
||||||
iter_ = iter(sharded_dataset)
|
iter_ = iter(sharded_dataset)
|
||||||
else:
|
else:
|
||||||
|
@ -155,13 +158,18 @@ def save_as_lerobot_dataset(
|
||||||
raise ValueError(f"Number of episodes is {num_episodes}, but needs to be positive.")
|
raise ValueError(f"Number of episodes is {num_episodes}, but needs to be positive.")
|
||||||
|
|
||||||
for episode_index in range(num_episodes):
|
for episode_index in range(num_episodes):
|
||||||
print(f"{episode_index} / {num_episodes}")
|
logging.info(f"{episode_index} / {num_episodes} episodes processed")
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
|
||||||
|
logging.info(f"It has been {d} days, {h} hours, {m} minutes, {s:.3f} seconds")
|
||||||
|
|
||||||
episode = next(iter_)
|
episode = next(iter_)
|
||||||
print("\nnext\n")
|
logging.info("next")
|
||||||
episode = transform_raw_dataset(episode, dataset_name)
|
episode = transform_raw_dataset(episode, dataset_name)
|
||||||
|
|
||||||
traj = episode["steps"]
|
traj = episode["steps"]
|
||||||
for i in tqdm.tqdm(range(traj["action"].shape[0])):
|
for i in range(traj["action"].shape[0]):
|
||||||
image_dict = {
|
image_dict = {
|
||||||
f"observation.images.{key}": value[i].numpy()
|
f"observation.images.{key}": value[i].numpy()
|
||||||
for key, value in traj["observation"].items()
|
for key, value in traj["observation"].items()
|
||||||
|
@ -176,9 +184,8 @@ def save_as_lerobot_dataset(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
print()
|
|
||||||
lerobot_dataset.save_episode()
|
lerobot_dataset.save_episode()
|
||||||
print("\nsave_episode\n")
|
logging.info("save_episode")
|
||||||
|
|
||||||
|
|
||||||
def create_lerobot_dataset(
|
def create_lerobot_dataset(
|
||||||
|
|
|
@ -21,8 +21,14 @@ class PortOpenXDataset(PipelineStep):
|
||||||
self.image_writer_threads = image_writer_threads
|
self.image_writer_threads = image_writer_threads
|
||||||
|
|
||||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
|
from datasets.utils.tqdm import disable_progress_bars
|
||||||
|
|
||||||
from examples.port_datasets.openx_rlds import create_lerobot_dataset
|
from examples.port_datasets.openx_rlds import create_lerobot_dataset
|
||||||
from examples.port_datasets.openx_utils.test import display_slurm_info, display_system_info
|
from examples.port_datasets.openx_utils.test import display_slurm_info, display_system_info
|
||||||
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
disable_progress_bars()
|
||||||
|
|
||||||
display_system_info()
|
display_system_info()
|
||||||
display_slurm_info()
|
display_slurm_info()
|
||||||
|
@ -48,18 +54,22 @@ def main(slurm=True):
|
||||||
# for dir_ in Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene").glob("droid_world*"):
|
# for dir_ in Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene").glob("droid_world*"):
|
||||||
# shutil.rmtree(dir_)
|
# shutil.rmtree(dir_)
|
||||||
|
|
||||||
now = dt.datetime.now()
|
|
||||||
port_job_name = "port_openx_droid"
|
port_job_name = "port_openx_droid"
|
||||||
logs_dir = Path("/fsx/remi_cadene/logs")
|
logs_dir = Path("/fsx/remi_cadene/logs")
|
||||||
# port_log_dir = logs_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{port_job_name}"
|
|
||||||
port_log_dir = Path("/fsx/remi_cadene/logs/2025-02-22_00-12-00_port_openx_droid")
|
now = dt.datetime.now()
|
||||||
|
datetime = f"{now:%Y-%m-%d}_{now:%H-%M-%S}"
|
||||||
|
# datetime = "2025-02-22_11-17-00"
|
||||||
|
|
||||||
|
port_log_dir = logs_dir / f"{datetime}_{port_job_name}"
|
||||||
|
|
||||||
if slurm:
|
if slurm:
|
||||||
executor_class = SlurmPipelineExecutor
|
executor_class = SlurmPipelineExecutor
|
||||||
dist_extra_kwargs = {
|
dist_extra_kwargs = {
|
||||||
"job_name": port_job_name,
|
"job_name": port_job_name,
|
||||||
"tasks": 10000,
|
"tasks": 2048,
|
||||||
"workers": 20, # 8 * 16,
|
# "workers": 20, # 8 * 16,
|
||||||
|
"workers": 1, # 8 * 16,
|
||||||
"time": "08:00:00",
|
"time": "08:00:00",
|
||||||
"partition": "hopper-cpu",
|
"partition": "hopper-cpu",
|
||||||
"cpus_per_task": 24,
|
"cpus_per_task": 24,
|
||||||
|
@ -75,9 +85,7 @@ def main(slurm=True):
|
||||||
|
|
||||||
port_executor = executor_class(
|
port_executor = executor_class(
|
||||||
pipeline=[
|
pipeline=[
|
||||||
PortOpenXDataset(
|
PortOpenXDataset(raw_dir=Path("/fsx/mustafa_shukor/droid"), repo_id=f"cadene/droid_{datetime}"),
|
||||||
raw_dir=Path("/fsx/mustafa_shukor/droid"), repo_id="cadene/droid_2025-02-22_00-12-00"
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
logging_dir=str(port_log_dir),
|
logging_dir=str(port_log_dir),
|
||||||
**dist_extra_kwargs,
|
**dist_extra_kwargs,
|
||||||
|
|
|
@ -136,7 +136,7 @@ def encode_video_frames(
|
||||||
g: int | None = 2,
|
g: int | None = 2,
|
||||||
crf: int | None = 30,
|
crf: int | None = 30,
|
||||||
fast_decode: int = 0,
|
fast_decode: int = 0,
|
||||||
log_level: str | None = "error",
|
log_level: str | None = "quiet",
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||||
|
|
|
@ -216,3 +216,13 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# If a TypeError is raised, the string is not a valid dtype
|
# If a TypeError is raised, the string is not a valid dtype
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float):
|
||||||
|
days = int(elapsed_time_s // (24 * 3600))
|
||||||
|
elapsed_time_s %= 24 * 3600
|
||||||
|
hours = int(elapsed_time_s // 3600)
|
||||||
|
elapsed_time_s %= 3600
|
||||||
|
minutes = int(elapsed_time_s // 60)
|
||||||
|
seconds = elapsed_time_s % 60
|
||||||
|
return days, hours, minutes, seconds
|
||||||
|
|
Loading…
Reference in New Issue