Make sure to make remove all traces of omegaconf from policy config

This commit is contained in:
Alexander Soare 2024-04-15 09:59:18 +01:00
parent 9241b5e830
commit 40d417ef60
1 changed files with 9 additions and 1 deletions

View File

@ -1,5 +1,7 @@
import inspect
from omegaconf import OmegaConf
from lerobot.common.utils import get_safe_torch_device
@ -33,7 +35,13 @@ def make_policy(cfg):
assert set(cfg.policy).issuperset(
expected_kwargs
), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}"
policy_cfg = ActConfig(**{k: v for k, v in cfg.policy.items() if k in expected_kwargs})
policy_cfg = ActConfig(
**{
k: v
for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items()
if k in expected_kwargs
}
)
policy = ActPolicy(policy_cfg)
policy.to(get_safe_torch_device(cfg.device))
else: