split encoder for critic and actor

This commit is contained in:
Michel Aractingi 2024-12-29 23:59:39 +00:00
parent bae3b02928
commit ee306e2f9b
2 changed files with 177 additions and 131 deletions

View File

@ -48,7 +48,7 @@ class SACConfig:
critic_target_update_weight = 0.005 critic_target_update_weight = 0.005
utd_ratio = 2 utd_ratio = 2
state_encoder_hidden_dim = 256 state_encoder_hidden_dim = 256
latent_dim = 50 latent_dim = 128
target_entropy = None target_entropy = None
critic_network_kwargs = { critic_network_kwargs = {
"hidden_dims": [256, 256], "hidden_dims": [256, 256],

View File

@ -63,25 +63,35 @@ class SACPolicy(
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
encoder = SACObservationEncoder(config) encoder_critic = SACObservationEncoder(config)
encoder_actor = SACObservationEncoder(config)
# Define networks # Define networks
critic_nets = [] critic_nets = []
for _ in range(config.num_critics): for _ in range(config.num_critics):
critic_net = Critic(encoder=encoder, network=MLP(**config.critic_network_kwargs)) critic_net = Critic(
encoder=encoder_critic,
network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs
)
)
critic_nets.append(critic_net) critic_nets.append(critic_net)
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = deepcopy(self.critic_ensemble) self.critic_target = deepcopy(self.critic_ensemble)
self.actor = Policy( self.actor = Policy(
encoder=encoder, encoder=encoder_actor,
network=MLP(**config.actor_network_kwargs), network=MLP(
input_dim=encoder_actor.output_dim,
**config.actor_network_kwargs
),
action_dim=config.output_shapes["action"][0], action_dim=config.output_shapes["action"][0],
**config.policy_kwargs, **config.policy_kwargs
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
self.temperature = LagrangeMultiplier(init_value=config.temperature_init) self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
def reset(self): def reset(self):
""" """
@ -104,15 +114,31 @@ class SACPolicy(
actions, _ = self.actor(batch) actions, _ = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
return actions return actions
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = torch.stack([critic(observations, actions) for critic in critics])
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss. """Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats. Returns a dictionary with loss as a tensor, and other information as native floats.
""" """
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and # batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for caluculating the right td index. # the next observation for calculating the right td index.
actions = batch["action"][:, 0] actions = batch["action"][:, 0]
rewards = batch["next.reward"][:, 0] rewards = batch["next.reward"][:, 0]
observations = {} observations = {}
@ -121,113 +147,109 @@ class SACPolicy(
if k.startswith("observation."): if k.startswith("observation."):
observations[k] = batch[k][:, 0] observations[k] = batch[k][:, 0]
next_observations[k] = batch[k][:, 1] next_observations[k] = batch[k][:, 1]
# perform image augmentation # perform image augmentation
# reward bias from HIL-SERL code base # reward bias from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss # calculate critics loss
# 1- compute actions from policy # 1- compute actions from policy
action_preds, log_probs = self.actor(next_observations) action_preds, log_probs = self.actor(next_observations)
# 2- compute q targets # 2- compute q targets
q_targets = self.target_qs(next_observations, action_preds) q_targets = self.critic_forward(next_observations, action_preds, use_target=True)
# subsample critics to prevent overfitting if use high UTD (update to date) # subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None: if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics) indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics] indices = indices[:self.config.num_subsample_critics]
q_targets = q_targets[indices] q_targets = q_targets[indices]
# critics subsample size # critics subsample size
min_q = q_targets.min(dim=0) min_q, _ = q_targets.min(dim=0) # Get values from min operation
# compute td target # compute td target
td_target = ( td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
rewards + self.config.discount * min_q
) # + self.config.discount * self.temperature() * log_probs # add entropy term
# 3- compute predicted qs # 3- compute predicted qs
q_preds = self.critic_ensemble(observations, actions) q_preds = self.critic_forward(observations, actions, use_target=False)
# 4- Calculate loss # 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
# critics_loss = ( critics_loss = F.mse_loss(
# ( q_preds, # shape: [num_critics, batch_size]
# F.mse_loss( einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
# q_preds, reduction="none"
# einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]), ).sum(0).mean()
# reduction="none",
# ).sum(0) # sum over ensemble
# # `q_preds_ensemble` depends on the first observation and the actions.
# * ~batch["observation.state_is_pad"][0]
# * ~batch["action_is_pad"]
# # q_targets depends on the reward and the next observations.
# * ~batch["next.reward_is_pad"]
# * ~batch["observation.state_is_pad"][1:]
# )
# .sum(0)
# .mean()
# )
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = (
F.mse_loss(
q_preds, # shape: [num_critics, batch_size]
einops.repeat(
td_target, "b -> e b", e=q_preds.shape[0]
), # expand td_target to match q_preds shape
reduction="none",
)
.sum(0)
.mean()
)
# critics_loss = (
# F.mse_loss(
# q_preds,
# einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
# reduction="none",
# ).sum(0) # sum over ensemble
# # `q_preds_ensemble` depends on the first observation and the actions.
# * ~batch["observation.state_is_pad"][0]
# * ~batch["action_is_pad"]
# # q_targets depends on the reward and the next observations.
# * ~batch["next.reward_is_pad"]
# * ~batch["observation.state_is_pad"][1:]
# ).sum(0).mean()
# calculate actors loss # calculate actors loss
# 1- temperature # 1- temperature
temperature = self.temperature() temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,) # 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor(observations) actions, log_probs = self.actor(observations)
# 3- get q-value predictions # 3- get q-value predictions
with torch.no_grad(): with torch.inference_mode():
q_preds = self.critic_ensemble(observations, actions, return_type="mean") q_preds = self.critic_forward(observations, actions, use_target=False)
actor_loss = ( actor_loss = (
-(q_preds - temperature * log_probs).mean() -(q_preds - temperature * log_probs).mean()
* ~batch["observation.state_is_pad"][0] # * ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"] # * ~batch["action_is_pad"]
).mean() ).mean()
# calculate temperature loss # calculate temperature loss
# 1- calculate entropy # 1- calculate entropy
entropy = -log_probs.mean() entropy = -log_probs.mean()
temperature_loss = self.temp(lhs=entropy, rhs=self.config.target_entropy) temperature_loss = self.temperature(
lhs=entropy,
rhs=self.config.target_entropy
)
loss = critics_loss + actor_loss + temperature_loss loss = critics_loss + actor_loss + temperature_loss
return { return {
"critics_loss": critics_loss.item(), "critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(), "actor_loss": actor_loss.item(),
"temperature_loss": temperature_loss.item(), "temperature_loss": temperature_loss.item(),
"temperature": temperature.item(), "temperature": temperature.item(),
"entropy": entropy.item(), "entropy": entropy.item(),
"loss": loss, "loss": loss,
} }
def update(self): def update(self):
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
# TODO: implement UTD update # TODO: implement UTD update
# First update only critics for utd_ratio-1 times # First update only critics for utd_ratio-1 times
# for critic_step in range(self.config.utd_ratio - 1): #for critic_step in range(self.config.utd_ratio - 1):
# only update critic and critic target # only update critic and critic target
# Then update critic, critic target, actor and temperature # Then update critic, critic target, actor and temperature
"""Update target networks with exponential moving average"""
# for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): with torch.no_grad():
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight) for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
target_param.data.copy_(
target_param.data * self.config.critic_target_update_weight +
param.data * (1.0 - self.config.critic_target_update_weight)
)
class MLP(nn.Module): class MLP(nn.Module):
def __init__( def __init__(
self, self,
input_dim: int,
hidden_dims: list[int], hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False, activate_final: bool = False,
@ -236,46 +258,52 @@ class MLP(nn.Module):
super().__init__() super().__init__()
self.activate_final = activate_final self.activate_final = activate_final
layers = [] layers = []
for i, size in enumerate(hidden_dims): # First layer uses input_dim
layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size)) layers.append(nn.Linear(input_dim, hidden_dims[0]))
# Add activation after first layer
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
# Rest of the layers
for i in range(1, len(hidden_dims)):
layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
if i + 1 < len(hidden_dims) or activate_final: if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0: if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(size)) layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append( layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers) self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# in training mode or not. TODO: find better way to do this
self.train(train)
return self.net(x) return self.net(x)
class Critic(nn.Module): class Critic(nn.Module):
def __init__( def __init__(
self, self,
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network: nn.Module, network: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
device: str = "cuda", device: str = "cuda"
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
self.encoder = encoder self.encoder = encoder
self.network = network self.network = network
self.init_final = init_final self.init_final = init_final
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network.net): for layer in reversed(network.net):
if isinstance(layer, nn.Linear): if isinstance(layer, nn.Linear):
out_features = layer.out_features out_features = layer.out_features
break break
# Output layer # Output layer
if init_final is not None: if init_final is not None:
self.output_layer = nn.Linear(out_features, 1) self.output_layer = nn.Linear(out_features, 1)
@ -284,17 +312,22 @@ class Critic(nn.Module):
else: else:
self.output_layer = nn.Linear(out_features, 1) self.output_layer = nn.Linear(out_features, 1)
orthogonal_init()(self.output_layer.weight) orthogonal_init()(self.output_layer.weight)
self.to(self.device) self.to(self.device)
def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False) -> torch.Tensor: def forward(
self.train(train) self,
observations: dict[str, torch.Tensor],
observations = observations.to(self.device) actions: torch.Tensor,
) -> torch.Tensor:
# Move each tensor in observations to device
observations = {
k: v.to(self.device) for k, v in observations.items()
}
actions = actions.to(self.device) actions = actions.to(self.device)
obs_enc = observations if self.encoder is None else self.encoder(observations) obs_enc = observations if self.encoder is None else self.encoder(observations)
inputs = torch.cat([obs_enc, actions], dim=-1) inputs = torch.cat([obs_enc, actions], dim=-1)
x = self.network(inputs) x = self.network(inputs)
value = self.output_layer(x) value = self.output_layer(x)
@ -312,7 +345,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None, fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None, init_final: Optional[float] = None,
use_tanh_squash: bool = False, use_tanh_squash: bool = False,
device: str = "cuda", device: str = "cuda"
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@ -323,13 +356,13 @@ class Policy(nn.Module):
self.log_std_max = log_std_max self.log_std_max = log_std_max
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.use_tanh_squash = use_tanh_squash self.use_tanh_squash = use_tanh_squash
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network.net): for layer in reversed(network.net):
if isinstance(layer, nn.Linear): if isinstance(layer, nn.Linear):
out_features = layer.out_features out_features = layer.out_features
break break
# Mean layer # Mean layer
self.mean_layer = nn.Linear(out_features, action_dim) self.mean_layer = nn.Linear(out_features, action_dim)
if init_final is not None: if init_final is not None:
@ -337,7 +370,7 @@ class Policy(nn.Module):
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
else: else:
orthogonal_init()(self.mean_layer.weight) orthogonal_init()(self.mean_layer.weight)
# Standard deviation layer or parameter # Standard deviation layer or parameter
if fixed_std is None: if fixed_std is None:
self.std_layer = nn.Linear(out_features, action_dim) self.std_layer = nn.Linear(out_features, action_dim)
@ -346,20 +379,21 @@ class Policy(nn.Module):
nn.init.uniform_(self.std_layer.bias, -init_final, init_final) nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else: else:
orthogonal_init()(self.std_layer.weight) orthogonal_init()(self.std_layer.weight)
self.to(self.device) self.to(self.device)
def forward( def forward(
self, self,
observations: torch.Tensor, observations: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists # Encode observations if encoder exists
obs_enc = observations if self.encoder is not None else self.encoder(observations) obs_enc = observations if self.encoder is None else self.encoder(observations)
# Get network outputs # Get network outputs
outputs = self.network(obs_enc) outputs = self.network(obs_enc)
means = self.mean_layer(outputs) means = self.mean_layer(outputs)
# Compute standard deviations # Compute standard deviations
if self.fixed_std is None: if self.fixed_std is None:
log_std = self.std_layer(outputs) log_std = self.std_layer(outputs)
@ -367,25 +401,25 @@ class Policy(nn.Module):
log_std = torch.tanh(log_std) log_std = torch.tanh(log_std)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else: else:
stds = self.fixed_std.expand_as(means) log_std = self.fixed_std.expand_as(means)
# uses tahn activation function to squash the action to be in the range of [-1, 1] # uses tahn activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, stds) normal = torch.distributions.Normal(means, torch.exp(log_std))
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t) log_probs = normal.log_prob(x_t)
if self.use_tanh_squash: if self.use_tanh_squash:
actions = torch.tanh(x_t) actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
log_probs = log_probs.sum(-1) # sum over action dim log_probs = log_probs.sum(-1) # sum over action dim
return actions, log_probs return actions, log_probs
def get_features(self, observations: torch.Tensor) -> torch.Tensor: def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations""" """Get encoded features from observations"""
observations = observations.to(self.device) observations = observations.to(self.device)
if self.encoder is not None: if self.encoder is not None:
with torch.no_grad(): with torch.inference_mode():
return self.encoder(observations, train=False) return self.encoder(observations)
return observations return observations
@ -459,43 +493,56 @@ class SACObservationEncoder(nn.Module):
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes: if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"])) feat.append(self.state_enc_layers(obs_dict["observation.state"]))
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
return torch.stack(feat, dim=0).mean(0) return torch.stack(feat, dim=0).mean(0)
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
class LagrangeMultiplier(nn.Module): class LagrangeMultiplier(nn.Module):
def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"): def __init__(
self,
init_value: float = 1.0,
constraint_shape: Sequence[int] = (),
device: str = "cuda"
):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1) init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
# Initialize the Lagrange multiplier as a parameter # Initialize the Lagrange multiplier as a parameter
self.lagrange = nn.Parameter( self.lagrange = nn.Parameter(
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device) torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
) )
self.to(self.device) self.to(self.device)
def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(
# Get the multiplier value based on parameterization self,
lhs: Optional[torch.Tensor | float | int] = None,
rhs: Optional[torch.Tensor | float | int] = None
) -> torch.Tensor:
# Get the multiplier value based on parameterization
multiplier = torch.nn.functional.softplus(self.lagrange) multiplier = torch.nn.functional.softplus(self.lagrange)
# Return the raw multiplier if no constraint values provided # Return the raw multiplier if no constraint values provided
if lhs is None: if lhs is None:
return multiplier return multiplier
# Move inputs to device # Convert inputs to tensors and move to device
lhs = lhs.to(self.device) lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device)
if rhs is not None: if rhs is not None:
rhs = rhs.to(self.device) rhs = torch.tensor(rhs, device=self.device) if not isinstance(rhs, torch.Tensor) else rhs.to(self.device)
else:
# Use the multiplier to compute the Lagrange penalty
if rhs is None:
rhs = torch.zeros_like(lhs, device=self.device) rhs = torch.zeros_like(lhs, device=self.device)
diff = lhs - rhs diff = lhs - rhs
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}" assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
return multiplier * diff return multiplier * diff
@ -508,7 +555,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}" assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
return nn.ModuleList(critics).to(device) return nn.ModuleList(critics).to(device)
# borrowed from tdmpc # borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor. """Helper to temporarily flatten extra dims at the start of the image tensor.
@ -516,7 +562,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
Args: Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions. (B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
can be more than 1 dimensions, generally different from *. can be more than 1 dimensions, generally different from *.
Returns: Returns:
A return value from the callable reshaped to (**, *). A return value from the callable reshaped to (**, *).
@ -526,4 +572,4 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
start_dims = image_tensor.shape[:-3] start_dims = image_tensor.shape[:-3]
inp = torch.flatten(image_tensor, end_dim=-4) inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp) flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))