diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 4015492d..dec8b465 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -25,13 +25,13 @@ from glob import glob from pathlib import Path import torch -import wandb from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from omegaconf import DictConfig, OmegaConf from termcolor import colored from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler +import wandb from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import get_global_random_state, set_global_random_state diff --git a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py index 553e4262..f0b9352f 100644 --- a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py @@ -2,8 +2,6 @@ import json import os from dataclasses import asdict, dataclass -import torch - @dataclass class ClassifierConfig: diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index 0b8d66ac..28b05744 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -23,9 +23,11 @@ class ClassifierOutput: self.hidden_states = hidden_states def __repr__(self): - return (f"ClassifierOutput(logits={self.logits}, " - f"probabilities={self.probabilities}, " - f"hidden_states={self.hidden_states})") + return ( + f"ClassifierOutput(logits={self.logits}, " + f"probabilities={self.probabilities}, " + f"hidden_states={self.hidden_states})" + ) class Classifier( @@ -74,7 +76,7 @@ class Classifier( self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension else: raise ValueError("Unsupported CNN architecture") - + self.encoder = self.encoder.to(self.config.device) def _freeze_encoder(self) -> None: diff --git a/lerobot/common/policies/hilserl/configuration_hilserl.py b/lerobot/common/policies/hilserl/configuration_hilserl.py index f1bc850f..80d2f578 100644 --- a/lerobot/common/policies/hilserl/configuration_hilserl.py +++ b/lerobot/common/policies/hilserl/configuration_hilserl.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2024 The HuggingFace Inc. team. +# Copyright 2024 The HuggingFace Inc. team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/lerobot/common/policies/hilserl/modeling_hilserl.py b/lerobot/common/policies/hilserl/modeling_hilserl.py index 236ed433..679eb010 100644 --- a/lerobot/common/policies/hilserl/modeling_hilserl.py +++ b/lerobot/common/policies/hilserl/modeling_hilserl.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2024 The HuggingFace Inc. team. +# Copyright 2024 The HuggingFace Inc. team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,4 +26,4 @@ class HILSerlPolicy( repo_url="https://github.com/huggingface/lerobot", tags=["robotics", "hilserl"], ): - pass \ No newline at end of file + pass diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 6db198e8..f4a2bc4c 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2024 The HuggingFace Inc. team. +# Copyright 2024 The HuggingFace Inc. team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field @dataclass @@ -30,14 +30,36 @@ class SACConfig: critic_target_update_weight = 0.005 utd_ratio = 2 critic_network_kwargs = { - "hidden_dims": [256, 256], - "activate_final": True, - } + "hidden_dims": [256, 256], + "activate_final": True, + } actor_network_kwargs = { - "hidden_dims": [256, 256], - "activate_final": True, - } + "hidden_dims": [256, 256], + "activate_final": True, + } policy_kwargs = { - "tanh_squash_distribution": True, - "std_parameterization": "uniform", + "tanh_squash_distribution": True, + "std_parameterization": "uniform", + } + + input_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "observation.image": [3, 84, 84], + "observation.state": [4], } + ) + output_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "action": [4], + } + ) + + state_encoder_hidden_dim: int = 256 + latent_dim: int = 256 + network_hidden_dims: int = 256 + + # Normalization / Unnormalization + input_normalization_modes: dict[str, str] | None = None + output_normalization_modes: dict[str, str] = field( + default_factory=lambda: {"action": "min_max"}, + ) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index c5e3f209..51258fac 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2024 The HuggingFace Inc. team. +# Copyright 2024 The HuggingFace Inc. team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -40,11 +40,9 @@ class SACPolicy( repo_url="https://github.com/huggingface/lerobot", tags=["robotics", "RL", "SAC"], ): - def __init__( self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None ): - super().__init__() if config is None: @@ -67,12 +65,9 @@ class SACPolicy( # Define networks critic_nets = [] for _ in range(config.num_critics): - critic_net = Critic( - encoder=encoder, - network=MLP(**config.critic_network_kwargs) - ) + critic_net = Critic(encoder=encoder, network=MLP(**config.critic_network_kwargs)) critic_nets.append(critic_net) - + self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) self.critic_target = deepcopy(self.critic_ensemble) @@ -80,11 +75,11 @@ class SACPolicy( encoder=encoder, network=MLP(**config.actor_network_kwargs), action_dim=config.output_shapes["action"][0], - **config.policy_kwargs + **config.policy_kwargs, ) if config.target_entropy is None: - config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) - self.temperature = LagrangeMultiplier(init_value=config.temperature_init) + config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) + self.temperature = LagrangeMultiplier(init_value=config.temperature_init) def reset(self): """ @@ -100,10 +95,10 @@ class SACPolicy( self._queues["observation.image"] = deque(maxlen=1) if self._use_env_state: self._queues["observation.environment_state"] = deque(maxlen=1) - + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: - actions, _ = self.actor_network(batch['observations'])### + actions, _ = self.actor_network(batch["observations"]) ### def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: """Run the batch through the model and compute the loss. @@ -111,8 +106,8 @@ class SACPolicy( Returns a dictionary with loss as a tensor, and other information as native floats. """ batch = self.normalize_inputs(batch) - # batch shape is (b, 2, ...) where index 1 returns the current observation and - # the next observation for caluculating the right td index. + # batch shape is (b, 2, ...) where index 1 returns the current observation and + # the next observation for caluculating the right td index. actions = batch["action"][:, 0] rewards = batch["next.reward"][:, 0] observations = {} @@ -121,13 +116,12 @@ class SACPolicy( if k.startswith("observation."): observations[k] = batch[k][:, 0] next_observations[k] = batch[k][:, 1] - + # perform image augmentation # reward bias - # from HIL-SERL code base + # from HIL-SERL code base # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch - # calculate critics loss # 1- compute actions from policy @@ -137,7 +131,7 @@ class SACPolicy( # subsample critics to prevent overfitting if use high UTD (update to date) if self.config.num_subsample_critics is not None: indices = torch.randperm(self.config.num_critics) - indices = indices[:self.config.num_subsample_critics] + indices = indices[: self.config.num_subsample_critics] q_targets = q_targets[indices] # critics subsample size @@ -151,8 +145,9 @@ class SACPolicy( # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - critics_loss = ( - F.mse_loss( + critics_loss = ( + ( + F.mse_loss( q_preds, einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]), reduction="none", @@ -163,15 +158,17 @@ class SACPolicy( # q_targets depends on the reward and the next observations. * ~batch["next.reward_is_pad"] * ~batch["observation.state_is_pad"][1:] - ).sum(0).mean() - + ) + .sum(0) + .mean() + ) + # calculate actors loss # 1- temperature temperature = self.temperature() # 2- get actions (batch_size, action_dim) and log probs (batch_size,) - actions, log_probs = self.actor_network(observations) \ - + actions, log_probs = self.actor_network(observations) # 3- get q-value predictions with torch.no_grad(): q_preds = self.critic_ensemble(observations, actions, return_type="mean") @@ -181,36 +178,31 @@ class SACPolicy( * ~batch["action_is_pad"] ).mean() - # calculate temperature loss # 1- calculate entropy entropy = -log_probs.mean() - temperature_loss = self.temp( - lhs=entropy, - rhs=self.config.target_entropy - ) + temperature_loss = self.temp(lhs=entropy, rhs=self.config.target_entropy) loss = critics_loss + actor_loss + temperature_loss return { - "critics_loss": critics_loss.item(), - "actor_loss": actor_loss.item(), - "temperature_loss": temperature_loss.item(), - "temperature": temperature.item(), - "entropy": entropy.item(), - "loss": loss, + "critics_loss": critics_loss.item(), + "actor_loss": actor_loss.item(), + "temperature_loss": temperature_loss.item(), + "temperature": temperature.item(), + "entropy": entropy.item(), + "loss": loss, + } - } - def update(self): self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight) # TODO: implement UTD update # First update only critics for utd_ratio-1 times - #for critic_step in range(self.config.utd_ratio - 1): - # only update critic and critic target + # for critic_step in range(self.config.utd_ratio - 1): + # only update critic and critic target # Then update critic, critic target, actor and temperature - #for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): + # for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): # target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight) @@ -225,24 +217,28 @@ class MLP(nn.Module): super().__init__() self.activate_final = config.activate_final layers = [] - + for i, size in enumerate(config.network_hidden_dims): - layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size)) - + layers.append( + nn.Linear(config.network_hidden_dims[i - 1] if i > 0 else config.network_hidden_dims[0], size) + ) + if i + 1 < len(config.network_hidden_dims) or activate_final: if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.LayerNorm(size)) - layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) - + layers.append( + activations if isinstance(activations, nn.Module) else getattr(nn, activations)() + ) + self.net = nn.Sequential(*layers) def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor: # in training mode or not. TODO: find better way to do this - self.train(train) + self.train(train) return self.net(x) - - + + class Critic(nn.Module): def __init__( self, @@ -250,7 +246,7 @@ class Critic(nn.Module): network: nn.Module, init_final: Optional[float] = None, activate_final: bool = False, - device: str = "cuda" + device: str = "cuda", ): super().__init__() self.device = torch.device(device) @@ -258,7 +254,7 @@ class Critic(nn.Module): self.network = network self.init_final = init_final self.activate_final = activate_final - + # Output layer if init_final is not None: if self.activate_final: @@ -273,36 +269,28 @@ class Critic(nn.Module): else: self.output_layer = nn.Linear(network.net[-2].out_features, 1) orthogonal_init()(self.output_layer.weight) - + self.to(self.device) - def forward( - self, - observations: torch.Tensor, - actions: torch.Tensor, - train: bool = False - ) -> torch.Tensor: + def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False) -> torch.Tensor: self.train(train) - + observations = observations.to(self.device) actions = actions.to(self.device) - + obs_enc = observations if self.encoder is None else self.encoder(observations) - + inputs = torch.cat([obs_enc, actions], dim=-1) x = self.network(inputs) value = self.output_layer(x) return value.squeeze(-1) - + def q_value_ensemble( - self, - observations: torch.Tensor, - actions: torch.Tensor, - train: bool = False + self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False ) -> torch.Tensor: observations = observations.to(self.device) actions = actions.to(self.device) - + if len(actions.shape) == 3: # [batch_size, num_actions, action_dim] batch_size, num_actions = actions.shape[:2] obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1) @@ -327,7 +315,7 @@ class Policy(nn.Module): fixed_std: Optional[torch.Tensor] = None, init_final: Optional[float] = None, activate_final: bool = False, - device: str = "cuda" + device: str = "cuda", ): super().__init__() self.device = torch.device(device) @@ -340,7 +328,7 @@ class Policy(nn.Module): self.tanh_squash_distribution = tanh_squash_distribution self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.activate_final = activate_final - + # Mean layer if self.activate_final: self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim) @@ -351,7 +339,7 @@ class Policy(nn.Module): nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) else: orthogonal_init()(self.mean_layer.weight) - + # Standard deviation layer or parameter if fixed_std is None: if std_parameterization == "uniform": @@ -366,18 +354,18 @@ class Policy(nn.Module): nn.init.uniform_(self.std_layer.bias, -init_final, init_final) else: orthogonal_init()(self.std_layer.weight) - + self.to(self.device) def forward( - self, + self, observations: torch.Tensor, temperature: float = 1.0, train: bool = False, - non_squash_distribution: bool = False + non_squash_distribution: bool = False, ) -> torch.distributions.Distribution: self.train(train) - + # Encode observations if encoder exists if self.encoder is not None: with torch.set_grad_enabled(train): @@ -387,7 +375,7 @@ class Policy(nn.Module): # Get network outputs outputs = self.network(obs_enc) means = self.mean_layer(outputs) - + # Compute standard deviations if self.fixed_std is None: if self.std_parameterization == "exp": @@ -398,9 +386,7 @@ class Policy(nn.Module): elif self.std_parameterization == "uniform": stds = torch.exp(self.log_stds).expand_as(means) else: - raise ValueError( - f"Invalid std_parameterization: {self.std_parameterization}" - ) + raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}") else: assert self.std_parameterization == "fixed" stds = self.fixed_std.expand_as(means) @@ -422,7 +408,7 @@ class Policy(nn.Module): ) return distribution - + def get_features(self, observations: torch.Tensor) -> torch.Tensor: """Get encoded features from observations""" observations = observations.to(self.device) @@ -503,56 +489,47 @@ class SACObservationEncoder(nn.Module): if "observation.state" in self.config.input_shapes: feat.append(self.state_enc_layers(obs_dict["observation.state"])) return torch.stack(feat, dim=0).mean(0) - + class LagrangeMultiplier(nn.Module): - def __init__( - self, - init_value: float = 1.0, - constraint_shape: Sequence[int] = (), - device: str = "cuda" - ): + def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"): super().__init__() self.device = torch.device(device) init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1) - + # Initialize the Lagrange multiplier as a parameter self.lagrange = nn.Parameter( torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device) ) - + self.to(self.device) - def forward( - self, - lhs: Optional[torch.Tensor] = None, - rhs: Optional[torch.Tensor] = None - ) -> torch.Tensor: - # Get the multiplier value based on parameterization + def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None) -> torch.Tensor: + # Get the multiplier value based on parameterization multiplier = torch.nn.functional.softplus(self.lagrange) - + # Return the raw multiplier if no constraint values provided if lhs is None: return multiplier - + # Move inputs to device lhs = lhs.to(self.device) if rhs is not None: rhs = rhs.to(self.device) - + # Use the multiplier to compute the Lagrange penalty if rhs is None: rhs = torch.zeros_like(lhs, device=self.device) - + diff = lhs - rhs - + assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}" - + return multiplier * diff # The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where: -# 1. The base distribution is a diagonal multivariate normal distribution +# 1. The base distribution is a diagonal multivariate normal distribution # 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1 # 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation # This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces @@ -568,28 +545,22 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): ): # Create base normal distribution base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag) - + # Create list of transforms transforms = [] - + # Add tanh transform transforms.append(torch.distributions.transforms.TanhTransform()) - + # Add rescaling transform if bounds are provided if low is not None and high is not None: transforms.append( - torch.distributions.transforms.AffineTransform( - loc=(high + low) / 2, - scale=(high - low) / 2 - ) + torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2) ) - + # Initialize parent class - super().__init__( - base_distribution=base_distribution, - transforms=transforms - ) - + super().__init__(base_distribution=base_distribution, transforms=transforms) + # Store parameters self.loc = loc self.scale_diag = scale_diag @@ -600,11 +571,11 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): """Get the mode of the transformed distribution""" # The mode of a normal distribution is its mean mode = self.loc - + # Apply transforms for transform in self.transforms: mode = transform(mode) - + return mode def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor: @@ -613,11 +584,11 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): """ # Sample from base distribution x = self.base_dist.rsample(sample_shape) - + # Apply transforms for transform in self.transforms: x = transform(x) - + return x def log_prob(self, value: torch.Tensor) -> torch.Tensor: @@ -627,16 +598,16 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): """ # Initialize log prob log_prob = torch.zeros_like(value[..., 0]) - + # Inverse transforms to get back to normal distribution q = value for transform in reversed(self.transforms): q = transform.inv(q) log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q)) - + # Add base distribution log prob log_prob = log_prob + self.base_dist.log_prob(q).sum(-1) - + return log_prob def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]: @@ -653,13 +624,13 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): """ # Start with base distribution entropy entropy = self.base_dist.entropy().sum(-1) - + # Add log det jacobian for each transform x = self.rsample() for transform in self.transforms: entropy = entropy + transform.log_abs_det_jacobian(x, transform(x)) x = transform(x) - + return entropy @@ -680,7 +651,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens Args: fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return (B, *), where * is any number of dimensions. - image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and + image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and can be more than 1 dimensions, generally different from *. Returns: A return value from the callable reshaped to (**, *). @@ -691,4 +662,3 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens inp = torch.flatten(image_tensor, end_dim=-4) flat_out = fn(inp) return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) - diff --git a/lerobot/scripts/eval_on_robot.py b/lerobot/scripts/eval_on_robot.py index 6a790f0a..92daa860 100644 --- a/lerobot/scripts/eval_on_robot.py +++ b/lerobot/scripts/eval_on_robot.py @@ -24,7 +24,7 @@ python lerobot/scripts/eval_on_robot.py \ ``` **NOTE** (michel-aractingi): This script is incomplete and it is being prepared -for running training on the real robot. +for running training on the real robot. """ import argparse @@ -47,7 +47,7 @@ from lerobot.common.utils.utils import ( def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict: - """Run a batched policy rollout on the real robot. + """Run a batched policy rollout on the real robot. The return dictionary contains: "robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation @@ -64,7 +64,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, extraneous elements from the sequences above. Args: - robot: The robot class that defines the interface with the real robot. + robot: The robot class that defines the interface with the real robot. policy: The policy. Must be a PyTorch nn module. Returns: @@ -77,7 +77,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, listener, events = init_keyboard_listener() # Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready. - # policy.reset() + # policy.reset() # Get observation from real robot observation = robot.capture_observation() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 346c3acd..fbe7927d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -95,12 +95,14 @@ def make_optimizer_and_scheduler(cfg, policy): lr_scheduler = None elif policy.name == "sac": - optimizer = torch.optim.Adam([ - {'params': policy.actor.parameters(), 'lr': policy.config.actor_lr}, - {'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr}, - {'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr}, - ]) - lr_scheduler = None + optimizer = torch.optim.Adam( + [ + {"params": policy.actor.parameters(), "lr": policy.config.actor_lr}, + {"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr}, + {"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr}, + ] + ) + lr_scheduler = None elif cfg.policy.name == "vqbet": from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 78659dc8..ea8336a9 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -22,7 +22,6 @@ from pprint import pformat import hydra import torch import torch.nn as nn -import wandb from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored @@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from tqdm import tqdm +import wandb from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.logger import Logger