added support for maniskill environment in factory.py
This commit is contained in:
parent
15090c2544
commit
f3450aea41
lerobot
common
scripts
|
@ -14,8 +14,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from collections import deque
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
|
@ -30,6 +33,10 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
|||
if cfg.env.name == "real_world":
|
||||
return
|
||||
|
||||
if "maniskill" in cfg.env.name:
|
||||
env = make_maniskill_env(cfg, n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
return env
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
|
||||
try:
|
||||
|
@ -56,3 +63,58 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
|||
)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
env = gym.make(
|
||||
cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
|
||||
num_envs=n_envs,
|
||||
)
|
||||
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True)
|
||||
env = PixelWrapper(cfg, env, n_envs)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class PixelWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for pixel observations. Works with Maniskill vectorized environments
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, env, num_envs, num_frames=3):
|
||||
super().__init__(env)
|
||||
self.cfg = cfg
|
||||
self.env = env
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
self._frames = deque([], maxlen=num_frames)
|
||||
self._render_size = cfg.env.render_size
|
||||
|
||||
def _get_obs(self, obs):
|
||||
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
||||
self._frames.append(frame)
|
||||
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
|
||||
|
||||
def reset(self, seed):
|
||||
obs, info = self.env.reset() # (seed=seed)
|
||||
for _ in range(self._frames.maxlen):
|
||||
obs_frames = self._get_obs(obs)
|
||||
return obs_frames, info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
|
|
@ -51,6 +51,13 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
|||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
return TDMPCPolicy, TDMPCConfig
|
||||
|
||||
elif name == "tdmpc2":
|
||||
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
||||
from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy
|
||||
|
||||
return TDMPC2Policy, TDMPC2Config
|
||||
|
||||
elif name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
|
|
@ -126,9 +126,12 @@ class TDMPC2Config:
|
|||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 512
|
||||
q_ensemble_size: int = 5
|
||||
num_enc_layers: int = 2
|
||||
mlp_dim: int = 512
|
||||
# Reinforcement learning.
|
||||
discount: float = 0.9
|
||||
simnorm_dim: int = 8
|
||||
dropout: float = 0.01
|
||||
|
||||
# actor
|
||||
log_std_min: float = -10
|
||||
|
@ -157,10 +160,10 @@ class TDMPC2Config:
|
|||
consistency_coeff: float = 20.0
|
||||
entropy_coef: float = 1e-4
|
||||
temporal_decay_coeff: float = 0.5
|
||||
# Target model. NOTE (michel_aractingi) this is equivelant to
|
||||
# 1 - target_model_momentum of our TD-MPC1 implementation because
|
||||
# Target model. NOTE (michel_aractingi) this is equivelant to
|
||||
# 1 - target_model_momentum of our TD-MPC1 implementation because
|
||||
# of the use of `torch.lerp`
|
||||
target_model_momentum: float = 0.01
|
||||
target_model_momentum: float = 0.01
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
|
|
|
@ -38,8 +38,16 @@ from torch import Tensor
|
|||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
||||
from lerobot.common.policies.tdmpc2.tdmpc2_utils import (
|
||||
NormedLinear,
|
||||
SimNorm,
|
||||
gaussian_logprob,
|
||||
soft_cross_entropy,
|
||||
squash,
|
||||
two_hot_inv,
|
||||
)
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash, soft_cross_entropy
|
||||
|
||||
|
||||
class TDMPC2Policy(
|
||||
nn.Module,
|
||||
|
@ -83,6 +91,8 @@ class TDMPC2Policy(
|
|||
config = TDMPC2Config()
|
||||
self.config = config
|
||||
self.model = TDMPC2WorldModel(config)
|
||||
# TODO (michel-aractingi) temp fix for gpu
|
||||
self.model = self.model.to("cuda:0")
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
|
@ -109,7 +119,9 @@ class TDMPC2Policy(
|
|||
self._use_env_state = True
|
||||
|
||||
self.scale = RunningScale(self.config.target_model_momentum)
|
||||
self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length
|
||||
self.discount = (
|
||||
self.config.discount
|
||||
) # TODO (michel-aractingi) downscale discount according to episode length
|
||||
|
||||
self.reset()
|
||||
|
||||
|
@ -204,7 +216,7 @@ class TDMPC2Policy(
|
|||
for t in range(self.config.horizon):
|
||||
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
|
||||
# helpful for CEM.
|
||||
pi_actions[t] = self.model.pi(_z, self.config.min_std)[0]
|
||||
pi_actions[t] = self.model.pi(_z)[0]
|
||||
_z = self.model.latent_dynamics(_z, pi_actions[t])
|
||||
|
||||
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||
|
@ -249,14 +261,17 @@ class TDMPC2Policy(
|
|||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||
score /= score.sum(axis=0, keepdim=True)
|
||||
# (horizon, batch, action_dim)
|
||||
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (score.sum(0) + 1e-9)
|
||||
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (
|
||||
einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9
|
||||
)
|
||||
std = torch.sqrt(
|
||||
torch.sum(
|
||||
einops.rearrange(score, "n b -> n b 1")
|
||||
* (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2,
|
||||
dim=1,
|
||||
) / (score.sum(0) + 1e-9)
|
||||
).clamp_(self.config.min_std, self.config.max_std)
|
||||
)
|
||||
/ (einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9)
|
||||
).clamp_(self.config.min_std, self.config.max_std)
|
||||
|
||||
# Keep track of the mean for warm-starting subsequent steps.
|
||||
self._prev_mean = mean
|
||||
|
@ -286,11 +301,11 @@ class TDMPC2Policy(
|
|||
# Update the return and running discount.
|
||||
G += running_discount * reward
|
||||
running_discount *= self.config.discount
|
||||
|
||||
#next_action = self.model.pi(z)[0] # (batch, action_dim)
|
||||
#terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
|
||||
|
||||
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type='avg')
|
||||
# next_action = self.model.pi(z)[0] # (batch, action_dim)
|
||||
# terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
|
||||
|
||||
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type="avg")
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
@ -351,14 +366,15 @@ class TDMPC2Policy(
|
|||
|
||||
# Compute various targets with stopgrad.
|
||||
with torch.no_grad():
|
||||
# Latent state consistency targets for consistency loss.
|
||||
# Latent state consistency targets for consistency loss.
|
||||
z_targets = self.model.encode(next_observations)
|
||||
|
||||
# Compute the TD-target from a reward and the next observation
|
||||
pi = self.model.pi(z_targets)[0]
|
||||
td_targets = (
|
||||
reward
|
||||
+ self.config.discount * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
|
||||
+ self.config.discount
|
||||
* self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
|
||||
)
|
||||
|
||||
# Compute losses.
|
||||
|
@ -421,18 +437,20 @@ class TDMPC2Policy(
|
|||
z_preds = z_preds.detach()
|
||||
self.model.change_q_grad(mode=False)
|
||||
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
# avoid unnessecary computation of the gradients during policy optimization
|
||||
# TODO (michel-aractingi): the same logic should be extended when adding task embeddings
|
||||
# TODO (michel-aractingi): the same logic should be extended when adding task embeddings
|
||||
qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg")
|
||||
self.scale.update(qs[0])
|
||||
qs = self.scale(qs)
|
||||
|
||||
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(-1)
|
||||
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(
|
||||
-1
|
||||
)
|
||||
|
||||
pi_loss = (
|
||||
(self.config.entropy_coef * log_pis - qs).mean(dim=(1,2))
|
||||
(self.config.entropy_coef * log_pis - qs).mean(dim=(1, 2))
|
||||
* rho
|
||||
# * temporal_loss_coeffs
|
||||
# `action_preds` depends on the first observation and the actions.
|
||||
|
@ -470,56 +488,69 @@ class TDMPC2Policy(
|
|||
"""Update the target model's using polyak averaging."""
|
||||
self.model.update_target_Q()
|
||||
|
||||
|
||||
class TDMPC2WorldModel(nn.Module):
|
||||
"""Latent dynamics model used in TD-MPC2."""
|
||||
|
||||
def __init__(self, config: TDMPC2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
|
||||
self._encoder = TDMPC2ObservationEncoder(config)
|
||||
|
||||
# Define latent dynamics head
|
||||
self._dynamics = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
||||
|
||||
self._dynamics = nn.Sequential(
|
||||
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)),
|
||||
)
|
||||
|
||||
# Define reward head
|
||||
self._reward = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
nn.Linear(config.mlp_dim, max(config.num_bins, 1)))
|
||||
|
||||
self._reward = nn.Sequential(
|
||||
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
|
||||
)
|
||||
|
||||
# Define policy head
|
||||
self._pi = nn.Sequential(NormedLinear(config.latent_dim, config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]))
|
||||
|
||||
self._pi = nn.Sequential(
|
||||
NormedLinear(config.latent_dim, config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]),
|
||||
)
|
||||
|
||||
# Define ensemble of Q functions
|
||||
self._Qs = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim, dropout=config.dropout),
|
||||
NormedLinear(
|
||||
config.latent_dim + config.output_shapes["action"][0],
|
||||
config.mlp_dim,
|
||||
dropout=config.dropout,
|
||||
),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
nn.Linear(config.mlp_dim, max(config.num_bins, 1))
|
||||
) for _ in range(config.q_ensemble_size)
|
||||
nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
|
||||
)
|
||||
for _ in range(config.q_ensemble_size)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
self._init_weights()
|
||||
|
||||
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
|
||||
|
||||
self.log_std_min = torch.tensor(config.log_std_min)
|
||||
self.log_std_dif = torch.tensor(config.log_std_max) - self.log_std_min
|
||||
|
||||
|
||||
self.bins = torch.linspace(config.vmin, config.vmax, config.num_bins)
|
||||
self.config.bin_size = (config.vmax - config.vmin) / (config.num_bins - 1)
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialize model weights.
|
||||
Custom weight initializations proposed in TD-MPC2.
|
||||
Custom weight initializations proposed in TD-MPC2.
|
||||
|
||||
"""
|
||||
|
||||
def _apply_fn(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
|
@ -527,13 +558,13 @@ class TDMPC2WorldModel(nn.Module):
|
|||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.ParameterList):
|
||||
for i, p in enumerate(m):
|
||||
if p.dim() == 3: # Linear
|
||||
nn.init.trunc_normal_(p, std=0.02) # Weight
|
||||
nn.init.constant_(m[i+1], 0) # Bias
|
||||
if p.dim() == 3: # Linear
|
||||
nn.init.trunc_normal_(p, std=0.02) # Weight
|
||||
nn.init.constant_(m[i + 1], 0) # Bias
|
||||
|
||||
self.apply(_apply_fn)
|
||||
|
||||
# initialize parameters of the
|
||||
|
||||
# initialize parameters of the
|
||||
for m in [self._reward, *self._Qs]:
|
||||
assert isinstance(
|
||||
m[-1], nn.Linear
|
||||
|
@ -549,7 +580,7 @@ class TDMPC2WorldModel(nn.Module):
|
|||
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
|
||||
self.bins = self.bins.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def train(self, mode):
|
||||
super().train(mode)
|
||||
self._target_Qs.train(False)
|
||||
|
@ -641,7 +672,7 @@ class TDMPC2WorldModel(nn.Module):
|
|||
Soft-update target Q-networks using Polyak averaging.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
|
||||
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters(), strict=False):
|
||||
p_target.data.lerp_(p.data, self.config.target_model_momentum)
|
||||
|
||||
|
||||
|
@ -662,49 +693,51 @@ class TDMPC2ObservationEncoder(nn.Module):
|
|||
for obs_key in config.input_shapes:
|
||||
if "observation.image" in config.input_shapes:
|
||||
encoder_module = nn.Sequential(
|
||||
nn.Conv2d(config.input_shapes[obs_key][0], config.image_encoder_hidden_dim, 7, stride=2),
|
||||
nn.Conv2d(config.input_shapes[obs_key][0], config.image_encoder_hidden_dim, 7, stride=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=1),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=1),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes[obs_key])
|
||||
with torch.inference_mode():
|
||||
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
out_shape = encoder_module(dummy_batch).shape[1:]
|
||||
encoder_module.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
NormedLinear(np.prod(out_shape), config.latent_dim, act=SimNorm(config.simnorm_dim)),
|
||||
)
|
||||
)
|
||||
|
||||
elif "observation.state" in config.input_shapes:
|
||||
|
||||
elif (
|
||||
"observation.state" in config.input_shapes
|
||||
or "observation.environment_state" in config.input_shapes
|
||||
):
|
||||
encoder_module = nn.ModuleList()
|
||||
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim))
|
||||
encoder_module.append(
|
||||
NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)
|
||||
)
|
||||
assert config.num_enc_layers > 0
|
||||
for _ in range(config.num_enc_layers - 1):
|
||||
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
|
||||
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
||||
encoder_module = nn.Sequential(*encoder_module)
|
||||
|
||||
elif "observation.environment_state" in config.input_shapes:
|
||||
encoder_module = nn.ModuleList()
|
||||
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim))
|
||||
assert config.num_enc_layers > 0
|
||||
for _ in range(config.num_enc_layers - 1):
|
||||
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
|
||||
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
||||
encoder_module.append(
|
||||
NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)
|
||||
)
|
||||
encoder_module.append(
|
||||
NormedLinear(
|
||||
config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)
|
||||
)
|
||||
)
|
||||
encoder_module = nn.Sequential(*encoder_module)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.")
|
||||
|
||||
encoder_dict[obs_key] = encoder_module
|
||||
|
||||
encoder_dict[obs_key.replace(".", "")] = encoder_module
|
||||
|
||||
self.encoder = nn.ModuleDict(encoder_dict)
|
||||
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
|
@ -714,9 +747,11 @@ class TDMPC2ObservationEncoder(nn.Module):
|
|||
feat = []
|
||||
for obs_key in self.config.input_shapes:
|
||||
if "observation.image" in obs_key:
|
||||
feat.append(flatten_forward_unflatten(self.encoder[obs_key], obs_dict[obs_key]))
|
||||
feat.append(
|
||||
flatten_forward_unflatten(self.encoder[obs_key.replace(".", "")], obs_dict[obs_key])
|
||||
)
|
||||
else:
|
||||
feat.append(self.encoder[obs_key](obs_dict[obs_key]))
|
||||
feat.append(self.encoder[obs_key.replace(".", "")](obs_dict[obs_key]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
|
@ -816,4 +851,4 @@ class RunningScale:
|
|||
return x * (1 / self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return f"RunningScale(S: {self.value})"
|
||||
return f"RunningScale(S: {self.value})"
|
||||
|
|
|
@ -5,152 +5,159 @@ from functorch import combine_state_for_ensemble
|
|||
|
||||
|
||||
class Ensemble(nn.Module):
|
||||
"""
|
||||
Vectorized ensemble of modules.
|
||||
"""
|
||||
"""
|
||||
Vectorized ensemble of modules.
|
||||
"""
|
||||
|
||||
def __init__(self, modules, **kwargs):
|
||||
super().__init__()
|
||||
modules = nn.ModuleList(modules)
|
||||
fn, params, _ = combine_state_for_ensemble(modules)
|
||||
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs)
|
||||
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
|
||||
self._repr = str(modules)
|
||||
def __init__(self, modules, **kwargs):
|
||||
super().__init__()
|
||||
modules = nn.ModuleList(modules)
|
||||
fn, params, _ = combine_state_for_ensemble(modules)
|
||||
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness="different", **kwargs)
|
||||
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
|
||||
self._repr = str(modules)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.vmap([p for p in self.params], (), *args, **kwargs)
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.vmap([p for p in self.params], (), *args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "Vectorized " + self._repr
|
||||
|
||||
def __repr__(self):
|
||||
return 'Vectorized ' + self._repr
|
||||
|
||||
class SimNorm(nn.Module):
|
||||
"""
|
||||
Simplicial normalization.
|
||||
Adapted from https://arxiv.org/abs/2204.00616.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
shp = x.shape
|
||||
x = x.view(*shp[:-1], -1, self.dim)
|
||||
x = F.softmax(x, dim=-1)
|
||||
return x.view(*shp)
|
||||
|
||||
def __repr__(self):
|
||||
return f"SimNorm(dim={self.dim})"
|
||||
"""
|
||||
Simplicial normalization.
|
||||
Adapted from https://arxiv.org/abs/2204.00616.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
shp = x.shape
|
||||
x = x.view(*shp[:-1], -1, self.dim)
|
||||
x = F.softmax(x, dim=-1)
|
||||
return x.view(*shp)
|
||||
|
||||
def __repr__(self):
|
||||
return f"SimNorm(dim={self.dim})"
|
||||
|
||||
|
||||
class NormedLinear(nn.Linear):
|
||||
"""
|
||||
Linear layer with LayerNorm, activation, and optionally dropout.
|
||||
"""
|
||||
"""
|
||||
Linear layer with LayerNorm, activation, and optionally dropout.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ln = nn.LayerNorm(self.out_features)
|
||||
self.act = act
|
||||
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None
|
||||
def __init__(self, *args, dropout=0.0, act=nn.Mish(inplace=True), **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ln = nn.LayerNorm(self.out_features)
|
||||
self.act = act
|
||||
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None
|
||||
|
||||
def forward(self, x):
|
||||
x = super().forward(x)
|
||||
if self.dropout:
|
||||
x = self.dropout(x)
|
||||
return self.act(self.ln(x))
|
||||
|
||||
def __repr__(self):
|
||||
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
|
||||
return f"NormedLinear(in_features={self.in_features}, "\
|
||||
f"out_features={self.out_features}, "\
|
||||
f"bias={self.bias is not None}{repr_dropout}, "\
|
||||
f"act={self.act.__class__.__name__})"
|
||||
def forward(self, x):
|
||||
x = super().forward(x)
|
||||
if self.dropout:
|
||||
x = self.dropout(x)
|
||||
return self.act(self.ln(x))
|
||||
|
||||
def __repr__(self):
|
||||
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
|
||||
return (
|
||||
f"NormedLinear(in_features={self.in_features}, "
|
||||
f"out_features={self.out_features}, "
|
||||
f"bias={self.bias is not None}{repr_dropout}, "
|
||||
f"act={self.act.__class__.__name__})"
|
||||
)
|
||||
|
||||
|
||||
def soft_cross_entropy(pred, target, cfg):
|
||||
"""Computes the cross entropy loss between predictions and soft targets."""
|
||||
pred = F.log_softmax(pred, dim=-1)
|
||||
target = two_hot(target, cfg)
|
||||
return -(target * pred).sum(-1, keepdim=True)
|
||||
"""Computes the cross entropy loss between predictions and soft targets."""
|
||||
pred = F.log_softmax(pred, dim=-1)
|
||||
target = two_hot(target, cfg)
|
||||
import pudb
|
||||
|
||||
pudb.set_trace()
|
||||
return -(target * pred).sum(-1, keepdim=True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def log_std(x, low, dif):
|
||||
return low + 0.5 * dif * (torch.tanh(x) + 1)
|
||||
return low + 0.5 * dif * (torch.tanh(x) + 1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _gaussian_residual(eps, log_std):
|
||||
return -0.5 * eps.pow(2) - log_std
|
||||
return -0.5 * eps.pow(2) - log_std
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _gaussian_logprob(residual):
|
||||
return residual - 0.5 * torch.log(2 * torch.pi)
|
||||
return residual - 0.5 * torch.log(2 * torch.pi)
|
||||
|
||||
|
||||
def gaussian_logprob(eps, log_std, size=None):
|
||||
"""Compute Gaussian log probability."""
|
||||
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
|
||||
if size is None:
|
||||
size = eps.size(-1)
|
||||
return _gaussian_logprob(residual) * size
|
||||
"""Compute Gaussian log probability."""
|
||||
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
|
||||
if size is None:
|
||||
size = eps.size(-1)
|
||||
return _gaussian_logprob(residual) * size
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _squash(pi):
|
||||
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
|
||||
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
|
||||
|
||||
|
||||
def squash(mu, pi, log_pi):
|
||||
"""Apply squashing function."""
|
||||
mu = torch.tanh(mu)
|
||||
pi = torch.tanh(pi)
|
||||
log_pi -= _squash(pi).sum(-1, keepdim=True)
|
||||
return mu, pi, log_pi
|
||||
"""Apply squashing function."""
|
||||
mu = torch.tanh(mu)
|
||||
pi = torch.tanh(pi)
|
||||
log_pi -= _squash(pi).sum(-1, keepdim=True)
|
||||
return mu, pi, log_pi
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def symlog(x):
|
||||
"""
|
||||
Symmetric logarithmic function.
|
||||
Adapted from https://github.com/danijar/dreamerv3.
|
||||
"""
|
||||
return torch.sign(x) * torch.log(1 + torch.abs(x))
|
||||
"""
|
||||
Symmetric logarithmic function.
|
||||
Adapted from https://github.com/danijar/dreamerv3.
|
||||
"""
|
||||
return torch.sign(x) * torch.log(1 + torch.abs(x))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def symexp(x):
|
||||
"""
|
||||
Symmetric exponential function.
|
||||
Adapted from https://github.com/danijar/dreamerv3.
|
||||
"""
|
||||
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
|
||||
"""
|
||||
Symmetric exponential function.
|
||||
Adapted from https://github.com/danijar/dreamerv3.
|
||||
"""
|
||||
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
|
||||
|
||||
|
||||
def two_hot(x, cfg):
|
||||
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
|
||||
if cfg.num_bins == 0:
|
||||
return x
|
||||
elif cfg.num_bins == 1:
|
||||
return symlog(x)
|
||||
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1)
|
||||
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
|
||||
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1)
|
||||
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
|
||||
soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset)
|
||||
soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset)
|
||||
return soft_two_hot
|
||||
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
|
||||
if cfg.num_bins == 0:
|
||||
return x
|
||||
elif cfg.num_bins == 1:
|
||||
return symlog(x)
|
||||
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax)
|
||||
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
|
||||
bin_offset = (x - cfg.vmin) / cfg.bin_size - bin_idx.float()
|
||||
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
|
||||
soft_two_hot.scatter_(1, bin_idx, 1 - bin_offset)
|
||||
soft_two_hot.scatter_(1, (bin_idx + 1) % cfg.num_bins, bin_offset)
|
||||
return soft_two_hot
|
||||
|
||||
|
||||
def two_hot_inv(x, bins):
|
||||
"""Converts a batch of soft two-hot encoded vectors to scalars."""
|
||||
num_bins = bins.shape[0]
|
||||
if num_bins == 0:
|
||||
return x
|
||||
elif num_bins == 1:
|
||||
return symexp(x)
|
||||
"""Converts a batch of soft two-hot encoded vectors to scalars."""
|
||||
num_bins = bins.shape[0]
|
||||
if num_bins == 0:
|
||||
return x
|
||||
elif num_bins == 1:
|
||||
return symexp(x)
|
||||
|
||||
x = F.softmax(x, dim=-1)
|
||||
x = torch.sum(x * bins, dim=-1, keepdim=True)
|
||||
return symexp(x)
|
||||
x = F.softmax(x, dim=-1)
|
||||
x = torch.sum(x * bins, dim=-1, keepdim=True)
|
||||
return symexp(x)
|
||||
|
|
|
@ -93,6 +93,18 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||
elif policy.name == "tdmpc":
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
|
||||
elif policy.name == "tdmpc2":
|
||||
params_group = [
|
||||
{"params": policy.model._encoder.parameters(), "lr": cfg.training.lr * cfg.training.enc_lr_scale},
|
||||
{"params": policy.model._dynamics.parameters()},
|
||||
{"params": policy.model._reward.parameters()},
|
||||
{"params": policy.model._Qs.parameters()},
|
||||
{"params": policy.model._pi.parameters(), "eps": 1e-5},
|
||||
]
|
||||
optimizer = torch.optim.Adam(params_group, lr=cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||
|
||||
|
|
Loading…
Reference in New Issue