[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-04-02 15:00:40 +00:00
parent 1daa9aae9e
commit 7755d43127
2 changed files with 13 additions and 8 deletions

View File

@ -13,13 +13,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import warnings
from typing import Any
import einops
import gym
import numpy as np
import torch
from torch import Tensor
import gym
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.utils.utils import get_channel_first_image_shape
from lerobot.configs.types import FeatureType, PolicyFeature
@ -89,27 +91,30 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
return policy_features
def check_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
first_type = type(env.envs[0]) # Get type of first env
return all(type(e) is first_type for e in env.envs) # Fast type check
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
with warnings.catch_warnings():
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
warnings.warn(
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
UserWarning,
stacklevel=2
stacklevel=2,
)
if not check_all_envs_same_type(env):
warnings.warn(
"The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.",
UserWarning,
stacklevel=2
stacklevel=2,
)
def infer_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
if hasattr(env.envs[0], "task_description"):
observation["task"] = env.call("task_description")
@ -118,4 +123,4 @@ def infer_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> d
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
num_envs = observation[list(observation.keys())[0]].shape[0]
observation["task"] = ["" for _ in range(num_envs)]
return observation
return observation

View File

@ -66,7 +66,7 @@ from torch import Tensor, nn
from tqdm import trange
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation, infer_envs_task, check_env_attributes_and_types
from lerobot.common.envs.utils import check_env_attributes_and_types, infer_envs_task, preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters
@ -158,7 +158,7 @@ def rollout(
# Infer "task" from envs. Works with AsyncVectorEnv.
observation = infer_envs_task(env, observation)
with torch.inference_mode():
action = policy.select_action(observation)