diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index b73b1171..468e7a1a 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -137,10 +137,7 @@ def train(cfg: TrainPipelineConfig): eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) logging.info("Creating policy") - if isinstance(dataset, MultiLeRobotDataset): - ds_meta = dataset._datasets[0].meta - else: - ds_meta = dataset.meta + ds_meta = dataset._datasets[0].meta if isinstance(dataset, MultiLeRobotDataset) else dataset.meta policy = make_policy( cfg=cfg.policy, ds_meta=ds_meta,