optimize shard
This commit is contained in:
parent
eda0b996cd
commit
689c5efc72
|
@ -205,7 +205,16 @@ def create_lerobot_dataset(
|
|||
|
||||
builder = tfds.builder(dataset_name, data_dir=data_dir, version=version)
|
||||
features = generate_features_from_raw(dataset_name, builder, use_videos)
|
||||
raw_dataset = builder.as_dataset(split="train")
|
||||
|
||||
if num_shards is not None:
|
||||
if num_shards != builder.info.splits["train"].num_shards:
|
||||
raise ValueError()
|
||||
if shard_index >= builder.info.splits["train"].num_shards:
|
||||
raise ValueError()
|
||||
|
||||
raw_dataset = builder.as_dataset(split=f"train[{shard_index}shard]")
|
||||
else:
|
||||
raw_dataset = builder.as_dataset(split="train")
|
||||
|
||||
if fps is None:
|
||||
if dataset_name in OXE_DATASET_CONFIGS:
|
||||
|
@ -234,8 +243,6 @@ def create_lerobot_dataset(
|
|||
dataset_name,
|
||||
lerobot_dataset,
|
||||
raw_dataset,
|
||||
num_shards=num_shards,
|
||||
shard_index=shard_index,
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
|
|
|
@ -29,7 +29,7 @@ class PortOpenXDataset(PipelineStep):
|
|||
|
||||
create_lerobot_dataset(
|
||||
self.raw_dir,
|
||||
f"{self.repo_id}_2025-02-22_00-12-00_world_{world_size}_rank_{rank}",
|
||||
f"{self.repo_id}_world_{world_size}_rank_{rank}",
|
||||
image_writer_process=self.image_writer_process,
|
||||
image_writer_threads=self.image_writer_threads,
|
||||
push_to_hub=False,
|
||||
|
@ -64,7 +64,7 @@ def main(slurm=True):
|
|||
"partition": "hopper-cpu",
|
||||
"cpus_per_task": 24,
|
||||
"mem_per_cpu_gb": 2,
|
||||
"max_array_launch_parallel": True,
|
||||
# "max_array_launch_parallel": True,
|
||||
}
|
||||
else:
|
||||
executor_class = LocalPipelineExecutor
|
||||
|
@ -75,7 +75,9 @@ def main(slurm=True):
|
|||
|
||||
port_executor = executor_class(
|
||||
pipeline=[
|
||||
PortOpenXDataset(raw_dir=Path("/fsx/mustafa_shukor/droid"), repo_id="cadene/droid"),
|
||||
PortOpenXDataset(
|
||||
raw_dir=Path("/fsx/mustafa_shukor/droid"), repo_id="cadene/droid_2025-02-22_00-12-00"
|
||||
),
|
||||
],
|
||||
logging_dir=str(port_log_dir),
|
||||
**dist_extra_kwargs,
|
||||
|
|
Loading…
Reference in New Issue