[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
808cf63221
commit
c05e4835d0
|
@ -14,7 +14,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
|
@ -99,36 +99,37 @@ class SGDConfig(OptimizerConfig):
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultiAdamConfig(OptimizerConfig):
|
class MultiAdamConfig(OptimizerConfig):
|
||||||
"""Configuration for multiple Adam optimizers with different parameter groups.
|
"""Configuration for multiple Adam optimizers with different parameter groups.
|
||||||
|
|
||||||
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
|
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lr: Default learning rate (used if not specified for a group)
|
lr: Default learning rate (used if not specified for a group)
|
||||||
weight_decay: Default weight decay (used if not specified for a group)
|
weight_decay: Default weight decay (used if not specified for a group)
|
||||||
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
|
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
|
||||||
grad_clip_norm: Gradient clipping norm
|
grad_clip_norm: Gradient clipping norm
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lr: float = 1e-3
|
lr: float = 1e-3
|
||||||
weight_decay: float = 0.0
|
weight_decay: float = 0.0
|
||||||
grad_clip_norm: float = 10.0
|
grad_clip_norm: float = 10.0
|
||||||
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||||
|
|
||||||
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
||||||
"""Build multiple Adam optimizers.
|
"""Build multiple Adam optimizers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params_dict: Dictionary mapping parameter group names to lists of parameters
|
params_dict: Dictionary mapping parameter group names to lists of parameters
|
||||||
The keys should match the keys in optimizer_groups
|
The keys should match the keys in optimizer_groups
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary mapping parameter group names to their optimizers
|
Dictionary mapping parameter group names to their optimizers
|
||||||
"""
|
"""
|
||||||
optimizers = {}
|
optimizers = {}
|
||||||
|
|
||||||
for name, params in params_dict.items():
|
for name, params in params_dict.items():
|
||||||
# Get group-specific hyperparameters or use defaults
|
# Get group-specific hyperparameters or use defaults
|
||||||
group_config = self.optimizer_groups.get(name, {})
|
group_config = self.optimizer_groups.get(name, {})
|
||||||
|
|
||||||
# Create optimizer with merged parameters (defaults + group-specific)
|
# Create optimizer with merged parameters (defaults + group-specific)
|
||||||
optimizer_kwargs = {
|
optimizer_kwargs = {
|
||||||
"lr": group_config.get("lr", self.lr),
|
"lr": group_config.get("lr", self.lr),
|
||||||
|
@ -136,15 +137,17 @@ class MultiAdamConfig(OptimizerConfig):
|
||||||
"eps": group_config.get("eps", 1e-5),
|
"eps": group_config.get("eps", 1e-5),
|
||||||
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
||||||
}
|
}
|
||||||
|
|
||||||
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
||||||
|
|
||||||
return optimizers
|
return optimizers
|
||||||
|
|
||||||
|
|
||||||
def save_optimizer_state(optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path) -> None:
|
def save_optimizer_state(
|
||||||
|
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||||
|
) -> None:
|
||||||
"""Save optimizer state to disk.
|
"""Save optimizer state to disk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||||
save_dir: Directory to save the optimizer state.
|
save_dir: Directory to save the optimizer state.
|
||||||
|
@ -173,11 +176,11 @@ def load_optimizer_state(
|
||||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||||
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||||
"""Load optimizer state from disk.
|
"""Load optimizer state from disk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||||
save_dir: Directory to load the optimizer state from.
|
save_dir: Directory to load the optimizer state from.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The updated optimizer(s) with loaded state.
|
The updated optimizer(s) with loaded state.
|
||||||
"""
|
"""
|
||||||
|
@ -201,7 +204,7 @@ def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
|
||||||
current_state_dict = optimizer.state_dict()
|
current_state_dict = optimizer.state_dict()
|
||||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||||
state = unflatten_dict(flat_state)
|
state = unflatten_dict(flat_state)
|
||||||
|
|
||||||
# Handle case where 'state' key might not exist (for newly created optimizers)
|
# Handle case where 'state' key might not exist (for newly created optimizers)
|
||||||
if "state" in state:
|
if "state" in state:
|
||||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||||
|
|
|
@ -24,11 +24,11 @@ from lerobot.common.envs.configs import EnvConfig
|
||||||
from lerobot.common.envs.utils import env_to_policy_features
|
from lerobot.common.envs.utils import env_to_policy_features
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import FeatureType
|
from lerobot.configs.types import FeatureType
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List
|
from typing import List
|
||||||
|
|
||||||
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
|
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
|
||||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
|
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
|
||||||
|
|
|
@ -82,8 +82,10 @@ def create_stats_buffers(
|
||||||
if stats and key in stats:
|
if stats and key in stats:
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
if "mean" not in stats[key] or "std" not in stats[key]:
|
if "mean" not in stats[key] or "std" not in stats[key]:
|
||||||
raise ValueError(f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization")
|
raise ValueError(
|
||||||
|
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(stats[key]["mean"], np.ndarray):
|
if isinstance(stats[key]["mean"], np.ndarray):
|
||||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||||
|
@ -96,12 +98,16 @@ def create_stats_buffers(
|
||||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
type_ = type(stats[key]["mean"])
|
type_ = type(stats[key]["mean"])
|
||||||
raise ValueError(f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead.")
|
raise ValueError(
|
||||||
|
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
|
||||||
|
)
|
||||||
|
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
if "min" not in stats[key] or "max" not in stats[key]:
|
if "min" not in stats[key] or "max" not in stats[key]:
|
||||||
raise ValueError(f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization")
|
raise ValueError(
|
||||||
|
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(stats[key]["min"], np.ndarray):
|
if isinstance(stats[key]["min"], np.ndarray):
|
||||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||||
|
@ -110,7 +116,9 @@ def create_stats_buffers(
|
||||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
type_ = type(stats[key]["min"])
|
type_ = type(stats[key]["min"])
|
||||||
raise ValueError(f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead.")
|
raise ValueError(
|
||||||
|
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
|
||||||
|
)
|
||||||
|
|
||||||
stats_buffers[key] = buffer
|
stats_buffers[key] = buffer
|
||||||
return stats_buffers
|
return stats_buffers
|
||||||
|
|
|
@ -19,7 +19,7 @@ from dataclasses import dataclass, field
|
||||||
|
|
||||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -897,7 +897,6 @@ if __name__ == "__main__":
|
||||||
# for j in range(i + 1, num_critics):
|
# for j in range(i + 1, num_critics):
|
||||||
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
||||||
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
||||||
import draccus
|
|
||||||
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
|
|
||||||
|
|
|
@ -115,11 +115,13 @@ class WandBLogger:
|
||||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||||
self._wandb.log_artifact(artifact)
|
self._wandb.log_artifact(artifact)
|
||||||
|
|
||||||
def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None):
|
def log_dict(
|
||||||
|
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
|
||||||
|
):
|
||||||
if mode not in {"train", "eval"}:
|
if mode not in {"train", "eval"}:
|
||||||
raise ValueError(mode)
|
raise ValueError(mode)
|
||||||
if step is None and custom_step_key is None:
|
if step is None and custom_step_key is None:
|
||||||
raise ValueError("Either step or custom_step_key must be provided.")
|
raise ValueError("Either step or custom_step_key must be provided.")
|
||||||
|
|
||||||
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
|
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
|
||||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||||
|
@ -142,10 +144,7 @@ class WandBLogger:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Do not log the custom step key itself.
|
# Do not log the custom step key itself.
|
||||||
if (
|
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||||
self._wandb_custom_step_key is not None
|
|
||||||
and k in self._wandb_custom_step_key
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if custom_step_key is not None:
|
if custom_step_key is not None:
|
||||||
|
@ -160,7 +159,6 @@ class WandBLogger:
|
||||||
|
|
||||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||||
|
|
||||||
|
|
||||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||||
if mode not in {"train", "eval"}:
|
if mode not in {"train", "eval"}:
|
||||||
raise ValueError(mode)
|
raise ValueError(mode)
|
||||||
|
|
|
@ -34,7 +34,7 @@ TRAIN_CONFIG_NAME = "train_config.json"
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainPipelineConfig(HubMixin):
|
class TrainPipelineConfig(HubMixin):
|
||||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
|
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
|
||||||
env: envs.EnvConfig | None = None
|
env: envs.EnvConfig | None = None
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||||
|
|
|
@ -47,7 +47,7 @@ from lerobot.scripts.server.buffer import (
|
||||||
python_object_to_bytes,
|
python_object_to_bytes,
|
||||||
transitions_to_bytes,
|
transitions_to_bytes,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||||
from lerobot.scripts.server.network_utils import (
|
from lerobot.scripts.server.network_utils import (
|
||||||
receive_bytes_in_chunks,
|
receive_bytes_in_chunks,
|
||||||
send_bytes_in_chunks,
|
send_bytes_in_chunks,
|
||||||
|
@ -444,7 +444,7 @@ def receive_policy(
|
||||||
|
|
||||||
# Initialize logging with explicit log file
|
# Initialize logging with explicit log file
|
||||||
init_logging(log_file=log_file)
|
init_logging(log_file=log_file)
|
||||||
logging.info(f"Actor receive policy process logging initialized")
|
logging.info("Actor receive policy process logging initialized")
|
||||||
|
|
||||||
# Setup process handlers to handle shutdown signal
|
# Setup process handlers to handle shutdown signal
|
||||||
# But use shutdown event from the main process
|
# But use shutdown event from the main process
|
||||||
|
|
|
@ -515,7 +515,7 @@ class ReplayBuffer:
|
||||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||||
|
|
||||||
# Add task field which is required by LeRobotDataset
|
# Add task field which is required by LeRobotDataset
|
||||||
frame_dict["task"] = task_name
|
frame_dict["task"] = task_name
|
||||||
|
|
||||||
|
|
|
@ -701,10 +701,10 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvWrapperConfig, HILSerlRobotEnvConfig
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||||
from lerobot.common.envs.configs import HILSerlRobotEnvConfig, EEActionSpaceConfig, EnvWrapperConfig
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
@ -5,14 +5,15 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.common.robot_devices.control_utils import is_headless
|
from lerobot.common.robot_devices.control_utils import is_headless
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
|
||||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
|
||||||
from lerobot.configs import parser
|
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
|
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||||
|
from lerobot.configs import parser
|
||||||
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||||
|
|
||||||
follower_port = "/dev/tty.usbmodem58760431631"
|
follower_port = "/dev/tty.usbmodem58760431631"
|
||||||
leader_port = "/dev/tty.usbmodem58760433331"
|
leader_port = "/dev/tty.usbmodem58760433331"
|
||||||
|
|
||||||
|
|
||||||
def find_joint_bounds(
|
def find_joint_bounds(
|
||||||
robot,
|
robot,
|
||||||
control_time_s=30,
|
control_time_s=30,
|
||||||
|
@ -85,21 +86,22 @@ def find_ee_bounds(
|
||||||
def make_robot(robot_type="so100"):
|
def make_robot(robot_type="so100"):
|
||||||
"""
|
"""
|
||||||
Create a robot instance using the appropriate robot config class.
|
Create a robot instance using the appropriate robot config class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
robot_type: Robot type string (e.g., "so100", "koch", "aloha")
|
robot_type: Robot type string (e.g., "so100", "koch", "aloha")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Robot instance
|
Robot instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Get the appropriate robot config class based on robot_type
|
# Get the appropriate robot config class based on robot_type
|
||||||
robot_config = RobotConfig.get_choice_class(robot_type)(mock=False)
|
robot_config = RobotConfig.get_choice_class(robot_type)(mock=False)
|
||||||
robot_config.leader_arms["main"].port = leader_port
|
robot_config.leader_arms["main"].port = leader_port
|
||||||
robot_config.follower_arms["main"].port = follower_port
|
robot_config.follower_arms["main"].port = follower_port
|
||||||
|
|
||||||
return make_robot_from_config(robot_config)
|
return make_robot_from_config(robot_config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Create argparse for script-specific arguments
|
# Create argparse for script-specific arguments
|
||||||
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
|
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
|
||||||
|
@ -125,14 +127,14 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Only parse known args, leaving robot config args for Hydra if used
|
# Only parse known args, leaving robot config args for Hydra if used
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Create robot with the appropriate config
|
# Create robot with the appropriate config
|
||||||
robot = make_robot(args.robot_type)
|
robot = make_robot(args.robot_type)
|
||||||
|
|
||||||
if args.mode == "joint":
|
if args.mode == "joint":
|
||||||
find_joint_bounds(robot, args.control_time_s)
|
find_joint_bounds(robot, args.control_time_s)
|
||||||
elif args.mode == "ee":
|
elif args.mode == "ee":
|
||||||
find_ee_bounds(robot, args.control_time_s)
|
find_ee_bounds(robot, args.control_time_s)
|
||||||
|
|
||||||
if robot.is_connected:
|
if robot.is_connected:
|
||||||
robot.disconnect()
|
robot.disconnect()
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from typing import Any
|
||||||
from typing import Any, Dict, Optional, Tuple
|
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
@ -10,10 +9,8 @@ import torch
|
||||||
from mani_skill.utils.wrappers.record import RecordEpisode
|
from mani_skill.utils.wrappers.record import RecordEpisode
|
||||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||||
|
|
||||||
from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig
|
from lerobot.common.envs.configs import ManiskillEnvConfig
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
||||||
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_maniskill_observation(
|
def preprocess_maniskill_observation(
|
||||||
|
@ -53,9 +50,6 @@ def preprocess_maniskill_observation(
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
||||||
def __init__(self, env, device: torch.device = "cuda"):
|
def __init__(self, env, device: torch.device = "cuda"):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
@ -122,6 +116,7 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||||
|
|
||||||
class BatchCompatibleWrapper(gym.ObservationWrapper):
|
class BatchCompatibleWrapper(gym.ObservationWrapper):
|
||||||
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
|
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
|
||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
|
@ -136,6 +131,7 @@ class BatchCompatibleWrapper(gym.ObservationWrapper):
|
||||||
|
|
||||||
class TimeLimitWrapper(gym.Wrapper):
|
class TimeLimitWrapper(gym.Wrapper):
|
||||||
"""Adds a time limit to the environment based on fps and control_time."""
|
"""Adds a time limit to the environment based on fps and control_time."""
|
||||||
|
|
||||||
def __init__(self, env, control_time_s, fps):
|
def __init__(self, env, control_time_s, fps):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.control_time_s = control_time_s
|
self.control_time_s = control_time_s
|
||||||
|
@ -146,10 +142,10 @@ class TimeLimitWrapper(gym.Wrapper):
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
|
|
||||||
if self.current_step >= self.max_episode_steps:
|
if self.current_step >= self.max_episode_steps:
|
||||||
terminated = True
|
terminated = True
|
||||||
|
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
def reset(self, *, seed=None, options=None):
|
def reset(self, *, seed=None, options=None):
|
||||||
|
@ -190,18 +186,18 @@ def make_maniskill(
|
||||||
save_video=True,
|
save_video=True,
|
||||||
video_fps=30,
|
video_fps=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add observation and image processing
|
# Add observation and image processing
|
||||||
env = ManiSkillObservationWrapper(env, device=cfg.device)
|
env = ManiSkillObservationWrapper(env, device=cfg.device)
|
||||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||||
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
|
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
|
||||||
env.unwrapped.metadata["render_fps"] = cfg.fps
|
env.unwrapped.metadata["render_fps"] = cfg.fps
|
||||||
|
|
||||||
# Add compatibility wrappers
|
# Add compatibility wrappers
|
||||||
env = ManiSkillCompat(env)
|
env = ManiSkillCompat(env)
|
||||||
env = ManiSkillActionWrapper(env)
|
env = ManiSkillActionWrapper(env)
|
||||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
@ -210,29 +206,29 @@ def main(cfg: ManiskillEnvConfig):
|
||||||
"""Main function to run the ManiSkill environment."""
|
"""Main function to run the ManiSkill environment."""
|
||||||
# Create the ManiSkill environment
|
# Create the ManiSkill environment
|
||||||
env = make_maniskill(cfg, n_envs=1)
|
env = make_maniskill(cfg, n_envs=1)
|
||||||
|
|
||||||
# Reset the environment
|
# Reset the environment
|
||||||
obs, info = env.reset()
|
obs, info = env.reset()
|
||||||
|
|
||||||
# Run a simple interaction loop
|
# Run a simple interaction loop
|
||||||
sum_reward = 0
|
sum_reward = 0
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
# Sample a random action
|
# Sample a random action
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
|
|
||||||
# Step the environment
|
# Step the environment
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
obs, reward, terminated, truncated, info = env.step(action)
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
step_time = time.perf_counter() - start_time
|
step_time = time.perf_counter() - start_time
|
||||||
sum_reward += reward
|
sum_reward += reward
|
||||||
# Log information
|
# Log information
|
||||||
|
|
||||||
# Reset if episode terminated
|
# Reset if episode terminated
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
||||||
sum_reward = 0
|
sum_reward = 0
|
||||||
obs, info = env.reset()
|
obs, info = env.reset()
|
||||||
|
|
||||||
# Close the environment
|
# Close the environment
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
@ -243,6 +239,10 @@ def main(cfg: ManiskillEnvConfig):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
config = ManiskillEnvConfig()
|
config = ManiskillEnvConfig()
|
||||||
draccus.set_config_type("json")
|
draccus.set_config_type("json")
|
||||||
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
|
draccus.dump(
|
||||||
|
config=config,
|
||||||
|
stream=open(file="run_config.json", mode="w"),
|
||||||
|
)
|
||||||
|
|
|
@ -32,7 +32,6 @@ from tqdm import tqdm
|
||||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||||
|
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
|
|
|
@ -134,15 +134,15 @@ def test_multi_adam_configuration(base_params_dict, config_params, expected_valu
|
||||||
# Create config with the given parameters
|
# Create config with the given parameters
|
||||||
config = MultiAdamConfig(**config_params)
|
config = MultiAdamConfig(**config_params)
|
||||||
optimizers = config.build(base_params_dict)
|
optimizers = config.build(base_params_dict)
|
||||||
|
|
||||||
# Verify optimizer count and keys
|
# Verify optimizer count and keys
|
||||||
assert len(optimizers) == len(expected_values)
|
assert len(optimizers) == len(expected_values)
|
||||||
assert set(optimizers.keys()) == set(expected_values.keys())
|
assert set(optimizers.keys()) == set(expected_values.keys())
|
||||||
|
|
||||||
# Check that all optimizers are Adam instances
|
# Check that all optimizers are Adam instances
|
||||||
for opt in optimizers.values():
|
for opt in optimizers.values():
|
||||||
assert isinstance(opt, torch.optim.Adam)
|
assert isinstance(opt, torch.optim.Adam)
|
||||||
|
|
||||||
# Verify hyperparameters for each optimizer
|
# Verify hyperparameters for each optimizer
|
||||||
for name, expected in expected_values.items():
|
for name, expected in expected_values.items():
|
||||||
optimizer = optimizers[name]
|
optimizer = optimizers[name]
|
||||||
|
@ -166,7 +166,7 @@ def multi_optimizers(base_params_dict):
|
||||||
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
|
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
|
||||||
# Save optimizer states
|
# Save optimizer states
|
||||||
save_optimizer_state(multi_optimizers, tmp_path)
|
save_optimizer_state(multi_optimizers, tmp_path)
|
||||||
|
|
||||||
# Verify that directories were created for each optimizer
|
# Verify that directories were created for each optimizer
|
||||||
for name in multi_optimizers.keys():
|
for name in multi_optimizers.keys():
|
||||||
assert (tmp_path / name).is_dir()
|
assert (tmp_path / name).is_dir()
|
||||||
|
@ -185,10 +185,10 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers,
|
||||||
multi_optimizers[name].step()
|
multi_optimizers[name].step()
|
||||||
# Zero gradients for next steps
|
# Zero gradients for next steps
|
||||||
multi_optimizers[name].zero_grad()
|
multi_optimizers[name].zero_grad()
|
||||||
|
|
||||||
# Save optimizer states
|
# Save optimizer states
|
||||||
save_optimizer_state(multi_optimizers, tmp_path)
|
save_optimizer_state(multi_optimizers, tmp_path)
|
||||||
|
|
||||||
# Create new optimizers with the same config
|
# Create new optimizers with the same config
|
||||||
config = MultiAdamConfig(
|
config = MultiAdamConfig(
|
||||||
lr=1e-3,
|
lr=1e-3,
|
||||||
|
@ -199,16 +199,13 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
new_optimizers = config.build(base_params_dict)
|
new_optimizers = config.build(base_params_dict)
|
||||||
|
|
||||||
# Load optimizer states
|
# Load optimizer states
|
||||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||||
|
|
||||||
# Verify state dictionaries match
|
# Verify state dictionaries match
|
||||||
for name in multi_optimizers.keys():
|
for name in multi_optimizers.keys():
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
|
||||||
multi_optimizers[name].state_dict(),
|
|
||||||
loaded_optimizers[name].state_dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
||||||
|
@ -223,25 +220,23 @@ def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
optimizers = config.build(base_params_dict)
|
optimizers = config.build(base_params_dict)
|
||||||
|
|
||||||
# Save optimizer states without any backward pass (empty state)
|
# Save optimizer states without any backward pass (empty state)
|
||||||
save_optimizer_state(optimizers, tmp_path)
|
save_optimizer_state(optimizers, tmp_path)
|
||||||
|
|
||||||
# Create new optimizers with the same config
|
# Create new optimizers with the same config
|
||||||
new_optimizers = config.build(base_params_dict)
|
new_optimizers = config.build(base_params_dict)
|
||||||
|
|
||||||
# Load optimizer states
|
# Load optimizer states
|
||||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||||
|
|
||||||
# Verify hyperparameters match even with empty state
|
# Verify hyperparameters match even with empty state
|
||||||
for name, optimizer in optimizers.items():
|
for name, optimizer in optimizers.items():
|
||||||
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
|
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
|
||||||
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
|
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
|
||||||
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
|
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
|
||||||
|
|
||||||
# Verify state dictionaries match (they will be empty)
|
# Verify state dictionaries match (they will be empty)
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
optimizer.state_dict()["param_groups"],
|
optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
|
||||||
loaded_optimizers[name].state_dict()["param_groups"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue