diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 3a3c1c57..fee598d8 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -13,13 +13,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any import warnings +from typing import Any + import einops +import gym import numpy as np import torch from torch import Tensor -import gym + from lerobot.common.envs.configs import EnvConfig from lerobot.common.utils.utils import get_channel_first_image_shape from lerobot.configs.types import FeatureType, PolicyFeature @@ -89,27 +91,30 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: return policy_features + def check_all_envs_same_type(env: gym.vector.VectorEnv) -> bool: first_type = type(env.envs[0]) # Get type of first env return all(type(e) is first_type for e in env.envs) # Fast type check + def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None: with warnings.catch_warnings(): warnings.simplefilter("once", UserWarning) # Apply filter only in this function - + if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")): warnings.warn( "The environment does not have 'task_description' and 'task'. Some policies require these features.", UserWarning, - stacklevel=2 + stacklevel=2, ) if not check_all_envs_same_type(env): warnings.warn( "The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.", UserWarning, - stacklevel=2 + stacklevel=2, ) + def infer_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]: if hasattr(env.envs[0], "task_description"): observation["task"] = env.call("task_description") @@ -118,4 +123,4 @@ def infer_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> d else: # For envs without language instructions, e.g. aloha transfer cube and etc. num_envs = observation[list(observation.keys())[0]].shape[0] observation["task"] = ["" for _ in range(num_envs)] - return observation \ No newline at end of file + return observation diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index cc8c298d..4cad2dd1 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -66,7 +66,7 @@ from torch import Tensor, nn from tqdm import trange from lerobot.common.envs.factory import make_env -from lerobot.common.envs.utils import preprocess_observation, infer_envs_task, check_env_attributes_and_types +from lerobot.common.envs.utils import check_env_attributes_and_types, infer_envs_task, preprocess_observation from lerobot.common.policies.factory import make_policy from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.utils import get_device_from_parameters @@ -158,7 +158,7 @@ def rollout( # Infer "task" from envs. Works with AsyncVectorEnv. observation = infer_envs_task(env, observation) - + with torch.inference_mode(): action = policy.select_action(observation)