Added normalization schemes and style checks

This commit is contained in:
Michel Aractingi 2024-12-29 12:51:21 +00:00
parent 08ec971086
commit dc54d357ca
10 changed files with 150 additions and 156 deletions

View File

@ -25,13 +25,13 @@ from glob import glob
from pathlib import Path from pathlib import Path
import torch import torch
import wandb
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
import wandb
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state from lerobot.common.utils.utils import get_global_random_state, set_global_random_state

View File

@ -2,8 +2,6 @@ import json
import os import os
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
import torch
@dataclass @dataclass
class ClassifierConfig: class ClassifierConfig:

View File

@ -23,9 +23,11 @@ class ClassifierOutput:
self.hidden_states = hidden_states self.hidden_states = hidden_states
def __repr__(self): def __repr__(self):
return (f"ClassifierOutput(logits={self.logits}, " return (
f"probabilities={self.probabilities}, " f"ClassifierOutput(logits={self.logits}, "
f"hidden_states={self.hidden_states})") f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})"
)
class Classifier( class Classifier(
@ -74,7 +76,7 @@ class Classifier(
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
else: else:
raise ValueError("Unsupported CNN architecture") raise ValueError("Unsupported CNN architecture")
self.encoder = self.encoder.to(self.config.device) self.encoder = self.encoder.to(self.config.device)
def _freeze_encoder(self) -> None: def _freeze_encoder(self) -> None:

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. # Copyright 2024 The HuggingFace Inc. team.
# All rights reserved. # All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. # Copyright 2024 The HuggingFace Inc. team.
# All rights reserved. # All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -26,4 +26,4 @@ class HILSerlPolicy(
repo_url="https://github.com/huggingface/lerobot", repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "hilserl"], tags=["robotics", "hilserl"],
): ):
pass pass

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. # Copyright 2024 The HuggingFace Inc. team.
# All rights reserved. # All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass, field
@dataclass @dataclass
@ -30,14 +30,36 @@ class SACConfig:
critic_target_update_weight = 0.005 critic_target_update_weight = 0.005
utd_ratio = 2 utd_ratio = 2
critic_network_kwargs = { critic_network_kwargs = {
"hidden_dims": [256, 256], "hidden_dims": [256, 256],
"activate_final": True, "activate_final": True,
} }
actor_network_kwargs = { actor_network_kwargs = {
"hidden_dims": [256, 256], "hidden_dims": [256, 256],
"activate_final": True, "activate_final": True,
} }
policy_kwargs = { policy_kwargs = {
"tanh_squash_distribution": True, "tanh_squash_distribution": True,
"std_parameterization": "uniform", "std_parameterization": "uniform",
}
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
} }
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
state_encoder_hidden_dim: int = 256
latent_dim: int = 256
network_hidden_dims: int = 256
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. # Copyright 2024 The HuggingFace Inc. team.
# All rights reserved. # All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -40,11 +40,9 @@ class SACPolicy(
repo_url="https://github.com/huggingface/lerobot", repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"], tags=["robotics", "RL", "SAC"],
): ):
def __init__( def __init__(
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
): ):
super().__init__() super().__init__()
if config is None: if config is None:
@ -67,12 +65,9 @@ class SACPolicy(
# Define networks # Define networks
critic_nets = [] critic_nets = []
for _ in range(config.num_critics): for _ in range(config.num_critics):
critic_net = Critic( critic_net = Critic(encoder=encoder, network=MLP(**config.critic_network_kwargs))
encoder=encoder,
network=MLP(**config.critic_network_kwargs)
)
critic_nets.append(critic_net) critic_nets.append(critic_net)
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = deepcopy(self.critic_ensemble) self.critic_target = deepcopy(self.critic_ensemble)
@ -80,11 +75,11 @@ class SACPolicy(
encoder=encoder, encoder=encoder,
network=MLP(**config.actor_network_kwargs), network=MLP(**config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0], action_dim=config.output_shapes["action"][0],
**config.policy_kwargs **config.policy_kwargs,
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
self.temperature = LagrangeMultiplier(init_value=config.temperature_init) self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
def reset(self): def reset(self):
""" """
@ -100,10 +95,10 @@ class SACPolicy(
self._queues["observation.image"] = deque(maxlen=1) self._queues["observation.image"] = deque(maxlen=1)
if self._use_env_state: if self._use_env_state:
self._queues["observation.environment_state"] = deque(maxlen=1) self._queues["observation.environment_state"] = deque(maxlen=1)
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
actions, _ = self.actor_network(batch['observations'])### actions, _ = self.actor_network(batch["observations"]) ###
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss. """Run the batch through the model and compute the loss.
@ -111,8 +106,8 @@ class SACPolicy(
Returns a dictionary with loss as a tensor, and other information as native floats. Returns a dictionary with loss as a tensor, and other information as native floats.
""" """
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and # batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for caluculating the right td index. # the next observation for caluculating the right td index.
actions = batch["action"][:, 0] actions = batch["action"][:, 0]
rewards = batch["next.reward"][:, 0] rewards = batch["next.reward"][:, 0]
observations = {} observations = {}
@ -121,13 +116,12 @@ class SACPolicy(
if k.startswith("observation."): if k.startswith("observation."):
observations[k] = batch[k][:, 0] observations[k] = batch[k][:, 0]
next_observations[k] = batch[k][:, 1] next_observations[k] = batch[k][:, 1]
# perform image augmentation # perform image augmentation
# reward bias # reward bias
# from HIL-SERL code base # from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss # calculate critics loss
# 1- compute actions from policy # 1- compute actions from policy
@ -137,7 +131,7 @@ class SACPolicy(
# subsample critics to prevent overfitting if use high UTD (update to date) # subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None: if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics) indices = torch.randperm(self.config.num_critics)
indices = indices[:self.config.num_subsample_critics] indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices] q_targets = q_targets[indices]
# critics subsample size # critics subsample size
@ -151,8 +145,9 @@ class SACPolicy(
# 4- Calculate loss # 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = ( critics_loss = (
F.mse_loss( (
F.mse_loss(
q_preds, q_preds,
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]), einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
reduction="none", reduction="none",
@ -163,15 +158,17 @@ class SACPolicy(
# q_targets depends on the reward and the next observations. # q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"] * ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:] * ~batch["observation.state_is_pad"][1:]
).sum(0).mean() )
.sum(0)
.mean()
)
# calculate actors loss # calculate actors loss
# 1- temperature # 1- temperature
temperature = self.temperature() temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,) # 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor_network(observations) \ actions, log_probs = self.actor_network(observations)
# 3- get q-value predictions # 3- get q-value predictions
with torch.no_grad(): with torch.no_grad():
q_preds = self.critic_ensemble(observations, actions, return_type="mean") q_preds = self.critic_ensemble(observations, actions, return_type="mean")
@ -181,36 +178,31 @@ class SACPolicy(
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
).mean() ).mean()
# calculate temperature loss # calculate temperature loss
# 1- calculate entropy # 1- calculate entropy
entropy = -log_probs.mean() entropy = -log_probs.mean()
temperature_loss = self.temp( temperature_loss = self.temp(lhs=entropy, rhs=self.config.target_entropy)
lhs=entropy,
rhs=self.config.target_entropy
)
loss = critics_loss + actor_loss + temperature_loss loss = critics_loss + actor_loss + temperature_loss
return { return {
"critics_loss": critics_loss.item(), "critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(), "actor_loss": actor_loss.item(),
"temperature_loss": temperature_loss.item(), "temperature_loss": temperature_loss.item(),
"temperature": temperature.item(), "temperature": temperature.item(),
"entropy": entropy.item(), "entropy": entropy.item(),
"loss": loss, "loss": loss,
}
}
def update(self): def update(self):
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight) self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
# TODO: implement UTD update # TODO: implement UTD update
# First update only critics for utd_ratio-1 times # First update only critics for utd_ratio-1 times
#for critic_step in range(self.config.utd_ratio - 1): # for critic_step in range(self.config.utd_ratio - 1):
# only update critic and critic target # only update critic and critic target
# Then update critic, critic target, actor and temperature # Then update critic, critic target, actor and temperature
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): # for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight) # target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
@ -225,24 +217,28 @@ class MLP(nn.Module):
super().__init__() super().__init__()
self.activate_final = config.activate_final self.activate_final = config.activate_final
layers = [] layers = []
for i, size in enumerate(config.network_hidden_dims): for i, size in enumerate(config.network_hidden_dims):
layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size)) layers.append(
nn.Linear(config.network_hidden_dims[i - 1] if i > 0 else config.network_hidden_dims[0], size)
)
if i + 1 < len(config.network_hidden_dims) or activate_final: if i + 1 < len(config.network_hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0: if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(size)) layers.append(nn.LayerNorm(size))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers) self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor: def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
# in training mode or not. TODO: find better way to do this # in training mode or not. TODO: find better way to do this
self.train(train) self.train(train)
return self.net(x) return self.net(x)
class Critic(nn.Module): class Critic(nn.Module):
def __init__( def __init__(
self, self,
@ -250,7 +246,7 @@ class Critic(nn.Module):
network: nn.Module, network: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
activate_final: bool = False, activate_final: bool = False,
device: str = "cuda" device: str = "cuda",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@ -258,7 +254,7 @@ class Critic(nn.Module):
self.network = network self.network = network
self.init_final = init_final self.init_final = init_final
self.activate_final = activate_final self.activate_final = activate_final
# Output layer # Output layer
if init_final is not None: if init_final is not None:
if self.activate_final: if self.activate_final:
@ -273,36 +269,28 @@ class Critic(nn.Module):
else: else:
self.output_layer = nn.Linear(network.net[-2].out_features, 1) self.output_layer = nn.Linear(network.net[-2].out_features, 1)
orthogonal_init()(self.output_layer.weight) orthogonal_init()(self.output_layer.weight)
self.to(self.device) self.to(self.device)
def forward( def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False) -> torch.Tensor:
self,
observations: torch.Tensor,
actions: torch.Tensor,
train: bool = False
) -> torch.Tensor:
self.train(train) self.train(train)
observations = observations.to(self.device) observations = observations.to(self.device)
actions = actions.to(self.device) actions = actions.to(self.device)
obs_enc = observations if self.encoder is None else self.encoder(observations) obs_enc = observations if self.encoder is None else self.encoder(observations)
inputs = torch.cat([obs_enc, actions], dim=-1) inputs = torch.cat([obs_enc, actions], dim=-1)
x = self.network(inputs) x = self.network(inputs)
value = self.output_layer(x) value = self.output_layer(x)
return value.squeeze(-1) return value.squeeze(-1)
def q_value_ensemble( def q_value_ensemble(
self, self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False
observations: torch.Tensor,
actions: torch.Tensor,
train: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
observations = observations.to(self.device) observations = observations.to(self.device)
actions = actions.to(self.device) actions = actions.to(self.device)
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim] if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
batch_size, num_actions = actions.shape[:2] batch_size, num_actions = actions.shape[:2]
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1) obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
@ -327,7 +315,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None, fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None, init_final: Optional[float] = None,
activate_final: bool = False, activate_final: bool = False,
device: str = "cuda" device: str = "cuda",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@ -340,7 +328,7 @@ class Policy(nn.Module):
self.tanh_squash_distribution = tanh_squash_distribution self.tanh_squash_distribution = tanh_squash_distribution
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.activate_final = activate_final self.activate_final = activate_final
# Mean layer # Mean layer
if self.activate_final: if self.activate_final:
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim) self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
@ -351,7 +339,7 @@ class Policy(nn.Module):
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
else: else:
orthogonal_init()(self.mean_layer.weight) orthogonal_init()(self.mean_layer.weight)
# Standard deviation layer or parameter # Standard deviation layer or parameter
if fixed_std is None: if fixed_std is None:
if std_parameterization == "uniform": if std_parameterization == "uniform":
@ -366,18 +354,18 @@ class Policy(nn.Module):
nn.init.uniform_(self.std_layer.bias, -init_final, init_final) nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else: else:
orthogonal_init()(self.std_layer.weight) orthogonal_init()(self.std_layer.weight)
self.to(self.device) self.to(self.device)
def forward( def forward(
self, self,
observations: torch.Tensor, observations: torch.Tensor,
temperature: float = 1.0, temperature: float = 1.0,
train: bool = False, train: bool = False,
non_squash_distribution: bool = False non_squash_distribution: bool = False,
) -> torch.distributions.Distribution: ) -> torch.distributions.Distribution:
self.train(train) self.train(train)
# Encode observations if encoder exists # Encode observations if encoder exists
if self.encoder is not None: if self.encoder is not None:
with torch.set_grad_enabled(train): with torch.set_grad_enabled(train):
@ -387,7 +375,7 @@ class Policy(nn.Module):
# Get network outputs # Get network outputs
outputs = self.network(obs_enc) outputs = self.network(obs_enc)
means = self.mean_layer(outputs) means = self.mean_layer(outputs)
# Compute standard deviations # Compute standard deviations
if self.fixed_std is None: if self.fixed_std is None:
if self.std_parameterization == "exp": if self.std_parameterization == "exp":
@ -398,9 +386,7 @@ class Policy(nn.Module):
elif self.std_parameterization == "uniform": elif self.std_parameterization == "uniform":
stds = torch.exp(self.log_stds).expand_as(means) stds = torch.exp(self.log_stds).expand_as(means)
else: else:
raise ValueError( raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}")
f"Invalid std_parameterization: {self.std_parameterization}"
)
else: else:
assert self.std_parameterization == "fixed" assert self.std_parameterization == "fixed"
stds = self.fixed_std.expand_as(means) stds = self.fixed_std.expand_as(means)
@ -422,7 +408,7 @@ class Policy(nn.Module):
) )
return distribution return distribution
def get_features(self, observations: torch.Tensor) -> torch.Tensor: def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations""" """Get encoded features from observations"""
observations = observations.to(self.device) observations = observations.to(self.device)
@ -503,56 +489,47 @@ class SACObservationEncoder(nn.Module):
if "observation.state" in self.config.input_shapes: if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"])) feat.append(self.state_enc_layers(obs_dict["observation.state"]))
return torch.stack(feat, dim=0).mean(0) return torch.stack(feat, dim=0).mean(0)
class LagrangeMultiplier(nn.Module): class LagrangeMultiplier(nn.Module):
def __init__( def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
self,
init_value: float = 1.0,
constraint_shape: Sequence[int] = (),
device: str = "cuda"
):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1) init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
# Initialize the Lagrange multiplier as a parameter # Initialize the Lagrange multiplier as a parameter
self.lagrange = nn.Parameter( self.lagrange = nn.Parameter(
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device) torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
) )
self.to(self.device) self.to(self.device)
def forward( def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None) -> torch.Tensor:
self, # Get the multiplier value based on parameterization
lhs: Optional[torch.Tensor] = None,
rhs: Optional[torch.Tensor] = None
) -> torch.Tensor:
# Get the multiplier value based on parameterization
multiplier = torch.nn.functional.softplus(self.lagrange) multiplier = torch.nn.functional.softplus(self.lagrange)
# Return the raw multiplier if no constraint values provided # Return the raw multiplier if no constraint values provided
if lhs is None: if lhs is None:
return multiplier return multiplier
# Move inputs to device # Move inputs to device
lhs = lhs.to(self.device) lhs = lhs.to(self.device)
if rhs is not None: if rhs is not None:
rhs = rhs.to(self.device) rhs = rhs.to(self.device)
# Use the multiplier to compute the Lagrange penalty # Use the multiplier to compute the Lagrange penalty
if rhs is None: if rhs is None:
rhs = torch.zeros_like(lhs, device=self.device) rhs = torch.zeros_like(lhs, device=self.device)
diff = lhs - rhs diff = lhs - rhs
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}" assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
return multiplier * diff return multiplier * diff
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where: # The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
# 1. The base distribution is a diagonal multivariate normal distribution # 1. The base distribution is a diagonal multivariate normal distribution
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1 # 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation # 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces # This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
@ -568,28 +545,22 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
): ):
# Create base normal distribution # Create base normal distribution
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag) base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
# Create list of transforms # Create list of transforms
transforms = [] transforms = []
# Add tanh transform # Add tanh transform
transforms.append(torch.distributions.transforms.TanhTransform()) transforms.append(torch.distributions.transforms.TanhTransform())
# Add rescaling transform if bounds are provided # Add rescaling transform if bounds are provided
if low is not None and high is not None: if low is not None and high is not None:
transforms.append( transforms.append(
torch.distributions.transforms.AffineTransform( torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2)
loc=(high + low) / 2,
scale=(high - low) / 2
)
) )
# Initialize parent class # Initialize parent class
super().__init__( super().__init__(base_distribution=base_distribution, transforms=transforms)
base_distribution=base_distribution,
transforms=transforms
)
# Store parameters # Store parameters
self.loc = loc self.loc = loc
self.scale_diag = scale_diag self.scale_diag = scale_diag
@ -600,11 +571,11 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
"""Get the mode of the transformed distribution""" """Get the mode of the transformed distribution"""
# The mode of a normal distribution is its mean # The mode of a normal distribution is its mean
mode = self.loc mode = self.loc
# Apply transforms # Apply transforms
for transform in self.transforms: for transform in self.transforms:
mode = transform(mode) mode = transform(mode)
return mode return mode
def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor: def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor:
@ -613,11 +584,11 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
""" """
# Sample from base distribution # Sample from base distribution
x = self.base_dist.rsample(sample_shape) x = self.base_dist.rsample(sample_shape)
# Apply transforms # Apply transforms
for transform in self.transforms: for transform in self.transforms:
x = transform(x) x = transform(x)
return x return x
def log_prob(self, value: torch.Tensor) -> torch.Tensor: def log_prob(self, value: torch.Tensor) -> torch.Tensor:
@ -627,16 +598,16 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
""" """
# Initialize log prob # Initialize log prob
log_prob = torch.zeros_like(value[..., 0]) log_prob = torch.zeros_like(value[..., 0])
# Inverse transforms to get back to normal distribution # Inverse transforms to get back to normal distribution
q = value q = value
for transform in reversed(self.transforms): for transform in reversed(self.transforms):
q = transform.inv(q) q = transform.inv(q)
log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q)) log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q))
# Add base distribution log prob # Add base distribution log prob
log_prob = log_prob + self.base_dist.log_prob(q).sum(-1) log_prob = log_prob + self.base_dist.log_prob(q).sum(-1)
return log_prob return log_prob
def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]: def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]:
@ -653,13 +624,13 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
""" """
# Start with base distribution entropy # Start with base distribution entropy
entropy = self.base_dist.entropy().sum(-1) entropy = self.base_dist.entropy().sum(-1)
# Add log det jacobian for each transform # Add log det jacobian for each transform
x = self.rsample() x = self.rsample()
for transform in self.transforms: for transform in self.transforms:
entropy = entropy + transform.log_abs_det_jacobian(x, transform(x)) entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
x = transform(x) x = transform(x)
return entropy return entropy
@ -680,7 +651,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
Args: Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions. (B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
can be more than 1 dimensions, generally different from *. can be more than 1 dimensions, generally different from *.
Returns: Returns:
A return value from the callable reshaped to (**, *). A return value from the callable reshaped to (**, *).
@ -691,4 +662,3 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
inp = torch.flatten(image_tensor, end_dim=-4) inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp) flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))

View File

@ -24,7 +24,7 @@ python lerobot/scripts/eval_on_robot.py \
``` ```
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared **NOTE** (michel-aractingi): This script is incomplete and it is being prepared
for running training on the real robot. for running training on the real robot.
""" """
import argparse import argparse
@ -47,7 +47,7 @@ from lerobot.common.utils.utils import (
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict: def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict:
"""Run a batched policy rollout on the real robot. """Run a batched policy rollout on the real robot.
The return dictionary contains: The return dictionary contains:
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation "robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
@ -64,7 +64,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
extraneous elements from the sequences above. extraneous elements from the sequences above.
Args: Args:
robot: The robot class that defines the interface with the real robot. robot: The robot class that defines the interface with the real robot.
policy: The policy. Must be a PyTorch nn module. policy: The policy. Must be a PyTorch nn module.
Returns: Returns:
@ -77,7 +77,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
listener, events = init_keyboard_listener() listener, events = init_keyboard_listener()
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready. # Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
# policy.reset() # policy.reset()
# Get observation from real robot # Get observation from real robot
observation = robot.capture_observation() observation = robot.capture_observation()

View File

@ -95,12 +95,14 @@ def make_optimizer_and_scheduler(cfg, policy):
lr_scheduler = None lr_scheduler = None
elif policy.name == "sac": elif policy.name == "sac":
optimizer = torch.optim.Adam([ optimizer = torch.optim.Adam(
{'params': policy.actor.parameters(), 'lr': policy.config.actor_lr}, [
{'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr}, {"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr}, {"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
]) {"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
lr_scheduler = None ]
)
lr_scheduler = None
elif cfg.policy.name == "vqbet": elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler

View File

@ -22,7 +22,6 @@ from pprint import pformat
import hydra import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
import wandb
from deepdiff import DeepDiff from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger from lerobot.common.logger import Logger