[Port HIL-SERL] Adjust Actor-Learner architecture & clean up dependency management for HIL-SERL (#722)

This commit is contained in:
Eugene Mironov 2025-02-21 16:29:00 +07:00 committed by Adil Zouitine
parent b8e9ee440b
commit 304d7136df
17 changed files with 1949 additions and 475 deletions

View File

@ -46,7 +46,7 @@ repos:
rev: v3.19.1 rev: v3.19.1
hooks: hooks:
- id: pyupgrade - id: pyupgrade
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.10 rev: v0.9.10
hooks: hooks:

View File

@ -0,0 +1,11 @@
FROM huggingface/lerobot-gpu:latest
RUN apt-get update && apt-get install -y --no-install-recommends \
libvulkan1 vulkan-tools \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[mani-skill]"
# Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl"

View File

@ -81,3 +81,14 @@ You can also log sample predictions during evaluation. Each logged sample will i
- The **classifier's "confidence" (logits/probability)**. - The **classifier's "confidence" (logits/probability)**.
These logs can be useful for diagnosing and debugging performance issues. These logs can be useful for diagnosing and debugging performance issues.
#### Generate protobuf files
```bash
python -m grpc_tools.protoc \
-I lerobot/scripts/server \
--python_out=lerobot/scripts/server \
--grpc_python_out=lerobot/scripts/server \
lerobot/scripts/server/hilserl.proto
```

View File

@ -41,11 +41,16 @@ class SACConfig:
) )
input_normalization_params: dict[str, dict[str, list[float]]] = field( input_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": {"mean": [[0.485, 0.456, 0.406]], "std": [[0.229, 0.224, 0.225]]}, "observation.image": {
"mean": [[0.485, 0.456, 0.406]],
"std": [[0.229, 0.224, 0.225]],
},
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]}, "observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
} }
) )
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"}) output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"}
)
output_normalization_params: dict[str, dict[str, list[float]]] = field( output_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: { default_factory=lambda: {
"action": {"min": [-1, -1], "max": [1, 1]}, "action": {"min": [-1, -1], "max": [1, 1]},
@ -54,9 +59,8 @@ class SACConfig:
# TODO: Move it outside of the config # TODO: Move it outside of the config
actor_learner_config: dict[str, str | int] = field( actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: { default_factory=lambda: {
"actor_ip": "127.0.0.1", "learner_host": "127.0.0.1",
"port": 50051, "learner_port": 50051,
"learner_ip": "127.0.0.1",
} }
) )
camera_number: int = 1 camera_number: int = 1

View File

@ -108,5 +108,6 @@ policy:
utd_ratio: 2 # 10 utd_ratio: 2 # 10
actor_learner_config: actor_learner_config:
actor_ip: "127.0.0.1" learner_host: "127.0.0.1"
port: 50051 learner_port: 50051
policy_parameters_push_frequency: 15

View File

@ -65,7 +65,7 @@ policy:
action: [4] # ["${env.action_dim}"] action: [4] # ["${env.action_dim}"]
# Normalization / Unnormalization # Normalization / Unnormalization
input_normalization_modes: input_normalization_modes:
observation.images.front: mean_std observation.images.front: mean_std
observation.images.side: mean_std observation.images.side: mean_std
observation.state: min_max observation.state: min_max
@ -80,7 +80,7 @@ policy:
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786] min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01] max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
# min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274] # min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
# max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685] # max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274] # min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792] # max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
@ -112,8 +112,9 @@ policy:
utd_ratio: 2 # 10 utd_ratio: 2 # 10
actor_learner_config: actor_learner_config:
actor_ip: "127.0.0.1" learner_host: "127.0.0.1"
port: 50051 learner_port: 50051
policy_parameters_push_frequency: 15
# # Loss coefficients. # # Loss coefficients.
# reward_coeff: 0.5 # reward_coeff: 0.5

View File

@ -17,9 +17,9 @@ import io
import logging import logging
import pickle import pickle
import queue import queue
import time
from concurrent import futures
from statistics import mean, quantiles from statistics import mean, quantiles
import signal
from functools import lru_cache
# from lerobot.scripts.eval import eval_policy # from lerobot.scripts.eval import eval_policy
from threading import Thread from threading import Thread
@ -35,7 +35,6 @@ from torch import nn
# from lerobot.common.envs.utils import preprocess_maniskill_observation # from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.control_utils import busy_wait
from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
@ -44,14 +43,24 @@ from lerobot.common.utils.utils import (
set_global_seed, set_global_seed,
) )
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device from lerobot.scripts.server.buffer import (
Transition,
move_state_dict_to_device,
move_transition_to_device,
bytes_buffer_size,
)
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
from lerobot.scripts.server import learner_service
from threading import Event
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
parameters_queue = queue.Queue(maxsize=1) parameters_queue = queue.Queue(maxsize=1)
message_queue = queue.Queue(maxsize=1_000_000) message_queue = queue.Queue(maxsize=1_000_000)
ACTOR_SHUTDOWN_TIMEOUT = 30
class ActorInformation: class ActorInformation:
""" """
@ -70,95 +79,171 @@ class ActorInformation:
self.interaction_message = interaction_message self.interaction_message = interaction_message
class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer): def receive_policy(
""" learner_client: hilserl_pb2_grpc.LearnerServiceStub,
gRPC service for actor-learner communication in reinforcement learning. shutdown_event: Event,
parameters_queue: queue.Queue,
):
logging.info("[ACTOR] Start receiving parameters from the Learner")
bytes_buffer = io.BytesIO()
step = 0
try:
for model_update in learner_client.StreamParameters(hilserl_pb2.Empty()):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down policy streaming receiver")
return hilserl_pb2.Empty()
This service is responsible for: if model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
1. Streaming batches of transition data and statistical metrics from the actor to the learner. bytes_buffer.seek(0)
2. Receiving updated network parameters from the learner. bytes_buffer.truncate(0)
""" bytes_buffer.write(model_update.parameter_bytes)
logging.info("Received model update at step 0")
def StreamTransition(self, request, context): # noqa: N802 step = 0
""" continue
Streams data from the actor to the learner. elif (
model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE
This function continuously retrieves messages from the queue and processes them based on their type: ):
bytes_buffer.write(model_update.parameter_bytes)
- **Transition Data:** step += 1
- A batch of transitions (observation, action, reward, next observation) is collected. logging.info(f"Received model update at step {step}")
- Transitions are moved to the CPU and serialized using PyTorch. elif model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner. bytes_buffer.write(model_update.parameter_bytes)
logging.info(
- **Interaction Messages:** f"Received model update at step end size {bytes_buffer_size(bytes_buffer)}"
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Yields:
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
"""
while True:
message = message_queue.get(block=True)
if message.transition is not None:
transition_to_send_to_learner: list[Transition] = [
move_transition_to_device(transition=T, device="cpu") for T in message.transition
]
# Check for NaNs in transitions before sending to learner
for transition in transition_to_send_to_learner:
for key, value in transition["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
buf = io.BytesIO()
torch.save(transition_to_send_to_learner, buf)
transition_bytes = buf.getvalue()
transition_message = hilserl_pb2.Transition(transition_bytes=transition_bytes)
response = hilserl_pb2.ActorInformation(transition=transition_message)
elif message.interaction_message is not None:
content = hilserl_pb2.InteractionMessage(
interaction_message_bytes=pickle.dumps(message.interaction_message)
) )
response = hilserl_pb2.ActorInformation(interaction_message=content)
yield response state_dict = torch.load(bytes_buffer)
def SendParameters(self, request, context): # noqa: N802 bytes_buffer.seek(0)
""" bytes_buffer.truncate(0)
Receives updated parameters from the learner and updates the actor. step = 0
The learner calls this method to send new model parameters. The received parameters are deserialized logging.info("Model updated")
and placed in a queue to be consumed by the actor.
Args: parameters_queue.put(state_dict)
request (hilserl_pb2.ParameterUpdate): The request containing serialized network parameters.
context (grpc.ServicerContext): The gRPC context.
Returns: except grpc.RpcError as e:
hilserl_pb2.Empty: An empty response to acknowledge receipt. logging.error(f"[ACTOR] gRPC error: {e}")
"""
buffer = io.BytesIO(request.parameter_bytes) return hilserl_pb2.Empty()
params = torch.load(buffer)
parameters_queue.put(params)
return hilserl_pb2.Empty()
def serve_actor_service(port=50052): def transitions_stream(shutdown_event: Event, message_queue: queue.Queue):
while not shutdown_event.is_set():
try:
message = message_queue.get(block=True, timeout=5)
except queue.Empty:
logging.debug("[ACTOR] Transition queue is empty")
continue
if message.transition is not None:
transition_to_send_to_learner: list[Transition] = [
move_transition_to_device(transition=T, device="cpu")
for T in message.transition
]
# Check for NaNs in transitions before sending to learner
for transition in transition_to_send_to_learner:
for key, value in transition["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
buf = io.BytesIO()
torch.save(transition_to_send_to_learner, buf)
transition_bytes = buf.getvalue()
transition_message = hilserl_pb2.Transition(
transition_bytes=transition_bytes
)
response = hilserl_pb2.ActorInformation(transition=transition_message)
elif message.interaction_message is not None:
content = hilserl_pb2.InteractionMessage(
interaction_message_bytes=pickle.dumps(message.interaction_message)
)
response = hilserl_pb2.ActorInformation(interaction_message=content)
yield response
return hilserl_pb2.Empty()
def send_transitions(
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
shutdown_event: Event,
message_queue: queue.Queue,
):
""" """
Runs a gRPC server to start streaming the data from the actor to the learner. Streams data from the actor to the learner.
Throught this server the learner can push parameters to the Actor as well.
This function continuously retrieves messages from the queue and processes them based on their type:
- **Transition Data:**
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
- **Interaction Messages:**
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Yields:
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
""" """
server = grpc.server( try:
futures.ThreadPoolExecutor(max_workers=20), learner_client.ReceiveTransitions(
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)], transitions_stream(shutdown_event, message_queue)
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming transitions")
@lru_cache(maxsize=1)
def learner_service_client(
host="127.0.0.1", port=50051
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
import json
"""
Returns a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it.
"""
service_config = {
"methodConfig": [
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": 5, # Max retries (total attempts = 5)
"initialBackoff": "0.1s", # First retry after 0.1s
"maxBackoff": "2s", # Max wait time between retries
"backoffMultiplier": 2, # Exponential backoff factor
"retryableStatusCodes": [
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
], # Retries on network failures
},
}
]
}
service_config_json = json.dumps(service_config)
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.enable_retries", 1),
("grpc.service_config", service_config_json),
],
) )
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(ActorServiceServicer(), server) stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
server.add_insecure_port(f"[::]:{port}") logging.info("[LEARNER] Learner service client created")
server.start() return stub, channel
logging.info(f"[ACTOR] gRPC server listening on port {port}")
server.wait_for_termination()
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device): def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
@ -169,7 +254,9 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, d
policy.load_state_dict(state_dict) policy.load_state_dict(state_dict)
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module): def act_with_policy(
cfg: DictConfig, robot: Robot, reward_classifier: nn.Module, shutdown_event: Event
):
""" """
Executes policy interaction within the environment. Executes policy interaction within the environment.
@ -182,7 +269,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
logging.info("make_env online") logging.info("make_env online")
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg) online_env = make_robot_env(
robot=robot, reward_classifier=reward_classifier, cfg=cfg
)
set_global_seed(cfg.seed) set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
@ -227,17 +316,27 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
episode_intervention = False episode_intervention = False
for interaction_step in range(cfg.training.online_steps): for interaction_step in range(cfg.training.online_steps):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutdown signal received. Exiting...")
return
if interaction_step >= cfg.training.online_step_before_learning: if interaction_step >= cfg.training.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement # Time policy inference and check if it meets FPS requirement
with TimerManager( with TimerManager(
elapsed_time_list=list_policy_time, label="Policy inference time", log=False elapsed_time_list=list_policy_time,
label="Policy inference time",
log=False,
) as timer: # noqa: F841 ) as timer: # noqa: F841
action = policy.select_action(batch=obs) action = policy.select_action(batch=obs)
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9) policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) log_policy_frequency_issue(
policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step
)
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy()) next_obs, reward, done, truncated, info = online_env.step(
action.squeeze(dim=0).cpu().numpy()
)
else: else:
# TODO (azouitine): Make a custom space for torch tensor # TODO (azouitine): Make a custom space for torch tensor
action = online_env.action_space.sample() action = online_env.action_space.sample()
@ -245,7 +344,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box # HACK: We have only one env but we want to batch it, it will be resolved with the torch box
action = ( action = (
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0) torch.from_numpy(action[0])
.to(device, non_blocking=device.type == "cuda")
.unsqueeze(dim=0)
) )
sum_reward_episode += float(reward) sum_reward_episode += float(reward)
@ -261,7 +362,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
# Check for NaN values in observations # Check for NaN values in observations
for key, tensor in obs.items(): for key, tensor in obs.items():
if torch.isnan(tensor).any(): if torch.isnan(tensor).any():
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}") logging.error(
f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}"
)
list_transition_to_send_to_learner.append( list_transition_to_send_to_learner.append(
Transition( Transition(
@ -281,13 +384,19 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
# Because we are using a single environment we can index at zero # Because we are using a single environment we can index at zero
if done or truncated: if done or truncated:
# TODO: Handle logging for episode information # TODO: Handle logging for episode information
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") logging.info(
f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}"
)
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device) update_policy_parameters(
policy=policy.actor, parameters_queue=parameters_queue, device=device
)
if len(list_transition_to_send_to_learner) > 0: if len(list_transition_to_send_to_learner) > 0:
send_transitions_in_chunks( send_transitions_in_chunks(
transitions=list_transition_to_send_to_learner, message_queue=message_queue, chunk_size=4 transitions=list_transition_to_send_to_learner,
message_queue=message_queue,
chunk_size=4,
) )
list_transition_to_send_to_learner = [] list_transition_to_send_to_learner = []
@ -332,11 +441,16 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
quantiles_90 = quantiles(list_policy_fps, n=10)[-1] quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}") logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}") logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90} stats = {
"Policy frequency [Hz]": policy_fps,
"Policy frequency 90th-p [Hz]": quantiles_90,
}
return stats return stats
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int): def log_policy_frequency_issue(
policy_fps: float, cfg: DictConfig, interaction_step: int
):
if policy_fps < cfg.fps: if policy_fps < cfg.fps:
logging.warning( logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}" f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
@ -347,7 +461,34 @@ def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_s
def actor_cli(cfg: dict): def actor_cli(cfg: dict):
robot = make_robot(cfg=cfg.robot) robot = make_robot(cfg=cfg.robot)
server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True) shutdown_event = Event()
# Define signal handler
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
)
receive_policy_thread = Thread(
target=receive_policy,
args=(learner_client, shutdown_event, parameters_queue),
daemon=True,
)
transitions_thread = Thread(
target=send_transitions,
args=(learner_client, shutdown_event, message_queue),
daemon=True,
)
# HACK: FOR MANISKILL we do not have a reward classifier # HACK: FOR MANISKILL we do not have a reward classifier
# TODO: Remove this once we merge into main # TODO: Remove this once we merge into main
@ -360,15 +501,27 @@ def actor_cli(cfg: dict):
pretrained_path=cfg.env.reward_classifier.pretrained_path, pretrained_path=cfg.env.reward_classifier.pretrained_path,
config_path=cfg.env.reward_classifier.config_path, config_path=cfg.env.reward_classifier.config_path,
) )
policy_thread = Thread( policy_thread = Thread(
target=act_with_policy, target=act_with_policy,
daemon=True, daemon=True,
args=(cfg, robot, reward_classifier), args=(cfg, robot, reward_classifier, shutdown_event),
) )
server_thread.start()
transitions_thread.start()
policy_thread.start() policy_thread.start()
receive_policy_thread.start()
shutdown_event.wait()
logging.info("[ACTOR] Shutdown event received")
grpc_channel.close()
policy_thread.join() policy_thread.join()
server_thread.join() logging.info("[ACTOR] Policy thread joined")
transitions_thread.join()
logging.info("[ACTOR] Transitions thread joined")
receive_policy_thread.join()
logging.info("[ACTOR] Receive policy thread joined")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -17,6 +17,7 @@ import functools
import random import random
from typing import Any, Callable, Optional, Sequence, TypedDict from typing import Any, Callable, Optional, Sequence, TypedDict
import io
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from tqdm import tqdm from tqdm import tqdm
@ -41,24 +42,33 @@ class BatchTransition(TypedDict):
done: torch.Tensor done: torch.Tensor
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition: def move_transition_to_device(
transition: Transition, device: str = "cpu"
) -> Transition:
# Move state tensors to CPU # Move state tensors to CPU
device = torch.device(device) device = torch.device(device)
transition["state"] = { transition["state"] = {
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items() key: val.to(device, non_blocking=device.type == "cuda")
for key, val in transition["state"].items()
} }
# Move action to CPU # Move action to CPU
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda") transition["action"] = transition["action"].to(
device, non_blocking=device.type == "cuda"
)
# No need to move reward or done, as they are float and bool # No need to move reward or done, as they are float and bool
# No need to move reward or done, as they are float and bool # No need to move reward or done, as they are float and bool
if isinstance(transition["reward"], torch.Tensor): if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda") transition["reward"] = transition["reward"].to(
device=device, non_blocking=device.type == "cuda"
)
if isinstance(transition["done"], torch.Tensor): if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda") transition["done"] = transition["done"].to(
device, non_blocking=device.type == "cuda"
)
# Move next_state tensors to CPU # Move next_state tensors to CPU
transition["next_state"] = { transition["next_state"] = {
@ -82,7 +92,10 @@ def move_state_dict_to_device(state_dict, device):
if isinstance(state_dict, torch.Tensor): if isinstance(state_dict, torch.Tensor):
return state_dict.to(device) return state_dict.to(device)
elif isinstance(state_dict, dict): elif isinstance(state_dict, dict):
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()} return {
k: move_state_dict_to_device(v, device=device)
for k, v in state_dict.items()
}
elif isinstance(state_dict, list): elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict] return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple): elif isinstance(state_dict, tuple):
@ -91,6 +104,22 @@ def move_state_dict_to_device(state_dict, device):
return state_dict return state_dict
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> io.BytesIO:
"""Convert model state dict to flat array for transmission"""
buffer = io.BytesIO()
torch.save(state_dict, buffer)
return buffer
def bytes_buffer_size(buffer: io.BytesIO) -> int:
buffer.seek(0, io.SEEK_END)
result = buffer.tell()
buffer.seek(0)
return result
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
""" """
Perform a per-image random crop over a batch of images in a vectorized way. Perform a per-image random crop over a batch of images in a vectorized way.
@ -116,7 +145,9 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels # Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :] cropped_hwcn = images_hwcn[
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
]
# cropped_hwcn => (B, crop_h, crop_w, C) # cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
@ -179,7 +210,9 @@ class ReplayBuffer:
"""Saves a transition, ensuring tensors are stored on the designated storage device.""" """Saves a transition, ensuring tensors are stored on the designated storage device."""
# Move tensors to the storage device # Move tensors to the storage device
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()} state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()} next_state = {
key: tensor.to(self.storage_device) for key, tensor in next_state.items()
}
action = action.to(self.storage_device) action = action.to(self.storage_device)
# if complementary_info is not None: # if complementary_info is not None:
# complementary_info = { # complementary_info = {
@ -234,7 +267,9 @@ class ReplayBuffer:
) )
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys) replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys) list_transition = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
# Fill the replay buffer with the lerobot dataset transitions # Fill the replay buffer with the lerobot dataset transitions
for data in list_transition: for data in list_transition:
for k, v in data.items(): for k, v in data.items():
@ -295,7 +330,9 @@ class ReplayBuffer:
# If not provided, you can either raise an error or define a default: # If not provided, you can either raise an error or define a default:
if state_keys is None: if state_keys is None:
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.") raise ValueError(
"You must provide a list of keys in `state_keys` that define your 'state'."
)
transitions: list[Transition] = [] transitions: list[Transition] = []
num_frames = len(dataset) num_frames = len(dataset)
@ -350,33 +387,37 @@ class ReplayBuffer:
# -- Build batched states -- # -- Build batched states --
batch_state = {} batch_state = {}
for key in self.state_keys: for key in self.state_keys:
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to( batch_state[key] = torch.cat(
self.device [t["state"][key] for t in list_of_transitions], dim=0
) ).to(self.device)
if key.startswith("observation.image") and self.use_drq: if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key]) batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions -- # -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device) batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
# -- Build batched rewards --
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
self.device self.device
) )
# -- Build batched rewards --
batch_rewards = torch.tensor(
[t["reward"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# -- Build batched next states -- # -- Build batched next states --
batch_next_state = {} batch_next_state = {}
for key in self.state_keys: for key in self.state_keys:
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to( batch_next_state[key] = torch.cat(
self.device [t["next_state"][key] for t in list_of_transitions], dim=0
) ).to(self.device)
if key.startswith("observation.image") and self.use_drq: if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key]) batch_next_state[key] = self.image_augmentation_function(
batch_next_state[key]
)
# -- Build batched dones -- # -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( batch_dones = torch.tensor(
self.device [t["done"] for t in list_of_transitions], dtype=torch.float32
) ).to(self.device)
# Return a BatchTransition typed dict # Return a BatchTransition typed dict
return BatchTransition( return BatchTransition(
@ -433,7 +474,9 @@ class ReplayBuffer:
# Add state keys # Add state keys
for key in self.state_keys: for key in self.state_keys:
sample_val = first_transition["state"][key].squeeze(dim=0) # Remove batch dimension sample_val = first_transition["state"][key].squeeze(
dim=0
) # Remove batch dimension
if not isinstance(sample_val, torch.Tensor): if not isinstance(sample_val, torch.Tensor):
raise ValueError( raise ValueError(
f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors." f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors."
@ -465,7 +508,9 @@ class ReplayBuffer:
# We detect episode boundaries by `done == True`. # We detect episode boundaries by `done == True`.
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
episode_index = 0 episode_index = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index) lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index
)
frame_idx_in_episode = 0 frame_idx_in_episode = 0
for global_frame_idx, transition in enumerate(self.memory): for global_frame_idx, transition in enumerate(self.memory):
@ -476,16 +521,24 @@ class ReplayBuffer:
# Expand dimension to match what the dataset expects (the dataset wants the raw shape) # Expand dimension to match what the dataset expects (the dataset wants the raw shape)
# We assume your buffer has shape [C, H, W] (if image) or [D] if vector # We assume your buffer has shape [C, H, W] (if image) or [D] if vector
# This is typically already correct, but if needed you can reshape below. # This is typically already correct, but if needed you can reshape below.
frame_dict[key] = transition["state"][key].cpu().squeeze(dim=0) # Remove batch dimension frame_dict[key] = (
transition["state"][key].cpu().squeeze(dim=0)
) # Remove batch dimension
# Fill action, reward, done # Fill action, reward, done
# Make sure they are shape (X,) or (X,Y,...) as needed. # Make sure they are shape (X,) or (X,Y,...) as needed.
frame_dict["action"] = transition["action"].cpu().squeeze(dim=0) # Remove batch dimension frame_dict["action"] = (
transition["action"].cpu().squeeze(dim=0)
) # Remove batch dimension
frame_dict["next.reward"] = ( frame_dict["next.reward"] = (
torch.tensor([transition["reward"]], dtype=torch.float32).cpu().squeeze(dim=0) torch.tensor([transition["reward"]], dtype=torch.float32)
.cpu()
.squeeze(dim=0)
) )
frame_dict["next.done"] = ( frame_dict["next.done"] = (
torch.tensor([transition["done"]], dtype=torch.bool).cpu().squeeze(dim=0) torch.tensor([transition["done"]], dtype=torch.bool)
.cpu()
.squeeze(dim=0)
) )
# Add to the dataset's buffer # Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict) lerobot_dataset.add_frame(frame_dict)
@ -499,7 +552,9 @@ class ReplayBuffer:
episode_index += 1 episode_index += 1
frame_idx_in_episode = 0 frame_idx_in_episode = 0
# Start a new buffer for the next episode # Start a new buffer for the next episode
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index) lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index
)
# We are done adding frames # We are done adding frames
# If the last transition wasn't done=True, we still have an open buffer with frames. # If the last transition wasn't done=True, we still have an open buffer with frames.
@ -541,7 +596,13 @@ def concatenate_batch_transitions(
) -> BatchTransition: ) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place""" """NOTE: Be careful it change the left_batch_transitions in place"""
left_batch_transitions["state"] = { left_batch_transitions["state"] = {
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0) key: torch.cat(
[
left_batch_transitions["state"][key],
right_batch_transition["state"][key],
],
dim=0,
)
for key in left_batch_transitions["state"] for key in left_batch_transitions["state"]
} }
left_batch_transitions["action"] = torch.cat( left_batch_transitions["action"] = torch.cat(
@ -552,7 +613,11 @@ def concatenate_batch_transitions(
) )
left_batch_transitions["next_state"] = { left_batch_transitions["next_state"] = {
key: torch.cat( key: torch.cat(
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0 [
left_batch_transitions["next_state"][key],
right_batch_transition["next_state"][key],
],
dim=0,
) )
for key in left_batch_transitions["next_state"] for key in left_batch_transitions["next_state"]
} }

View File

@ -10,11 +10,9 @@ import torch
import torchvision.transforms.functional as F # noqa: N812 import torchvision.transforms.functional as F # noqa: N812
from lerobot.common.envs.utils import preprocess_observation from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position from lerobot.common.robot_devices.control_utils import busy_wait, is_headless
from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config, log_say from lerobot.common.utils.utils import init_hydra_config, log_say
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -62,7 +60,9 @@ class HILSerlRobotEnv(gym.Env):
if not self.robot.is_connected: if not self.robot.is_connected:
self.robot.connect() self.robot.connect()
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position") self.initial_follower_position = robot.follower_arms["main"].read(
"Present_Position"
)
# Episode tracking. # Episode tracking.
self.current_step = 0 self.current_step = 0
@ -70,7 +70,9 @@ class HILSerlRobotEnv(gym.Env):
self.delta = delta self.delta = delta
self.use_delta_action_space = use_delta_action_space self.use_delta_action_space = use_delta_action_space
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") self.current_joint_positions = self.robot.follower_arms["main"].read(
"Present_Position"
)
# Retrieve the size of the joint position interval bound. # Retrieve the size of the joint position interval bound.
self.relative_bounds_size = ( self.relative_bounds_size = (
@ -105,12 +107,16 @@ class HILSerlRobotEnv(gym.Env):
image_keys = [key for key in example_obs if "image" in key] image_keys = [key for key in example_obs if "image" in key]
state_keys = [key for key in example_obs if "image" not in key] state_keys = [key for key in example_obs if "image" not in key]
observation_spaces = { observation_spaces = {
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8) key: gym.spaces.Box(
low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8
)
for key in image_keys for key in image_keys
} }
observation_spaces["observation.state"] = gym.spaces.Dict( observation_spaces["observation.state"] = gym.spaces.Dict(
{ {
key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32) key: gym.spaces.Box(
low=0, high=10, shape=example_obs[key].shape, dtype=np.float32
)
for key in state_keys for key in state_keys
} }
) )
@ -128,8 +134,12 @@ class HILSerlRobotEnv(gym.Env):
) )
else: else:
action_space_robot = gym.spaces.Box( action_space_robot = gym.spaces.Box(
low=self.robot.config.joint_position_relative_bounds["min"].cpu().numpy(), low=self.robot.config.joint_position_relative_bounds["min"]
high=self.robot.config.joint_position_relative_bounds["max"].cpu().numpy(), .cpu()
.numpy(),
high=self.robot.config.joint_position_relative_bounds["max"]
.cpu()
.numpy(),
shape=(action_dim,), shape=(action_dim,),
dtype=np.float32, dtype=np.float32,
) )
@ -141,7 +151,9 @@ class HILSerlRobotEnv(gym.Env):
), ),
) )
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: def reset(
self, seed=None, options=None
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
""" """
Reset the environment to its initial state. Reset the environment to its initial state.
This method resets the step counter and clears any episodic data. This method resets the step counter and clears any episodic data.
@ -198,24 +210,34 @@ class HILSerlRobotEnv(gym.Env):
""" """
policy_action, intervention_bool = action policy_action, intervention_bool = action
teleop_action = None teleop_action = None
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") self.current_joint_positions = self.robot.follower_arms["main"].read(
"Present_Position"
)
if isinstance(policy_action, torch.Tensor): if isinstance(policy_action, torch.Tensor):
policy_action = policy_action.cpu().numpy() policy_action = policy_action.cpu().numpy()
policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high) policy_action = np.clip(
policy_action, self.action_space[0].low, self.action_space[0].high
)
if not intervention_bool: if not intervention_bool:
if self.use_delta_action_space: if self.use_delta_action_space:
target_joint_positions = self.current_joint_positions + self.delta * policy_action target_joint_positions = (
self.current_joint_positions + self.delta * policy_action
)
else: else:
target_joint_positions = policy_action target_joint_positions = policy_action
self.robot.send_action(torch.from_numpy(target_joint_positions)) self.robot.send_action(torch.from_numpy(target_joint_positions))
observation = self.robot.capture_observation() observation = self.robot.capture_observation()
else: else:
observation, teleop_action = self.robot.teleop_step(record_data=True) observation, teleop_action = self.robot.teleop_step(record_data=True)
teleop_action = teleop_action["action"] # Convert tensor to appropriate format teleop_action = teleop_action[
"action"
] # Convert tensor to appropriate format
# When applying the delta action space, convert teleop absolute values to relative differences. # When applying the delta action space, convert teleop absolute values to relative differences.
if self.use_delta_action_space: if self.use_delta_action_space:
teleop_action = (teleop_action - self.current_joint_positions) / self.delta teleop_action = (
teleop_action - self.current_joint_positions
) / self.delta
if torch.any(teleop_action < -self.relative_bounds_size) and torch.any( if torch.any(teleop_action < -self.relative_bounds_size) and torch.any(
teleop_action > self.relative_bounds_size teleop_action > self.relative_bounds_size
): ):
@ -226,7 +248,9 @@ class HILSerlRobotEnv(gym.Env):
) )
teleop_action = torch.clamp( teleop_action = torch.clamp(
teleop_action, -self.relative_bounds_size, self.relative_bounds_size teleop_action,
-self.relative_bounds_size,
self.relative_bounds_size,
) )
# NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action. # NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action.
if teleop_action.dim() == 1: if teleop_action.dim() == 1:
@ -245,7 +269,10 @@ class HILSerlRobotEnv(gym.Env):
reward, reward,
terminated, terminated,
truncated, truncated,
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None}, {
"action_intervention": teleop_action,
"is_intervention": teleop_action is not None,
},
) )
def render(self): def render(self):
@ -351,7 +378,9 @@ class JointMaskingActionSpace(gym.Wrapper):
raise ValueError("Mask length must match action space dimensions") raise ValueError("Mask length must match action space dimensions")
low = env.action_space.low[self.active_dims] low = env.action_space.low[self.active_dims]
high = env.action_space.high[self.active_dims] high = env.action_space.high[self.active_dims]
self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype) self.action_space = gym.spaces.Box(
low=low, high=high, dtype=env.action_space.dtype
)
if isinstance(env.action_space, gym.spaces.Tuple): if isinstance(env.action_space, gym.spaces.Tuple):
if len(mask) != env.action_space[0].shape[0]: if len(mask) != env.action_space[0].shape[0]:
@ -359,8 +388,12 @@ class JointMaskingActionSpace(gym.Wrapper):
low = env.action_space[0].low[self.active_dims] low = env.action_space[0].low[self.active_dims]
high = env.action_space[0].high[self.active_dims] high = env.action_space[0].high[self.active_dims]
action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype) action_space_masked = gym.spaces.Box(
self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1])) low=low, high=high, dtype=env.action_space[0].dtype
)
self.action_space = gym.spaces.Tuple(
(action_space_masked, env.action_space[1])
)
# Create new action space with masked dimensions # Create new action space with masked dimensions
def action(self, action): def action(self, action):
@ -379,14 +412,18 @@ class JointMaskingActionSpace(gym.Wrapper):
# Extract the masked component from the tuple. # Extract the masked component from the tuple.
masked_action = action[0] if isinstance(action, tuple) else action masked_action = action[0] if isinstance(action, tuple) else action
# Create a full action for the Box element. # Create a full action for the Box element.
full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype) full_box_action = np.zeros(
self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype
)
full_box_action[self.active_dims] = masked_action full_box_action[self.active_dims] = masked_action
# Return a tuple with the reconstructed Box action and the unchanged remainder. # Return a tuple with the reconstructed Box action and the unchanged remainder.
return (full_box_action, action[1]) return (full_box_action, action[1])
else: else:
# For Box action spaces. # For Box action spaces.
masked_action = action if not isinstance(action, tuple) else action[0] masked_action = action if not isinstance(action, tuple) else action[0]
full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype) full_action = np.zeros(
self.env.action_space.shape, dtype=self.env.action_space.dtype
)
full_action[self.active_dims] = masked_action full_action[self.active_dims] = masked_action
return full_action return full_action
@ -395,9 +432,13 @@ class JointMaskingActionSpace(gym.Wrapper):
obs, reward, terminated, truncated, info = self.env.step(action) obs, reward, terminated, truncated, info = self.env.step(action)
if "action_intervention" in info and info["action_intervention"] is not None: if "action_intervention" in info and info["action_intervention"] is not None:
if info["action_intervention"].dim() == 1: if info["action_intervention"].dim() == 1:
info["action_intervention"] = info["action_intervention"][self.active_dims] info["action_intervention"] = info["action_intervention"][
self.active_dims
]
else: else:
info["action_intervention"] = info["action_intervention"][:, self.active_dims] info["action_intervention"] = info["action_intervention"][
:, self.active_dims
]
return obs, reward, terminated, truncated, info return obs, reward, terminated, truncated, info
@ -438,7 +479,12 @@ class TimeLimitWrapper(gym.Wrapper):
class ImageCropResizeWrapper(gym.Wrapper): class ImageCropResizeWrapper(gym.Wrapper):
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None): def __init__(
self,
env,
crop_params_dict: Dict[str, Annotated[Tuple[int], 4]],
resize_size=None,
):
super().__init__(env) super().__init__(env)
self.env = env self.env = env
self.crop_params_dict = crop_params_dict self.crop_params_dict = crop_params_dict
@ -450,7 +496,9 @@ class ImageCropResizeWrapper(gym.Wrapper):
for key in crop_params_dict: for key in crop_params_dict:
top, left, height, width = crop_params_dict[key] top, left, height, width = crop_params_dict[key]
new_shape = (top + height, left + width) new_shape = (top + height, left + width)
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape) self.observation_space[key] = gym.spaces.Box(
low=0, high=255, shape=new_shape
)
self.resize_size = resize_size self.resize_size = resize_size
if self.resize_size is None: if self.resize_size is None:
@ -463,7 +511,9 @@ class ImageCropResizeWrapper(gym.Wrapper):
# Check for NaNs before processing # Check for NaNs before processing
if torch.isnan(obs[k]).any(): if torch.isnan(obs[k]).any():
logging.error(f"NaN values detected in observation {k} before crop and resize") logging.error(
f"NaN values detected in observation {k} before crop and resize"
)
if device == torch.device("mps:0"): if device == torch.device("mps:0"):
obs[k] = obs[k].cpu() obs[k] = obs[k].cpu()
@ -473,7 +523,9 @@ class ImageCropResizeWrapper(gym.Wrapper):
# Check for NaNs after processing # Check for NaNs after processing
if torch.isnan(obs[k]).any(): if torch.isnan(obs[k]).any():
logging.error(f"NaN values detected in observation {k} after crop and resize") logging.error(
f"NaN values detected in observation {k} after crop and resize"
)
obs[k] = obs[k].to(device) obs[k] = obs[k].to(device)
@ -503,10 +555,14 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper):
observation = preprocess_observation(observation) observation = preprocess_observation(observation)
observation = { observation = {
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda") key: observation[key].to(
self.device, non_blocking=self.device.type == "cuda"
)
for key in observation for key in observation
} }
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()} observation = {
k: torch.tensor(v, device=self.device) for k, v in observation.items()
}
return observation return observation
@ -553,18 +609,31 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
"Place the leader in similar pose to the follower and press space again." "Place the leader in similar pose to the follower and press space again."
) )
self.events["pause_policy"] = True self.events["pause_policy"] = True
log_say("Human intervention stage. Get ready to take over.", play_sounds=True) log_say(
"Human intervention stage. Get ready to take over.",
play_sounds=True,
)
return return
if self.events["pause_policy"] and not self.events["human_intervention_step"]: if (
self.events["pause_policy"]
and not self.events["human_intervention_step"]
):
self.events["human_intervention_step"] = True self.events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.") print("Space key pressed. Human intervention starting.")
log_say("Starting human intervention.", play_sounds=True) log_say(
"Starting human intervention.", play_sounds=True
)
return return
if self.events["pause_policy"] and self.events["human_intervention_step"]: if (
self.events["pause_policy"]
and self.events["human_intervention_step"]
):
self.events["pause_policy"] = False self.events["pause_policy"] = False
self.events["human_intervention_step"] = False self.events["human_intervention_step"] = False
print("Space key pressed for a third time.") print("Space key pressed for a third time.")
log_say("Continuing with policy actions.", play_sounds=True) log_say(
"Continuing with policy actions.", play_sounds=True
)
return return
except Exception as e: except Exception as e:
print(f"Error handling key press: {e}") print(f"Error handling key press: {e}")
@ -572,7 +641,9 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
self.listener = keyboard.Listener(on_press=on_press) self.listener = keyboard.Listener(on_press=on_press)
self.listener.start() self.listener.start()
except ImportError: except ImportError:
logging.warning("Could not import pynput. Keyboard interface will not be available.") logging.warning(
"Could not import pynput. Keyboard interface will not be available."
)
self.listener = None self.listener = None
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]: def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
@ -599,7 +670,9 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
time.sleep(0.1) # Check more frequently if desired time.sleep(0.1) # Check more frequently if desired
# Execute the step in the underlying environment # Execute the step in the underlying environment
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention)) obs, reward, terminated, truncated, info = self.env.step(
(policy_action, is_intervention)
)
# Override reward and termination if episode success event triggered # Override reward and termination if episode success event triggered
with self.event_lock: with self.event_lock:
@ -628,7 +701,10 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
class ResetWrapper(gym.Wrapper): class ResetWrapper(gym.Wrapper):
def __init__( def __init__(
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5 self,
env: HILSerlRobotEnv,
reset_fn: Optional[Callable[[], None]] = None,
reset_time_s: float = 5,
): ):
super().__init__(env) super().__init__(env)
self.reset_fn = reset_fn self.reset_fn = reset_fn
@ -641,7 +717,10 @@ class ResetWrapper(gym.Wrapper):
if self.reset_fn is not None: if self.reset_fn is not None:
self.reset_fn(self.env) self.reset_fn(self.env)
else: else:
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True) log_say(
f"Manually reset the environment for {self.reset_time_s} seconds.",
play_sounds=True,
)
start_time = time.perf_counter() start_time = time.perf_counter()
while time.perf_counter() - start_time < self.reset_time_s: while time.perf_counter() - start_time < self.reset_time_s:
self.robot.teleop_step() self.robot.teleop_step()
@ -654,7 +733,9 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: def observation(
self, observation: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
for key in observation: for key in observation:
if "image" in key and observation[key].dim() == 3: if "image" in key and observation[key].dim() == 3:
observation[key] = observation[key].unsqueeze(0) observation[key] = observation[key].unsqueeze(0)
@ -685,6 +766,8 @@ def make_robot_env(
A vectorized gym environment with all the necessary wrappers applied. A vectorized gym environment with all the necessary wrappers applied.
""" """
if "maniskill" in cfg.env.name: if "maniskill" in cfg.env.name:
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
env = make_maniskill( env = make_maniskill(
cfg=cfg, cfg=cfg,
@ -703,15 +786,23 @@ def make_robot_env(
env = ConvertToLeRobotObservation(env=env, device=cfg.device) env = ConvertToLeRobotObservation(env=env, device=cfg.device)
if cfg.env.wrapper.crop_params_dict is not None: if cfg.env.wrapper.crop_params_dict is not None:
env = ImageCropResizeWrapper( env = ImageCropResizeWrapper(
env=env, crop_params_dict=cfg.env.wrapper.crop_params_dict, resize_size=cfg.env.wrapper.resize_size env=env,
crop_params_dict=cfg.env.wrapper.crop_params_dict,
resize_size=cfg.env.wrapper.resize_size,
) )
# Add reward computation and control wrappers # Add reward computation and control wrappers
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps) env = TimeLimitWrapper(
env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps
)
env = KeyboardInterfaceWrapper(env=env) env = KeyboardInterfaceWrapper(env=env)
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s) env = ResetWrapper(
env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space) env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s
)
env = JointMaskingActionSpace(
env=env, mask=cfg.env.wrapper.joint_masking_action_space
)
env = BatchCompitableWrapper(env=env) env = BatchCompitableWrapper(env=env)
return env return env
@ -724,13 +815,19 @@ def get_classifier(pretrained_path, config_path, device="mps"):
return None return None
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
cfg = init_hydra_config(config_path) cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths classifier_config.num_cameras = len(
cfg.training.image_keys
) # TODO automate these paths
model = Classifier(classifier_config) model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict()) model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to(device) model = model.to(device)
@ -741,7 +838,9 @@ def replay_episode(env, repo_id, root=None, episode=0):
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
local_files_only = root is not None local_files_only = root is not None
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only) dataset = LeRobotDataset(
repo_id, root=root, episodes=[episode], local_files_only=local_files_only
)
actions = dataset.hf_dataset.select_columns("action") actions = dataset.hf_dataset.select_columns("action")
for idx in range(dataset.num_frames): for idx in range(dataset.num_frames):
@ -787,7 +886,8 @@ if __name__ == "__main__":
), ),
) )
parser.add_argument( parser.add_argument(
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening") "--display-cameras",
help=("Whether to display the camera feed while the rollout is happening"),
) )
parser.add_argument( parser.add_argument(
"--reward-classifier-pretrained-path", "--reward-classifier-pretrained-path",
@ -801,13 +901,39 @@ if __name__ == "__main__":
default=None, default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.", help="Path to a yaml config file that is necessary to build the reward classifier model.",
) )
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file") parser.add_argument(
parser.add_argument("--env-overrides", type=str, default=None, help="Overrides for the env yaml file") "--env-path", type=str, default=None, help="Path to the env yaml file"
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds") )
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes") parser.add_argument(
parser.add_argument("--replay-repo-id", type=str, default=None, help="Repo ID of the episode to replay") "--env-overrides",
parser.add_argument("--replay-root", type=str, default=None, help="Root of the dataset to replay") type=str,
parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay") default=None,
help="Overrides for the env yaml file",
)
parser.add_argument(
"--control-time-s",
type=float,
default=20,
help="Maximum episode length in seconds",
)
parser.add_argument(
"--reset-follower-pos",
type=int,
default=1,
help="Reset follower between episodes",
)
parser.add_argument(
"--replay-repo-id",
type=str,
default=None,
help="Repo ID of the episode to replay",
)
parser.add_argument(
"--replay-root", type=str, default=None, help="Root of the dataset to replay"
)
parser.add_argument(
"--replay-episode", type=int, default=0, help="Episode to replay"
)
args = parser.parse_args() args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
@ -828,7 +954,9 @@ if __name__ == "__main__":
env.reset() env.reset()
if args.replay_repo_id is not None: if args.replay_repo_id is not None:
replay_episode(env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode) replay_episode(
env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode
)
exit() exit()
# Retrieve the robot's action space for joint commands. # Retrieve the robot's action space for joint commands.
@ -849,7 +977,9 @@ if __name__ == "__main__":
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# Execute the step: wrap the NumPy action in a torch tensor. # Execute the step: wrap the NumPy action in a torch tensor.
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False)) obs, reward, terminated, truncated, info = env.step(
(torch.from_numpy(smoothed_action), False)
)
if terminated or truncated: if terminated or truncated:
env.reset() env.reset()

View File

@ -22,19 +22,11 @@ package hil_serl;
// The Learner implements this service. // The Learner implements this service.
service LearnerService { service LearnerService {
// Actor -> Learner to store transitions // Actor -> Learner to store transitions
rpc SendTransition(Transition) returns (Empty); rpc SendInteractionMessage(InteractionMessage) returns (Empty);
rpc SendInteractionMessage(InteractionMessage) returns (Empty); rpc StreamParameters(Empty) returns (stream Parameters);
rpc ReceiveTransitions(stream ActorInformation) returns (Empty);
} }
// ActorService: the Learner calls this to push parameters.
// The Actor implements this service.
service ActorService {
// Learner -> Actor to send new parameters
rpc StreamTransition(Empty) returns (stream ActorInformation) {};
rpc SendParameters(Parameters) returns (Empty);
}
message ActorInformation { message ActorInformation {
oneof data { oneof data {
Transition transition = 1; Transition transition = 1;
@ -42,17 +34,25 @@ message ActorInformation {
} }
} }
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages // Messages
message Transition { message Transition {
bytes transition_bytes = 1; bytes transition_bytes = 1;
} }
message Parameters { message Parameters {
bytes parameter_bytes = 1; TransferState transfer_state = 1;
bytes parameter_bytes = 2;
} }
message InteractionMessage { message InteractionMessage {
bytes interaction_message_bytes = 1; bytes interaction_message_bytes = 1;
} }
message Empty {} message Empty {}

View File

@ -24,25 +24,25 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"%\n\nParameters\x12\x17\n\x0fparameter_bytes\x18\x01 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty2\x92\x01\n\x0eLearnerService\x12\x37\n\x0eSendTransition\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty2\x8c\x01\n\x0c\x41\x63torService\x12\x43\n\x10StreamTransition\x12\x0f.hil_serl.Empty\x1a\x1a.hil_serl.ActorInformation\"\x00\x30\x01\x12\x37\n\x0eSendParameters\x12\x14.hil_serl.Parameters\x1a\x0f.hil_serl.Emptyb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"V\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x17\n\x0fparameter_bytes\x18\x02 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdb\x01\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12\x43\n\x12ReceiveTransitions\x12\x1a.hil_serl.ActorInformation\x1a\x0f.hil_serl.Empty(\x01\x62\x06proto3')
_globals = globals() _globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS: if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=355
_globals['_TRANSFERSTATE']._serialized_end=451
_globals['_ACTORINFORMATION']._serialized_start=28 _globals['_ACTORINFORMATION']._serialized_start=28
_globals['_ACTORINFORMATION']._serialized_end=159 _globals['_ACTORINFORMATION']._serialized_end=159
_globals['_TRANSITION']._serialized_start=161 _globals['_TRANSITION']._serialized_start=161
_globals['_TRANSITION']._serialized_end=199 _globals['_TRANSITION']._serialized_end=199
_globals['_PARAMETERS']._serialized_start=201 _globals['_PARAMETERS']._serialized_start=201
_globals['_PARAMETERS']._serialized_end=238 _globals['_PARAMETERS']._serialized_end=287
_globals['_INTERACTIONMESSAGE']._serialized_start=240 _globals['_INTERACTIONMESSAGE']._serialized_start=289
_globals['_INTERACTIONMESSAGE']._serialized_end=295 _globals['_INTERACTIONMESSAGE']._serialized_end=344
_globals['_EMPTY']._serialized_start=297 _globals['_EMPTY']._serialized_start=346
_globals['_EMPTY']._serialized_end=304 _globals['_EMPTY']._serialized_end=353
_globals['_LEARNERSERVICE']._serialized_start=307 _globals['_LEARNERSERVICE']._serialized_start=454
_globals['_LEARNERSERVICE']._serialized_end=453 _globals['_LEARNERSERVICE']._serialized_end=673
_globals['_ACTORSERVICE']._serialized_start=456
_globals['_ACTORSERVICE']._serialized_end=596
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View File

@ -36,16 +36,21 @@ class LearnerServiceStub(object):
Args: Args:
channel: A grpc.Channel. channel: A grpc.Channel.
""" """
self.SendTransition = channel.unary_unary(
'/hil_serl.LearnerService/SendTransition',
request_serializer=hilserl__pb2.Transition.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.SendInteractionMessage = channel.unary_unary( self.SendInteractionMessage = channel.unary_unary(
'/hil_serl.LearnerService/SendInteractionMessage', '/hil_serl.LearnerService/SendInteractionMessage',
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString, request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString, response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True) _registered_method=True)
self.StreamParameters = channel.unary_stream(
'/hil_serl.LearnerService/StreamParameters',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.Parameters.FromString,
_registered_method=True)
self.ReceiveTransitions = channel.stream_unary(
'/hil_serl.LearnerService/ReceiveTransitions',
request_serializer=hilserl__pb2.ActorInformation.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
class LearnerServiceServicer(object): class LearnerServiceServicer(object):
@ -53,14 +58,20 @@ class LearnerServiceServicer(object):
The Learner implements this service. The Learner implements this service.
""" """
def SendTransition(self, request, context): def SendInteractionMessage(self, request, context):
"""Actor -> Learner to store transitions """Actor -> Learner to store transitions
""" """
context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!') raise NotImplementedError('Method not implemented!')
def SendInteractionMessage(self, request, context): def StreamParameters(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ReceiveTransitions(self, request_iterator, context):
"""Missing associated documentation comment in .proto file.""" """Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
@ -69,16 +80,21 @@ class LearnerServiceServicer(object):
def add_LearnerServiceServicer_to_server(servicer, server): def add_LearnerServiceServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
'SendTransition': grpc.unary_unary_rpc_method_handler(
servicer.SendTransition,
request_deserializer=hilserl__pb2.Transition.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler( 'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
servicer.SendInteractionMessage, servicer.SendInteractionMessage,
request_deserializer=hilserl__pb2.InteractionMessage.FromString, request_deserializer=hilserl__pb2.InteractionMessage.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString, response_serializer=hilserl__pb2.Empty.SerializeToString,
), ),
'StreamParameters': grpc.unary_stream_rpc_method_handler(
servicer.StreamParameters,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.Parameters.SerializeToString,
),
'ReceiveTransitions': grpc.stream_unary_rpc_method_handler(
servicer.ReceiveTransitions,
request_deserializer=hilserl__pb2.ActorInformation.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
'hil_serl.LearnerService', rpc_method_handlers) 'hil_serl.LearnerService', rpc_method_handlers)
@ -92,33 +108,6 @@ class LearnerService(object):
The Learner implements this service. The Learner implements this service.
""" """
@staticmethod
def SendTransition(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.LearnerService/SendTransition',
hilserl__pb2.Transition.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod @staticmethod
def SendInteractionMessage(request, def SendInteractionMessage(request,
target, target,
@ -146,76 +135,8 @@ class LearnerService(object):
metadata, metadata,
_registered_method=True) _registered_method=True)
class ActorServiceStub(object):
"""ActorService: the Learner calls this to push parameters.
The Actor implements this service.
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.StreamTransition = channel.unary_stream(
'/hil_serl.ActorService/StreamTransition',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.ActorInformation.FromString,
_registered_method=True)
self.SendParameters = channel.unary_unary(
'/hil_serl.ActorService/SendParameters',
request_serializer=hilserl__pb2.Parameters.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
class ActorServiceServicer(object):
"""ActorService: the Learner calls this to push parameters.
The Actor implements this service.
"""
def StreamTransition(self, request, context):
"""Learner -> Actor to send new parameters
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendParameters(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_ActorServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'StreamTransition': grpc.unary_stream_rpc_method_handler(
servicer.StreamTransition,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.ActorInformation.SerializeToString,
),
'SendParameters': grpc.unary_unary_rpc_method_handler(
servicer.SendParameters,
request_deserializer=hilserl__pb2.Parameters.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'hil_serl.ActorService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('hil_serl.ActorService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class ActorService(object):
"""ActorService: the Learner calls this to push parameters.
The Actor implements this service.
"""
@staticmethod @staticmethod
def StreamTransition(request, def StreamParameters(request,
target, target,
options=(), options=(),
channel_credentials=None, channel_credentials=None,
@ -228,9 +149,9 @@ class ActorService(object):
return grpc.experimental.unary_stream( return grpc.experimental.unary_stream(
request, request,
target, target,
'/hil_serl.ActorService/StreamTransition', '/hil_serl.LearnerService/StreamParameters',
hilserl__pb2.Empty.SerializeToString, hilserl__pb2.Empty.SerializeToString,
hilserl__pb2.ActorInformation.FromString, hilserl__pb2.Parameters.FromString,
options, options,
channel_credentials, channel_credentials,
insecure, insecure,
@ -242,7 +163,7 @@ class ActorService(object):
_registered_method=True) _registered_method=True)
@staticmethod @staticmethod
def SendParameters(request, def ReceiveTransitions(request_iterator,
target, target,
options=(), options=(),
channel_credentials=None, channel_credentials=None,
@ -252,11 +173,11 @@ class ActorService(object):
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
metadata=None): metadata=None):
return grpc.experimental.unary_unary( return grpc.experimental.stream_unary(
request, request_iterator,
target, target,
'/hil_serl.ActorService/SendParameters', '/hil_serl.LearnerService/ReceiveTransitions',
hilserl__pb2.Parameters.SerializeToString, hilserl__pb2.ActorInformation.SerializeToString,
hilserl__pb2.Empty.FromString, hilserl__pb2.Empty.FromString,
options, options,
channel_credentials, channel_credentials,

View File

@ -14,19 +14,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io
import logging import logging
import pickle
import queue import queue
import shutil import shutil
import time import time
from pprint import pformat from pprint import pformat
from threading import Lock, Thread from threading import Lock, Thread
import signal
from threading import Event
from concurrent.futures import ThreadPoolExecutor
import grpc import grpc
# Import generated stubs # Import generated stubs
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore import hilserl_pb2_grpc # type: ignore
import hydra import hydra
import torch import torch
@ -55,10 +55,11 @@ from lerobot.common.utils.utils import (
from lerobot.scripts.server.buffer import ( from lerobot.scripts.server.buffer import (
ReplayBuffer, ReplayBuffer,
concatenate_batch_transitions, concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device, move_transition_to_device,
) )
from lerobot.scripts.server import learner_service
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
transition_queue = queue.Queue() transition_queue = queue.Queue()
@ -77,9 +78,13 @@ def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
# if resume == True # if resume == True
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir) checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
if not checkpoint_dir.exists(): if not checkpoint_dir.exists():
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True") raise RuntimeError(
f"No model checkpoint found in {checkpoint_dir} for resume=True"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") checkpoint_cfg_path = str(
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
logging.info( logging.info(
colored( colored(
"Resume=True detected, resuming previous run", "Resume=True detected, resuming previous run",
@ -112,7 +117,9 @@ def load_training_state(
if not cfg.resume: if not cfg.resume:
return None, None return None, None
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name) training_state = torch.load(
logger.last_checkpoint_dir / logger.training_state_file_name
)
if isinstance(training_state["optimizer"], dict): if isinstance(training_state["optimizer"], dict):
assert set(training_state["optimizer"].keys()) == set(optimizers.keys()) assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
@ -126,7 +133,9 @@ def load_training_state(
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None: def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir) log_output_dir(out_dir)
@ -136,7 +145,9 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer: def initialize_replay_buffer(
cfg: DictConfig, logger: Logger, device: str
) -> ReplayBuffer:
if not cfg.resume: if not cfg.resume:
return ReplayBuffer( return ReplayBuffer(
capacity=cfg.training.online_buffer_capacity, capacity=cfg.training.online_buffer_capacity,
@ -146,7 +157,9 @@ def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> Re
) )
dataset = LeRobotDataset( dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset" repo_id=cfg.dataset_repo_id,
local_files_only=True,
root=logger.log_dir / "dataset",
) )
return ReplayBuffer.from_lerobot_dataset( return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset, lerobot_dataset=dataset,
@ -168,18 +181,10 @@ def start_learner_threads(
logger: Logger, logger: Logger,
resume_optimization_step: int | None = None, resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None, resume_interaction_step: int | None = None,
shutdown_event: Event | None = None,
) -> None: ) -> None:
actor_ip = cfg.actor_learner_config.actor_ip host = cfg.actor_learner_config.learner_host
port = cfg.actor_learner_config.port port = cfg.actor_learner_config.learner_port
server_thread = Thread(
target=stream_transitions_from_actor,
args=(
actor_ip,
port,
),
daemon=True,
)
transition_thread = Thread( transition_thread = Thread(
target=add_actor_information_and_train, target=add_actor_information_and_train,
@ -196,95 +201,56 @@ def start_learner_threads(
logger, logger,
resume_optimization_step, resume_optimization_step,
resume_interaction_step, resume_interaction_step,
shutdown_event,
), ),
) )
param_push_thread = Thread(
target=learner_push_parameters,
args=(policy, policy_lock, actor_ip, port, 15),
daemon=True,
)
server_thread.start()
transition_thread.start() transition_thread.start()
param_push_thread.start()
param_push_thread.join() service = learner_service.LearnerService(
shutdown_event,
policy,
policy_lock,
cfg.actor_learner_config.policy_parameters_push_frequency,
transition_queue,
interaction_message_queue,
)
server = start_learner_server(service, host, port)
shutdown_event.wait()
server.stop(learner_service.STUTDOWN_TIMEOUT)
logging.info("[LEARNER] gRPC server stopped")
transition_thread.join() transition_thread.join()
server_thread.join() logging.info("[LEARNER] Transition thread stopped")
def stream_transitions_from_actor(host="127.0.0.1", port=50051): def start_learner_server(
""" service: learner_service.LearnerService,
Runs a gRPC client that listens for transition and interaction messages from an Actor service. host="0.0.0.0",
port=50051,
This function establishes a gRPC connection with the given `host` and `port`, then continuously ) -> grpc.server:
streams transition data from the `ActorServiceStub`. The received transition data is deserialized server = grpc.server(
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
and stored in a separate queue (`interaction_message_queue`). options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
Args: ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`. ],
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
"""
# NOTE: This is waiting for the handshake to be done
# In the future we will do it in a canonical way with a proper handshake
time.sleep(10)
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
) )
stub = hilserl_pb2_grpc.ActorServiceStub(channel) hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
for response in stub.StreamTransition(hilserl_pb2.Empty()): service,
if response.HasField("transition"): server,
buffer = io.BytesIO(response.transition.transition_bytes) )
transition = torch.load(buffer) server.add_insecure_port(f"{host}:{port}")
transition_queue.put(transition) server.start()
if response.HasField("interaction_message"): logging.info("[LEARNER] gRPC server started")
content = pickle.loads(response.interaction_message.interaction_message_bytes)
interaction_message_queue.put(content) return server
def learner_push_parameters( def check_nan_in_transition(
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5 observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor
): ):
"""
As a client, connect to the Actor's gRPC server (ActorService)
and periodically push new parameters.
"""
time.sleep(10)
channel = grpc.insecure_channel(
f"{actor_host}:{actor_port}",
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
)
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
while True:
with policy_lock:
params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None:
if policy.config.freeze_vision_encoder:
params_dict: dict[str, torch.Tensor] = {
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
}
else:
raise NotImplementedError(
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
)
params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize
buf = io.BytesIO()
torch.save(params_dict, buf)
params_bytes = buf.getvalue()
# Push them to the Actor's "SendParameters" method
logging.info("[LEARNER] Publishing parameters to the Actor")
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
time.sleep(seconds_between_pushes)
def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
for k in observations: for k in observations:
if torch.isnan(observations[k]).any(): if torch.isnan(observations[k]).any():
logging.error(f"observations[{k}] contains NaN values") logging.error(f"observations[{k}] contains NaN values")
@ -307,6 +273,7 @@ def add_actor_information_and_train(
logger: Logger, logger: Logger,
resume_optimization_step: int | None = None, resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None, resume_interaction_step: int | None = None,
shutdown_event: Event | None = None,
): ):
""" """
Handles data transfer from the actor to the learner, manages training updates, Handles data transfer from the actor to the learner, manages training updates,
@ -338,6 +305,7 @@ def add_actor_information_and_train(
logger (Logger): Logger instance for tracking training progress. logger (Logger): Logger instance for tracking training progress.
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached. resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging. resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
shutdown_event (Event | None): Event to signal shutdown.
""" """
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions # NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance # in the future. The reason why we did that is the GIL in Python. It's super slow the performance
@ -345,9 +313,17 @@ def add_actor_information_and_train(
time.time() time.time()
logging.info("Starting learner thread") logging.info("Starting learner thread")
interaction_message, transition = None, None interaction_message, transition = None, None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0 optimization_step = (
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 resume_optimization_step if resume_optimization_step is not None else 0
)
interaction_step_shift = (
resume_interaction_step if resume_interaction_step is not None else 0
)
while True: while True:
if shutdown_event is not None and shutdown_event.is_set():
logging.info("[LEARNER] Shutdown signal received. Exiting...")
break
while not transition_queue.empty(): while not transition_queue.empty():
transition_list = transition_queue.get() transition_list = transition_queue.get()
for transition in transition_list: for transition in transition_list:
@ -361,7 +337,9 @@ def add_actor_information_and_train(
interaction_message = interaction_message_queue.get() interaction_message = interaction_message_queue.get()
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step") logger.log_dict(
interaction_message, mode="train", custom_step_key="Interaction step"
)
# logging.info(f"Interaction message: {interaction_message}") # logging.info(f"Interaction message: {interaction_message}")
if len(replay_buffer) < cfg.training.online_step_before_learning: if len(replay_buffer) < cfg.training.online_step_before_learning:
@ -383,7 +361,9 @@ def add_actor_information_and_train(
observations = batch["state"] observations = batch["state"]
next_observations = batch["next_state"] next_observations = batch["next_state"]
done = batch["done"] done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) check_nan_in_transition(
observations=observations, actions=actions, next_state=next_observations
)
with policy_lock: with policy_lock:
loss_critic = policy.compute_loss_critic( loss_critic = policy.compute_loss_critic(
@ -411,7 +391,9 @@ def add_actor_information_and_train(
next_observations = batch["next_state"] next_observations = batch["next_state"]
done = batch["done"] done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) check_nan_in_transition(
observations=observations, actions=actions, next_state=next_observations
)
with policy_lock: with policy_lock:
loss_critic = policy.compute_loss_critic( loss_critic = policy.compute_loss_critic(
@ -439,7 +421,9 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item() training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations) loss_temperature = policy.compute_loss_temperature(
observations=observations
)
optimizers["temperature"].zero_grad() optimizers["temperature"].zero_grad()
loss_temperature.backward() loss_temperature.backward()
optimizers["temperature"].step() optimizers["temperature"].step()
@ -453,9 +437,13 @@ def add_actor_information_and_train(
# logging.info(f"Training infos: {training_infos}") # logging.info(f"Training infos: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) frequency_for_one_optimization_step = 1 / (
time_for_one_optimization_step + 1e-9
)
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}") logging.info(
f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}"
)
logger.log_dict( logger.log_dict(
{ {
@ -471,7 +459,8 @@ def add_actor_information_and_train(
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if cfg.training.save_checkpoint and ( if cfg.training.save_checkpoint and (
optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps optimization_step % cfg.training.save_freq == 0
or optimization_step == cfg.training.online_steps
): ):
logging.info(f"Checkpoint policy after step {optimization_step}") logging.info(f"Checkpoint policy after step {optimization_step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if # Note: Save with step as the identifier, and format it to have at least 6 digits but more if
@ -479,7 +468,9 @@ def add_actor_information_and_train(
_num_digits = max(6, len(str(cfg.training.online_steps))) _num_digits = max(6, len(str(cfg.training.online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}" step_identifier = f"{optimization_step:0{_num_digits}d}"
interaction_step = ( interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0 interaction_message["Interaction step"]
if interaction_message is not None
else 0
) )
logger.save_checkpoint( logger.save_checkpoint(
optimization_step, optimization_step,
@ -538,7 +529,9 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
optimizer_critic = torch.optim.Adam( optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
) )
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr) optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha], lr=policy.config.critic_lr
)
lr_scheduler = None lr_scheduler = None
optimizers = { optimizers = {
"actor": optimizer_actor, "actor": optimizer_actor,
@ -580,14 +573,18 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats # Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None, dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
) )
# compile policy # compile policy
policy = torch.compile(policy) policy = torch.compile(policy)
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers) resume_optimization_step, resume_interaction_step = load_training_state(
cfg, logger, optimizers
)
log_training_info(cfg, out_dir, policy) log_training_info(cfg, out_dir, policy)
@ -599,7 +596,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset offline buffer") logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg) offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer") logging.info("Convertion to a offline replay buffer")
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask] active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset, offline_dataset,
device=device, device=device,
@ -609,6 +610,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
) )
batch_size: int = batch_size // 2 # We will sample from both replay buffer batch_size: int = batch_size // 2 # We will sample from both replay buffer
shutdown_event = Event()
def signal_handler(signum, frame):
print(
f"\nReceived signal {signal.Signals(signum).name}. Initiating learner shutdown..."
)
shutdown_event.set()
# Register signal handlers
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
start_learner_threads( start_learner_threads(
cfg, cfg,
device, device,
@ -621,6 +636,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logger, logger,
resume_optimization_step, resume_optimization_step,
resume_interaction_step, resume_interaction_step,
shutdown_event,
) )

View File

@ -0,0 +1,113 @@
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import torch
from torch import nn
from threading import Lock, Event
import logging
import queue
import io
import pickle
from lerobot.scripts.server.buffer import (
move_state_dict_to_device,
bytes_buffer_size,
state_to_bytes,
)
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
MAX_WORKERS = 10
STUTDOWN_TIMEOUT = 10
class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
def __init__(
self,
shutdown_event: Event,
policy: nn.Module,
policy_lock: Lock,
seconds_between_pushes: float,
transition_queue: queue.Queue,
interaction_message_queue: queue.Queue,
):
self.shutdown_event = shutdown_event
self.policy = policy
self.policy_lock = policy_lock
self.seconds_between_pushes = seconds_between_pushes
self.transition_queue = transition_queue
self.interaction_message_queue = interaction_message_queue
def _get_policy_state(self):
with self.policy_lock:
params_dict = self.policy.actor.state_dict()
if self.policy.config.vision_encoder_name is not None:
if self.policy.config.freeze_vision_encoder:
params_dict: dict[str, torch.Tensor] = {
k: v
for k, v in params_dict.items()
if not k.startswith("encoder.")
}
else:
raise NotImplementedError(
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
)
return move_state_dict_to_device(params_dict, device="cpu")
def _send_bytes(self, buffer: bytes):
size_in_bytes = bytes_buffer_size(buffer)
sent_bytes = 0
logging.info(f"Model state size {size_in_bytes/1024/1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
yield hilserl_pb2.Parameters(
transfer_state=transfer_state, parameter_bytes=chunk
)
sent_bytes += size_to_read
logging.info(
f"[Learner] Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
)
logging.info(f"[LEARNER] Published {sent_bytes/1024/1024} MB to the Actor")
def StreamParameters(self, request, context):
# TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor")
while not self.shutdown_event.is_set():
logging.debug("[LEARNER] Push parameters to the Actor")
state_dict = self._get_policy_state()
with state_to_bytes(state_dict) as buffer:
yield from self._send_bytes(buffer)
self.shutdown_event.wait(self.seconds_between_pushes)
def ReceiveTransitions(self, request_iterator, context):
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor")
for request in request_iterator:
logging.debug("[LEARNER] Received request")
if request.HasField("transition"):
buffer = io.BytesIO(request.transition.transition_bytes)
transition = torch.load(buffer)
self.transition_queue.put(transition)
if request.HasField("interaction_message"):
content = pickle.loads(
request.interaction_message.interaction_message_bytes
)
self.interaction_message_queue.put(content)

1051
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -53,16 +53,19 @@ dependencies = [
"einops>=0.8.0", "einops>=0.8.0",
"flask>=3.0.3", "flask>=3.0.3",
"gdown>=5.1.0", "gdown>=5.1.0",
"grpcio>=1.70.0",
"gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work "gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
"h5py>=3.10.0", "h5py>=3.10.0",
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'", "huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
"imageio[ffmpeg]>=2.34.0", "imageio[ffmpeg]>=2.34.0",
"jsonlines>=4.0.0", "jsonlines>=4.0.0",
"mani-skill>=3.0.0b18",
"numba>=0.59.0", "numba>=0.59.0",
"omegaconf>=2.3.0", "omegaconf>=2.3.0",
"opencv-python>=4.9.0", "opencv-python>=4.9.0",
"packaging>=24.2", "packaging>=24.2",
"av>=12.0.5", "av>=12.0.5",
"protobuf>=5.29.3",
"pymunk>=6.6.0", "pymunk>=6.6.0",
"pynput>=1.7.7", "pynput>=1.7.7",
"pyzmq>=26.2.1", "pyzmq>=26.2.1",
@ -87,6 +90,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"] hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
mani_skill = ["mani-skill"]
pi0 = ["transformers>=4.48.0"] pi0 = ["transformers>=4.48.0"]
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
stretch = [ stretch = [
@ -107,7 +111,33 @@ requires-poetry = ">=2.1"
[tool.ruff] [tool.ruff]
line-length = 110 line-length = 110
target-version = "py310" target-version = "py310"
exclude = ["tests/artifacts/**/*.safetensors"] exclude = [
"tests/data",
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
"*_pb2.py",
"*_pb2_grpc.py",
]
[tool.ruff.lint] [tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"] select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]

11
ruff.toml Normal file
View File

@ -0,0 +1,11 @@
# Exclude files/directories from Ruff
exclude = [
"*_pb2.py", # Ignore all protobuf generated files
"*_pb2_grpc.py", # Ignore all gRPC generated files
"lerobot/scripts/server/hilserl_pb2.py", # Ignore specific file
".git",
".env",
".venv",
"build",
"dist"
]