fix caching
This commit is contained in:
parent
037ecae9e0
commit
7741526ce4
|
@ -187,8 +187,8 @@ class SACPolicy(
|
||||||
"""Select action for inference/evaluation"""
|
"""Select action for inference/evaluation"""
|
||||||
# We cached the encoder output to avoid recomputing it
|
# We cached the encoder output to avoid recomputing it
|
||||||
observations_features = None
|
observations_features = None
|
||||||
if self.shared_encoder and self.actor.encoder is not None:
|
if self.shared_encoder:
|
||||||
observations_features = self.actor.encoder(batch)
|
observations_features = self.actor.encoder.get_image_features(batch)
|
||||||
|
|
||||||
actions, _, _ = self.actor(batch, observations_features)
|
actions, _, _ = self.actor(batch, observations_features)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
@ -484,6 +484,109 @@ class SACPolicy(
|
||||||
return actor_loss
|
return actor_loss
|
||||||
|
|
||||||
|
|
||||||
|
class SACObservationEncoder(nn.Module):
|
||||||
|
"""Encode image and/or state vector observations."""
|
||||||
|
|
||||||
|
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
||||||
|
"""
|
||||||
|
Creates encoders for pixel and/or state modalities.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.input_normalization = input_normalizer
|
||||||
|
self.has_pretrained_vision_encoder = False
|
||||||
|
self.parameters_to_optimize = []
|
||||||
|
|
||||||
|
self.aggregation_size: int = 0
|
||||||
|
if any("observation.image" in key for key in config.input_features):
|
||||||
|
self.camera_number = config.camera_number
|
||||||
|
|
||||||
|
if self.config.vision_encoder_name is not None:
|
||||||
|
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||||
|
self.has_pretrained_vision_encoder = True
|
||||||
|
else:
|
||||||
|
self.image_enc_layers = DefaultImageEncoder(config)
|
||||||
|
|
||||||
|
self.aggregation_size += config.latent_dim * self.camera_number
|
||||||
|
|
||||||
|
if config.freeze_vision_encoder:
|
||||||
|
freeze_image_encoder(self.image_enc_layers)
|
||||||
|
else:
|
||||||
|
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||||
|
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||||
|
|
||||||
|
if "observation.state" in config.input_features:
|
||||||
|
self.state_enc_layers = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
in_features=config.input_features["observation.state"].shape[0],
|
||||||
|
out_features=config.latent_dim,
|
||||||
|
),
|
||||||
|
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
self.aggregation_size += config.latent_dim
|
||||||
|
|
||||||
|
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||||
|
|
||||||
|
if "observation.environment_state" in config.input_features:
|
||||||
|
self.env_state_enc_layers = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
in_features=config.input_features["observation.environment_state"].shape[0],
|
||||||
|
out_features=config.latent_dim,
|
||||||
|
),
|
||||||
|
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
self.aggregation_size += config.latent_dim
|
||||||
|
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||||
|
|
||||||
|
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||||
|
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
|
||||||
|
) -> Tensor:
|
||||||
|
"""Encode the image and/or state vector.
|
||||||
|
|
||||||
|
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||||
|
over all features.
|
||||||
|
"""
|
||||||
|
feat = []
|
||||||
|
obs_dict = self.input_normalization(obs_dict)
|
||||||
|
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
||||||
|
vision_encoder_cache = self.get_image_features(obs_dict)
|
||||||
|
feat.append(vision_encoder_cache)
|
||||||
|
|
||||||
|
if vision_encoder_cache is not None:
|
||||||
|
feat.append(vision_encoder_cache)
|
||||||
|
|
||||||
|
if "observation.environment_state" in self.config.input_features:
|
||||||
|
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||||
|
if "observation.state" in self.config.input_features:
|
||||||
|
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||||
|
|
||||||
|
features = torch.cat(tensors=feat, dim=-1)
|
||||||
|
features = self.aggregation_layer(features)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor:
|
||||||
|
# [N*B, C, H, W]
|
||||||
|
if len(self.all_image_keys) > 0:
|
||||||
|
# Batch all images along the batch dimension, then encode them.
|
||||||
|
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
|
||||||
|
images_batched = self.image_enc_layers(images_batched)
|
||||||
|
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||||
|
embeddings_image = torch.cat(embeddings_chunks, dim=-1)
|
||||||
|
return embeddings_image
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dim(self) -> int:
|
||||||
|
"""Returns the dimension of the encoder output"""
|
||||||
|
return self.config.latent_dim
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -606,7 +709,7 @@ class CriticEnsemble(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder: Optional[nn.Module],
|
encoder: SACObservationEncoder,
|
||||||
ensemble: List[CriticHead],
|
ensemble: List[CriticHead],
|
||||||
output_normalization: nn.Module,
|
output_normalization: nn.Module,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
|
@ -638,11 +741,7 @@ class CriticEnsemble(nn.Module):
|
||||||
actions = self.output_normalization(actions)["action"]
|
actions = self.output_normalization(actions)["action"]
|
||||||
actions = actions.to(device)
|
actions = actions.to(device)
|
||||||
|
|
||||||
obs_enc = (
|
obs_enc = self.encoder(observations, observation_features)
|
||||||
observation_features
|
|
||||||
if observation_features is not None
|
|
||||||
else (observations if self.encoder is None else self.encoder(observations))
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||||
|
|
||||||
|
@ -659,7 +758,7 @@ class CriticEnsemble(nn.Module):
|
||||||
class GraspCritic(nn.Module):
|
class GraspCritic(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder: Optional[nn.Module],
|
encoder: nn.Module,
|
||||||
input_dim: int,
|
input_dim: int,
|
||||||
hidden_dims: list[int],
|
hidden_dims: list[int],
|
||||||
output_dim: int = 3,
|
output_dim: int = 3,
|
||||||
|
@ -699,19 +798,14 @@ class GraspCritic(nn.Module):
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
||||||
observations = {k: v.to(device) for k, v in observations.items()}
|
observations = {k: v.to(device) for k, v in observations.items()}
|
||||||
# Encode observations if encoder exists
|
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||||
obs_enc = (
|
|
||||||
observation_features.to(device)
|
|
||||||
if observation_features is not None
|
|
||||||
else (observations if self.encoder is None else self.encoder(observations))
|
|
||||||
)
|
|
||||||
return self.output_layer(self.net(obs_enc))
|
return self.output_layer(self.net(obs_enc))
|
||||||
|
|
||||||
|
|
||||||
class Policy(nn.Module):
|
class Policy(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder: Optional[nn.Module],
|
encoder: SACObservationEncoder,
|
||||||
network: nn.Module,
|
network: nn.Module,
|
||||||
action_dim: int,
|
action_dim: int,
|
||||||
log_std_min: float = -5,
|
log_std_min: float = -5,
|
||||||
|
@ -722,7 +816,7 @@ class Policy(nn.Module):
|
||||||
encoder_is_shared: bool = False,
|
encoder_is_shared: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.encoder: SACObservationEncoder = encoder
|
||||||
self.network = network
|
self.network = network
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
self.log_std_min = log_std_min
|
self.log_std_min = log_std_min
|
||||||
|
@ -765,11 +859,7 @@ class Policy(nn.Module):
|
||||||
observation_features: torch.Tensor | None = None,
|
observation_features: torch.Tensor | None = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Encode observations if encoder exists
|
# Encode observations if encoder exists
|
||||||
obs_enc = (
|
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||||
observation_features
|
|
||||||
if observation_features is not None
|
|
||||||
else (observations if self.encoder is None else self.encoder(observations))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get network outputs
|
# Get network outputs
|
||||||
outputs = self.network(obs_enc)
|
outputs = self.network(obs_enc)
|
||||||
|
@ -813,96 +903,6 @@ class Policy(nn.Module):
|
||||||
return observations
|
return observations
|
||||||
|
|
||||||
|
|
||||||
class SACObservationEncoder(nn.Module):
|
|
||||||
"""Encode image and/or state vector observations."""
|
|
||||||
|
|
||||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
|
||||||
"""
|
|
||||||
Creates encoders for pixel and/or state modalities.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.input_normalization = input_normalizer
|
|
||||||
self.has_pretrained_vision_encoder = False
|
|
||||||
self.parameters_to_optimize = []
|
|
||||||
|
|
||||||
self.aggregation_size: int = 0
|
|
||||||
if any("observation.image" in key for key in config.input_features):
|
|
||||||
self.camera_number = config.camera_number
|
|
||||||
|
|
||||||
if self.config.vision_encoder_name is not None:
|
|
||||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
|
||||||
self.has_pretrained_vision_encoder = True
|
|
||||||
else:
|
|
||||||
self.image_enc_layers = DefaultImageEncoder(config)
|
|
||||||
|
|
||||||
self.aggregation_size += config.latent_dim * self.camera_number
|
|
||||||
|
|
||||||
if config.freeze_vision_encoder:
|
|
||||||
freeze_image_encoder(self.image_enc_layers)
|
|
||||||
else:
|
|
||||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
|
||||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
|
||||||
|
|
||||||
if "observation.state" in config.input_features:
|
|
||||||
self.state_enc_layers = nn.Sequential(
|
|
||||||
nn.Linear(
|
|
||||||
in_features=config.input_features["observation.state"].shape[0],
|
|
||||||
out_features=config.latent_dim,
|
|
||||||
),
|
|
||||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
|
||||||
nn.Tanh(),
|
|
||||||
)
|
|
||||||
self.aggregation_size += config.latent_dim
|
|
||||||
|
|
||||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
|
||||||
|
|
||||||
if "observation.environment_state" in config.input_features:
|
|
||||||
self.env_state_enc_layers = nn.Sequential(
|
|
||||||
nn.Linear(
|
|
||||||
in_features=config.input_features["observation.environment_state"].shape[0],
|
|
||||||
out_features=config.latent_dim,
|
|
||||||
),
|
|
||||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
|
||||||
nn.Tanh(),
|
|
||||||
)
|
|
||||||
self.aggregation_size += config.latent_dim
|
|
||||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
|
||||||
|
|
||||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
|
||||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
|
||||||
|
|
||||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
|
||||||
"""Encode the image and/or state vector.
|
|
||||||
|
|
||||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
|
||||||
over all features.
|
|
||||||
"""
|
|
||||||
feat = []
|
|
||||||
obs_dict = self.input_normalization(obs_dict)
|
|
||||||
# Batch all images along the batch dimension, then encode them.
|
|
||||||
if len(self.all_image_keys) > 0:
|
|
||||||
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
|
|
||||||
images_batched = self.image_enc_layers(images_batched)
|
|
||||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
|
||||||
feat.extend(embeddings_chunks)
|
|
||||||
|
|
||||||
if "observation.environment_state" in self.config.input_features:
|
|
||||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
|
||||||
if "observation.state" in self.config.input_features:
|
|
||||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
|
||||||
|
|
||||||
features = torch.cat(tensors=feat, dim=-1)
|
|
||||||
features = self.aggregation_layer(features)
|
|
||||||
|
|
||||||
return features
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_dim(self) -> int:
|
|
||||||
"""Returns the dimension of the encoder output"""
|
|
||||||
return self.config.latent_dim
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultImageEncoder(nn.Module):
|
class DefaultImageEncoder(nn.Module):
|
||||||
def __init__(self, config: SACConfig):
|
def __init__(self, config: SACConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -775,7 +775,9 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||||
params=policy.actor.parameters_to_optimize,
|
params=policy.actor.parameters_to_optimize,
|
||||||
lr=cfg.policy.actor_lr,
|
lr=cfg.policy.actor_lr,
|
||||||
)
|
)
|
||||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
optimizer_critic = torch.optim.Adam(
|
||||||
|
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.policy.num_discrete_actions is not None:
|
if cfg.policy.num_discrete_actions is not None:
|
||||||
optimizer_grasp_critic = torch.optim.Adam(
|
optimizer_grasp_critic = torch.optim.Adam(
|
||||||
|
@ -1024,12 +1026,8 @@ def get_observation_features(
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
observation_features = (
|
observation_features = policy.actor.encoder.get_image_features(observations)
|
||||||
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
|
next_observation_features = policy.actor.encoder.get_image_features(next_observations)
|
||||||
)
|
|
||||||
next_observation_features = (
|
|
||||||
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return observation_features, next_observation_features
|
return observation_features, next_observation_features
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue