diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 4b20bce4..629c6576 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -16,11 +16,10 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Any, Optional from lerobot.common.optim.optimizers import MultiAdamConfig from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature @dataclass @@ -221,7 +220,7 @@ class SACConfig(PreTrainedConfig): @property def image_features(self) -> list[str]: - return [key for key in self.input_features.keys() if "image" in key] + return [key for key in self.input_features if "image" in key] @property def observation_delta_indices(self) -> list: @@ -234,14 +233,3 @@ class SACConfig(PreTrainedConfig): @property def reward_delta_indices(self) -> None: return None - - -if __name__ == "__main__": - import draccus - - config = SACConfig() - draccus.set_config_type("json") - draccus.dump( - config=config, - stream=open(file="run_config.json", mode="w"), - )