added support for maniskill environment in factory.py

This commit is contained in:
Michel Aractingi 2024-11-26 09:35:09 +00:00
parent 15090c2544
commit f3450aea41
6 changed files with 296 additions and 170 deletions

View File

@ -14,8 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib import importlib
from collections import deque
import gymnasium as gym import gymnasium as gym
import numpy as np
import torch
from omegaconf import DictConfig from omegaconf import DictConfig
@ -30,6 +33,10 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
if cfg.env.name == "real_world": if cfg.env.name == "real_world":
return return
if "maniskill" in cfg.env.name:
env = make_maniskill_env(cfg, n_envs if n_envs is not None else cfg.eval.batch_size)
return env
package_name = f"gym_{cfg.env.name}" package_name = f"gym_{cfg.env.name}"
try: try:
@ -56,3 +63,58 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
) )
return env return env
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
env = gym.make(
cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
num_envs=n_envs,
)
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
env = ManiSkillVectorEnv(env, ignore_terminations=True)
env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env.unwrapped.metadata["render_fps"] = 20
return env
class PixelWrapper(gym.Wrapper):
"""
Wrapper for pixel observations. Works with Maniskill vectorized environments
"""
def __init__(self, cfg, env, num_envs, num_frames=3):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
dtype=np.uint8,
)
self._frames = deque([], maxlen=num_frames)
self._render_size = cfg.env.render_size
def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame)
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
def reset(self, seed):
obs, info = self.env.reset() # (seed=seed)
for _ in range(self._frames.maxlen):
obs_frames = self._get_obs(obs)
return obs_frames, info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info

View File

@ -51,6 +51,13 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy, TDMPCConfig return TDMPCPolicy, TDMPCConfig
elif name == "tdmpc2":
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy
return TDMPC2Policy, TDMPC2Config
elif name == "diffusion": elif name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

View File

@ -126,9 +126,12 @@ class TDMPC2Config:
state_encoder_hidden_dim: int = 256 state_encoder_hidden_dim: int = 256
latent_dim: int = 512 latent_dim: int = 512
q_ensemble_size: int = 5 q_ensemble_size: int = 5
num_enc_layers: int = 2
mlp_dim: int = 512 mlp_dim: int = 512
# Reinforcement learning. # Reinforcement learning.
discount: float = 0.9 discount: float = 0.9
simnorm_dim: int = 8
dropout: float = 0.01
# actor # actor
log_std_min: float = -10 log_std_min: float = -10

View File

@ -38,8 +38,16 @@ from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.tdmpc2.tdmpc2_utils import (
NormedLinear,
SimNorm,
gaussian_logprob,
soft_cross_entropy,
squash,
two_hot_inv,
)
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash, soft_cross_entropy
class TDMPC2Policy( class TDMPC2Policy(
nn.Module, nn.Module,
@ -83,6 +91,8 @@ class TDMPC2Policy(
config = TDMPC2Config() config = TDMPC2Config()
self.config = config self.config = config
self.model = TDMPC2WorldModel(config) self.model = TDMPC2WorldModel(config)
# TODO (michel-aractingi) temp fix for gpu
self.model = self.model.to("cuda:0")
if config.input_normalization_modes is not None: if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(
@ -109,7 +119,9 @@ class TDMPC2Policy(
self._use_env_state = True self._use_env_state = True
self.scale = RunningScale(self.config.target_model_momentum) self.scale = RunningScale(self.config.target_model_momentum)
self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length self.discount = (
self.config.discount
) # TODO (michel-aractingi) downscale discount according to episode length
self.reset() self.reset()
@ -204,7 +216,7 @@ class TDMPC2Policy(
for t in range(self.config.horizon): for t in range(self.config.horizon):
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be # Note: Adding a small amount of noise here doesn't hurt during inference and may even be
# helpful for CEM. # helpful for CEM.
pi_actions[t] = self.model.pi(_z, self.config.min_std)[0] pi_actions[t] = self.model.pi(_z)[0]
_z = self.model.latent_dynamics(_z, pi_actions[t]) _z = self.model.latent_dynamics(_z, pi_actions[t])
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled # In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
@ -249,13 +261,16 @@ class TDMPC2Policy(
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score /= score.sum(axis=0, keepdim=True) score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim) # (horizon, batch, action_dim)
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (score.sum(0) + 1e-9) mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (
einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9
)
std = torch.sqrt( std = torch.sqrt(
torch.sum( torch.sum(
einops.rearrange(score, "n b -> n b 1") einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2, * (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2,
dim=1, dim=1,
) / (score.sum(0) + 1e-9) )
/ (einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9)
).clamp_(self.config.min_std, self.config.max_std) ).clamp_(self.config.min_std, self.config.max_std)
# Keep track of the mean for warm-starting subsequent steps. # Keep track of the mean for warm-starting subsequent steps.
@ -290,7 +305,7 @@ class TDMPC2Policy(
# next_action = self.model.pi(z)[0] # (batch, action_dim) # next_action = self.model.pi(z)[0] # (batch, action_dim)
# terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch) # terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type='avg') return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type="avg")
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss. """Run the batch through the model and compute the loss.
@ -358,7 +373,8 @@ class TDMPC2Policy(
pi = self.model.pi(z_targets)[0] pi = self.model.pi(z_targets)[0]
td_targets = ( td_targets = (
reward reward
+ self.config.discount * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze() + self.config.discount
* self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
) )
# Compute losses. # Compute losses.
@ -429,7 +445,9 @@ class TDMPC2Policy(
self.scale.update(qs[0]) self.scale.update(qs[0])
qs = self.scale(qs) qs = self.scale(qs)
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(-1) rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(
-1
)
pi_loss = ( pi_loss = (
(self.config.entropy_coef * log_pis - qs).mean(dim=(1, 2)) (self.config.entropy_coef * log_pis - qs).mean(dim=(1, 2))
@ -470,6 +488,7 @@ class TDMPC2Policy(
"""Update the target model's using polyak averaging.""" """Update the target model's using polyak averaging."""
self.model.update_target_Q() self.model.update_target_Q()
class TDMPC2WorldModel(nn.Module): class TDMPC2WorldModel(nn.Module):
"""Latent dynamics model used in TD-MPC2.""" """Latent dynamics model used in TD-MPC2."""
@ -480,28 +499,39 @@ class TDMPC2WorldModel(nn.Module):
self._encoder = TDMPC2ObservationEncoder(config) self._encoder = TDMPC2ObservationEncoder(config)
# Define latent dynamics head # Define latent dynamics head
self._dynamics = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), self._dynamics = nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim), NormedLinear(config.mlp_dim, config.mlp_dim),
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim))) NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)),
)
# Define reward head # Define reward head
self._reward = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), self._reward = nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim), NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1))) nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
)
# Define policy head # Define policy head
self._pi = nn.Sequential(NormedLinear(config.latent_dim, config.mlp_dim), self._pi = nn.Sequential(
NormedLinear(config.latent_dim, config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim), NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0])) nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]),
)
# Define ensemble of Q functions # Define ensemble of Q functions
self._Qs = nn.ModuleList( self._Qs = nn.ModuleList(
[ [
nn.Sequential( nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim, dropout=config.dropout), NormedLinear(
config.latent_dim + config.output_shapes["action"][0],
config.mlp_dim,
dropout=config.dropout,
),
NormedLinear(config.mlp_dim, config.mlp_dim), NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1)) nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
) for _ in range(config.q_ensemble_size) )
for _ in range(config.q_ensemble_size)
] ]
) )
@ -520,6 +550,7 @@ class TDMPC2WorldModel(nn.Module):
Custom weight initializations proposed in TD-MPC2. Custom weight initializations proposed in TD-MPC2.
""" """
def _apply_fn(m): def _apply_fn(m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
@ -641,7 +672,7 @@ class TDMPC2WorldModel(nn.Module):
Soft-update target Q-networks using Polyak averaging. Soft-update target Q-networks using Polyak averaging.
""" """
with torch.no_grad(): with torch.no_grad():
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()): for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters(), strict=False):
p_target.data.lerp_(p.data, self.config.target_model_momentum) p_target.data.lerp_(p.data, self.config.target_model_momentum)
@ -672,7 +703,7 @@ class TDMPC2ObservationEncoder(nn.Module):
) )
dummy_batch = torch.zeros(1, *config.input_shapes[obs_key]) dummy_batch = torch.zeros(1, *config.input_shapes[obs_key])
with torch.inference_mode(): with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:] out_shape = encoder_module(dummy_batch).shape[1:]
encoder_module.extend( encoder_module.extend(
nn.Sequential( nn.Sequential(
nn.Flatten(), nn.Flatten(),
@ -680,28 +711,30 @@ class TDMPC2ObservationEncoder(nn.Module):
) )
) )
elif "observation.state" in config.input_shapes: elif (
"observation.state" in config.input_shapes
or "observation.environment_state" in config.input_shapes
):
encoder_module = nn.ModuleList() encoder_module = nn.ModuleList()
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)) encoder_module.append(
NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)
)
assert config.num_enc_layers > 0 assert config.num_enc_layers > 0
for _ in range(config.num_enc_layers - 1): for _ in range(config.num_enc_layers - 1):
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)) encoder_module.append(
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim))) NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)
encoder_module = nn.Sequential(*encoder_module) )
encoder_module.append(
elif "observation.environment_state" in config.input_shapes: NormedLinear(
encoder_module = nn.ModuleList() config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)) )
assert config.num_enc_layers > 0 )
for _ in range(config.num_enc_layers - 1):
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
encoder_module = nn.Sequential(*encoder_module) encoder_module = nn.Sequential(*encoder_module)
else: else:
raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.") raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.")
encoder_dict[obs_key] = encoder_module encoder_dict[obs_key.replace(".", "")] = encoder_module
self.encoder = nn.ModuleDict(encoder_dict) self.encoder = nn.ModuleDict(encoder_dict)
@ -714,9 +747,11 @@ class TDMPC2ObservationEncoder(nn.Module):
feat = [] feat = []
for obs_key in self.config.input_shapes: for obs_key in self.config.input_shapes:
if "observation.image" in obs_key: if "observation.image" in obs_key:
feat.append(flatten_forward_unflatten(self.encoder[obs_key], obs_dict[obs_key])) feat.append(
flatten_forward_unflatten(self.encoder[obs_key.replace(".", "")], obs_dict[obs_key])
)
else: else:
feat.append(self.encoder[obs_key](obs_dict[obs_key])) feat.append(self.encoder[obs_key.replace(".", "")](obs_dict[obs_key]))
return torch.stack(feat, dim=0).mean(0) return torch.stack(feat, dim=0).mean(0)

View File

@ -13,7 +13,7 @@ class Ensemble(nn.Module):
super().__init__() super().__init__()
modules = nn.ModuleList(modules) modules = nn.ModuleList(modules)
fn, params, _ = combine_state_for_ensemble(modules) fn, params, _ = combine_state_for_ensemble(modules)
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs) self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness="different", **kwargs)
self.params = nn.ParameterList([nn.Parameter(p) for p in params]) self.params = nn.ParameterList([nn.Parameter(p) for p in params])
self._repr = str(modules) self._repr = str(modules)
@ -21,7 +21,8 @@ class Ensemble(nn.Module):
return self.vmap([p for p in self.params], (), *args, **kwargs) return self.vmap([p for p in self.params], (), *args, **kwargs)
def __repr__(self): def __repr__(self):
return 'Vectorized ' + self._repr return "Vectorized " + self._repr
class SimNorm(nn.Module): class SimNorm(nn.Module):
""" """
@ -48,7 +49,7 @@ class NormedLinear(nn.Linear):
Linear layer with LayerNorm, activation, and optionally dropout. Linear layer with LayerNorm, activation, and optionally dropout.
""" """
def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs): def __init__(self, *args, dropout=0.0, act=nn.Mish(inplace=True), **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.ln = nn.LayerNorm(self.out_features) self.ln = nn.LayerNorm(self.out_features)
self.act = act self.act = act
@ -62,16 +63,21 @@ class NormedLinear(nn.Linear):
def __repr__(self): def __repr__(self):
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else "" repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
return f"NormedLinear(in_features={self.in_features}, "\ return (
f"out_features={self.out_features}, "\ f"NormedLinear(in_features={self.in_features}, "
f"bias={self.bias is not None}{repr_dropout}, "\ f"out_features={self.out_features}, "
f"bias={self.bias is not None}{repr_dropout}, "
f"act={self.act.__class__.__name__})" f"act={self.act.__class__.__name__})"
)
def soft_cross_entropy(pred, target, cfg): def soft_cross_entropy(pred, target, cfg):
"""Computes the cross entropy loss between predictions and soft targets.""" """Computes the cross entropy loss between predictions and soft targets."""
pred = F.log_softmax(pred, dim=-1) pred = F.log_softmax(pred, dim=-1)
target = two_hot(target, cfg) target = two_hot(target, cfg)
import pudb
pudb.set_trace()
return -(target * pred).sum(-1, keepdim=True) return -(target * pred).sum(-1, keepdim=True)
@ -135,14 +141,15 @@ def two_hot(x, cfg):
return x return x
elif cfg.num_bins == 1: elif cfg.num_bins == 1:
return symlog(x) return symlog(x)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1) x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) bin_offset = (x - cfg.vmin) / cfg.bin_size - bin_idx.float()
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device) soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset) soft_two_hot.scatter_(1, bin_idx, 1 - bin_offset)
soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset) soft_two_hot.scatter_(1, (bin_idx + 1) % cfg.num_bins, bin_offset)
return soft_two_hot return soft_two_hot
def two_hot_inv(x, bins): def two_hot_inv(x, bins):
"""Converts a batch of soft two-hot encoded vectors to scalars.""" """Converts a batch of soft two-hot encoded vectors to scalars."""
num_bins = bins.shape[0] num_bins = bins.shape[0]

View File

@ -93,6 +93,18 @@ def make_optimizer_and_scheduler(cfg, policy):
elif policy.name == "tdmpc": elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None lr_scheduler = None
elif policy.name == "tdmpc2":
params_group = [
{"params": policy.model._encoder.parameters(), "lr": cfg.training.lr * cfg.training.enc_lr_scale},
{"params": policy.model._dynamics.parameters()},
{"params": policy.model._reward.parameters()},
{"params": policy.model._Qs.parameters()},
{"params": policy.model._pi.parameters(), "eps": 1e-5},
]
optimizer = torch.optim.Adam(params_group, lr=cfg.training.lr)
lr_scheduler = None
elif cfg.policy.name == "vqbet": elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler