Added normalization schemes and style checks
This commit is contained in:
parent
08ec971086
commit
dc54d357ca
|
@ -25,13 +25,13 @@ from glob import glob
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import wandb
|
|
||||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
|
import wandb
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,6 @@ import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClassifierConfig:
|
class ClassifierConfig:
|
||||||
|
|
|
@ -23,9 +23,11 @@ class ClassifierOutput:
|
||||||
self.hidden_states = hidden_states
|
self.hidden_states = hidden_states
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f"ClassifierOutput(logits={self.logits}, "
|
return (
|
||||||
f"probabilities={self.probabilities}, "
|
f"ClassifierOutput(logits={self.logits}, "
|
||||||
f"hidden_states={self.hidden_states})")
|
f"probabilities={self.probabilities}, "
|
||||||
|
f"hidden_states={self.hidden_states})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Classifier(
|
class Classifier(
|
||||||
|
@ -74,7 +76,7 @@ class Classifier(
|
||||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported CNN architecture")
|
raise ValueError("Unsupported CNN architecture")
|
||||||
|
|
||||||
self.encoder = self.encoder.to(self.config.device)
|
self.encoder = self.encoder.to(self.config.device)
|
||||||
|
|
||||||
def _freeze_encoder(self) -> None:
|
def _freeze_encoder(self) -> None:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team.
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team.
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -26,4 +26,4 @@ class HILSerlPolicy(
|
||||||
repo_url="https://github.com/huggingface/lerobot",
|
repo_url="https://github.com/huggingface/lerobot",
|
||||||
tags=["robotics", "hilserl"],
|
tags=["robotics", "hilserl"],
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team.
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -15,7 +15,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -30,14 +30,36 @@ class SACConfig:
|
||||||
critic_target_update_weight = 0.005
|
critic_target_update_weight = 0.005
|
||||||
utd_ratio = 2
|
utd_ratio = 2
|
||||||
critic_network_kwargs = {
|
critic_network_kwargs = {
|
||||||
"hidden_dims": [256, 256],
|
"hidden_dims": [256, 256],
|
||||||
"activate_final": True,
|
"activate_final": True,
|
||||||
}
|
}
|
||||||
actor_network_kwargs = {
|
actor_network_kwargs = {
|
||||||
"hidden_dims": [256, 256],
|
"hidden_dims": [256, 256],
|
||||||
"activate_final": True,
|
"activate_final": True,
|
||||||
}
|
}
|
||||||
policy_kwargs = {
|
policy_kwargs = {
|
||||||
"tanh_squash_distribution": True,
|
"tanh_squash_distribution": True,
|
||||||
"std_parameterization": "uniform",
|
"std_parameterization": "uniform",
|
||||||
|
}
|
||||||
|
|
||||||
|
input_shapes: dict[str, list[int]] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"observation.image": [3, 84, 84],
|
||||||
|
"observation.state": [4],
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
output_shapes: dict[str, list[int]] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": [4],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
state_encoder_hidden_dim: int = 256
|
||||||
|
latent_dim: int = 256
|
||||||
|
network_hidden_dims: int = 256
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes: dict[str, str] | None = None
|
||||||
|
output_normalization_modes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {"action": "min_max"},
|
||||||
|
)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team.
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -40,11 +40,9 @@ class SACPolicy(
|
||||||
repo_url="https://github.com/huggingface/lerobot",
|
repo_url="https://github.com/huggingface/lerobot",
|
||||||
tags=["robotics", "RL", "SAC"],
|
tags=["robotics", "RL", "SAC"],
|
||||||
):
|
):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
|
@ -67,12 +65,9 @@ class SACPolicy(
|
||||||
# Define networks
|
# Define networks
|
||||||
critic_nets = []
|
critic_nets = []
|
||||||
for _ in range(config.num_critics):
|
for _ in range(config.num_critics):
|
||||||
critic_net = Critic(
|
critic_net = Critic(encoder=encoder, network=MLP(**config.critic_network_kwargs))
|
||||||
encoder=encoder,
|
|
||||||
network=MLP(**config.critic_network_kwargs)
|
|
||||||
)
|
|
||||||
critic_nets.append(critic_net)
|
critic_nets.append(critic_net)
|
||||||
|
|
||||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||||
self.critic_target = deepcopy(self.critic_ensemble)
|
self.critic_target = deepcopy(self.critic_ensemble)
|
||||||
|
|
||||||
|
@ -80,11 +75,11 @@ class SACPolicy(
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
network=MLP(**config.actor_network_kwargs),
|
network=MLP(**config.actor_network_kwargs),
|
||||||
action_dim=config.output_shapes["action"][0],
|
action_dim=config.output_shapes["action"][0],
|
||||||
**config.policy_kwargs
|
**config.policy_kwargs,
|
||||||
)
|
)
|
||||||
if config.target_entropy is None:
|
if config.target_entropy is None:
|
||||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
|
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
|
||||||
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
|
@ -100,10 +95,10 @@ class SACPolicy(
|
||||||
self._queues["observation.image"] = deque(maxlen=1)
|
self._queues["observation.image"] = deque(maxlen=1)
|
||||||
if self._use_env_state:
|
if self._use_env_state:
|
||||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
actions, _ = self.actor_network(batch['observations'])###
|
actions, _ = self.actor_network(batch["observations"]) ###
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||||
"""Run the batch through the model and compute the loss.
|
"""Run the batch through the model and compute the loss.
|
||||||
|
@ -111,8 +106,8 @@ class SACPolicy(
|
||||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||||
"""
|
"""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||||
# the next observation for caluculating the right td index.
|
# the next observation for caluculating the right td index.
|
||||||
actions = batch["action"][:, 0]
|
actions = batch["action"][:, 0]
|
||||||
rewards = batch["next.reward"][:, 0]
|
rewards = batch["next.reward"][:, 0]
|
||||||
observations = {}
|
observations = {}
|
||||||
|
@ -121,13 +116,12 @@ class SACPolicy(
|
||||||
if k.startswith("observation."):
|
if k.startswith("observation."):
|
||||||
observations[k] = batch[k][:, 0]
|
observations[k] = batch[k][:, 0]
|
||||||
next_observations[k] = batch[k][:, 1]
|
next_observations[k] = batch[k][:, 1]
|
||||||
|
|
||||||
# perform image augmentation
|
# perform image augmentation
|
||||||
|
|
||||||
# reward bias
|
# reward bias
|
||||||
# from HIL-SERL code base
|
# from HIL-SERL code base
|
||||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||||
|
|
||||||
|
|
||||||
# calculate critics loss
|
# calculate critics loss
|
||||||
# 1- compute actions from policy
|
# 1- compute actions from policy
|
||||||
|
@ -137,7 +131,7 @@ class SACPolicy(
|
||||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||||
if self.config.num_subsample_critics is not None:
|
if self.config.num_subsample_critics is not None:
|
||||||
indices = torch.randperm(self.config.num_critics)
|
indices = torch.randperm(self.config.num_critics)
|
||||||
indices = indices[:self.config.num_subsample_critics]
|
indices = indices[: self.config.num_subsample_critics]
|
||||||
q_targets = q_targets[indices]
|
q_targets = q_targets[indices]
|
||||||
|
|
||||||
# critics subsample size
|
# critics subsample size
|
||||||
|
@ -151,8 +145,9 @@ class SACPolicy(
|
||||||
|
|
||||||
# 4- Calculate loss
|
# 4- Calculate loss
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
critics_loss = (
|
critics_loss = (
|
||||||
F.mse_loss(
|
(
|
||||||
|
F.mse_loss(
|
||||||
q_preds,
|
q_preds,
|
||||||
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
||||||
reduction="none",
|
reduction="none",
|
||||||
|
@ -163,15 +158,17 @@ class SACPolicy(
|
||||||
# q_targets depends on the reward and the next observations.
|
# q_targets depends on the reward and the next observations.
|
||||||
* ~batch["next.reward_is_pad"]
|
* ~batch["next.reward_is_pad"]
|
||||||
* ~batch["observation.state_is_pad"][1:]
|
* ~batch["observation.state_is_pad"][1:]
|
||||||
).sum(0).mean()
|
)
|
||||||
|
.sum(0)
|
||||||
|
.mean()
|
||||||
|
)
|
||||||
|
|
||||||
# calculate actors loss
|
# calculate actors loss
|
||||||
# 1- temperature
|
# 1- temperature
|
||||||
temperature = self.temperature()
|
temperature = self.temperature()
|
||||||
|
|
||||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||||
actions, log_probs = self.actor_network(observations) \
|
actions, log_probs = self.actor_network(observations)
|
||||||
|
|
||||||
# 3- get q-value predictions
|
# 3- get q-value predictions
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
||||||
|
@ -181,36 +178,31 @@ class SACPolicy(
|
||||||
* ~batch["action_is_pad"]
|
* ~batch["action_is_pad"]
|
||||||
).mean()
|
).mean()
|
||||||
|
|
||||||
|
|
||||||
# calculate temperature loss
|
# calculate temperature loss
|
||||||
# 1- calculate entropy
|
# 1- calculate entropy
|
||||||
entropy = -log_probs.mean()
|
entropy = -log_probs.mean()
|
||||||
temperature_loss = self.temp(
|
temperature_loss = self.temp(lhs=entropy, rhs=self.config.target_entropy)
|
||||||
lhs=entropy,
|
|
||||||
rhs=self.config.target_entropy
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = critics_loss + actor_loss + temperature_loss
|
loss = critics_loss + actor_loss + temperature_loss
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"critics_loss": critics_loss.item(),
|
"critics_loss": critics_loss.item(),
|
||||||
"actor_loss": actor_loss.item(),
|
"actor_loss": actor_loss.item(),
|
||||||
"temperature_loss": temperature_loss.item(),
|
"temperature_loss": temperature_loss.item(),
|
||||||
"temperature": temperature.item(),
|
"temperature": temperature.item(),
|
||||||
"entropy": entropy.item(),
|
"entropy": entropy.item(),
|
||||||
"loss": loss,
|
"loss": loss,
|
||||||
|
}
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
||||||
# TODO: implement UTD update
|
# TODO: implement UTD update
|
||||||
# First update only critics for utd_ratio-1 times
|
# First update only critics for utd_ratio-1 times
|
||||||
#for critic_step in range(self.config.utd_ratio - 1):
|
# for critic_step in range(self.config.utd_ratio - 1):
|
||||||
# only update critic and critic target
|
# only update critic and critic target
|
||||||
# Then update critic, critic target, actor and temperature
|
# Then update critic, critic target, actor and temperature
|
||||||
|
|
||||||
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
# for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
||||||
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@ -225,24 +217,28 @@ class MLP(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.activate_final = config.activate_final
|
self.activate_final = config.activate_final
|
||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for i, size in enumerate(config.network_hidden_dims):
|
for i, size in enumerate(config.network_hidden_dims):
|
||||||
layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size))
|
layers.append(
|
||||||
|
nn.Linear(config.network_hidden_dims[i - 1] if i > 0 else config.network_hidden_dims[0], size)
|
||||||
|
)
|
||||||
|
|
||||||
if i + 1 < len(config.network_hidden_dims) or activate_final:
|
if i + 1 < len(config.network_hidden_dims) or activate_final:
|
||||||
if dropout_rate is not None and dropout_rate > 0:
|
if dropout_rate is not None and dropout_rate > 0:
|
||||||
layers.append(nn.Dropout(p=dropout_rate))
|
layers.append(nn.Dropout(p=dropout_rate))
|
||||||
layers.append(nn.LayerNorm(size))
|
layers.append(nn.LayerNorm(size))
|
||||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
layers.append(
|
||||||
|
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||||
|
)
|
||||||
|
|
||||||
self.net = nn.Sequential(*layers)
|
self.net = nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
|
||||||
# in training mode or not. TODO: find better way to do this
|
# in training mode or not. TODO: find better way to do this
|
||||||
self.train(train)
|
self.train(train)
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
class Critic(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -250,7 +246,7 @@ class Critic(nn.Module):
|
||||||
network: nn.Module,
|
network: nn.Module,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
activate_final: bool = False,
|
activate_final: bool = False,
|
||||||
device: str = "cuda"
|
device: str = "cuda",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
|
@ -258,7 +254,7 @@ class Critic(nn.Module):
|
||||||
self.network = network
|
self.network = network
|
||||||
self.init_final = init_final
|
self.init_final = init_final
|
||||||
self.activate_final = activate_final
|
self.activate_final = activate_final
|
||||||
|
|
||||||
# Output layer
|
# Output layer
|
||||||
if init_final is not None:
|
if init_final is not None:
|
||||||
if self.activate_final:
|
if self.activate_final:
|
||||||
|
@ -273,36 +269,28 @@ class Critic(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
||||||
orthogonal_init()(self.output_layer.weight)
|
orthogonal_init()(self.output_layer.weight)
|
||||||
|
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False) -> torch.Tensor:
|
||||||
self,
|
|
||||||
observations: torch.Tensor,
|
|
||||||
actions: torch.Tensor,
|
|
||||||
train: bool = False
|
|
||||||
) -> torch.Tensor:
|
|
||||||
self.train(train)
|
self.train(train)
|
||||||
|
|
||||||
observations = observations.to(self.device)
|
observations = observations.to(self.device)
|
||||||
actions = actions.to(self.device)
|
actions = actions.to(self.device)
|
||||||
|
|
||||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||||
|
|
||||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||||
x = self.network(inputs)
|
x = self.network(inputs)
|
||||||
value = self.output_layer(x)
|
value = self.output_layer(x)
|
||||||
return value.squeeze(-1)
|
return value.squeeze(-1)
|
||||||
|
|
||||||
def q_value_ensemble(
|
def q_value_ensemble(
|
||||||
self,
|
self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False
|
||||||
observations: torch.Tensor,
|
|
||||||
actions: torch.Tensor,
|
|
||||||
train: bool = False
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
observations = observations.to(self.device)
|
observations = observations.to(self.device)
|
||||||
actions = actions.to(self.device)
|
actions = actions.to(self.device)
|
||||||
|
|
||||||
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
|
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
|
||||||
batch_size, num_actions = actions.shape[:2]
|
batch_size, num_actions = actions.shape[:2]
|
||||||
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
|
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
|
||||||
|
@ -327,7 +315,7 @@ class Policy(nn.Module):
|
||||||
fixed_std: Optional[torch.Tensor] = None,
|
fixed_std: Optional[torch.Tensor] = None,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
activate_final: bool = False,
|
activate_final: bool = False,
|
||||||
device: str = "cuda"
|
device: str = "cuda",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
|
@ -340,7 +328,7 @@ class Policy(nn.Module):
|
||||||
self.tanh_squash_distribution = tanh_squash_distribution
|
self.tanh_squash_distribution = tanh_squash_distribution
|
||||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||||
self.activate_final = activate_final
|
self.activate_final = activate_final
|
||||||
|
|
||||||
# Mean layer
|
# Mean layer
|
||||||
if self.activate_final:
|
if self.activate_final:
|
||||||
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
||||||
|
@ -351,7 +339,7 @@ class Policy(nn.Module):
|
||||||
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
||||||
else:
|
else:
|
||||||
orthogonal_init()(self.mean_layer.weight)
|
orthogonal_init()(self.mean_layer.weight)
|
||||||
|
|
||||||
# Standard deviation layer or parameter
|
# Standard deviation layer or parameter
|
||||||
if fixed_std is None:
|
if fixed_std is None:
|
||||||
if std_parameterization == "uniform":
|
if std_parameterization == "uniform":
|
||||||
|
@ -366,18 +354,18 @@ class Policy(nn.Module):
|
||||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||||
else:
|
else:
|
||||||
orthogonal_init()(self.std_layer.weight)
|
orthogonal_init()(self.std_layer.weight)
|
||||||
|
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
observations: torch.Tensor,
|
observations: torch.Tensor,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
non_squash_distribution: bool = False
|
non_squash_distribution: bool = False,
|
||||||
) -> torch.distributions.Distribution:
|
) -> torch.distributions.Distribution:
|
||||||
self.train(train)
|
self.train(train)
|
||||||
|
|
||||||
# Encode observations if encoder exists
|
# Encode observations if encoder exists
|
||||||
if self.encoder is not None:
|
if self.encoder is not None:
|
||||||
with torch.set_grad_enabled(train):
|
with torch.set_grad_enabled(train):
|
||||||
|
@ -387,7 +375,7 @@ class Policy(nn.Module):
|
||||||
# Get network outputs
|
# Get network outputs
|
||||||
outputs = self.network(obs_enc)
|
outputs = self.network(obs_enc)
|
||||||
means = self.mean_layer(outputs)
|
means = self.mean_layer(outputs)
|
||||||
|
|
||||||
# Compute standard deviations
|
# Compute standard deviations
|
||||||
if self.fixed_std is None:
|
if self.fixed_std is None:
|
||||||
if self.std_parameterization == "exp":
|
if self.std_parameterization == "exp":
|
||||||
|
@ -398,9 +386,7 @@ class Policy(nn.Module):
|
||||||
elif self.std_parameterization == "uniform":
|
elif self.std_parameterization == "uniform":
|
||||||
stds = torch.exp(self.log_stds).expand_as(means)
|
stds = torch.exp(self.log_stds).expand_as(means)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}")
|
||||||
f"Invalid std_parameterization: {self.std_parameterization}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert self.std_parameterization == "fixed"
|
assert self.std_parameterization == "fixed"
|
||||||
stds = self.fixed_std.expand_as(means)
|
stds = self.fixed_std.expand_as(means)
|
||||||
|
@ -422,7 +408,7 @@ class Policy(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
return distribution
|
return distribution
|
||||||
|
|
||||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||||
"""Get encoded features from observations"""
|
"""Get encoded features from observations"""
|
||||||
observations = observations.to(self.device)
|
observations = observations.to(self.device)
|
||||||
|
@ -503,56 +489,47 @@ class SACObservationEncoder(nn.Module):
|
||||||
if "observation.state" in self.config.input_shapes:
|
if "observation.state" in self.config.input_shapes:
|
||||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||||
return torch.stack(feat, dim=0).mean(0)
|
return torch.stack(feat, dim=0).mean(0)
|
||||||
|
|
||||||
|
|
||||||
class LagrangeMultiplier(nn.Module):
|
class LagrangeMultiplier(nn.Module):
|
||||||
def __init__(
|
def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
|
||||||
self,
|
|
||||||
init_value: float = 1.0,
|
|
||||||
constraint_shape: Sequence[int] = (),
|
|
||||||
device: str = "cuda"
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
||||||
|
|
||||||
# Initialize the Lagrange multiplier as a parameter
|
# Initialize the Lagrange multiplier as a parameter
|
||||||
self.lagrange = nn.Parameter(
|
self.lagrange = nn.Parameter(
|
||||||
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
|
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
self,
|
# Get the multiplier value based on parameterization
|
||||||
lhs: Optional[torch.Tensor] = None,
|
|
||||||
rhs: Optional[torch.Tensor] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Get the multiplier value based on parameterization
|
|
||||||
multiplier = torch.nn.functional.softplus(self.lagrange)
|
multiplier = torch.nn.functional.softplus(self.lagrange)
|
||||||
|
|
||||||
# Return the raw multiplier if no constraint values provided
|
# Return the raw multiplier if no constraint values provided
|
||||||
if lhs is None:
|
if lhs is None:
|
||||||
return multiplier
|
return multiplier
|
||||||
|
|
||||||
# Move inputs to device
|
# Move inputs to device
|
||||||
lhs = lhs.to(self.device)
|
lhs = lhs.to(self.device)
|
||||||
if rhs is not None:
|
if rhs is not None:
|
||||||
rhs = rhs.to(self.device)
|
rhs = rhs.to(self.device)
|
||||||
|
|
||||||
# Use the multiplier to compute the Lagrange penalty
|
# Use the multiplier to compute the Lagrange penalty
|
||||||
if rhs is None:
|
if rhs is None:
|
||||||
rhs = torch.zeros_like(lhs, device=self.device)
|
rhs = torch.zeros_like(lhs, device=self.device)
|
||||||
|
|
||||||
diff = lhs - rhs
|
diff = lhs - rhs
|
||||||
|
|
||||||
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
|
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
|
||||||
|
|
||||||
return multiplier * diff
|
return multiplier * diff
|
||||||
|
|
||||||
|
|
||||||
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
|
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
|
||||||
# 1. The base distribution is a diagonal multivariate normal distribution
|
# 1. The base distribution is a diagonal multivariate normal distribution
|
||||||
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
|
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
|
||||||
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
|
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
|
||||||
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
|
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
|
||||||
|
@ -568,28 +545,22 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||||
):
|
):
|
||||||
# Create base normal distribution
|
# Create base normal distribution
|
||||||
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
|
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
|
||||||
|
|
||||||
# Create list of transforms
|
# Create list of transforms
|
||||||
transforms = []
|
transforms = []
|
||||||
|
|
||||||
# Add tanh transform
|
# Add tanh transform
|
||||||
transforms.append(torch.distributions.transforms.TanhTransform())
|
transforms.append(torch.distributions.transforms.TanhTransform())
|
||||||
|
|
||||||
# Add rescaling transform if bounds are provided
|
# Add rescaling transform if bounds are provided
|
||||||
if low is not None and high is not None:
|
if low is not None and high is not None:
|
||||||
transforms.append(
|
transforms.append(
|
||||||
torch.distributions.transforms.AffineTransform(
|
torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2)
|
||||||
loc=(high + low) / 2,
|
|
||||||
scale=(high - low) / 2
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize parent class
|
# Initialize parent class
|
||||||
super().__init__(
|
super().__init__(base_distribution=base_distribution, transforms=transforms)
|
||||||
base_distribution=base_distribution,
|
|
||||||
transforms=transforms
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store parameters
|
# Store parameters
|
||||||
self.loc = loc
|
self.loc = loc
|
||||||
self.scale_diag = scale_diag
|
self.scale_diag = scale_diag
|
||||||
|
@ -600,11 +571,11 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||||
"""Get the mode of the transformed distribution"""
|
"""Get the mode of the transformed distribution"""
|
||||||
# The mode of a normal distribution is its mean
|
# The mode of a normal distribution is its mean
|
||||||
mode = self.loc
|
mode = self.loc
|
||||||
|
|
||||||
# Apply transforms
|
# Apply transforms
|
||||||
for transform in self.transforms:
|
for transform in self.transforms:
|
||||||
mode = transform(mode)
|
mode = transform(mode)
|
||||||
|
|
||||||
return mode
|
return mode
|
||||||
|
|
||||||
def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor:
|
def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor:
|
||||||
|
@ -613,11 +584,11 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||||
"""
|
"""
|
||||||
# Sample from base distribution
|
# Sample from base distribution
|
||||||
x = self.base_dist.rsample(sample_shape)
|
x = self.base_dist.rsample(sample_shape)
|
||||||
|
|
||||||
# Apply transforms
|
# Apply transforms
|
||||||
for transform in self.transforms:
|
for transform in self.transforms:
|
||||||
x = transform(x)
|
x = transform(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -627,16 +598,16 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||||
"""
|
"""
|
||||||
# Initialize log prob
|
# Initialize log prob
|
||||||
log_prob = torch.zeros_like(value[..., 0])
|
log_prob = torch.zeros_like(value[..., 0])
|
||||||
|
|
||||||
# Inverse transforms to get back to normal distribution
|
# Inverse transforms to get back to normal distribution
|
||||||
q = value
|
q = value
|
||||||
for transform in reversed(self.transforms):
|
for transform in reversed(self.transforms):
|
||||||
q = transform.inv(q)
|
q = transform.inv(q)
|
||||||
log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q))
|
log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q))
|
||||||
|
|
||||||
# Add base distribution log prob
|
# Add base distribution log prob
|
||||||
log_prob = log_prob + self.base_dist.log_prob(q).sum(-1)
|
log_prob = log_prob + self.base_dist.log_prob(q).sum(-1)
|
||||||
|
|
||||||
return log_prob
|
return log_prob
|
||||||
|
|
||||||
def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]:
|
def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
@ -653,13 +624,13 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||||
"""
|
"""
|
||||||
# Start with base distribution entropy
|
# Start with base distribution entropy
|
||||||
entropy = self.base_dist.entropy().sum(-1)
|
entropy = self.base_dist.entropy().sum(-1)
|
||||||
|
|
||||||
# Add log det jacobian for each transform
|
# Add log det jacobian for each transform
|
||||||
x = self.rsample()
|
x = self.rsample()
|
||||||
for transform in self.transforms:
|
for transform in self.transforms:
|
||||||
entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
|
entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
|
||||||
x = transform(x)
|
x = transform(x)
|
||||||
|
|
||||||
return entropy
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
|
@ -680,7 +651,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
||||||
Args:
|
Args:
|
||||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||||
(B, *), where * is any number of dimensions.
|
(B, *), where * is any number of dimensions.
|
||||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||||
can be more than 1 dimensions, generally different from *.
|
can be more than 1 dimensions, generally different from *.
|
||||||
Returns:
|
Returns:
|
||||||
A return value from the callable reshaped to (**, *).
|
A return value from the callable reshaped to (**, *).
|
||||||
|
@ -691,4 +662,3 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
||||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||||
flat_out = fn(inp)
|
flat_out = fn(inp)
|
||||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ python lerobot/scripts/eval_on_robot.py \
|
||||||
```
|
```
|
||||||
|
|
||||||
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
|
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
|
||||||
for running training on the real robot.
|
for running training on the real robot.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -47,7 +47,7 @@ from lerobot.common.utils.utils import (
|
||||||
|
|
||||||
|
|
||||||
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict:
|
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict:
|
||||||
"""Run a batched policy rollout on the real robot.
|
"""Run a batched policy rollout on the real robot.
|
||||||
|
|
||||||
The return dictionary contains:
|
The return dictionary contains:
|
||||||
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
||||||
|
@ -64,7 +64,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
|
||||||
extraneous elements from the sequences above.
|
extraneous elements from the sequences above.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
robot: The robot class that defines the interface with the real robot.
|
robot: The robot class that defines the interface with the real robot.
|
||||||
policy: The policy. Must be a PyTorch nn module.
|
policy: The policy. Must be a PyTorch nn module.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -77,7 +77,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
|
||||||
listener, events = init_keyboard_listener()
|
listener, events = init_keyboard_listener()
|
||||||
|
|
||||||
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
|
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
|
||||||
# policy.reset()
|
# policy.reset()
|
||||||
|
|
||||||
# Get observation from real robot
|
# Get observation from real robot
|
||||||
observation = robot.capture_observation()
|
observation = robot.capture_observation()
|
||||||
|
|
|
@ -95,12 +95,14 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
elif policy.name == "sac":
|
elif policy.name == "sac":
|
||||||
optimizer = torch.optim.Adam([
|
optimizer = torch.optim.Adam(
|
||||||
{'params': policy.actor.parameters(), 'lr': policy.config.actor_lr},
|
[
|
||||||
{'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr},
|
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||||
{'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr},
|
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
|
||||||
])
|
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
|
||||||
lr_scheduler = None
|
]
|
||||||
|
)
|
||||||
|
lr_scheduler = None
|
||||||
|
|
||||||
elif cfg.policy.name == "vqbet":
|
elif cfg.policy.name == "vqbet":
|
||||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||||
|
|
|
@ -22,7 +22,6 @@ from pprint import pformat
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import wandb
|
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler
|
||||||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import wandb
|
||||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.logger import Logger
|
from lerobot.common.logger import Logger
|
||||||
|
|
Loading…
Reference in New Issue