diff --git a/examples/port_datasets/openx_rlds.py b/examples/port_datasets/openx_rlds.py index 99ff6d4d..c57a0c52 100644 --- a/examples/port_datasets/openx_rlds.py +++ b/examples/port_datasets/openx_rlds.py @@ -205,7 +205,16 @@ def create_lerobot_dataset( builder = tfds.builder(dataset_name, data_dir=data_dir, version=version) features = generate_features_from_raw(dataset_name, builder, use_videos) - raw_dataset = builder.as_dataset(split="train") + + if num_shards is not None: + if num_shards != builder.info.splits["train"].num_shards: + raise ValueError() + if shard_index >= builder.info.splits["train"].num_shards: + raise ValueError() + + raw_dataset = builder.as_dataset(split=f"train[{shard_index}shard]") + else: + raw_dataset = builder.as_dataset(split="train") if fps is None: if dataset_name in OXE_DATASET_CONFIGS: @@ -234,8 +243,6 @@ def create_lerobot_dataset( dataset_name, lerobot_dataset, raw_dataset, - num_shards=num_shards, - shard_index=shard_index, ) if push_to_hub: diff --git a/examples/port_datasets/openx_rlds_datatrove.py b/examples/port_datasets/openx_rlds_datatrove.py index 9ecbec59..32ccaa1a 100644 --- a/examples/port_datasets/openx_rlds_datatrove.py +++ b/examples/port_datasets/openx_rlds_datatrove.py @@ -29,7 +29,7 @@ class PortOpenXDataset(PipelineStep): create_lerobot_dataset( self.raw_dir, - f"{self.repo_id}_2025-02-22_00-12-00_world_{world_size}_rank_{rank}", + f"{self.repo_id}_world_{world_size}_rank_{rank}", image_writer_process=self.image_writer_process, image_writer_threads=self.image_writer_threads, push_to_hub=False, @@ -64,7 +64,7 @@ def main(slurm=True): "partition": "hopper-cpu", "cpus_per_task": 24, "mem_per_cpu_gb": 2, - "max_array_launch_parallel": True, + # "max_array_launch_parallel": True, } else: executor_class = LocalPipelineExecutor @@ -75,7 +75,9 @@ def main(slurm=True): port_executor = executor_class( pipeline=[ - PortOpenXDataset(raw_dir=Path("/fsx/mustafa_shukor/droid"), repo_id="cadene/droid"), + PortOpenXDataset( + raw_dir=Path("/fsx/mustafa_shukor/droid"), repo_id="cadene/droid_2025-02-22_00-12-00" + ), ], logging_dir=str(port_log_dir), **dist_extra_kwargs,