[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
808cf63221
commit
c05e4835d0
|
@ -14,7 +14,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
|
@ -108,6 +108,7 @@ class MultiAdamConfig(OptimizerConfig):
|
||||||
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
|
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
|
||||||
grad_clip_norm: Gradient clipping norm
|
grad_clip_norm: Gradient clipping norm
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lr: float = 1e-3
|
lr: float = 1e-3
|
||||||
weight_decay: float = 0.0
|
weight_decay: float = 0.0
|
||||||
grad_clip_norm: float = 10.0
|
grad_clip_norm: float = 10.0
|
||||||
|
@ -142,7 +143,9 @@ class MultiAdamConfig(OptimizerConfig):
|
||||||
return optimizers
|
return optimizers
|
||||||
|
|
||||||
|
|
||||||
def save_optimizer_state(optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path) -> None:
|
def save_optimizer_state(
|
||||||
|
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||||
|
) -> None:
|
||||||
"""Save optimizer state to disk.
|
"""Save optimizer state to disk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -24,11 +24,11 @@ from lerobot.common.envs.configs import EnvConfig
|
||||||
from lerobot.common.envs.utils import env_to_policy_features
|
from lerobot.common.envs.utils import env_to_policy_features
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import FeatureType
|
from lerobot.configs.types import FeatureType
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List
|
from typing import List
|
||||||
|
|
||||||
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
|
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
|
||||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
|
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
|
||||||
|
|
|
@ -82,7 +82,9 @@ def create_stats_buffers(
|
||||||
if stats and key in stats:
|
if stats and key in stats:
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
if "mean" not in stats[key] or "std" not in stats[key]:
|
if "mean" not in stats[key] or "std" not in stats[key]:
|
||||||
raise ValueError(f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization")
|
raise ValueError(
|
||||||
|
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(stats[key]["mean"], np.ndarray):
|
if isinstance(stats[key]["mean"], np.ndarray):
|
||||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||||
|
@ -96,11 +98,15 @@ def create_stats_buffers(
|
||||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
type_ = type(stats[key]["mean"])
|
type_ = type(stats[key]["mean"])
|
||||||
raise ValueError(f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead.")
|
raise ValueError(
|
||||||
|
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
|
||||||
|
)
|
||||||
|
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
if "min" not in stats[key] or "max" not in stats[key]:
|
if "min" not in stats[key] or "max" not in stats[key]:
|
||||||
raise ValueError(f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization")
|
raise ValueError(
|
||||||
|
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(stats[key]["min"], np.ndarray):
|
if isinstance(stats[key]["min"], np.ndarray):
|
||||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||||
|
@ -110,7 +116,9 @@ def create_stats_buffers(
|
||||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
type_ = type(stats[key]["min"])
|
type_ = type(stats[key]["min"])
|
||||||
raise ValueError(f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead.")
|
raise ValueError(
|
||||||
|
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
|
||||||
|
)
|
||||||
|
|
||||||
stats_buffers[key] = buffer
|
stats_buffers[key] = buffer
|
||||||
return stats_buffers
|
return stats_buffers
|
||||||
|
|
|
@ -19,7 +19,7 @@ from dataclasses import dataclass, field
|
||||||
|
|
||||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -897,7 +897,6 @@ if __name__ == "__main__":
|
||||||
# for j in range(i + 1, num_critics):
|
# for j in range(i + 1, num_critics):
|
||||||
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
||||||
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
||||||
import draccus
|
|
||||||
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
|
|
||||||
|
|
|
@ -115,11 +115,13 @@ class WandBLogger:
|
||||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||||
self._wandb.log_artifact(artifact)
|
self._wandb.log_artifact(artifact)
|
||||||
|
|
||||||
def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None):
|
def log_dict(
|
||||||
|
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
|
||||||
|
):
|
||||||
if mode not in {"train", "eval"}:
|
if mode not in {"train", "eval"}:
|
||||||
raise ValueError(mode)
|
raise ValueError(mode)
|
||||||
if step is None and custom_step_key is None:
|
if step is None and custom_step_key is None:
|
||||||
raise ValueError("Either step or custom_step_key must be provided.")
|
raise ValueError("Either step or custom_step_key must be provided.")
|
||||||
|
|
||||||
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
|
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
|
||||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||||
|
@ -142,10 +144,7 @@ class WandBLogger:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Do not log the custom step key itself.
|
# Do not log the custom step key itself.
|
||||||
if (
|
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||||
self._wandb_custom_step_key is not None
|
|
||||||
and k in self._wandb_custom_step_key
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if custom_step_key is not None:
|
if custom_step_key is not None:
|
||||||
|
@ -160,7 +159,6 @@ class WandBLogger:
|
||||||
|
|
||||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||||
|
|
||||||
|
|
||||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||||
if mode not in {"train", "eval"}:
|
if mode not in {"train", "eval"}:
|
||||||
raise ValueError(mode)
|
raise ValueError(mode)
|
||||||
|
|
|
@ -34,7 +34,7 @@ TRAIN_CONFIG_NAME = "train_config.json"
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainPipelineConfig(HubMixin):
|
class TrainPipelineConfig(HubMixin):
|
||||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
|
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
|
||||||
env: envs.EnvConfig | None = None
|
env: envs.EnvConfig | None = None
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||||
|
|
|
@ -47,7 +47,7 @@ from lerobot.scripts.server.buffer import (
|
||||||
python_object_to_bytes,
|
python_object_to_bytes,
|
||||||
transitions_to_bytes,
|
transitions_to_bytes,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||||
from lerobot.scripts.server.network_utils import (
|
from lerobot.scripts.server.network_utils import (
|
||||||
receive_bytes_in_chunks,
|
receive_bytes_in_chunks,
|
||||||
send_bytes_in_chunks,
|
send_bytes_in_chunks,
|
||||||
|
@ -444,7 +444,7 @@ def receive_policy(
|
||||||
|
|
||||||
# Initialize logging with explicit log file
|
# Initialize logging with explicit log file
|
||||||
init_logging(log_file=log_file)
|
init_logging(log_file=log_file)
|
||||||
logging.info(f"Actor receive policy process logging initialized")
|
logging.info("Actor receive policy process logging initialized")
|
||||||
|
|
||||||
# Setup process handlers to handle shutdown signal
|
# Setup process handlers to handle shutdown signal
|
||||||
# But use shutdown event from the main process
|
# But use shutdown event from the main process
|
||||||
|
|
|
@ -701,10 +701,10 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvWrapperConfig, HILSerlRobotEnvConfig
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||||
from lerobot.common.envs.configs import HILSerlRobotEnvConfig, EEActionSpaceConfig, EnvWrapperConfig
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
@ -5,14 +5,15 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.common.robot_devices.control_utils import is_headless
|
from lerobot.common.robot_devices.control_utils import is_headless
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
|
||||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
|
||||||
from lerobot.configs import parser
|
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
|
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||||
|
from lerobot.configs import parser
|
||||||
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||||
|
|
||||||
follower_port = "/dev/tty.usbmodem58760431631"
|
follower_port = "/dev/tty.usbmodem58760431631"
|
||||||
leader_port = "/dev/tty.usbmodem58760433331"
|
leader_port = "/dev/tty.usbmodem58760433331"
|
||||||
|
|
||||||
|
|
||||||
def find_joint_bounds(
|
def find_joint_bounds(
|
||||||
robot,
|
robot,
|
||||||
control_time_s=30,
|
control_time_s=30,
|
||||||
|
@ -100,6 +101,7 @@ def make_robot(robot_type="so100"):
|
||||||
|
|
||||||
return make_robot_from_config(robot_config)
|
return make_robot_from_config(robot_config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Create argparse for script-specific arguments
|
# Create argparse for script-specific arguments
|
||||||
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
|
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from typing import Any
|
||||||
from typing import Any, Dict, Optional, Tuple
|
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
@ -10,10 +9,8 @@ import torch
|
||||||
from mani_skill.utils.wrappers.record import RecordEpisode
|
from mani_skill.utils.wrappers.record import RecordEpisode
|
||||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||||
|
|
||||||
from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig
|
from lerobot.common.envs.configs import ManiskillEnvConfig
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
||||||
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_maniskill_observation(
|
def preprocess_maniskill_observation(
|
||||||
|
@ -53,9 +50,6 @@ def preprocess_maniskill_observation(
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
||||||
def __init__(self, env, device: torch.device = "cuda"):
|
def __init__(self, env, device: torch.device = "cuda"):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
@ -122,6 +116,7 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||||
|
|
||||||
class BatchCompatibleWrapper(gym.ObservationWrapper):
|
class BatchCompatibleWrapper(gym.ObservationWrapper):
|
||||||
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
|
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
|
||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
|
@ -136,6 +131,7 @@ class BatchCompatibleWrapper(gym.ObservationWrapper):
|
||||||
|
|
||||||
class TimeLimitWrapper(gym.Wrapper):
|
class TimeLimitWrapper(gym.Wrapper):
|
||||||
"""Adds a time limit to the environment based on fps and control_time."""
|
"""Adds a time limit to the environment based on fps and control_time."""
|
||||||
|
|
||||||
def __init__(self, env, control_time_s, fps):
|
def __init__(self, env, control_time_s, fps):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.control_time_s = control_time_s
|
self.control_time_s = control_time_s
|
||||||
|
@ -243,6 +239,10 @@ def main(cfg: ManiskillEnvConfig):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
config = ManiskillEnvConfig()
|
config = ManiskillEnvConfig()
|
||||||
draccus.set_config_type("json")
|
draccus.set_config_type("json")
|
||||||
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
|
draccus.dump(
|
||||||
|
config=config,
|
||||||
|
stream=open(file="run_config.json", mode="w"),
|
||||||
|
)
|
||||||
|
|
|
@ -32,7 +32,6 @@ from tqdm import tqdm
|
||||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||||
|
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
|
|
|
@ -205,10 +205,7 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers,
|
||||||
|
|
||||||
# Verify state dictionaries match
|
# Verify state dictionaries match
|
||||||
for name in multi_optimizers.keys():
|
for name in multi_optimizers.keys():
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
|
||||||
multi_optimizers[name].state_dict(),
|
|
||||||
loaded_optimizers[name].state_dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
||||||
|
@ -241,7 +238,5 @@ def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
||||||
|
|
||||||
# Verify state dictionaries match (they will be empty)
|
# Verify state dictionaries match (they will be empty)
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
optimizer.state_dict()["param_groups"],
|
optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
|
||||||
loaded_optimizers[name].state_dict()["param_groups"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue