added temporary fix for missing task_index key in online environment

This commit is contained in:
Michel Aractingi 2024-12-30 13:47:28 +00:00
parent 41b377211c
commit 13441f0d98
1 changed files with 5 additions and 0 deletions

View File

@ -322,6 +322,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset")
offline_dataset = make_dataset(cfg)
# TODO (michel-aractingi): temporary fix to avoid datasets with task_index key that doesn't exist in online environment
# i.e., pusht
if "task_index" in offline_dataset.hf_dataset[0]:
offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(["task_index"])
if isinstance(offline_dataset, MultiLeRobotDataset):
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "