diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 6259ca94..d000d1dd 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field import draccus -from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT +from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.configs.types import FeatureType, PolicyFeature @@ -39,7 +39,7 @@ class AlohaEnv(EnvConfig): features_map: dict[str, str] = field( default_factory=lambda: { "action": ACTION, - "agent_pos": OBS_ROBOT, + "agent_pos": OBS_STATE, "top": f"{OBS_IMAGE}.top", "pixels/top": f"{OBS_IMAGES}.top", } @@ -80,8 +80,8 @@ class PushtEnv(EnvConfig): features_map: dict[str, str] = field( default_factory=lambda: { "action": ACTION, - "agent_pos": OBS_ROBOT, - "environment_state": OBS_ENV, + "agent_pos": OBS_STATE, + "environment_state": OBS_ENV_STATE, "pixels": OBS_IMAGE, } ) @@ -122,7 +122,7 @@ class XarmEnv(EnvConfig): features_map: dict[str, str] = field( default_factory=lambda: { "action": ACTION, - "agent_pos": OBS_ROBOT, + "agent_pos": OBS_STATE, "pixels": OBS_IMAGE, } ) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 9ecadcb0..3edaf852 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -33,7 +33,7 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from torch import Tensor, nn -from lerobot.common.constants import OBS_ENV, OBS_ROBOT +from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -238,8 +238,8 @@ class DiffusionModel(nn.Module): def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor: """Encode image features and concatenate them all together along with the state vector.""" - batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2] - global_cond_feats = [batch[OBS_ROBOT]] + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] + global_cond_feats = [batch[OBS_STATE]] # Extract image features. if self.config.image_features: if self.config.use_separate_rgb_encoder_per_camera: @@ -269,7 +269,7 @@ class DiffusionModel(nn.Module): global_cond_feats.append(img_features) if self.config.env_state_feature: - global_cond_feats.append(batch[OBS_ENV]) + global_cond_feats.append(batch[OBS_ENV_STATE]) # Concatenate features then flatten to (B, global_cond_dim). return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1) diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index c8b12caf..555a86bd 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -57,7 +57,7 @@ import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn from transformers import AutoTokenizer -from lerobot.common.constants import ACTION, OBS_ROBOT +from lerobot.common.constants import ACTION, OBS_STATE from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0.paligemma_with_expert import ( @@ -271,7 +271,7 @@ class PI0Policy(PreTrainedPolicy): self.eval() if self.config.adapt_to_pi_aloha: - batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch = self.normalize_inputs(batch) @@ -303,7 +303,7 @@ class PI0Policy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]: """Do a full training forward pass to compute the loss""" if self.config.adapt_to_pi_aloha: - batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) batch = self.normalize_inputs(batch) @@ -380,7 +380,7 @@ class PI0Policy(PreTrainedPolicy): def prepare_language(self, batch) -> tuple[Tensor, Tensor]: """Tokenize the text input""" - device = batch[OBS_ROBOT].device + device = batch[OBS_STATE].device tasks = batch["task"] # PaliGemma prompt has to end with a new line @@ -427,7 +427,7 @@ class PI0Policy(PreTrainedPolicy): def prepare_state(self, batch): """Pad state""" - state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim) + state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) return state def prepare_action(self, batch): diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 0940f198..615c156c 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -35,7 +35,7 @@ import torch.nn as nn import torch.nn.functional as F # noqa: N812 from torch import Tensor -from lerobot.common.constants import OBS_ENV, OBS_ROBOT +from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig @@ -753,9 +753,9 @@ class TDMPCObservationEncoder(nn.Module): ) ) if self.config.env_state_feature: - feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV])) + feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV_STATE])) if self.config.robot_state_feature: - feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT])) + feat.append(self.state_enc_layers(obs_dict[OBS_STATE])) return torch.stack(feat, dim=0).mean(0)