Update constants
This commit is contained in:
parent
a13e49073c
commit
2b24feb604
lerobot/common
|
@ -3,7 +3,7 @@ from dataclasses import dataclass, field
|
|||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
||||
from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
|
@ -39,7 +39,7 @@ class AlohaEnv(EnvConfig):
|
|||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"agent_pos": OBS_STATE,
|
||||
"top": f"{OBS_IMAGE}.top",
|
||||
"pixels/top": f"{OBS_IMAGES}.top",
|
||||
}
|
||||
|
@ -80,8 +80,8 @@ class PushtEnv(EnvConfig):
|
|||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"environment_state": OBS_ENV,
|
||||
"agent_pos": OBS_STATE,
|
||||
"environment_state": OBS_ENV_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
@ -122,7 +122,7 @@ class XarmEnv(EnvConfig):
|
|||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
|
|
@ -33,7 +33,7 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
|||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
@ -238,8 +238,8 @@ class DiffusionModel(nn.Module):
|
|||
|
||||
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode image features and concatenate them all together along with the state vector."""
|
||||
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
|
||||
global_cond_feats = [batch[OBS_ROBOT]]
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
global_cond_feats = [batch[OBS_STATE]]
|
||||
# Extract image features.
|
||||
if self.config.image_features:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
|
@ -269,7 +269,7 @@ class DiffusionModel(nn.Module):
|
|||
global_cond_feats.append(img_features)
|
||||
|
||||
if self.config.env_state_feature:
|
||||
global_cond_feats.append(batch[OBS_ENV])
|
||||
global_cond_feats.append(batch[OBS_ENV_STATE])
|
||||
|
||||
# Concatenate features then flatten to (B, global_cond_dim).
|
||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||
|
|
|
@ -57,7 +57,7 @@ import torch.nn.functional as F # noqa: N812
|
|||
from torch import Tensor, nn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.constants import ACTION, OBS_STATE
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0.paligemma_with_expert import (
|
||||
|
@ -271,7 +271,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
|
@ -303,7 +303,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
@ -380,7 +380,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_ROBOT].device
|
||||
device = batch[OBS_STATE].device
|
||||
tasks = batch["task"]
|
||||
|
||||
# PaliGemma prompt has to end with a new line
|
||||
|
@ -427,7 +427,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||
|
||||
def prepare_state(self, batch):
|
||||
"""Pad state"""
|
||||
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch):
|
||||
|
|
|
@ -35,7 +35,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
|
@ -753,9 +753,9 @@ class TDMPCObservationEncoder(nn.Module):
|
|||
)
|
||||
)
|
||||
if self.config.env_state_feature:
|
||||
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV]))
|
||||
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV_STATE]))
|
||||
if self.config.robot_state_feature:
|
||||
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT]))
|
||||
feat.append(self.state_enc_layers(obs_dict[OBS_STATE]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue