From f3450aea4197ee76c047e7ca14afd35b3d98e319 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Tue, 26 Nov 2024 09:35:09 +0000 Subject: [PATCH] added support for maniskill environment in factory.py --- lerobot/common/envs/factory.py | 62 ++++++ lerobot/common/policies/factory.py | 7 + .../policies/tdmpc2/configuration_tdmpc2.py | 9 +- .../common/policies/tdmpc2/modeling_tdmpc2.py | 169 ++++++++------ .../common/policies/tdmpc2/tdmpc2_utils.py | 207 +++++++++--------- lerobot/scripts/train.py | 12 + 6 files changed, 296 insertions(+), 170 deletions(-) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 54f24ea8..ed6ec5c8 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -14,8 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +from collections import deque import gymnasium as gym +import numpy as np +import torch from omegaconf import DictConfig @@ -30,6 +33,10 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv if cfg.env.name == "real_world": return + if "maniskill" in cfg.env.name: + env = make_maniskill_env(cfg, n_envs if n_envs is not None else cfg.eval.batch_size) + return env + package_name = f"gym_{cfg.env.name}" try: @@ -56,3 +63,58 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv ) return env + + +def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None: + """Make ManiSkill3 gym environment""" + from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv + + env = gym.make( + cfg.env.task, + obs_mode=cfg.env.obs, + control_mode=cfg.env.control_mode, + render_mode=cfg.env.render_mode, + sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size), + num_envs=n_envs, + ) + # cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode + env = ManiSkillVectorEnv(env, ignore_terminations=True) + env = PixelWrapper(cfg, env, n_envs) + env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env) + env.unwrapped.metadata["render_fps"] = 20 + + return env + + +class PixelWrapper(gym.Wrapper): + """ + Wrapper for pixel observations. Works with Maniskill vectorized environments + """ + + def __init__(self, cfg, env, num_envs, num_frames=3): + super().__init__(env) + self.cfg = cfg + self.env = env + self.observation_space = gym.spaces.Box( + low=0, + high=255, + shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size), + dtype=np.uint8, + ) + self._frames = deque([], maxlen=num_frames) + self._render_size = cfg.env.render_size + + def _get_obs(self, obs): + frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2) + self._frames.append(frame) + return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)} + + def reset(self, seed): + obs, info = self.env.reset() # (seed=seed) + for _ in range(self._frames.maxlen): + obs_frames = self._get_obs(obs) + return obs_frames, info + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + return self._get_obs(obs), reward, terminated, truncated, info diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 5cb2fd52..f75baec3 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -51,6 +51,13 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]: from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy return TDMPCPolicy, TDMPCConfig + + elif name == "tdmpc2": + from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config + from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy + + return TDMPC2Policy, TDMPC2Config + elif name == "diffusion": from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy diff --git a/lerobot/common/policies/tdmpc2/configuration_tdmpc2.py b/lerobot/common/policies/tdmpc2/configuration_tdmpc2.py index 946f70e3..661128a5 100644 --- a/lerobot/common/policies/tdmpc2/configuration_tdmpc2.py +++ b/lerobot/common/policies/tdmpc2/configuration_tdmpc2.py @@ -126,9 +126,12 @@ class TDMPC2Config: state_encoder_hidden_dim: int = 256 latent_dim: int = 512 q_ensemble_size: int = 5 + num_enc_layers: int = 2 mlp_dim: int = 512 # Reinforcement learning. discount: float = 0.9 + simnorm_dim: int = 8 + dropout: float = 0.01 # actor log_std_min: float = -10 @@ -157,10 +160,10 @@ class TDMPC2Config: consistency_coeff: float = 20.0 entropy_coef: float = 1e-4 temporal_decay_coeff: float = 0.5 - # Target model. NOTE (michel_aractingi) this is equivelant to - # 1 - target_model_momentum of our TD-MPC1 implementation because + # Target model. NOTE (michel_aractingi) this is equivelant to + # 1 - target_model_momentum of our TD-MPC1 implementation because # of the use of `torch.lerp` - target_model_momentum: float = 0.01 + target_model_momentum: float = 0.01 def __post_init__(self): """Input validation (not exhaustive).""" diff --git a/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py b/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py index e3127168..35cdd4ce 100644 --- a/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py +++ b/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py @@ -38,8 +38,16 @@ from torch import Tensor from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config +from lerobot.common.policies.tdmpc2.tdmpc2_utils import ( + NormedLinear, + SimNorm, + gaussian_logprob, + soft_cross_entropy, + squash, + two_hot_inv, +) from lerobot.common.policies.utils import get_device_from_parameters, populate_queues -from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash, soft_cross_entropy + class TDMPC2Policy( nn.Module, @@ -83,6 +91,8 @@ class TDMPC2Policy( config = TDMPC2Config() self.config = config self.model = TDMPC2WorldModel(config) + # TODO (michel-aractingi) temp fix for gpu + self.model = self.model.to("cuda:0") if config.input_normalization_modes is not None: self.normalize_inputs = Normalize( @@ -109,7 +119,9 @@ class TDMPC2Policy( self._use_env_state = True self.scale = RunningScale(self.config.target_model_momentum) - self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length + self.discount = ( + self.config.discount + ) # TODO (michel-aractingi) downscale discount according to episode length self.reset() @@ -204,7 +216,7 @@ class TDMPC2Policy( for t in range(self.config.horizon): # Note: Adding a small amount of noise here doesn't hurt during inference and may even be # helpful for CEM. - pi_actions[t] = self.model.pi(_z, self.config.min_std)[0] + pi_actions[t] = self.model.pi(_z)[0] _z = self.model.latent_dynamics(_z, pi_actions[t]) # In the CEM loop we will need this for a call to estimate_value with the gaussian sampled @@ -249,14 +261,17 @@ class TDMPC2Policy( score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) score /= score.sum(axis=0, keepdim=True) # (horizon, batch, action_dim) - mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (score.sum(0) + 1e-9) + mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / ( + einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9 + ) std = torch.sqrt( torch.sum( einops.rearrange(score, "n b -> n b 1") * (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2, dim=1, - ) / (score.sum(0) + 1e-9) - ).clamp_(self.config.min_std, self.config.max_std) + ) + / (einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9) + ).clamp_(self.config.min_std, self.config.max_std) # Keep track of the mean for warm-starting subsequent steps. self._prev_mean = mean @@ -286,11 +301,11 @@ class TDMPC2Policy( # Update the return and running discount. G += running_discount * reward running_discount *= self.config.discount - - #next_action = self.model.pi(z)[0] # (batch, action_dim) - #terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch) - return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type='avg') + # next_action = self.model.pi(z)[0] # (batch, action_dim) + # terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch) + + return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type="avg") def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: """Run the batch through the model and compute the loss. @@ -351,14 +366,15 @@ class TDMPC2Policy( # Compute various targets with stopgrad. with torch.no_grad(): - # Latent state consistency targets for consistency loss. + # Latent state consistency targets for consistency loss. z_targets = self.model.encode(next_observations) # Compute the TD-target from a reward and the next observation pi = self.model.pi(z_targets)[0] td_targets = ( reward - + self.config.discount * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze() + + self.config.discount + * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze() ) # Compute losses. @@ -421,18 +437,20 @@ class TDMPC2Policy( z_preds = z_preds.detach() self.model.change_q_grad(mode=False) action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1]) - + with torch.no_grad(): # avoid unnessecary computation of the gradients during policy optimization - # TODO (michel-aractingi): the same logic should be extended when adding task embeddings + # TODO (michel-aractingi): the same logic should be extended when adding task embeddings qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg") self.scale.update(qs[0]) qs = self.scale(qs) - rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(-1) + rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze( + -1 + ) pi_loss = ( - (self.config.entropy_coef * log_pis - qs).mean(dim=(1,2)) + (self.config.entropy_coef * log_pis - qs).mean(dim=(1, 2)) * rho # * temporal_loss_coeffs # `action_preds` depends on the first observation and the actions. @@ -470,56 +488,69 @@ class TDMPC2Policy( """Update the target model's using polyak averaging.""" self.model.update_target_Q() + class TDMPC2WorldModel(nn.Module): """Latent dynamics model used in TD-MPC2.""" def __init__(self, config: TDMPC2Config): super().__init__() self.config = config - + self._encoder = TDMPC2ObservationEncoder(config) # Define latent dynamics head - self._dynamics = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), - NormedLinear(config.mlp_dim, config.mlp_dim), - NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim))) - + self._dynamics = nn.Sequential( + NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), + NormedLinear(config.mlp_dim, config.mlp_dim), + NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)), + ) + # Define reward head - self._reward = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), - NormedLinear(config.mlp_dim, config.mlp_dim), - nn.Linear(config.mlp_dim, max(config.num_bins, 1))) - + self._reward = nn.Sequential( + NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), + NormedLinear(config.mlp_dim, config.mlp_dim), + nn.Linear(config.mlp_dim, max(config.num_bins, 1)), + ) + # Define policy head - self._pi = nn.Sequential(NormedLinear(config.latent_dim, config.mlp_dim), - NormedLinear(config.mlp_dim, config.mlp_dim), - nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0])) - + self._pi = nn.Sequential( + NormedLinear(config.latent_dim, config.mlp_dim), + NormedLinear(config.mlp_dim, config.mlp_dim), + nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]), + ) + # Define ensemble of Q functions self._Qs = nn.ModuleList( [ nn.Sequential( - NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim, dropout=config.dropout), + NormedLinear( + config.latent_dim + config.output_shapes["action"][0], + config.mlp_dim, + dropout=config.dropout, + ), NormedLinear(config.mlp_dim, config.mlp_dim), - nn.Linear(config.mlp_dim, max(config.num_bins, 1)) - ) for _ in range(config.q_ensemble_size) + nn.Linear(config.mlp_dim, max(config.num_bins, 1)), + ) + for _ in range(config.q_ensemble_size) ] ) - + self._init_weights() self._target_Qs = deepcopy(self._Qs).requires_grad_(False) self.log_std_min = torch.tensor(config.log_std_min) self.log_std_dif = torch.tensor(config.log_std_max) - self.log_std_min - + self.bins = torch.linspace(config.vmin, config.vmax, config.num_bins) self.config.bin_size = (config.vmax - config.vmin) / (config.num_bins - 1) def _init_weights(self): """Initialize model weights. - Custom weight initializations proposed in TD-MPC2. + Custom weight initializations proposed in TD-MPC2. """ + def _apply_fn(m): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) @@ -527,13 +558,13 @@ class TDMPC2WorldModel(nn.Module): nn.init.constant_(m.bias, 0) elif isinstance(m, nn.ParameterList): for i, p in enumerate(m): - if p.dim() == 3: # Linear - nn.init.trunc_normal_(p, std=0.02) # Weight - nn.init.constant_(m[i+1], 0) # Bias + if p.dim() == 3: # Linear + nn.init.trunc_normal_(p, std=0.02) # Weight + nn.init.constant_(m[i + 1], 0) # Bias self.apply(_apply_fn) - - # initialize parameters of the + + # initialize parameters of the for m in [self._reward, *self._Qs]: assert isinstance( m[-1], nn.Linear @@ -549,7 +580,7 @@ class TDMPC2WorldModel(nn.Module): self.log_std_dif = self.log_std_dif.to(*args, **kwargs) self.bins = self.bins.to(*args, **kwargs) return self - + def train(self, mode): super().train(mode) self._target_Qs.train(False) @@ -641,7 +672,7 @@ class TDMPC2WorldModel(nn.Module): Soft-update target Q-networks using Polyak averaging. """ with torch.no_grad(): - for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()): + for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters(), strict=False): p_target.data.lerp_(p.data, self.config.target_model_momentum) @@ -662,49 +693,51 @@ class TDMPC2ObservationEncoder(nn.Module): for obs_key in config.input_shapes: if "observation.image" in config.input_shapes: encoder_module = nn.Sequential( - nn.Conv2d(config.input_shapes[obs_key][0], config.image_encoder_hidden_dim, 7, stride=2), + nn.Conv2d(config.input_shapes[obs_key][0], config.image_encoder_hidden_dim, 7, stride=2), nn.ReLU(inplace=True), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), nn.ReLU(inplace=True), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), nn.ReLU(inplace=True), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=1), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=1), ) dummy_batch = torch.zeros(1, *config.input_shapes[obs_key]) with torch.inference_mode(): - out_shape = self.image_enc_layers(dummy_batch).shape[1:] + out_shape = encoder_module(dummy_batch).shape[1:] encoder_module.extend( nn.Sequential( nn.Flatten(), NormedLinear(np.prod(out_shape), config.latent_dim, act=SimNorm(config.simnorm_dim)), ) ) - - elif "observation.state" in config.input_shapes: + + elif ( + "observation.state" in config.input_shapes + or "observation.environment_state" in config.input_shapes + ): encoder_module = nn.ModuleList() - encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)) + encoder_module.append( + NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim) + ) assert config.num_enc_layers > 0 for _ in range(config.num_enc_layers - 1): - encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)) - encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim))) - encoder_module = nn.Sequential(*encoder_module) - - elif "observation.environment_state" in config.input_shapes: - encoder_module = nn.ModuleList() - encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)) - assert config.num_enc_layers > 0 - for _ in range(config.num_enc_layers - 1): - encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)) - encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim))) + encoder_module.append( + NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim) + ) + encoder_module.append( + NormedLinear( + config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim) + ) + ) encoder_module = nn.Sequential(*encoder_module) else: raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.") - - encoder_dict[obs_key] = encoder_module + + encoder_dict[obs_key.replace(".", "")] = encoder_module self.encoder = nn.ModuleDict(encoder_dict) - + def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: """Encode the image and/or state vector. @@ -714,9 +747,11 @@ class TDMPC2ObservationEncoder(nn.Module): feat = [] for obs_key in self.config.input_shapes: if "observation.image" in obs_key: - feat.append(flatten_forward_unflatten(self.encoder[obs_key], obs_dict[obs_key])) + feat.append( + flatten_forward_unflatten(self.encoder[obs_key.replace(".", "")], obs_dict[obs_key]) + ) else: - feat.append(self.encoder[obs_key](obs_dict[obs_key])) + feat.append(self.encoder[obs_key.replace(".", "")](obs_dict[obs_key])) return torch.stack(feat, dim=0).mean(0) @@ -816,4 +851,4 @@ class RunningScale: return x * (1 / self.value) def __repr__(self): - return f"RunningScale(S: {self.value})" \ No newline at end of file + return f"RunningScale(S: {self.value})" diff --git a/lerobot/common/policies/tdmpc2/tdmpc2_utils.py b/lerobot/common/policies/tdmpc2/tdmpc2_utils.py index e1a38eab..22f1ca06 100644 --- a/lerobot/common/policies/tdmpc2/tdmpc2_utils.py +++ b/lerobot/common/policies/tdmpc2/tdmpc2_utils.py @@ -5,152 +5,159 @@ from functorch import combine_state_for_ensemble class Ensemble(nn.Module): - """ - Vectorized ensemble of modules. - """ + """ + Vectorized ensemble of modules. + """ - def __init__(self, modules, **kwargs): - super().__init__() - modules = nn.ModuleList(modules) - fn, params, _ = combine_state_for_ensemble(modules) - self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs) - self.params = nn.ParameterList([nn.Parameter(p) for p in params]) - self._repr = str(modules) + def __init__(self, modules, **kwargs): + super().__init__() + modules = nn.ModuleList(modules) + fn, params, _ = combine_state_for_ensemble(modules) + self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness="different", **kwargs) + self.params = nn.ParameterList([nn.Parameter(p) for p in params]) + self._repr = str(modules) - def forward(self, *args, **kwargs): - return self.vmap([p for p in self.params], (), *args, **kwargs) + def forward(self, *args, **kwargs): + return self.vmap([p for p in self.params], (), *args, **kwargs) + + def __repr__(self): + return "Vectorized " + self._repr - def __repr__(self): - return 'Vectorized ' + self._repr class SimNorm(nn.Module): - """ - Simplicial normalization. - Adapted from https://arxiv.org/abs/2204.00616. - """ - - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - shp = x.shape - x = x.view(*shp[:-1], -1, self.dim) - x = F.softmax(x, dim=-1) - return x.view(*shp) - - def __repr__(self): - return f"SimNorm(dim={self.dim})" + """ + Simplicial normalization. + Adapted from https://arxiv.org/abs/2204.00616. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + shp = x.shape + x = x.view(*shp[:-1], -1, self.dim) + x = F.softmax(x, dim=-1) + return x.view(*shp) + + def __repr__(self): + return f"SimNorm(dim={self.dim})" class NormedLinear(nn.Linear): - """ - Linear layer with LayerNorm, activation, and optionally dropout. - """ + """ + Linear layer with LayerNorm, activation, and optionally dropout. + """ - def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs): - super().__init__(*args, **kwargs) - self.ln = nn.LayerNorm(self.out_features) - self.act = act - self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None + def __init__(self, *args, dropout=0.0, act=nn.Mish(inplace=True), **kwargs): + super().__init__(*args, **kwargs) + self.ln = nn.LayerNorm(self.out_features) + self.act = act + self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None - def forward(self, x): - x = super().forward(x) - if self.dropout: - x = self.dropout(x) - return self.act(self.ln(x)) - - def __repr__(self): - repr_dropout = f", dropout={self.dropout.p}" if self.dropout else "" - return f"NormedLinear(in_features={self.in_features}, "\ - f"out_features={self.out_features}, "\ - f"bias={self.bias is not None}{repr_dropout}, "\ - f"act={self.act.__class__.__name__})" + def forward(self, x): + x = super().forward(x) + if self.dropout: + x = self.dropout(x) + return self.act(self.ln(x)) + + def __repr__(self): + repr_dropout = f", dropout={self.dropout.p}" if self.dropout else "" + return ( + f"NormedLinear(in_features={self.in_features}, " + f"out_features={self.out_features}, " + f"bias={self.bias is not None}{repr_dropout}, " + f"act={self.act.__class__.__name__})" + ) def soft_cross_entropy(pred, target, cfg): - """Computes the cross entropy loss between predictions and soft targets.""" - pred = F.log_softmax(pred, dim=-1) - target = two_hot(target, cfg) - return -(target * pred).sum(-1, keepdim=True) + """Computes the cross entropy loss between predictions and soft targets.""" + pred = F.log_softmax(pred, dim=-1) + target = two_hot(target, cfg) + import pudb + + pudb.set_trace() + return -(target * pred).sum(-1, keepdim=True) @torch.jit.script def log_std(x, low, dif): - return low + 0.5 * dif * (torch.tanh(x) + 1) + return low + 0.5 * dif * (torch.tanh(x) + 1) @torch.jit.script def _gaussian_residual(eps, log_std): - return -0.5 * eps.pow(2) - log_std + return -0.5 * eps.pow(2) - log_std @torch.jit.script def _gaussian_logprob(residual): - return residual - 0.5 * torch.log(2 * torch.pi) + return residual - 0.5 * torch.log(2 * torch.pi) def gaussian_logprob(eps, log_std, size=None): - """Compute Gaussian log probability.""" - residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True) - if size is None: - size = eps.size(-1) - return _gaussian_logprob(residual) * size + """Compute Gaussian log probability.""" + residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True) + if size is None: + size = eps.size(-1) + return _gaussian_logprob(residual) * size @torch.jit.script def _squash(pi): - return torch.log(F.relu(1 - pi.pow(2)) + 1e-6) + return torch.log(F.relu(1 - pi.pow(2)) + 1e-6) def squash(mu, pi, log_pi): - """Apply squashing function.""" - mu = torch.tanh(mu) - pi = torch.tanh(pi) - log_pi -= _squash(pi).sum(-1, keepdim=True) - return mu, pi, log_pi + """Apply squashing function.""" + mu = torch.tanh(mu) + pi = torch.tanh(pi) + log_pi -= _squash(pi).sum(-1, keepdim=True) + return mu, pi, log_pi @torch.jit.script def symlog(x): - """ - Symmetric logarithmic function. - Adapted from https://github.com/danijar/dreamerv3. - """ - return torch.sign(x) * torch.log(1 + torch.abs(x)) + """ + Symmetric logarithmic function. + Adapted from https://github.com/danijar/dreamerv3. + """ + return torch.sign(x) * torch.log(1 + torch.abs(x)) @torch.jit.script def symexp(x): - """ - Symmetric exponential function. - Adapted from https://github.com/danijar/dreamerv3. - """ - return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) + """ + Symmetric exponential function. + Adapted from https://github.com/danijar/dreamerv3. + """ + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) def two_hot(x, cfg): - """Converts a batch of scalars to soft two-hot encoded targets for discrete regression.""" - if cfg.num_bins == 0: - return x - elif cfg.num_bins == 1: - return symlog(x) - x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1) - bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() - bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) - soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device) - soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset) - soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset) - return soft_two_hot + """Converts a batch of scalars to soft two-hot encoded targets for discrete regression.""" + if cfg.num_bins == 0: + return x + elif cfg.num_bins == 1: + return symlog(x) + x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax) + bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() + bin_offset = (x - cfg.vmin) / cfg.bin_size - bin_idx.float() + soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device) + soft_two_hot.scatter_(1, bin_idx, 1 - bin_offset) + soft_two_hot.scatter_(1, (bin_idx + 1) % cfg.num_bins, bin_offset) + return soft_two_hot + def two_hot_inv(x, bins): - """Converts a batch of soft two-hot encoded vectors to scalars.""" - num_bins = bins.shape[0] - if num_bins == 0: - return x - elif num_bins == 1: - return symexp(x) + """Converts a batch of soft two-hot encoded vectors to scalars.""" + num_bins = bins.shape[0] + if num_bins == 0: + return x + elif num_bins == 1: + return symexp(x) - x = F.softmax(x, dim=-1) - x = torch.sum(x * bins, dim=-1, keepdim=True) - return symexp(x) + x = F.softmax(x, dim=-1) + x = torch.sum(x * bins, dim=-1, keepdim=True) + return symexp(x) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f60f904e..9c7df689 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -93,6 +93,18 @@ def make_optimizer_and_scheduler(cfg, policy): elif policy.name == "tdmpc": optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) lr_scheduler = None + + elif policy.name == "tdmpc2": + params_group = [ + {"params": policy.model._encoder.parameters(), "lr": cfg.training.lr * cfg.training.enc_lr_scale}, + {"params": policy.model._dynamics.parameters()}, + {"params": policy.model._reward.parameters()}, + {"params": policy.model._Qs.parameters()}, + {"params": policy.model._pi.parameters(), "eps": 1e-5}, + ] + optimizer = torch.optim.Adam(params_group, lr=cfg.training.lr) + lr_scheduler = None + elif cfg.policy.name == "vqbet": from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler