Handle caching

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
AdilZouitine 2025-04-15 13:02:31 +00:00
parent eac79a006d
commit cf8d995c3a
1 changed files with 3 additions and 4 deletions

View File

@ -772,11 +772,10 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
"""
optimizer_actor = torch.optim.Adam(
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=[
p
for n, p in policy.actor.named_parameters()
if not n.startswith("encoder") or not policy.config.shared_encoder
if not policy.config.shared_encoder or not n.startswith("encoder")
],
lr=cfg.policy.actor_lr,
)
@ -1028,8 +1027,8 @@ def get_observation_features(
return None, None
with torch.no_grad():
observation_features = policy.actor.encoder.get_base_image_features(observations, normalize=True)
next_observation_features = policy.actor.encoder.get_base_image_features(
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
next_observation_features = policy.actor.encoder.get_cached_image_features(
next_observations, normalize=True
)