updated params
This commit is contained in:
parent
53d67bb5b7
commit
d5fb8e9802
|
@ -51,6 +51,11 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||||
|
|
||||||
return TDMPCPolicy, TDMPCConfig
|
return TDMPCPolicy, TDMPCConfig
|
||||||
|
elif name == "tdmpc2":
|
||||||
|
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
||||||
|
from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy
|
||||||
|
|
||||||
|
return TDMPC2Policy, TDMPC2Config
|
||||||
elif name == "diffusion":
|
elif name == "diffusion":
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
|
|
|
@ -106,6 +106,7 @@ class TDMPC2Config:
|
||||||
vmax = +10
|
vmax = +10
|
||||||
rho: float = 0.5
|
rho: float = 0.5
|
||||||
tau: float = 0.01
|
tau: float = 0.01
|
||||||
|
simnorm_dim: int = 8
|
||||||
|
|
||||||
# Input / output structure.
|
# Input / output structure.
|
||||||
n_action_repeats: int = 2
|
n_action_repeats: int = 2
|
||||||
|
@ -115,7 +116,7 @@ class TDMPC2Config:
|
||||||
input_shapes: dict[str, list[int]] = field(
|
input_shapes: dict[str, list[int]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.image": [3, 64, 64],
|
"observation.image": [3, 64, 64],
|
||||||
# "observation.state": [4],
|
"observation.state": [4],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
output_shapes: dict[str, list[int]] = field(
|
output_shapes: dict[str, list[int]] = field(
|
||||||
|
@ -134,7 +135,7 @@ class TDMPC2Config:
|
||||||
# Neural networks.
|
# Neural networks.
|
||||||
image_encoder_hidden_dim: int = 32
|
image_encoder_hidden_dim: int = 32
|
||||||
state_encoder_hidden_dim: int = 256
|
state_encoder_hidden_dim: int = 256
|
||||||
latent_dim: int = 50
|
latent_dim: int = 8 #50
|
||||||
q_ensemble_size: int = 5
|
q_ensemble_size: int = 5
|
||||||
mlp_dim: int = 512
|
mlp_dim: int = 512
|
||||||
# Reinforcement learning.
|
# Reinforcement learning.
|
||||||
|
|
|
@ -111,7 +111,7 @@ class TDMPC2Policy(nn.Module, PyTorchModelHubMixin):
|
||||||
called on `env.reset()`
|
called on `env.reset()`
|
||||||
"""
|
"""
|
||||||
self._queues = {
|
self._queues = {
|
||||||
# "observation.state": deque(maxlen=1),
|
"observation.state": deque(maxlen=1),
|
||||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||||
}
|
}
|
||||||
if self._use_image:
|
if self._use_image:
|
||||||
|
@ -150,9 +150,7 @@ class TDMPC2Policy(nn.Module, PyTorchModelHubMixin):
|
||||||
encode_keys.append("observation.image")
|
encode_keys.append("observation.image")
|
||||||
if self._use_env_state:
|
if self._use_env_state:
|
||||||
encode_keys.append("observation.environment_state")
|
encode_keys.append("observation.environment_state")
|
||||||
|
encode_keys.append("observation.state")
|
||||||
# if False: # hardcoded for initial tdmpc2 impl
|
|
||||||
# encode_keys.append("observation.state")
|
|
||||||
|
|
||||||
z = self.model.encode({k: batch[k] for k in encode_keys})
|
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||||
|
|
||||||
|
@ -337,9 +335,7 @@ class TDMPC2Policy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
td_targets = reward + discount * self.model.Qs(next_z, pi, return_type="min", target=True)
|
td_targets = reward + discount * self.model.Qs(next_z, pi, return_type="min", target=True)
|
||||||
|
|
||||||
# Prepare for update
|
#self.model.train()
|
||||||
self.optim.zero_grad(set_to_none=True)
|
|
||||||
self.model.train()
|
|
||||||
|
|
||||||
# Latent rollout
|
# Latent rollout
|
||||||
zs = torch.empty(self.config.horizon + 1, batch_size, self.config.latent_dim, device=device)
|
zs = torch.empty(self.config.horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||||
|
@ -495,8 +491,6 @@ class TDMPC2TOLD(nn.Module):
|
||||||
"""
|
"""
|
||||||
for p in self._Qs.parameters():
|
for p in self._Qs.parameters():
|
||||||
p.requires_grad_(mode)
|
p.requires_grad_(mode)
|
||||||
if self.config.multitask:
|
|
||||||
raise NotImplementedError("Multitask not implemented for TOLD yet.")
|
|
||||||
|
|
||||||
def weight_init(self, m): # lifted from Nicklas' code
|
def weight_init(self, m): # lifted from Nicklas' code
|
||||||
"""Custom weight initialization for TD-MPC2."""
|
"""Custom weight initialization for TD-MPC2."""
|
||||||
|
@ -556,9 +550,6 @@ class TDMPC2TOLD(nn.Module):
|
||||||
The policy prior is a Gaussian distribution with
|
The policy prior is a Gaussian distribution with
|
||||||
mean and (log) std predicted by a neural network.
|
mean and (log) std predicted by a neural network.
|
||||||
"""
|
"""
|
||||||
if self.config.multitask:
|
|
||||||
raise NotImplementedError("Multitask not implemented for pi yet.")
|
|
||||||
|
|
||||||
# Gaussian policy prior
|
# Gaussian policy prior
|
||||||
mu, log_std = self._pi(z).chunk(2, dim=-1)
|
mu, log_std = self._pi(z).chunk(2, dim=-1)
|
||||||
log_std = utils.log_std_fn(log_std, self.log_std_min, self.log_std_dif)
|
log_std = utils.log_std_fn(log_std, self.log_std_min, self.log_std_dif)
|
||||||
|
@ -593,9 +584,6 @@ class TDMPC2TOLD(nn.Module):
|
||||||
"""
|
"""
|
||||||
assert return_type in {"min", "avg", "all"}
|
assert return_type in {"min", "avg", "all"}
|
||||||
|
|
||||||
if self.config.multitask:
|
|
||||||
raise NotImplementedError("Multitask not implemented for Qs yet.")
|
|
||||||
|
|
||||||
z = torch.cat([z, a], dim=-1)
|
z = torch.cat([z, a], dim=-1)
|
||||||
out = (self._target_Qs if target else self._Qs)(z)
|
out = (self._target_Qs if target else self._Qs)(z)
|
||||||
|
|
||||||
|
@ -639,6 +627,18 @@ class TDMPC2ObservationEncoder(nn.Module):
|
||||||
elif "observation.image" in k:
|
elif "observation.image" in k:
|
||||||
obs_shape = config.input_shapes["observation.image"]
|
obs_shape = config.input_shapes["observation.image"]
|
||||||
self.image_enc_layers = utils.conv(obs_shape, config.num_channels, act=utils.SimNorm(config))
|
self.image_enc_layers = utils.conv(obs_shape, config.num_channels, act=utils.SimNorm(config))
|
||||||
|
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||||
|
with torch.no_grad():
|
||||||
|
out_shape = self.image_enc_layers(dummy_batch).shape[1]
|
||||||
|
self.image_enc_layers.extend(
|
||||||
|
utils.mlp(
|
||||||
|
out_shape,
|
||||||
|
max(config.num_enc_layers - 1, 1) * [config.enc_dim],
|
||||||
|
config.latent_dim,
|
||||||
|
act=utils.SimNorm(config),
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||||
"""Encode the image and/or state vector.
|
"""Encode the image and/or state vector.
|
||||||
|
@ -654,6 +654,7 @@ class TDMPC2ObservationEncoder(nn.Module):
|
||||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||||
if "observation.state" in self.config.input_shapes:
|
if "observation.state" in self.config.input_shapes:
|
||||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||||
|
|
||||||
return torch.stack(feat, dim=0).mean(0)
|
return torch.stack(feat, dim=0).mean(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -106,7 +106,10 @@ def two_hot_inv(x, cfg):
|
||||||
if DREG_BINS is None:
|
if DREG_BINS is None:
|
||||||
DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device)
|
DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device)
|
||||||
x = F.softmax(x, dim=-1)
|
x = F.softmax(x, dim=-1)
|
||||||
x = torch.sum(x * DREG_BINS, dim=-1, keepdim=True)
|
|
||||||
|
# cloning bins to avoid the inference tensor errodr
|
||||||
|
x = torch.sum(x * DREG_BINS.clone(), dim=-1, keepdim=True)
|
||||||
|
|
||||||
return symexp(x)
|
return symexp(x)
|
||||||
|
|
||||||
|
|
||||||
|
@ -177,12 +180,12 @@ class SimNorm(nn.Module):
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# TODO: move to config
|
self.dim = cfg.simnorm_dim
|
||||||
self.dim = 8 # cfg.simnorm_dim
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shp = x.shape
|
shp = x.shape
|
||||||
x = x.view(*shp[:-1], -1, self.dim)
|
x = x.view(*shp[:-1], -1, self.dim)
|
||||||
|
|
||||||
x = F.softmax(x, dim=-1)
|
x = F.softmax(x, dim=-1)
|
||||||
return x.view(*shp)
|
return x.view(*shp)
|
||||||
|
|
||||||
|
@ -237,7 +240,7 @@ def conv(in_shape, num_channels, act=None):
|
||||||
Basic convolutional encoder for TD-MPC2 with raw image observations.
|
Basic convolutional encoder for TD-MPC2 with raw image observations.
|
||||||
4 layers of convolution with ReLU activations, followed by a linear layer.
|
4 layers of convolution with ReLU activations, followed by a linear layer.
|
||||||
"""
|
"""
|
||||||
assert in_shape[-1] == 64 # assumes rgb observations to be 64x64
|
#assert in_shape[-1] == 64 # assumes rgb observations to be 64x64
|
||||||
layers = [
|
layers = [
|
||||||
ShiftAug(),
|
ShiftAug(),
|
||||||
PixelPreprocess(),
|
PixelPreprocess(),
|
||||||
|
|
|
@ -95,13 +95,13 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
elif policy.name == "tdmpc2":
|
elif policy.name == "tdmpc2":
|
||||||
params_group = [
|
params_group = [
|
||||||
{"params": policy.model._encoder.parameters(), "lr": cfg.lr * cfg.enc_lr_scale},
|
{"params": policy.model._encoder.parameters(), "lr": cfg.training.lr * cfg.training.enc_lr_scale},
|
||||||
{"params": policy.model._dynamics.parameters()},
|
{"params": policy.model._dynamics.parameters()},
|
||||||
{"params": policy.model._reward.parameters()},
|
{"params": policy.model._reward.parameters()},
|
||||||
{"params": policy.model._Qs.parameters()},
|
{"params": policy.model._Qs.parameters()},
|
||||||
{"params": policy.model._pi.parameters(), "lr": cfg.lr, "eps": 1e-5},
|
{"params": policy.model._pi.parameters(), "eps": 1e-5},
|
||||||
]
|
]
|
||||||
optimizer = torch.optim.Adam(params_group, lr=cfg.lr)
|
optimizer = torch.optim.Adam(params_group, lr=cfg.training.lr)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
elif cfg.policy.name == "vqbet":
|
elif cfg.policy.name == "vqbet":
|
||||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||||
|
|
Loading…
Reference in New Issue