fix caching

This commit is contained in:
AdilZouitine 2025-04-04 14:29:38 +00:00
parent 037ecae9e0
commit 7741526ce4
2 changed files with 117 additions and 119 deletions

View File

@ -187,8 +187,8 @@ class SACPolicy(
"""Select action for inference/evaluation"""
# We cached the encoder output to avoid recomputing it
observations_features = None
if self.shared_encoder and self.actor.encoder is not None:
observations_features = self.actor.encoder(batch)
if self.shared_encoder:
observations_features = self.actor.encoder.get_image_features(batch)
actions, _, _ = self.actor(batch, observations_features)
actions = self.unnormalize_outputs({"action": actions})["action"]
@ -484,6 +484,109 @@ class SACPolicy(
return actor_loss
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_features):
self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True
else:
self.image_enc_layers = DefaultImageEncoder(config)
self.aggregation_size += config.latent_dim * self.camera_number
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
if "observation.state" in config.input_features:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_features["observation.state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_features:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_features["observation.environment_state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
vision_encoder_cache = self.get_image_features(obs_dict)
feat.append(vision_encoder_cache)
if vision_encoder_cache is not None:
feat.append(vision_encoder_cache)
if "observation.environment_state" in self.config.input_features:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_features:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1)
features = self.aggregation_layer(features)
return features
def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor:
# [N*B, C, H, W]
if len(self.all_image_keys) > 0:
# Batch all images along the batch dimension, then encode them.
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
embeddings_image = torch.cat(embeddings_chunks, dim=-1)
return embeddings_image
return None
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
class MLP(nn.Module):
def __init__(
self,
@ -606,7 +709,7 @@ class CriticEnsemble(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
encoder: SACObservationEncoder,
ensemble: List[CriticHead],
output_normalization: nn.Module,
init_final: Optional[float] = None,
@ -638,11 +741,7 @@ class CriticEnsemble(nn.Module):
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
obs_enc = self.encoder(observations, observation_features)
inputs = torch.cat([obs_enc, actions], dim=-1)
@ -659,7 +758,7 @@ class CriticEnsemble(nn.Module):
class GraspCritic(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
encoder: nn.Module,
input_dim: int,
hidden_dims: list[int],
output_dim: int = 3,
@ -699,19 +798,14 @@ class GraspCritic(nn.Module):
device = get_device_from_parameters(self)
# Move each tensor in observations to device by cloning first to avoid inplace operations
observations = {k: v.to(device) for k, v in observations.items()}
# Encode observations if encoder exists
obs_enc = (
observation_features.to(device)
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
return self.output_layer(self.net(obs_enc))
class Policy(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
encoder: SACObservationEncoder,
network: nn.Module,
action_dim: int,
log_std_min: float = -5,
@ -722,7 +816,7 @@ class Policy(nn.Module):
encoder_is_shared: bool = False,
):
super().__init__()
self.encoder = encoder
self.encoder: SACObservationEncoder = encoder
self.network = network
self.action_dim = action_dim
self.log_std_min = log_std_min
@ -765,11 +859,7 @@ class Policy(nn.Module):
observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
# Get network outputs
outputs = self.network(obs_enc)
@ -813,96 +903,6 @@ class Policy(nn.Module):
return observations
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_features):
self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True
else:
self.image_enc_layers = DefaultImageEncoder(config)
self.aggregation_size += config.latent_dim * self.camera_number
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
if "observation.state" in config.input_features:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_features["observation.state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_features:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_features["observation.environment_state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
# Batch all images along the batch dimension, then encode them.
if len(self.all_image_keys) > 0:
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_features:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_features:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1)
features = self.aggregation_layer(features)
return features
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
class DefaultImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
super().__init__()

View File

@ -775,7 +775,9 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
params=policy.actor.parameters_to_optimize,
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
)
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam(
@ -1024,12 +1026,8 @@ def get_observation_features(
return None, None
with torch.no_grad():
observation_features = (
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
)
next_observation_features = (
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
)
observation_features = policy.actor.encoder.get_image_features(observations)
next_observation_features = policy.actor.encoder.get_image_features(next_observations)
return observation_features, next_observation_features