diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 9507586c..a956cb4b 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,6 +1,6 @@ def make_policy(cfg): if cfg.policy.name == "tdmpc": - from lerobot.common.policies.tdmpc import TDMPC + from lerobot.common.policies.tdmpc.policy import TDMPC policy = TDMPC(cfg.policy, cfg.device) elif cfg.policy.name == "diffusion": diff --git a/lerobot/common/policies/tdmpc/__init__.py b/lerobot/common/policies/tdmpc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lerobot/common/policies/tdmpc_helper.py b/lerobot/common/policies/tdmpc/helper.py similarity index 100% rename from lerobot/common/policies/tdmpc_helper.py rename to lerobot/common/policies/tdmpc/helper.py diff --git a/lerobot/common/policies/tdmpc.py b/lerobot/common/policies/tdmpc/policy.py similarity index 99% rename from lerobot/common/policies/tdmpc.py rename to lerobot/common/policies/tdmpc/policy.py index 42fbb825..ae9888a5 100644 --- a/lerobot/common/policies/tdmpc.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -8,7 +8,7 @@ import numpy as np import torch import torch.nn as nn -import lerobot.common.policies.tdmpc_helper as h +import lerobot.common.policies.tdmpc.helper as h FIRST_FRAME = 0