Update constants
This commit is contained in:
parent
a13e49073c
commit
2b24feb604
|
@ -3,7 +3,7 @@ from dataclasses import dataclass, field
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ class AlohaEnv(EnvConfig):
|
||||||
features_map: dict[str, str] = field(
|
features_map: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": ACTION,
|
"action": ACTION,
|
||||||
"agent_pos": OBS_ROBOT,
|
"agent_pos": OBS_STATE,
|
||||||
"top": f"{OBS_IMAGE}.top",
|
"top": f"{OBS_IMAGE}.top",
|
||||||
"pixels/top": f"{OBS_IMAGES}.top",
|
"pixels/top": f"{OBS_IMAGES}.top",
|
||||||
}
|
}
|
||||||
|
@ -80,8 +80,8 @@ class PushtEnv(EnvConfig):
|
||||||
features_map: dict[str, str] = field(
|
features_map: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": ACTION,
|
"action": ACTION,
|
||||||
"agent_pos": OBS_ROBOT,
|
"agent_pos": OBS_STATE,
|
||||||
"environment_state": OBS_ENV,
|
"environment_state": OBS_ENV_STATE,
|
||||||
"pixels": OBS_IMAGE,
|
"pixels": OBS_IMAGE,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -122,7 +122,7 @@ class XarmEnv(EnvConfig):
|
||||||
features_map: dict[str, str] = field(
|
features_map: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": ACTION,
|
"action": ACTION,
|
||||||
"agent_pos": OBS_ROBOT,
|
"agent_pos": OBS_STATE,
|
||||||
"pixels": OBS_IMAGE,
|
"pixels": OBS_IMAGE,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -33,7 +33,7 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
@ -238,8 +238,8 @@ class DiffusionModel(nn.Module):
|
||||||
|
|
||||||
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Encode image features and concatenate them all together along with the state vector."""
|
"""Encode image features and concatenate them all together along with the state vector."""
|
||||||
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
|
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||||
global_cond_feats = [batch[OBS_ROBOT]]
|
global_cond_feats = [batch[OBS_STATE]]
|
||||||
# Extract image features.
|
# Extract image features.
|
||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
if self.config.use_separate_rgb_encoder_per_camera:
|
if self.config.use_separate_rgb_encoder_per_camera:
|
||||||
|
@ -269,7 +269,7 @@ class DiffusionModel(nn.Module):
|
||||||
global_cond_feats.append(img_features)
|
global_cond_feats.append(img_features)
|
||||||
|
|
||||||
if self.config.env_state_feature:
|
if self.config.env_state_feature:
|
||||||
global_cond_feats.append(batch[OBS_ENV])
|
global_cond_feats.append(batch[OBS_ENV_STATE])
|
||||||
|
|
||||||
# Concatenate features then flatten to (B, global_cond_dim).
|
# Concatenate features then flatten to (B, global_cond_dim).
|
||||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||||
|
|
|
@ -57,7 +57,7 @@ import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
from lerobot.common.constants import ACTION, OBS_STATE
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.common.policies.pi0.paligemma_with_expert import (
|
from lerobot.common.policies.pi0.paligemma_with_expert import (
|
||||||
|
@ -271,7 +271,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
if self.config.adapt_to_pi_aloha:
|
if self.config.adapt_to_pi_aloha:
|
||||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
||||||
|
@ -303,7 +303,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
||||||
"""Do a full training forward pass to compute the loss"""
|
"""Do a full training forward pass to compute the loss"""
|
||||||
if self.config.adapt_to_pi_aloha:
|
if self.config.adapt_to_pi_aloha:
|
||||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
@ -380,7 +380,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||||
|
|
||||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||||
"""Tokenize the text input"""
|
"""Tokenize the text input"""
|
||||||
device = batch[OBS_ROBOT].device
|
device = batch[OBS_STATE].device
|
||||||
tasks = batch["task"]
|
tasks = batch["task"]
|
||||||
|
|
||||||
# PaliGemma prompt has to end with a new line
|
# PaliGemma prompt has to end with a new line
|
||||||
|
@ -427,7 +427,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||||
|
|
||||||
def prepare_state(self, batch):
|
def prepare_state(self, batch):
|
||||||
"""Pad state"""
|
"""Pad state"""
|
||||||
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
|
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def prepare_action(self, batch):
|
def prepare_action(self, batch):
|
||||||
|
|
|
@ -35,7 +35,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
|
@ -753,9 +753,9 @@ class TDMPCObservationEncoder(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self.config.env_state_feature:
|
if self.config.env_state_feature:
|
||||||
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV]))
|
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV_STATE]))
|
||||||
if self.config.robot_state_feature:
|
if self.config.robot_state_feature:
|
||||||
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT]))
|
feat.append(self.state_enc_layers(obs_dict[OBS_STATE]))
|
||||||
return torch.stack(feat, dim=0).mean(0)
|
return torch.stack(feat, dim=0).mean(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue