Merge a8135629b4
into 0f706ce543
This commit is contained in:
commit
a72146e814
|
@ -203,6 +203,9 @@ class EnvWrapperConfig:
|
||||||
joint_masking_action_space: Optional[Any] = None
|
joint_masking_action_space: Optional[Any] = None
|
||||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||||
use_gripper: bool = False
|
use_gripper: bool = False
|
||||||
|
gripper_quantization_threshold: float = 0.8
|
||||||
|
gripper_penalty: float = 0.0
|
||||||
|
open_gripper_on_reset: bool = False
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||||
|
@ -254,6 +257,7 @@ class ManiskillEnvConfig(EnvConfig):
|
||||||
robot: str = "so100" # This is a hack to make the robot config work
|
robot: str = "so100" # This is a hack to make the robot config work
|
||||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||||
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
|
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
|
||||||
|
mock_gripper: bool = False
|
||||||
features: dict[str, PolicyFeature] = field(
|
features: dict[str, PolicyFeature] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||||
|
|
|
@ -85,12 +85,14 @@ class SACConfig(PreTrainedConfig):
|
||||||
freeze_vision_encoder: Whether to freeze the vision encoder during training.
|
freeze_vision_encoder: Whether to freeze the vision encoder during training.
|
||||||
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
||||||
shared_encoder: Whether to use a shared encoder for actor and critic.
|
shared_encoder: Whether to use a shared encoder for actor and critic.
|
||||||
|
num_discrete_actions: Number of discrete actions, eg for gripper actions.
|
||||||
concurrency: Configuration for concurrency settings.
|
concurrency: Configuration for concurrency settings.
|
||||||
actor_learner: Configuration for actor-learner architecture.
|
actor_learner: Configuration for actor-learner architecture.
|
||||||
online_steps: Number of steps for online training.
|
online_steps: Number of steps for online training.
|
||||||
online_env_seed: Seed for the online environment.
|
online_env_seed: Seed for the online environment.
|
||||||
online_buffer_capacity: Capacity of the online replay buffer.
|
online_buffer_capacity: Capacity of the online replay buffer.
|
||||||
offline_buffer_capacity: Capacity of the offline replay buffer.
|
offline_buffer_capacity: Capacity of the offline replay buffer.
|
||||||
|
async_prefetch: Whether to use asynchronous prefetching for the buffers.
|
||||||
online_step_before_learning: Number of steps before learning starts.
|
online_step_before_learning: Number of steps before learning starts.
|
||||||
policy_update_freq: Frequency of policy updates.
|
policy_update_freq: Frequency of policy updates.
|
||||||
discount: Discount factor for the SAC algorithm.
|
discount: Discount factor for the SAC algorithm.
|
||||||
|
@ -144,12 +146,14 @@ class SACConfig(PreTrainedConfig):
|
||||||
freeze_vision_encoder: bool = True
|
freeze_vision_encoder: bool = True
|
||||||
image_encoder_hidden_dim: int = 32
|
image_encoder_hidden_dim: int = 32
|
||||||
shared_encoder: bool = True
|
shared_encoder: bool = True
|
||||||
|
num_discrete_actions: int | None = None
|
||||||
|
|
||||||
# Training parameter
|
# Training parameter
|
||||||
online_steps: int = 1000000
|
online_steps: int = 1000000
|
||||||
online_env_seed: int = 10000
|
online_env_seed: int = 10000
|
||||||
online_buffer_capacity: int = 100000
|
online_buffer_capacity: int = 100000
|
||||||
offline_buffer_capacity: int = 100000
|
offline_buffer_capacity: int = 100000
|
||||||
|
async_prefetch: bool = False
|
||||||
online_step_before_learning: int = 100
|
online_step_before_learning: int = 100
|
||||||
policy_update_freq: int = 1
|
policy_update_freq: int = 1
|
||||||
|
|
||||||
|
@ -173,7 +177,7 @@ class SACConfig(PreTrainedConfig):
|
||||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||||
|
grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,8 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
|
|
||||||
|
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||||
|
|
||||||
|
|
||||||
class SACPolicy(
|
class SACPolicy(
|
||||||
PreTrainedPolicy,
|
PreTrainedPolicy,
|
||||||
|
@ -49,6 +51,8 @@ class SACPolicy(
|
||||||
config.validate_features()
|
config.validate_features()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
continuous_action_dim = config.output_features["action"].shape[0]
|
||||||
|
|
||||||
if config.dataset_stats is not None:
|
if config.dataset_stats is not None:
|
||||||
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = Normalize(
|
||||||
|
@ -77,11 +81,12 @@ class SACPolicy(
|
||||||
else:
|
else:
|
||||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||||
|
self.shared_encoder = config.shared_encoder
|
||||||
|
|
||||||
# Create a list of critic heads
|
# Create a list of critic heads
|
||||||
critic_heads = [
|
critic_heads = [
|
||||||
CriticHead(
|
CriticHead(
|
||||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||||
**asdict(config.critic_network_kwargs),
|
**asdict(config.critic_network_kwargs),
|
||||||
)
|
)
|
||||||
for _ in range(config.num_critics)
|
for _ in range(config.num_critics)
|
||||||
|
@ -96,7 +101,7 @@ class SACPolicy(
|
||||||
# Create target critic heads as deepcopies of the original critic heads
|
# Create target critic heads as deepcopies of the original critic heads
|
||||||
target_critic_heads = [
|
target_critic_heads = [
|
||||||
CriticHead(
|
CriticHead(
|
||||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||||
**asdict(config.critic_network_kwargs),
|
**asdict(config.critic_network_kwargs),
|
||||||
)
|
)
|
||||||
for _ in range(config.num_critics)
|
for _ in range(config.num_critics)
|
||||||
|
@ -112,15 +117,41 @@ class SACPolicy(
|
||||||
|
|
||||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||||
self.critic_target = torch.compile(self.critic_target)
|
self.critic_target = torch.compile(self.critic_target)
|
||||||
|
|
||||||
|
self.grasp_critic = None
|
||||||
|
self.grasp_critic_target = None
|
||||||
|
|
||||||
|
if config.num_discrete_actions is not None:
|
||||||
|
# Create grasp critic
|
||||||
|
self.grasp_critic = GraspCritic(
|
||||||
|
encoder=encoder_critic,
|
||||||
|
input_dim=encoder_critic.output_dim,
|
||||||
|
output_dim=config.num_discrete_actions,
|
||||||
|
**asdict(config.grasp_critic_network_kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create target grasp critic
|
||||||
|
self.grasp_critic_target = GraspCritic(
|
||||||
|
encoder=encoder_critic,
|
||||||
|
input_dim=encoder_critic.output_dim,
|
||||||
|
output_dim=config.num_discrete_actions,
|
||||||
|
**asdict(config.grasp_critic_network_kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
|
||||||
|
|
||||||
|
self.grasp_critic = torch.compile(self.grasp_critic)
|
||||||
|
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||||
|
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=encoder_actor,
|
encoder=encoder_actor,
|
||||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||||
action_dim=config.output_features["action"].shape[0],
|
action_dim=continuous_action_dim,
|
||||||
encoder_is_shared=config.shared_encoder,
|
encoder_is_shared=config.shared_encoder,
|
||||||
**asdict(config.policy_kwargs),
|
**asdict(config.policy_kwargs),
|
||||||
)
|
)
|
||||||
if config.target_entropy is None:
|
if config.target_entropy is None:
|
||||||
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
|
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
|
||||||
|
|
||||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||||
|
@ -131,11 +162,14 @@ class SACPolicy(
|
||||||
self.temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
return {
|
optim_params = {
|
||||||
"actor": self.actor.parameters_to_optimize,
|
"actor": self.actor.parameters_to_optimize,
|
||||||
"critic": self.critic_ensemble.parameters_to_optimize,
|
"critic": self.critic_ensemble.parameters_to_optimize,
|
||||||
"temperature": self.log_alpha,
|
"temperature": self.log_alpha,
|
||||||
}
|
}
|
||||||
|
if self.config.num_discrete_actions is not None:
|
||||||
|
optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize
|
||||||
|
return optim_params
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the policy"""
|
"""Reset the policy"""
|
||||||
|
@ -151,8 +185,19 @@ class SACPolicy(
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select action for inference/evaluation"""
|
"""Select action for inference/evaluation"""
|
||||||
actions, _, _ = self.actor(batch)
|
# We cached the encoder output to avoid recomputing it
|
||||||
|
observations_features = None
|
||||||
|
if self.shared_encoder:
|
||||||
|
observations_features = self.actor.encoder.get_image_features(batch)
|
||||||
|
|
||||||
|
actions, _, _ = self.actor(batch, observations_features)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
|
||||||
|
if self.config.num_discrete_actions is not None:
|
||||||
|
discrete_action_value = self.grasp_critic(batch, observations_features)
|
||||||
|
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||||
|
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
def critic_forward(
|
def critic_forward(
|
||||||
|
@ -172,14 +217,30 @@ class SACPolicy(
|
||||||
Returns:
|
Returns:
|
||||||
Tensor of Q-values from all critics
|
Tensor of Q-values from all critics
|
||||||
"""
|
"""
|
||||||
|
|
||||||
critics = self.critic_target if use_target else self.critic_ensemble
|
critics = self.critic_target if use_target else self.critic_ensemble
|
||||||
q_values = critics(observations, actions, observation_features)
|
q_values = critics(observations, actions, observation_features)
|
||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
|
def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor:
|
||||||
|
"""Forward pass through a grasp critic network
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observations: Dictionary of observations
|
||||||
|
use_target: If True, use target critics, otherwise use ensemble critics
|
||||||
|
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of Q-values from the grasp critic network
|
||||||
|
"""
|
||||||
|
grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
|
||||||
|
q_values = grasp_critic(observations, observation_features)
|
||||||
|
return q_values
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||||
model: Literal["actor", "critic", "temperature"] = "critic",
|
model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic",
|
||||||
) -> dict[str, Tensor]:
|
) -> dict[str, Tensor]:
|
||||||
"""Compute the loss for the given model
|
"""Compute the loss for the given model
|
||||||
|
|
||||||
|
@ -192,12 +253,11 @@ class SACPolicy(
|
||||||
- done: Done mask tensor
|
- done: Done mask tensor
|
||||||
- observation_feature: Optional pre-computed observation features
|
- observation_feature: Optional pre-computed observation features
|
||||||
- next_observation_feature: Optional pre-computed next observation features
|
- next_observation_feature: Optional pre-computed next observation features
|
||||||
model: Which model to compute the loss for ("actor", "critic", or "temperature")
|
model: Which model to compute the loss for ("actor", "critic", "grasp_critic", or "temperature")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The computed loss tensor
|
The computed loss tensor
|
||||||
"""
|
"""
|
||||||
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
|
|
||||||
# Extract common components from batch
|
# Extract common components from batch
|
||||||
actions: Tensor = batch["action"]
|
actions: Tensor = batch["action"]
|
||||||
observations: dict[str, Tensor] = batch["state"]
|
observations: dict[str, Tensor] = batch["state"]
|
||||||
|
@ -210,7 +270,7 @@ class SACPolicy(
|
||||||
done: Tensor = batch["done"]
|
done: Tensor = batch["done"]
|
||||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||||
|
|
||||||
return self.compute_loss_critic(
|
loss_critic = self.compute_loss_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
actions=actions,
|
actions=actions,
|
||||||
rewards=rewards,
|
rewards=rewards,
|
||||||
|
@ -220,17 +280,41 @@ class SACPolicy(
|
||||||
next_observation_features=next_observation_features,
|
next_observation_features=next_observation_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return {"loss_critic": loss_critic}
|
||||||
|
|
||||||
|
if model == "grasp_critic" and self.config.num_discrete_actions is not None:
|
||||||
|
# Extract critic-specific components
|
||||||
|
rewards: Tensor = batch["reward"]
|
||||||
|
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||||
|
done: Tensor = batch["done"]
|
||||||
|
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||||
|
complementary_info = batch.get("complementary_info")
|
||||||
|
loss_grasp_critic = self.compute_loss_grasp_critic(
|
||||||
|
observations=observations,
|
||||||
|
actions=actions,
|
||||||
|
rewards=rewards,
|
||||||
|
next_observations=next_observations,
|
||||||
|
done=done,
|
||||||
|
observation_features=observation_features,
|
||||||
|
next_observation_features=next_observation_features,
|
||||||
|
complementary_info=complementary_info,
|
||||||
|
)
|
||||||
|
return {"loss_grasp_critic": loss_grasp_critic}
|
||||||
if model == "actor":
|
if model == "actor":
|
||||||
return self.compute_loss_actor(
|
return {
|
||||||
|
"loss_actor": self.compute_loss_actor(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
|
||||||
if model == "temperature":
|
if model == "temperature":
|
||||||
return self.compute_loss_temperature(
|
return {
|
||||||
|
"loss_temperature": self.compute_loss_temperature(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
|
||||||
raise ValueError(f"Unknown model type: {model}")
|
raise ValueError(f"Unknown model type: {model}")
|
||||||
|
|
||||||
|
@ -245,6 +329,16 @@ class SACPolicy(
|
||||||
param.data * self.config.critic_target_update_weight
|
param.data * self.config.critic_target_update_weight
|
||||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||||
)
|
)
|
||||||
|
if self.config.num_discrete_actions is not None:
|
||||||
|
for target_param, param in zip(
|
||||||
|
self.grasp_critic_target.parameters(),
|
||||||
|
self.grasp_critic.parameters(),
|
||||||
|
strict=False,
|
||||||
|
):
|
||||||
|
target_param.data.copy_(
|
||||||
|
param.data * self.config.critic_target_update_weight
|
||||||
|
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||||
|
)
|
||||||
|
|
||||||
def update_temperature(self):
|
def update_temperature(self):
|
||||||
self.temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
@ -287,6 +381,11 @@ class SACPolicy(
|
||||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||||
|
|
||||||
# 3- compute predicted qs
|
# 3- compute predicted qs
|
||||||
|
if self.config.num_discrete_actions is not None:
|
||||||
|
# NOTE: We only want to keep the continuous action part
|
||||||
|
# In the buffer we have the full action space (continuous + discrete)
|
||||||
|
# We need to split them before concatenating them in the critic forward
|
||||||
|
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||||
q_preds = self.critic_forward(
|
q_preds = self.critic_forward(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
actions=actions,
|
actions=actions,
|
||||||
|
@ -307,6 +406,64 @@ class SACPolicy(
|
||||||
).sum()
|
).sum()
|
||||||
return critics_loss
|
return critics_loss
|
||||||
|
|
||||||
|
def compute_loss_grasp_critic(
|
||||||
|
self,
|
||||||
|
observations,
|
||||||
|
actions,
|
||||||
|
rewards,
|
||||||
|
next_observations,
|
||||||
|
done,
|
||||||
|
observation_features=None,
|
||||||
|
next_observation_features=None,
|
||||||
|
complementary_info=None,
|
||||||
|
):
|
||||||
|
# NOTE: We only want to keep the discrete action part
|
||||||
|
# In the buffer we have the full action space (continuous + discrete)
|
||||||
|
# We need to split them before concatenating them in the critic forward
|
||||||
|
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||||
|
actions_discrete = torch.round(actions_discrete)
|
||||||
|
actions_discrete = actions_discrete.long()
|
||||||
|
|
||||||
|
if complementary_info is not None:
|
||||||
|
gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# For DQN, select actions using online network, evaluate with target network
|
||||||
|
next_grasp_qs = self.grasp_critic_forward(
|
||||||
|
next_observations, use_target=False, observation_features=next_observation_features
|
||||||
|
)
|
||||||
|
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Get target Q-values from target network
|
||||||
|
target_next_grasp_qs = self.grasp_critic_forward(
|
||||||
|
observations=next_observations,
|
||||||
|
use_target=True,
|
||||||
|
observation_features=next_observation_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use gather to select Q-values for best actions
|
||||||
|
target_next_grasp_q = torch.gather(
|
||||||
|
target_next_grasp_qs, dim=1, index=best_next_grasp_action
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
|
# Compute target Q-value with Bellman equation
|
||||||
|
rewards_gripper = rewards
|
||||||
|
if gripper_penalties is not None:
|
||||||
|
rewards_gripper = rewards + gripper_penalties
|
||||||
|
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||||
|
|
||||||
|
# Get predicted Q-values for current observations
|
||||||
|
predicted_grasp_qs = self.grasp_critic_forward(
|
||||||
|
observations=observations, use_target=False, observation_features=observation_features
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use gather to select Q-values for taken actions
|
||||||
|
predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||||
|
|
||||||
|
# Compute MSE loss between predicted and target Q-values
|
||||||
|
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
|
||||||
|
return grasp_critic_loss
|
||||||
|
|
||||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||||
"""Compute the temperature loss"""
|
"""Compute the temperature loss"""
|
||||||
# calculate temperature loss
|
# calculate temperature loss
|
||||||
|
@ -337,6 +494,109 @@ class SACPolicy(
|
||||||
return actor_loss
|
return actor_loss
|
||||||
|
|
||||||
|
|
||||||
|
class SACObservationEncoder(nn.Module):
|
||||||
|
"""Encode image and/or state vector observations."""
|
||||||
|
|
||||||
|
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
||||||
|
"""
|
||||||
|
Creates encoders for pixel and/or state modalities.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.input_normalization = input_normalizer
|
||||||
|
self.has_pretrained_vision_encoder = False
|
||||||
|
self.parameters_to_optimize = []
|
||||||
|
|
||||||
|
self.aggregation_size: int = 0
|
||||||
|
if any("observation.image" in key for key in config.input_features):
|
||||||
|
self.camera_number = config.camera_number
|
||||||
|
|
||||||
|
if self.config.vision_encoder_name is not None:
|
||||||
|
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||||
|
self.has_pretrained_vision_encoder = True
|
||||||
|
else:
|
||||||
|
self.image_enc_layers = DefaultImageEncoder(config)
|
||||||
|
|
||||||
|
self.aggregation_size += config.latent_dim * self.camera_number
|
||||||
|
|
||||||
|
if config.freeze_vision_encoder:
|
||||||
|
freeze_image_encoder(self.image_enc_layers)
|
||||||
|
else:
|
||||||
|
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||||
|
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||||
|
|
||||||
|
if "observation.state" in config.input_features:
|
||||||
|
self.state_enc_layers = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
in_features=config.input_features["observation.state"].shape[0],
|
||||||
|
out_features=config.latent_dim,
|
||||||
|
),
|
||||||
|
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
self.aggregation_size += config.latent_dim
|
||||||
|
|
||||||
|
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||||
|
|
||||||
|
if "observation.environment_state" in config.input_features:
|
||||||
|
self.env_state_enc_layers = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
in_features=config.input_features["observation.environment_state"].shape[0],
|
||||||
|
out_features=config.latent_dim,
|
||||||
|
),
|
||||||
|
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
self.aggregation_size += config.latent_dim
|
||||||
|
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||||
|
|
||||||
|
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||||
|
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
|
||||||
|
) -> Tensor:
|
||||||
|
"""Encode the image and/or state vector.
|
||||||
|
|
||||||
|
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||||
|
over all features.
|
||||||
|
"""
|
||||||
|
feat = []
|
||||||
|
obs_dict = self.input_normalization(obs_dict)
|
||||||
|
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
||||||
|
vision_encoder_cache = self.get_image_features(obs_dict)
|
||||||
|
feat.append(vision_encoder_cache)
|
||||||
|
|
||||||
|
if vision_encoder_cache is not None:
|
||||||
|
feat.append(vision_encoder_cache)
|
||||||
|
|
||||||
|
if "observation.environment_state" in self.config.input_features:
|
||||||
|
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||||
|
if "observation.state" in self.config.input_features:
|
||||||
|
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||||
|
|
||||||
|
features = torch.cat(tensors=feat, dim=-1)
|
||||||
|
features = self.aggregation_layer(features)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor:
|
||||||
|
# [N*B, C, H, W]
|
||||||
|
if len(self.all_image_keys) > 0:
|
||||||
|
# Batch all images along the batch dimension, then encode them.
|
||||||
|
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
|
||||||
|
images_batched = self.image_enc_layers(images_batched)
|
||||||
|
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||||
|
embeddings_image = torch.cat(embeddings_chunks, dim=-1)
|
||||||
|
return embeddings_image
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dim(self) -> int:
|
||||||
|
"""Returns the dimension of the encoder output"""
|
||||||
|
return self.config.latent_dim
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -459,7 +719,7 @@ class CriticEnsemble(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder: Optional[nn.Module],
|
encoder: SACObservationEncoder,
|
||||||
ensemble: List[CriticHead],
|
ensemble: List[CriticHead],
|
||||||
output_normalization: nn.Module,
|
output_normalization: nn.Module,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
|
@ -491,11 +751,7 @@ class CriticEnsemble(nn.Module):
|
||||||
actions = self.output_normalization(actions)["action"]
|
actions = self.output_normalization(actions)["action"]
|
||||||
actions = actions.to(device)
|
actions = actions.to(device)
|
||||||
|
|
||||||
obs_enc = (
|
obs_enc = self.encoder(observations, observation_features)
|
||||||
observation_features
|
|
||||||
if observation_features is not None
|
|
||||||
else (observations if self.encoder is None else self.encoder(observations))
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||||
|
|
||||||
|
@ -509,10 +765,57 @@ class CriticEnsemble(nn.Module):
|
||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
|
|
||||||
|
class GraspCritic(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder: nn.Module,
|
||||||
|
input_dim: int,
|
||||||
|
hidden_dims: list[int],
|
||||||
|
output_dim: int = 3,
|
||||||
|
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||||
|
activate_final: bool = False,
|
||||||
|
dropout_rate: Optional[float] = None,
|
||||||
|
init_final: Optional[float] = None,
|
||||||
|
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
self.net = MLP(
|
||||||
|
input_dim=input_dim,
|
||||||
|
hidden_dims=hidden_dims,
|
||||||
|
activations=activations,
|
||||||
|
activate_final=activate_final,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
final_activation=final_activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
|
||||||
|
if init_final is not None:
|
||||||
|
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||||
|
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||||
|
else:
|
||||||
|
orthogonal_init()(self.output_layer.weight)
|
||||||
|
|
||||||
|
self.parameters_to_optimize = []
|
||||||
|
self.parameters_to_optimize += list(self.net.parameters())
|
||||||
|
self.parameters_to_optimize += list(self.output_layer.parameters())
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
device = get_device_from_parameters(self)
|
||||||
|
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
||||||
|
observations = {k: v.to(device) for k, v in observations.items()}
|
||||||
|
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||||
|
return self.output_layer(self.net(obs_enc))
|
||||||
|
|
||||||
|
|
||||||
class Policy(nn.Module):
|
class Policy(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder: Optional[nn.Module],
|
encoder: SACObservationEncoder,
|
||||||
network: nn.Module,
|
network: nn.Module,
|
||||||
action_dim: int,
|
action_dim: int,
|
||||||
log_std_min: float = -5,
|
log_std_min: float = -5,
|
||||||
|
@ -523,7 +826,7 @@ class Policy(nn.Module):
|
||||||
encoder_is_shared: bool = False,
|
encoder_is_shared: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.encoder: SACObservationEncoder = encoder
|
||||||
self.network = network
|
self.network = network
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
self.log_std_min = log_std_min
|
self.log_std_min = log_std_min
|
||||||
|
@ -566,11 +869,7 @@ class Policy(nn.Module):
|
||||||
observation_features: torch.Tensor | None = None,
|
observation_features: torch.Tensor | None = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Encode observations if encoder exists
|
# Encode observations if encoder exists
|
||||||
obs_enc = (
|
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||||
observation_features
|
|
||||||
if observation_features is not None
|
|
||||||
else (observations if self.encoder is None else self.encoder(observations))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get network outputs
|
# Get network outputs
|
||||||
outputs = self.network(obs_enc)
|
outputs = self.network(obs_enc)
|
||||||
|
@ -614,96 +913,6 @@ class Policy(nn.Module):
|
||||||
return observations
|
return observations
|
||||||
|
|
||||||
|
|
||||||
class SACObservationEncoder(nn.Module):
|
|
||||||
"""Encode image and/or state vector observations."""
|
|
||||||
|
|
||||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
|
||||||
"""
|
|
||||||
Creates encoders for pixel and/or state modalities.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.input_normalization = input_normalizer
|
|
||||||
self.has_pretrained_vision_encoder = False
|
|
||||||
self.parameters_to_optimize = []
|
|
||||||
|
|
||||||
self.aggregation_size: int = 0
|
|
||||||
if any("observation.image" in key for key in config.input_features):
|
|
||||||
self.camera_number = config.camera_number
|
|
||||||
|
|
||||||
if self.config.vision_encoder_name is not None:
|
|
||||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
|
||||||
self.has_pretrained_vision_encoder = True
|
|
||||||
else:
|
|
||||||
self.image_enc_layers = DefaultImageEncoder(config)
|
|
||||||
|
|
||||||
self.aggregation_size += config.latent_dim * self.camera_number
|
|
||||||
|
|
||||||
if config.freeze_vision_encoder:
|
|
||||||
freeze_image_encoder(self.image_enc_layers)
|
|
||||||
else:
|
|
||||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
|
||||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
|
||||||
|
|
||||||
if "observation.state" in config.input_features:
|
|
||||||
self.state_enc_layers = nn.Sequential(
|
|
||||||
nn.Linear(
|
|
||||||
in_features=config.input_features["observation.state"].shape[0],
|
|
||||||
out_features=config.latent_dim,
|
|
||||||
),
|
|
||||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
|
||||||
nn.Tanh(),
|
|
||||||
)
|
|
||||||
self.aggregation_size += config.latent_dim
|
|
||||||
|
|
||||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
|
||||||
|
|
||||||
if "observation.environment_state" in config.input_features:
|
|
||||||
self.env_state_enc_layers = nn.Sequential(
|
|
||||||
nn.Linear(
|
|
||||||
in_features=config.input_features["observation.environment_state"].shape[0],
|
|
||||||
out_features=config.latent_dim,
|
|
||||||
),
|
|
||||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
|
||||||
nn.Tanh(),
|
|
||||||
)
|
|
||||||
self.aggregation_size += config.latent_dim
|
|
||||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
|
||||||
|
|
||||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
|
||||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
|
||||||
|
|
||||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
|
||||||
"""Encode the image and/or state vector.
|
|
||||||
|
|
||||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
|
||||||
over all features.
|
|
||||||
"""
|
|
||||||
feat = []
|
|
||||||
obs_dict = self.input_normalization(obs_dict)
|
|
||||||
# Batch all images along the batch dimension, then encode them.
|
|
||||||
if len(self.all_image_keys) > 0:
|
|
||||||
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
|
|
||||||
images_batched = self.image_enc_layers(images_batched)
|
|
||||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
|
||||||
feat.extend(embeddings_chunks)
|
|
||||||
|
|
||||||
if "observation.environment_state" in self.config.input_features:
|
|
||||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
|
||||||
if "observation.state" in self.config.input_features:
|
|
||||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
|
||||||
|
|
||||||
features = torch.cat(tensors=feat, dim=-1)
|
|
||||||
features = self.aggregation_layer(features)
|
|
||||||
|
|
||||||
return features
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_dim(self) -> int:
|
|
||||||
"""Returns the dimension of the encoder output"""
|
|
||||||
return self.config.latent_dim
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultImageEncoder(nn.Module):
|
class DefaultImageEncoder(nn.Module):
|
||||||
def __init__(self, config: SACConfig):
|
def __init__(self, config: SACConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import functools
|
import functools
|
||||||
import io
|
import io
|
||||||
import os
|
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||||
|
|
||||||
|
@ -33,7 +32,7 @@ class Transition(TypedDict):
|
||||||
next_state: dict[str, torch.Tensor]
|
next_state: dict[str, torch.Tensor]
|
||||||
done: bool
|
done: bool
|
||||||
truncated: bool
|
truncated: bool
|
||||||
complementary_info: dict[str, Any] = None
|
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||||
|
|
||||||
|
|
||||||
class BatchTransition(TypedDict):
|
class BatchTransition(TypedDict):
|
||||||
|
@ -43,41 +42,47 @@ class BatchTransition(TypedDict):
|
||||||
next_state: dict[str, torch.Tensor]
|
next_state: dict[str, torch.Tensor]
|
||||||
done: torch.Tensor
|
done: torch.Tensor
|
||||||
truncated: torch.Tensor
|
truncated: torch.Tensor
|
||||||
|
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||||
|
|
||||||
|
|
||||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||||
# Move state tensors to CPU
|
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
non_blocking = device.type == "cuda"
|
||||||
|
|
||||||
|
# Move state tensors to device
|
||||||
transition["state"] = {
|
transition["state"] = {
|
||||||
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
|
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
|
||||||
}
|
}
|
||||||
|
|
||||||
# Move action to CPU
|
# Move action to device
|
||||||
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
|
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
|
||||||
|
|
||||||
# No need to move reward or done, as they are float and bool
|
# Move reward and done if they are tensors
|
||||||
|
|
||||||
# No need to move reward or done, as they are float and bool
|
|
||||||
if isinstance(transition["reward"], torch.Tensor):
|
if isinstance(transition["reward"], torch.Tensor):
|
||||||
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
|
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
|
||||||
|
|
||||||
if isinstance(transition["done"], torch.Tensor):
|
if isinstance(transition["done"], torch.Tensor):
|
||||||
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
|
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
|
||||||
|
|
||||||
if isinstance(transition["truncated"], torch.Tensor):
|
if isinstance(transition["truncated"], torch.Tensor):
|
||||||
transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda")
|
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
|
||||||
|
|
||||||
# Move next_state tensors to CPU
|
# Move next_state tensors to device
|
||||||
transition["next_state"] = {
|
transition["next_state"] = {
|
||||||
key: val.to(device, non_blocking=device.type == "cuda")
|
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
|
||||||
for key, val in transition["next_state"].items()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# If complementary_info is present, move its tensors to CPU
|
# Move complementary_info tensors if present
|
||||||
# if transition["complementary_info"] is not None:
|
if transition.get("complementary_info") is not None:
|
||||||
# transition["complementary_info"] = {
|
for key, val in transition["complementary_info"].items():
|
||||||
# key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
|
if isinstance(val, torch.Tensor):
|
||||||
# }
|
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||||
|
elif isinstance(val, (int, float, bool)):
|
||||||
|
transition["complementary_info"][key] = torch.tensor(
|
||||||
|
val, device=device, non_blocking=non_blocking
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
|
|
||||||
|
@ -217,7 +222,12 @@ class ReplayBuffer:
|
||||||
self.image_augmentation_function = torch.compile(base_function)
|
self.image_augmentation_function = torch.compile(base_function)
|
||||||
self.use_drq = use_drq
|
self.use_drq = use_drq
|
||||||
|
|
||||||
def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor):
|
def _initialize_storage(
|
||||||
|
self,
|
||||||
|
state: dict[str, torch.Tensor],
|
||||||
|
action: torch.Tensor,
|
||||||
|
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||||
|
):
|
||||||
"""Initialize the storage tensors based on the first transition."""
|
"""Initialize the storage tensors based on the first transition."""
|
||||||
# Determine shapes from the first transition
|
# Determine shapes from the first transition
|
||||||
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
||||||
|
@ -244,6 +254,27 @@ class ReplayBuffer:
|
||||||
|
|
||||||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||||
|
|
||||||
|
# Initialize storage for complementary_info
|
||||||
|
self.has_complementary_info = complementary_info is not None
|
||||||
|
self.complementary_info_keys = []
|
||||||
|
self.complementary_info = {}
|
||||||
|
|
||||||
|
if self.has_complementary_info:
|
||||||
|
self.complementary_info_keys = list(complementary_info.keys())
|
||||||
|
# Pre-allocate tensors for each key in complementary_info
|
||||||
|
for key, value in complementary_info.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
value_shape = value.squeeze(0).shape
|
||||||
|
self.complementary_info[key] = torch.empty(
|
||||||
|
(self.capacity, *value_shape), device=self.storage_device
|
||||||
|
)
|
||||||
|
elif isinstance(value, (int, float)):
|
||||||
|
# Handle scalar values similar to reward
|
||||||
|
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
|
||||||
|
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -262,7 +293,7 @@ class ReplayBuffer:
|
||||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||||
# Initialize storage if this is the first transition
|
# Initialize storage if this is the first transition
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
self._initialize_storage(state=state, action=action)
|
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
|
||||||
|
|
||||||
# Store the transition in pre-allocated tensors
|
# Store the transition in pre-allocated tensors
|
||||||
for key in self.states:
|
for key in self.states:
|
||||||
|
@ -277,6 +308,17 @@ class ReplayBuffer:
|
||||||
self.dones[self.position] = done
|
self.dones[self.position] = done
|
||||||
self.truncateds[self.position] = truncated
|
self.truncateds[self.position] = truncated
|
||||||
|
|
||||||
|
# Handle complementary_info if provided and storage is initialized
|
||||||
|
if complementary_info is not None and self.has_complementary_info:
|
||||||
|
# Store the complementary_info
|
||||||
|
for key in self.complementary_info_keys:
|
||||||
|
if key in complementary_info:
|
||||||
|
value = complementary_info[key]
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
|
||||||
|
elif isinstance(value, (int, float)):
|
||||||
|
self.complementary_info[key][self.position] = value
|
||||||
|
|
||||||
self.position = (self.position + 1) % self.capacity
|
self.position = (self.position + 1) % self.capacity
|
||||||
self.size = min(self.size + 1, self.capacity)
|
self.size = min(self.size + 1, self.capacity)
|
||||||
|
|
||||||
|
@ -335,6 +377,13 @@ class ReplayBuffer:
|
||||||
batch_dones = self.dones[idx].to(self.device).float()
|
batch_dones = self.dones[idx].to(self.device).float()
|
||||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||||
|
|
||||||
|
# Sample complementary_info if available
|
||||||
|
batch_complementary_info = None
|
||||||
|
if self.has_complementary_info:
|
||||||
|
batch_complementary_info = {}
|
||||||
|
for key in self.complementary_info_keys:
|
||||||
|
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
|
||||||
|
|
||||||
return BatchTransition(
|
return BatchTransition(
|
||||||
state=batch_state,
|
state=batch_state,
|
||||||
action=batch_actions,
|
action=batch_actions,
|
||||||
|
@ -342,8 +391,112 @@ class ReplayBuffer:
|
||||||
next_state=batch_next_state,
|
next_state=batch_next_state,
|
||||||
done=batch_dones,
|
done=batch_dones,
|
||||||
truncated=batch_truncateds,
|
truncated=batch_truncateds,
|
||||||
|
complementary_info=batch_complementary_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_iterator(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
async_prefetch: bool = True,
|
||||||
|
queue_size: int = 2,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates an infinite iterator that yields batches of transitions.
|
||||||
|
Will automatically restart when internal iterator is exhausted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): Size of batches to sample
|
||||||
|
async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True)
|
||||||
|
queue_size (int): Number of batches to prefetch (default: 2)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
BatchTransition: Batched transitions
|
||||||
|
"""
|
||||||
|
while True: # Create an infinite loop
|
||||||
|
if async_prefetch:
|
||||||
|
# Get the standard iterator
|
||||||
|
iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size)
|
||||||
|
else:
|
||||||
|
iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size)
|
||||||
|
|
||||||
|
# Yield all items from the iterator
|
||||||
|
try:
|
||||||
|
yield from iterator
|
||||||
|
except StopIteration:
|
||||||
|
# Just continue the outer loop to create a new iterator
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_async_iterator(self, batch_size: int, queue_size: int = 2):
|
||||||
|
"""
|
||||||
|
Creates an iterator that prefetches batches in a background thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_size (int): Number of batches to prefetch (default: 2)
|
||||||
|
batch_size (int): Size of batches to sample (default: 128)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
BatchTransition: Prefetched batch transitions
|
||||||
|
"""
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# Use thread-safe queue
|
||||||
|
data_queue = queue.Queue(maxsize=queue_size)
|
||||||
|
running = [True] # Use list to allow modification in nested function
|
||||||
|
|
||||||
|
def prefetch_worker():
|
||||||
|
while running[0]:
|
||||||
|
try:
|
||||||
|
# Sample data and add to queue
|
||||||
|
data = self.sample(batch_size)
|
||||||
|
data_queue.put(data, block=True, timeout=0.5)
|
||||||
|
except queue.Full:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Prefetch error: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Start prefetching thread
|
||||||
|
thread = threading.Thread(target=prefetch_worker, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while running[0]:
|
||||||
|
try:
|
||||||
|
yield data_queue.get(block=True, timeout=0.5)
|
||||||
|
except queue.Empty:
|
||||||
|
if not thread.is_alive():
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
running[0] = False
|
||||||
|
thread.join(timeout=1.0)
|
||||||
|
|
||||||
|
def _get_naive_iterator(self, batch_size: int, queue_size: int = 2):
|
||||||
|
"""
|
||||||
|
Creates a simple non-threaded iterator that yields batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): Size of batches to sample
|
||||||
|
queue_size (int): Number of initial batches to prefetch
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
BatchTransition: Batch transitions
|
||||||
|
"""
|
||||||
|
import collections
|
||||||
|
|
||||||
|
queue = collections.deque()
|
||||||
|
|
||||||
|
def enqueue(n):
|
||||||
|
for _ in range(n):
|
||||||
|
data = self.sample(batch_size)
|
||||||
|
queue.append(data)
|
||||||
|
|
||||||
|
enqueue(queue_size)
|
||||||
|
while queue:
|
||||||
|
yield queue.popleft()
|
||||||
|
enqueue(1)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_lerobot_dataset(
|
def from_lerobot_dataset(
|
||||||
cls,
|
cls,
|
||||||
|
@ -415,7 +568,19 @@ class ReplayBuffer:
|
||||||
if action_delta is not None:
|
if action_delta is not None:
|
||||||
first_action = first_action / action_delta
|
first_action = first_action / action_delta
|
||||||
|
|
||||||
replay_buffer._initialize_storage(state=first_state, action=first_action)
|
# Get complementary info if available
|
||||||
|
first_complementary_info = None
|
||||||
|
if (
|
||||||
|
"complementary_info" in first_transition
|
||||||
|
and first_transition["complementary_info"] is not None
|
||||||
|
):
|
||||||
|
first_complementary_info = {
|
||||||
|
k: v.to(device) for k, v in first_transition["complementary_info"].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
replay_buffer._initialize_storage(
|
||||||
|
state=first_state, action=first_action, complementary_info=first_complementary_info
|
||||||
|
)
|
||||||
|
|
||||||
# Fill the buffer with all transitions
|
# Fill the buffer with all transitions
|
||||||
for data in list_transition:
|
for data in list_transition:
|
||||||
|
@ -443,6 +608,7 @@ class ReplayBuffer:
|
||||||
next_state=data["next_state"],
|
next_state=data["next_state"],
|
||||||
done=data["done"],
|
done=data["done"],
|
||||||
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
||||||
|
complementary_info=data.get("complementary_info", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
return replay_buffer
|
return replay_buffer
|
||||||
|
@ -484,6 +650,15 @@ class ReplayBuffer:
|
||||||
f_info = guess_feature_info(t=sample_val, name=key)
|
f_info = guess_feature_info(t=sample_val, name=key)
|
||||||
features[key] = f_info
|
features[key] = f_info
|
||||||
|
|
||||||
|
# Add complementary_info keys if available
|
||||||
|
if self.has_complementary_info:
|
||||||
|
for key in self.complementary_info_keys:
|
||||||
|
sample_val = self.complementary_info[key][0]
|
||||||
|
if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0:
|
||||||
|
sample_val = sample_val.unsqueeze(0)
|
||||||
|
f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
|
||||||
|
features[f"complementary_info.{key}"] = f_info
|
||||||
|
|
||||||
# Create an empty LeRobotDataset
|
# Create an empty LeRobotDataset
|
||||||
lerobot_dataset = LeRobotDataset.create(
|
lerobot_dataset = LeRobotDataset.create(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
|
@ -517,6 +692,19 @@ class ReplayBuffer:
|
||||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||||
|
|
||||||
|
# Add complementary_info if available
|
||||||
|
if self.has_complementary_info:
|
||||||
|
for key in self.complementary_info_keys:
|
||||||
|
val = self.complementary_info[key][actual_idx]
|
||||||
|
# Convert tensors to CPU
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
if val.ndim == 0:
|
||||||
|
val = val.unsqueeze(0)
|
||||||
|
frame_dict[f"complementary_info.{key}"] = val.cpu()
|
||||||
|
# Non-tensor values can be used directly
|
||||||
|
else:
|
||||||
|
frame_dict[f"complementary_info.{key}"] = val
|
||||||
|
|
||||||
# Add task field which is required by LeRobotDataset
|
# Add task field which is required by LeRobotDataset
|
||||||
frame_dict["task"] = task_name
|
frame_dict["task"] = task_name
|
||||||
|
|
||||||
|
@ -583,6 +771,10 @@ class ReplayBuffer:
|
||||||
sample = dataset[0]
|
sample = dataset[0]
|
||||||
has_done_key = "next.done" in sample
|
has_done_key = "next.done" in sample
|
||||||
|
|
||||||
|
# Check for complementary_info keys
|
||||||
|
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
|
||||||
|
has_complementary_info = len(complementary_info_keys) > 0
|
||||||
|
|
||||||
# If not, we need to infer it from episode boundaries
|
# If not, we need to infer it from episode boundaries
|
||||||
if not has_done_key:
|
if not has_done_key:
|
||||||
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
|
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
|
||||||
|
@ -632,6 +824,22 @@ class ReplayBuffer:
|
||||||
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
||||||
next_state = next_state_data
|
next_state = next_state_data
|
||||||
|
|
||||||
|
# ----- 5) Complementary info (if available) -----
|
||||||
|
complementary_info = None
|
||||||
|
if has_complementary_info:
|
||||||
|
complementary_info = {}
|
||||||
|
for key in complementary_info_keys:
|
||||||
|
# Strip the "complementary_info." prefix to get the actual key
|
||||||
|
clean_key = key[len("complementary_info.") :]
|
||||||
|
val = current_sample[key]
|
||||||
|
# Handle tensor and non-tensor values differently
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
|
||||||
|
else:
|
||||||
|
# TODO: (azouitine) Check if it's necessary to convert to tensor
|
||||||
|
# For non-tensor values, use directly
|
||||||
|
complementary_info[clean_key] = val
|
||||||
|
|
||||||
# ----- Construct the Transition -----
|
# ----- Construct the Transition -----
|
||||||
transition = Transition(
|
transition = Transition(
|
||||||
state=current_state,
|
state=current_state,
|
||||||
|
@ -640,6 +848,7 @@ class ReplayBuffer:
|
||||||
next_state=next_state,
|
next_state=next_state,
|
||||||
done=done,
|
done=done,
|
||||||
truncated=truncated,
|
truncated=truncated,
|
||||||
|
complementary_info=complementary_info,
|
||||||
)
|
)
|
||||||
transitions.append(transition)
|
transitions.append(transition)
|
||||||
|
|
||||||
|
@ -647,12 +856,13 @@ class ReplayBuffer:
|
||||||
|
|
||||||
|
|
||||||
# Utility function to guess shapes/dtypes from a tensor
|
# Utility function to guess shapes/dtypes from a tensor
|
||||||
def guess_feature_info(t: torch.Tensor, name: str):
|
def guess_feature_info(t, name: str):
|
||||||
"""
|
"""
|
||||||
Return a dictionary with the 'dtype' and 'shape' for a given tensor or array.
|
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
|
||||||
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
|
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
|
||||||
Otherwise default to 'float32' for numeric. You can customize as needed.
|
Otherwise default to appropriate dtype for numeric.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
shape = tuple(t.shape)
|
shape = tuple(t.shape)
|
||||||
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
|
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
|
||||||
if len(shape) == 3 and shape[0] in [1, 3]:
|
if len(shape) == 3 and shape[0] in [1, 3]:
|
||||||
|
@ -672,32 +882,33 @@ def concatenate_batch_transitions(
|
||||||
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||||
) -> BatchTransition:
|
) -> BatchTransition:
|
||||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||||
|
# Concatenate state fields
|
||||||
left_batch_transitions["state"] = {
|
left_batch_transitions["state"] = {
|
||||||
key: torch.cat(
|
key: torch.cat(
|
||||||
[
|
[left_batch_transitions["state"][key], right_batch_transition["state"][key]],
|
||||||
left_batch_transitions["state"][key],
|
|
||||||
right_batch_transition["state"][key],
|
|
||||||
],
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
for key in left_batch_transitions["state"]
|
for key in left_batch_transitions["state"]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Concatenate basic fields
|
||||||
left_batch_transitions["action"] = torch.cat(
|
left_batch_transitions["action"] = torch.cat(
|
||||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||||
)
|
)
|
||||||
left_batch_transitions["reward"] = torch.cat(
|
left_batch_transitions["reward"] = torch.cat(
|
||||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Concatenate next_state fields
|
||||||
left_batch_transitions["next_state"] = {
|
left_batch_transitions["next_state"] = {
|
||||||
key: torch.cat(
|
key: torch.cat(
|
||||||
[
|
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]],
|
||||||
left_batch_transitions["next_state"][key],
|
|
||||||
right_batch_transition["next_state"][key],
|
|
||||||
],
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
for key in left_batch_transitions["next_state"]
|
for key in left_batch_transitions["next_state"]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Concatenate done and truncated fields
|
||||||
left_batch_transitions["done"] = torch.cat(
|
left_batch_transitions["done"] = torch.cat(
|
||||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||||
)
|
)
|
||||||
|
@ -705,479 +916,114 @@ def concatenate_batch_transitions(
|
||||||
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Handle complementary_info
|
||||||
|
left_info = left_batch_transitions.get("complementary_info")
|
||||||
|
right_info = right_batch_transition.get("complementary_info")
|
||||||
|
|
||||||
|
# Only process if right_info exists
|
||||||
|
if right_info is not None:
|
||||||
|
# Initialize left complementary_info if needed
|
||||||
|
if left_info is None:
|
||||||
|
left_batch_transitions["complementary_info"] = right_info
|
||||||
|
else:
|
||||||
|
# Concatenate each field
|
||||||
|
for key in right_info:
|
||||||
|
if key in left_info:
|
||||||
|
left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0)
|
||||||
|
else:
|
||||||
|
left_info[key] = right_info[key]
|
||||||
|
|
||||||
return left_batch_transitions
|
return left_batch_transitions
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from tempfile import TemporaryDirectory
|
|
||||||
|
|
||||||
# ===== Test 1: Create and use a synthetic ReplayBuffer =====
|
def test_load_dataset_with_complementary_info():
|
||||||
print("Testing synthetic ReplayBuffer...")
|
"""
|
||||||
|
Test loading a dataset with complementary_info into a ReplayBuffer.
|
||||||
|
The dataset 'aractingi/pick_lift_cube_two_cameras_gripper_penalty' contains
|
||||||
|
gripper_penalty values in complementary_info.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
# Create sample data dimensions
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
batch_size = 32
|
|
||||||
state_dims = {"observation.image": (3, 84, 84), "observation.state": (10,)}
|
|
||||||
action_dim = (6,)
|
|
||||||
|
|
||||||
# Create a buffer
|
print("Loading dataset with complementary info...")
|
||||||
buffer = ReplayBuffer(
|
# Load a small subset of the dataset (first episode)
|
||||||
capacity=1000,
|
dataset = LeRobotDataset(
|
||||||
device="cpu",
|
repo_id="aractingi/pick_lift_cube_two_cameras_gripper_penalty",
|
||||||
state_keys=list(state_dims.keys()),
|
|
||||||
use_drq=True,
|
|
||||||
storage_device="cpu",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add some random transitions
|
print(f"Dataset loaded with {len(dataset)} frames")
|
||||||
for i in range(100):
|
|
||||||
# Create dummy transition data
|
|
||||||
state = {
|
|
||||||
"observation.image": torch.rand(1, 3, 84, 84),
|
|
||||||
"observation.state": torch.rand(1, 10),
|
|
||||||
}
|
|
||||||
action = torch.rand(1, 6)
|
|
||||||
reward = 0.5
|
|
||||||
next_state = {
|
|
||||||
"observation.image": torch.rand(1, 3, 84, 84),
|
|
||||||
"observation.state": torch.rand(1, 10),
|
|
||||||
}
|
|
||||||
done = False if i < 99 else True
|
|
||||||
truncated = False
|
|
||||||
|
|
||||||
buffer.add(
|
|
||||||
state=state,
|
|
||||||
action=action,
|
|
||||||
reward=reward,
|
|
||||||
next_state=next_state,
|
|
||||||
done=done,
|
|
||||||
truncated=truncated,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test sampling
|
|
||||||
batch = buffer.sample(batch_size)
|
|
||||||
print(f"Buffer size: {len(buffer)}")
|
|
||||||
print(
|
|
||||||
f"Sampled batch state shapes: {batch['state']['observation.image'].shape}, {batch['state']['observation.state'].shape}"
|
|
||||||
)
|
|
||||||
print(f"Sampled batch action shape: {batch['action'].shape}")
|
|
||||||
print(f"Sampled batch reward shape: {batch['reward'].shape}")
|
|
||||||
print(f"Sampled batch done shape: {batch['done'].shape}")
|
|
||||||
print(f"Sampled batch truncated shape: {batch['truncated'].shape}")
|
|
||||||
|
|
||||||
# ===== Test for state-action-reward alignment =====
|
|
||||||
print("\nTesting state-action-reward alignment...")
|
|
||||||
|
|
||||||
# Create a buffer with controlled transitions where we know the relationships
|
|
||||||
aligned_buffer = ReplayBuffer(
|
|
||||||
capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create transitions with known relationships
|
|
||||||
# - Each state has a unique signature value
|
|
||||||
# - Action is 2x the state signature
|
|
||||||
# - Reward is 3x the state signature
|
|
||||||
# - Next state is signature + 0.01 (unless at episode end)
|
|
||||||
for i in range(100):
|
|
||||||
# Create a state with a signature value that encodes the transition number
|
|
||||||
signature = float(i) / 100.0
|
|
||||||
state = {"state_value": torch.tensor([[signature]]).float()}
|
|
||||||
|
|
||||||
# Action is 2x the signature
|
|
||||||
action = torch.tensor([[2.0 * signature]]).float()
|
|
||||||
|
|
||||||
# Reward is 3x the signature
|
|
||||||
reward = 3.0 * signature
|
|
||||||
|
|
||||||
# Next state is signature + 0.01, unless end of episode
|
|
||||||
# End episode every 10 steps
|
|
||||||
is_end = (i + 1) % 10 == 0
|
|
||||||
|
|
||||||
if is_end:
|
|
||||||
# At episode boundaries, next_state repeats current state (as per your implementation)
|
|
||||||
next_state = {"state_value": torch.tensor([[signature]]).float()}
|
|
||||||
done = True
|
|
||||||
else:
|
|
||||||
# Within episodes, next_state has signature + 0.01
|
|
||||||
next_signature = float(i + 1) / 100.0
|
|
||||||
next_state = {"state_value": torch.tensor([[next_signature]]).float()}
|
|
||||||
done = False
|
|
||||||
|
|
||||||
aligned_buffer.add(state, action, reward, next_state, done, False)
|
|
||||||
|
|
||||||
# Sample from this buffer
|
|
||||||
aligned_batch = aligned_buffer.sample(50)
|
|
||||||
|
|
||||||
# Verify alignments in sampled batch
|
|
||||||
correct_relationships = 0
|
|
||||||
total_checks = 0
|
|
||||||
|
|
||||||
# For each transition in the batch
|
|
||||||
for i in range(50):
|
|
||||||
# Extract signature from state
|
|
||||||
state_sig = aligned_batch["state"]["state_value"][i].item()
|
|
||||||
|
|
||||||
# Check action is 2x signature (within reasonable precision)
|
|
||||||
action_val = aligned_batch["action"][i].item()
|
|
||||||
action_check = abs(action_val - 2.0 * state_sig) < 1e-4
|
|
||||||
|
|
||||||
# Check reward is 3x signature (within reasonable precision)
|
|
||||||
reward_val = aligned_batch["reward"][i].item()
|
|
||||||
reward_check = abs(reward_val - 3.0 * state_sig) < 1e-4
|
|
||||||
|
|
||||||
# Check next_state relationship matches our pattern
|
|
||||||
next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
|
|
||||||
is_done = aligned_batch["done"][i].item() > 0.5
|
|
||||||
|
|
||||||
# Calculate expected next_state value based on done flag
|
|
||||||
if is_done:
|
|
||||||
# For episodes that end, next_state should equal state
|
|
||||||
next_state_check = abs(next_state_sig - state_sig) < 1e-4
|
|
||||||
else:
|
|
||||||
# For continuing episodes, check if next_state is approximately state + 0.01
|
|
||||||
# We need to be careful because we don't know the original index
|
|
||||||
# So we check if the increment is roughly 0.01
|
|
||||||
next_state_check = (
|
|
||||||
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
|
|
||||||
)
|
|
||||||
|
|
||||||
# Count correct relationships
|
|
||||||
if action_check:
|
|
||||||
correct_relationships += 1
|
|
||||||
if reward_check:
|
|
||||||
correct_relationships += 1
|
|
||||||
if next_state_check:
|
|
||||||
correct_relationships += 1
|
|
||||||
|
|
||||||
total_checks += 3
|
|
||||||
|
|
||||||
alignment_accuracy = 100.0 * correct_relationships / total_checks
|
|
||||||
print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%")
|
|
||||||
if alignment_accuracy > 99.0:
|
|
||||||
print("✅ All relationships verified! Buffer maintains correct temporal relationships.")
|
|
||||||
else:
|
|
||||||
print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.")
|
|
||||||
|
|
||||||
# Print some debug information about failures
|
|
||||||
print("\nDebug information for failed checks:")
|
|
||||||
for i in range(5): # Print first 5 transitions for debugging
|
|
||||||
state_sig = aligned_batch["state"]["state_value"][i].item()
|
|
||||||
action_val = aligned_batch["action"][i].item()
|
|
||||||
reward_val = aligned_batch["reward"][i].item()
|
|
||||||
next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
|
|
||||||
is_done = aligned_batch["done"][i].item() > 0.5
|
|
||||||
|
|
||||||
print(f"Transition {i}:")
|
|
||||||
print(f" State: {state_sig:.6f}")
|
|
||||||
print(f" Action: {action_val:.6f} (expected: {2.0 * state_sig:.6f})")
|
|
||||||
print(f" Reward: {reward_val:.6f} (expected: {3.0 * state_sig:.6f})")
|
|
||||||
print(f" Done: {is_done}")
|
|
||||||
print(f" Next state: {next_state_sig:.6f}")
|
|
||||||
|
|
||||||
# Calculate expected next state
|
|
||||||
if is_done:
|
|
||||||
expected_next = state_sig
|
|
||||||
else:
|
|
||||||
# This approximation might not be perfect
|
|
||||||
state_idx = round(state_sig * 100)
|
|
||||||
expected_next = (state_idx + 1) / 100.0
|
|
||||||
|
|
||||||
print(f" Expected next state: {expected_next:.6f}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# ===== Test 2: Convert to LeRobotDataset and back =====
|
|
||||||
with TemporaryDirectory() as temp_dir:
|
|
||||||
print("\nTesting conversion to LeRobotDataset and back...")
|
|
||||||
# Convert buffer to dataset
|
|
||||||
repo_id = "test/replay_buffer_conversion"
|
|
||||||
# Create a subdirectory to avoid the "directory exists" error
|
|
||||||
dataset_dir = os.path.join(temp_dir, "dataset1")
|
|
||||||
dataset = buffer.to_lerobot_dataset(repo_id=repo_id, root=dataset_dir)
|
|
||||||
|
|
||||||
print(f"Dataset created with {len(dataset)} frames")
|
|
||||||
print(f"Dataset features: {list(dataset.features.keys())}")
|
print(f"Dataset features: {list(dataset.features.keys())}")
|
||||||
|
|
||||||
# Check a random sample from the dataset
|
# Check if dataset has complementary_info.gripper_penalty
|
||||||
sample = dataset[0]
|
sample = dataset[0]
|
||||||
print(
|
complementary_info_keys = [key for key in sample if key.startswith("complementary_info")]
|
||||||
f"Dataset sample types: {[(k, type(v)) for k, v in sample.items() if k.startswith('observation')]}"
|
print(f"Complementary info keys: {complementary_info_keys}")
|
||||||
|
|
||||||
|
if "complementary_info.gripper_penalty" in sample:
|
||||||
|
print(f"Found gripper_penalty: {sample['complementary_info.gripper_penalty']}")
|
||||||
|
|
||||||
|
# Extract state keys for the buffer
|
||||||
|
state_keys = []
|
||||||
|
for key in sample:
|
||||||
|
if key.startswith("observation"):
|
||||||
|
state_keys.append(key)
|
||||||
|
|
||||||
|
print(f"Using state keys: {state_keys}")
|
||||||
|
|
||||||
|
# Create a replay buffer from the dataset
|
||||||
|
start_time = time.time()
|
||||||
|
buffer = ReplayBuffer.from_lerobot_dataset(
|
||||||
|
lerobot_dataset=dataset, state_keys=state_keys, use_drq=True, optimize_memory=False
|
||||||
)
|
)
|
||||||
|
load_time = time.time() - start_time
|
||||||
|
print(f"Loaded dataset into buffer in {load_time:.2f} seconds")
|
||||||
|
print(f"Buffer size: {len(buffer)}")
|
||||||
|
|
||||||
# Convert dataset back to buffer
|
# Check if complementary_info was transferred correctly
|
||||||
reconverted_buffer = ReplayBuffer.from_lerobot_dataset(
|
print("Sampling from buffer to check complementary_info...")
|
||||||
dataset, state_keys=list(state_dims.keys()), device="cpu"
|
batch = buffer.sample(batch_size=4)
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Reconverted buffer size: {len(reconverted_buffer)}")
|
if batch["complementary_info"] is not None:
|
||||||
|
print("Complementary info in batch:")
|
||||||
# Sample from the reconverted buffer
|
for key, value in batch["complementary_info"].items():
|
||||||
reconverted_batch = reconverted_buffer.sample(batch_size)
|
print(f" {key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else 'N/A'}")
|
||||||
print(
|
if key == "gripper_penalty":
|
||||||
f"Reconverted batch state shapes: {reconverted_batch['state']['observation.image'].shape}, {reconverted_batch['state']['observation.state'].shape}"
|
print(f" Sample gripper_penalty values: {value[:5]}")
|
||||||
)
|
|
||||||
|
|
||||||
# Verify consistency before and after conversion
|
|
||||||
original_states = batch["state"]["observation.image"].mean().item()
|
|
||||||
reconverted_states = reconverted_batch["state"]["observation.image"].mean().item()
|
|
||||||
print(f"Original buffer state mean: {original_states:.4f}")
|
|
||||||
print(f"Reconverted buffer state mean: {reconverted_states:.4f}")
|
|
||||||
|
|
||||||
if abs(original_states - reconverted_states) < 1.0:
|
|
||||||
print("Values are reasonably similar - conversion works as expected")
|
|
||||||
else:
|
else:
|
||||||
print("WARNING: Significant difference between original and reconverted values")
|
print("No complementary_info found in batch")
|
||||||
|
|
||||||
print("\nAll previous tests completed!")
|
# Now convert the buffer back to a LeRobotDataset
|
||||||
|
print("\nConverting buffer back to LeRobotDataset...")
|
||||||
# ===== Test for memory optimization =====
|
start_time = time.time()
|
||||||
print("\n===== Testing Memory Optimization =====")
|
new_dataset = buffer.to_lerobot_dataset(
|
||||||
|
repo_id="test_dataset_from_buffer",
|
||||||
# Create two buffers, one with memory optimization and one without
|
fps=dataset.fps,
|
||||||
standard_buffer = ReplayBuffer(
|
root="./test_dataset_from_buffer",
|
||||||
capacity=1000,
|
task_name="test_conversion",
|
||||||
device="cpu",
|
|
||||||
state_keys=["observation.image", "observation.state"],
|
|
||||||
storage_device="cpu",
|
|
||||||
optimize_memory=False,
|
|
||||||
use_drq=True,
|
|
||||||
)
|
)
|
||||||
|
convert_time = time.time() - start_time
|
||||||
|
print(f"Converted buffer to dataset in {convert_time:.2f} seconds")
|
||||||
|
print(f"New dataset size: {len(new_dataset)} frames")
|
||||||
|
|
||||||
optimized_buffer = ReplayBuffer(
|
# Check if complementary_info was preserved
|
||||||
capacity=1000,
|
new_sample = new_dataset[0]
|
||||||
device="cpu",
|
new_complementary_info_keys = [key for key in new_sample if key.startswith("complementary_info")]
|
||||||
state_keys=["observation.image", "observation.state"],
|
print(f"New dataset complementary info keys: {new_complementary_info_keys}")
|
||||||
storage_device="cpu",
|
|
||||||
optimize_memory=True,
|
|
||||||
use_drq=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate sample data with larger state dimensions for better memory impact
|
if "complementary_info.gripper_penalty" in new_sample:
|
||||||
print("Generating test data...")
|
print(f"Found gripper_penalty in new dataset: {new_sample['complementary_info.gripper_penalty']}")
|
||||||
num_episodes = 10
|
|
||||||
steps_per_episode = 50
|
|
||||||
total_steps = num_episodes * steps_per_episode
|
|
||||||
|
|
||||||
for episode in range(num_episodes):
|
# Compare original and new datasets
|
||||||
for step in range(steps_per_episode):
|
print("\nComparing original and new datasets:")
|
||||||
# Index in the overall sequence
|
print(f"Original dataset frames: {len(dataset)}, New dataset frames: {len(new_dataset)}")
|
||||||
i = episode * steps_per_episode + step
|
print(f"Original features: {list(dataset.features.keys())}")
|
||||||
|
print(f"New features: {list(new_dataset.features.keys())}")
|
||||||
|
|
||||||
# Create state with identifiable values
|
return buffer, dataset, new_dataset
|
||||||
img = torch.ones((3, 84, 84)) * (i / total_steps)
|
|
||||||
state_vec = torch.ones((10,)) * (i / total_steps)
|
|
||||||
|
|
||||||
state = {
|
# Run the test
|
||||||
"observation.image": img.unsqueeze(0),
|
test_load_dataset_with_complementary_info()
|
||||||
"observation.state": state_vec.unsqueeze(0),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create next state (i+1 or same as current if last in episode)
|
|
||||||
is_last_step = step == steps_per_episode - 1
|
|
||||||
|
|
||||||
if is_last_step:
|
|
||||||
# At episode end, next state = current state
|
|
||||||
next_img = img.clone()
|
|
||||||
next_state_vec = state_vec.clone()
|
|
||||||
done = True
|
|
||||||
truncated = False
|
|
||||||
else:
|
|
||||||
# Within episode, next state has incremented value
|
|
||||||
next_val = (i + 1) / total_steps
|
|
||||||
next_img = torch.ones((3, 84, 84)) * next_val
|
|
||||||
next_state_vec = torch.ones((10,)) * next_val
|
|
||||||
done = False
|
|
||||||
truncated = False
|
|
||||||
|
|
||||||
next_state = {
|
|
||||||
"observation.image": next_img.unsqueeze(0),
|
|
||||||
"observation.state": next_state_vec.unsqueeze(0),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Action and reward
|
|
||||||
action = torch.tensor([[i / total_steps]])
|
|
||||||
reward = float(i / total_steps)
|
|
||||||
|
|
||||||
# Add to both buffers
|
|
||||||
standard_buffer.add(state, action, reward, next_state, done, truncated)
|
|
||||||
optimized_buffer.add(state, action, reward, next_state, done, truncated)
|
|
||||||
|
|
||||||
# Verify episode boundaries with our simplified approach
|
|
||||||
print("\nVerifying simplified memory optimization...")
|
|
||||||
|
|
||||||
# Test with a new buffer with a small sequence
|
|
||||||
test_buffer = ReplayBuffer(
|
|
||||||
capacity=20,
|
|
||||||
device="cpu",
|
|
||||||
state_keys=["value"],
|
|
||||||
storage_device="cpu",
|
|
||||||
optimize_memory=True,
|
|
||||||
use_drq=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add a simple sequence with known episode boundaries
|
|
||||||
for i in range(20):
|
|
||||||
val = float(i)
|
|
||||||
state = {"value": torch.tensor([[val]]).float()}
|
|
||||||
next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps
|
|
||||||
next_state = {"value": torch.tensor([[next_val]]).float()}
|
|
||||||
|
|
||||||
# Set done=True at every 5th step
|
|
||||||
done = (i % 5) == 4
|
|
||||||
action = torch.tensor([[0.0]])
|
|
||||||
reward = 1.0
|
|
||||||
truncated = False
|
|
||||||
|
|
||||||
test_buffer.add(state, action, reward, next_state, done, truncated)
|
|
||||||
|
|
||||||
# Get sequential batch for verification
|
|
||||||
sequential_batch_size = test_buffer.size
|
|
||||||
all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device)
|
|
||||||
|
|
||||||
# Get state tensors
|
|
||||||
batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)}
|
|
||||||
|
|
||||||
# Get next_state using memory-optimized approach (simply index+1)
|
|
||||||
next_indices = (all_indices + 1) % test_buffer.capacity
|
|
||||||
batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)}
|
|
||||||
|
|
||||||
# Get other tensors
|
|
||||||
batch_dones = test_buffer.dones[all_indices].to(test_buffer.device)
|
|
||||||
|
|
||||||
# Print sequential values
|
|
||||||
print("State, Next State, Done (Sequential values with simplified optimization):")
|
|
||||||
state_values = batch_state["value"].squeeze().tolist()
|
|
||||||
next_values = batch_next_state["value"].squeeze().tolist()
|
|
||||||
done_flags = batch_dones.tolist()
|
|
||||||
|
|
||||||
# Print all values
|
|
||||||
for i in range(len(state_values)):
|
|
||||||
print(f" {state_values[i]:.1f} → {next_values[i]:.1f}, Done: {done_flags[i]}")
|
|
||||||
|
|
||||||
# Explain the memory optimization tradeoff
|
|
||||||
print("\nWith simplified memory optimization:")
|
|
||||||
print("- We always use the next state in the buffer (index+1) as next_state")
|
|
||||||
print("- For terminal states, this means using the first state of the next episode")
|
|
||||||
print("- This is a common tradeoff in RL implementations for memory efficiency")
|
|
||||||
print("- Since we track done flags, the algorithm can handle these transitions correctly")
|
|
||||||
|
|
||||||
# Test random sampling
|
|
||||||
print("\nVerifying random sampling with simplified memory optimization...")
|
|
||||||
random_samples = test_buffer.sample(20) # Sample all transitions
|
|
||||||
|
|
||||||
# Extract values
|
|
||||||
random_state_values = random_samples["state"]["value"].squeeze().tolist()
|
|
||||||
random_next_values = random_samples["next_state"]["value"].squeeze().tolist()
|
|
||||||
random_done_flags = random_samples["done"].bool().tolist()
|
|
||||||
|
|
||||||
# Print a few samples
|
|
||||||
print("Random samples - State, Next State, Done (First 10):")
|
|
||||||
for i in range(10):
|
|
||||||
print(f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}")
|
|
||||||
|
|
||||||
# Calculate memory savings
|
|
||||||
# Assume optimized_buffer and standard_buffer have already been initialized and filled
|
|
||||||
std_mem = (
|
|
||||||
sum(
|
|
||||||
standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size()
|
|
||||||
for key in standard_buffer.states
|
|
||||||
)
|
|
||||||
* 2
|
|
||||||
)
|
|
||||||
opt_mem = sum(
|
|
||||||
optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size()
|
|
||||||
for key in optimized_buffer.states
|
|
||||||
)
|
|
||||||
|
|
||||||
savings_percent = (std_mem - opt_mem) / std_mem * 100
|
|
||||||
|
|
||||||
print("\nMemory optimization result:")
|
|
||||||
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
|
|
||||||
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
|
|
||||||
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")
|
|
||||||
|
|
||||||
print("\nAll memory optimization tests completed!")
|
|
||||||
|
|
||||||
# # ===== Test real dataset conversion =====
|
|
||||||
# print("\n===== Testing Real LeRobotDataset Conversion =====")
|
|
||||||
# try:
|
|
||||||
# # Try to use a real dataset if available
|
|
||||||
# dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
|
|
||||||
# dataset = LeRobotDataset(repo_id=dataset_name)
|
|
||||||
|
|
||||||
# # Print available keys to debug
|
|
||||||
# sample = dataset[0]
|
|
||||||
# print("Available keys in dataset:", list(sample.keys()))
|
|
||||||
|
|
||||||
# # Check for required keys
|
|
||||||
# if "action" not in sample or "next.reward" not in sample:
|
|
||||||
# print("Dataset missing essential keys. Cannot convert.")
|
|
||||||
# raise ValueError("Missing required keys in dataset")
|
|
||||||
|
|
||||||
# # Auto-detect appropriate state keys
|
|
||||||
# image_keys = []
|
|
||||||
# state_keys = []
|
|
||||||
# for k, v in sample.items():
|
|
||||||
# # Skip metadata keys and action/reward keys
|
|
||||||
# if k in {
|
|
||||||
# "index",
|
|
||||||
# "episode_index",
|
|
||||||
# "frame_index",
|
|
||||||
# "timestamp",
|
|
||||||
# "task_index",
|
|
||||||
# "action",
|
|
||||||
# "next.reward",
|
|
||||||
# "next.done",
|
|
||||||
# }:
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# # Infer key type from tensor shape
|
|
||||||
# if isinstance(v, torch.Tensor):
|
|
||||||
# if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
|
|
||||||
# # Likely an image (channels, height, width)
|
|
||||||
# image_keys.append(k)
|
|
||||||
# else:
|
|
||||||
# # Likely state or other vector
|
|
||||||
# state_keys.append(k)
|
|
||||||
|
|
||||||
# print(f"Detected image keys: {image_keys}")
|
|
||||||
# print(f"Detected state keys: {state_keys}")
|
|
||||||
|
|
||||||
# if not image_keys and not state_keys:
|
|
||||||
# print("No usable keys found in dataset, skipping further tests")
|
|
||||||
# raise ValueError("No usable keys found in dataset")
|
|
||||||
|
|
||||||
# # Test with standard and memory-optimized buffers
|
|
||||||
# for optimize_memory in [False, True]:
|
|
||||||
# buffer_type = "Standard" if not optimize_memory else "Memory-optimized"
|
|
||||||
# print(f"\nTesting {buffer_type} buffer with real dataset...")
|
|
||||||
|
|
||||||
# # Convert to ReplayBuffer with detected keys
|
|
||||||
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
|
||||||
# lerobot_dataset=dataset,
|
|
||||||
# state_keys=image_keys + state_keys,
|
|
||||||
# device="cpu",
|
|
||||||
# optimize_memory=optimize_memory,
|
|
||||||
# )
|
|
||||||
# print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
|
|
||||||
|
|
||||||
# # Test sampling
|
|
||||||
# real_batch = replay_buffer.sample(32)
|
|
||||||
# print(f"Sampled batch from real dataset ({buffer_type}), state shapes:")
|
|
||||||
# for key in real_batch["state"]:
|
|
||||||
# print(f" {key}: {real_batch['state'][key].shape}")
|
|
||||||
|
|
||||||
# # Convert back to LeRobotDataset
|
|
||||||
# with TemporaryDirectory() as temp_dir:
|
|
||||||
# dataset_name = f"test/real_dataset_converted_{buffer_type}"
|
|
||||||
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(
|
|
||||||
# repo_id=dataset_name,
|
|
||||||
# root=os.path.join(temp_dir, f"dataset_{buffer_type}"),
|
|
||||||
# )
|
|
||||||
# print(
|
|
||||||
# f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"Real dataset test failed: {e}")
|
|
||||||
# print("This is expected if running offline or if the dataset is not available.")
|
|
||||||
|
|
||||||
# print("\nAll tests completed!")
|
|
||||||
|
|
|
@ -761,6 +761,62 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||||
|
def __init__(self, env, penalty: float = -0.1):
|
||||||
|
super().__init__(env)
|
||||||
|
self.penalty = penalty
|
||||||
|
self.last_gripper_state = None
|
||||||
|
|
||||||
|
def reward(self, reward, action):
|
||||||
|
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
|
||||||
|
|
||||||
|
if isinstance(action, tuple):
|
||||||
|
action = action[0]
|
||||||
|
action_normalized = action[-1] / MAX_GRIPPER_COMMAND
|
||||||
|
|
||||||
|
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
|
||||||
|
gripper_state_normalized > 0.9 and action_normalized < 0.1
|
||||||
|
)
|
||||||
|
breakpoint()
|
||||||
|
|
||||||
|
return reward + self.penalty * gripper_penalty_bool
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
reward = self.reward(reward, action)
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
self.last_gripper_state = None
|
||||||
|
return super().reset(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class GripperQuantizationWrapper(gym.ActionWrapper):
|
||||||
|
def __init__(self, env, quantization_threshold: float = 0.2):
|
||||||
|
super().__init__(env)
|
||||||
|
self.quantization_threshold = quantization_threshold
|
||||||
|
|
||||||
|
def action(self, action):
|
||||||
|
is_intervention = False
|
||||||
|
if isinstance(action, tuple):
|
||||||
|
action, is_intervention = action
|
||||||
|
|
||||||
|
gripper_command = action[-1]
|
||||||
|
# Quantize gripper command to -1, 0 or 1
|
||||||
|
if gripper_command < -self.quantization_threshold:
|
||||||
|
gripper_command = -MAX_GRIPPER_COMMAND
|
||||||
|
elif gripper_command > self.quantization_threshold:
|
||||||
|
gripper_command = MAX_GRIPPER_COMMAND
|
||||||
|
else:
|
||||||
|
gripper_command = 0.0
|
||||||
|
|
||||||
|
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||||
|
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
||||||
|
action[-1] = gripper_action.item()
|
||||||
|
return action, is_intervention
|
||||||
|
|
||||||
|
|
||||||
class EEActionWrapper(gym.ActionWrapper):
|
class EEActionWrapper(gym.ActionWrapper):
|
||||||
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
|
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
@ -820,17 +876,7 @@ class EEActionWrapper(gym.ActionWrapper):
|
||||||
fk_func=self.fk_function,
|
fk_func=self.fk_function,
|
||||||
)
|
)
|
||||||
if self.use_gripper:
|
if self.use_gripper:
|
||||||
# Quantize gripper command to -1, 0 or 1
|
target_joint_pos[-1] = gripper_command
|
||||||
if gripper_command < -0.2:
|
|
||||||
gripper_command = -1.0
|
|
||||||
elif gripper_command > 0.2:
|
|
||||||
gripper_command = 1.0
|
|
||||||
else:
|
|
||||||
gripper_command = 0.0
|
|
||||||
|
|
||||||
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
|
||||||
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
|
||||||
target_joint_pos[-1] = gripper_action
|
|
||||||
|
|
||||||
return target_joint_pos, is_intervention
|
return target_joint_pos, is_intervention
|
||||||
|
|
||||||
|
@ -1118,6 +1164,12 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
# Add reward computation and control wrappers
|
# Add reward computation and control wrappers
|
||||||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||||
|
if cfg.wrapper.use_gripper:
|
||||||
|
env = GripperQuantizationWrapper(
|
||||||
|
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||||
|
)
|
||||||
|
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
|
||||||
|
|
||||||
if cfg.wrapper.ee_action_space_params is not None:
|
if cfg.wrapper.ee_action_space_params is not None:
|
||||||
env = EEActionWrapper(
|
env = EEActionWrapper(
|
||||||
env=env,
|
env=env,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
#!/usr/bin/env python
|
# !/usr/bin/env python
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team.
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
@ -269,6 +269,7 @@ def add_actor_information_and_train(
|
||||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
||||||
saving_checkpoint = cfg.save_checkpoint
|
saving_checkpoint = cfg.save_checkpoint
|
||||||
online_steps = cfg.policy.online_steps
|
online_steps = cfg.policy.online_steps
|
||||||
|
async_prefetch = cfg.policy.async_prefetch
|
||||||
|
|
||||||
# Initialize logging for multiprocessing
|
# Initialize logging for multiprocessing
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
|
@ -326,6 +327,9 @@ def add_actor_information_and_train(
|
||||||
if cfg.dataset is not None:
|
if cfg.dataset is not None:
|
||||||
dataset_repo_id = cfg.dataset.repo_id
|
dataset_repo_id = cfg.dataset.repo_id
|
||||||
|
|
||||||
|
# Initialize iterators
|
||||||
|
online_iterator = None
|
||||||
|
offline_iterator = None
|
||||||
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
||||||
while True:
|
while True:
|
||||||
# Exit the training loop if shutdown is requested
|
# Exit the training loop if shutdown is requested
|
||||||
|
@ -359,13 +363,26 @@ def add_actor_information_and_train(
|
||||||
if len(replay_buffer) < online_step_before_learning:
|
if len(replay_buffer) < online_step_before_learning:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if online_iterator is None:
|
||||||
|
logging.debug("[LEARNER] Initializing online replay buffer iterator")
|
||||||
|
online_iterator = replay_buffer.get_iterator(
|
||||||
|
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||||
|
)
|
||||||
|
|
||||||
|
if offline_replay_buffer is not None and offline_iterator is None:
|
||||||
|
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
|
||||||
|
offline_iterator = offline_replay_buffer.get_iterator(
|
||||||
|
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||||
|
)
|
||||||
|
|
||||||
logging.debug("[LEARNER] Starting optimization loop")
|
logging.debug("[LEARNER] Starting optimization loop")
|
||||||
time_for_one_optimization_step = time.time()
|
time_for_one_optimization_step = time.time()
|
||||||
for _ in range(utd_ratio - 1):
|
for _ in range(utd_ratio - 1):
|
||||||
batch = replay_buffer.sample(batch_size=batch_size)
|
# Sample from the iterators
|
||||||
|
batch = next(online_iterator)
|
||||||
|
|
||||||
if dataset_repo_id is not None:
|
if dataset_repo_id is not None:
|
||||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
batch_offline = next(offline_iterator)
|
||||||
batch = concatenate_batch_transitions(
|
batch = concatenate_batch_transitions(
|
||||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||||
)
|
)
|
||||||
|
@ -392,24 +409,37 @@ def add_actor_information_and_train(
|
||||||
"next_observation_feature": next_observation_features,
|
"next_observation_feature": next_observation_features,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use the forward method for critic loss
|
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||||
loss_critic = policy.forward(forward_batch, model="critic")
|
critic_output = policy.forward(forward_batch, model="critic")
|
||||||
|
|
||||||
|
# Main critic optimization
|
||||||
|
loss_critic = critic_output["loss_critic"]
|
||||||
optimizers["critic"].zero_grad()
|
optimizers["critic"].zero_grad()
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
|
|
||||||
# clip gradients
|
|
||||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizers["critic"].step()
|
optimizers["critic"].step()
|
||||||
|
|
||||||
|
# Grasp critic optimization (if available)
|
||||||
|
if policy.config.num_discrete_actions is not None:
|
||||||
|
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
||||||
|
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
||||||
|
optimizers["grasp_critic"].zero_grad()
|
||||||
|
loss_grasp_critic.backward()
|
||||||
|
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||||
|
)
|
||||||
|
optimizers["grasp_critic"].step()
|
||||||
|
|
||||||
|
# Update target networks
|
||||||
policy.update_target_networks()
|
policy.update_target_networks()
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size=batch_size)
|
# Sample for the last update in the UTD ratio
|
||||||
|
batch = next(online_iterator)
|
||||||
|
|
||||||
if dataset_repo_id is not None:
|
if dataset_repo_id is not None:
|
||||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
batch_offline = next(offline_iterator)
|
||||||
batch = concatenate_batch_transitions(
|
batch = concatenate_batch_transitions(
|
||||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||||
)
|
)
|
||||||
|
@ -437,63 +467,80 @@ def add_actor_information_and_train(
|
||||||
"next_observation_feature": next_observation_features,
|
"next_observation_feature": next_observation_features,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use the forward method for critic loss
|
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||||
loss_critic = policy.forward(forward_batch, model="critic")
|
critic_output = policy.forward(forward_batch, model="critic")
|
||||||
|
|
||||||
|
# Main critic optimization
|
||||||
|
loss_critic = critic_output["loss_critic"]
|
||||||
optimizers["critic"].zero_grad()
|
optimizers["critic"].zero_grad()
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
|
|
||||||
# clip gradients
|
|
||||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
|
|
||||||
optimizers["critic"].step()
|
optimizers["critic"].step()
|
||||||
|
|
||||||
training_infos = {}
|
# Initialize training info dictionary
|
||||||
training_infos["loss_critic"] = loss_critic.item()
|
training_infos = {
|
||||||
training_infos["critic_grad_norm"] = critic_grad_norm
|
"loss_critic": loss_critic.item(),
|
||||||
|
"critic_grad_norm": critic_grad_norm,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Grasp critic optimization (if available)
|
||||||
|
if policy.config.num_discrete_actions is not None:
|
||||||
|
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
||||||
|
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
||||||
|
optimizers["grasp_critic"].zero_grad()
|
||||||
|
loss_grasp_critic.backward()
|
||||||
|
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||||
|
).item()
|
||||||
|
optimizers["grasp_critic"].step()
|
||||||
|
|
||||||
|
# Add grasp critic info to training info
|
||||||
|
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
|
||||||
|
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
|
||||||
|
|
||||||
|
# Actor and temperature optimization (at specified frequency)
|
||||||
if optimization_step % policy_update_freq == 0:
|
if optimization_step % policy_update_freq == 0:
|
||||||
for _ in range(policy_update_freq):
|
for _ in range(policy_update_freq):
|
||||||
# Use the forward method for actor loss
|
# Actor optimization
|
||||||
loss_actor = policy.forward(forward_batch, model="actor")
|
actor_output = policy.forward(forward_batch, model="actor")
|
||||||
|
loss_actor = actor_output["loss_actor"]
|
||||||
optimizers["actor"].zero_grad()
|
optimizers["actor"].zero_grad()
|
||||||
loss_actor.backward()
|
loss_actor.backward()
|
||||||
|
|
||||||
# clip gradients
|
|
||||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
|
|
||||||
optimizers["actor"].step()
|
optimizers["actor"].step()
|
||||||
|
|
||||||
|
# Add actor info to training info
|
||||||
training_infos["loss_actor"] = loss_actor.item()
|
training_infos["loss_actor"] = loss_actor.item()
|
||||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||||
|
|
||||||
# Temperature optimization using forward method
|
# Temperature optimization
|
||||||
loss_temperature = policy.forward(forward_batch, model="temperature")
|
temperature_output = policy.forward(forward_batch, model="temperature")
|
||||||
|
loss_temperature = temperature_output["loss_temperature"]
|
||||||
optimizers["temperature"].zero_grad()
|
optimizers["temperature"].zero_grad()
|
||||||
loss_temperature.backward()
|
loss_temperature.backward()
|
||||||
|
|
||||||
# clip gradients
|
|
||||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
|
|
||||||
optimizers["temperature"].step()
|
optimizers["temperature"].step()
|
||||||
|
|
||||||
|
# Add temperature info to training info
|
||||||
training_infos["loss_temperature"] = loss_temperature.item()
|
training_infos["loss_temperature"] = loss_temperature.item()
|
||||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||||
training_infos["temperature"] = policy.temperature
|
training_infos["temperature"] = policy.temperature
|
||||||
|
|
||||||
|
# Update temperature
|
||||||
policy.update_temperature()
|
policy.update_temperature()
|
||||||
|
|
||||||
# Check if it's time to push updated policy to actors
|
# Push policy to actors if needed
|
||||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||||
last_time_policy_pushed = time.time()
|
last_time_policy_pushed = time.time()
|
||||||
|
|
||||||
|
# Update target networks
|
||||||
policy.update_target_networks()
|
policy.update_target_networks()
|
||||||
|
|
||||||
# Log training metrics at specified intervals
|
# Log training metrics at specified intervals
|
||||||
|
@ -697,7 +744,7 @@ def save_training_checkpoint(
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
|
|
||||||
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||||
"""
|
"""
|
||||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||||
|
|
||||||
|
@ -728,7 +775,14 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||||
params=policy.actor.parameters_to_optimize,
|
params=policy.actor.parameters_to_optimize,
|
||||||
lr=cfg.policy.actor_lr,
|
lr=cfg.policy.actor_lr,
|
||||||
)
|
)
|
||||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
optimizer_critic = torch.optim.Adam(
|
||||||
|
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.policy.num_discrete_actions is not None:
|
||||||
|
optimizer_grasp_critic = torch.optim.Adam(
|
||||||
|
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||||
|
)
|
||||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
optimizers = {
|
optimizers = {
|
||||||
|
@ -736,6 +790,8 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||||
"critic": optimizer_critic,
|
"critic": optimizer_critic,
|
||||||
"temperature": optimizer_temperature,
|
"temperature": optimizer_temperature,
|
||||||
}
|
}
|
||||||
|
if cfg.policy.num_discrete_actions is not None:
|
||||||
|
optimizers["grasp_critic"] = optimizer_grasp_critic
|
||||||
return optimizers, lr_scheduler
|
return optimizers, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@ -970,12 +1026,8 @@ def get_observation_features(
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
observation_features = (
|
observation_features = policy.actor.encoder.get_image_features(observations)
|
||||||
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
|
next_observation_features = policy.actor.encoder.get_image_features(next_observations)
|
||||||
)
|
|
||||||
next_observation_features = (
|
|
||||||
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return observation_features, next_observation_features
|
return observation_features, next_observation_features
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
@ -10,7 +8,6 @@ from mani_skill.utils.wrappers.record import RecordEpisode
|
||||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||||
|
|
||||||
from lerobot.common.envs.configs import ManiskillEnvConfig
|
from lerobot.common.envs.configs import ManiskillEnvConfig
|
||||||
from lerobot.configs import parser
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_maniskill_observation(
|
def preprocess_maniskill_observation(
|
||||||
|
@ -153,6 +150,27 @@ class TimeLimitWrapper(gym.Wrapper):
|
||||||
return super().reset(seed=seed, options=options)
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
|
class ManiskillMockGripperWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, nb_discrete_actions: int = 3):
|
||||||
|
super().__init__(env)
|
||||||
|
new_shape = env.action_space[0].shape[0] + 1
|
||||||
|
new_low = np.concatenate([env.action_space[0].low, [0]])
|
||||||
|
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
|
||||||
|
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
|
||||||
|
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
if isinstance(action, tuple):
|
||||||
|
action_agent, telop_action = action
|
||||||
|
else:
|
||||||
|
telop_action = 0
|
||||||
|
action_agent = action
|
||||||
|
real_action = action_agent[:-1]
|
||||||
|
final_action = (real_action, telop_action)
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(final_action)
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
def make_maniskill(
|
def make_maniskill(
|
||||||
cfg: ManiskillEnvConfig,
|
cfg: ManiskillEnvConfig,
|
||||||
n_envs: int | None = None,
|
n_envs: int | None = None,
|
||||||
|
@ -197,40 +215,42 @@ def make_maniskill(
|
||||||
env = ManiSkillCompat(env)
|
env = ManiSkillCompat(env)
|
||||||
env = ManiSkillActionWrapper(env)
|
env = ManiSkillActionWrapper(env)
|
||||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
||||||
|
if cfg.mock_gripper:
|
||||||
|
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
# @parser.wrap()
|
||||||
def main(cfg: ManiskillEnvConfig):
|
# def main(cfg: TrainPipelineConfig):
|
||||||
"""Main function to run the ManiSkill environment."""
|
# """Main function to run the ManiSkill environment."""
|
||||||
# Create the ManiSkill environment
|
# # Create the ManiSkill environment
|
||||||
env = make_maniskill(cfg, n_envs=1)
|
# env = make_maniskill(cfg.env, n_envs=1)
|
||||||
|
|
||||||
# Reset the environment
|
# # Reset the environment
|
||||||
obs, info = env.reset()
|
# obs, info = env.reset()
|
||||||
|
|
||||||
# Run a simple interaction loop
|
# # Run a simple interaction loop
|
||||||
sum_reward = 0
|
# sum_reward = 0
|
||||||
for i in range(100):
|
# for i in range(100):
|
||||||
# Sample a random action
|
# # Sample a random action
|
||||||
action = env.action_space.sample()
|
# action = env.action_space.sample()
|
||||||
|
|
||||||
# Step the environment
|
# # Step the environment
|
||||||
start_time = time.perf_counter()
|
# start_time = time.perf_counter()
|
||||||
obs, reward, terminated, truncated, info = env.step(action)
|
# obs, reward, terminated, truncated, info = env.step(action)
|
||||||
step_time = time.perf_counter() - start_time
|
# step_time = time.perf_counter() - start_time
|
||||||
sum_reward += reward
|
# sum_reward += reward
|
||||||
# Log information
|
# # Log information
|
||||||
|
|
||||||
# Reset if episode terminated
|
# # Reset if episode terminated
|
||||||
if terminated or truncated:
|
# if terminated or truncated:
|
||||||
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
# logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
||||||
sum_reward = 0
|
# sum_reward = 0
|
||||||
obs, info = env.reset()
|
# obs, info = env.reset()
|
||||||
|
|
||||||
# Close the environment
|
# # Close the environment
|
||||||
env.close()
|
# env.close()
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue