updated params

This commit is contained in:
Michel Aractingi 2024-09-02 06:34:24 +00:00
parent 53d67bb5b7
commit d5fb8e9802
5 changed files with 34 additions and 24 deletions

View File

@ -51,6 +51,11 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy, TDMPCConfig return TDMPCPolicy, TDMPCConfig
elif name == "tdmpc2":
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy
return TDMPC2Policy, TDMPC2Config
elif name == "diffusion": elif name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

View File

@ -106,6 +106,7 @@ class TDMPC2Config:
vmax = +10 vmax = +10
rho: float = 0.5 rho: float = 0.5
tau: float = 0.01 tau: float = 0.01
simnorm_dim: int = 8
# Input / output structure. # Input / output structure.
n_action_repeats: int = 2 n_action_repeats: int = 2
@ -115,7 +116,7 @@ class TDMPC2Config:
input_shapes: dict[str, list[int]] = field( input_shapes: dict[str, list[int]] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": [3, 64, 64], "observation.image": [3, 64, 64],
# "observation.state": [4], "observation.state": [4],
} }
) )
output_shapes: dict[str, list[int]] = field( output_shapes: dict[str, list[int]] = field(
@ -134,7 +135,7 @@ class TDMPC2Config:
# Neural networks. # Neural networks.
image_encoder_hidden_dim: int = 32 image_encoder_hidden_dim: int = 32
state_encoder_hidden_dim: int = 256 state_encoder_hidden_dim: int = 256
latent_dim: int = 50 latent_dim: int = 8 #50
q_ensemble_size: int = 5 q_ensemble_size: int = 5
mlp_dim: int = 512 mlp_dim: int = 512
# Reinforcement learning. # Reinforcement learning.

View File

@ -111,7 +111,7 @@ class TDMPC2Policy(nn.Module, PyTorchModelHubMixin):
called on `env.reset()` called on `env.reset()`
""" """
self._queues = { self._queues = {
# "observation.state": deque(maxlen=1), "observation.state": deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
} }
if self._use_image: if self._use_image:
@ -150,9 +150,7 @@ class TDMPC2Policy(nn.Module, PyTorchModelHubMixin):
encode_keys.append("observation.image") encode_keys.append("observation.image")
if self._use_env_state: if self._use_env_state:
encode_keys.append("observation.environment_state") encode_keys.append("observation.environment_state")
encode_keys.append("observation.state")
# if False: # hardcoded for initial tdmpc2 impl
# encode_keys.append("observation.state")
z = self.model.encode({k: batch[k] for k in encode_keys}) z = self.model.encode({k: batch[k] for k in encode_keys})
@ -337,9 +335,7 @@ class TDMPC2Policy(nn.Module, PyTorchModelHubMixin):
td_targets = reward + discount * self.model.Qs(next_z, pi, return_type="min", target=True) td_targets = reward + discount * self.model.Qs(next_z, pi, return_type="min", target=True)
# Prepare for update #self.model.train()
self.optim.zero_grad(set_to_none=True)
self.model.train()
# Latent rollout # Latent rollout
zs = torch.empty(self.config.horizon + 1, batch_size, self.config.latent_dim, device=device) zs = torch.empty(self.config.horizon + 1, batch_size, self.config.latent_dim, device=device)
@ -495,8 +491,6 @@ class TDMPC2TOLD(nn.Module):
""" """
for p in self._Qs.parameters(): for p in self._Qs.parameters():
p.requires_grad_(mode) p.requires_grad_(mode)
if self.config.multitask:
raise NotImplementedError("Multitask not implemented for TOLD yet.")
def weight_init(self, m): # lifted from Nicklas' code def weight_init(self, m): # lifted from Nicklas' code
"""Custom weight initialization for TD-MPC2.""" """Custom weight initialization for TD-MPC2."""
@ -556,9 +550,6 @@ class TDMPC2TOLD(nn.Module):
The policy prior is a Gaussian distribution with The policy prior is a Gaussian distribution with
mean and (log) std predicted by a neural network. mean and (log) std predicted by a neural network.
""" """
if self.config.multitask:
raise NotImplementedError("Multitask not implemented for pi yet.")
# Gaussian policy prior # Gaussian policy prior
mu, log_std = self._pi(z).chunk(2, dim=-1) mu, log_std = self._pi(z).chunk(2, dim=-1)
log_std = utils.log_std_fn(log_std, self.log_std_min, self.log_std_dif) log_std = utils.log_std_fn(log_std, self.log_std_min, self.log_std_dif)
@ -593,9 +584,6 @@ class TDMPC2TOLD(nn.Module):
""" """
assert return_type in {"min", "avg", "all"} assert return_type in {"min", "avg", "all"}
if self.config.multitask:
raise NotImplementedError("Multitask not implemented for Qs yet.")
z = torch.cat([z, a], dim=-1) z = torch.cat([z, a], dim=-1)
out = (self._target_Qs if target else self._Qs)(z) out = (self._target_Qs if target else self._Qs)(z)
@ -639,6 +627,18 @@ class TDMPC2ObservationEncoder(nn.Module):
elif "observation.image" in k: elif "observation.image" in k:
obs_shape = config.input_shapes["observation.image"] obs_shape = config.input_shapes["observation.image"]
self.image_enc_layers = utils.conv(obs_shape, config.num_channels, act=utils.SimNorm(config)) self.image_enc_layers = utils.conv(obs_shape, config.num_channels, act=utils.SimNorm(config))
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.no_grad():
out_shape = self.image_enc_layers(dummy_batch).shape[1]
self.image_enc_layers.extend(
utils.mlp(
out_shape,
max(config.num_enc_layers - 1, 1) * [config.enc_dim],
config.latent_dim,
act=utils.SimNorm(config),
))
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector. """Encode the image and/or state vector.
@ -654,6 +654,7 @@ class TDMPC2ObservationEncoder(nn.Module):
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes: if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"])) feat.append(self.state_enc_layers(obs_dict["observation.state"]))
return torch.stack(feat, dim=0).mean(0) return torch.stack(feat, dim=0).mean(0)

View File

@ -106,7 +106,10 @@ def two_hot_inv(x, cfg):
if DREG_BINS is None: if DREG_BINS is None:
DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device) DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device)
x = F.softmax(x, dim=-1) x = F.softmax(x, dim=-1)
x = torch.sum(x * DREG_BINS, dim=-1, keepdim=True)
# cloning bins to avoid the inference tensor errodr
x = torch.sum(x * DREG_BINS.clone(), dim=-1, keepdim=True)
return symexp(x) return symexp(x)
@ -177,12 +180,12 @@ class SimNorm(nn.Module):
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
# TODO: move to config self.dim = cfg.simnorm_dim
self.dim = 8 # cfg.simnorm_dim
def forward(self, x): def forward(self, x):
shp = x.shape shp = x.shape
x = x.view(*shp[:-1], -1, self.dim) x = x.view(*shp[:-1], -1, self.dim)
x = F.softmax(x, dim=-1) x = F.softmax(x, dim=-1)
return x.view(*shp) return x.view(*shp)
@ -237,7 +240,7 @@ def conv(in_shape, num_channels, act=None):
Basic convolutional encoder for TD-MPC2 with raw image observations. Basic convolutional encoder for TD-MPC2 with raw image observations.
4 layers of convolution with ReLU activations, followed by a linear layer. 4 layers of convolution with ReLU activations, followed by a linear layer.
""" """
assert in_shape[-1] == 64 # assumes rgb observations to be 64x64 #assert in_shape[-1] == 64 # assumes rgb observations to be 64x64
layers = [ layers = [
ShiftAug(), ShiftAug(),
PixelPreprocess(), PixelPreprocess(),

View File

@ -95,13 +95,13 @@ def make_optimizer_and_scheduler(cfg, policy):
lr_scheduler = None lr_scheduler = None
elif policy.name == "tdmpc2": elif policy.name == "tdmpc2":
params_group = [ params_group = [
{"params": policy.model._encoder.parameters(), "lr": cfg.lr * cfg.enc_lr_scale}, {"params": policy.model._encoder.parameters(), "lr": cfg.training.lr * cfg.training.enc_lr_scale},
{"params": policy.model._dynamics.parameters()}, {"params": policy.model._dynamics.parameters()},
{"params": policy.model._reward.parameters()}, {"params": policy.model._reward.parameters()},
{"params": policy.model._Qs.parameters()}, {"params": policy.model._Qs.parameters()},
{"params": policy.model._pi.parameters(), "lr": cfg.lr, "eps": 1e-5}, {"params": policy.model._pi.parameters(), "eps": 1e-5},
] ]
optimizer = torch.optim.Adam(params_group, lr=cfg.lr) optimizer = torch.optim.Adam(params_group, lr=cfg.training.lr)
lr_scheduler = None lr_scheduler = None
elif cfg.policy.name == "vqbet": elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler