From 6b2c91255af7295640cbdda62b0de7a5d79cd60f Mon Sep 17 00:00:00 2001 From: mshukor Date: Wed, 2 Apr 2025 19:40:04 +0200 Subject: [PATCH] testing syncenv --- lerobot/common/envs/utils.py | 2 +- lerobot/scripts/eval.py | 4 ++-- lerobot/scripts/train.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index c51e6090..83334f87 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -17,7 +17,7 @@ import warnings from typing import Any import einops -import gym +import gymnasium as gym import numpy as np import torch from torch import Tensor diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 2ad3b36e..e6300dbd 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -66,7 +66,7 @@ from torch import Tensor, nn from tqdm import trange from lerobot.common.envs.factory import make_env -from lerobot.common.envs.utils import check_env_attributes_and_types, preprocess_observation +from lerobot.common.envs.utils import check_env_attributes_and_types, preprocess_observation, add_envs_task from lerobot.common.policies.factory import make_policy from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.utils import get_device_from_parameters @@ -157,7 +157,7 @@ def rollout( } # Infer "task" from attributes of environments. Works with AsyncVectorEnv. - observation = addd_envs_task(env, observation) + observation = add_envs_task(env, observation) with torch.inference_mode(): action = policy.select_action(observation) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f2b1e29e..0de247be 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -133,7 +133,7 @@ def train(cfg: TrainPipelineConfig): eval_env = None if cfg.eval_freq > 0 and cfg.env is not None: logging.info("Creating env") - eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size) + eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) logging.info("Creating policy") policy = make_policy(