Add type annotations and restructure SACConfig class fields
This commit is contained in:
parent
7d2970fdfe
commit
1fb03d4cf2
|
@ -16,6 +16,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -26,6 +27,7 @@ class SACConfig:
|
||||||
"observation.state": [4],
|
"observation.state": [4],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
output_shapes: dict[str, list[int]] = field(
|
output_shapes: dict[str, list[int]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": [2],
|
"action": [2],
|
||||||
|
@ -43,36 +45,64 @@ class SACConfig:
|
||||||
output_normalization_modes: dict[str, str] = field(
|
output_normalization_modes: dict[str, str] = field(
|
||||||
default_factory=lambda: {"action": "min_max"},
|
default_factory=lambda: {"action": "min_max"},
|
||||||
)
|
)
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
shared_encoder = False
|
@dataclass
|
||||||
discount = 0.99
|
class SACConfig:
|
||||||
temperature_init = 1.0
|
input_shapes: dict[str, list[int]] = field(
|
||||||
num_critics = 2
|
default_factory=lambda: {
|
||||||
# num_critics = 8
|
"observation.image": [3, 84, 84],
|
||||||
num_subsample_critics = None
|
"observation.state": [4],
|
||||||
# num_subsample_critics = 2
|
}
|
||||||
# critic_lr = 1e-3
|
)
|
||||||
critic_lr = 3e-4
|
output_shapes: dict[str, list[int]] = field(
|
||||||
actor_lr = 3e-4
|
default_factory=lambda: {
|
||||||
temperature_lr = 3e-4
|
"action": [2],
|
||||||
critic_target_update_weight = 0.005
|
}
|
||||||
# utd_ratio = 8
|
)
|
||||||
utd_ratio = 1 # If you want enable utd_ratio, you need to set it to >1
|
input_normalization_modes: dict[str, str] = field(
|
||||||
state_encoder_hidden_dim = 256
|
default_factory=lambda: {
|
||||||
latent_dim = 256
|
"observation.image": "mean_std",
|
||||||
target_entropy = None
|
"observation.state": "min_max",
|
||||||
# backup_entropy = False
|
"observation.environment_state": "min_max",
|
||||||
use_backup_entropy = True
|
}
|
||||||
critic_network_kwargs = {
|
)
|
||||||
|
output_normalization_modes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {"action": "min_max"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add type annotations for these fields:
|
||||||
|
image_encoder_hidden_dim: int = 32
|
||||||
|
shared_encoder: bool = False
|
||||||
|
discount: float = 0.99
|
||||||
|
temperature_init: float = 1.0
|
||||||
|
num_critics: int = 2
|
||||||
|
num_subsample_critics: int | None = None
|
||||||
|
critic_lr: float = 3e-4
|
||||||
|
actor_lr: float = 3e-4
|
||||||
|
temperature_lr: float = 3e-4
|
||||||
|
critic_target_update_weight: float = 0.005
|
||||||
|
utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1
|
||||||
|
state_encoder_hidden_dim: int = 256
|
||||||
|
latent_dim: int = 256
|
||||||
|
target_entropy: float | None = None
|
||||||
|
use_backup_entropy: bool = True
|
||||||
|
critic_network_kwargs: dict[str, Any] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
"hidden_dims": [256, 256],
|
"hidden_dims": [256, 256],
|
||||||
"activate_final": True,
|
"activate_final": True,
|
||||||
}
|
}
|
||||||
actor_network_kwargs = {
|
)
|
||||||
|
actor_network_kwargs: dict[str, Any] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
"hidden_dims": [256, 256],
|
"hidden_dims": [256, 256],
|
||||||
"activate_final": True,
|
"activate_final": True,
|
||||||
}
|
}
|
||||||
policy_kwargs = {
|
)
|
||||||
|
policy_kwargs: dict[str, Any] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
"use_tanh_squash": True,
|
"use_tanh_squash": True,
|
||||||
"log_std_min": -5,
|
"log_std_min": -5,
|
||||||
"log_std_max": 2,
|
"log_std_max": 2,
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue