diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 7bb7f167..bcca8976 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -39,6 +39,12 @@ class SACConfig: "observation.environment_state": "min_max", } ) + input_normalization_params: dict[str, dict[str, list[float]]] = field( + default_factory=lambda: { + "observation.image": {"mean": [[0.485, 0.456, 0.406]], "std": [[0.229, 0.224, 0.225]]}, + "observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]}, + } + ) output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"}) output_normalization_params: dict[str, dict[str, list[float]]] = field( default_factory=lambda: { diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 9faeeeb6..a3d5d8e6 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -51,18 +51,20 @@ class SACPolicy( if config is None: config = SACConfig() self.config = config + if config.input_normalization_modes is not None: + input_normalization_params = _convert_normalization_params_to_tensor( + config.input_normalization_params + ) self.normalize_inputs = Normalize( - config.input_shapes, config.input_normalization_modes, dataset_stats + config.input_shapes, config.input_normalization_modes, input_normalization_params ) else: self.normalize_inputs = nn.Identity() - output_normalization_params = {} - for outer_key, inner_dict in config.output_normalization_params.items(): - output_normalization_params[outer_key] = {} - for key, value in inner_dict.items(): - output_normalization_params[outer_key][key] = torch.tensor(value) + output_normalization_params = _convert_normalization_params_to_tensor( + config.output_normalization_params + ) # HACK: This is hacky and should be removed dataset_stats = dataset_stats or output_normalization_params @@ -75,7 +77,7 @@ class SACPolicy( # NOTE: For images the encoder should be shared between the actor and critic if config.shared_encoder: - encoder_critic = SACObservationEncoder(config) + encoder_critic = SACObservationEncoder(config, self.normalize_inputs) encoder_actor: SACObservationEncoder = encoder_critic else: encoder_critic = SACObservationEncoder(config) @@ -92,6 +94,7 @@ class SACPolicy( for _ in range(config.num_critics) ] ), + output_normalization=self.normalize_targets, ) self.critic_target = CriticEnsemble( @@ -105,6 +108,7 @@ class SACPolicy( for _ in range(config.num_critics) ] ), + output_normalization=self.normalize_targets, ) self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) @@ -122,7 +126,7 @@ class SACPolicy( # TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise # it triggers "can't optimize a non-leaf Tensor" - self.log_alpha = torch.zeros(1, requires_grad=True, device=torch.device("cuda:0")) + self.log_alpha = torch.tensor([0.0], requires_grad=True, device=torch.device("mps")) self.temperature = self.log_alpha.exp().item() def reset(self): @@ -313,12 +317,14 @@ class CriticEnsemble(nn.Module): self, encoder: Optional[nn.Module], network_list: nn.ModuleList, + output_normalization: nn.Module, init_final: Optional[float] = None, ): super().__init__() self.encoder = encoder self.network_list = network_list self.init_final = init_final + self.output_normalization = output_normalization self.parameters_to_optimize = [] # Handle the case where a part of the encoder if frozen @@ -358,6 +364,10 @@ class CriticEnsemble(nn.Module): device = get_device_from_parameters(self) # Move each tensor in observations to device observations = {k: v.to(device) for k, v in observations.items()} + # NOTE: We normalize actions it helps for sample efficiency + actions: dict[str, torch.tensor] = {"action": actions} + # NOTE: Normalization layer took dict in input and outputs a dict that why + actions = self.output_normalization(actions)["action"] actions = actions.to(device) obs_enc = observations if self.encoder is None else self.encoder(observations) @@ -474,17 +484,18 @@ class Policy(nn.Module): class SACObservationEncoder(nn.Module): """Encode image and/or state vector observations.""" - def __init__(self, config: SACConfig): + def __init__(self, config: SACConfig, input_normalizer: nn.Module): """ Creates encoders for pixel and/or state modalities. """ super().__init__() self.config = config + self.input_normalization = input_normalizer self.has_pretrained_vision_encoder = False self.parameters_to_optimize = [] self.aggregation_size: int = 0 - if "observation.image" in config.input_shapes: + if any("observation.image" in key for key in config.input_shapes): self.camera_number = config.camera_number if self.config.vision_encoder_name is not None: @@ -534,8 +545,9 @@ class SACObservationEncoder(nn.Module): over all features. """ feat = [] + obs_dict = self.input_normalization(obs_dict) # Concatenate all images along the channel dimension. - image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")] + image_keys = [k for k in obs_dict if k.startswith("observation.image")] for image_key in image_keys: enc_feat = self.image_enc_layers(obs_dict[image_key]) @@ -681,6 +693,18 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) +def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: + converted_params = {} + for outer_key, inner_dict in normalization_params.items(): + converted_params[outer_key] = {} + for key, value in inner_dict.items(): + converted_params[outer_key][key] = torch.tensor(value) + if "image" in outer_key: + converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1) + + return converted_params + + if __name__ == "__main__": # Test the SACObservationEncoder import time diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 4e276e16..e4460f5f 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -18,6 +18,7 @@ import os import os.path as osp import platform import random +import time from contextlib import contextmanager from datetime import datetime, timezone from pathlib import Path @@ -217,3 +218,28 @@ def log_say(text, play_sounds, blocking=False): if play_sounds: say(text, blocking) + + +class TimerManager: + def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True): + self.label = label + self.elapsed_time_list = elapsed_time_list + self.log = log + self.elapsed = 0.0 + + def __enter__(self): + self.start = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.elapsed: float = time.perf_counter() - self.start + + if self.elapsed_time_list is not None: + self.elapsed_time_list.append(self.elapsed) + + if self.log: + print(f"{self.label}: {self.elapsed:.6f} seconds") + + @property + def elapsed_seconds(self): + return self.elapsed diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index a3ff1d41..7750ba3a 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -2,6 +2,7 @@ defaults: - _self_ - env: pusht - policy: diffusion + - robot: so100 hydra: run: diff --git a/lerobot/configs/env/so100_real.yaml b/lerobot/configs/env/so100_real.yaml index 8e65d72f..862ea951 100644 --- a/lerobot/configs/env/so100_real.yaml +++ b/lerobot/configs/env/so100_real.yaml @@ -8,3 +8,20 @@ env: state_dim: 6 action_dim: 6 fps: ${fps} + device: mps + + wrapper: + crop_params_dict: + observation.images.laptop: [58, 89, 357, 455] + observation.images.phone: [3, 4, 471, 633] + resize_size: [128, 128] + control_time_s: 20 + reset_follower_pos: true + use_relative_joint_positions: true + reset_time_s: 10 + display_cameras: false + delta_action: 0.1 + + reward_classifier: + pretrained_path: outputs/classifier/checkpoints/best/pretrained_model + config_path: lerobot/configs/policy/hilserl_classifier.yaml diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index be5c0818..28c582d2 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -31,16 +31,21 @@ from omegaconf import DictConfig from torch import nn # TODO: Remove the import of maniskill -from lerobot.common.envs.factory import make_maniskill_env -from lerobot.common.envs.utils import preprocess_maniskill_observation +# from lerobot.common.envs.factory import make_maniskill_env +# from lerobot.common.envs.utils import preprocess_maniskill_observation from lerobot.common.policies.factory import make_policy from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.robot_devices.control_utils import busy_wait +from lerobot.common.robot_devices.robots.factory import make_robot +from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.utils.utils import ( + TimerManager, get_safe_torch_device, set_global_seed, ) from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device +from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env logging.basicConfig(level=logging.INFO) @@ -152,7 +157,15 @@ def serve_actor_service(port=50052): server.wait_for_termination() -def act_with_policy(cfg: DictConfig): +def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device): + if not parameters_queue.empty(): + logging.debug("[ACTOR] Load new parameters from Learner.") + state_dict = parameters_queue.get() + state_dict = move_state_dict_to_device(state_dict, device=device) + policy.load_state_dict(state_dict) + + +def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module): """ Executes policy interaction within the environment. @@ -165,9 +178,7 @@ def act_with_policy(cfg: DictConfig): logging.info("make_env online") - # online_env = make_env(cfg, n_envs=1) - # TODO: Remove the import of maniskill and unifiy with make env - online_env = make_maniskill_env(cfg, n_envs=1) + online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg.env) set_global_seed(cfg.seed) device = get_safe_torch_device(cfg.device, log=True) @@ -177,6 +188,16 @@ def act_with_policy(cfg: DictConfig): logging.info("make_policy") + # HACK: This is an ugly hack to pass the normalization parameters to the policy + # Because the action space is dynamic so we override the output normalization parameters + # it's ugly, we know ... and we will fix it + min_action_space: list = online_env.action_space.spaces[0].low.tolist() + max_action_space: list = online_env.action_space.spaces[0].high.tolist() + output_normalization_params: dict[dict[str, list]] = { + "action": {"min": min_action_space, "max": max_action_space} + } + cfg.policy.output_normalization_params = output_normalization_params + ### Instantiate the policy in both the actor and learner processes ### To avoid sending a SACPolicy object through the port, we create a policy intance ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters @@ -187,92 +208,41 @@ def act_with_policy(cfg: DictConfig): # Hack: But if we do online training, we do not need dataset_stats dataset_stats=None, # TODO: Handle resume training + device=device, ) - # pretrained_policy_name_or_path=None, - # device=device, - # ) policy = torch.compile(policy) assert isinstance(policy, nn.Module) - # HACK for maniskill obs, info = online_env.reset() - # obs = preprocess_observation(obs) - obs = preprocess_maniskill_observation(obs) - obs = {key: obs[key].to(device, non_blocking=True) for key in obs} - # NOTE: For the moment we will solely handle the case of a single environment sum_reward_episode = 0 list_transition_to_send_to_learner = [] - list_policy_fps = [] + list_policy_time = [] for interaction_step in range(cfg.training.online_steps): if interaction_step >= cfg.training.online_step_before_learning: - start = time.perf_counter() - action = policy.select_action(batch=obs) - list_policy_fps.append(1.0 / (time.perf_counter() - start + 1e-9)) - if list_policy_fps[-1] < cfg.fps: - logging.warning( - f"[ACTOR] policy frame rate {list_policy_fps[-1]} during interaction step {interaction_step} is below the required control frame rate {cfg.fps}" - ) + # Time policy inference and check if it meets FPS requirement + with TimerManager( + elapsed_time_list=list_policy_time, label="Policy inference time", log=False + ) as timer: # noqa: F841 + action = policy.select_action(batch=obs) * 0.0 + policy_fps = 1.0 / (list_policy_time[-1] + 1e-9) - next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy()) + log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) + + next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy()) else: + # TODO (azouitine): Make a custom space for torch tensor action = online_env.action_space.sample() next_obs, reward, done, truncated, info = online_env.step(action) - # HACK - action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True) - # HACK: For maniskill - # next_obs = preprocess_observation(next_obs) - next_obs = preprocess_maniskill_observation(next_obs) - next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs} - sum_reward_episode += float(reward[0]) + # HACK: We have only one env but we want to batch it, it will be resolved with the torch box + action = torch.from_numpy(action[0]).to(device, non_blocking=True).unsqueeze(dim=0) - # Because we are using a single environment we can index at zero - if done[0].item() or truncated[0].item(): - # TODO: Handle logging for episode information - logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") + sum_reward_episode += float(reward) - if not parameters_queue.empty(): - logging.debug("[ACTOR] Load new parameters from Learner.") - state_dict = parameters_queue.get() - state_dict = move_state_dict_to_device(state_dict, device=device) - # strict=False for the case when the image encoder is frozen and not sent through - # the network. Becareful might cause issues if the wrong keys are passed - policy.actor.load_state_dict(state_dict, strict=False) - - if len(list_transition_to_send_to_learner) > 0: - logging.debug( - f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner." - ) - message_queue.put(ActorInformation(transition=list_transition_to_send_to_learner)) - list_transition_to_send_to_learner = [] - - stats = {} - if len(list_policy_fps) > 0: - policy_fps = mean(list_policy_fps) - quantiles_90 = quantiles(list_policy_fps, n=10)[-1] - logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}") - logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}") - stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90} - list_policy_fps = [] - - # Send episodic reward to the learner - message_queue.put( - ActorInformation( - interaction_message={ - "Episodic reward": sum_reward_episode, - "Interaction step": interaction_step, - **stats, - } - ) - ) - sum_reward_episode = 0.0 - - # TODO (michel-aractingi): Label the reward - # if config.label_reward_on_actor: - # reward = reward_classifier(obs) + # NOTE: We overide the action if the intervention is True, because the action applied is the intervention action if info["is_intervention"]: # TODO: Check the shape action = info["action_intervention"] @@ -291,17 +261,85 @@ def act_with_policy(cfg: DictConfig): # assign obs to the next obs and continue the rollout obs = next_obs + # HACK: We have only one env but we want to batch it, it will be resolved with the torch box + # Because we are using a single environment we can index at zero + if done or truncated: + # TODO: Handle logging for episode information + logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") + + # update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device) + + if len(list_transition_to_send_to_learner) > 0: + send_transitions_in_chunks( + transitions=list_transition_to_send_to_learner, message_queue=message_queue, chunk_size=4 + ) + list_transition_to_send_to_learner = [] + + stats = get_frequency_stats(list_policy_time) + list_policy_time.clear() + + # Send episodic reward to the learner + message_queue.put( + ActorInformation( + interaction_message={ + "Episodic reward": sum_reward_episode, + "Interaction step": interaction_step, + **stats, + } + ) + ) + sum_reward_episode = 0.0 + obs, info = online_env.reset() + + +def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int = 100): + """Send transitions to learner in smaller chunks to avoid network issues. + + Args: + transitions: List of transitions to send + message_queue: Queue to send messages to learner + chunk_size: Size of each chunk to send + """ + for i in range(0, len(transitions), chunk_size): + chunk = transitions[i : i + chunk_size] + logging.debug(f"[ACTOR] Sending chunk of {len(chunk)} transitions to Learner.") + message_queue.put(ActorInformation(transition=chunk)) + + +def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]: + stats = {} + list_policy_fps = [1.0 / t for t in list_policy_time] + if len(list_policy_fps) > 0: + policy_fps = mean(list_policy_fps) + quantiles_90 = quantiles(list_policy_fps, n=10)[-1] + logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}") + logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}") + stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90} + return stats + + +def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int): + if policy_fps < cfg.fps: + logging.warning( + f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}" + ) + @hydra.main(version_base="1.2", config_name="default", config_path="../../configs") def actor_cli(cfg: dict): - port = cfg.actor_learner_config.port - server_thread = Thread(target=serve_actor_service, args=(port,), daemon=True) - server_thread.start() + robot = make_robot(cfg=cfg.robot) + + server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True) + reward_classifier = get_classifier( + pretrained_path=cfg.env.reward_classifier.pretrained_path, + config_path=cfg.env.reward_classifier.config_path, + ) policy_thread = Thread( target=act_with_policy, daemon=True, - args=(cfg,), + args=(cfg, robot, reward_classifier), ) + server_thread.start() policy_thread.start() policy_thread.join() server_thread.join() diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 828116b9..8be21365 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -56,10 +56,10 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr } # If complementary_info is present, move its tensors to CPU - if transition["complementary_info"] is not None: - transition["complementary_info"] = { - key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items() - } + # if transition["complementary_info"] is not None: + # transition["complementary_info"] = { + # key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items() + # } return transition @@ -309,6 +309,7 @@ class ReplayBuffer: def sample(self, batch_size: int) -> BatchTransition: """Sample a random batch of transitions and collate them into batched tensors.""" + batch_size = min(batch_size, len(self.memory)) list_of_transitions = random.sample(self.memory, batch_size) # -- Build batched states -- @@ -341,9 +342,6 @@ class ReplayBuffer: batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( self.device ) - batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( - self.device - ) # Return a BatchTransition typed dict return BatchTransition( @@ -531,30 +529,31 @@ def concatenate_batch_transitions( # if __name__ == "__main__": -# dataset_name = "lerobot/pusht_image" -# dataset = LeRobotDataset(repo_id=dataset_name, episodes=range(1, 3)) -# replay_buffer = ReplayBuffer.from_lerobot_dataset( -# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"] -# ) -# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted") -# for i in range(len(replay_buffer_converted)): -# replay_convert = replay_buffer_converted[i] -# dataset_convert = dataset[i] -# for key in replay_convert.keys(): -# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}: -# continue -# if key in dataset_convert.keys(): -# assert torch.equal(replay_convert[key], dataset_convert[key]) -# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}") -# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset( -# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu" -# ) -# for _ in range(20): -# batch = re_reconverted_dataset.sample(32) +# dataset_name = "aractingi/push_green_cube_hf_cropped_resized" +# dataset = LeRobotDataset(repo_id=dataset_name) -# for key in batch.keys(): -# if key in {"state", "next_state"}: -# for key_state in batch[key].keys(): -# print(key_state, batch[key][key_state].size()) -# continue -# print(key, batch[key].size()) +# replay_buffer = ReplayBuffer.from_lerobot_dataset( +# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"] +# ) +# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted") +# for i in range(len(replay_buffer_converted)): +# replay_convert = replay_buffer_converted[i] +# dataset_convert = dataset[i] +# for key in replay_convert.keys(): +# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}: +# continue +# if key in dataset_convert.keys(): +# assert torch.equal(replay_convert[key], dataset_convert[key]) +# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}") +# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset( +# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu" +# ) +# for _ in range(20): +# batch = re_reconverted_dataset.sample(32) + +# for key in batch.keys(): +# if key in {"state", "next_state"}: +# for key_state in batch[key].keys(): +# print(key_state, batch[key][key_state].size()) +# continue +# print(key, batch[key].size()) diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 40dc2784..5bf51868 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -4,7 +4,6 @@ import time from threading import Lock from typing import Annotated, Any, Callable, Dict, Optional, Tuple -import cv2 import gymnasium as gym import numpy as np import torch @@ -20,10 +19,15 @@ logging.basicConfig(level=logging.INFO) class HILSerlRobotEnv(gym.Env): """ - Gym-like environment wrapper for robot policy evaluation. + Gym-compatible environment for evaluating robotic control policies with integrated human intervention. - This wrapper provides a consistent interface for interacting with the robot, - following the OpenAI Gym environment conventions. + This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta) + and absolute joint position commands and automatically configures its observation and action spaces based on the robot's + sensors and configuration. + + The environment can switch between executing actions from a policy or using teleoperated actions (human intervention) during + each step. When teleoperation is used, the override action is captured and returned in the `info` dict along with a flag + `is_intervention`. """ def __init__( @@ -31,32 +35,34 @@ class HILSerlRobotEnv(gym.Env): robot, use_delta_action_space: bool = True, delta: float | None = None, - display_cameras=False, + display_cameras: bool = False, ): """ - Initialize the robot environment. + Initialize the HILSerlRobotEnv environment. + + The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup + supports both relative (delta) adjustments and absolute joint positions for controlling the robot. Args: - robot: The robot interface object - reward_classifier: Optional reward classifier - fps: Frames per second for control - control_time_s: Total control time for each episode - display_cameras: Whether to display camera feeds - output_normalization_params_action: Bound parameters for the action space - delta: The delta for the relative joint position action space + robot: The robot interface object used to connect and interact with the physical robot. + use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute + joint positions are used. + delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between + 0 and 1 when using a delta action space. + display_cameras (bool): If True, the robot's camera feeds will be displayed during execution. """ super().__init__() self.robot = robot self.display_cameras = display_cameras - # connect robot + # Connect to the robot if not already connected. if not self.robot.is_connected: self.robot.connect() self.initial_follower_position = robot.follower_arms["main"].read("Present_Position") - # Episode tracking + # Episode tracking. self.current_step = 0 self.episode_data = None @@ -64,6 +70,7 @@ class HILSerlRobotEnv(gym.Env): self.use_delta_action_space = use_delta_action_space self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") + # Retrieve the size of the joint position interval bound. self.relative_bounds_size = ( self.robot.config.joint_position_relative_bounds["max"] - self.robot.config.joint_position_relative_bounds["min"] @@ -73,20 +80,26 @@ class HILSerlRobotEnv(gym.Env): self.robot.config.max_relative_target = self.delta_relative_bounds_size.float() - # Dynamically determine observation and action spaces + # Dynamically configure the observation and action spaces. self._setup_spaces() def _setup_spaces(self): """ - Dynamically determine observation and action spaces based on robot capabilities. + Dynamically configure the observation and action spaces based on the robot's capabilities. - This method should be customized based on the specific robot's observation - and action representations. + Observation Space: + - For keys with "image": A Box space with pixel values ranging from 0 to 255. + - For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range. + + Action Space: + - The action space is defined as a Tuple where: + • The first element is a Box space representing joint position commands. It is defined as relative (delta) + or absolute, based on the configuration. + • The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation). """ - # Example space setup - you'll need to adapt this to your specific robot example_obs = self.robot.capture_observation() - # Observation space (assuming image-based observations) + # Define observation spaces for images and other states. image_keys = [key for key in example_obs if "image" in key] state_keys = [key for key in example_obs if "image" not in key] observation_spaces = { @@ -102,7 +115,7 @@ class HILSerlRobotEnv(gym.Env): self.observation_space = gym.spaces.Dict(observation_spaces) - # Action space (assuming joint positions) + # Define the action space for joint positions along with setting an intervention flag. action_dim = len(self.robot.follower_arms["main"].read("Present_Position")) if self.use_delta_action_space: action_space_robot = gym.spaces.Box( @@ -128,18 +141,24 @@ class HILSerlRobotEnv(gym.Env): def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: """ - Reset the environment to initial state. + Reset the environment to its initial state. + This method resets the step counter and clears any episodic data. + + Args: + seed (Optional[int]): A seed for random number generation to ensure reproducibility. + options (Optional[dict]): Additional options to influence the reset behavior. Returns: - observation (dict): Initial observation - info (dict): Additional information + A tuple containing: + - observation (dict): The initial sensor observation. + - info (dict): A dictionary with supplementary information, including the key "initial_position". """ super().reset(seed=seed, options=options) - # Capture initial observation + # Capture the initial observation. observation = self.robot.capture_observation() - # Reset tracking variables + # Reset episode tracking variables. self.current_step = 0 self.episode_data = None @@ -149,28 +168,38 @@ class HILSerlRobotEnv(gym.Env): self, action: Tuple[np.ndarray, bool] ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: """ - Take a step in the environment. + Execute a single step within the environment using the specified action. + + The provided action is a tuple comprised of: + • A policy action (joint position commands) that may be either in absolute values or as a delta. + • A boolean flag indicating whether teleoperation (human intervention) should be used for this step. + + Behavior: + - When the intervention flag is False, the environment processes and sends the policy action to the robot. + - When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted + to relative change based on the current joint positions. Args: - action tuple(np.ndarray, bool): - Policy action to be executed on the robot and boolean to determine - whether to choose policy action or expert action. + action (tuple): A tuple with two elements: + - policy_action (np.ndarray or torch.Tensor): The commanded joint positions. + - intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input. Returns: - observation (dict): Next observation - reward (float): Reward for this step - terminated (bool): Whether the episode has terminated - truncated (bool): Whether the episode was truncated - info (dict): Additional information + tuple: A tuple containing: + - observation (dict): The new sensor observation after taking the step. + - reward (float): The step reward (default is 0.0 within this wrapper). + - terminated (bool): True if the episode has reached a terminal state. + - truncated (bool): True if the episode was truncated (e.g., time constraints). + - info (dict): Additional debugging information including: + ◦ "action_intervention": The teleop action if intervention was used. + ◦ "is_intervention": Flag indicating whether teleoperation was employed. """ - # The actions recieved are the in form of a tuple containing the policy action and an intervention bool - # The boolean inidicated whether we will use the expert's actions (through teleoperation) or the policy actions policy_action, intervention_bool = action teleop_action = None self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") if isinstance(policy_action, torch.Tensor): policy_action = policy_action.cpu().numpy() - olicy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high) + policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high) if not intervention_bool: if self.use_delta_action_space: target_joint_positions = self.current_joint_positions + self.delta * policy_action @@ -180,26 +209,26 @@ class HILSerlRobotEnv(gym.Env): observation = self.robot.capture_observation() else: observation, teleop_action = self.robot.teleop_step(record_data=True) - teleop_action = teleop_action["action"] # teleop step returns torch tensors but in a dict + teleop_action = teleop_action["action"] # Convert tensor to appropriate format - # teleop actions are returned in absolute joint space - # If we are using a relative joint position action space, - # there will be a mismatch between the spaces of the policy and teleop actions - # Solution is to transform the teleop actions into relative space. - # teleop relative action is: + # When applying the delta action space, convert teleop absolute values to relative differences. if self.use_delta_action_space: teleop_action = teleop_action - self.current_joint_positions if torch.any(teleop_action < -self.delta_relative_bounds_size * self.delta) and torch.any( teleop_action > self.delta_relative_bounds_size ): print( - f"relative teleop delta exceeded bounds {self.delta_relative_bounds_size}, teleop_action {teleop_action}\n" + f"Relative teleop delta exceeded bounds {self.delta_relative_bounds_size}, teleop_action {teleop_action}\n" f"lower bounds condition {teleop_action < -self.delta_relative_bounds_size}\n" f"upper bounds condition {teleop_action > self.delta_relative_bounds_size}" ) + teleop_action = torch.clamp( teleop_action, -self.delta_relative_bounds_size, self.delta_relative_bounds_size ) + # NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action. + if teleop_action.dim() == 1: + teleop_action = teleop_action.unsqueeze(0) self.current_step += 1 @@ -217,7 +246,7 @@ class HILSerlRobotEnv(gym.Env): def render(self): """ - Render the environment (in this case, display camera feeds). + Render the current state of the environment by displaying the robot's camera feeds. """ import cv2 @@ -231,7 +260,10 @@ class HILSerlRobotEnv(gym.Env): def close(self): """ - Close the environment and disconnect the robot. + Close the environment and clean up resources by disconnecting the robot. + + If the robot is currently connected, this method properly terminates the connection to ensure that all + associated resources are released. """ if self.robot.is_connected: self.robot.disconnect() @@ -250,48 +282,19 @@ class ActionRepeatWrapper(gym.Wrapper): return obs, reward, done, truncated, info -class RelativeJointPositionActionWrapper(gym.Wrapper): - def __init__( - self, - env: HILSerlRobotEnv, - # output_normalization_params_action: dict[str, list[float]], - delta: float = 0.1, - ): - super().__init__(env) - self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position") - self.delta = delta - if delta > 1: - raise ValueError("Delta should be less than 1") - - def step(self, action): - action_joint = action - self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position") - if isinstance(self.env.action_space, gym.spaces.Tuple): - action_joint = action[0] - joint_positions = self.joint_positions + (self.delta * action_joint) - # clip the joint positions to the joint limits with the action space - joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high) - - if isinstance(self.env.action_space, gym.spaces.Tuple): - return self.env.step((joint_positions, action[1])) - - obs, reward, terminated, truncated, info = self.env.step(joint_positions) - if info["is_intervention"]: - # teleop actions are returned in absolute joint space - # If we are using a relative joint position action space, - # there will be a mismatch between the spaces of the policy and teleop actions - # Solution is to transform the teleop actions into relative space. - self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position") - teleop_action = info["action_intervention"] # teleop actions are in absolute joint space - relative_teleop_action = (teleop_action - self.joint_positions) / self.delta - info["action_intervention"] = relative_teleop_action - - return self.env.step(joint_positions) - - class RewardWrapper(gym.Wrapper): - def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"): + def __init__(self, env, reward_classifier, device: torch.device = "cuda"): + """ + Wrapper to add reward prediction to the environment, it use a trained classifer. + + Args: + env: The environment to wrap + reward_classifier: The reward classifier model + device: The device to run the model on + """ self.env = env + + # NOTE: We got 15% speedup by compiling the model self.reward_classifier = torch.compile(reward_classifier) self.device = device @@ -305,9 +308,7 @@ class RewardWrapper(gym.Wrapper): reward = ( self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0 ) - # print(f"fps for reward classifier {1/(time.perf_counter() - start_time)}") - reward = reward.item() - # print(f"Reward from reward classifier {reward}") + info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time) return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): @@ -323,17 +324,23 @@ class TimeLimitWrapper(gym.Wrapper): self.last_timestamp = 0.0 self.episode_time_in_s = 0.0 + self.max_episode_steps = int(self.control_time_s * self.fps) + + self.current_step = 0 + def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) time_since_last_step = time.perf_counter() - self.last_timestamp + # logging.warning(f"Current timestep is lower than the expected fps {self.fps}") self.episode_time_in_s += time_since_last_step self.last_timestamp = time.perf_counter() - + self.current_step += 1 # check if last timestep took more time than the expected fps - if 1.0 / time_since_last_step < self.fps: - logging.warning(f"Current timestep is lower than the expected fps {self.fps}") + # if 1.0 / time_since_last_step < self.fps: + # logging.warning(f"Current timestep exceeded expected fps {self.fps}") if self.episode_time_in_s > self.control_time_s: + # if self.current_step >= self.max_episode_steps: # Terminated = True terminated = True return obs, reward, terminated, truncated, info @@ -341,11 +348,13 @@ class TimeLimitWrapper(gym.Wrapper): def reset(self, seed=None, options=None): self.episode_time_in_s = 0.0 self.last_timestamp = time.perf_counter() + self.current_step = 0 return self.env.reset(seed=seed, options=options) class ImageCropResizeWrapper(gym.Wrapper): def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None): + super().__init__(env) self.env = env self.crop_params_dict = crop_params_dict print(f"obs_keys , {self.env.observation_space}") @@ -372,10 +381,21 @@ class ImageCropResizeWrapper(gym.Wrapper): obs[k] = F.resize(obs[k], self.resize_size) obs[k] = obs[k].to(device) # print(f"observation with key {k} with size {obs[k].size()}") - cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)) - cv2.waitKey(1) + # cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)) + # cv2.waitKey(1) return obs, reward, terminated, truncated, info + def reset(self, seed=None, options=None): + obs, info = self.env.reset(seed=seed, options=options) + for k in self.crop_params_dict: + device = obs[k].device + if device == torch.device("mps:0"): + obs[k] = obs[k].cpu() + obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) + obs[k] = F.resize(obs[k], self.resize_size) + obs[k] = obs[k].to(device) + return obs, info + class ConvertToLeRobotObservation(gym.ObservationWrapper): def __init__(self, env, device): @@ -515,42 +535,64 @@ class ResetWrapper(gym.Wrapper): return super().reset(seed=seed, options=options) +class BatchCompitableWrapper(gym.ObservationWrapper): + def __init__(self, env): + super().__init__(env) + + def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + for key in observation: + if "image" in key and observation[key].dim() == 3: + observation[key] = observation[key].unsqueeze(0) + if "state" in key and observation[key].dim() == 1: + observation[key] = observation[key].unsqueeze(0) + return observation + + def make_robot_env( robot, reward_classifier, - crop_params_dict=None, - fps=30, - control_time_s=20, - reset_follower_pos=True, - display_cameras=False, - device="cuda:0", - resize_size=None, - reset_time_s=10, - delta_action=0.1, - nb_repeats=1, - use_relative_joint_positions=False, -): + cfg, + n_envs: int = 1, +) -> gym.vector.VectorEnv: """ - Factory function to create the robot environment. + Factory function to create a vectorized robot environment. - Mimics gym.make() for consistent environment creation. + Args: + robot: Robot instance to control + reward_classifier: Classifier model for computing rewards + cfg: Configuration object containing environment parameters + n_envs: Number of environments to create in parallel. Defaults to 1. + + Returns: + A vectorized gym environment with all the necessary wrappers applied. """ + + # Create base environment env = HILSerlRobotEnv( - robot, - display_cameras=display_cameras, - delta=delta_action, - use_delta_action_space=use_relative_joint_positions, + robot=robot, + display_cameras=cfg.wrapper.display_cameras, + delta=cfg.wrapper.delta_action, + use_delta_action_space=cfg.wrapper.use_relative_joint_positions, ) - env = ConvertToLeRobotObservation(env, device) - if crop_params_dict is not None: - env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size) - env = RewardWrapper(env, reward_classifier, device=device) - env = TimeLimitWrapper(env, control_time_s, fps) - # env = ActionRepeatWrapper(env, nb_repeat=nb_repeats) - env = KeyboardInterfaceWrapper(env) - env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s) + + # Add observation and image processing + env = ConvertToLeRobotObservation(env=env, device=cfg.device) + if cfg.wrapper.crop_params_dict is not None: + env = ImageCropResizeWrapper( + env=env, crop_params_dict=cfg.wrapper.crop_params_dict, resize_size=cfg.wrapper.resize_size + ) + + # Add reward computation and control wrappers + env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) + env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) + env = KeyboardInterfaceWrapper(env=env) + env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.wrapper.reset_time_s) + env = BatchCompitableWrapper(env=env) + return env + # batched version of the env that returns an observation of shape (b, c) + def get_classifier(pretrained_path, config_path, device="mps"): if pretrained_path is None or config_path is None: @@ -616,6 +658,8 @@ if __name__ == "__main__": default=None, help="Path to a yaml config file that is necessary to build the reward classifier model.", ) + parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file") + parser.add_argument("--env-overrides", type=str, default=None, help="Overrides for the env yaml file") parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds") parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes") args = parser.parse_args() @@ -626,72 +670,38 @@ if __name__ == "__main__": reward_classifier = get_classifier( args.reward_classifier_pretrained_path, args.reward_classifier_config_file ) - - crop_parameters = { - "observation.images.laptop": (58, 89, 357, 455), - "observation.images.phone": (3, 4, 471, 633), - } - user_relative_joint_positions = True + cfg = init_hydra_config(args.env_path, args.env_overrides) env = make_robot_env( robot, reward_classifier, - crop_parameters, - args.fps, - args.control_time_s, - args.reset_follower_pos, - args.display_cameras, - device="mps", - resize_size=None, - reset_time_s=10, - delta_action=0.1, - nb_repeats=1, - use_relative_joint_positions=user_relative_joint_positions, + cfg.wrapper, ) env.reset() - init_pos = env.unwrapped.initial_follower_position - right_goal = init_pos.copy() - right_goal[0] += 50 + # Retrieve the robot's action space for joint commands. + action_space_robot = env.action_space.spaces[0] - left_goal = init_pos.copy() - left_goal[0] -= 50 + # Initialize the smoothed action as a random sample. + smoothed_action = action_space_robot.sample() - pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000) - - delta_angle = np.concatenate((-np.ones(50), np.ones(50))) * 100 + # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. + # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. + alpha = 0.4 while True: - action = np.zeros(len(init_pos)) - for i in range(len(delta_angle)): - start_loop_s = time.perf_counter() - action[0] = delta_angle[i] - obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False)) - if terminated or truncated: - env.reset() + start_loop_s = time.perf_counter() + # Sample a new random action from the robot's action space. + new_random_action = action_space_robot.sample() + # Update the smoothed action using an exponential moving average. + smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action - dt_s = time.perf_counter() - start_loop_s - busy_wait(1 / args.fps - dt_s) - # action = np.zeros(len(init_pos)) if user_relative_joint_positions else init_pos - # for i in range(len(pitch_angle)): - # if user_relative_joint_positions: - # action[0] = delta_angle[i] - # else: - # action[0] = pitch_angle[i] - # obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False)) - # if terminated or truncated: - # logging.info("Max control time reached, reset environment.") - # env.reset() + # Execute the step: wrap the NumPy action in a torch tensor. + obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False)) + if terminated or truncated: + env.reset() - # for i in reversed(range(len(pitch_angle))): - # if user_relative_joint_positions: - # action[0] = delta_angle[i] - # else: - # action[0] = pitch_angle[i] - # obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False)) - - # if terminated or truncated: - # logging.info("Max control time reached, reset environment.") - # env.reset() + dt_s = time.perf_counter() - start_loop_s + busy_wait(1 / args.fps - dt_s) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 5766c69c..bbd70598 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -36,6 +36,8 @@ from termcolor import colored from torch import nn from torch.optim.optimizer import Optimizer +from lerobot.common.datasets.factory import make_dataset + # TODO: Remove the import of maniskill from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.logger import Logger, log_output_dir @@ -52,6 +54,7 @@ from lerobot.common.utils.utils import ( ) from lerobot.scripts.server.buffer import ( ReplayBuffer, + concatenate_batch_transitions, move_state_dict_to_device, move_transition_to_device, ) @@ -259,8 +262,15 @@ def learner_push_parameters( while True: with policy_lock: params_dict = policy.actor.state_dict() - if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder: - params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")} + if policy.config.vision_encoder_name is not None: + if policy.config.freeze_vision_encoder: + params_dict: dict[str, torch.Tensor] = { + k: v for k, v in params_dict.items() if not k.startswith("encoder.") + } + else: + raise NotImplementedError( + "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model." + ) params_dict = move_state_dict_to_device(params_dict, device="cpu") # Serialize @@ -322,6 +332,7 @@ def add_actor_information_and_train( # in the future. The reason why we did that is the GIL in Python. It's super slow the performance # are divided by 200. So we need to have a single thread that does all the work. time.time() + logging.info("Starting learner thread") interaction_message, transition = None, None optimization_step = resume_optimization_step if resume_optimization_step is not None else 0 interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 @@ -340,16 +351,21 @@ def add_actor_information_and_train( # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging interaction_message["Interaction step"] += interaction_step_shift logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step") + logging.info(f"Interaction message: {interaction_message}") if len(replay_buffer) < cfg.training.online_step_before_learning: continue + + # logging.info(f"Size of replay buffer: {len(replay_buffer)}") + # logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}") + time_for_one_optimization_step = time.time() for _ in range(cfg.policy.utd_ratio - 1): batch = replay_buffer.sample(batch_size) - # if cfg.offline_dataset_repo_id is not None: - # batch_offline = offline_replay_buffer.sample(batch_size) - # batch = concatenate_batch_transitions(batch, batch_offline) + if cfg.dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size) + batch = concatenate_batch_transitions(batch, batch_offline) actions = batch["action"] rewards = batch["reward"] @@ -371,11 +387,11 @@ def add_actor_information_and_train( batch = replay_buffer.sample(batch_size) - # if cfg.offline_dataset_repo_id is not None: - # batch_offline = offline_replay_buffer.sample(batch_size) - # batch = concatenate_batch_transitions( - # left_batch_transitions=batch, right_batch_transition=batch_offline - # ) + if cfg.dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) actions = batch["action"] rewards = batch["reward"] @@ -423,7 +439,7 @@ def add_actor_information_and_train( time_for_one_optimization_step = time.time() - time_for_one_optimization_step frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) - logging.debug(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}") + logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}") logger.log_dict( {"Optimization frequency loop [Hz]": frequency_for_one_optimization_step}, @@ -560,14 +576,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No batch_size = cfg.training.batch_size offline_replay_buffer = None - # if cfg.dataset_repo_id is not None: - # logging.info("make_dataset offline buffer") - # offline_dataset = make_dataset(cfg) - # logging.info("Convertion to a offline replay buffer") - # offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( - # offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys() - # ) - # batch_size: int = batch_size // 2 # We will sample from both replay buffer + if cfg.dataset_repo_id is not None: + logging.info("make_dataset offline buffer") + offline_dataset = make_dataset(cfg) + logging.info("Convertion to a offline replay buffer") + offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( + offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys() + ) + batch_size: int = batch_size // 2 # We will sample from both replay buffer start_learner_threads( cfg, diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 0ca8eae4..15cf3d0b 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -279,8 +279,10 @@ def train(cfg: DictConfig) -> None: logging.info(f"Dataset size: {len(dataset)}") train_size = int(cfg.train_split_proportion * len(dataset)) - val_size = len(dataset) - train_size - train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + # val_size = len(dataset) - train_size + # train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + train_dataset = dataset[:train_size] + val_dataset = dataset[train_size:] sampler = create_balanced_sampler(train_dataset, cfg) train_loader = DataLoader(