updated configuration parameters
This commit is contained in:
parent
31984645da
commit
166c1fc776
|
@ -102,8 +102,8 @@ class TDMPC2Config:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Input / output structure.
|
# Input / output structure.
|
||||||
n_action_repeats: int = 2
|
n_action_repeats: int = 1
|
||||||
horizon: int = 5
|
horizon: int = 3
|
||||||
n_action_steps: int = 1
|
n_action_steps: int = 1
|
||||||
|
|
||||||
input_shapes: dict[str, list[int]] = field(
|
input_shapes: dict[str, list[int]] = field(
|
||||||
|
@ -128,7 +128,7 @@ class TDMPC2Config:
|
||||||
# Neural networks.
|
# Neural networks.
|
||||||
image_encoder_hidden_dim: int = 32
|
image_encoder_hidden_dim: int = 32
|
||||||
state_encoder_hidden_dim: int = 256
|
state_encoder_hidden_dim: int = 256
|
||||||
latent_dim: int = 50
|
latent_dim: int = 512
|
||||||
q_ensemble_size: int = 5
|
q_ensemble_size: int = 5
|
||||||
mlp_dim: int = 512
|
mlp_dim: int = 512
|
||||||
# Reinforcement learning.
|
# Reinforcement learning.
|
||||||
|
@ -137,39 +137,34 @@ class TDMPC2Config:
|
||||||
# actor
|
# actor
|
||||||
log_std_min: float = -10
|
log_std_min: float = -10
|
||||||
log_std_max: float = 2
|
log_std_max: float = 2
|
||||||
entropy_coef: float = 1e-4
|
|
||||||
|
|
||||||
# critic
|
# critic
|
||||||
num_bins: int = 101
|
num_bins: int = 101
|
||||||
vmin: int = -10
|
vmin: int = -10
|
||||||
vmax: int = +10
|
vmax: int = +10
|
||||||
|
|
||||||
rho: float = 0.5
|
|
||||||
tau: float = 0.01
|
|
||||||
# Inference.
|
# Inference.
|
||||||
use_mpc: bool = True
|
use_mpc: bool = True
|
||||||
cem_iterations: int = 6
|
cem_iterations: int = 6
|
||||||
max_std: float = 2.0
|
max_std: float = 2.0
|
||||||
min_std: float = 0.05
|
min_std: float = 0.05
|
||||||
n_gaussian_samples: int = 512
|
n_gaussian_samples: int = 512
|
||||||
n_pi_samples: int = 51
|
n_pi_samples: int = 24
|
||||||
uncertainty_regularizer_coeff: float = 1.0
|
n_elites: int = 64
|
||||||
n_elites: int = 50
|
|
||||||
elite_weighting_temperature: float = 0.5
|
elite_weighting_temperature: float = 0.5
|
||||||
gaussian_mean_momentum: float = 0.1
|
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
max_random_shift_ratio: float = 0.0476
|
max_random_shift_ratio: float = 0.0476
|
||||||
# Loss coefficients.
|
# Loss coefficients.
|
||||||
reward_coeff: float = 0.1
|
reward_coeff: float = 0.1
|
||||||
expectile_weight: float = 0.9
|
|
||||||
value_coeff: float = 0.1
|
value_coeff: float = 0.1
|
||||||
consistency_coeff: float = 20.0
|
consistency_coeff: float = 20.0
|
||||||
advantage_scaling: float = 3.0
|
entropy_coef: float = 1e-4
|
||||||
pi_coeff: float = 0.5
|
|
||||||
temporal_decay_coeff: float = 0.5
|
temporal_decay_coeff: float = 0.5
|
||||||
# Target model.
|
# Target model. NOTE (michel_aractingi) this is equivelant to
|
||||||
target_model_momentum: float = 0.995
|
# 1 - target_model_momentum of our TD-MPC1 implementation because
|
||||||
|
# of the use of `torch.lerp`
|
||||||
|
target_model_momentum: float = 0.01
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
|
|
|
@ -108,7 +108,7 @@ class TDMPC2Policy(
|
||||||
if "observation.environment_state" in config.input_shapes:
|
if "observation.environment_state" in config.input_shapes:
|
||||||
self._use_env_state = True
|
self._use_env_state = True
|
||||||
|
|
||||||
self.scale = RunningScale(self.config.tau)
|
self.scale = RunningScale(self.config.target_model_momentum)
|
||||||
self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length
|
self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
@ -249,19 +249,14 @@ class TDMPC2Policy(
|
||||||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||||
score /= score.sum(axis=0, keepdim=True)
|
score /= score.sum(axis=0, keepdim=True)
|
||||||
# (horizon, batch, action_dim)
|
# (horizon, batch, action_dim)
|
||||||
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
|
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (score.sum(0) + 1e-9)
|
||||||
_std = torch.sqrt(
|
std = torch.sqrt(
|
||||||
torch.sum(
|
torch.sum(
|
||||||
einops.rearrange(score, "n b -> n b 1")
|
einops.rearrange(score, "n b -> n b 1")
|
||||||
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
|
* (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2,
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
) / (score.sum(0) + 1e-9)
|
||||||
)
|
).clamp_(self.config.min_std, self.config.max_std)
|
||||||
# Update mean with an exponential moving average, and std with a direct replacement.
|
|
||||||
mean = (
|
|
||||||
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
|
|
||||||
)
|
|
||||||
std = _std.clamp_(self.config.min_std, self.config.max_std)
|
|
||||||
|
|
||||||
# Keep track of the mean for warm-starting subsequent steps.
|
# Keep track of the mean for warm-starting subsequent steps.
|
||||||
self._prev_mean = mean
|
self._prev_mean = mean
|
||||||
|
@ -687,20 +682,20 @@ class TDMPC2ObservationEncoder(nn.Module):
|
||||||
|
|
||||||
elif "observation.state" in config.input_shapes:
|
elif "observation.state" in config.input_shapes:
|
||||||
encoder_module = nn.ModuleList()
|
encoder_module = nn.ModuleList()
|
||||||
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.enc_dim))
|
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim))
|
||||||
assert config.num_enc_layers > 0
|
assert config.num_enc_layers > 0
|
||||||
for _ in range(config.num_enc_layers - 1):
|
for _ in range(config.num_enc_layers - 1):
|
||||||
encoder_module.append(NormedLinear(config.enc_dim, config.enc_dim))
|
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
|
||||||
encoder_module.append(NormedLinear(config.enc_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
||||||
encoder_module = nn.Sequential(*encoder_module)
|
encoder_module = nn.Sequential(*encoder_module)
|
||||||
|
|
||||||
elif "observation.environment_state" in config.input_shapes:
|
elif "observation.environment_state" in config.input_shapes:
|
||||||
encoder_module = nn.ModuleList()
|
encoder_module = nn.ModuleList()
|
||||||
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.enc_dim))
|
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim))
|
||||||
assert config.num_enc_layers > 0
|
assert config.num_enc_layers > 0
|
||||||
for _ in range(config.num_enc_layers - 1):
|
for _ in range(config.num_enc_layers - 1):
|
||||||
encoder_module.append(NormedLinear(config.enc_dim, config.enc_dim))
|
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
|
||||||
encoder_module.append(NormedLinear(config.enc_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
||||||
encoder_module = nn.Sequential(*encoder_module)
|
encoder_module = nn.Sequential(*encoder_module)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue