[Port HIL-SERL] Adjust Actor-Learner architecture & clean up dependency management for HIL-SERL (#722)

This commit is contained in:
Eugene Mironov 2025-02-21 16:29:00 +07:00 committed by Michel Aractingi
parent 85242cac67
commit e1d55c7a44
17 changed files with 1949 additions and 475 deletions

View File

@ -46,7 +46,7 @@ repos:
rev: v3.19.1
hooks:
- id: pyupgrade
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.5
hooks:

View File

@ -0,0 +1,11 @@
FROM huggingface/lerobot-gpu:latest
RUN apt-get update && apt-get install -y --no-install-recommends \
libvulkan1 vulkan-tools \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[mani-skill]"
# Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl"

View File

@ -81,3 +81,14 @@ You can also log sample predictions during evaluation. Each logged sample will i
- The **classifier's "confidence" (logits/probability)**.
These logs can be useful for diagnosing and debugging performance issues.
#### Generate protobuf files
```bash
python -m grpc_tools.protoc \
-I lerobot/scripts/server \
--python_out=lerobot/scripts/server \
--grpc_python_out=lerobot/scripts/server \
lerobot/scripts/server/hilserl.proto
```

View File

@ -41,11 +41,16 @@ class SACConfig:
)
input_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"observation.image": {"mean": [[0.485, 0.456, 0.406]], "std": [[0.229, 0.224, 0.225]]},
"observation.image": {
"mean": [[0.485, 0.456, 0.406]],
"std": [[0.229, 0.224, 0.225]],
},
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"}
)
output_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"action": {"min": [-1, -1], "max": [1, 1]},
@ -54,9 +59,8 @@ class SACConfig:
# TODO: Move it outside of the config
actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: {
"actor_ip": "127.0.0.1",
"port": 50051,
"learner_ip": "127.0.0.1",
"learner_host": "127.0.0.1",
"learner_port": 50051,
}
)
camera_number: int = 1

View File

@ -108,5 +108,6 @@ policy:
utd_ratio: 2 # 10
actor_learner_config:
actor_ip: "127.0.0.1"
port: 50051
learner_host: "127.0.0.1"
learner_port: 50051
policy_parameters_push_frequency: 15

View File

@ -112,8 +112,9 @@ policy:
utd_ratio: 2 # 10
actor_learner_config:
actor_ip: "127.0.0.1"
port: 50051
learner_host: "127.0.0.1"
learner_port: 50051
policy_parameters_push_frequency: 15
# # Loss coefficients.
# reward_coeff: 0.5

View File

@ -17,9 +17,9 @@ import io
import logging
import pickle
import queue
import time
from concurrent import futures
from statistics import mean, quantiles
import signal
from functools import lru_cache
# from lerobot.scripts.eval import eval_policy
from threading import Thread
@ -35,7 +35,6 @@ from torch import nn
# from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.control_utils import busy_wait
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import (
@ -44,14 +43,24 @@ from lerobot.common.utils.utils import (
set_global_seed,
)
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
from lerobot.scripts.server.buffer import (
Transition,
move_state_dict_to_device,
move_transition_to_device,
bytes_buffer_size,
)
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
from lerobot.scripts.server import learner_service
from threading import Event
logging.basicConfig(level=logging.INFO)
parameters_queue = queue.Queue(maxsize=1)
message_queue = queue.Queue(maxsize=1_000_000)
ACTOR_SHUTDOWN_TIMEOUT = 30
class ActorInformation:
"""
@ -70,95 +79,171 @@ class ActorInformation:
self.interaction_message = interaction_message
class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
"""
gRPC service for actor-learner communication in reinforcement learning.
def receive_policy(
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
shutdown_event: Event,
parameters_queue: queue.Queue,
):
logging.info("[ACTOR] Start receiving parameters from the Learner")
bytes_buffer = io.BytesIO()
step = 0
try:
for model_update in learner_client.StreamParameters(hilserl_pb2.Empty()):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down policy streaming receiver")
return hilserl_pb2.Empty()
This service is responsible for:
1. Streaming batches of transition data and statistical metrics from the actor to the learner.
2. Receiving updated network parameters from the learner.
"""
def StreamTransition(self, request, context): # noqa: N802
"""
Streams data from the actor to the learner.
This function continuously retrieves messages from the queue and processes them based on their type:
- **Transition Data:**
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
- **Interaction Messages:**
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Yields:
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
"""
while True:
message = message_queue.get(block=True)
if message.transition is not None:
transition_to_send_to_learner: list[Transition] = [
move_transition_to_device(transition=T, device="cpu") for T in message.transition
]
# Check for NaNs in transitions before sending to learner
for transition in transition_to_send_to_learner:
for key, value in transition["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
buf = io.BytesIO()
torch.save(transition_to_send_to_learner, buf)
transition_bytes = buf.getvalue()
transition_message = hilserl_pb2.Transition(transition_bytes=transition_bytes)
response = hilserl_pb2.ActorInformation(transition=transition_message)
elif message.interaction_message is not None:
content = hilserl_pb2.InteractionMessage(
interaction_message_bytes=pickle.dumps(message.interaction_message)
if model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(model_update.parameter_bytes)
logging.info("Received model update at step 0")
step = 0
continue
elif (
model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE
):
bytes_buffer.write(model_update.parameter_bytes)
step += 1
logging.info(f"Received model update at step {step}")
elif model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(model_update.parameter_bytes)
logging.info(
f"Received model update at step end size {bytes_buffer_size(bytes_buffer)}"
)
response = hilserl_pb2.ActorInformation(interaction_message=content)
yield response
state_dict = torch.load(bytes_buffer)
def SendParameters(self, request, context): # noqa: N802
"""
Receives updated parameters from the learner and updates the actor.
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
step = 0
The learner calls this method to send new model parameters. The received parameters are deserialized
and placed in a queue to be consumed by the actor.
logging.info("Model updated")
Args:
request (hilserl_pb2.ParameterUpdate): The request containing serialized network parameters.
context (grpc.ServicerContext): The gRPC context.
parameters_queue.put(state_dict)
Returns:
hilserl_pb2.Empty: An empty response to acknowledge receipt.
"""
buffer = io.BytesIO(request.parameter_bytes)
params = torch.load(buffer)
parameters_queue.put(params)
return hilserl_pb2.Empty()
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
return hilserl_pb2.Empty()
def serve_actor_service(port=50052):
def transitions_stream(shutdown_event: Event, message_queue: queue.Queue):
while not shutdown_event.is_set():
try:
message = message_queue.get(block=True, timeout=5)
except queue.Empty:
logging.debug("[ACTOR] Transition queue is empty")
continue
if message.transition is not None:
transition_to_send_to_learner: list[Transition] = [
move_transition_to_device(transition=T, device="cpu")
for T in message.transition
]
# Check for NaNs in transitions before sending to learner
for transition in transition_to_send_to_learner:
for key, value in transition["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
buf = io.BytesIO()
torch.save(transition_to_send_to_learner, buf)
transition_bytes = buf.getvalue()
transition_message = hilserl_pb2.Transition(
transition_bytes=transition_bytes
)
response = hilserl_pb2.ActorInformation(transition=transition_message)
elif message.interaction_message is not None:
content = hilserl_pb2.InteractionMessage(
interaction_message_bytes=pickle.dumps(message.interaction_message)
)
response = hilserl_pb2.ActorInformation(interaction_message=content)
yield response
return hilserl_pb2.Empty()
def send_transitions(
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
shutdown_event: Event,
message_queue: queue.Queue,
):
"""
Runs a gRPC server to start streaming the data from the actor to the learner.
Throught this server the learner can push parameters to the Actor as well.
Streams data from the actor to the learner.
This function continuously retrieves messages from the queue and processes them based on their type:
- **Transition Data:**
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
- **Interaction Messages:**
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Yields:
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
"""
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=20),
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
try:
learner_client.ReceiveTransitions(
transitions_stream(shutdown_event, message_queue)
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming transitions")
@lru_cache(maxsize=1)
def learner_service_client(
host="127.0.0.1", port=50051
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
import json
"""
Returns a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it.
"""
service_config = {
"methodConfig": [
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": 5, # Max retries (total attempts = 5)
"initialBackoff": "0.1s", # First retry after 0.1s
"maxBackoff": "2s", # Max wait time between retries
"backoffMultiplier": 2, # Exponential backoff factor
"retryableStatusCodes": [
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
], # Retries on network failures
},
}
]
}
service_config_json = json.dumps(service_config)
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.enable_retries", 1),
("grpc.service_config", service_config_json),
],
)
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(ActorServiceServicer(), server)
server.add_insecure_port(f"[::]:{port}")
server.start()
logging.info(f"[ACTOR] gRPC server listening on port {port}")
server.wait_for_termination()
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
logging.info("[LEARNER] Learner service client created")
return stub, channel
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
@ -169,7 +254,9 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, d
policy.load_state_dict(state_dict)
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
def act_with_policy(
cfg: DictConfig, robot: Robot, reward_classifier: nn.Module, shutdown_event: Event
):
"""
Executes policy interaction within the environment.
@ -182,7 +269,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
logging.info("make_env online")
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
online_env = make_robot_env(
robot=robot, reward_classifier=reward_classifier, cfg=cfg
)
set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
@ -227,17 +316,27 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
episode_intervention = False
for interaction_step in range(cfg.training.online_steps):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutdown signal received. Exiting...")
return
if interaction_step >= cfg.training.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with TimerManager(
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
elapsed_time_list=list_policy_time,
label="Policy inference time",
log=False,
) as timer: # noqa: F841
action = policy.select_action(batch=obs)
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
log_policy_frequency_issue(
policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step
)
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
next_obs, reward, done, truncated, info = online_env.step(
action.squeeze(dim=0).cpu().numpy()
)
else:
# TODO (azouitine): Make a custom space for torch tensor
action = online_env.action_space.sample()
@ -245,7 +344,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
action = (
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
torch.from_numpy(action[0])
.to(device, non_blocking=device.type == "cuda")
.unsqueeze(dim=0)
)
sum_reward_episode += float(reward)
@ -261,7 +362,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
# Check for NaN values in observations
for key, tensor in obs.items():
if torch.isnan(tensor).any():
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
logging.error(
f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}"
)
list_transition_to_send_to_learner.append(
Transition(
@ -281,13 +384,19 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
# Because we are using a single environment we can index at zero
if done or truncated:
# TODO: Handle logging for episode information
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
logging.info(
f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}"
)
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
update_policy_parameters(
policy=policy.actor, parameters_queue=parameters_queue, device=device
)
if len(list_transition_to_send_to_learner) > 0:
send_transitions_in_chunks(
transitions=list_transition_to_send_to_learner, message_queue=message_queue, chunk_size=4
transitions=list_transition_to_send_to_learner,
message_queue=message_queue,
chunk_size=4,
)
list_transition_to_send_to_learner = []
@ -332,11 +441,16 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
stats = {
"Policy frequency [Hz]": policy_fps,
"Policy frequency 90th-p [Hz]": quantiles_90,
}
return stats
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
def log_policy_frequency_issue(
policy_fps: float, cfg: DictConfig, interaction_step: int
):
if policy_fps < cfg.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
@ -347,7 +461,34 @@ def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_s
def actor_cli(cfg: dict):
robot = make_robot(cfg=cfg.robot)
server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True)
shutdown_event = Event()
# Define signal handler
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
)
receive_policy_thread = Thread(
target=receive_policy,
args=(learner_client, shutdown_event, parameters_queue),
daemon=True,
)
transitions_thread = Thread(
target=send_transitions,
args=(learner_client, shutdown_event, message_queue),
daemon=True,
)
# HACK: FOR MANISKILL we do not have a reward classifier
# TODO: Remove this once we merge into main
@ -360,15 +501,27 @@ def actor_cli(cfg: dict):
pretrained_path=cfg.env.reward_classifier.pretrained_path,
config_path=cfg.env.reward_classifier.config_path,
)
policy_thread = Thread(
target=act_with_policy,
daemon=True,
args=(cfg, robot, reward_classifier),
args=(cfg, robot, reward_classifier, shutdown_event),
)
server_thread.start()
transitions_thread.start()
policy_thread.start()
receive_policy_thread.start()
shutdown_event.wait()
logging.info("[ACTOR] Shutdown event received")
grpc_channel.close()
policy_thread.join()
server_thread.join()
logging.info("[ACTOR] Policy thread joined")
transitions_thread.join()
logging.info("[ACTOR] Transitions thread joined")
receive_policy_thread.join()
logging.info("[ACTOR] Receive policy thread joined")
if __name__ == "__main__":

View File

@ -17,6 +17,7 @@ import functools
import random
from typing import Any, Callable, Optional, Sequence, TypedDict
import io
import torch
import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
@ -41,24 +42,33 @@ class BatchTransition(TypedDict):
done: torch.Tensor
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
def move_transition_to_device(
transition: Transition, device: str = "cpu"
) -> Transition:
# Move state tensors to CPU
device = torch.device(device)
transition["state"] = {
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
key: val.to(device, non_blocking=device.type == "cuda")
for key, val in transition["state"].items()
}
# Move action to CPU
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
transition["action"] = transition["action"].to(
device, non_blocking=device.type == "cuda"
)
# No need to move reward or done, as they are float and bool
# No need to move reward or done, as they are float and bool
if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
transition["reward"] = transition["reward"].to(
device=device, non_blocking=device.type == "cuda"
)
if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
transition["done"] = transition["done"].to(
device, non_blocking=device.type == "cuda"
)
# Move next_state tensors to CPU
transition["next_state"] = {
@ -82,7 +92,10 @@ def move_state_dict_to_device(state_dict, device):
if isinstance(state_dict, torch.Tensor):
return state_dict.to(device)
elif isinstance(state_dict, dict):
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
return {
k: move_state_dict_to_device(v, device=device)
for k, v in state_dict.items()
}
elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple):
@ -91,6 +104,22 @@ def move_state_dict_to_device(state_dict, device):
return state_dict
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> io.BytesIO:
"""Convert model state dict to flat array for transmission"""
buffer = io.BytesIO()
torch.save(state_dict, buffer)
return buffer
def bytes_buffer_size(buffer: io.BytesIO) -> int:
buffer.seek(0, io.SEEK_END)
result = buffer.tell()
buffer.seek(0)
return result
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
"""
Perform a per-image random crop over a batch of images in a vectorized way.
@ -116,7 +145,9 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
cropped_hwcn = images_hwcn[
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
@ -179,7 +210,9 @@ class ReplayBuffer:
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
# Move tensors to the storage device
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
next_state = {
key: tensor.to(self.storage_device) for key, tensor in next_state.items()
}
action = action.to(self.storage_device)
# if complementary_info is not None:
# complementary_info = {
@ -234,7 +267,9 @@ class ReplayBuffer:
)
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
list_transition = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
# Fill the replay buffer with the lerobot dataset transitions
for data in list_transition:
for k, v in data.items():
@ -295,7 +330,9 @@ class ReplayBuffer:
# If not provided, you can either raise an error or define a default:
if state_keys is None:
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
raise ValueError(
"You must provide a list of keys in `state_keys` that define your 'state'."
)
transitions: list[Transition] = []
num_frames = len(dataset)
@ -350,33 +387,37 @@ class ReplayBuffer:
# -- Build batched states --
batch_state = {}
for key in self.state_keys:
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
batch_state[key] = torch.cat(
[t["state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
# -- Build batched rewards --
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
self.device
)
# -- Build batched rewards --
batch_rewards = torch.tensor(
[t["reward"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# -- Build batched next states --
batch_next_state = {}
for key in self.state_keys:
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
batch_next_state[key] = torch.cat(
[t["next_state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
batch_next_state[key] = self.image_augmentation_function(
batch_next_state[key]
)
# -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor(
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# Return a BatchTransition typed dict
return BatchTransition(
@ -433,7 +474,9 @@ class ReplayBuffer:
# Add state keys
for key in self.state_keys:
sample_val = first_transition["state"][key].squeeze(dim=0) # Remove batch dimension
sample_val = first_transition["state"][key].squeeze(
dim=0
) # Remove batch dimension
if not isinstance(sample_val, torch.Tensor):
raise ValueError(
f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors."
@ -465,7 +508,9 @@ class ReplayBuffer:
# We detect episode boundaries by `done == True`.
# --------------------------------------------------------------------------------------------
episode_index = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index
)
frame_idx_in_episode = 0
for global_frame_idx, transition in enumerate(self.memory):
@ -476,16 +521,24 @@ class ReplayBuffer:
# Expand dimension to match what the dataset expects (the dataset wants the raw shape)
# We assume your buffer has shape [C, H, W] (if image) or [D] if vector
# This is typically already correct, but if needed you can reshape below.
frame_dict[key] = transition["state"][key].cpu().squeeze(dim=0) # Remove batch dimension
frame_dict[key] = (
transition["state"][key].cpu().squeeze(dim=0)
) # Remove batch dimension
# Fill action, reward, done
# Make sure they are shape (X,) or (X,Y,...) as needed.
frame_dict["action"] = transition["action"].cpu().squeeze(dim=0) # Remove batch dimension
frame_dict["action"] = (
transition["action"].cpu().squeeze(dim=0)
) # Remove batch dimension
frame_dict["next.reward"] = (
torch.tensor([transition["reward"]], dtype=torch.float32).cpu().squeeze(dim=0)
torch.tensor([transition["reward"]], dtype=torch.float32)
.cpu()
.squeeze(dim=0)
)
frame_dict["next.done"] = (
torch.tensor([transition["done"]], dtype=torch.bool).cpu().squeeze(dim=0)
torch.tensor([transition["done"]], dtype=torch.bool)
.cpu()
.squeeze(dim=0)
)
# Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict)
@ -499,7 +552,9 @@ class ReplayBuffer:
episode_index += 1
frame_idx_in_episode = 0
# Start a new buffer for the next episode
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index
)
# We are done adding frames
# If the last transition wasn't done=True, we still have an open buffer with frames.
@ -541,7 +596,13 @@ def concatenate_batch_transitions(
) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place"""
left_batch_transitions["state"] = {
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
key: torch.cat(
[
left_batch_transitions["state"][key],
right_batch_transition["state"][key],
],
dim=0,
)
for key in left_batch_transitions["state"]
}
left_batch_transitions["action"] = torch.cat(
@ -552,7 +613,11 @@ def concatenate_batch_transitions(
)
left_batch_transitions["next_state"] = {
key: torch.cat(
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
[
left_batch_transitions["next_state"][key],
right_batch_transition["next_state"][key],
],
dim=0,
)
for key in left_batch_transitions["next_state"]
}

View File

@ -10,11 +10,9 @@ import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config, log_say
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
logging.basicConfig(level=logging.INFO)
@ -62,7 +60,9 @@ class HILSerlRobotEnv(gym.Env):
if not self.robot.is_connected:
self.robot.connect()
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
self.initial_follower_position = robot.follower_arms["main"].read(
"Present_Position"
)
# Episode tracking.
self.current_step = 0
@ -70,7 +70,9 @@ class HILSerlRobotEnv(gym.Env):
self.delta = delta
self.use_delta_action_space = use_delta_action_space
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
self.current_joint_positions = self.robot.follower_arms["main"].read(
"Present_Position"
)
# Retrieve the size of the joint position interval bound.
self.relative_bounds_size = (
@ -105,12 +107,16 @@ class HILSerlRobotEnv(gym.Env):
image_keys = [key for key in example_obs if "image" in key]
state_keys = [key for key in example_obs if "image" not in key]
observation_spaces = {
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
key: gym.spaces.Box(
low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8
)
for key in image_keys
}
observation_spaces["observation.state"] = gym.spaces.Dict(
{
key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32)
key: gym.spaces.Box(
low=0, high=10, shape=example_obs[key].shape, dtype=np.float32
)
for key in state_keys
}
)
@ -128,8 +134,12 @@ class HILSerlRobotEnv(gym.Env):
)
else:
action_space_robot = gym.spaces.Box(
low=self.robot.config.joint_position_relative_bounds["min"].cpu().numpy(),
high=self.robot.config.joint_position_relative_bounds["max"].cpu().numpy(),
low=self.robot.config.joint_position_relative_bounds["min"]
.cpu()
.numpy(),
high=self.robot.config.joint_position_relative_bounds["max"]
.cpu()
.numpy(),
shape=(action_dim,),
dtype=np.float32,
)
@ -141,7 +151,9 @@ class HILSerlRobotEnv(gym.Env):
),
)
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
def reset(
self, seed=None, options=None
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""
Reset the environment to its initial state.
This method resets the step counter and clears any episodic data.
@ -198,24 +210,34 @@ class HILSerlRobotEnv(gym.Env):
"""
policy_action, intervention_bool = action
teleop_action = None
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
self.current_joint_positions = self.robot.follower_arms["main"].read(
"Present_Position"
)
if isinstance(policy_action, torch.Tensor):
policy_action = policy_action.cpu().numpy()
policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
policy_action = np.clip(
policy_action, self.action_space[0].low, self.action_space[0].high
)
if not intervention_bool:
if self.use_delta_action_space:
target_joint_positions = self.current_joint_positions + self.delta * policy_action
target_joint_positions = (
self.current_joint_positions + self.delta * policy_action
)
else:
target_joint_positions = policy_action
self.robot.send_action(torch.from_numpy(target_joint_positions))
observation = self.robot.capture_observation()
else:
observation, teleop_action = self.robot.teleop_step(record_data=True)
teleop_action = teleop_action["action"] # Convert tensor to appropriate format
teleop_action = teleop_action[
"action"
] # Convert tensor to appropriate format
# When applying the delta action space, convert teleop absolute values to relative differences.
if self.use_delta_action_space:
teleop_action = (teleop_action - self.current_joint_positions) / self.delta
teleop_action = (
teleop_action - self.current_joint_positions
) / self.delta
if torch.any(teleop_action < -self.relative_bounds_size) and torch.any(
teleop_action > self.relative_bounds_size
):
@ -226,7 +248,9 @@ class HILSerlRobotEnv(gym.Env):
)
teleop_action = torch.clamp(
teleop_action, -self.relative_bounds_size, self.relative_bounds_size
teleop_action,
-self.relative_bounds_size,
self.relative_bounds_size,
)
# NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action.
if teleop_action.dim() == 1:
@ -245,7 +269,10 @@ class HILSerlRobotEnv(gym.Env):
reward,
terminated,
truncated,
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
{
"action_intervention": teleop_action,
"is_intervention": teleop_action is not None,
},
)
def render(self):
@ -351,7 +378,9 @@ class JointMaskingActionSpace(gym.Wrapper):
raise ValueError("Mask length must match action space dimensions")
low = env.action_space.low[self.active_dims]
high = env.action_space.high[self.active_dims]
self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype)
self.action_space = gym.spaces.Box(
low=low, high=high, dtype=env.action_space.dtype
)
if isinstance(env.action_space, gym.spaces.Tuple):
if len(mask) != env.action_space[0].shape[0]:
@ -359,8 +388,12 @@ class JointMaskingActionSpace(gym.Wrapper):
low = env.action_space[0].low[self.active_dims]
high = env.action_space[0].high[self.active_dims]
action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype)
self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1]))
action_space_masked = gym.spaces.Box(
low=low, high=high, dtype=env.action_space[0].dtype
)
self.action_space = gym.spaces.Tuple(
(action_space_masked, env.action_space[1])
)
# Create new action space with masked dimensions
def action(self, action):
@ -379,14 +412,18 @@ class JointMaskingActionSpace(gym.Wrapper):
# Extract the masked component from the tuple.
masked_action = action[0] if isinstance(action, tuple) else action
# Create a full action for the Box element.
full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype)
full_box_action = np.zeros(
self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype
)
full_box_action[self.active_dims] = masked_action
# Return a tuple with the reconstructed Box action and the unchanged remainder.
return (full_box_action, action[1])
else:
# For Box action spaces.
masked_action = action if not isinstance(action, tuple) else action[0]
full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype)
full_action = np.zeros(
self.env.action_space.shape, dtype=self.env.action_space.dtype
)
full_action[self.active_dims] = masked_action
return full_action
@ -395,9 +432,13 @@ class JointMaskingActionSpace(gym.Wrapper):
obs, reward, terminated, truncated, info = self.env.step(action)
if "action_intervention" in info and info["action_intervention"] is not None:
if info["action_intervention"].dim() == 1:
info["action_intervention"] = info["action_intervention"][self.active_dims]
info["action_intervention"] = info["action_intervention"][
self.active_dims
]
else:
info["action_intervention"] = info["action_intervention"][:, self.active_dims]
info["action_intervention"] = info["action_intervention"][
:, self.active_dims
]
return obs, reward, terminated, truncated, info
@ -438,7 +479,12 @@ class TimeLimitWrapper(gym.Wrapper):
class ImageCropResizeWrapper(gym.Wrapper):
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
def __init__(
self,
env,
crop_params_dict: Dict[str, Annotated[Tuple[int], 4]],
resize_size=None,
):
super().__init__(env)
self.env = env
self.crop_params_dict = crop_params_dict
@ -450,7 +496,9 @@ class ImageCropResizeWrapper(gym.Wrapper):
for key in crop_params_dict:
top, left, height, width = crop_params_dict[key]
new_shape = (top + height, left + width)
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
self.observation_space[key] = gym.spaces.Box(
low=0, high=255, shape=new_shape
)
self.resize_size = resize_size
if self.resize_size is None:
@ -463,7 +511,9 @@ class ImageCropResizeWrapper(gym.Wrapper):
# Check for NaNs before processing
if torch.isnan(obs[k]).any():
logging.error(f"NaN values detected in observation {k} before crop and resize")
logging.error(
f"NaN values detected in observation {k} before crop and resize"
)
if device == torch.device("mps:0"):
obs[k] = obs[k].cpu()
@ -473,7 +523,9 @@ class ImageCropResizeWrapper(gym.Wrapper):
# Check for NaNs after processing
if torch.isnan(obs[k]).any():
logging.error(f"NaN values detected in observation {k} after crop and resize")
logging.error(
f"NaN values detected in observation {k} after crop and resize"
)
obs[k] = obs[k].to(device)
@ -503,10 +555,14 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper):
observation = preprocess_observation(observation)
observation = {
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
key: observation[key].to(
self.device, non_blocking=self.device.type == "cuda"
)
for key in observation
}
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
observation = {
k: torch.tensor(v, device=self.device) for k, v in observation.items()
}
return observation
@ -553,18 +609,31 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
"Place the leader in similar pose to the follower and press space again."
)
self.events["pause_policy"] = True
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
log_say(
"Human intervention stage. Get ready to take over.",
play_sounds=True,
)
return
if self.events["pause_policy"] and not self.events["human_intervention_step"]:
if (
self.events["pause_policy"]
and not self.events["human_intervention_step"]
):
self.events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.")
log_say("Starting human intervention.", play_sounds=True)
log_say(
"Starting human intervention.", play_sounds=True
)
return
if self.events["pause_policy"] and self.events["human_intervention_step"]:
if (
self.events["pause_policy"]
and self.events["human_intervention_step"]
):
self.events["pause_policy"] = False
self.events["human_intervention_step"] = False
print("Space key pressed for a third time.")
log_say("Continuing with policy actions.", play_sounds=True)
log_say(
"Continuing with policy actions.", play_sounds=True
)
return
except Exception as e:
print(f"Error handling key press: {e}")
@ -572,7 +641,9 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
self.listener = keyboard.Listener(on_press=on_press)
self.listener.start()
except ImportError:
logging.warning("Could not import pynput. Keyboard interface will not be available.")
logging.warning(
"Could not import pynput. Keyboard interface will not be available."
)
self.listener = None
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
@ -599,7 +670,9 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
time.sleep(0.1) # Check more frequently if desired
# Execute the step in the underlying environment
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
obs, reward, terminated, truncated, info = self.env.step(
(policy_action, is_intervention)
)
# Override reward and termination if episode success event triggered
with self.event_lock:
@ -628,7 +701,10 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
class ResetWrapper(gym.Wrapper):
def __init__(
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
self,
env: HILSerlRobotEnv,
reset_fn: Optional[Callable[[], None]] = None,
reset_time_s: float = 5,
):
super().__init__(env)
self.reset_fn = reset_fn
@ -641,7 +717,10 @@ class ResetWrapper(gym.Wrapper):
if self.reset_fn is not None:
self.reset_fn(self.env)
else:
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
log_say(
f"Manually reset the environment for {self.reset_time_s} seconds.",
play_sounds=True,
)
start_time = time.perf_counter()
while time.perf_counter() - start_time < self.reset_time_s:
self.robot.teleop_step()
@ -654,7 +733,9 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
def observation(
self, observation: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
for key in observation:
if "image" in key and observation[key].dim() == 3:
observation[key] = observation[key].unsqueeze(0)
@ -685,6 +766,8 @@ def make_robot_env(
A vectorized gym environment with all the necessary wrappers applied.
"""
if "maniskill" in cfg.env.name:
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
env = make_maniskill(
cfg=cfg,
@ -703,15 +786,23 @@ def make_robot_env(
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
if cfg.env.wrapper.crop_params_dict is not None:
env = ImageCropResizeWrapper(
env=env, crop_params_dict=cfg.env.wrapper.crop_params_dict, resize_size=cfg.env.wrapper.resize_size
env=env,
crop_params_dict=cfg.env.wrapper.crop_params_dict,
resize_size=cfg.env.wrapper.resize_size,
)
# Add reward computation and control wrappers
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps)
env = TimeLimitWrapper(
env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps
)
env = KeyboardInterfaceWrapper(env=env)
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s)
env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space)
env = ResetWrapper(
env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s
)
env = JointMaskingActionSpace(
env=env, mask=cfg.env.wrapper.joint_masking_action_space
)
env = BatchCompitableWrapper(env=env)
return env
@ -724,13 +815,19 @@ def get_classifier(pretrained_path, config_path, device="mps"):
return None
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
classifier_config.num_cameras = len(
cfg.training.image_keys
) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to(device)
@ -741,7 +838,9 @@ def replay_episode(env, repo_id, root=None, episode=0):
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
local_files_only = root is not None
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
dataset = LeRobotDataset(
repo_id, root=root, episodes=[episode], local_files_only=local_files_only
)
actions = dataset.hf_dataset.select_columns("action")
for idx in range(dataset.num_frames):
@ -787,7 +886,8 @@ if __name__ == "__main__":
),
)
parser.add_argument(
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
"--display-cameras",
help=("Whether to display the camera feed while the rollout is happening"),
)
parser.add_argument(
"--reward-classifier-pretrained-path",
@ -801,13 +901,39 @@ if __name__ == "__main__":
default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.",
)
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
parser.add_argument("--env-overrides", type=str, default=None, help="Overrides for the env yaml file")
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
parser.add_argument("--replay-repo-id", type=str, default=None, help="Repo ID of the episode to replay")
parser.add_argument("--replay-root", type=str, default=None, help="Root of the dataset to replay")
parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay")
parser.add_argument(
"--env-path", type=str, default=None, help="Path to the env yaml file"
)
parser.add_argument(
"--env-overrides",
type=str,
default=None,
help="Overrides for the env yaml file",
)
parser.add_argument(
"--control-time-s",
type=float,
default=20,
help="Maximum episode length in seconds",
)
parser.add_argument(
"--reset-follower-pos",
type=int,
default=1,
help="Reset follower between episodes",
)
parser.add_argument(
"--replay-repo-id",
type=str,
default=None,
help="Repo ID of the episode to replay",
)
parser.add_argument(
"--replay-root", type=str, default=None, help="Root of the dataset to replay"
)
parser.add_argument(
"--replay-episode", type=int, default=0, help="Episode to replay"
)
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
@ -828,7 +954,9 @@ if __name__ == "__main__":
env.reset()
if args.replay_repo_id is not None:
replay_episode(env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode)
replay_episode(
env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode
)
exit()
# Retrieve the robot's action space for joint commands.
@ -849,7 +977,9 @@ if __name__ == "__main__":
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# Execute the step: wrap the NumPy action in a torch tensor.
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
obs, reward, terminated, truncated, info = env.step(
(torch.from_numpy(smoothed_action), False)
)
if terminated or truncated:
env.reset()

View File

@ -22,19 +22,11 @@ package hil_serl;
// The Learner implements this service.
service LearnerService {
// Actor -> Learner to store transitions
rpc SendTransition(Transition) returns (Empty);
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
rpc StreamParameters(Empty) returns (stream Parameters);
rpc ReceiveTransitions(stream ActorInformation) returns (Empty);
}
// ActorService: the Learner calls this to push parameters.
// The Actor implements this service.
service ActorService {
// Learner -> Actor to send new parameters
rpc StreamTransition(Empty) returns (stream ActorInformation) {};
rpc SendParameters(Parameters) returns (Empty);
}
message ActorInformation {
oneof data {
Transition transition = 1;
@ -42,13 +34,21 @@ message ActorInformation {
}
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Transition {
bytes transition_bytes = 1;
}
message Parameters {
bytes parameter_bytes = 1;
TransferState transfer_state = 1;
bytes parameter_bytes = 2;
}
message InteractionMessage {

View File

@ -24,25 +24,25 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"%\n\nParameters\x12\x17\n\x0fparameter_bytes\x18\x01 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty2\x92\x01\n\x0eLearnerService\x12\x37\n\x0eSendTransition\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty2\x8c\x01\n\x0c\x41\x63torService\x12\x43\n\x10StreamTransition\x12\x0f.hil_serl.Empty\x1a\x1a.hil_serl.ActorInformation\"\x00\x30\x01\x12\x37\n\x0eSendParameters\x12\x14.hil_serl.Parameters\x1a\x0f.hil_serl.Emptyb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"V\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x17\n\x0fparameter_bytes\x18\x02 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdb\x01\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12\x43\n\x12ReceiveTransitions\x12\x1a.hil_serl.ActorInformation\x1a\x0f.hil_serl.Empty(\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=355
_globals['_TRANSFERSTATE']._serialized_end=451
_globals['_ACTORINFORMATION']._serialized_start=28
_globals['_ACTORINFORMATION']._serialized_end=159
_globals['_TRANSITION']._serialized_start=161
_globals['_TRANSITION']._serialized_end=199
_globals['_PARAMETERS']._serialized_start=201
_globals['_PARAMETERS']._serialized_end=238
_globals['_INTERACTIONMESSAGE']._serialized_start=240
_globals['_INTERACTIONMESSAGE']._serialized_end=295
_globals['_EMPTY']._serialized_start=297
_globals['_EMPTY']._serialized_end=304
_globals['_LEARNERSERVICE']._serialized_start=307
_globals['_LEARNERSERVICE']._serialized_end=453
_globals['_ACTORSERVICE']._serialized_start=456
_globals['_ACTORSERVICE']._serialized_end=596
_globals['_PARAMETERS']._serialized_end=287
_globals['_INTERACTIONMESSAGE']._serialized_start=289
_globals['_INTERACTIONMESSAGE']._serialized_end=344
_globals['_EMPTY']._serialized_start=346
_globals['_EMPTY']._serialized_end=353
_globals['_LEARNERSERVICE']._serialized_start=454
_globals['_LEARNERSERVICE']._serialized_end=673
# @@protoc_insertion_point(module_scope)

View File

@ -36,16 +36,21 @@ class LearnerServiceStub(object):
Args:
channel: A grpc.Channel.
"""
self.SendTransition = channel.unary_unary(
'/hil_serl.LearnerService/SendTransition',
request_serializer=hilserl__pb2.Transition.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.SendInteractionMessage = channel.unary_unary(
'/hil_serl.LearnerService/SendInteractionMessage',
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.StreamParameters = channel.unary_stream(
'/hil_serl.LearnerService/StreamParameters',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.Parameters.FromString,
_registered_method=True)
self.ReceiveTransitions = channel.stream_unary(
'/hil_serl.LearnerService/ReceiveTransitions',
request_serializer=hilserl__pb2.ActorInformation.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
class LearnerServiceServicer(object):
@ -53,14 +58,20 @@ class LearnerServiceServicer(object):
The Learner implements this service.
"""
def SendTransition(self, request, context):
def SendInteractionMessage(self, request, context):
"""Actor -> Learner to store transitions
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendInteractionMessage(self, request, context):
def StreamParameters(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ReceiveTransitions(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
@ -69,16 +80,21 @@ class LearnerServiceServicer(object):
def add_LearnerServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendTransition': grpc.unary_unary_rpc_method_handler(
servicer.SendTransition,
request_deserializer=hilserl__pb2.Transition.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
servicer.SendInteractionMessage,
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'StreamParameters': grpc.unary_stream_rpc_method_handler(
servicer.StreamParameters,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.Parameters.SerializeToString,
),
'ReceiveTransitions': grpc.stream_unary_rpc_method_handler(
servicer.ReceiveTransitions,
request_deserializer=hilserl__pb2.ActorInformation.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'hil_serl.LearnerService', rpc_method_handlers)
@ -92,33 +108,6 @@ class LearnerService(object):
The Learner implements this service.
"""
@staticmethod
def SendTransition(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.LearnerService/SendTransition',
hilserl__pb2.Transition.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendInteractionMessage(request,
target,
@ -146,76 +135,8 @@ class LearnerService(object):
metadata,
_registered_method=True)
class ActorServiceStub(object):
"""ActorService: the Learner calls this to push parameters.
The Actor implements this service.
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.StreamTransition = channel.unary_stream(
'/hil_serl.ActorService/StreamTransition',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.ActorInformation.FromString,
_registered_method=True)
self.SendParameters = channel.unary_unary(
'/hil_serl.ActorService/SendParameters',
request_serializer=hilserl__pb2.Parameters.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
class ActorServiceServicer(object):
"""ActorService: the Learner calls this to push parameters.
The Actor implements this service.
"""
def StreamTransition(self, request, context):
"""Learner -> Actor to send new parameters
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendParameters(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_ActorServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'StreamTransition': grpc.unary_stream_rpc_method_handler(
servicer.StreamTransition,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.ActorInformation.SerializeToString,
),
'SendParameters': grpc.unary_unary_rpc_method_handler(
servicer.SendParameters,
request_deserializer=hilserl__pb2.Parameters.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'hil_serl.ActorService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('hil_serl.ActorService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class ActorService(object):
"""ActorService: the Learner calls this to push parameters.
The Actor implements this service.
"""
@staticmethod
def StreamTransition(request,
def StreamParameters(request,
target,
options=(),
channel_credentials=None,
@ -228,9 +149,9 @@ class ActorService(object):
return grpc.experimental.unary_stream(
request,
target,
'/hil_serl.ActorService/StreamTransition',
'/hil_serl.LearnerService/StreamParameters',
hilserl__pb2.Empty.SerializeToString,
hilserl__pb2.ActorInformation.FromString,
hilserl__pb2.Parameters.FromString,
options,
channel_credentials,
insecure,
@ -242,7 +163,7 @@ class ActorService(object):
_registered_method=True)
@staticmethod
def SendParameters(request,
def ReceiveTransitions(request_iterator,
target,
options=(),
channel_credentials=None,
@ -252,11 +173,11 @@ class ActorService(object):
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
return grpc.experimental.stream_unary(
request_iterator,
target,
'/hil_serl.ActorService/SendParameters',
hilserl__pb2.Parameters.SerializeToString,
'/hil_serl.LearnerService/ReceiveTransitions',
hilserl__pb2.ActorInformation.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,

View File

@ -14,19 +14,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import pickle
import queue
import shutil
import time
from pprint import pformat
from threading import Lock, Thread
import signal
from threading import Event
from concurrent.futures import ThreadPoolExecutor
import grpc
# Import generated stubs
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import hydra
import torch
@ -55,10 +55,11 @@ from lerobot.common.utils.utils import (
from lerobot.scripts.server.buffer import (
ReplayBuffer,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
)
from lerobot.scripts.server import learner_service
logging.basicConfig(level=logging.INFO)
transition_queue = queue.Queue()
@ -77,9 +78,13 @@ def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
# if resume == True
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
if not checkpoint_dir.exists():
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
raise RuntimeError(
f"No model checkpoint found in {checkpoint_dir} for resume=True"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
checkpoint_cfg_path = str(
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
logging.info(
colored(
"Resume=True detected, resuming previous run",
@ -112,7 +117,9 @@ def load_training_state(
if not cfg.resume:
return None, None
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
training_state = torch.load(
logger.last_checkpoint_dir / logger.training_state_file_name
)
if isinstance(training_state["optimizer"], dict):
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
@ -126,7 +133,9 @@ def load_training_state(
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
@ -136,7 +145,9 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
def initialize_replay_buffer(
cfg: DictConfig, logger: Logger, device: str
) -> ReplayBuffer:
if not cfg.resume:
return ReplayBuffer(
capacity=cfg.training.online_buffer_capacity,
@ -146,7 +157,9 @@ def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> Re
)
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
repo_id=cfg.dataset_repo_id,
local_files_only=True,
root=logger.log_dir / "dataset",
)
return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
@ -168,18 +181,10 @@ def start_learner_threads(
logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
shutdown_event: Event | None = None,
) -> None:
actor_ip = cfg.actor_learner_config.actor_ip
port = cfg.actor_learner_config.port
server_thread = Thread(
target=stream_transitions_from_actor,
args=(
actor_ip,
port,
),
daemon=True,
)
host = cfg.actor_learner_config.learner_host
port = cfg.actor_learner_config.learner_port
transition_thread = Thread(
target=add_actor_information_and_train,
@ -196,95 +201,56 @@ def start_learner_threads(
logger,
resume_optimization_step,
resume_interaction_step,
shutdown_event,
),
)
param_push_thread = Thread(
target=learner_push_parameters,
args=(policy, policy_lock, actor_ip, port, 15),
daemon=True,
)
server_thread.start()
transition_thread.start()
param_push_thread.start()
param_push_thread.join()
service = learner_service.LearnerService(
shutdown_event,
policy,
policy_lock,
cfg.actor_learner_config.policy_parameters_push_frequency,
transition_queue,
interaction_message_queue,
)
server = start_learner_server(service, host, port)
shutdown_event.wait()
server.stop(learner_service.STUTDOWN_TIMEOUT)
logging.info("[LEARNER] gRPC server stopped")
transition_thread.join()
server_thread.join()
logging.info("[LEARNER] Transition thread stopped")
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
"""
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
This function establishes a gRPC connection with the given `host` and `port`, then continuously
streams transition data from the `ActorServiceStub`. The received transition data is deserialized
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized
and stored in a separate queue (`interaction_message_queue`).
Args:
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`.
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
"""
# NOTE: This is waiting for the handshake to be done
# In the future we will do it in a canonical way with a proper handshake
time.sleep(10)
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
def start_learner_server(
service: learner_service.LearnerService,
host="0.0.0.0",
port=50051,
) -> grpc.server:
server = grpc.server(
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
],
)
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
for response in stub.StreamTransition(hilserl_pb2.Empty()):
if response.HasField("transition"):
buffer = io.BytesIO(response.transition.transition_bytes)
transition = torch.load(buffer)
transition_queue.put(transition)
if response.HasField("interaction_message"):
content = pickle.loads(response.interaction_message.interaction_message_bytes)
interaction_message_queue.put(content)
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
service,
server,
)
server.add_insecure_port(f"{host}:{port}")
server.start()
logging.info("[LEARNER] gRPC server started")
return server
def learner_push_parameters(
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5
def check_nan_in_transition(
observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor
):
"""
As a client, connect to the Actor's gRPC server (ActorService)
and periodically push new parameters.
"""
time.sleep(10)
channel = grpc.insecure_channel(
f"{actor_host}:{actor_port}",
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
)
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
while True:
with policy_lock:
params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None:
if policy.config.freeze_vision_encoder:
params_dict: dict[str, torch.Tensor] = {
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
}
else:
raise NotImplementedError(
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
)
params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize
buf = io.BytesIO()
torch.save(params_dict, buf)
params_bytes = buf.getvalue()
# Push them to the Actor's "SendParameters" method
logging.info("[LEARNER] Publishing parameters to the Actor")
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
time.sleep(seconds_between_pushes)
def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
for k in observations:
if torch.isnan(observations[k]).any():
logging.error(f"observations[{k}] contains NaN values")
@ -307,6 +273,7 @@ def add_actor_information_and_train(
logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
shutdown_event: Event | None = None,
):
"""
Handles data transfer from the actor to the learner, manages training updates,
@ -338,6 +305,7 @@ def add_actor_information_and_train(
logger (Logger): Logger instance for tracking training progress.
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
shutdown_event (Event | None): Event to signal shutdown.
"""
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
@ -345,9 +313,17 @@ def add_actor_information_and_train(
time.time()
logging.info("Starting learner thread")
interaction_message, transition = None, None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
optimization_step = (
resume_optimization_step if resume_optimization_step is not None else 0
)
interaction_step_shift = (
resume_interaction_step if resume_interaction_step is not None else 0
)
while True:
if shutdown_event is not None and shutdown_event.is_set():
logging.info("[LEARNER] Shutdown signal received. Exiting...")
break
while not transition_queue.empty():
transition_list = transition_queue.get()
for transition in transition_list:
@ -361,7 +337,9 @@ def add_actor_information_and_train(
interaction_message = interaction_message_queue.get()
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
logger.log_dict(
interaction_message, mode="train", custom_step_key="Interaction step"
)
# logging.info(f"Interaction message: {interaction_message}")
if len(replay_buffer) < cfg.training.online_step_before_learning:
@ -383,7 +361,9 @@ def add_actor_information_and_train(
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
check_nan_in_transition(
observations=observations, actions=actions, next_state=next_observations
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
@ -411,7 +391,9 @@ def add_actor_information_and_train(
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
check_nan_in_transition(
observations=observations, actions=actions, next_state=next_observations
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
@ -439,7 +421,9 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations)
loss_temperature = policy.compute_loss_temperature(
observations=observations
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
@ -453,9 +437,13 @@ def add_actor_information_and_train(
# logging.info(f"Training infos: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
frequency_for_one_optimization_step = 1 / (
time_for_one_optimization_step + 1e-9
)
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logging.info(
f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}"
)
logger.log_dict(
{
@ -471,7 +459,8 @@ def add_actor_information_and_train(
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if cfg.training.save_checkpoint and (
optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps
optimization_step % cfg.training.save_freq == 0
or optimization_step == cfg.training.online_steps
):
logging.info(f"Checkpoint policy after step {optimization_step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
@ -479,7 +468,9 @@ def add_actor_information_and_train(
_num_digits = max(6, len(str(cfg.training.online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0
interaction_message["Interaction step"]
if interaction_message is not None
else 0
)
logger.save_checkpoint(
optimization_step,
@ -538,7 +529,9 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha], lr=policy.config.critic_lr
)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
@ -580,14 +573,18 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
)
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
resume_optimization_step, resume_interaction_step = load_training_state(
cfg, logger, optimizers
)
log_training_info(cfg, out_dir, policy)
@ -599,7 +596,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
@ -609,6 +610,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
shutdown_event = Event()
def signal_handler(signum, frame):
print(
f"\nReceived signal {signal.Signals(signum).name}. Initiating learner shutdown..."
)
shutdown_event.set()
# Register signal handlers
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
start_learner_threads(
cfg,
device,
@ -621,6 +636,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logger,
resume_optimization_step,
resume_interaction_step,
shutdown_event,
)

View File

@ -0,0 +1,113 @@
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import torch
from torch import nn
from threading import Lock, Event
import logging
import queue
import io
import pickle
from lerobot.scripts.server.buffer import (
move_state_dict_to_device,
bytes_buffer_size,
state_to_bytes,
)
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
MAX_WORKERS = 10
STUTDOWN_TIMEOUT = 10
class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
def __init__(
self,
shutdown_event: Event,
policy: nn.Module,
policy_lock: Lock,
seconds_between_pushes: float,
transition_queue: queue.Queue,
interaction_message_queue: queue.Queue,
):
self.shutdown_event = shutdown_event
self.policy = policy
self.policy_lock = policy_lock
self.seconds_between_pushes = seconds_between_pushes
self.transition_queue = transition_queue
self.interaction_message_queue = interaction_message_queue
def _get_policy_state(self):
with self.policy_lock:
params_dict = self.policy.actor.state_dict()
if self.policy.config.vision_encoder_name is not None:
if self.policy.config.freeze_vision_encoder:
params_dict: dict[str, torch.Tensor] = {
k: v
for k, v in params_dict.items()
if not k.startswith("encoder.")
}
else:
raise NotImplementedError(
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
)
return move_state_dict_to_device(params_dict, device="cpu")
def _send_bytes(self, buffer: bytes):
size_in_bytes = bytes_buffer_size(buffer)
sent_bytes = 0
logging.info(f"Model state size {size_in_bytes/1024/1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
yield hilserl_pb2.Parameters(
transfer_state=transfer_state, parameter_bytes=chunk
)
sent_bytes += size_to_read
logging.info(
f"[Learner] Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
)
logging.info(f"[LEARNER] Published {sent_bytes/1024/1024} MB to the Actor")
def StreamParameters(self, request, context):
# TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor")
while not self.shutdown_event.is_set():
logging.debug("[LEARNER] Push parameters to the Actor")
state_dict = self._get_policy_state()
with state_to_bytes(state_dict) as buffer:
yield from self._send_bytes(buffer)
self.shutdown_event.wait(self.seconds_between_pushes)
def ReceiveTransitions(self, request_iterator, context):
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor")
for request in request_iterator:
logging.debug("[LEARNER] Received request")
if request.HasField("transition"):
buffer = io.BytesIO(request.transition.transition_bytes)
transition = torch.load(buffer)
self.transition_queue.put(transition)
if request.HasField("interaction_message"):
content = pickle.loads(
request.interaction_message.interaction_message_bytes
)
self.interaction_message_queue.put(content)

1051
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -53,16 +53,19 @@ dependencies = [
"einops>=0.8.0",
"flask>=3.0.3",
"gdown>=5.1.0",
"grpcio>=1.70.0",
"gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
"h5py>=3.10.0",
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
"imageio[ffmpeg]>=2.34.0",
"jsonlines>=4.0.0",
"mani-skill>=3.0.0b18",
"numba>=0.59.0",
"omegaconf>=2.3.0",
"opencv-python-headless>=4.9.0",
"packaging>=24.2",
"av>=12.0.5",
"protobuf>=5.29.3",
"pymunk>=6.6.0",
"pynput>=1.7.7",
"pyzmq>=26.2.1",
@ -87,6 +90,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
mani_skill = ["mani-skill"]
pi0 = ["transformers>=4.48.0"]
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
stretch = [
@ -107,7 +111,33 @@ requires-poetry = ">=2.1"
[tool.ruff]
line-length = 110
target-version = "py310"
exclude = ["tests/artifacts/**/*.safetensors"]
exclude = [
"tests/data",
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
"*_pb2.py",
"*_pb2_grpc.py",
]
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]

11
ruff.toml Normal file
View File

@ -0,0 +1,11 @@
# Exclude files/directories from Ruff
exclude = [
"*_pb2.py", # Ignore all protobuf generated files
"*_pb2_grpc.py", # Ignore all gRPC generated files
"lerobot/scripts/server/hilserl_pb2.py", # Ignore specific file
".git",
".env",
".venv",
"build",
"dist"
]