updated params
This commit is contained in:
parent
53d67bb5b7
commit
d5fb8e9802
|
@ -51,6 +51,11 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
|||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
return TDMPCPolicy, TDMPCConfig
|
||||
elif name == "tdmpc2":
|
||||
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
||||
from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy
|
||||
|
||||
return TDMPC2Policy, TDMPC2Config
|
||||
elif name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
|
|
@ -106,6 +106,7 @@ class TDMPC2Config:
|
|||
vmax = +10
|
||||
rho: float = 0.5
|
||||
tau: float = 0.01
|
||||
simnorm_dim: int = 8
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: int = 2
|
||||
|
@ -115,7 +116,7 @@ class TDMPC2Config:
|
|||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 64, 64],
|
||||
# "observation.state": [4],
|
||||
"observation.state": [4],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
|
@ -134,7 +135,7 @@ class TDMPC2Config:
|
|||
# Neural networks.
|
||||
image_encoder_hidden_dim: int = 32
|
||||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 50
|
||||
latent_dim: int = 8 #50
|
||||
q_ensemble_size: int = 5
|
||||
mlp_dim: int = 512
|
||||
# Reinforcement learning.
|
||||
|
|
|
@ -111,7 +111,7 @@ class TDMPC2Policy(nn.Module, PyTorchModelHubMixin):
|
|||
called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
# "observation.state": deque(maxlen=1),
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self._use_image:
|
||||
|
@ -150,9 +150,7 @@ class TDMPC2Policy(nn.Module, PyTorchModelHubMixin):
|
|||
encode_keys.append("observation.image")
|
||||
if self._use_env_state:
|
||||
encode_keys.append("observation.environment_state")
|
||||
|
||||
# if False: # hardcoded for initial tdmpc2 impl
|
||||
# encode_keys.append("observation.state")
|
||||
encode_keys.append("observation.state")
|
||||
|
||||
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||
|
||||
|
@ -337,9 +335,7 @@ class TDMPC2Policy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
td_targets = reward + discount * self.model.Qs(next_z, pi, return_type="min", target=True)
|
||||
|
||||
# Prepare for update
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
self.model.train()
|
||||
#self.model.train()
|
||||
|
||||
# Latent rollout
|
||||
zs = torch.empty(self.config.horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||
|
@ -495,8 +491,6 @@ class TDMPC2TOLD(nn.Module):
|
|||
"""
|
||||
for p in self._Qs.parameters():
|
||||
p.requires_grad_(mode)
|
||||
if self.config.multitask:
|
||||
raise NotImplementedError("Multitask not implemented for TOLD yet.")
|
||||
|
||||
def weight_init(self, m): # lifted from Nicklas' code
|
||||
"""Custom weight initialization for TD-MPC2."""
|
||||
|
@ -556,9 +550,6 @@ class TDMPC2TOLD(nn.Module):
|
|||
The policy prior is a Gaussian distribution with
|
||||
mean and (log) std predicted by a neural network.
|
||||
"""
|
||||
if self.config.multitask:
|
||||
raise NotImplementedError("Multitask not implemented for pi yet.")
|
||||
|
||||
# Gaussian policy prior
|
||||
mu, log_std = self._pi(z).chunk(2, dim=-1)
|
||||
log_std = utils.log_std_fn(log_std, self.log_std_min, self.log_std_dif)
|
||||
|
@ -593,9 +584,6 @@ class TDMPC2TOLD(nn.Module):
|
|||
"""
|
||||
assert return_type in {"min", "avg", "all"}
|
||||
|
||||
if self.config.multitask:
|
||||
raise NotImplementedError("Multitask not implemented for Qs yet.")
|
||||
|
||||
z = torch.cat([z, a], dim=-1)
|
||||
out = (self._target_Qs if target else self._Qs)(z)
|
||||
|
||||
|
@ -639,6 +627,18 @@ class TDMPC2ObservationEncoder(nn.Module):
|
|||
elif "observation.image" in k:
|
||||
obs_shape = config.input_shapes["observation.image"]
|
||||
self.image_enc_layers = utils.conv(obs_shape, config.num_channels, act=utils.SimNorm(config))
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.no_grad():
|
||||
out_shape = self.image_enc_layers(dummy_batch).shape[1]
|
||||
self.image_enc_layers.extend(
|
||||
utils.mlp(
|
||||
out_shape,
|
||||
max(config.num_enc_layers - 1, 1) * [config.enc_dim],
|
||||
config.latent_dim,
|
||||
act=utils.SimNorm(config),
|
||||
))
|
||||
|
||||
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
@ -654,6 +654,7 @@ class TDMPC2ObservationEncoder(nn.Module):
|
|||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
|
|
|
@ -106,7 +106,10 @@ def two_hot_inv(x, cfg):
|
|||
if DREG_BINS is None:
|
||||
DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device)
|
||||
x = F.softmax(x, dim=-1)
|
||||
x = torch.sum(x * DREG_BINS, dim=-1, keepdim=True)
|
||||
|
||||
# cloning bins to avoid the inference tensor errodr
|
||||
x = torch.sum(x * DREG_BINS.clone(), dim=-1, keepdim=True)
|
||||
|
||||
return symexp(x)
|
||||
|
||||
|
||||
|
@ -177,12 +180,12 @@ class SimNorm(nn.Module):
|
|||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
# TODO: move to config
|
||||
self.dim = 8 # cfg.simnorm_dim
|
||||
self.dim = cfg.simnorm_dim
|
||||
|
||||
def forward(self, x):
|
||||
shp = x.shape
|
||||
x = x.view(*shp[:-1], -1, self.dim)
|
||||
|
||||
x = F.softmax(x, dim=-1)
|
||||
return x.view(*shp)
|
||||
|
||||
|
@ -237,7 +240,7 @@ def conv(in_shape, num_channels, act=None):
|
|||
Basic convolutional encoder for TD-MPC2 with raw image observations.
|
||||
4 layers of convolution with ReLU activations, followed by a linear layer.
|
||||
"""
|
||||
assert in_shape[-1] == 64 # assumes rgb observations to be 64x64
|
||||
#assert in_shape[-1] == 64 # assumes rgb observations to be 64x64
|
||||
layers = [
|
||||
ShiftAug(),
|
||||
PixelPreprocess(),
|
||||
|
|
|
@ -95,13 +95,13 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||
lr_scheduler = None
|
||||
elif policy.name == "tdmpc2":
|
||||
params_group = [
|
||||
{"params": policy.model._encoder.parameters(), "lr": cfg.lr * cfg.enc_lr_scale},
|
||||
{"params": policy.model._encoder.parameters(), "lr": cfg.training.lr * cfg.training.enc_lr_scale},
|
||||
{"params": policy.model._dynamics.parameters()},
|
||||
{"params": policy.model._reward.parameters()},
|
||||
{"params": policy.model._Qs.parameters()},
|
||||
{"params": policy.model._pi.parameters(), "lr": cfg.lr, "eps": 1e-5},
|
||||
{"params": policy.model._pi.parameters(), "eps": 1e-5},
|
||||
]
|
||||
optimizer = torch.optim.Adam(params_group, lr=cfg.lr)
|
||||
optimizer = torch.optim.Adam(params_group, lr=cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||
|
|
Loading…
Reference in New Issue