added support for maniskill environment in factory.py

This commit is contained in:
Michel Aractingi 2024-11-26 09:35:09 +00:00
parent 15090c2544
commit f3450aea41
6 changed files with 296 additions and 170 deletions

View File

@ -14,8 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib import importlib
from collections import deque
import gymnasium as gym import gymnasium as gym
import numpy as np
import torch
from omegaconf import DictConfig from omegaconf import DictConfig
@ -30,6 +33,10 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
if cfg.env.name == "real_world": if cfg.env.name == "real_world":
return return
if "maniskill" in cfg.env.name:
env = make_maniskill_env(cfg, n_envs if n_envs is not None else cfg.eval.batch_size)
return env
package_name = f"gym_{cfg.env.name}" package_name = f"gym_{cfg.env.name}"
try: try:
@ -56,3 +63,58 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
) )
return env return env
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
env = gym.make(
cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
num_envs=n_envs,
)
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
env = ManiSkillVectorEnv(env, ignore_terminations=True)
env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env.unwrapped.metadata["render_fps"] = 20
return env
class PixelWrapper(gym.Wrapper):
"""
Wrapper for pixel observations. Works with Maniskill vectorized environments
"""
def __init__(self, cfg, env, num_envs, num_frames=3):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
dtype=np.uint8,
)
self._frames = deque([], maxlen=num_frames)
self._render_size = cfg.env.render_size
def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame)
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
def reset(self, seed):
obs, info = self.env.reset() # (seed=seed)
for _ in range(self._frames.maxlen):
obs_frames = self._get_obs(obs)
return obs_frames, info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info

View File

@ -51,6 +51,13 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy, TDMPCConfig return TDMPCPolicy, TDMPCConfig
elif name == "tdmpc2":
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy
return TDMPC2Policy, TDMPC2Config
elif name == "diffusion": elif name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

View File

@ -126,9 +126,12 @@ class TDMPC2Config:
state_encoder_hidden_dim: int = 256 state_encoder_hidden_dim: int = 256
latent_dim: int = 512 latent_dim: int = 512
q_ensemble_size: int = 5 q_ensemble_size: int = 5
num_enc_layers: int = 2
mlp_dim: int = 512 mlp_dim: int = 512
# Reinforcement learning. # Reinforcement learning.
discount: float = 0.9 discount: float = 0.9
simnorm_dim: int = 8
dropout: float = 0.01
# actor # actor
log_std_min: float = -10 log_std_min: float = -10
@ -157,10 +160,10 @@ class TDMPC2Config:
consistency_coeff: float = 20.0 consistency_coeff: float = 20.0
entropy_coef: float = 1e-4 entropy_coef: float = 1e-4
temporal_decay_coeff: float = 0.5 temporal_decay_coeff: float = 0.5
# Target model. NOTE (michel_aractingi) this is equivelant to # Target model. NOTE (michel_aractingi) this is equivelant to
# 1 - target_model_momentum of our TD-MPC1 implementation because # 1 - target_model_momentum of our TD-MPC1 implementation because
# of the use of `torch.lerp` # of the use of `torch.lerp`
target_model_momentum: float = 0.01 target_model_momentum: float = 0.01
def __post_init__(self): def __post_init__(self):
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""

View File

@ -38,8 +38,16 @@ from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.tdmpc2.tdmpc2_utils import (
NormedLinear,
SimNorm,
gaussian_logprob,
soft_cross_entropy,
squash,
two_hot_inv,
)
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash, soft_cross_entropy
class TDMPC2Policy( class TDMPC2Policy(
nn.Module, nn.Module,
@ -83,6 +91,8 @@ class TDMPC2Policy(
config = TDMPC2Config() config = TDMPC2Config()
self.config = config self.config = config
self.model = TDMPC2WorldModel(config) self.model = TDMPC2WorldModel(config)
# TODO (michel-aractingi) temp fix for gpu
self.model = self.model.to("cuda:0")
if config.input_normalization_modes is not None: if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(
@ -109,7 +119,9 @@ class TDMPC2Policy(
self._use_env_state = True self._use_env_state = True
self.scale = RunningScale(self.config.target_model_momentum) self.scale = RunningScale(self.config.target_model_momentum)
self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length self.discount = (
self.config.discount
) # TODO (michel-aractingi) downscale discount according to episode length
self.reset() self.reset()
@ -204,7 +216,7 @@ class TDMPC2Policy(
for t in range(self.config.horizon): for t in range(self.config.horizon):
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be # Note: Adding a small amount of noise here doesn't hurt during inference and may even be
# helpful for CEM. # helpful for CEM.
pi_actions[t] = self.model.pi(_z, self.config.min_std)[0] pi_actions[t] = self.model.pi(_z)[0]
_z = self.model.latent_dynamics(_z, pi_actions[t]) _z = self.model.latent_dynamics(_z, pi_actions[t])
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled # In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
@ -249,14 +261,17 @@ class TDMPC2Policy(
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score /= score.sum(axis=0, keepdim=True) score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim) # (horizon, batch, action_dim)
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (score.sum(0) + 1e-9) mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (
einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9
)
std = torch.sqrt( std = torch.sqrt(
torch.sum( torch.sum(
einops.rearrange(score, "n b -> n b 1") einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2, * (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2,
dim=1, dim=1,
) / (score.sum(0) + 1e-9) )
).clamp_(self.config.min_std, self.config.max_std) / (einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9)
).clamp_(self.config.min_std, self.config.max_std)
# Keep track of the mean for warm-starting subsequent steps. # Keep track of the mean for warm-starting subsequent steps.
self._prev_mean = mean self._prev_mean = mean
@ -286,11 +301,11 @@ class TDMPC2Policy(
# Update the return and running discount. # Update the return and running discount.
G += running_discount * reward G += running_discount * reward
running_discount *= self.config.discount running_discount *= self.config.discount
#next_action = self.model.pi(z)[0] # (batch, action_dim)
#terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type='avg') # next_action = self.model.pi(z)[0] # (batch, action_dim)
# terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type="avg")
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss. """Run the batch through the model and compute the loss.
@ -351,14 +366,15 @@ class TDMPC2Policy(
# Compute various targets with stopgrad. # Compute various targets with stopgrad.
with torch.no_grad(): with torch.no_grad():
# Latent state consistency targets for consistency loss. # Latent state consistency targets for consistency loss.
z_targets = self.model.encode(next_observations) z_targets = self.model.encode(next_observations)
# Compute the TD-target from a reward and the next observation # Compute the TD-target from a reward and the next observation
pi = self.model.pi(z_targets)[0] pi = self.model.pi(z_targets)[0]
td_targets = ( td_targets = (
reward reward
+ self.config.discount * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze() + self.config.discount
* self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
) )
# Compute losses. # Compute losses.
@ -421,18 +437,20 @@ class TDMPC2Policy(
z_preds = z_preds.detach() z_preds = z_preds.detach()
self.model.change_q_grad(mode=False) self.model.change_q_grad(mode=False)
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1]) action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
with torch.no_grad(): with torch.no_grad():
# avoid unnessecary computation of the gradients during policy optimization # avoid unnessecary computation of the gradients during policy optimization
# TODO (michel-aractingi): the same logic should be extended when adding task embeddings # TODO (michel-aractingi): the same logic should be extended when adding task embeddings
qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg") qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg")
self.scale.update(qs[0]) self.scale.update(qs[0])
qs = self.scale(qs) qs = self.scale(qs)
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(-1) rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(
-1
)
pi_loss = ( pi_loss = (
(self.config.entropy_coef * log_pis - qs).mean(dim=(1,2)) (self.config.entropy_coef * log_pis - qs).mean(dim=(1, 2))
* rho * rho
# * temporal_loss_coeffs # * temporal_loss_coeffs
# `action_preds` depends on the first observation and the actions. # `action_preds` depends on the first observation and the actions.
@ -470,56 +488,69 @@ class TDMPC2Policy(
"""Update the target model's using polyak averaging.""" """Update the target model's using polyak averaging."""
self.model.update_target_Q() self.model.update_target_Q()
class TDMPC2WorldModel(nn.Module): class TDMPC2WorldModel(nn.Module):
"""Latent dynamics model used in TD-MPC2.""" """Latent dynamics model used in TD-MPC2."""
def __init__(self, config: TDMPC2Config): def __init__(self, config: TDMPC2Config):
super().__init__() super().__init__()
self.config = config self.config = config
self._encoder = TDMPC2ObservationEncoder(config) self._encoder = TDMPC2ObservationEncoder(config)
# Define latent dynamics head # Define latent dynamics head
self._dynamics = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), self._dynamics = nn.Sequential(
NormedLinear(config.mlp_dim, config.mlp_dim), NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim))) NormedLinear(config.mlp_dim, config.mlp_dim),
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)),
)
# Define reward head # Define reward head
self._reward = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), self._reward = nn.Sequential(
NormedLinear(config.mlp_dim, config.mlp_dim), NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1))) NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
)
# Define policy head # Define policy head
self._pi = nn.Sequential(NormedLinear(config.latent_dim, config.mlp_dim), self._pi = nn.Sequential(
NormedLinear(config.mlp_dim, config.mlp_dim), NormedLinear(config.latent_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0])) NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]),
)
# Define ensemble of Q functions # Define ensemble of Q functions
self._Qs = nn.ModuleList( self._Qs = nn.ModuleList(
[ [
nn.Sequential( nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim, dropout=config.dropout), NormedLinear(
config.latent_dim + config.output_shapes["action"][0],
config.mlp_dim,
dropout=config.dropout,
),
NormedLinear(config.mlp_dim, config.mlp_dim), NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1)) nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
) for _ in range(config.q_ensemble_size) )
for _ in range(config.q_ensemble_size)
] ]
) )
self._init_weights() self._init_weights()
self._target_Qs = deepcopy(self._Qs).requires_grad_(False) self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
self.log_std_min = torch.tensor(config.log_std_min) self.log_std_min = torch.tensor(config.log_std_min)
self.log_std_dif = torch.tensor(config.log_std_max) - self.log_std_min self.log_std_dif = torch.tensor(config.log_std_max) - self.log_std_min
self.bins = torch.linspace(config.vmin, config.vmax, config.num_bins) self.bins = torch.linspace(config.vmin, config.vmax, config.num_bins)
self.config.bin_size = (config.vmax - config.vmin) / (config.num_bins - 1) self.config.bin_size = (config.vmax - config.vmin) / (config.num_bins - 1)
def _init_weights(self): def _init_weights(self):
"""Initialize model weights. """Initialize model weights.
Custom weight initializations proposed in TD-MPC2. Custom weight initializations proposed in TD-MPC2.
""" """
def _apply_fn(m): def _apply_fn(m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
@ -527,13 +558,13 @@ class TDMPC2WorldModel(nn.Module):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.ParameterList): elif isinstance(m, nn.ParameterList):
for i, p in enumerate(m): for i, p in enumerate(m):
if p.dim() == 3: # Linear if p.dim() == 3: # Linear
nn.init.trunc_normal_(p, std=0.02) # Weight nn.init.trunc_normal_(p, std=0.02) # Weight
nn.init.constant_(m[i+1], 0) # Bias nn.init.constant_(m[i + 1], 0) # Bias
self.apply(_apply_fn) self.apply(_apply_fn)
# initialize parameters of the # initialize parameters of the
for m in [self._reward, *self._Qs]: for m in [self._reward, *self._Qs]:
assert isinstance( assert isinstance(
m[-1], nn.Linear m[-1], nn.Linear
@ -549,7 +580,7 @@ class TDMPC2WorldModel(nn.Module):
self.log_std_dif = self.log_std_dif.to(*args, **kwargs) self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
self.bins = self.bins.to(*args, **kwargs) self.bins = self.bins.to(*args, **kwargs)
return self return self
def train(self, mode): def train(self, mode):
super().train(mode) super().train(mode)
self._target_Qs.train(False) self._target_Qs.train(False)
@ -641,7 +672,7 @@ class TDMPC2WorldModel(nn.Module):
Soft-update target Q-networks using Polyak averaging. Soft-update target Q-networks using Polyak averaging.
""" """
with torch.no_grad(): with torch.no_grad():
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()): for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters(), strict=False):
p_target.data.lerp_(p.data, self.config.target_model_momentum) p_target.data.lerp_(p.data, self.config.target_model_momentum)
@ -662,49 +693,51 @@ class TDMPC2ObservationEncoder(nn.Module):
for obs_key in config.input_shapes: for obs_key in config.input_shapes:
if "observation.image" in config.input_shapes: if "observation.image" in config.input_shapes:
encoder_module = nn.Sequential( encoder_module = nn.Sequential(
nn.Conv2d(config.input_shapes[obs_key][0], config.image_encoder_hidden_dim, 7, stride=2), nn.Conv2d(config.input_shapes[obs_key][0], config.image_encoder_hidden_dim, 7, stride=2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=1), nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=1),
) )
dummy_batch = torch.zeros(1, *config.input_shapes[obs_key]) dummy_batch = torch.zeros(1, *config.input_shapes[obs_key])
with torch.inference_mode(): with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:] out_shape = encoder_module(dummy_batch).shape[1:]
encoder_module.extend( encoder_module.extend(
nn.Sequential( nn.Sequential(
nn.Flatten(), nn.Flatten(),
NormedLinear(np.prod(out_shape), config.latent_dim, act=SimNorm(config.simnorm_dim)), NormedLinear(np.prod(out_shape), config.latent_dim, act=SimNorm(config.simnorm_dim)),
) )
) )
elif "observation.state" in config.input_shapes: elif (
"observation.state" in config.input_shapes
or "observation.environment_state" in config.input_shapes
):
encoder_module = nn.ModuleList() encoder_module = nn.ModuleList()
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)) encoder_module.append(
NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)
)
assert config.num_enc_layers > 0 assert config.num_enc_layers > 0
for _ in range(config.num_enc_layers - 1): for _ in range(config.num_enc_layers - 1):
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)) encoder_module.append(
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim))) NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)
encoder_module = nn.Sequential(*encoder_module) )
encoder_module.append(
elif "observation.environment_state" in config.input_shapes: NormedLinear(
encoder_module = nn.ModuleList() config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)) )
assert config.num_enc_layers > 0 )
for _ in range(config.num_enc_layers - 1):
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
encoder_module = nn.Sequential(*encoder_module) encoder_module = nn.Sequential(*encoder_module)
else: else:
raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.") raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.")
encoder_dict[obs_key] = encoder_module encoder_dict[obs_key.replace(".", "")] = encoder_module
self.encoder = nn.ModuleDict(encoder_dict) self.encoder = nn.ModuleDict(encoder_dict)
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector. """Encode the image and/or state vector.
@ -714,9 +747,11 @@ class TDMPC2ObservationEncoder(nn.Module):
feat = [] feat = []
for obs_key in self.config.input_shapes: for obs_key in self.config.input_shapes:
if "observation.image" in obs_key: if "observation.image" in obs_key:
feat.append(flatten_forward_unflatten(self.encoder[obs_key], obs_dict[obs_key])) feat.append(
flatten_forward_unflatten(self.encoder[obs_key.replace(".", "")], obs_dict[obs_key])
)
else: else:
feat.append(self.encoder[obs_key](obs_dict[obs_key])) feat.append(self.encoder[obs_key.replace(".", "")](obs_dict[obs_key]))
return torch.stack(feat, dim=0).mean(0) return torch.stack(feat, dim=0).mean(0)
@ -816,4 +851,4 @@ class RunningScale:
return x * (1 / self.value) return x * (1 / self.value)
def __repr__(self): def __repr__(self):
return f"RunningScale(S: {self.value})" return f"RunningScale(S: {self.value})"

View File

@ -5,152 +5,159 @@ from functorch import combine_state_for_ensemble
class Ensemble(nn.Module): class Ensemble(nn.Module):
""" """
Vectorized ensemble of modules. Vectorized ensemble of modules.
""" """
def __init__(self, modules, **kwargs): def __init__(self, modules, **kwargs):
super().__init__() super().__init__()
modules = nn.ModuleList(modules) modules = nn.ModuleList(modules)
fn, params, _ = combine_state_for_ensemble(modules) fn, params, _ = combine_state_for_ensemble(modules)
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs) self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness="different", **kwargs)
self.params = nn.ParameterList([nn.Parameter(p) for p in params]) self.params = nn.ParameterList([nn.Parameter(p) for p in params])
self._repr = str(modules) self._repr = str(modules)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.vmap([p for p in self.params], (), *args, **kwargs) return self.vmap([p for p in self.params], (), *args, **kwargs)
def __repr__(self):
return "Vectorized " + self._repr
def __repr__(self):
return 'Vectorized ' + self._repr
class SimNorm(nn.Module): class SimNorm(nn.Module):
""" """
Simplicial normalization. Simplicial normalization.
Adapted from https://arxiv.org/abs/2204.00616. Adapted from https://arxiv.org/abs/2204.00616.
""" """
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
def forward(self, x): def forward(self, x):
shp = x.shape shp = x.shape
x = x.view(*shp[:-1], -1, self.dim) x = x.view(*shp[:-1], -1, self.dim)
x = F.softmax(x, dim=-1) x = F.softmax(x, dim=-1)
return x.view(*shp) return x.view(*shp)
def __repr__(self): def __repr__(self):
return f"SimNorm(dim={self.dim})" return f"SimNorm(dim={self.dim})"
class NormedLinear(nn.Linear): class NormedLinear(nn.Linear):
""" """
Linear layer with LayerNorm, activation, and optionally dropout. Linear layer with LayerNorm, activation, and optionally dropout.
""" """
def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs): def __init__(self, *args, dropout=0.0, act=nn.Mish(inplace=True), **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.ln = nn.LayerNorm(self.out_features) self.ln = nn.LayerNorm(self.out_features)
self.act = act self.act = act
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None
def forward(self, x): def forward(self, x):
x = super().forward(x) x = super().forward(x)
if self.dropout: if self.dropout:
x = self.dropout(x) x = self.dropout(x)
return self.act(self.ln(x)) return self.act(self.ln(x))
def __repr__(self): def __repr__(self):
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else "" repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
return f"NormedLinear(in_features={self.in_features}, "\ return (
f"out_features={self.out_features}, "\ f"NormedLinear(in_features={self.in_features}, "
f"bias={self.bias is not None}{repr_dropout}, "\ f"out_features={self.out_features}, "
f"act={self.act.__class__.__name__})" f"bias={self.bias is not None}{repr_dropout}, "
f"act={self.act.__class__.__name__})"
)
def soft_cross_entropy(pred, target, cfg): def soft_cross_entropy(pred, target, cfg):
"""Computes the cross entropy loss between predictions and soft targets.""" """Computes the cross entropy loss between predictions and soft targets."""
pred = F.log_softmax(pred, dim=-1) pred = F.log_softmax(pred, dim=-1)
target = two_hot(target, cfg) target = two_hot(target, cfg)
return -(target * pred).sum(-1, keepdim=True) import pudb
pudb.set_trace()
return -(target * pred).sum(-1, keepdim=True)
@torch.jit.script @torch.jit.script
def log_std(x, low, dif): def log_std(x, low, dif):
return low + 0.5 * dif * (torch.tanh(x) + 1) return low + 0.5 * dif * (torch.tanh(x) + 1)
@torch.jit.script @torch.jit.script
def _gaussian_residual(eps, log_std): def _gaussian_residual(eps, log_std):
return -0.5 * eps.pow(2) - log_std return -0.5 * eps.pow(2) - log_std
@torch.jit.script @torch.jit.script
def _gaussian_logprob(residual): def _gaussian_logprob(residual):
return residual - 0.5 * torch.log(2 * torch.pi) return residual - 0.5 * torch.log(2 * torch.pi)
def gaussian_logprob(eps, log_std, size=None): def gaussian_logprob(eps, log_std, size=None):
"""Compute Gaussian log probability.""" """Compute Gaussian log probability."""
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True) residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
if size is None: if size is None:
size = eps.size(-1) size = eps.size(-1)
return _gaussian_logprob(residual) * size return _gaussian_logprob(residual) * size
@torch.jit.script @torch.jit.script
def _squash(pi): def _squash(pi):
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6) return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
def squash(mu, pi, log_pi): def squash(mu, pi, log_pi):
"""Apply squashing function.""" """Apply squashing function."""
mu = torch.tanh(mu) mu = torch.tanh(mu)
pi = torch.tanh(pi) pi = torch.tanh(pi)
log_pi -= _squash(pi).sum(-1, keepdim=True) log_pi -= _squash(pi).sum(-1, keepdim=True)
return mu, pi, log_pi return mu, pi, log_pi
@torch.jit.script @torch.jit.script
def symlog(x): def symlog(x):
""" """
Symmetric logarithmic function. Symmetric logarithmic function.
Adapted from https://github.com/danijar/dreamerv3. Adapted from https://github.com/danijar/dreamerv3.
""" """
return torch.sign(x) * torch.log(1 + torch.abs(x)) return torch.sign(x) * torch.log(1 + torch.abs(x))
@torch.jit.script @torch.jit.script
def symexp(x): def symexp(x):
""" """
Symmetric exponential function. Symmetric exponential function.
Adapted from https://github.com/danijar/dreamerv3. Adapted from https://github.com/danijar/dreamerv3.
""" """
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
def two_hot(x, cfg): def two_hot(x, cfg):
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression.""" """Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
if cfg.num_bins == 0: if cfg.num_bins == 0:
return x return x
elif cfg.num_bins == 1: elif cfg.num_bins == 1:
return symlog(x) return symlog(x)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1) x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) bin_offset = (x - cfg.vmin) / cfg.bin_size - bin_idx.float()
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device) soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset) soft_two_hot.scatter_(1, bin_idx, 1 - bin_offset)
soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset) soft_two_hot.scatter_(1, (bin_idx + 1) % cfg.num_bins, bin_offset)
return soft_two_hot return soft_two_hot
def two_hot_inv(x, bins): def two_hot_inv(x, bins):
"""Converts a batch of soft two-hot encoded vectors to scalars.""" """Converts a batch of soft two-hot encoded vectors to scalars."""
num_bins = bins.shape[0] num_bins = bins.shape[0]
if num_bins == 0: if num_bins == 0:
return x return x
elif num_bins == 1: elif num_bins == 1:
return symexp(x) return symexp(x)
x = F.softmax(x, dim=-1) x = F.softmax(x, dim=-1)
x = torch.sum(x * bins, dim=-1, keepdim=True) x = torch.sum(x * bins, dim=-1, keepdim=True)
return symexp(x) return symexp(x)

View File

@ -93,6 +93,18 @@ def make_optimizer_and_scheduler(cfg, policy):
elif policy.name == "tdmpc": elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None lr_scheduler = None
elif policy.name == "tdmpc2":
params_group = [
{"params": policy.model._encoder.parameters(), "lr": cfg.training.lr * cfg.training.enc_lr_scale},
{"params": policy.model._dynamics.parameters()},
{"params": policy.model._reward.parameters()},
{"params": policy.model._Qs.parameters()},
{"params": policy.model._pi.parameters(), "eps": 1e-5},
]
optimizer = torch.optim.Adam(params_group, lr=cfg.training.lr)
lr_scheduler = None
elif cfg.policy.name == "vqbet": elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler