split encoder for critic and actor
This commit is contained in:
parent
2c2ed084cc
commit
7bb142b707
|
@ -48,7 +48,7 @@ class SACConfig:
|
||||||
critic_target_update_weight = 0.005
|
critic_target_update_weight = 0.005
|
||||||
utd_ratio = 2
|
utd_ratio = 2
|
||||||
state_encoder_hidden_dim = 256
|
state_encoder_hidden_dim = 256
|
||||||
latent_dim = 50
|
latent_dim = 128
|
||||||
target_entropy = None
|
target_entropy = None
|
||||||
critic_network_kwargs = {
|
critic_network_kwargs = {
|
||||||
"hidden_dims": [256, 256],
|
"hidden_dims": [256, 256],
|
||||||
|
|
|
@ -63,21 +63,31 @@ class SACPolicy(
|
||||||
self.unnormalize_outputs = Unnormalize(
|
self.unnormalize_outputs = Unnormalize(
|
||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
encoder = SACObservationEncoder(config)
|
encoder_critic = SACObservationEncoder(config)
|
||||||
|
encoder_actor = SACObservationEncoder(config)
|
||||||
# Define networks
|
# Define networks
|
||||||
critic_nets = []
|
critic_nets = []
|
||||||
for _ in range(config.num_critics):
|
for _ in range(config.num_critics):
|
||||||
critic_net = Critic(encoder=encoder, network=MLP(**config.critic_network_kwargs))
|
critic_net = Critic(
|
||||||
|
encoder=encoder_critic,
|
||||||
|
network=MLP(
|
||||||
|
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||||
|
**config.critic_network_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
critic_nets.append(critic_net)
|
critic_nets.append(critic_net)
|
||||||
|
|
||||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||||
self.critic_target = deepcopy(self.critic_ensemble)
|
self.critic_target = deepcopy(self.critic_ensemble)
|
||||||
|
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=encoder,
|
encoder=encoder_actor,
|
||||||
network=MLP(**config.actor_network_kwargs),
|
network=MLP(
|
||||||
|
input_dim=encoder_actor.output_dim,
|
||||||
|
**config.actor_network_kwargs
|
||||||
|
),
|
||||||
action_dim=config.output_shapes["action"][0],
|
action_dim=config.output_shapes["action"][0],
|
||||||
**config.policy_kwargs,
|
**config.policy_kwargs
|
||||||
)
|
)
|
||||||
if config.target_entropy is None:
|
if config.target_entropy is None:
|
||||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
|
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
|
||||||
|
@ -105,6 +115,22 @@ class SACPolicy(
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
|
||||||
|
"""Forward pass through a critic network ensemble
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observations: Dictionary of observations
|
||||||
|
actions: Action tensor
|
||||||
|
use_target: If True, use target critics, otherwise use ensemble critics
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of Q-values from all critics
|
||||||
|
"""
|
||||||
|
critics = self.critic_target if use_target else self.critic_ensemble
|
||||||
|
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||||
|
return q_values
|
||||||
|
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||||
"""Run the batch through the model and compute the loss.
|
"""Run the batch through the model and compute the loss.
|
||||||
|
|
||||||
|
@ -112,7 +138,7 @@ class SACPolicy(
|
||||||
"""
|
"""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||||
# the next observation for caluculating the right td index.
|
# the next observation for calculating the right td index.
|
||||||
actions = batch["action"][:, 0]
|
actions = batch["action"][:, 0]
|
||||||
rewards = batch["next.reward"][:, 0]
|
rewards = batch["next.reward"][:, 0]
|
||||||
observations = {}
|
observations = {}
|
||||||
|
@ -132,7 +158,8 @@ class SACPolicy(
|
||||||
action_preds, log_probs = self.actor(next_observations)
|
action_preds, log_probs = self.actor(next_observations)
|
||||||
|
|
||||||
# 2- compute q targets
|
# 2- compute q targets
|
||||||
q_targets = self.target_qs(next_observations, action_preds)
|
q_targets = self.critic_forward(next_observations, action_preds, use_target=True)
|
||||||
|
|
||||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||||
if self.config.num_subsample_critics is not None:
|
if self.config.num_subsample_critics is not None:
|
||||||
indices = torch.randperm(self.config.num_critics)
|
indices = torch.randperm(self.config.num_critics)
|
||||||
|
@ -140,23 +167,26 @@ class SACPolicy(
|
||||||
q_targets = q_targets[indices]
|
q_targets = q_targets[indices]
|
||||||
|
|
||||||
# critics subsample size
|
# critics subsample size
|
||||||
min_q = q_targets.min(dim=0)
|
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||||
|
|
||||||
# compute td target
|
# compute td target
|
||||||
td_target = (
|
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
|
||||||
rewards + self.config.discount * min_q
|
|
||||||
) # + self.config.discount * self.temperature() * log_probs # add entropy term
|
|
||||||
|
|
||||||
# 3- compute predicted qs
|
# 3- compute predicted qs
|
||||||
q_preds = self.critic_ensemble(observations, actions)
|
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||||
|
|
||||||
# 4- Calculate loss
|
# 4- Calculate loss
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
|
critics_loss = F.mse_loss(
|
||||||
|
q_preds, # shape: [num_critics, batch_size]
|
||||||
|
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
|
||||||
|
reduction="none"
|
||||||
|
).sum(0).mean()
|
||||||
|
|
||||||
# critics_loss = (
|
# critics_loss = (
|
||||||
# (
|
|
||||||
# F.mse_loss(
|
# F.mse_loss(
|
||||||
# q_preds,
|
# q_preds,
|
||||||
# einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
# einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
|
||||||
# reduction="none",
|
# reduction="none",
|
||||||
# ).sum(0) # sum over ensemble
|
# ).sum(0) # sum over ensemble
|
||||||
# # `q_preds_ensemble` depends on the first observation and the actions.
|
# # `q_preds_ensemble` depends on the first observation and the actions.
|
||||||
|
@ -165,23 +195,7 @@ class SACPolicy(
|
||||||
# # q_targets depends on the reward and the next observations.
|
# # q_targets depends on the reward and the next observations.
|
||||||
# * ~batch["next.reward_is_pad"]
|
# * ~batch["next.reward_is_pad"]
|
||||||
# * ~batch["observation.state_is_pad"][1:]
|
# * ~batch["observation.state_is_pad"][1:]
|
||||||
# )
|
# ).sum(0).mean()
|
||||||
# .sum(0)
|
|
||||||
# .mean()
|
|
||||||
# )
|
|
||||||
# 4- Calculate loss
|
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
|
||||||
critics_loss = (
|
|
||||||
F.mse_loss(
|
|
||||||
q_preds, # shape: [num_critics, batch_size]
|
|
||||||
einops.repeat(
|
|
||||||
td_target, "b -> e b", e=q_preds.shape[0]
|
|
||||||
), # expand td_target to match q_preds shape
|
|
||||||
reduction="none",
|
|
||||||
)
|
|
||||||
.sum(0)
|
|
||||||
.mean()
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculate actors loss
|
# calculate actors loss
|
||||||
# 1- temperature
|
# 1- temperature
|
||||||
|
@ -189,18 +203,22 @@ class SACPolicy(
|
||||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||||
actions, log_probs = self.actor(observations)
|
actions, log_probs = self.actor(observations)
|
||||||
# 3- get q-value predictions
|
# 3- get q-value predictions
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||||
actor_loss = (
|
actor_loss = (
|
||||||
-(q_preds - temperature * log_probs).mean()
|
-(q_preds - temperature * log_probs).mean()
|
||||||
* ~batch["observation.state_is_pad"][0]
|
# * ~batch["observation.state_is_pad"][0]
|
||||||
* ~batch["action_is_pad"]
|
# * ~batch["action_is_pad"]
|
||||||
).mean()
|
).mean()
|
||||||
|
|
||||||
|
|
||||||
# calculate temperature loss
|
# calculate temperature loss
|
||||||
# 1- calculate entropy
|
# 1- calculate entropy
|
||||||
entropy = -log_probs.mean()
|
entropy = -log_probs.mean()
|
||||||
temperature_loss = self.temp(lhs=entropy, rhs=self.config.target_entropy)
|
temperature_loss = self.temperature(
|
||||||
|
lhs=entropy,
|
||||||
|
rhs=self.config.target_entropy
|
||||||
|
)
|
||||||
|
|
||||||
loss = critics_loss + actor_loss + temperature_loss
|
loss = critics_loss + actor_loss + temperature_loss
|
||||||
|
|
||||||
|
@ -214,20 +232,24 @@ class SACPolicy(
|
||||||
}
|
}
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
|
||||||
# TODO: implement UTD update
|
# TODO: implement UTD update
|
||||||
# First update only critics for utd_ratio-1 times
|
# First update only critics for utd_ratio-1 times
|
||||||
#for critic_step in range(self.config.utd_ratio - 1):
|
#for critic_step in range(self.config.utd_ratio - 1):
|
||||||
# only update critic and critic target
|
# only update critic and critic target
|
||||||
# Then update critic, critic target, actor and temperature
|
# Then update critic, critic target, actor and temperature
|
||||||
|
"""Update target networks with exponential moving average"""
|
||||||
# for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
with torch.no_grad():
|
||||||
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
|
||||||
|
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
|
||||||
|
target_param.data.copy_(
|
||||||
|
target_param.data * self.config.critic_target_update_weight +
|
||||||
|
param.data * (1.0 - self.config.critic_target_update_weight)
|
||||||
|
)
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
input_dim: int,
|
||||||
hidden_dims: list[int],
|
hidden_dims: list[int],
|
||||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||||
activate_final: bool = False,
|
activate_final: bool = False,
|
||||||
|
@ -237,22 +259,28 @@ class MLP(nn.Module):
|
||||||
self.activate_final = activate_final
|
self.activate_final = activate_final
|
||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for i, size in enumerate(hidden_dims):
|
# First layer uses input_dim
|
||||||
layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size))
|
layers.append(nn.Linear(input_dim, hidden_dims[0]))
|
||||||
|
|
||||||
|
# Add activation after first layer
|
||||||
|
if dropout_rate is not None and dropout_rate > 0:
|
||||||
|
layers.append(nn.Dropout(p=dropout_rate))
|
||||||
|
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||||
|
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||||
|
|
||||||
|
# Rest of the layers
|
||||||
|
for i in range(1, len(hidden_dims)):
|
||||||
|
layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
|
||||||
|
|
||||||
if i + 1 < len(hidden_dims) or activate_final:
|
if i + 1 < len(hidden_dims) or activate_final:
|
||||||
if dropout_rate is not None and dropout_rate > 0:
|
if dropout_rate is not None and dropout_rate > 0:
|
||||||
layers.append(nn.Dropout(p=dropout_rate))
|
layers.append(nn.Dropout(p=dropout_rate))
|
||||||
layers.append(nn.LayerNorm(size))
|
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||||
layers.append(
|
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.net = nn.Sequential(*layers)
|
self.net = nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# in training mode or not. TODO: find better way to do this
|
|
||||||
self.train(train)
|
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
@ -262,7 +290,7 @@ class Critic(nn.Module):
|
||||||
encoder: Optional[nn.Module],
|
encoder: Optional[nn.Module],
|
||||||
network: nn.Module,
|
network: nn.Module,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
device: str = "cuda",
|
device: str = "cuda"
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
|
@ -287,10 +315,15 @@ class Critic(nn.Module):
|
||||||
|
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False) -> torch.Tensor:
|
def forward(
|
||||||
self.train(train)
|
self,
|
||||||
|
observations: dict[str, torch.Tensor],
|
||||||
observations = observations.to(self.device)
|
actions: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Move each tensor in observations to device
|
||||||
|
observations = {
|
||||||
|
k: v.to(self.device) for k, v in observations.items()
|
||||||
|
}
|
||||||
actions = actions.to(self.device)
|
actions = actions.to(self.device)
|
||||||
|
|
||||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||||
|
@ -312,7 +345,7 @@ class Policy(nn.Module):
|
||||||
fixed_std: Optional[torch.Tensor] = None,
|
fixed_std: Optional[torch.Tensor] = None,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
use_tanh_squash: bool = False,
|
use_tanh_squash: bool = False,
|
||||||
device: str = "cuda",
|
device: str = "cuda"
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
|
@ -353,8 +386,9 @@ class Policy(nn.Module):
|
||||||
self,
|
self,
|
||||||
observations: torch.Tensor,
|
observations: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
# Encode observations if encoder exists
|
# Encode observations if encoder exists
|
||||||
obs_enc = observations if self.encoder is not None else self.encoder(observations)
|
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||||
|
|
||||||
# Get network outputs
|
# Get network outputs
|
||||||
outputs = self.network(obs_enc)
|
outputs = self.network(obs_enc)
|
||||||
|
@ -367,10 +401,10 @@ class Policy(nn.Module):
|
||||||
log_std = torch.tanh(log_std)
|
log_std = torch.tanh(log_std)
|
||||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||||
else:
|
else:
|
||||||
stds = self.fixed_std.expand_as(means)
|
log_std = self.fixed_std.expand_as(means)
|
||||||
|
|
||||||
# uses tahn activation function to squash the action to be in the range of [-1, 1]
|
# uses tahn activation function to squash the action to be in the range of [-1, 1]
|
||||||
normal = torch.distributions.Normal(means, stds)
|
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||||
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
|
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
|
||||||
log_probs = normal.log_prob(x_t)
|
log_probs = normal.log_prob(x_t)
|
||||||
if self.use_tanh_squash:
|
if self.use_tanh_squash:
|
||||||
|
@ -384,8 +418,8 @@ class Policy(nn.Module):
|
||||||
"""Get encoded features from observations"""
|
"""Get encoded features from observations"""
|
||||||
observations = observations.to(self.device)
|
observations = observations.to(self.device)
|
||||||
if self.encoder is not None:
|
if self.encoder is not None:
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
return self.encoder(observations, train=False)
|
return self.encoder(observations)
|
||||||
return observations
|
return observations
|
||||||
|
|
||||||
|
|
||||||
|
@ -459,11 +493,22 @@ class SACObservationEncoder(nn.Module):
|
||||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||||
if "observation.state" in self.config.input_shapes:
|
if "observation.state" in self.config.input_shapes:
|
||||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||||
|
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
|
||||||
return torch.stack(feat, dim=0).mean(0)
|
return torch.stack(feat, dim=0).mean(0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dim(self) -> int:
|
||||||
|
"""Returns the dimension of the encoder output"""
|
||||||
|
return self.config.latent_dim
|
||||||
|
|
||||||
|
|
||||||
class LagrangeMultiplier(nn.Module):
|
class LagrangeMultiplier(nn.Module):
|
||||||
def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
init_value: float = 1.0,
|
||||||
|
constraint_shape: Sequence[int] = (),
|
||||||
|
device: str = "cuda"
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
||||||
|
@ -475,7 +520,11 @@ class LagrangeMultiplier(nn.Module):
|
||||||
|
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
lhs: Optional[torch.Tensor | float | int] = None,
|
||||||
|
rhs: Optional[torch.Tensor | float | int] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
# Get the multiplier value based on parameterization
|
# Get the multiplier value based on parameterization
|
||||||
multiplier = torch.nn.functional.softplus(self.lagrange)
|
multiplier = torch.nn.functional.softplus(self.lagrange)
|
||||||
|
|
||||||
|
@ -483,13 +532,11 @@ class LagrangeMultiplier(nn.Module):
|
||||||
if lhs is None:
|
if lhs is None:
|
||||||
return multiplier
|
return multiplier
|
||||||
|
|
||||||
# Move inputs to device
|
# Convert inputs to tensors and move to device
|
||||||
lhs = lhs.to(self.device)
|
lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device)
|
||||||
if rhs is not None:
|
if rhs is not None:
|
||||||
rhs = rhs.to(self.device)
|
rhs = torch.tensor(rhs, device=self.device) if not isinstance(rhs, torch.Tensor) else rhs.to(self.device)
|
||||||
|
else:
|
||||||
# Use the multiplier to compute the Lagrange penalty
|
|
||||||
if rhs is None:
|
|
||||||
rhs = torch.zeros_like(lhs, device=self.device)
|
rhs = torch.zeros_like(lhs, device=self.device)
|
||||||
|
|
||||||
diff = lhs - rhs
|
diff = lhs - rhs
|
||||||
|
@ -508,7 +555,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
|
||||||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||||
return nn.ModuleList(critics).to(device)
|
return nn.ModuleList(critics).to(device)
|
||||||
|
|
||||||
|
|
||||||
# borrowed from tdmpc
|
# borrowed from tdmpc
|
||||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||||
|
|
Loading…
Reference in New Issue