testing syncenv

This commit is contained in:
mshukor 2025-04-02 19:40:04 +02:00
parent df2a312174
commit 6b2c91255a
3 changed files with 4 additions and 4 deletions

View File

@ -17,7 +17,7 @@ import warnings
from typing import Any from typing import Any
import einops import einops
import gym import gymnasium as gym
import numpy as np import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor

View File

@ -66,7 +66,7 @@ from torch import Tensor, nn
from tqdm import trange from tqdm import trange
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import check_env_attributes_and_types, preprocess_observation from lerobot.common.envs.utils import check_env_attributes_and_types, preprocess_observation, add_envs_task
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.policies.utils import get_device_from_parameters
@ -157,7 +157,7 @@ def rollout(
} }
# Infer "task" from attributes of environments. Works with AsyncVectorEnv. # Infer "task" from attributes of environments. Works with AsyncVectorEnv.
observation = addd_envs_task(env, observation) observation = add_envs_task(env, observation)
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation) action = policy.select_action(observation)

View File

@ -133,7 +133,7 @@ def train(cfg: TrainPipelineConfig):
eval_env = None eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None: if cfg.eval_freq > 0 and cfg.env is not None:
logging.info("Creating env") logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size) eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Creating policy") logging.info("Creating policy")
policy = make_policy( policy = make_policy(