From 35de91ef2bed8d25ef6aa40e6ff8514a39666436 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 30 Dec 2024 13:47:28 +0000 Subject: [PATCH] added temporary fix for missing task_index key in online environment --- lerobot/common/policies/sac/configuration_sac.py | 1 + lerobot/scripts/train.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 5f676933..4ae6e5d4 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -50,6 +50,7 @@ class SACConfig: state_encoder_hidden_dim = 256 latent_dim = 128 target_entropy = None + backup_entropy = True critic_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index fbe7927d..a4eb3528 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -322,6 +322,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("make_dataset") offline_dataset = make_dataset(cfg) + # TODO (michel-aractingi): temporary fix to avoid datasets with task_index key that doesn't exist in online environment + # i.e., pusht + if "task_index" in offline_dataset.hf_dataset[0]: + offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(["task_index"]) + if isinstance(offline_dataset, MultiLeRobotDataset): logging.info( "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "