Handle caching
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
parent
eac79a006d
commit
cf8d995c3a
|
@ -772,11 +772,10 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
|||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not policy.config.shared_encoder
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
|
@ -1028,8 +1027,8 @@ def get_observation_features(
|
|||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_base_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_base_image_features(
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(
|
||||
next_observations, normalize=True
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue