Handle caching
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
parent
157f719d5f
commit
0c9a3ec301
|
@ -772,11 +772,10 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
optimizer_actor = torch.optim.Adam(
|
optimizer_actor = torch.optim.Adam(
|
||||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
|
||||||
params=[
|
params=[
|
||||||
p
|
p
|
||||||
for n, p in policy.actor.named_parameters()
|
for n, p in policy.actor.named_parameters()
|
||||||
if not n.startswith("encoder") or not policy.config.shared_encoder
|
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||||
],
|
],
|
||||||
lr=cfg.policy.actor_lr,
|
lr=cfg.policy.actor_lr,
|
||||||
)
|
)
|
||||||
|
@ -1028,8 +1027,8 @@ def get_observation_features(
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
observation_features = policy.actor.encoder.get_base_image_features(observations, normalize=True)
|
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
|
||||||
next_observation_features = policy.actor.encoder.get_base_image_features(
|
next_observation_features = policy.actor.encoder.get_cached_image_features(
|
||||||
next_observations, normalize=True
|
next_observations, normalize=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue