updated configuration parameters
This commit is contained in:
parent
31984645da
commit
166c1fc776
|
@ -102,8 +102,8 @@ class TDMPC2Config:
|
|||
"""
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: int = 2
|
||||
horizon: int = 5
|
||||
n_action_repeats: int = 1
|
||||
horizon: int = 3
|
||||
n_action_steps: int = 1
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
|
@ -128,7 +128,7 @@ class TDMPC2Config:
|
|||
# Neural networks.
|
||||
image_encoder_hidden_dim: int = 32
|
||||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 50
|
||||
latent_dim: int = 512
|
||||
q_ensemble_size: int = 5
|
||||
mlp_dim: int = 512
|
||||
# Reinforcement learning.
|
||||
|
@ -137,39 +137,34 @@ class TDMPC2Config:
|
|||
# actor
|
||||
log_std_min: float = -10
|
||||
log_std_max: float = 2
|
||||
entropy_coef: float = 1e-4
|
||||
|
||||
# critic
|
||||
num_bins: int = 101
|
||||
vmin: int = -10
|
||||
vmax: int = +10
|
||||
|
||||
rho: float = 0.5
|
||||
tau: float = 0.01
|
||||
# Inference.
|
||||
use_mpc: bool = True
|
||||
cem_iterations: int = 6
|
||||
max_std: float = 2.0
|
||||
min_std: float = 0.05
|
||||
n_gaussian_samples: int = 512
|
||||
n_pi_samples: int = 51
|
||||
uncertainty_regularizer_coeff: float = 1.0
|
||||
n_elites: int = 50
|
||||
n_pi_samples: int = 24
|
||||
n_elites: int = 64
|
||||
elite_weighting_temperature: float = 0.5
|
||||
gaussian_mean_momentum: float = 0.1
|
||||
|
||||
# Training and loss computation.
|
||||
max_random_shift_ratio: float = 0.0476
|
||||
# Loss coefficients.
|
||||
reward_coeff: float = 0.1
|
||||
expectile_weight: float = 0.9
|
||||
value_coeff: float = 0.1
|
||||
consistency_coeff: float = 20.0
|
||||
advantage_scaling: float = 3.0
|
||||
pi_coeff: float = 0.5
|
||||
entropy_coef: float = 1e-4
|
||||
temporal_decay_coeff: float = 0.5
|
||||
# Target model.
|
||||
target_model_momentum: float = 0.995
|
||||
# Target model. NOTE (michel_aractingi) this is equivelant to
|
||||
# 1 - target_model_momentum of our TD-MPC1 implementation because
|
||||
# of the use of `torch.lerp`
|
||||
target_model_momentum: float = 0.01
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
|
|
|
@ -108,7 +108,7 @@ class TDMPC2Policy(
|
|||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
|
||||
self.scale = RunningScale(self.config.tau)
|
||||
self.scale = RunningScale(self.config.target_model_momentum)
|
||||
self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length
|
||||
|
||||
self.reset()
|
||||
|
@ -249,19 +249,14 @@ class TDMPC2Policy(
|
|||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||
score /= score.sum(axis=0, keepdim=True)
|
||||
# (horizon, batch, action_dim)
|
||||
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
|
||||
_std = torch.sqrt(
|
||||
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (score.sum(0) + 1e-9)
|
||||
std = torch.sqrt(
|
||||
torch.sum(
|
||||
einops.rearrange(score, "n b -> n b 1")
|
||||
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
|
||||
* (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2,
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
# Update mean with an exponential moving average, and std with a direct replacement.
|
||||
mean = (
|
||||
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
|
||||
)
|
||||
std = _std.clamp_(self.config.min_std, self.config.max_std)
|
||||
) / (score.sum(0) + 1e-9)
|
||||
).clamp_(self.config.min_std, self.config.max_std)
|
||||
|
||||
# Keep track of the mean for warm-starting subsequent steps.
|
||||
self._prev_mean = mean
|
||||
|
@ -687,20 +682,20 @@ class TDMPC2ObservationEncoder(nn.Module):
|
|||
|
||||
elif "observation.state" in config.input_shapes:
|
||||
encoder_module = nn.ModuleList()
|
||||
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.enc_dim))
|
||||
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim))
|
||||
assert config.num_enc_layers > 0
|
||||
for _ in range(config.num_enc_layers - 1):
|
||||
encoder_module.append(NormedLinear(config.enc_dim, config.enc_dim))
|
||||
encoder_module.append(NormedLinear(config.enc_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
||||
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
|
||||
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
||||
encoder_module = nn.Sequential(*encoder_module)
|
||||
|
||||
elif "observation.environment_state" in config.input_shapes:
|
||||
encoder_module = nn.ModuleList()
|
||||
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.enc_dim))
|
||||
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim))
|
||||
assert config.num_enc_layers > 0
|
||||
for _ in range(config.num_enc_layers - 1):
|
||||
encoder_module.append(NormedLinear(config.enc_dim, config.enc_dim))
|
||||
encoder_module.append(NormedLinear(config.enc_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
||||
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
|
||||
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
||||
encoder_module = nn.Sequential(*encoder_module)
|
||||
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue