fix caching
This commit is contained in:
parent
037ecae9e0
commit
7741526ce4
|
@ -187,8 +187,8 @@ class SACPolicy(
|
|||
"""Select action for inference/evaluation"""
|
||||
# We cached the encoder output to avoid recomputing it
|
||||
observations_features = None
|
||||
if self.shared_encoder and self.actor.encoder is not None:
|
||||
observations_features = self.actor.encoder(batch)
|
||||
if self.shared_encoder:
|
||||
observations_features = self.actor.encoder.get_image_features(batch)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
@ -484,6 +484,109 @@ class SACPolicy(
|
|||
return actor_loss
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self.has_pretrained_vision_encoder = False
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
self.aggregation_size: int = 0
|
||||
if any("observation.image" in key for key in config.input_features):
|
||||
self.camera_number = config.camera_number
|
||||
|
||||
if self.config.vision_encoder_name is not None:
|
||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||
self.has_pretrained_vision_encoder = True
|
||||
else:
|
||||
self.image_enc_layers = DefaultImageEncoder(config)
|
||||
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
|
||||
if config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||
|
||||
if "observation.state" in config.input_features:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_features["observation.state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||
|
||||
if "observation.environment_state" in config.input_features:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_features["observation.environment_state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
|
||||
) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
||||
vision_encoder_cache = self.get_image_features(obs_dict)
|
||||
feat.append(vision_encoder_cache)
|
||||
|
||||
if vision_encoder_cache is not None:
|
||||
feat.append(vision_encoder_cache)
|
||||
|
||||
if "observation.environment_state" in self.config.input_features:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_features:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
|
||||
return features
|
||||
|
||||
def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor:
|
||||
# [N*B, C, H, W]
|
||||
if len(self.all_image_keys) > 0:
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
|
||||
images_batched = self.image_enc_layers(images_batched)
|
||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||
embeddings_image = torch.cat(embeddings_chunks, dim=-1)
|
||||
return embeddings_image
|
||||
return None
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
return self.config.latent_dim
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -606,7 +709,7 @@ class CriticEnsemble(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
encoder: SACObservationEncoder,
|
||||
ensemble: List[CriticHead],
|
||||
output_normalization: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
|
@ -638,11 +741,7 @@ class CriticEnsemble(nn.Module):
|
|||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = (
|
||||
observation_features
|
||||
if observation_features is not None
|
||||
else (observations if self.encoder is None else self.encoder(observations))
|
||||
)
|
||||
obs_enc = self.encoder(observations, observation_features)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
|
@ -659,7 +758,7 @@ class CriticEnsemble(nn.Module):
|
|||
class GraspCritic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
encoder: nn.Module,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
output_dim: int = 3,
|
||||
|
@ -699,19 +798,14 @@ class GraspCritic(nn.Module):
|
|||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = (
|
||||
observation_features.to(device)
|
||||
if observation_features is not None
|
||||
else (observations if self.encoder is None else self.encoder(observations))
|
||||
)
|
||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||
return self.output_layer(self.net(obs_enc))
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
encoder: SACObservationEncoder,
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
log_std_min: float = -5,
|
||||
|
@ -722,7 +816,7 @@ class Policy(nn.Module):
|
|||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder: SACObservationEncoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.log_std_min = log_std_min
|
||||
|
@ -765,11 +859,7 @@ class Policy(nn.Module):
|
|||
observation_features: torch.Tensor | None = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = (
|
||||
observation_features
|
||||
if observation_features is not None
|
||||
else (observations if self.encoder is None else self.encoder(observations))
|
||||
)
|
||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
|
@ -813,96 +903,6 @@ class Policy(nn.Module):
|
|||
return observations
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self.has_pretrained_vision_encoder = False
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
self.aggregation_size: int = 0
|
||||
if any("observation.image" in key for key in config.input_features):
|
||||
self.camera_number = config.camera_number
|
||||
|
||||
if self.config.vision_encoder_name is not None:
|
||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||
self.has_pretrained_vision_encoder = True
|
||||
else:
|
||||
self.image_enc_layers = DefaultImageEncoder(config)
|
||||
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
|
||||
if config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||
|
||||
if "observation.state" in config.input_features:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_features["observation.state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||
|
||||
if "observation.environment_state" in config.input_features:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_features["observation.environment_state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
if len(self.all_image_keys) > 0:
|
||||
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
|
||||
images_batched = self.image_enc_layers(images_batched)
|
||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||
feat.extend(embeddings_chunks)
|
||||
|
||||
if "observation.environment_state" in self.config.input_features:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_features:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
|
||||
return features
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
return self.config.latent_dim
|
||||
|
||||
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
super().__init__()
|
||||
|
|
|
@ -775,7 +775,9 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
|||
params=policy.actor.parameters_to_optimize,
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||
)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
|
@ -1024,12 +1026,8 @@ def get_observation_features(
|
|||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = (
|
||||
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
|
||||
)
|
||||
next_observation_features = (
|
||||
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
|
||||
)
|
||||
observation_features = policy.actor.encoder.get_image_features(observations)
|
||||
next_observation_features = policy.actor.encoder.get_image_features(next_observations)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
|
Loading…
Reference in New Issue