updated params

This commit is contained in:
Michel Aractingi 2024-09-02 06:34:24 +00:00
parent 53d67bb5b7
commit d5fb8e9802
5 changed files with 34 additions and 24 deletions

View File

@ -51,6 +51,11 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy, TDMPCConfig
elif name == "tdmpc2":
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy
return TDMPC2Policy, TDMPC2Config
elif name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

View File

@ -106,6 +106,7 @@ class TDMPC2Config:
vmax = +10
rho: float = 0.5
tau: float = 0.01
simnorm_dim: int = 8
# Input / output structure.
n_action_repeats: int = 2
@ -115,7 +116,7 @@ class TDMPC2Config:
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 64, 64],
# "observation.state": [4],
"observation.state": [4],
}
)
output_shapes: dict[str, list[int]] = field(
@ -134,7 +135,7 @@ class TDMPC2Config:
# Neural networks.
image_encoder_hidden_dim: int = 32
state_encoder_hidden_dim: int = 256
latent_dim: int = 50
latent_dim: int = 8 #50
q_ensemble_size: int = 5
mlp_dim: int = 512
# Reinforcement learning.

View File

@ -111,7 +111,7 @@ class TDMPC2Policy(nn.Module, PyTorchModelHubMixin):
called on `env.reset()`
"""
self._queues = {
# "observation.state": deque(maxlen=1),
"observation.state": deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
}
if self._use_image:
@ -150,9 +150,7 @@ class TDMPC2Policy(nn.Module, PyTorchModelHubMixin):
encode_keys.append("observation.image")
if self._use_env_state:
encode_keys.append("observation.environment_state")
# if False: # hardcoded for initial tdmpc2 impl
# encode_keys.append("observation.state")
encode_keys.append("observation.state")
z = self.model.encode({k: batch[k] for k in encode_keys})
@ -337,9 +335,7 @@ class TDMPC2Policy(nn.Module, PyTorchModelHubMixin):
td_targets = reward + discount * self.model.Qs(next_z, pi, return_type="min", target=True)
# Prepare for update
self.optim.zero_grad(set_to_none=True)
self.model.train()
#self.model.train()
# Latent rollout
zs = torch.empty(self.config.horizon + 1, batch_size, self.config.latent_dim, device=device)
@ -495,8 +491,6 @@ class TDMPC2TOLD(nn.Module):
"""
for p in self._Qs.parameters():
p.requires_grad_(mode)
if self.config.multitask:
raise NotImplementedError("Multitask not implemented for TOLD yet.")
def weight_init(self, m): # lifted from Nicklas' code
"""Custom weight initialization for TD-MPC2."""
@ -556,9 +550,6 @@ class TDMPC2TOLD(nn.Module):
The policy prior is a Gaussian distribution with
mean and (log) std predicted by a neural network.
"""
if self.config.multitask:
raise NotImplementedError("Multitask not implemented for pi yet.")
# Gaussian policy prior
mu, log_std = self._pi(z).chunk(2, dim=-1)
log_std = utils.log_std_fn(log_std, self.log_std_min, self.log_std_dif)
@ -593,9 +584,6 @@ class TDMPC2TOLD(nn.Module):
"""
assert return_type in {"min", "avg", "all"}
if self.config.multitask:
raise NotImplementedError("Multitask not implemented for Qs yet.")
z = torch.cat([z, a], dim=-1)
out = (self._target_Qs if target else self._Qs)(z)
@ -639,6 +627,18 @@ class TDMPC2ObservationEncoder(nn.Module):
elif "observation.image" in k:
obs_shape = config.input_shapes["observation.image"]
self.image_enc_layers = utils.conv(obs_shape, config.num_channels, act=utils.SimNorm(config))
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.no_grad():
out_shape = self.image_enc_layers(dummy_batch).shape[1]
self.image_enc_layers.extend(
utils.mlp(
out_shape,
max(config.num_enc_layers - 1, 1) * [config.enc_dim],
config.latent_dim,
act=utils.SimNorm(config),
))
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
@ -654,6 +654,7 @@ class TDMPC2ObservationEncoder(nn.Module):
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
return torch.stack(feat, dim=0).mean(0)

View File

@ -106,7 +106,10 @@ def two_hot_inv(x, cfg):
if DREG_BINS is None:
DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device)
x = F.softmax(x, dim=-1)
x = torch.sum(x * DREG_BINS, dim=-1, keepdim=True)
# cloning bins to avoid the inference tensor errodr
x = torch.sum(x * DREG_BINS.clone(), dim=-1, keepdim=True)
return symexp(x)
@ -177,12 +180,12 @@ class SimNorm(nn.Module):
def __init__(self, cfg):
super().__init__()
# TODO: move to config
self.dim = 8 # cfg.simnorm_dim
self.dim = cfg.simnorm_dim
def forward(self, x):
shp = x.shape
x = x.view(*shp[:-1], -1, self.dim)
x = F.softmax(x, dim=-1)
return x.view(*shp)
@ -237,7 +240,7 @@ def conv(in_shape, num_channels, act=None):
Basic convolutional encoder for TD-MPC2 with raw image observations.
4 layers of convolution with ReLU activations, followed by a linear layer.
"""
assert in_shape[-1] == 64 # assumes rgb observations to be 64x64
#assert in_shape[-1] == 64 # assumes rgb observations to be 64x64
layers = [
ShiftAug(),
PixelPreprocess(),

View File

@ -95,13 +95,13 @@ def make_optimizer_and_scheduler(cfg, policy):
lr_scheduler = None
elif policy.name == "tdmpc2":
params_group = [
{"params": policy.model._encoder.parameters(), "lr": cfg.lr * cfg.enc_lr_scale},
{"params": policy.model._encoder.parameters(), "lr": cfg.training.lr * cfg.training.enc_lr_scale},
{"params": policy.model._dynamics.parameters()},
{"params": policy.model._reward.parameters()},
{"params": policy.model._Qs.parameters()},
{"params": policy.model._pi.parameters(), "lr": cfg.lr, "eps": 1e-5},
{"params": policy.model._pi.parameters(), "eps": 1e-5},
]
optimizer = torch.optim.Adam(params_group, lr=cfg.lr)
optimizer = torch.optim.Adam(params_group, lr=cfg.training.lr)
lr_scheduler = None
elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler