Update constants

This commit is contained in:
Simon Alibert 2025-03-04 11:07:15 +01:00
parent a13e49073c
commit 2b24feb604
4 changed files with 17 additions and 17 deletions

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass, field
import draccus import draccus
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
@ -39,7 +39,7 @@ class AlohaEnv(EnvConfig):
features_map: dict[str, str] = field( features_map: dict[str, str] = field(
default_factory=lambda: { default_factory=lambda: {
"action": ACTION, "action": ACTION,
"agent_pos": OBS_ROBOT, "agent_pos": OBS_STATE,
"top": f"{OBS_IMAGE}.top", "top": f"{OBS_IMAGE}.top",
"pixels/top": f"{OBS_IMAGES}.top", "pixels/top": f"{OBS_IMAGES}.top",
} }
@ -80,8 +80,8 @@ class PushtEnv(EnvConfig):
features_map: dict[str, str] = field( features_map: dict[str, str] = field(
default_factory=lambda: { default_factory=lambda: {
"action": ACTION, "action": ACTION,
"agent_pos": OBS_ROBOT, "agent_pos": OBS_STATE,
"environment_state": OBS_ENV, "environment_state": OBS_ENV_STATE,
"pixels": OBS_IMAGE, "pixels": OBS_IMAGE,
} }
) )
@ -122,7 +122,7 @@ class XarmEnv(EnvConfig):
features_map: dict[str, str] = field( features_map: dict[str, str] = field(
default_factory=lambda: { default_factory=lambda: {
"action": ACTION, "action": ACTION,
"agent_pos": OBS_ROBOT, "agent_pos": OBS_STATE,
"pixels": OBS_IMAGE, "pixels": OBS_IMAGE,
} }
) )

View File

@ -33,7 +33,7 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor, nn from torch import Tensor, nn
from lerobot.common.constants import OBS_ENV, OBS_ROBOT from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
@ -238,8 +238,8 @@ class DiffusionModel(nn.Module):
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor: def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector.""" """Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2] batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
global_cond_feats = [batch[OBS_ROBOT]] global_cond_feats = [batch[OBS_STATE]]
# Extract image features. # Extract image features.
if self.config.image_features: if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera: if self.config.use_separate_rgb_encoder_per_camera:
@ -269,7 +269,7 @@ class DiffusionModel(nn.Module):
global_cond_feats.append(img_features) global_cond_feats.append(img_features)
if self.config.env_state_feature: if self.config.env_state_feature:
global_cond_feats.append(batch[OBS_ENV]) global_cond_feats.append(batch[OBS_ENV_STATE])
# Concatenate features then flatten to (B, global_cond_dim). # Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1) return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)

View File

@ -57,7 +57,7 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn from torch import Tensor, nn
from transformers import AutoTokenizer from transformers import AutoTokenizer
from lerobot.common.constants import ACTION, OBS_ROBOT from lerobot.common.constants import ACTION, OBS_STATE
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0.paligemma_with_expert import ( from lerobot.common.policies.pi0.paligemma_with_expert import (
@ -271,7 +271,7 @@ class PI0Policy(PreTrainedPolicy):
self.eval() self.eval()
if self.config.adapt_to_pi_aloha: if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
@ -303,7 +303,7 @@ class PI0Policy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]: def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss""" """Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha: if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
@ -380,7 +380,7 @@ class PI0Policy(PreTrainedPolicy):
def prepare_language(self, batch) -> tuple[Tensor, Tensor]: def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input""" """Tokenize the text input"""
device = batch[OBS_ROBOT].device device = batch[OBS_STATE].device
tasks = batch["task"] tasks = batch["task"]
# PaliGemma prompt has to end with a new line # PaliGemma prompt has to end with a new line
@ -427,7 +427,7 @@ class PI0Policy(PreTrainedPolicy):
def prepare_state(self, batch): def prepare_state(self, batch):
"""Pad state""" """Pad state"""
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim) state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
return state return state
def prepare_action(self, batch): def prepare_action(self, batch):

View File

@ -35,7 +35,7 @@ import torch.nn as nn
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from torch import Tensor from torch import Tensor
from lerobot.common.constants import OBS_ENV, OBS_ROBOT from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
@ -753,9 +753,9 @@ class TDMPCObservationEncoder(nn.Module):
) )
) )
if self.config.env_state_feature: if self.config.env_state_feature:
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV])) feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV_STATE]))
if self.config.robot_state_feature: if self.config.robot_state_feature:
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT])) feat.append(self.state_enc_layers(obs_dict[OBS_STATE]))
return torch.stack(feat, dim=0).mean(0) return torch.stack(feat, dim=0).mean(0)