refactor(config): Move device & amp args to PreTrainedConfig (#812)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
parent
10706ed753
commit
5e9473806c
18
Makefile
18
Makefile
|
@ -47,6 +47,7 @@ test-act-ete-train:
|
||||||
--policy.dim_model=64 \
|
--policy.dim_model=64 \
|
||||||
--policy.n_action_steps=20 \
|
--policy.n_action_steps=20 \
|
||||||
--policy.chunk_size=20 \
|
--policy.chunk_size=20 \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=aloha \
|
--env.type=aloha \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||||
|
@ -61,7 +62,6 @@ test-act-ete-train:
|
||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
--log_freq=1 \
|
--log_freq=1 \
|
||||||
--wandb.enable=false \
|
--wandb.enable=false \
|
||||||
--device=$(DEVICE) \
|
|
||||||
--output_dir=tests/outputs/act/
|
--output_dir=tests/outputs/act/
|
||||||
|
|
||||||
test-act-ete-train-resume:
|
test-act-ete-train-resume:
|
||||||
|
@ -72,11 +72,11 @@ test-act-ete-train-resume:
|
||||||
test-act-ete-eval:
|
test-act-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=aloha \
|
--env.type=aloha \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1
|
||||||
--device=$(DEVICE)
|
|
||||||
|
|
||||||
test-diffusion-ete-train:
|
test-diffusion-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
|
@ -84,6 +84,7 @@ test-diffusion-ete-train:
|
||||||
--policy.down_dims='[64,128,256]' \
|
--policy.down_dims='[64,128,256]' \
|
||||||
--policy.diffusion_step_embed_dim=32 \
|
--policy.diffusion_step_embed_dim=32 \
|
||||||
--policy.num_inference_steps=10 \
|
--policy.num_inference_steps=10 \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=pusht \
|
--env.type=pusht \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--dataset.repo_id=lerobot/pusht \
|
--dataset.repo_id=lerobot/pusht \
|
||||||
|
@ -98,21 +99,21 @@ test-diffusion-ete-train:
|
||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
--log_freq=1 \
|
--log_freq=1 \
|
||||||
--wandb.enable=false \
|
--wandb.enable=false \
|
||||||
--device=$(DEVICE) \
|
|
||||||
--output_dir=tests/outputs/diffusion/
|
--output_dir=tests/outputs/diffusion/
|
||||||
|
|
||||||
test-diffusion-ete-eval:
|
test-diffusion-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=pusht \
|
--env.type=pusht \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1
|
||||||
--device=$(DEVICE)
|
|
||||||
|
|
||||||
test-tdmpc-ete-train:
|
test-tdmpc-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
--policy.type=tdmpc \
|
--policy.type=tdmpc \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=xarm \
|
--env.type=xarm \
|
||||||
--env.task=XarmLift-v0 \
|
--env.task=XarmLift-v0 \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
|
@ -128,15 +129,14 @@ test-tdmpc-ete-train:
|
||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
--log_freq=1 \
|
--log_freq=1 \
|
||||||
--wandb.enable=false \
|
--wandb.enable=false \
|
||||||
--device=$(DEVICE) \
|
|
||||||
--output_dir=tests/outputs/tdmpc/
|
--output_dir=tests/outputs/tdmpc/
|
||||||
|
|
||||||
test-tdmpc-ete-eval:
|
test-tdmpc-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=xarm \
|
--env.type=xarm \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--env.task=XarmLift-v0 \
|
--env.task=XarmLift-v0 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1
|
||||||
--device=$(DEVICE)
|
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
@ -76,7 +75,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||||
|
|
||||||
def make_policy(
|
def make_policy(
|
||||||
cfg: PreTrainedConfig,
|
cfg: PreTrainedConfig,
|
||||||
device: str | torch.device,
|
|
||||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||||
env_cfg: EnvConfig | None = None,
|
env_cfg: EnvConfig | None = None,
|
||||||
) -> PreTrainedPolicy:
|
) -> PreTrainedPolicy:
|
||||||
|
@ -88,7 +86,6 @@ def make_policy(
|
||||||
Args:
|
Args:
|
||||||
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
||||||
be loaded with the weights from that path.
|
be loaded with the weights from that path.
|
||||||
device (str): the device to load the policy onto.
|
|
||||||
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
||||||
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
||||||
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
||||||
|
@ -96,7 +93,7 @@ def make_policy(
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: Either ds_meta or env and env_cfg must be provided.
|
ValueError: Either ds_meta or env and env_cfg must be provided.
|
||||||
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
|
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
PreTrainedPolicy: _description_
|
PreTrainedPolicy: _description_
|
||||||
|
@ -111,7 +108,7 @@ def make_policy(
|
||||||
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
||||||
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
||||||
# slower than running natively on MPS.
|
# slower than running natively on MPS.
|
||||||
if cfg.type == "vqbet" and str(device) == "mps":
|
if cfg.type == "vqbet" and cfg.device == "mps":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Current implementation of VQBeT does not support `mps` backend. "
|
"Current implementation of VQBeT does not support `mps` backend. "
|
||||||
"Please use `cpu` or `cuda` backend."
|
"Please use `cpu` or `cuda` backend."
|
||||||
|
@ -145,7 +142,7 @@ def make_policy(
|
||||||
# Make a fresh policy.
|
# Make a fresh policy.
|
||||||
policy = policy_cls(**kwargs)
|
policy = policy_cls(**kwargs)
|
||||||
|
|
||||||
policy.to(device)
|
policy.to(cfg.device)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||||
|
|
|
@ -90,6 +90,7 @@ class PI0Config(PreTrainedConfig):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
|
# TODO(Steven): Validate device and amp? in all policy configs?
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -45,7 +45,7 @@ def main():
|
||||||
|
|
||||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||||
cfg.pretrained_path = ckpt_torch_dir
|
cfg.pretrained_path = ckpt_torch_dir
|
||||||
policy = make_policy(cfg, device, ds_meta=dataset.meta)
|
policy = make_policy(cfg, ds_meta=dataset.meta)
|
||||||
|
|
||||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,7 @@ def main():
|
||||||
|
|
||||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||||
cfg.pretrained_path = ckpt_torch_dir
|
cfg.pretrained_path = ckpt_torch_dir
|
||||||
policy = make_policy(cfg, device, dataset_meta)
|
policy = make_policy(cfg, dataset_meta)
|
||||||
|
|
||||||
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
||||||
# loss_dict["loss"].backward()
|
# loss_dict["loss"].backward()
|
||||||
|
|
|
@ -86,7 +86,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||||
cache_dir: str | Path | None = None,
|
cache_dir: str | Path | None = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
map_location: str = "cpu",
|
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> T:
|
) -> T:
|
||||||
|
@ -111,7 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||||
if os.path.isdir(model_id):
|
if os.path.isdir(model_id):
|
||||||
print("Loading weights from local directory")
|
print("Loading weights from local directory")
|
||||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
model_file = hf_hub_download(
|
model_file = hf_hub_download(
|
||||||
|
@ -125,13 +124,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||||
token=token,
|
token=token,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||||
except HfHubHTTPError as e:
|
except HfHubHTTPError as e:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
policy.to(map_location)
|
policy.to(config.device)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
|
@ -12,17 +12,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -57,11 +54,6 @@ class RecordControlConfig(ControlConfig):
|
||||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||||
root: str | Path | None = None
|
root: str | Path | None = None
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
|
|
||||||
device: str | None = None # cuda | cpu | mps
|
|
||||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
||||||
# automatic gradient scaling is used.
|
|
||||||
use_amp: bool | None = None
|
|
||||||
# Limit the frames per second. By default, uses the policy fps.
|
# Limit the frames per second. By default, uses the policy fps.
|
||||||
fps: int | None = None
|
fps: int | None = None
|
||||||
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
||||||
|
@ -104,27 +96,6 @@ class RecordControlConfig(ControlConfig):
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = policy_path
|
self.policy.pretrained_path = policy_path
|
||||||
|
|
||||||
# When no device or use_amp are given, use the one from training config.
|
|
||||||
if self.device is None or self.use_amp is None:
|
|
||||||
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
|
|
||||||
if self.device is None:
|
|
||||||
self.device = train_cfg.device
|
|
||||||
if self.use_amp is None:
|
|
||||||
self.use_amp = train_cfg.use_amp
|
|
||||||
|
|
||||||
# Automatically switch to available device if necessary
|
|
||||||
if not is_torch_device_available(self.device):
|
|
||||||
auto_device = auto_select_torch_device()
|
|
||||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
|
||||||
self.device = auto_device
|
|
||||||
|
|
||||||
# Automatically deactivate AMP if necessary
|
|
||||||
if self.use_amp and not is_amp_available(self.device):
|
|
||||||
logging.warning(
|
|
||||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
|
||||||
)
|
|
||||||
self.use_amp = False
|
|
||||||
|
|
||||||
|
|
||||||
@ControlConfig.register_subclass("replay")
|
@ControlConfig.register_subclass("replay")
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -32,6 +32,7 @@ from termcolor import colored
|
||||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import get_features_from_robot
|
from lerobot.common.datasets.utils import get_features_from_robot
|
||||||
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.robot_devices.utils import busy_wait
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
||||||
|
@ -193,8 +194,6 @@ def record_episode(
|
||||||
episode_time_s,
|
episode_time_s,
|
||||||
display_cameras,
|
display_cameras,
|
||||||
policy,
|
policy,
|
||||||
device,
|
|
||||||
use_amp,
|
|
||||||
fps,
|
fps,
|
||||||
single_task,
|
single_task,
|
||||||
):
|
):
|
||||||
|
@ -205,8 +204,6 @@ def record_episode(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
events=events,
|
events=events,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
device=device,
|
|
||||||
use_amp=use_amp,
|
|
||||||
fps=fps,
|
fps=fps,
|
||||||
teleoperate=policy is None,
|
teleoperate=policy is None,
|
||||||
single_task=single_task,
|
single_task=single_task,
|
||||||
|
@ -221,9 +218,7 @@ def control_loop(
|
||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
dataset: LeRobotDataset | None = None,
|
dataset: LeRobotDataset | None = None,
|
||||||
events=None,
|
events=None,
|
||||||
policy=None,
|
policy: PreTrainedPolicy = None,
|
||||||
device: torch.device | str | None = None,
|
|
||||||
use_amp: bool | None = None,
|
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
single_task: str | None = None,
|
single_task: str | None = None,
|
||||||
):
|
):
|
||||||
|
@ -246,9 +241,6 @@ def control_loop(
|
||||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||||
|
|
||||||
if isinstance(device, str):
|
|
||||||
device = get_safe_torch_device(device)
|
|
||||||
|
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
while timestamp < control_time_s:
|
while timestamp < control_time_s:
|
||||||
|
@ -260,7 +252,9 @@ def control_loop(
|
||||||
observation = robot.capture_observation()
|
observation = robot.capture_observation()
|
||||||
|
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
pred_action = predict_action(observation, policy, device, use_amp)
|
pred_action = predict_action(
|
||||||
|
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||||
|
)
|
||||||
# Action can eventually be clipped using `max_relative_target`,
|
# Action can eventually be clipped using `max_relative_target`,
|
||||||
# so action actually sent is saved in the dataset.
|
# so action actually sent is saved in the dataset.
|
||||||
action = robot.send_action(pred_action)
|
action = robot.send_action(pred_action)
|
||||||
|
|
|
@ -51,8 +51,10 @@ def auto_select_torch_device() -> torch.device:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
||||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||||
|
try_device = str(try_device)
|
||||||
match try_device:
|
match try_device:
|
||||||
case "cuda":
|
case "cuda":
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
|
@ -85,6 +87,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
||||||
|
|
||||||
|
|
||||||
def is_torch_device_available(try_device: str) -> bool:
|
def is_torch_device_available(try_device: str) -> bool:
|
||||||
|
try_device = str(try_device) # Ensure try_device is a string
|
||||||
if try_device == "cuda":
|
if try_device == "cuda":
|
||||||
return torch.cuda.is_available()
|
return torch.cuda.is_available()
|
||||||
elif try_device == "mps":
|
elif try_device == "mps":
|
||||||
|
@ -92,7 +95,7 @@ def is_torch_device_available(try_device: str) -> bool:
|
||||||
elif try_device == "cpu":
|
elif try_device == "cpu":
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown device '{try_device}.")
|
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
||||||
|
|
||||||
|
|
||||||
def is_amp_available(device: str):
|
def is_amp_available(device: str):
|
||||||
|
|
|
@ -18,11 +18,9 @@ from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lerobot.common import envs, policies # noqa: F401
|
from lerobot.common import envs, policies # noqa: F401
|
||||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.default import EvalConfig
|
from lerobot.configs.default import EvalConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -35,11 +33,6 @@ class EvalPipelineConfig:
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
output_dir: Path | None = None
|
output_dir: Path | None = None
|
||||||
job_name: str | None = None
|
job_name: str | None = None
|
||||||
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
|
|
||||||
device: str | None = None # cuda | cpu | mps
|
|
||||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
||||||
# automatic gradient scaling is used.
|
|
||||||
use_amp: bool = False
|
|
||||||
seed: int | None = 1000
|
seed: int | None = 1000
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -50,27 +43,6 @@ class EvalPipelineConfig:
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = policy_path
|
self.policy.pretrained_path = policy_path
|
||||||
|
|
||||||
# When no device or use_amp are given, use the one from training config.
|
|
||||||
if self.device is None or self.use_amp is None:
|
|
||||||
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
|
|
||||||
if self.device is None:
|
|
||||||
self.device = train_cfg.device
|
|
||||||
if self.use_amp is None:
|
|
||||||
self.use_amp = train_cfg.use_amp
|
|
||||||
|
|
||||||
# Automatically switch to available device if necessary
|
|
||||||
if not is_torch_device_available(self.device):
|
|
||||||
auto_device = auto_select_torch_device()
|
|
||||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
|
||||||
self.device = auto_device
|
|
||||||
|
|
||||||
# Automatically deactivate AMP if necessary
|
|
||||||
if self.use_amp and not is_amp_available(self.device):
|
|
||||||
logging.warning(
|
|
||||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
|
||||||
)
|
|
||||||
self.use_amp = False
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
||||||
|
@ -87,11 +59,6 @@ class EvalPipelineConfig:
|
||||||
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||||
self.output_dir = Path("outputs/eval") / eval_dir
|
self.output_dir = Path("outputs/eval") / eval_dir
|
||||||
|
|
||||||
if self.device is None:
|
|
||||||
raise ValueError("Set one of the following device: cuda, cpu or mps")
|
|
||||||
elif self.device == "cuda" and self.use_amp is None:
|
|
||||||
raise ValueError("Set 'use_amp' to True or False.")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_path_fields__(cls) -> list[str]:
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -25,6 +26,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||||
from lerobot.common.optim.optimizers import OptimizerConfig
|
from lerobot.common.optim.optimizers import OptimizerConfig
|
||||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||||
from lerobot.common.utils.hub import HubMixin
|
from lerobot.common.utils.hub import HubMixin
|
||||||
|
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
|
||||||
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
||||||
|
@ -53,8 +55,24 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|
||||||
|
device: str | None = None # cuda | cpu | mp
|
||||||
|
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||||
|
# automatic gradient scaling is used.
|
||||||
|
use_amp: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.pretrained_path = None
|
self.pretrained_path = None
|
||||||
|
if not self.device or not is_torch_device_available(self.device):
|
||||||
|
auto_device = auto_select_torch_device()
|
||||||
|
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||||
|
self.device = auto_device.type
|
||||||
|
|
||||||
|
# Automatically deactivate AMP if necessary
|
||||||
|
if self.use_amp and not is_amp_available(self.device):
|
||||||
|
logging.warning(
|
||||||
|
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||||
|
)
|
||||||
|
self.use_amp = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -26,7 +25,6 @@ from lerobot.common import envs
|
||||||
from lerobot.common.optim import OptimizerConfig
|
from lerobot.common.optim import OptimizerConfig
|
||||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||||
from lerobot.common.utils.hub import HubMixin
|
from lerobot.common.utils.hub import HubMixin
|
||||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
@ -48,10 +46,6 @@ class TrainPipelineConfig(HubMixin):
|
||||||
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
||||||
# regardless of what's provided with the training command at the time of resumption.
|
# regardless of what's provided with the training command at the time of resumption.
|
||||||
resume: bool = False
|
resume: bool = False
|
||||||
device: str | None = None # cuda | cpu | mp
|
|
||||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
||||||
# automatic gradient scaling is used.
|
|
||||||
use_amp: bool = False
|
|
||||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||||
# AND for the evaluation environments.
|
# AND for the evaluation environments.
|
||||||
seed: int | None = 1000
|
seed: int | None = 1000
|
||||||
|
@ -74,18 +68,6 @@ class TrainPipelineConfig(HubMixin):
|
||||||
self.checkpoint_path = None
|
self.checkpoint_path = None
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
if not self.device:
|
|
||||||
logging.warning("No device specified, trying to infer device automatically")
|
|
||||||
device = auto_select_torch_device()
|
|
||||||
self.device = device.type
|
|
||||||
|
|
||||||
# Automatically deactivate AMP if necessary
|
|
||||||
if self.use_amp and not is_amp_available(self.device):
|
|
||||||
logging.warning(
|
|
||||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
|
||||||
)
|
|
||||||
self.use_amp = False
|
|
||||||
|
|
||||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||||
policy_path = parser.get_path_arg("policy")
|
policy_path = parser.get_path_arg("policy")
|
||||||
if policy_path:
|
if policy_path:
|
||||||
|
|
|
@ -267,7 +267,7 @@ def record(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
|
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
@ -298,8 +298,6 @@ def record(
|
||||||
episode_time_s=cfg.episode_time_s,
|
episode_time_s=cfg.episode_time_s,
|
||||||
display_cameras=cfg.display_cameras,
|
display_cameras=cfg.display_cameras,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
device=cfg.device,
|
|
||||||
use_amp=cfg.use_amp,
|
|
||||||
fps=cfg.fps,
|
fps=cfg.fps,
|
||||||
single_task=cfg.single_task,
|
single_task=cfg.single_task,
|
||||||
)
|
)
|
||||||
|
|
|
@ -458,7 +458,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||||
logging.info(pformat(asdict(cfg)))
|
logging.info(pformat(asdict(cfg)))
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -470,14 +470,14 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||||
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||||
|
|
||||||
logging.info("Making policy.")
|
logging.info("Making policy.")
|
||||||
|
|
||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
device=device,
|
|
||||||
env_cfg=cfg.env,
|
env_cfg=cfg.env,
|
||||||
)
|
)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
info = eval_policy(
|
info = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy,
|
policy,
|
||||||
|
|
|
@ -120,7 +120,7 @@ def train(cfg: TrainPipelineConfig):
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
@ -138,13 +138,12 @@ def train(cfg: TrainPipelineConfig):
|
||||||
logging.info("Creating policy")
|
logging.info("Creating policy")
|
||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
device=device,
|
|
||||||
ds_meta=dataset.meta,
|
ds_meta=dataset.meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Creating optimizer and scheduler")
|
logging.info("Creating optimizer and scheduler")
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
|
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
||||||
|
|
||||||
step = 0 # number of policy updates (forward + backward + optim)
|
step = 0 # number of policy updates (forward + backward + optim)
|
||||||
|
|
||||||
|
@ -218,7 +217,7 @@ def train(cfg: TrainPipelineConfig):
|
||||||
cfg.optimizer.grad_clip_norm,
|
cfg.optimizer.grad_clip_norm,
|
||||||
grad_scaler=grad_scaler,
|
grad_scaler=grad_scaler,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
use_amp=cfg.use_amp,
|
use_amp=cfg.policy.use_amp,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
|
@ -249,7 +248,10 @@ def train(cfg: TrainPipelineConfig):
|
||||||
if cfg.env and is_eval_step:
|
if cfg.env and is_eval_step:
|
||||||
step_id = get_step_identifier(step, cfg.steps)
|
step_id = get_step_identifier(step, cfg.steps)
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
with (
|
||||||
|
torch.no_grad(),
|
||||||
|
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||||
|
):
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy(
|
||||||
eval_env,
|
eval_env,
|
||||||
policy,
|
policy,
|
||||||
|
|
|
@ -33,12 +33,11 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
# TODO(rcadene, aliberts): remove dataset download
|
||||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||||
policy=make_policy_config(policy_name, **policy_kwargs),
|
policy=make_policy_config(policy_name, **policy_kwargs),
|
||||||
device="cpu",
|
|
||||||
)
|
)
|
||||||
train_cfg.validate() # Needed for auto-setting some parameters
|
train_cfg.validate() # Needed for auto-setting some parameters
|
||||||
|
|
||||||
dataset = make_dataset(train_cfg)
|
dataset = make_dataset(train_cfg)
|
||||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=train_cfg.device)
|
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
|
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
|
||||||
|
|
|
@ -52,7 +52,7 @@ from lerobot.common.robot_devices.control_configs import (
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||||
from tests.test_robots import make_robot
|
from tests.test_robots import make_robot
|
||||||
from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||||
|
@ -184,7 +184,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||||
replay(robot, replay_cfg)
|
replay(robot, replay_cfg)
|
||||||
|
|
||||||
policy_cfg = ACTConfig()
|
policy_cfg = ACTConfig()
|
||||||
policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE)
|
policy = make_policy(policy_cfg, ds_meta=dataset.meta)
|
||||||
|
|
||||||
out_dir = tmp_path / "logger"
|
out_dir = tmp_path / "logger"
|
||||||
|
|
||||||
|
@ -229,8 +229,6 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
num_image_writer_processes=num_image_writer_processes,
|
num_image_writer_processes=num_image_writer_processes,
|
||||||
device=DEVICE,
|
|
||||||
use_amp=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path)
|
rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path)
|
||||||
|
|
|
@ -45,7 +45,7 @@ from lerobot.common.robot_devices.robots.utils import make_robot
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||||
from tests.utils import DEVICE, require_x86_64_kernel
|
from tests.utils import require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -349,7 +349,6 @@ def test_factory(env_name, repo_id, policy_name):
|
||||||
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
|
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
|
||||||
env=make_env_config(env_name),
|
env=make_env_config(env_name),
|
||||||
policy=make_policy_config(policy_name),
|
policy=make_policy_config(policy_name),
|
||||||
device=DEVICE,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
|
|
@ -143,12 +143,11 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||||
policy=make_policy_config(policy_name, **policy_kwargs),
|
policy=make_policy_config(policy_name, **policy_kwargs),
|
||||||
env=make_env_config(env_name, **env_kwargs),
|
env=make_env_config(env_name, **env_kwargs),
|
||||||
device=DEVICE,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that we can make the policy object.
|
# Check that we can make the policy object.
|
||||||
dataset = make_dataset(train_cfg)
|
dataset = make_dataset(train_cfg)
|
||||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=DEVICE)
|
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
||||||
assert isinstance(policy, PreTrainedPolicy)
|
assert isinstance(policy, PreTrainedPolicy)
|
||||||
|
|
||||||
# Check that we run select_actions and get the appropriate output.
|
# Check that we run select_actions and get the appropriate output.
|
||||||
|
@ -214,7 +213,6 @@ def test_act_backbone_lr():
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
# TODO(rcadene, aliberts): remove dataset download
|
||||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
||||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
||||||
device=DEVICE,
|
|
||||||
)
|
)
|
||||||
cfg.validate() # Needed for auto-setting some parameters
|
cfg.validate() # Needed for auto-setting some parameters
|
||||||
|
|
||||||
|
@ -222,7 +220,7 @@ def test_act_backbone_lr():
|
||||||
assert cfg.policy.optimizer_lr_backbone == 0.001
|
assert cfg.policy.optimizer_lr_backbone == 0.001
|
||||||
|
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
policy = make_policy(cfg.policy, device=DEVICE, ds_meta=dataset.meta)
|
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||||
assert len(optimizer.param_groups) == 2
|
assert len(optimizer.param_groups) == 2
|
||||||
assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr
|
assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr
|
||||||
|
|
Loading…
Reference in New Issue