split encoder for critic and actor
This commit is contained in:
parent
2c2ed084cc
commit
7bb142b707
|
@ -48,7 +48,7 @@ class SACConfig:
|
|||
critic_target_update_weight = 0.005
|
||||
utd_ratio = 2
|
||||
state_encoder_hidden_dim = 256
|
||||
latent_dim = 50
|
||||
latent_dim = 128
|
||||
target_entropy = None
|
||||
critic_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
|
|
|
@ -63,25 +63,35 @@ class SACPolicy(
|
|||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
encoder = SACObservationEncoder(config)
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor = SACObservationEncoder(config)
|
||||
# Define networks
|
||||
critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
critic_net = Critic(encoder=encoder, network=MLP(**config.critic_network_kwargs))
|
||||
critic_net = Critic(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs
|
||||
)
|
||||
)
|
||||
critic_nets.append(critic_net)
|
||||
|
||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||
self.critic_target = deepcopy(self.critic_ensemble)
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder,
|
||||
network=MLP(**config.actor_network_kwargs),
|
||||
encoder=encoder_actor,
|
||||
network=MLP(
|
||||
input_dim=encoder_actor.output_dim,
|
||||
**config.actor_network_kwargs
|
||||
),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
**config.policy_kwargs,
|
||||
**config.policy_kwargs
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
|
||||
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
|
||||
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
|
@ -104,15 +114,31 @@ class SACPolicy(
|
|||
actions, _ = self.actor(batch)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||
return q_values
|
||||
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||
# the next observation for caluculating the right td index.
|
||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||
# the next observation for calculating the right td index.
|
||||
actions = batch["action"][:, 0]
|
||||
rewards = batch["next.reward"][:, 0]
|
||||
observations = {}
|
||||
|
@ -121,113 +147,109 @@ class SACPolicy(
|
|||
if k.startswith("observation."):
|
||||
observations[k] = batch[k][:, 0]
|
||||
next_observations[k] = batch[k][:, 1]
|
||||
|
||||
|
||||
# perform image augmentation
|
||||
|
||||
# reward bias from HIL-SERL code base
|
||||
# reward bias from HIL-SERL code base
|
||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||
|
||||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
action_preds, log_probs = self.actor(next_observations)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.target_qs(next_observations, action_preds)
|
||||
q_targets = self.critic_forward(next_observations, action_preds, use_target=True)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
indices = indices[:self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q = q_targets.min(dim=0)
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
|
||||
# compute td target
|
||||
td_target = (
|
||||
rewards + self.config.discount * min_q
|
||||
) # + self.config.discount * self.temperature() * log_probs # add entropy term
|
||||
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_ensemble(observations, actions)
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
# critics_loss = (
|
||||
# (
|
||||
# F.mse_loss(
|
||||
# q_preds,
|
||||
# einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
||||
# reduction="none",
|
||||
# ).sum(0) # sum over ensemble
|
||||
# # `q_preds_ensemble` depends on the first observation and the actions.
|
||||
# * ~batch["observation.state_is_pad"][0]
|
||||
# * ~batch["action_is_pad"]
|
||||
# # q_targets depends on the reward and the next observations.
|
||||
# * ~batch["next.reward_is_pad"]
|
||||
# * ~batch["observation.state_is_pad"][1:]
|
||||
# )
|
||||
# .sum(0)
|
||||
# .mean()
|
||||
# )
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
q_preds, # shape: [num_critics, batch_size]
|
||||
einops.repeat(
|
||||
td_target, "b -> e b", e=q_preds.shape[0]
|
||||
), # expand td_target to match q_preds shape
|
||||
reduction="none",
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
critics_loss = F.mse_loss(
|
||||
q_preds, # shape: [num_critics, batch_size]
|
||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
|
||||
reduction="none"
|
||||
).sum(0).mean()
|
||||
|
||||
# critics_loss = (
|
||||
# F.mse_loss(
|
||||
# q_preds,
|
||||
# einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
|
||||
# reduction="none",
|
||||
# ).sum(0) # sum over ensemble
|
||||
# # `q_preds_ensemble` depends on the first observation and the actions.
|
||||
# * ~batch["observation.state_is_pad"][0]
|
||||
# * ~batch["action_is_pad"]
|
||||
# # q_targets depends on the reward and the next observations.
|
||||
# * ~batch["next.reward_is_pad"]
|
||||
# * ~batch["observation.state_is_pad"][1:]
|
||||
# ).sum(0).mean()
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
temperature = self.temperature()
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
actions, log_probs = self.actor(observations)
|
||||
# 3- get q-value predictions
|
||||
with torch.no_grad():
|
||||
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
||||
with torch.inference_mode():
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
actor_loss = (
|
||||
-(q_preds - temperature * log_probs).mean()
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# * ~batch["observation.state_is_pad"][0]
|
||||
# * ~batch["action_is_pad"]
|
||||
).mean()
|
||||
|
||||
|
||||
# calculate temperature loss
|
||||
# 1- calculate entropy
|
||||
entropy = -log_probs.mean()
|
||||
temperature_loss = self.temp(lhs=entropy, rhs=self.config.target_entropy)
|
||||
temperature_loss = self.temperature(
|
||||
lhs=entropy,
|
||||
rhs=self.config.target_entropy
|
||||
)
|
||||
|
||||
loss = critics_loss + actor_loss + temperature_loss
|
||||
|
||||
return {
|
||||
"critics_loss": critics_loss.item(),
|
||||
"actor_loss": actor_loss.item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature.item(),
|
||||
"entropy": entropy.item(),
|
||||
"loss": loss,
|
||||
}
|
||||
|
||||
"critics_loss": critics_loss.item(),
|
||||
"actor_loss": actor_loss.item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature.item(),
|
||||
"entropy": entropy.item(),
|
||||
"loss": loss,
|
||||
}
|
||||
|
||||
def update(self):
|
||||
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
||||
# TODO: implement UTD update
|
||||
# First update only critics for utd_ratio-1 times
|
||||
# for critic_step in range(self.config.utd_ratio - 1):
|
||||
# only update critic and critic target
|
||||
#for critic_step in range(self.config.utd_ratio - 1):
|
||||
# only update critic and critic target
|
||||
# Then update critic, critic target, actor and temperature
|
||||
|
||||
# for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
||||
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
||||
|
||||
|
||||
"""Update target networks with exponential moving average"""
|
||||
with torch.no_grad():
|
||||
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
|
||||
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
|
||||
target_param.data.copy_(
|
||||
target_param.data * self.config.critic_target_update_weight +
|
||||
param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
|
@ -236,46 +258,52 @@ class MLP(nn.Module):
|
|||
super().__init__()
|
||||
self.activate_final = activate_final
|
||||
layers = []
|
||||
|
||||
for i, size in enumerate(hidden_dims):
|
||||
layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size))
|
||||
|
||||
|
||||
# First layer uses input_dim
|
||||
layers.append(nn.Linear(input_dim, hidden_dims[0]))
|
||||
|
||||
# Add activation after first layer
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
# Rest of the layers
|
||||
for i in range(1, len(hidden_dims)):
|
||||
layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
|
||||
|
||||
if i + 1 < len(hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(size))
|
||||
layers.append(
|
||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
|
||||
# in training mode or not. TODO: find better way to do this
|
||||
self.train(train)
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
device: str = "cuda",
|
||||
device: str = "cuda"
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.init_final = init_final
|
||||
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
|
||||
# Output layer
|
||||
if init_final is not None:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
|
@ -284,17 +312,22 @@ class Critic(nn.Module):
|
|||
else:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False) -> torch.Tensor:
|
||||
self.train(train)
|
||||
|
||||
observations = observations.to(self.device)
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Move each tensor in observations to device
|
||||
observations = {
|
||||
k: v.to(self.device) for k, v in observations.items()
|
||||
}
|
||||
actions = actions.to(self.device)
|
||||
|
||||
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
x = self.network(inputs)
|
||||
value = self.output_layer(x)
|
||||
|
@ -312,7 +345,7 @@ class Policy(nn.Module):
|
|||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
use_tanh_squash: bool = False,
|
||||
device: str = "cuda",
|
||||
device: str = "cuda"
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
|
@ -323,13 +356,13 @@ class Policy(nn.Module):
|
|||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
|
||||
# Mean layer
|
||||
self.mean_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
|
@ -337,7 +370,7 @@ class Policy(nn.Module):
|
|||
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
|
@ -346,20 +379,21 @@ class Policy(nn.Module):
|
|||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = observations if self.encoder is not None else self.encoder(observations)
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
means = self.mean_layer(outputs)
|
||||
|
||||
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
|
@ -367,25 +401,25 @@ class Policy(nn.Module):
|
|||
log_std = torch.tanh(log_std)
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
stds = self.fixed_std.expand_as(means)
|
||||
|
||||
log_std = self.fixed_std.expand_as(means)
|
||||
|
||||
# uses tahn activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, stds)
|
||||
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
|
||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t)
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
|
||||
log_probs = log_probs.sum(-1) # sum over action dim
|
||||
log_probs = log_probs.sum(-1) # sum over action dim
|
||||
|
||||
return actions, log_probs
|
||||
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
"""Get encoded features from observations"""
|
||||
observations = observations.to(self.device)
|
||||
if self.encoder is not None:
|
||||
with torch.no_grad():
|
||||
return self.encoder(observations, train=False)
|
||||
with torch.inference_mode():
|
||||
return self.encoder(observations)
|
||||
return observations
|
||||
|
||||
|
||||
|
@ -459,43 +493,56 @@ class SACObservationEncoder(nn.Module):
|
|||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
return self.config.latent_dim
|
||||
|
||||
|
||||
class LagrangeMultiplier(nn.Module):
|
||||
def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
|
||||
def __init__(
|
||||
self,
|
||||
init_value: float = 1.0,
|
||||
constraint_shape: Sequence[int] = (),
|
||||
device: str = "cuda"
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
||||
|
||||
|
||||
# Initialize the Lagrange multiplier as a parameter
|
||||
self.lagrange = nn.Parameter(
|
||||
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
|
||||
)
|
||||
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# Get the multiplier value based on parameterization
|
||||
def forward(
|
||||
self,
|
||||
lhs: Optional[torch.Tensor | float | int] = None,
|
||||
rhs: Optional[torch.Tensor | float | int] = None
|
||||
) -> torch.Tensor:
|
||||
# Get the multiplier value based on parameterization
|
||||
multiplier = torch.nn.functional.softplus(self.lagrange)
|
||||
|
||||
|
||||
# Return the raw multiplier if no constraint values provided
|
||||
if lhs is None:
|
||||
return multiplier
|
||||
|
||||
# Move inputs to device
|
||||
lhs = lhs.to(self.device)
|
||||
|
||||
# Convert inputs to tensors and move to device
|
||||
lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device)
|
||||
if rhs is not None:
|
||||
rhs = rhs.to(self.device)
|
||||
|
||||
# Use the multiplier to compute the Lagrange penalty
|
||||
if rhs is None:
|
||||
rhs = torch.tensor(rhs, device=self.device) if not isinstance(rhs, torch.Tensor) else rhs.to(self.device)
|
||||
else:
|
||||
rhs = torch.zeros_like(lhs, device=self.device)
|
||||
|
||||
|
||||
diff = lhs - rhs
|
||||
|
||||
|
||||
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
|
||||
|
||||
|
||||
return multiplier * diff
|
||||
|
||||
|
||||
|
@ -508,7 +555,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
|
|||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||
return nn.ModuleList(critics).to(device)
|
||||
|
||||
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
@ -516,7 +562,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
|||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
can be more than 1 dimensions, generally different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
|
@ -526,4 +572,4 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
|||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
Loading…
Reference in New Issue