added temporary fix for missing task_index key in online environment
This commit is contained in:
parent
ee306e2f9b
commit
35de91ef2b
|
@ -50,6 +50,7 @@ class SACConfig:
|
||||||
state_encoder_hidden_dim = 256
|
state_encoder_hidden_dim = 256
|
||||||
latent_dim = 128
|
latent_dim = 128
|
||||||
target_entropy = None
|
target_entropy = None
|
||||||
|
backup_entropy = True
|
||||||
critic_network_kwargs = {
|
critic_network_kwargs = {
|
||||||
"hidden_dims": [256, 256],
|
"hidden_dims": [256, 256],
|
||||||
"activate_final": True,
|
"activate_final": True,
|
||||||
|
|
|
@ -322,6 +322,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
logging.info("make_dataset")
|
logging.info("make_dataset")
|
||||||
offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
|
# TODO (michel-aractingi): temporary fix to avoid datasets with task_index key that doesn't exist in online environment
|
||||||
|
# i.e., pusht
|
||||||
|
if "task_index" in offline_dataset.hf_dataset[0]:
|
||||||
|
offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(["task_index"])
|
||||||
|
|
||||||
if isinstance(offline_dataset, MultiLeRobotDataset):
|
if isinstance(offline_dataset, MultiLeRobotDataset):
|
||||||
logging.info(
|
logging.info(
|
||||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||||
|
|
Loading…
Reference in New Issue