[Port HIL-SERL] Adjust Actor-Learner architecture & clean up dependency management for HIL-SERL (#722)
This commit is contained in:
parent
b8e9ee440b
commit
304d7136df
|
@ -46,7 +46,7 @@ repos:
|
||||||
rev: v3.19.1
|
rev: v3.19.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
|
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.9.10
|
rev: v0.9.10
|
||||||
hooks:
|
hooks:
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
FROM huggingface/lerobot-gpu:latest
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
libvulkan1 vulkan-tools \
|
||||||
|
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN pip install --upgrade --no-cache-dir pip
|
||||||
|
RUN pip install --no-cache-dir ".[mani-skill]"
|
||||||
|
|
||||||
|
# Set EGL as the rendering backend for MuJoCo
|
||||||
|
ENV MUJOCO_GL="egl"
|
|
@ -81,3 +81,14 @@ You can also log sample predictions during evaluation. Each logged sample will i
|
||||||
- The **classifier's "confidence" (logits/probability)**.
|
- The **classifier's "confidence" (logits/probability)**.
|
||||||
|
|
||||||
These logs can be useful for diagnosing and debugging performance issues.
|
These logs can be useful for diagnosing and debugging performance issues.
|
||||||
|
|
||||||
|
|
||||||
|
#### Generate protobuf files
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m grpc_tools.protoc \
|
||||||
|
-I lerobot/scripts/server \
|
||||||
|
--python_out=lerobot/scripts/server \
|
||||||
|
--grpc_python_out=lerobot/scripts/server \
|
||||||
|
lerobot/scripts/server/hilserl.proto
|
||||||
|
```
|
||||||
|
|
|
@ -41,11 +41,16 @@ class SACConfig:
|
||||||
)
|
)
|
||||||
input_normalization_params: dict[str, dict[str, list[float]]] = field(
|
input_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.image": {"mean": [[0.485, 0.456, 0.406]], "std": [[0.229, 0.224, 0.225]]},
|
"observation.image": {
|
||||||
|
"mean": [[0.485, 0.456, 0.406]],
|
||||||
|
"std": [[0.229, 0.224, 0.225]],
|
||||||
|
},
|
||||||
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
|
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
output_normalization_modes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {"action": "min_max"}
|
||||||
|
)
|
||||||
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": {"min": [-1, -1], "max": [1, 1]},
|
"action": {"min": [-1, -1], "max": [1, 1]},
|
||||||
|
@ -54,9 +59,8 @@ class SACConfig:
|
||||||
# TODO: Move it outside of the config
|
# TODO: Move it outside of the config
|
||||||
actor_learner_config: dict[str, str | int] = field(
|
actor_learner_config: dict[str, str | int] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"actor_ip": "127.0.0.1",
|
"learner_host": "127.0.0.1",
|
||||||
"port": 50051,
|
"learner_port": 50051,
|
||||||
"learner_ip": "127.0.0.1",
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
camera_number: int = 1
|
camera_number: int = 1
|
||||||
|
|
|
@ -108,5 +108,6 @@ policy:
|
||||||
utd_ratio: 2 # 10
|
utd_ratio: 2 # 10
|
||||||
|
|
||||||
actor_learner_config:
|
actor_learner_config:
|
||||||
actor_ip: "127.0.0.1"
|
learner_host: "127.0.0.1"
|
||||||
port: 50051
|
learner_port: 50051
|
||||||
|
policy_parameters_push_frequency: 15
|
||||||
|
|
|
@ -65,7 +65,7 @@ policy:
|
||||||
action: [4] # ["${env.action_dim}"]
|
action: [4] # ["${env.action_dim}"]
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
input_normalization_modes:
|
input_normalization_modes:
|
||||||
observation.images.front: mean_std
|
observation.images.front: mean_std
|
||||||
observation.images.side: mean_std
|
observation.images.side: mean_std
|
||||||
observation.state: min_max
|
observation.state: min_max
|
||||||
|
@ -80,7 +80,7 @@ policy:
|
||||||
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
||||||
max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
||||||
|
|
||||||
# min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
|
# min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
|
||||||
# max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
|
# max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
|
||||||
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
||||||
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
||||||
|
@ -112,8 +112,9 @@ policy:
|
||||||
utd_ratio: 2 # 10
|
utd_ratio: 2 # 10
|
||||||
|
|
||||||
actor_learner_config:
|
actor_learner_config:
|
||||||
actor_ip: "127.0.0.1"
|
learner_host: "127.0.0.1"
|
||||||
port: 50051
|
learner_port: 50051
|
||||||
|
policy_parameters_push_frequency: 15
|
||||||
|
|
||||||
# # Loss coefficients.
|
# # Loss coefficients.
|
||||||
# reward_coeff: 0.5
|
# reward_coeff: 0.5
|
||||||
|
|
|
@ -17,9 +17,9 @@ import io
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
import queue
|
import queue
|
||||||
import time
|
|
||||||
from concurrent import futures
|
|
||||||
from statistics import mean, quantiles
|
from statistics import mean, quantiles
|
||||||
|
import signal
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
# from lerobot.scripts.eval import eval_policy
|
# from lerobot.scripts.eval import eval_policy
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
@ -35,7 +35,6 @@ from torch import nn
|
||||||
# from lerobot.common.envs.utils import preprocess_maniskill_observation
|
# from lerobot.common.envs.utils import preprocess_maniskill_observation
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.common.robot_devices.control_utils import busy_wait
|
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
|
@ -44,14 +43,24 @@ from lerobot.common.utils.utils import (
|
||||||
set_global_seed,
|
set_global_seed,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
|
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
|
||||||
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
|
from lerobot.scripts.server.buffer import (
|
||||||
|
Transition,
|
||||||
|
move_state_dict_to_device,
|
||||||
|
move_transition_to_device,
|
||||||
|
bytes_buffer_size,
|
||||||
|
)
|
||||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||||
|
from lerobot.scripts.server import learner_service
|
||||||
|
|
||||||
|
from threading import Event
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
parameters_queue = queue.Queue(maxsize=1)
|
parameters_queue = queue.Queue(maxsize=1)
|
||||||
message_queue = queue.Queue(maxsize=1_000_000)
|
message_queue = queue.Queue(maxsize=1_000_000)
|
||||||
|
|
||||||
|
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||||
|
|
||||||
|
|
||||||
class ActorInformation:
|
class ActorInformation:
|
||||||
"""
|
"""
|
||||||
|
@ -70,95 +79,171 @@ class ActorInformation:
|
||||||
self.interaction_message = interaction_message
|
self.interaction_message = interaction_message
|
||||||
|
|
||||||
|
|
||||||
class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
|
def receive_policy(
|
||||||
"""
|
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
|
||||||
gRPC service for actor-learner communication in reinforcement learning.
|
shutdown_event: Event,
|
||||||
|
parameters_queue: queue.Queue,
|
||||||
|
):
|
||||||
|
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||||
|
bytes_buffer = io.BytesIO()
|
||||||
|
step = 0
|
||||||
|
try:
|
||||||
|
for model_update in learner_client.StreamParameters(hilserl_pb2.Empty()):
|
||||||
|
if shutdown_event.is_set():
|
||||||
|
logging.info("[ACTOR] Shutting down policy streaming receiver")
|
||||||
|
return hilserl_pb2.Empty()
|
||||||
|
|
||||||
This service is responsible for:
|
if model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
|
||||||
1. Streaming batches of transition data and statistical metrics from the actor to the learner.
|
bytes_buffer.seek(0)
|
||||||
2. Receiving updated network parameters from the learner.
|
bytes_buffer.truncate(0)
|
||||||
"""
|
bytes_buffer.write(model_update.parameter_bytes)
|
||||||
|
logging.info("Received model update at step 0")
|
||||||
def StreamTransition(self, request, context): # noqa: N802
|
step = 0
|
||||||
"""
|
continue
|
||||||
Streams data from the actor to the learner.
|
elif (
|
||||||
|
model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
||||||
This function continuously retrieves messages from the queue and processes them based on their type:
|
):
|
||||||
|
bytes_buffer.write(model_update.parameter_bytes)
|
||||||
- **Transition Data:**
|
step += 1
|
||||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
logging.info(f"Received model update at step {step}")
|
||||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
elif model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
|
||||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
bytes_buffer.write(model_update.parameter_bytes)
|
||||||
|
logging.info(
|
||||||
- **Interaction Messages:**
|
f"Received model update at step end size {bytes_buffer_size(bytes_buffer)}"
|
||||||
- Contains useful statistics about episodic rewards and policy timings.
|
|
||||||
- The message is serialized using `pickle` and sent to the learner.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
message = message_queue.get(block=True)
|
|
||||||
|
|
||||||
if message.transition is not None:
|
|
||||||
transition_to_send_to_learner: list[Transition] = [
|
|
||||||
move_transition_to_device(transition=T, device="cpu") for T in message.transition
|
|
||||||
]
|
|
||||||
# Check for NaNs in transitions before sending to learner
|
|
||||||
for transition in transition_to_send_to_learner:
|
|
||||||
for key, value in transition["state"].items():
|
|
||||||
if torch.isnan(value).any():
|
|
||||||
logging.warning(f"Found NaN values in transition {key}")
|
|
||||||
buf = io.BytesIO()
|
|
||||||
torch.save(transition_to_send_to_learner, buf)
|
|
||||||
transition_bytes = buf.getvalue()
|
|
||||||
|
|
||||||
transition_message = hilserl_pb2.Transition(transition_bytes=transition_bytes)
|
|
||||||
|
|
||||||
response = hilserl_pb2.ActorInformation(transition=transition_message)
|
|
||||||
|
|
||||||
elif message.interaction_message is not None:
|
|
||||||
content = hilserl_pb2.InteractionMessage(
|
|
||||||
interaction_message_bytes=pickle.dumps(message.interaction_message)
|
|
||||||
)
|
)
|
||||||
response = hilserl_pb2.ActorInformation(interaction_message=content)
|
|
||||||
|
|
||||||
yield response
|
state_dict = torch.load(bytes_buffer)
|
||||||
|
|
||||||
def SendParameters(self, request, context): # noqa: N802
|
bytes_buffer.seek(0)
|
||||||
"""
|
bytes_buffer.truncate(0)
|
||||||
Receives updated parameters from the learner and updates the actor.
|
step = 0
|
||||||
|
|
||||||
The learner calls this method to send new model parameters. The received parameters are deserialized
|
logging.info("Model updated")
|
||||||
and placed in a queue to be consumed by the actor.
|
|
||||||
|
|
||||||
Args:
|
parameters_queue.put(state_dict)
|
||||||
request (hilserl_pb2.ParameterUpdate): The request containing serialized network parameters.
|
|
||||||
context (grpc.ServicerContext): The gRPC context.
|
|
||||||
|
|
||||||
Returns:
|
except grpc.RpcError as e:
|
||||||
hilserl_pb2.Empty: An empty response to acknowledge receipt.
|
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||||
"""
|
|
||||||
buffer = io.BytesIO(request.parameter_bytes)
|
return hilserl_pb2.Empty()
|
||||||
params = torch.load(buffer)
|
|
||||||
parameters_queue.put(params)
|
|
||||||
return hilserl_pb2.Empty()
|
|
||||||
|
|
||||||
|
|
||||||
def serve_actor_service(port=50052):
|
def transitions_stream(shutdown_event: Event, message_queue: queue.Queue):
|
||||||
|
while not shutdown_event.is_set():
|
||||||
|
try:
|
||||||
|
message = message_queue.get(block=True, timeout=5)
|
||||||
|
except queue.Empty:
|
||||||
|
logging.debug("[ACTOR] Transition queue is empty")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if message.transition is not None:
|
||||||
|
transition_to_send_to_learner: list[Transition] = [
|
||||||
|
move_transition_to_device(transition=T, device="cpu")
|
||||||
|
for T in message.transition
|
||||||
|
]
|
||||||
|
# Check for NaNs in transitions before sending to learner
|
||||||
|
for transition in transition_to_send_to_learner:
|
||||||
|
for key, value in transition["state"].items():
|
||||||
|
if torch.isnan(value).any():
|
||||||
|
logging.warning(f"Found NaN values in transition {key}")
|
||||||
|
buf = io.BytesIO()
|
||||||
|
torch.save(transition_to_send_to_learner, buf)
|
||||||
|
transition_bytes = buf.getvalue()
|
||||||
|
|
||||||
|
transition_message = hilserl_pb2.Transition(
|
||||||
|
transition_bytes=transition_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
response = hilserl_pb2.ActorInformation(transition=transition_message)
|
||||||
|
|
||||||
|
elif message.interaction_message is not None:
|
||||||
|
content = hilserl_pb2.InteractionMessage(
|
||||||
|
interaction_message_bytes=pickle.dumps(message.interaction_message)
|
||||||
|
)
|
||||||
|
response = hilserl_pb2.ActorInformation(interaction_message=content)
|
||||||
|
|
||||||
|
yield response
|
||||||
|
|
||||||
|
return hilserl_pb2.Empty()
|
||||||
|
|
||||||
|
|
||||||
|
def send_transitions(
|
||||||
|
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
|
||||||
|
shutdown_event: Event,
|
||||||
|
message_queue: queue.Queue,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Runs a gRPC server to start streaming the data from the actor to the learner.
|
Streams data from the actor to the learner.
|
||||||
Throught this server the learner can push parameters to the Actor as well.
|
|
||||||
|
This function continuously retrieves messages from the queue and processes them based on their type:
|
||||||
|
|
||||||
|
- **Transition Data:**
|
||||||
|
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||||
|
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||||
|
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||||
|
|
||||||
|
- **Interaction Messages:**
|
||||||
|
- Contains useful statistics about episodic rewards and policy timings.
|
||||||
|
- The message is serialized using `pickle` and sent to the learner.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
|
||||||
"""
|
"""
|
||||||
server = grpc.server(
|
try:
|
||||||
futures.ThreadPoolExecutor(max_workers=20),
|
learner_client.ReceiveTransitions(
|
||||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
transitions_stream(shutdown_event, message_queue)
|
||||||
|
)
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||||
|
|
||||||
|
logging.info("[ACTOR] Finished streaming transitions")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def learner_service_client(
|
||||||
|
host="127.0.0.1", port=50051
|
||||||
|
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||||
|
import json
|
||||||
|
|
||||||
|
"""
|
||||||
|
Returns a client for the learner service.
|
||||||
|
|
||||||
|
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
||||||
|
So we need to create only one client and reuse it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
service_config = {
|
||||||
|
"methodConfig": [
|
||||||
|
{
|
||||||
|
"name": [{}], # Applies to ALL methods in ALL services
|
||||||
|
"retryPolicy": {
|
||||||
|
"maxAttempts": 5, # Max retries (total attempts = 5)
|
||||||
|
"initialBackoff": "0.1s", # First retry after 0.1s
|
||||||
|
"maxBackoff": "2s", # Max wait time between retries
|
||||||
|
"backoffMultiplier": 2, # Exponential backoff factor
|
||||||
|
"retryableStatusCodes": [
|
||||||
|
"UNAVAILABLE",
|
||||||
|
"DEADLINE_EXCEEDED",
|
||||||
|
], # Retries on network failures
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
service_config_json = json.dumps(service_config)
|
||||||
|
|
||||||
|
channel = grpc.insecure_channel(
|
||||||
|
f"{host}:{port}",
|
||||||
|
options=[
|
||||||
|
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||||
|
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||||
|
("grpc.enable_retries", 1),
|
||||||
|
("grpc.service_config", service_config_json),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(ActorServiceServicer(), server)
|
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
|
||||||
server.add_insecure_port(f"[::]:{port}")
|
logging.info("[LEARNER] Learner service client created")
|
||||||
server.start()
|
return stub, channel
|
||||||
logging.info(f"[ACTOR] gRPC server listening on port {port}")
|
|
||||||
server.wait_for_termination()
|
|
||||||
|
|
||||||
|
|
||||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
|
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
|
||||||
|
@ -169,7 +254,9 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, d
|
||||||
policy.load_state_dict(state_dict)
|
policy.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
|
def act_with_policy(
|
||||||
|
cfg: DictConfig, robot: Robot, reward_classifier: nn.Module, shutdown_event: Event
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Executes policy interaction within the environment.
|
Executes policy interaction within the environment.
|
||||||
|
|
||||||
|
@ -182,7 +269,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
|
|
||||||
logging.info("make_env online")
|
logging.info("make_env online")
|
||||||
|
|
||||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
online_env = make_robot_env(
|
||||||
|
robot=robot, reward_classifier=reward_classifier, cfg=cfg
|
||||||
|
)
|
||||||
|
|
||||||
set_global_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.device, log=True)
|
||||||
|
@ -227,17 +316,27 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
episode_intervention = False
|
episode_intervention = False
|
||||||
|
|
||||||
for interaction_step in range(cfg.training.online_steps):
|
for interaction_step in range(cfg.training.online_steps):
|
||||||
|
if shutdown_event.is_set():
|
||||||
|
logging.info("[ACTOR] Shutdown signal received. Exiting...")
|
||||||
|
return
|
||||||
|
|
||||||
if interaction_step >= cfg.training.online_step_before_learning:
|
if interaction_step >= cfg.training.online_step_before_learning:
|
||||||
# Time policy inference and check if it meets FPS requirement
|
# Time policy inference and check if it meets FPS requirement
|
||||||
with TimerManager(
|
with TimerManager(
|
||||||
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
|
elapsed_time_list=list_policy_time,
|
||||||
|
label="Policy inference time",
|
||||||
|
log=False,
|
||||||
) as timer: # noqa: F841
|
) as timer: # noqa: F841
|
||||||
action = policy.select_action(batch=obs)
|
action = policy.select_action(batch=obs)
|
||||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||||
|
|
||||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
log_policy_frequency_issue(
|
||||||
|
policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step
|
||||||
|
)
|
||||||
|
|
||||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
next_obs, reward, done, truncated, info = online_env.step(
|
||||||
|
action.squeeze(dim=0).cpu().numpy()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# TODO (azouitine): Make a custom space for torch tensor
|
# TODO (azouitine): Make a custom space for torch tensor
|
||||||
action = online_env.action_space.sample()
|
action = online_env.action_space.sample()
|
||||||
|
@ -245,7 +344,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
|
|
||||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||||
action = (
|
action = (
|
||||||
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
torch.from_numpy(action[0])
|
||||||
|
.to(device, non_blocking=device.type == "cuda")
|
||||||
|
.unsqueeze(dim=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
sum_reward_episode += float(reward)
|
sum_reward_episode += float(reward)
|
||||||
|
@ -261,7 +362,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
# Check for NaN values in observations
|
# Check for NaN values in observations
|
||||||
for key, tensor in obs.items():
|
for key, tensor in obs.items():
|
||||||
if torch.isnan(tensor).any():
|
if torch.isnan(tensor).any():
|
||||||
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
logging.error(
|
||||||
|
f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}"
|
||||||
|
)
|
||||||
|
|
||||||
list_transition_to_send_to_learner.append(
|
list_transition_to_send_to_learner.append(
|
||||||
Transition(
|
Transition(
|
||||||
|
@ -281,13 +384,19 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
# Because we are using a single environment we can index at zero
|
# Because we are using a single environment we can index at zero
|
||||||
if done or truncated:
|
if done or truncated:
|
||||||
# TODO: Handle logging for episode information
|
# TODO: Handle logging for episode information
|
||||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
logging.info(
|
||||||
|
f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}"
|
||||||
|
)
|
||||||
|
|
||||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
update_policy_parameters(
|
||||||
|
policy=policy.actor, parameters_queue=parameters_queue, device=device
|
||||||
|
)
|
||||||
|
|
||||||
if len(list_transition_to_send_to_learner) > 0:
|
if len(list_transition_to_send_to_learner) > 0:
|
||||||
send_transitions_in_chunks(
|
send_transitions_in_chunks(
|
||||||
transitions=list_transition_to_send_to_learner, message_queue=message_queue, chunk_size=4
|
transitions=list_transition_to_send_to_learner,
|
||||||
|
message_queue=message_queue,
|
||||||
|
chunk_size=4,
|
||||||
)
|
)
|
||||||
list_transition_to_send_to_learner = []
|
list_transition_to_send_to_learner = []
|
||||||
|
|
||||||
|
@ -332,11 +441,16 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||||
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
|
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
|
||||||
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
|
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
|
||||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
|
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
|
||||||
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
|
stats = {
|
||||||
|
"Policy frequency [Hz]": policy_fps,
|
||||||
|
"Policy frequency 90th-p [Hz]": quantiles_90,
|
||||||
|
}
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
|
def log_policy_frequency_issue(
|
||||||
|
policy_fps: float, cfg: DictConfig, interaction_step: int
|
||||||
|
):
|
||||||
if policy_fps < cfg.fps:
|
if policy_fps < cfg.fps:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
||||||
|
@ -347,7 +461,34 @@ def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_s
|
||||||
def actor_cli(cfg: dict):
|
def actor_cli(cfg: dict):
|
||||||
robot = make_robot(cfg=cfg.robot)
|
robot = make_robot(cfg=cfg.robot)
|
||||||
|
|
||||||
server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True)
|
shutdown_event = Event()
|
||||||
|
|
||||||
|
# Define signal handler
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
logging.info("Shutdown signal received. Cleaning up...")
|
||||||
|
shutdown_event.set()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
|
||||||
|
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
|
||||||
|
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||||
|
|
||||||
|
learner_client, grpc_channel = learner_service_client(
|
||||||
|
host=cfg.actor_learner_config.learner_host,
|
||||||
|
port=cfg.actor_learner_config.learner_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
receive_policy_thread = Thread(
|
||||||
|
target=receive_policy,
|
||||||
|
args=(learner_client, shutdown_event, parameters_queue),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
transitions_thread = Thread(
|
||||||
|
target=send_transitions,
|
||||||
|
args=(learner_client, shutdown_event, message_queue),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
|
||||||
# HACK: FOR MANISKILL we do not have a reward classifier
|
# HACK: FOR MANISKILL we do not have a reward classifier
|
||||||
# TODO: Remove this once we merge into main
|
# TODO: Remove this once we merge into main
|
||||||
|
@ -360,15 +501,27 @@ def actor_cli(cfg: dict):
|
||||||
pretrained_path=cfg.env.reward_classifier.pretrained_path,
|
pretrained_path=cfg.env.reward_classifier.pretrained_path,
|
||||||
config_path=cfg.env.reward_classifier.config_path,
|
config_path=cfg.env.reward_classifier.config_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
policy_thread = Thread(
|
policy_thread = Thread(
|
||||||
target=act_with_policy,
|
target=act_with_policy,
|
||||||
daemon=True,
|
daemon=True,
|
||||||
args=(cfg, robot, reward_classifier),
|
args=(cfg, robot, reward_classifier, shutdown_event),
|
||||||
)
|
)
|
||||||
server_thread.start()
|
|
||||||
|
transitions_thread.start()
|
||||||
policy_thread.start()
|
policy_thread.start()
|
||||||
|
receive_policy_thread.start()
|
||||||
|
|
||||||
|
shutdown_event.wait()
|
||||||
|
logging.info("[ACTOR] Shutdown event received")
|
||||||
|
grpc_channel.close()
|
||||||
|
|
||||||
policy_thread.join()
|
policy_thread.join()
|
||||||
server_thread.join()
|
logging.info("[ACTOR] Policy thread joined")
|
||||||
|
transitions_thread.join()
|
||||||
|
logging.info("[ACTOR] Transitions thread joined")
|
||||||
|
receive_policy_thread.join()
|
||||||
|
logging.info("[ACTOR] Receive policy thread joined")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -17,6 +17,7 @@ import functools
|
||||||
import random
|
import random
|
||||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||||
|
|
||||||
|
import io
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
@ -41,24 +42,33 @@ class BatchTransition(TypedDict):
|
||||||
done: torch.Tensor
|
done: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
def move_transition_to_device(
|
||||||
|
transition: Transition, device: str = "cpu"
|
||||||
|
) -> Transition:
|
||||||
# Move state tensors to CPU
|
# Move state tensors to CPU
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
transition["state"] = {
|
transition["state"] = {
|
||||||
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
|
key: val.to(device, non_blocking=device.type == "cuda")
|
||||||
|
for key, val in transition["state"].items()
|
||||||
}
|
}
|
||||||
|
|
||||||
# Move action to CPU
|
# Move action to CPU
|
||||||
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
|
transition["action"] = transition["action"].to(
|
||||||
|
device, non_blocking=device.type == "cuda"
|
||||||
|
)
|
||||||
|
|
||||||
# No need to move reward or done, as they are float and bool
|
# No need to move reward or done, as they are float and bool
|
||||||
|
|
||||||
# No need to move reward or done, as they are float and bool
|
# No need to move reward or done, as they are float and bool
|
||||||
if isinstance(transition["reward"], torch.Tensor):
|
if isinstance(transition["reward"], torch.Tensor):
|
||||||
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
|
transition["reward"] = transition["reward"].to(
|
||||||
|
device=device, non_blocking=device.type == "cuda"
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(transition["done"], torch.Tensor):
|
if isinstance(transition["done"], torch.Tensor):
|
||||||
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
|
transition["done"] = transition["done"].to(
|
||||||
|
device, non_blocking=device.type == "cuda"
|
||||||
|
)
|
||||||
|
|
||||||
# Move next_state tensors to CPU
|
# Move next_state tensors to CPU
|
||||||
transition["next_state"] = {
|
transition["next_state"] = {
|
||||||
|
@ -82,7 +92,10 @@ def move_state_dict_to_device(state_dict, device):
|
||||||
if isinstance(state_dict, torch.Tensor):
|
if isinstance(state_dict, torch.Tensor):
|
||||||
return state_dict.to(device)
|
return state_dict.to(device)
|
||||||
elif isinstance(state_dict, dict):
|
elif isinstance(state_dict, dict):
|
||||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
return {
|
||||||
|
k: move_state_dict_to_device(v, device=device)
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
}
|
||||||
elif isinstance(state_dict, list):
|
elif isinstance(state_dict, list):
|
||||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||||
elif isinstance(state_dict, tuple):
|
elif isinstance(state_dict, tuple):
|
||||||
|
@ -91,6 +104,22 @@ def move_state_dict_to_device(state_dict, device):
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> io.BytesIO:
|
||||||
|
"""Convert model state dict to flat array for transmission"""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
|
||||||
|
torch.save(state_dict, buffer)
|
||||||
|
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||||
|
buffer.seek(0, io.SEEK_END)
|
||||||
|
result = buffer.tell()
|
||||||
|
buffer.seek(0)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||||
|
@ -116,7 +145,9 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
|
||||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||||
|
|
||||||
# Gather pixels
|
# Gather pixels
|
||||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
cropped_hwcn = images_hwcn[
|
||||||
|
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
|
||||||
|
]
|
||||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||||
|
|
||||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||||
|
@ -179,7 +210,9 @@ class ReplayBuffer:
|
||||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||||
# Move tensors to the storage device
|
# Move tensors to the storage device
|
||||||
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
|
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
|
||||||
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
|
next_state = {
|
||||||
|
key: tensor.to(self.storage_device) for key, tensor in next_state.items()
|
||||||
|
}
|
||||||
action = action.to(self.storage_device)
|
action = action.to(self.storage_device)
|
||||||
# if complementary_info is not None:
|
# if complementary_info is not None:
|
||||||
# complementary_info = {
|
# complementary_info = {
|
||||||
|
@ -234,7 +267,9 @@ class ReplayBuffer:
|
||||||
)
|
)
|
||||||
|
|
||||||
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
|
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
|
||||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
list_transition = cls._lerobotdataset_to_transitions(
|
||||||
|
dataset=lerobot_dataset, state_keys=state_keys
|
||||||
|
)
|
||||||
# Fill the replay buffer with the lerobot dataset transitions
|
# Fill the replay buffer with the lerobot dataset transitions
|
||||||
for data in list_transition:
|
for data in list_transition:
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
|
@ -295,7 +330,9 @@ class ReplayBuffer:
|
||||||
|
|
||||||
# If not provided, you can either raise an error or define a default:
|
# If not provided, you can either raise an error or define a default:
|
||||||
if state_keys is None:
|
if state_keys is None:
|
||||||
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
|
raise ValueError(
|
||||||
|
"You must provide a list of keys in `state_keys` that define your 'state'."
|
||||||
|
)
|
||||||
|
|
||||||
transitions: list[Transition] = []
|
transitions: list[Transition] = []
|
||||||
num_frames = len(dataset)
|
num_frames = len(dataset)
|
||||||
|
@ -350,33 +387,37 @@ class ReplayBuffer:
|
||||||
# -- Build batched states --
|
# -- Build batched states --
|
||||||
batch_state = {}
|
batch_state = {}
|
||||||
for key in self.state_keys:
|
for key in self.state_keys:
|
||||||
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
batch_state[key] = torch.cat(
|
||||||
self.device
|
[t["state"][key] for t in list_of_transitions], dim=0
|
||||||
)
|
).to(self.device)
|
||||||
if key.startswith("observation.image") and self.use_drq:
|
if key.startswith("observation.image") and self.use_drq:
|
||||||
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
||||||
|
|
||||||
# -- Build batched actions --
|
# -- Build batched actions --
|
||||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
|
||||||
|
|
||||||
# -- Build batched rewards --
|
|
||||||
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
|
|
||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# -- Build batched rewards --
|
||||||
|
batch_rewards = torch.tensor(
|
||||||
|
[t["reward"] for t in list_of_transitions], dtype=torch.float32
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
# -- Build batched next states --
|
# -- Build batched next states --
|
||||||
batch_next_state = {}
|
batch_next_state = {}
|
||||||
for key in self.state_keys:
|
for key in self.state_keys:
|
||||||
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
batch_next_state[key] = torch.cat(
|
||||||
self.device
|
[t["next_state"][key] for t in list_of_transitions], dim=0
|
||||||
)
|
).to(self.device)
|
||||||
if key.startswith("observation.image") and self.use_drq:
|
if key.startswith("observation.image") and self.use_drq:
|
||||||
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
|
batch_next_state[key] = self.image_augmentation_function(
|
||||||
|
batch_next_state[key]
|
||||||
|
)
|
||||||
|
|
||||||
# -- Build batched dones --
|
# -- Build batched dones --
|
||||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
batch_dones = torch.tensor(
|
||||||
self.device
|
[t["done"] for t in list_of_transitions], dtype=torch.float32
|
||||||
)
|
).to(self.device)
|
||||||
|
|
||||||
# Return a BatchTransition typed dict
|
# Return a BatchTransition typed dict
|
||||||
return BatchTransition(
|
return BatchTransition(
|
||||||
|
@ -433,7 +474,9 @@ class ReplayBuffer:
|
||||||
|
|
||||||
# Add state keys
|
# Add state keys
|
||||||
for key in self.state_keys:
|
for key in self.state_keys:
|
||||||
sample_val = first_transition["state"][key].squeeze(dim=0) # Remove batch dimension
|
sample_val = first_transition["state"][key].squeeze(
|
||||||
|
dim=0
|
||||||
|
) # Remove batch dimension
|
||||||
if not isinstance(sample_val, torch.Tensor):
|
if not isinstance(sample_val, torch.Tensor):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors."
|
f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors."
|
||||||
|
@ -465,7 +508,9 @@ class ReplayBuffer:
|
||||||
# We detect episode boundaries by `done == True`.
|
# We detect episode boundaries by `done == True`.
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
episode_index = 0
|
episode_index = 0
|
||||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
|
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||||
|
episode_index
|
||||||
|
)
|
||||||
|
|
||||||
frame_idx_in_episode = 0
|
frame_idx_in_episode = 0
|
||||||
for global_frame_idx, transition in enumerate(self.memory):
|
for global_frame_idx, transition in enumerate(self.memory):
|
||||||
|
@ -476,16 +521,24 @@ class ReplayBuffer:
|
||||||
# Expand dimension to match what the dataset expects (the dataset wants the raw shape)
|
# Expand dimension to match what the dataset expects (the dataset wants the raw shape)
|
||||||
# We assume your buffer has shape [C, H, W] (if image) or [D] if vector
|
# We assume your buffer has shape [C, H, W] (if image) or [D] if vector
|
||||||
# This is typically already correct, but if needed you can reshape below.
|
# This is typically already correct, but if needed you can reshape below.
|
||||||
frame_dict[key] = transition["state"][key].cpu().squeeze(dim=0) # Remove batch dimension
|
frame_dict[key] = (
|
||||||
|
transition["state"][key].cpu().squeeze(dim=0)
|
||||||
|
) # Remove batch dimension
|
||||||
|
|
||||||
# Fill action, reward, done
|
# Fill action, reward, done
|
||||||
# Make sure they are shape (X,) or (X,Y,...) as needed.
|
# Make sure they are shape (X,) or (X,Y,...) as needed.
|
||||||
frame_dict["action"] = transition["action"].cpu().squeeze(dim=0) # Remove batch dimension
|
frame_dict["action"] = (
|
||||||
|
transition["action"].cpu().squeeze(dim=0)
|
||||||
|
) # Remove batch dimension
|
||||||
frame_dict["next.reward"] = (
|
frame_dict["next.reward"] = (
|
||||||
torch.tensor([transition["reward"]], dtype=torch.float32).cpu().squeeze(dim=0)
|
torch.tensor([transition["reward"]], dtype=torch.float32)
|
||||||
|
.cpu()
|
||||||
|
.squeeze(dim=0)
|
||||||
)
|
)
|
||||||
frame_dict["next.done"] = (
|
frame_dict["next.done"] = (
|
||||||
torch.tensor([transition["done"]], dtype=torch.bool).cpu().squeeze(dim=0)
|
torch.tensor([transition["done"]], dtype=torch.bool)
|
||||||
|
.cpu()
|
||||||
|
.squeeze(dim=0)
|
||||||
)
|
)
|
||||||
# Add to the dataset's buffer
|
# Add to the dataset's buffer
|
||||||
lerobot_dataset.add_frame(frame_dict)
|
lerobot_dataset.add_frame(frame_dict)
|
||||||
|
@ -499,7 +552,9 @@ class ReplayBuffer:
|
||||||
episode_index += 1
|
episode_index += 1
|
||||||
frame_idx_in_episode = 0
|
frame_idx_in_episode = 0
|
||||||
# Start a new buffer for the next episode
|
# Start a new buffer for the next episode
|
||||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
|
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||||
|
episode_index
|
||||||
|
)
|
||||||
|
|
||||||
# We are done adding frames
|
# We are done adding frames
|
||||||
# If the last transition wasn't done=True, we still have an open buffer with frames.
|
# If the last transition wasn't done=True, we still have an open buffer with frames.
|
||||||
|
@ -541,7 +596,13 @@ def concatenate_batch_transitions(
|
||||||
) -> BatchTransition:
|
) -> BatchTransition:
|
||||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||||
left_batch_transitions["state"] = {
|
left_batch_transitions["state"] = {
|
||||||
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
|
key: torch.cat(
|
||||||
|
[
|
||||||
|
left_batch_transitions["state"][key],
|
||||||
|
right_batch_transition["state"][key],
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
for key in left_batch_transitions["state"]
|
for key in left_batch_transitions["state"]
|
||||||
}
|
}
|
||||||
left_batch_transitions["action"] = torch.cat(
|
left_batch_transitions["action"] = torch.cat(
|
||||||
|
@ -552,7 +613,11 @@ def concatenate_batch_transitions(
|
||||||
)
|
)
|
||||||
left_batch_transitions["next_state"] = {
|
left_batch_transitions["next_state"] = {
|
||||||
key: torch.cat(
|
key: torch.cat(
|
||||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
|
[
|
||||||
|
left_batch_transitions["next_state"][key],
|
||||||
|
right_batch_transition["next_state"][key],
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
)
|
)
|
||||||
for key in left_batch_transitions["next_state"]
|
for key in left_batch_transitions["next_state"]
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,11 +10,9 @@ import torch
|
||||||
import torchvision.transforms.functional as F # noqa: N812
|
import torchvision.transforms.functional as F # noqa: N812
|
||||||
|
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
from lerobot.common.utils.utils import init_hydra_config, log_say
|
from lerobot.common.utils.utils import init_hydra_config, log_say
|
||||||
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
@ -62,7 +60,9 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
if not self.robot.is_connected:
|
if not self.robot.is_connected:
|
||||||
self.robot.connect()
|
self.robot.connect()
|
||||||
|
|
||||||
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
|
self.initial_follower_position = robot.follower_arms["main"].read(
|
||||||
|
"Present_Position"
|
||||||
|
)
|
||||||
|
|
||||||
# Episode tracking.
|
# Episode tracking.
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
|
@ -70,7 +70,9 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
|
|
||||||
self.delta = delta
|
self.delta = delta
|
||||||
self.use_delta_action_space = use_delta_action_space
|
self.use_delta_action_space = use_delta_action_space
|
||||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
self.current_joint_positions = self.robot.follower_arms["main"].read(
|
||||||
|
"Present_Position"
|
||||||
|
)
|
||||||
|
|
||||||
# Retrieve the size of the joint position interval bound.
|
# Retrieve the size of the joint position interval bound.
|
||||||
self.relative_bounds_size = (
|
self.relative_bounds_size = (
|
||||||
|
@ -105,12 +107,16 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
image_keys = [key for key in example_obs if "image" in key]
|
image_keys = [key for key in example_obs if "image" in key]
|
||||||
state_keys = [key for key in example_obs if "image" not in key]
|
state_keys = [key for key in example_obs if "image" not in key]
|
||||||
observation_spaces = {
|
observation_spaces = {
|
||||||
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
|
key: gym.spaces.Box(
|
||||||
|
low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8
|
||||||
|
)
|
||||||
for key in image_keys
|
for key in image_keys
|
||||||
}
|
}
|
||||||
observation_spaces["observation.state"] = gym.spaces.Dict(
|
observation_spaces["observation.state"] = gym.spaces.Dict(
|
||||||
{
|
{
|
||||||
key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32)
|
key: gym.spaces.Box(
|
||||||
|
low=0, high=10, shape=example_obs[key].shape, dtype=np.float32
|
||||||
|
)
|
||||||
for key in state_keys
|
for key in state_keys
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -128,8 +134,12 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
action_space_robot = gym.spaces.Box(
|
action_space_robot = gym.spaces.Box(
|
||||||
low=self.robot.config.joint_position_relative_bounds["min"].cpu().numpy(),
|
low=self.robot.config.joint_position_relative_bounds["min"]
|
||||||
high=self.robot.config.joint_position_relative_bounds["max"].cpu().numpy(),
|
.cpu()
|
||||||
|
.numpy(),
|
||||||
|
high=self.robot.config.joint_position_relative_bounds["max"]
|
||||||
|
.cpu()
|
||||||
|
.numpy(),
|
||||||
shape=(action_dim,),
|
shape=(action_dim,),
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
@ -141,7 +151,9 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
def reset(
|
||||||
|
self, seed=None, options=None
|
||||||
|
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Reset the environment to its initial state.
|
Reset the environment to its initial state.
|
||||||
This method resets the step counter and clears any episodic data.
|
This method resets the step counter and clears any episodic data.
|
||||||
|
@ -198,24 +210,34 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
"""
|
"""
|
||||||
policy_action, intervention_bool = action
|
policy_action, intervention_bool = action
|
||||||
teleop_action = None
|
teleop_action = None
|
||||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
self.current_joint_positions = self.robot.follower_arms["main"].read(
|
||||||
|
"Present_Position"
|
||||||
|
)
|
||||||
if isinstance(policy_action, torch.Tensor):
|
if isinstance(policy_action, torch.Tensor):
|
||||||
policy_action = policy_action.cpu().numpy()
|
policy_action = policy_action.cpu().numpy()
|
||||||
policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
|
policy_action = np.clip(
|
||||||
|
policy_action, self.action_space[0].low, self.action_space[0].high
|
||||||
|
)
|
||||||
if not intervention_bool:
|
if not intervention_bool:
|
||||||
if self.use_delta_action_space:
|
if self.use_delta_action_space:
|
||||||
target_joint_positions = self.current_joint_positions + self.delta * policy_action
|
target_joint_positions = (
|
||||||
|
self.current_joint_positions + self.delta * policy_action
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
target_joint_positions = policy_action
|
target_joint_positions = policy_action
|
||||||
self.robot.send_action(torch.from_numpy(target_joint_positions))
|
self.robot.send_action(torch.from_numpy(target_joint_positions))
|
||||||
observation = self.robot.capture_observation()
|
observation = self.robot.capture_observation()
|
||||||
else:
|
else:
|
||||||
observation, teleop_action = self.robot.teleop_step(record_data=True)
|
observation, teleop_action = self.robot.teleop_step(record_data=True)
|
||||||
teleop_action = teleop_action["action"] # Convert tensor to appropriate format
|
teleop_action = teleop_action[
|
||||||
|
"action"
|
||||||
|
] # Convert tensor to appropriate format
|
||||||
|
|
||||||
# When applying the delta action space, convert teleop absolute values to relative differences.
|
# When applying the delta action space, convert teleop absolute values to relative differences.
|
||||||
if self.use_delta_action_space:
|
if self.use_delta_action_space:
|
||||||
teleop_action = (teleop_action - self.current_joint_positions) / self.delta
|
teleop_action = (
|
||||||
|
teleop_action - self.current_joint_positions
|
||||||
|
) / self.delta
|
||||||
if torch.any(teleop_action < -self.relative_bounds_size) and torch.any(
|
if torch.any(teleop_action < -self.relative_bounds_size) and torch.any(
|
||||||
teleop_action > self.relative_bounds_size
|
teleop_action > self.relative_bounds_size
|
||||||
):
|
):
|
||||||
|
@ -226,7 +248,9 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
)
|
)
|
||||||
|
|
||||||
teleop_action = torch.clamp(
|
teleop_action = torch.clamp(
|
||||||
teleop_action, -self.relative_bounds_size, self.relative_bounds_size
|
teleop_action,
|
||||||
|
-self.relative_bounds_size,
|
||||||
|
self.relative_bounds_size,
|
||||||
)
|
)
|
||||||
# NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action.
|
# NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action.
|
||||||
if teleop_action.dim() == 1:
|
if teleop_action.dim() == 1:
|
||||||
|
@ -245,7 +269,10 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
reward,
|
reward,
|
||||||
terminated,
|
terminated,
|
||||||
truncated,
|
truncated,
|
||||||
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
|
{
|
||||||
|
"action_intervention": teleop_action,
|
||||||
|
"is_intervention": teleop_action is not None,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
|
@ -351,7 +378,9 @@ class JointMaskingActionSpace(gym.Wrapper):
|
||||||
raise ValueError("Mask length must match action space dimensions")
|
raise ValueError("Mask length must match action space dimensions")
|
||||||
low = env.action_space.low[self.active_dims]
|
low = env.action_space.low[self.active_dims]
|
||||||
high = env.action_space.high[self.active_dims]
|
high = env.action_space.high[self.active_dims]
|
||||||
self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype)
|
self.action_space = gym.spaces.Box(
|
||||||
|
low=low, high=high, dtype=env.action_space.dtype
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(env.action_space, gym.spaces.Tuple):
|
if isinstance(env.action_space, gym.spaces.Tuple):
|
||||||
if len(mask) != env.action_space[0].shape[0]:
|
if len(mask) != env.action_space[0].shape[0]:
|
||||||
|
@ -359,8 +388,12 @@ class JointMaskingActionSpace(gym.Wrapper):
|
||||||
|
|
||||||
low = env.action_space[0].low[self.active_dims]
|
low = env.action_space[0].low[self.active_dims]
|
||||||
high = env.action_space[0].high[self.active_dims]
|
high = env.action_space[0].high[self.active_dims]
|
||||||
action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype)
|
action_space_masked = gym.spaces.Box(
|
||||||
self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1]))
|
low=low, high=high, dtype=env.action_space[0].dtype
|
||||||
|
)
|
||||||
|
self.action_space = gym.spaces.Tuple(
|
||||||
|
(action_space_masked, env.action_space[1])
|
||||||
|
)
|
||||||
# Create new action space with masked dimensions
|
# Create new action space with masked dimensions
|
||||||
|
|
||||||
def action(self, action):
|
def action(self, action):
|
||||||
|
@ -379,14 +412,18 @@ class JointMaskingActionSpace(gym.Wrapper):
|
||||||
# Extract the masked component from the tuple.
|
# Extract the masked component from the tuple.
|
||||||
masked_action = action[0] if isinstance(action, tuple) else action
|
masked_action = action[0] if isinstance(action, tuple) else action
|
||||||
# Create a full action for the Box element.
|
# Create a full action for the Box element.
|
||||||
full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype)
|
full_box_action = np.zeros(
|
||||||
|
self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype
|
||||||
|
)
|
||||||
full_box_action[self.active_dims] = masked_action
|
full_box_action[self.active_dims] = masked_action
|
||||||
# Return a tuple with the reconstructed Box action and the unchanged remainder.
|
# Return a tuple with the reconstructed Box action and the unchanged remainder.
|
||||||
return (full_box_action, action[1])
|
return (full_box_action, action[1])
|
||||||
else:
|
else:
|
||||||
# For Box action spaces.
|
# For Box action spaces.
|
||||||
masked_action = action if not isinstance(action, tuple) else action[0]
|
masked_action = action if not isinstance(action, tuple) else action[0]
|
||||||
full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype)
|
full_action = np.zeros(
|
||||||
|
self.env.action_space.shape, dtype=self.env.action_space.dtype
|
||||||
|
)
|
||||||
full_action[self.active_dims] = masked_action
|
full_action[self.active_dims] = masked_action
|
||||||
return full_action
|
return full_action
|
||||||
|
|
||||||
|
@ -395,9 +432,13 @@ class JointMaskingActionSpace(gym.Wrapper):
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
if "action_intervention" in info and info["action_intervention"] is not None:
|
if "action_intervention" in info and info["action_intervention"] is not None:
|
||||||
if info["action_intervention"].dim() == 1:
|
if info["action_intervention"].dim() == 1:
|
||||||
info["action_intervention"] = info["action_intervention"][self.active_dims]
|
info["action_intervention"] = info["action_intervention"][
|
||||||
|
self.active_dims
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
info["action_intervention"] = info["action_intervention"][:, self.active_dims]
|
info["action_intervention"] = info["action_intervention"][
|
||||||
|
:, self.active_dims
|
||||||
|
]
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
@ -438,7 +479,12 @@ class TimeLimitWrapper(gym.Wrapper):
|
||||||
|
|
||||||
|
|
||||||
class ImageCropResizeWrapper(gym.Wrapper):
|
class ImageCropResizeWrapper(gym.Wrapper):
|
||||||
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
env,
|
||||||
|
crop_params_dict: Dict[str, Annotated[Tuple[int], 4]],
|
||||||
|
resize_size=None,
|
||||||
|
):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.env = env
|
self.env = env
|
||||||
self.crop_params_dict = crop_params_dict
|
self.crop_params_dict = crop_params_dict
|
||||||
|
@ -450,7 +496,9 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||||
for key in crop_params_dict:
|
for key in crop_params_dict:
|
||||||
top, left, height, width = crop_params_dict[key]
|
top, left, height, width = crop_params_dict[key]
|
||||||
new_shape = (top + height, left + width)
|
new_shape = (top + height, left + width)
|
||||||
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
|
self.observation_space[key] = gym.spaces.Box(
|
||||||
|
low=0, high=255, shape=new_shape
|
||||||
|
)
|
||||||
|
|
||||||
self.resize_size = resize_size
|
self.resize_size = resize_size
|
||||||
if self.resize_size is None:
|
if self.resize_size is None:
|
||||||
|
@ -463,7 +511,9 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||||
|
|
||||||
# Check for NaNs before processing
|
# Check for NaNs before processing
|
||||||
if torch.isnan(obs[k]).any():
|
if torch.isnan(obs[k]).any():
|
||||||
logging.error(f"NaN values detected in observation {k} before crop and resize")
|
logging.error(
|
||||||
|
f"NaN values detected in observation {k} before crop and resize"
|
||||||
|
)
|
||||||
|
|
||||||
if device == torch.device("mps:0"):
|
if device == torch.device("mps:0"):
|
||||||
obs[k] = obs[k].cpu()
|
obs[k] = obs[k].cpu()
|
||||||
|
@ -473,7 +523,9 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||||
|
|
||||||
# Check for NaNs after processing
|
# Check for NaNs after processing
|
||||||
if torch.isnan(obs[k]).any():
|
if torch.isnan(obs[k]).any():
|
||||||
logging.error(f"NaN values detected in observation {k} after crop and resize")
|
logging.error(
|
||||||
|
f"NaN values detected in observation {k} after crop and resize"
|
||||||
|
)
|
||||||
|
|
||||||
obs[k] = obs[k].to(device)
|
obs[k] = obs[k].to(device)
|
||||||
|
|
||||||
|
@ -503,10 +555,14 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
||||||
observation = preprocess_observation(observation)
|
observation = preprocess_observation(observation)
|
||||||
|
|
||||||
observation = {
|
observation = {
|
||||||
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
key: observation[key].to(
|
||||||
|
self.device, non_blocking=self.device.type == "cuda"
|
||||||
|
)
|
||||||
for key in observation
|
for key in observation
|
||||||
}
|
}
|
||||||
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
|
observation = {
|
||||||
|
k: torch.tensor(v, device=self.device) for k, v in observation.items()
|
||||||
|
}
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
@ -553,18 +609,31 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||||
"Place the leader in similar pose to the follower and press space again."
|
"Place the leader in similar pose to the follower and press space again."
|
||||||
)
|
)
|
||||||
self.events["pause_policy"] = True
|
self.events["pause_policy"] = True
|
||||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
log_say(
|
||||||
|
"Human intervention stage. Get ready to take over.",
|
||||||
|
play_sounds=True,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
if self.events["pause_policy"] and not self.events["human_intervention_step"]:
|
if (
|
||||||
|
self.events["pause_policy"]
|
||||||
|
and not self.events["human_intervention_step"]
|
||||||
|
):
|
||||||
self.events["human_intervention_step"] = True
|
self.events["human_intervention_step"] = True
|
||||||
print("Space key pressed. Human intervention starting.")
|
print("Space key pressed. Human intervention starting.")
|
||||||
log_say("Starting human intervention.", play_sounds=True)
|
log_say(
|
||||||
|
"Starting human intervention.", play_sounds=True
|
||||||
|
)
|
||||||
return
|
return
|
||||||
if self.events["pause_policy"] and self.events["human_intervention_step"]:
|
if (
|
||||||
|
self.events["pause_policy"]
|
||||||
|
and self.events["human_intervention_step"]
|
||||||
|
):
|
||||||
self.events["pause_policy"] = False
|
self.events["pause_policy"] = False
|
||||||
self.events["human_intervention_step"] = False
|
self.events["human_intervention_step"] = False
|
||||||
print("Space key pressed for a third time.")
|
print("Space key pressed for a third time.")
|
||||||
log_say("Continuing with policy actions.", play_sounds=True)
|
log_say(
|
||||||
|
"Continuing with policy actions.", play_sounds=True
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error handling key press: {e}")
|
print(f"Error handling key press: {e}")
|
||||||
|
@ -572,7 +641,9 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||||
self.listener = keyboard.Listener(on_press=on_press)
|
self.listener = keyboard.Listener(on_press=on_press)
|
||||||
self.listener.start()
|
self.listener.start()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.warning("Could not import pynput. Keyboard interface will not be available.")
|
logging.warning(
|
||||||
|
"Could not import pynput. Keyboard interface will not be available."
|
||||||
|
)
|
||||||
self.listener = None
|
self.listener = None
|
||||||
|
|
||||||
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
|
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
|
||||||
|
@ -599,7 +670,9 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||||
time.sleep(0.1) # Check more frequently if desired
|
time.sleep(0.1) # Check more frequently if desired
|
||||||
|
|
||||||
# Execute the step in the underlying environment
|
# Execute the step in the underlying environment
|
||||||
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
|
obs, reward, terminated, truncated, info = self.env.step(
|
||||||
|
(policy_action, is_intervention)
|
||||||
|
)
|
||||||
|
|
||||||
# Override reward and termination if episode success event triggered
|
# Override reward and termination if episode success event triggered
|
||||||
with self.event_lock:
|
with self.event_lock:
|
||||||
|
@ -628,7 +701,10 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||||
|
|
||||||
class ResetWrapper(gym.Wrapper):
|
class ResetWrapper(gym.Wrapper):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
|
self,
|
||||||
|
env: HILSerlRobotEnv,
|
||||||
|
reset_fn: Optional[Callable[[], None]] = None,
|
||||||
|
reset_time_s: float = 5,
|
||||||
):
|
):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.reset_fn = reset_fn
|
self.reset_fn = reset_fn
|
||||||
|
@ -641,7 +717,10 @@ class ResetWrapper(gym.Wrapper):
|
||||||
if self.reset_fn is not None:
|
if self.reset_fn is not None:
|
||||||
self.reset_fn(self.env)
|
self.reset_fn(self.env)
|
||||||
else:
|
else:
|
||||||
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
|
log_say(
|
||||||
|
f"Manually reset the environment for {self.reset_time_s} seconds.",
|
||||||
|
play_sounds=True,
|
||||||
|
)
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
while time.perf_counter() - start_time < self.reset_time_s:
|
while time.perf_counter() - start_time < self.reset_time_s:
|
||||||
self.robot.teleop_step()
|
self.robot.teleop_step()
|
||||||
|
@ -654,7 +733,9 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
def observation(
|
||||||
|
self, observation: dict[str, torch.Tensor]
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
for key in observation:
|
for key in observation:
|
||||||
if "image" in key and observation[key].dim() == 3:
|
if "image" in key and observation[key].dim() == 3:
|
||||||
observation[key] = observation[key].unsqueeze(0)
|
observation[key] = observation[key].unsqueeze(0)
|
||||||
|
@ -685,6 +766,8 @@ def make_robot_env(
|
||||||
A vectorized gym environment with all the necessary wrappers applied.
|
A vectorized gym environment with all the necessary wrappers applied.
|
||||||
"""
|
"""
|
||||||
if "maniskill" in cfg.env.name:
|
if "maniskill" in cfg.env.name:
|
||||||
|
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
||||||
|
|
||||||
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||||
env = make_maniskill(
|
env = make_maniskill(
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
|
@ -703,15 +786,23 @@ def make_robot_env(
|
||||||
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
||||||
if cfg.env.wrapper.crop_params_dict is not None:
|
if cfg.env.wrapper.crop_params_dict is not None:
|
||||||
env = ImageCropResizeWrapper(
|
env = ImageCropResizeWrapper(
|
||||||
env=env, crop_params_dict=cfg.env.wrapper.crop_params_dict, resize_size=cfg.env.wrapper.resize_size
|
env=env,
|
||||||
|
crop_params_dict=cfg.env.wrapper.crop_params_dict,
|
||||||
|
resize_size=cfg.env.wrapper.resize_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add reward computation and control wrappers
|
# Add reward computation and control wrappers
|
||||||
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps)
|
env = TimeLimitWrapper(
|
||||||
|
env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps
|
||||||
|
)
|
||||||
env = KeyboardInterfaceWrapper(env=env)
|
env = KeyboardInterfaceWrapper(env=env)
|
||||||
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s)
|
env = ResetWrapper(
|
||||||
env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space)
|
env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s
|
||||||
|
)
|
||||||
|
env = JointMaskingActionSpace(
|
||||||
|
env=env, mask=cfg.env.wrapper.joint_masking_action_space
|
||||||
|
)
|
||||||
env = BatchCompitableWrapper(env=env)
|
env = BatchCompitableWrapper(env=env)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
@ -724,13 +815,19 @@ def get_classifier(pretrained_path, config_path, device="mps"):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
ClassifierConfig,
|
||||||
|
)
|
||||||
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||||
|
Classifier,
|
||||||
|
)
|
||||||
|
|
||||||
cfg = init_hydra_config(config_path)
|
cfg = init_hydra_config(config_path)
|
||||||
|
|
||||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
classifier_config.num_cameras = len(
|
||||||
|
cfg.training.image_keys
|
||||||
|
) # TODO automate these paths
|
||||||
model = Classifier(classifier_config)
|
model = Classifier(classifier_config)
|
||||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
@ -741,7 +838,9 @@ def replay_episode(env, repo_id, root=None, episode=0):
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
local_files_only = root is not None
|
local_files_only = root is not None
|
||||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
dataset = LeRobotDataset(
|
||||||
|
repo_id, root=root, episodes=[episode], local_files_only=local_files_only
|
||||||
|
)
|
||||||
actions = dataset.hf_dataset.select_columns("action")
|
actions = dataset.hf_dataset.select_columns("action")
|
||||||
|
|
||||||
for idx in range(dataset.num_frames):
|
for idx in range(dataset.num_frames):
|
||||||
|
@ -787,7 +886,8 @@ if __name__ == "__main__":
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
|
"--display-cameras",
|
||||||
|
help=("Whether to display the camera feed while the rollout is happening"),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--reward-classifier-pretrained-path",
|
"--reward-classifier-pretrained-path",
|
||||||
|
@ -801,13 +901,39 @@ if __name__ == "__main__":
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
|
parser.add_argument(
|
||||||
parser.add_argument("--env-overrides", type=str, default=None, help="Overrides for the env yaml file")
|
"--env-path", type=str, default=None, help="Path to the env yaml file"
|
||||||
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
)
|
||||||
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
|
parser.add_argument(
|
||||||
parser.add_argument("--replay-repo-id", type=str, default=None, help="Repo ID of the episode to replay")
|
"--env-overrides",
|
||||||
parser.add_argument("--replay-root", type=str, default=None, help="Root of the dataset to replay")
|
type=str,
|
||||||
parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay")
|
default=None,
|
||||||
|
help="Overrides for the env yaml file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--control-time-s",
|
||||||
|
type=float,
|
||||||
|
default=20,
|
||||||
|
help="Maximum episode length in seconds",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reset-follower-pos",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Reset follower between episodes",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--replay-repo-id",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Repo ID of the episode to replay",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--replay-root", type=str, default=None, help="Root of the dataset to replay"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--replay-episode", type=int, default=0, help="Episode to replay"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||||
|
@ -828,7 +954,9 @@ if __name__ == "__main__":
|
||||||
env.reset()
|
env.reset()
|
||||||
|
|
||||||
if args.replay_repo_id is not None:
|
if args.replay_repo_id is not None:
|
||||||
replay_episode(env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode)
|
replay_episode(
|
||||||
|
env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode
|
||||||
|
)
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
# Retrieve the robot's action space for joint commands.
|
# Retrieve the robot's action space for joint commands.
|
||||||
|
@ -849,7 +977,9 @@ if __name__ == "__main__":
|
||||||
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
||||||
|
|
||||||
# Execute the step: wrap the NumPy action in a torch tensor.
|
# Execute the step: wrap the NumPy action in a torch tensor.
|
||||||
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
|
obs, reward, terminated, truncated, info = env.step(
|
||||||
|
(torch.from_numpy(smoothed_action), False)
|
||||||
|
)
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
env.reset()
|
env.reset()
|
||||||
|
|
||||||
|
|
|
@ -22,19 +22,11 @@ package hil_serl;
|
||||||
// The Learner implements this service.
|
// The Learner implements this service.
|
||||||
service LearnerService {
|
service LearnerService {
|
||||||
// Actor -> Learner to store transitions
|
// Actor -> Learner to store transitions
|
||||||
rpc SendTransition(Transition) returns (Empty);
|
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
|
||||||
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
|
rpc StreamParameters(Empty) returns (stream Parameters);
|
||||||
|
rpc ReceiveTransitions(stream ActorInformation) returns (Empty);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ActorService: the Learner calls this to push parameters.
|
|
||||||
// The Actor implements this service.
|
|
||||||
service ActorService {
|
|
||||||
// Learner -> Actor to send new parameters
|
|
||||||
rpc StreamTransition(Empty) returns (stream ActorInformation) {};
|
|
||||||
rpc SendParameters(Parameters) returns (Empty);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
message ActorInformation {
|
message ActorInformation {
|
||||||
oneof data {
|
oneof data {
|
||||||
Transition transition = 1;
|
Transition transition = 1;
|
||||||
|
@ -42,17 +34,25 @@ message ActorInformation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum TransferState {
|
||||||
|
TRANSFER_UNKNOWN = 0;
|
||||||
|
TRANSFER_BEGIN = 1;
|
||||||
|
TRANSFER_MIDDLE = 2;
|
||||||
|
TRANSFER_END = 3;
|
||||||
|
}
|
||||||
|
|
||||||
// Messages
|
// Messages
|
||||||
message Transition {
|
message Transition {
|
||||||
bytes transition_bytes = 1;
|
bytes transition_bytes = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Parameters {
|
message Parameters {
|
||||||
bytes parameter_bytes = 1;
|
TransferState transfer_state = 1;
|
||||||
|
bytes parameter_bytes = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message InteractionMessage {
|
message InteractionMessage {
|
||||||
bytes interaction_message_bytes = 1;
|
bytes interaction_message_bytes = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Empty {}
|
message Empty {}
|
||||||
|
|
|
@ -24,25 +24,25 @@ _sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"%\n\nParameters\x12\x17\n\x0fparameter_bytes\x18\x01 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty2\x92\x01\n\x0eLearnerService\x12\x37\n\x0eSendTransition\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty2\x8c\x01\n\x0c\x41\x63torService\x12\x43\n\x10StreamTransition\x12\x0f.hil_serl.Empty\x1a\x1a.hil_serl.ActorInformation\"\x00\x30\x01\x12\x37\n\x0eSendParameters\x12\x14.hil_serl.Parameters\x1a\x0f.hil_serl.Emptyb\x06proto3')
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"V\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x17\n\x0fparameter_bytes\x18\x02 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdb\x01\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12\x43\n\x12ReceiveTransitions\x12\x1a.hil_serl.ActorInformation\x1a\x0f.hil_serl.Empty(\x01\x62\x06proto3')
|
||||||
|
|
||||||
_globals = globals()
|
_globals = globals()
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
|
||||||
if not _descriptor._USE_C_DESCRIPTORS:
|
if not _descriptor._USE_C_DESCRIPTORS:
|
||||||
DESCRIPTOR._loaded_options = None
|
DESCRIPTOR._loaded_options = None
|
||||||
|
_globals['_TRANSFERSTATE']._serialized_start=355
|
||||||
|
_globals['_TRANSFERSTATE']._serialized_end=451
|
||||||
_globals['_ACTORINFORMATION']._serialized_start=28
|
_globals['_ACTORINFORMATION']._serialized_start=28
|
||||||
_globals['_ACTORINFORMATION']._serialized_end=159
|
_globals['_ACTORINFORMATION']._serialized_end=159
|
||||||
_globals['_TRANSITION']._serialized_start=161
|
_globals['_TRANSITION']._serialized_start=161
|
||||||
_globals['_TRANSITION']._serialized_end=199
|
_globals['_TRANSITION']._serialized_end=199
|
||||||
_globals['_PARAMETERS']._serialized_start=201
|
_globals['_PARAMETERS']._serialized_start=201
|
||||||
_globals['_PARAMETERS']._serialized_end=238
|
_globals['_PARAMETERS']._serialized_end=287
|
||||||
_globals['_INTERACTIONMESSAGE']._serialized_start=240
|
_globals['_INTERACTIONMESSAGE']._serialized_start=289
|
||||||
_globals['_INTERACTIONMESSAGE']._serialized_end=295
|
_globals['_INTERACTIONMESSAGE']._serialized_end=344
|
||||||
_globals['_EMPTY']._serialized_start=297
|
_globals['_EMPTY']._serialized_start=346
|
||||||
_globals['_EMPTY']._serialized_end=304
|
_globals['_EMPTY']._serialized_end=353
|
||||||
_globals['_LEARNERSERVICE']._serialized_start=307
|
_globals['_LEARNERSERVICE']._serialized_start=454
|
||||||
_globals['_LEARNERSERVICE']._serialized_end=453
|
_globals['_LEARNERSERVICE']._serialized_end=673
|
||||||
_globals['_ACTORSERVICE']._serialized_start=456
|
|
||||||
_globals['_ACTORSERVICE']._serialized_end=596
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
|
|
@ -36,16 +36,21 @@ class LearnerServiceStub(object):
|
||||||
Args:
|
Args:
|
||||||
channel: A grpc.Channel.
|
channel: A grpc.Channel.
|
||||||
"""
|
"""
|
||||||
self.SendTransition = channel.unary_unary(
|
|
||||||
'/hil_serl.LearnerService/SendTransition',
|
|
||||||
request_serializer=hilserl__pb2.Transition.SerializeToString,
|
|
||||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
|
||||||
_registered_method=True)
|
|
||||||
self.SendInteractionMessage = channel.unary_unary(
|
self.SendInteractionMessage = channel.unary_unary(
|
||||||
'/hil_serl.LearnerService/SendInteractionMessage',
|
'/hil_serl.LearnerService/SendInteractionMessage',
|
||||||
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
||||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
|
self.StreamParameters = channel.unary_stream(
|
||||||
|
'/hil_serl.LearnerService/StreamParameters',
|
||||||
|
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||||
|
response_deserializer=hilserl__pb2.Parameters.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.ReceiveTransitions = channel.stream_unary(
|
||||||
|
'/hil_serl.LearnerService/ReceiveTransitions',
|
||||||
|
request_serializer=hilserl__pb2.ActorInformation.SerializeToString,
|
||||||
|
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
class LearnerServiceServicer(object):
|
class LearnerServiceServicer(object):
|
||||||
|
@ -53,14 +58,20 @@ class LearnerServiceServicer(object):
|
||||||
The Learner implements this service.
|
The Learner implements this service.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def SendTransition(self, request, context):
|
def SendInteractionMessage(self, request, context):
|
||||||
"""Actor -> Learner to store transitions
|
"""Actor -> Learner to store transitions
|
||||||
"""
|
"""
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
context.set_details('Method not implemented!')
|
context.set_details('Method not implemented!')
|
||||||
raise NotImplementedError('Method not implemented!')
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
def SendInteractionMessage(self, request, context):
|
def StreamParameters(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def ReceiveTransitions(self, request_iterator, context):
|
||||||
"""Missing associated documentation comment in .proto file."""
|
"""Missing associated documentation comment in .proto file."""
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
context.set_details('Method not implemented!')
|
context.set_details('Method not implemented!')
|
||||||
|
@ -69,16 +80,21 @@ class LearnerServiceServicer(object):
|
||||||
|
|
||||||
def add_LearnerServiceServicer_to_server(servicer, server):
|
def add_LearnerServiceServicer_to_server(servicer, server):
|
||||||
rpc_method_handlers = {
|
rpc_method_handlers = {
|
||||||
'SendTransition': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.SendTransition,
|
|
||||||
request_deserializer=hilserl__pb2.Transition.FromString,
|
|
||||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
|
||||||
),
|
|
||||||
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
|
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
|
||||||
servicer.SendInteractionMessage,
|
servicer.SendInteractionMessage,
|
||||||
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
||||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||||
),
|
),
|
||||||
|
'StreamParameters': grpc.unary_stream_rpc_method_handler(
|
||||||
|
servicer.StreamParameters,
|
||||||
|
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||||
|
response_serializer=hilserl__pb2.Parameters.SerializeToString,
|
||||||
|
),
|
||||||
|
'ReceiveTransitions': grpc.stream_unary_rpc_method_handler(
|
||||||
|
servicer.ReceiveTransitions,
|
||||||
|
request_deserializer=hilserl__pb2.ActorInformation.FromString,
|
||||||
|
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
'hil_serl.LearnerService', rpc_method_handlers)
|
'hil_serl.LearnerService', rpc_method_handlers)
|
||||||
|
@ -92,33 +108,6 @@ class LearnerService(object):
|
||||||
The Learner implements this service.
|
The Learner implements this service.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def SendTransition(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(
|
|
||||||
request,
|
|
||||||
target,
|
|
||||||
'/hil_serl.LearnerService/SendTransition',
|
|
||||||
hilserl__pb2.Transition.SerializeToString,
|
|
||||||
hilserl__pb2.Empty.FromString,
|
|
||||||
options,
|
|
||||||
channel_credentials,
|
|
||||||
insecure,
|
|
||||||
call_credentials,
|
|
||||||
compression,
|
|
||||||
wait_for_ready,
|
|
||||||
timeout,
|
|
||||||
metadata,
|
|
||||||
_registered_method=True)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def SendInteractionMessage(request,
|
def SendInteractionMessage(request,
|
||||||
target,
|
target,
|
||||||
|
@ -146,76 +135,8 @@ class LearnerService(object):
|
||||||
metadata,
|
metadata,
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
class ActorServiceStub(object):
|
|
||||||
"""ActorService: the Learner calls this to push parameters.
|
|
||||||
The Actor implements this service.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, channel):
|
|
||||||
"""Constructor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel: A grpc.Channel.
|
|
||||||
"""
|
|
||||||
self.StreamTransition = channel.unary_stream(
|
|
||||||
'/hil_serl.ActorService/StreamTransition',
|
|
||||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
|
||||||
response_deserializer=hilserl__pb2.ActorInformation.FromString,
|
|
||||||
_registered_method=True)
|
|
||||||
self.SendParameters = channel.unary_unary(
|
|
||||||
'/hil_serl.ActorService/SendParameters',
|
|
||||||
request_serializer=hilserl__pb2.Parameters.SerializeToString,
|
|
||||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
|
||||||
_registered_method=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ActorServiceServicer(object):
|
|
||||||
"""ActorService: the Learner calls this to push parameters.
|
|
||||||
The Actor implements this service.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def StreamTransition(self, request, context):
|
|
||||||
"""Learner -> Actor to send new parameters
|
|
||||||
"""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def SendParameters(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
|
|
||||||
def add_ActorServiceServicer_to_server(servicer, server):
|
|
||||||
rpc_method_handlers = {
|
|
||||||
'StreamTransition': grpc.unary_stream_rpc_method_handler(
|
|
||||||
servicer.StreamTransition,
|
|
||||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
|
||||||
response_serializer=hilserl__pb2.ActorInformation.SerializeToString,
|
|
||||||
),
|
|
||||||
'SendParameters': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.SendParameters,
|
|
||||||
request_deserializer=hilserl__pb2.Parameters.FromString,
|
|
||||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
|
||||||
'hil_serl.ActorService', rpc_method_handlers)
|
|
||||||
server.add_generic_rpc_handlers((generic_handler,))
|
|
||||||
server.add_registered_method_handlers('hil_serl.ActorService', rpc_method_handlers)
|
|
||||||
|
|
||||||
|
|
||||||
# This class is part of an EXPERIMENTAL API.
|
|
||||||
class ActorService(object):
|
|
||||||
"""ActorService: the Learner calls this to push parameters.
|
|
||||||
The Actor implements this service.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def StreamTransition(request,
|
def StreamParameters(request,
|
||||||
target,
|
target,
|
||||||
options=(),
|
options=(),
|
||||||
channel_credentials=None,
|
channel_credentials=None,
|
||||||
|
@ -228,9 +149,9 @@ class ActorService(object):
|
||||||
return grpc.experimental.unary_stream(
|
return grpc.experimental.unary_stream(
|
||||||
request,
|
request,
|
||||||
target,
|
target,
|
||||||
'/hil_serl.ActorService/StreamTransition',
|
'/hil_serl.LearnerService/StreamParameters',
|
||||||
hilserl__pb2.Empty.SerializeToString,
|
hilserl__pb2.Empty.SerializeToString,
|
||||||
hilserl__pb2.ActorInformation.FromString,
|
hilserl__pb2.Parameters.FromString,
|
||||||
options,
|
options,
|
||||||
channel_credentials,
|
channel_credentials,
|
||||||
insecure,
|
insecure,
|
||||||
|
@ -242,7 +163,7 @@ class ActorService(object):
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def SendParameters(request,
|
def ReceiveTransitions(request_iterator,
|
||||||
target,
|
target,
|
||||||
options=(),
|
options=(),
|
||||||
channel_credentials=None,
|
channel_credentials=None,
|
||||||
|
@ -252,11 +173,11 @@ class ActorService(object):
|
||||||
wait_for_ready=None,
|
wait_for_ready=None,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
metadata=None):
|
metadata=None):
|
||||||
return grpc.experimental.unary_unary(
|
return grpc.experimental.stream_unary(
|
||||||
request,
|
request_iterator,
|
||||||
target,
|
target,
|
||||||
'/hil_serl.ActorService/SendParameters',
|
'/hil_serl.LearnerService/ReceiveTransitions',
|
||||||
hilserl__pb2.Parameters.SerializeToString,
|
hilserl__pb2.ActorInformation.SerializeToString,
|
||||||
hilserl__pb2.Empty.FromString,
|
hilserl__pb2.Empty.FromString,
|
||||||
options,
|
options,
|
||||||
channel_credentials,
|
channel_credentials,
|
||||||
|
|
|
@ -14,19 +14,19 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
|
||||||
import queue
|
import queue
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from threading import Lock, Thread
|
from threading import Lock, Thread
|
||||||
|
import signal
|
||||||
|
from threading import Event
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
# Import generated stubs
|
# Import generated stubs
|
||||||
import hilserl_pb2 # type: ignore
|
|
||||||
import hilserl_pb2_grpc # type: ignore
|
import hilserl_pb2_grpc # type: ignore
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
|
@ -55,10 +55,11 @@ from lerobot.common.utils.utils import (
|
||||||
from lerobot.scripts.server.buffer import (
|
from lerobot.scripts.server.buffer import (
|
||||||
ReplayBuffer,
|
ReplayBuffer,
|
||||||
concatenate_batch_transitions,
|
concatenate_batch_transitions,
|
||||||
move_state_dict_to_device,
|
|
||||||
move_transition_to_device,
|
move_transition_to_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from lerobot.scripts.server import learner_service
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
transition_queue = queue.Queue()
|
transition_queue = queue.Queue()
|
||||||
|
@ -77,9 +78,13 @@ def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||||
# if resume == True
|
# if resume == True
|
||||||
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
||||||
if not checkpoint_dir.exists():
|
if not checkpoint_dir.exists():
|
||||||
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
raise RuntimeError(
|
||||||
|
f"No model checkpoint found in {checkpoint_dir} for resume=True"
|
||||||
|
)
|
||||||
|
|
||||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
checkpoint_cfg_path = str(
|
||||||
|
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
|
||||||
|
)
|
||||||
logging.info(
|
logging.info(
|
||||||
colored(
|
colored(
|
||||||
"Resume=True detected, resuming previous run",
|
"Resume=True detected, resuming previous run",
|
||||||
|
@ -112,7 +117,9 @@ def load_training_state(
|
||||||
if not cfg.resume:
|
if not cfg.resume:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
|
training_state = torch.load(
|
||||||
|
logger.last_checkpoint_dir / logger.training_state_file_name
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(training_state["optimizer"], dict):
|
if isinstance(training_state["optimizer"], dict):
|
||||||
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
|
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
|
||||||
|
@ -126,7 +133,9 @@ def load_training_state(
|
||||||
|
|
||||||
|
|
||||||
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(
|
||||||
|
p.numel() for p in policy.parameters() if p.requires_grad
|
||||||
|
)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
log_output_dir(out_dir)
|
log_output_dir(out_dir)
|
||||||
|
@ -136,7 +145,9 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
|
|
||||||
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
|
def initialize_replay_buffer(
|
||||||
|
cfg: DictConfig, logger: Logger, device: str
|
||||||
|
) -> ReplayBuffer:
|
||||||
if not cfg.resume:
|
if not cfg.resume:
|
||||||
return ReplayBuffer(
|
return ReplayBuffer(
|
||||||
capacity=cfg.training.online_buffer_capacity,
|
capacity=cfg.training.online_buffer_capacity,
|
||||||
|
@ -146,7 +157,9 @@ def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> Re
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
|
repo_id=cfg.dataset_repo_id,
|
||||||
|
local_files_only=True,
|
||||||
|
root=logger.log_dir / "dataset",
|
||||||
)
|
)
|
||||||
return ReplayBuffer.from_lerobot_dataset(
|
return ReplayBuffer.from_lerobot_dataset(
|
||||||
lerobot_dataset=dataset,
|
lerobot_dataset=dataset,
|
||||||
|
@ -168,18 +181,10 @@ def start_learner_threads(
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
resume_optimization_step: int | None = None,
|
resume_optimization_step: int | None = None,
|
||||||
resume_interaction_step: int | None = None,
|
resume_interaction_step: int | None = None,
|
||||||
|
shutdown_event: Event | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
actor_ip = cfg.actor_learner_config.actor_ip
|
host = cfg.actor_learner_config.learner_host
|
||||||
port = cfg.actor_learner_config.port
|
port = cfg.actor_learner_config.learner_port
|
||||||
|
|
||||||
server_thread = Thread(
|
|
||||||
target=stream_transitions_from_actor,
|
|
||||||
args=(
|
|
||||||
actor_ip,
|
|
||||||
port,
|
|
||||||
),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
transition_thread = Thread(
|
transition_thread = Thread(
|
||||||
target=add_actor_information_and_train,
|
target=add_actor_information_and_train,
|
||||||
|
@ -196,95 +201,56 @@ def start_learner_threads(
|
||||||
logger,
|
logger,
|
||||||
resume_optimization_step,
|
resume_optimization_step,
|
||||||
resume_interaction_step,
|
resume_interaction_step,
|
||||||
|
shutdown_event,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
param_push_thread = Thread(
|
|
||||||
target=learner_push_parameters,
|
|
||||||
args=(policy, policy_lock, actor_ip, port, 15),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
server_thread.start()
|
|
||||||
transition_thread.start()
|
transition_thread.start()
|
||||||
param_push_thread.start()
|
|
||||||
param_push_thread.join()
|
service = learner_service.LearnerService(
|
||||||
|
shutdown_event,
|
||||||
|
policy,
|
||||||
|
policy_lock,
|
||||||
|
cfg.actor_learner_config.policy_parameters_push_frequency,
|
||||||
|
transition_queue,
|
||||||
|
interaction_message_queue,
|
||||||
|
)
|
||||||
|
server = start_learner_server(service, host, port)
|
||||||
|
|
||||||
|
shutdown_event.wait()
|
||||||
|
server.stop(learner_service.STUTDOWN_TIMEOUT)
|
||||||
|
logging.info("[LEARNER] gRPC server stopped")
|
||||||
|
|
||||||
transition_thread.join()
|
transition_thread.join()
|
||||||
server_thread.join()
|
logging.info("[LEARNER] Transition thread stopped")
|
||||||
|
|
||||||
|
|
||||||
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
|
def start_learner_server(
|
||||||
"""
|
service: learner_service.LearnerService,
|
||||||
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
|
host="0.0.0.0",
|
||||||
|
port=50051,
|
||||||
This function establishes a gRPC connection with the given `host` and `port`, then continuously
|
) -> grpc.server:
|
||||||
streams transition data from the `ActorServiceStub`. The received transition data is deserialized
|
server = grpc.server(
|
||||||
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized
|
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
|
||||||
and stored in a separate queue (`interaction_message_queue`).
|
options=[
|
||||||
|
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||||
Args:
|
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||||
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`.
|
],
|
||||||
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
|
|
||||||
|
|
||||||
"""
|
|
||||||
# NOTE: This is waiting for the handshake to be done
|
|
||||||
# In the future we will do it in a canonical way with a proper handshake
|
|
||||||
time.sleep(10)
|
|
||||||
channel = grpc.insecure_channel(
|
|
||||||
f"{host}:{port}",
|
|
||||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
|
||||||
)
|
)
|
||||||
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
|
||||||
for response in stub.StreamTransition(hilserl_pb2.Empty()):
|
service,
|
||||||
if response.HasField("transition"):
|
server,
|
||||||
buffer = io.BytesIO(response.transition.transition_bytes)
|
)
|
||||||
transition = torch.load(buffer)
|
server.add_insecure_port(f"{host}:{port}")
|
||||||
transition_queue.put(transition)
|
server.start()
|
||||||
if response.HasField("interaction_message"):
|
logging.info("[LEARNER] gRPC server started")
|
||||||
content = pickle.loads(response.interaction_message.interaction_message_bytes)
|
|
||||||
interaction_message_queue.put(content)
|
return server
|
||||||
|
|
||||||
|
|
||||||
def learner_push_parameters(
|
def check_nan_in_transition(
|
||||||
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5
|
observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
As a client, connect to the Actor's gRPC server (ActorService)
|
|
||||||
and periodically push new parameters.
|
|
||||||
"""
|
|
||||||
time.sleep(10)
|
|
||||||
channel = grpc.insecure_channel(
|
|
||||||
f"{actor_host}:{actor_port}",
|
|
||||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
|
||||||
)
|
|
||||||
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
with policy_lock:
|
|
||||||
params_dict = policy.actor.state_dict()
|
|
||||||
if policy.config.vision_encoder_name is not None:
|
|
||||||
if policy.config.freeze_vision_encoder:
|
|
||||||
params_dict: dict[str, torch.Tensor] = {
|
|
||||||
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
|
||||||
)
|
|
||||||
|
|
||||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
|
||||||
# Serialize
|
|
||||||
buf = io.BytesIO()
|
|
||||||
torch.save(params_dict, buf)
|
|
||||||
params_bytes = buf.getvalue()
|
|
||||||
|
|
||||||
# Push them to the Actor's "SendParameters" method
|
|
||||||
logging.info("[LEARNER] Publishing parameters to the Actor")
|
|
||||||
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
|
|
||||||
time.sleep(seconds_between_pushes)
|
|
||||||
|
|
||||||
|
|
||||||
def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
|
|
||||||
for k in observations:
|
for k in observations:
|
||||||
if torch.isnan(observations[k]).any():
|
if torch.isnan(observations[k]).any():
|
||||||
logging.error(f"observations[{k}] contains NaN values")
|
logging.error(f"observations[{k}] contains NaN values")
|
||||||
|
@ -307,6 +273,7 @@ def add_actor_information_and_train(
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
resume_optimization_step: int | None = None,
|
resume_optimization_step: int | None = None,
|
||||||
resume_interaction_step: int | None = None,
|
resume_interaction_step: int | None = None,
|
||||||
|
shutdown_event: Event | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Handles data transfer from the actor to the learner, manages training updates,
|
Handles data transfer from the actor to the learner, manages training updates,
|
||||||
|
@ -338,6 +305,7 @@ def add_actor_information_and_train(
|
||||||
logger (Logger): Logger instance for tracking training progress.
|
logger (Logger): Logger instance for tracking training progress.
|
||||||
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
|
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
|
||||||
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
|
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
|
||||||
|
shutdown_event (Event | None): Event to signal shutdown.
|
||||||
"""
|
"""
|
||||||
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||||
|
@ -345,9 +313,17 @@ def add_actor_information_and_train(
|
||||||
time.time()
|
time.time()
|
||||||
logging.info("Starting learner thread")
|
logging.info("Starting learner thread")
|
||||||
interaction_message, transition = None, None
|
interaction_message, transition = None, None
|
||||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
optimization_step = (
|
||||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
resume_optimization_step if resume_optimization_step is not None else 0
|
||||||
|
)
|
||||||
|
interaction_step_shift = (
|
||||||
|
resume_interaction_step if resume_interaction_step is not None else 0
|
||||||
|
)
|
||||||
while True:
|
while True:
|
||||||
|
if shutdown_event is not None and shutdown_event.is_set():
|
||||||
|
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
||||||
|
break
|
||||||
|
|
||||||
while not transition_queue.empty():
|
while not transition_queue.empty():
|
||||||
transition_list = transition_queue.get()
|
transition_list = transition_queue.get()
|
||||||
for transition in transition_list:
|
for transition in transition_list:
|
||||||
|
@ -361,7 +337,9 @@ def add_actor_information_and_train(
|
||||||
interaction_message = interaction_message_queue.get()
|
interaction_message = interaction_message_queue.get()
|
||||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||||
interaction_message["Interaction step"] += interaction_step_shift
|
interaction_message["Interaction step"] += interaction_step_shift
|
||||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
logger.log_dict(
|
||||||
|
interaction_message, mode="train", custom_step_key="Interaction step"
|
||||||
|
)
|
||||||
# logging.info(f"Interaction message: {interaction_message}")
|
# logging.info(f"Interaction message: {interaction_message}")
|
||||||
|
|
||||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||||
|
@ -383,7 +361,9 @@ def add_actor_information_and_train(
|
||||||
observations = batch["state"]
|
observations = batch["state"]
|
||||||
next_observations = batch["next_state"]
|
next_observations = batch["next_state"]
|
||||||
done = batch["done"]
|
done = batch["done"]
|
||||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
check_nan_in_transition(
|
||||||
|
observations=observations, actions=actions, next_state=next_observations
|
||||||
|
)
|
||||||
|
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
loss_critic = policy.compute_loss_critic(
|
loss_critic = policy.compute_loss_critic(
|
||||||
|
@ -411,7 +391,9 @@ def add_actor_information_and_train(
|
||||||
next_observations = batch["next_state"]
|
next_observations = batch["next_state"]
|
||||||
done = batch["done"]
|
done = batch["done"]
|
||||||
|
|
||||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
check_nan_in_transition(
|
||||||
|
observations=observations, actions=actions, next_state=next_observations
|
||||||
|
)
|
||||||
|
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
loss_critic = policy.compute_loss_critic(
|
loss_critic = policy.compute_loss_critic(
|
||||||
|
@ -439,7 +421,9 @@ def add_actor_information_and_train(
|
||||||
|
|
||||||
training_infos["loss_actor"] = loss_actor.item()
|
training_infos["loss_actor"] = loss_actor.item()
|
||||||
|
|
||||||
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
loss_temperature = policy.compute_loss_temperature(
|
||||||
|
observations=observations
|
||||||
|
)
|
||||||
optimizers["temperature"].zero_grad()
|
optimizers["temperature"].zero_grad()
|
||||||
loss_temperature.backward()
|
loss_temperature.backward()
|
||||||
optimizers["temperature"].step()
|
optimizers["temperature"].step()
|
||||||
|
@ -453,9 +437,13 @@ def add_actor_information_and_train(
|
||||||
# logging.info(f"Training infos: {training_infos}")
|
# logging.info(f"Training infos: {training_infos}")
|
||||||
|
|
||||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
frequency_for_one_optimization_step = 1 / (
|
||||||
|
time_for_one_optimization_step + 1e-9
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
logging.info(
|
||||||
|
f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}"
|
||||||
|
)
|
||||||
|
|
||||||
logger.log_dict(
|
logger.log_dict(
|
||||||
{
|
{
|
||||||
|
@ -471,7 +459,8 @@ def add_actor_information_and_train(
|
||||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||||
|
|
||||||
if cfg.training.save_checkpoint and (
|
if cfg.training.save_checkpoint and (
|
||||||
optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps
|
optimization_step % cfg.training.save_freq == 0
|
||||||
|
or optimization_step == cfg.training.online_steps
|
||||||
):
|
):
|
||||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||||
|
@ -479,7 +468,9 @@ def add_actor_information_and_train(
|
||||||
_num_digits = max(6, len(str(cfg.training.online_steps)))
|
_num_digits = max(6, len(str(cfg.training.online_steps)))
|
||||||
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
||||||
interaction_step = (
|
interaction_step = (
|
||||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
interaction_message["Interaction step"]
|
||||||
|
if interaction_message is not None
|
||||||
|
else 0
|
||||||
)
|
)
|
||||||
logger.save_checkpoint(
|
logger.save_checkpoint(
|
||||||
optimization_step,
|
optimization_step,
|
||||||
|
@ -538,7 +529,9 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||||
optimizer_critic = torch.optim.Adam(
|
optimizer_critic = torch.optim.Adam(
|
||||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||||
)
|
)
|
||||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
optimizer_temperature = torch.optim.Adam(
|
||||||
|
params=[policy.log_alpha], lr=policy.config.critic_lr
|
||||||
|
)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
optimizers = {
|
optimizers = {
|
||||||
"actor": optimizer_actor,
|
"actor": optimizer_actor,
|
||||||
|
@ -580,14 +573,18 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||||
# Hack: But if we do online traning, we do not need dataset_stats
|
# Hack: But if we do online traning, we do not need dataset_stats
|
||||||
dataset_stats=None,
|
dataset_stats=None,
|
||||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||||
|
if cfg.resume
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
# compile policy
|
# compile policy
|
||||||
policy = torch.compile(policy)
|
policy = torch.compile(policy)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
|
resume_optimization_step, resume_interaction_step = load_training_state(
|
||||||
|
cfg, logger, optimizers
|
||||||
|
)
|
||||||
|
|
||||||
log_training_info(cfg, out_dir, policy)
|
log_training_info(cfg, out_dir, policy)
|
||||||
|
|
||||||
|
@ -599,7 +596,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info("make_dataset offline buffer")
|
logging.info("make_dataset offline buffer")
|
||||||
offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
logging.info("Convertion to a offline replay buffer")
|
logging.info("Convertion to a offline replay buffer")
|
||||||
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
|
active_action_dims = [
|
||||||
|
i
|
||||||
|
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||||
|
if mask
|
||||||
|
]
|
||||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||||
offline_dataset,
|
offline_dataset,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -609,6 +610,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
)
|
)
|
||||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||||
|
|
||||||
|
shutdown_event = Event()
|
||||||
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
print(
|
||||||
|
f"\nReceived signal {signal.Signals(signum).name}. Initiating learner shutdown..."
|
||||||
|
)
|
||||||
|
shutdown_event.set()
|
||||||
|
|
||||||
|
# Register signal handlers
|
||||||
|
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler) # Termination request
|
||||||
|
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed
|
||||||
|
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||||
|
|
||||||
start_learner_threads(
|
start_learner_threads(
|
||||||
cfg,
|
cfg,
|
||||||
device,
|
device,
|
||||||
|
@ -621,6 +636,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logger,
|
logger,
|
||||||
resume_optimization_step,
|
resume_optimization_step,
|
||||||
resume_interaction_step,
|
resume_interaction_step,
|
||||||
|
shutdown_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,113 @@
|
||||||
|
import hilserl_pb2 # type: ignore
|
||||||
|
import hilserl_pb2_grpc # type: ignore
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from threading import Lock, Event
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import io
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from lerobot.scripts.server.buffer import (
|
||||||
|
move_state_dict_to_device,
|
||||||
|
bytes_buffer_size,
|
||||||
|
state_to_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||||
|
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||||
|
MAX_WORKERS = 10
|
||||||
|
STUTDOWN_TIMEOUT = 10
|
||||||
|
|
||||||
|
|
||||||
|
class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
shutdown_event: Event,
|
||||||
|
policy: nn.Module,
|
||||||
|
policy_lock: Lock,
|
||||||
|
seconds_between_pushes: float,
|
||||||
|
transition_queue: queue.Queue,
|
||||||
|
interaction_message_queue: queue.Queue,
|
||||||
|
):
|
||||||
|
self.shutdown_event = shutdown_event
|
||||||
|
self.policy = policy
|
||||||
|
self.policy_lock = policy_lock
|
||||||
|
self.seconds_between_pushes = seconds_between_pushes
|
||||||
|
self.transition_queue = transition_queue
|
||||||
|
self.interaction_message_queue = interaction_message_queue
|
||||||
|
|
||||||
|
def _get_policy_state(self):
|
||||||
|
with self.policy_lock:
|
||||||
|
params_dict = self.policy.actor.state_dict()
|
||||||
|
if self.policy.config.vision_encoder_name is not None:
|
||||||
|
if self.policy.config.freeze_vision_encoder:
|
||||||
|
params_dict: dict[str, torch.Tensor] = {
|
||||||
|
k: v
|
||||||
|
for k, v in params_dict.items()
|
||||||
|
if not k.startswith("encoder.")
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||||
|
)
|
||||||
|
|
||||||
|
return move_state_dict_to_device(params_dict, device="cpu")
|
||||||
|
|
||||||
|
def _send_bytes(self, buffer: bytes):
|
||||||
|
size_in_bytes = bytes_buffer_size(buffer)
|
||||||
|
|
||||||
|
sent_bytes = 0
|
||||||
|
|
||||||
|
logging.info(f"Model state size {size_in_bytes/1024/1024} MB with")
|
||||||
|
|
||||||
|
while sent_bytes < size_in_bytes:
|
||||||
|
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
||||||
|
|
||||||
|
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
||||||
|
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
|
||||||
|
elif sent_bytes == 0:
|
||||||
|
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
|
||||||
|
|
||||||
|
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
||||||
|
chunk = buffer.read(size_to_read)
|
||||||
|
|
||||||
|
yield hilserl_pb2.Parameters(
|
||||||
|
transfer_state=transfer_state, parameter_bytes=chunk
|
||||||
|
)
|
||||||
|
sent_bytes += size_to_read
|
||||||
|
logging.info(
|
||||||
|
f"[Learner] Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"[LEARNER] Published {sent_bytes/1024/1024} MB to the Actor")
|
||||||
|
|
||||||
|
def StreamParameters(self, request, context):
|
||||||
|
# TODO: authorize the request
|
||||||
|
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||||
|
|
||||||
|
while not self.shutdown_event.is_set():
|
||||||
|
logging.debug("[LEARNER] Push parameters to the Actor")
|
||||||
|
state_dict = self._get_policy_state()
|
||||||
|
|
||||||
|
with state_to_bytes(state_dict) as buffer:
|
||||||
|
yield from self._send_bytes(buffer)
|
||||||
|
|
||||||
|
self.shutdown_event.wait(self.seconds_between_pushes)
|
||||||
|
|
||||||
|
def ReceiveTransitions(self, request_iterator, context):
|
||||||
|
# TODO: authorize the request
|
||||||
|
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||||
|
|
||||||
|
for request in request_iterator:
|
||||||
|
logging.debug("[LEARNER] Received request")
|
||||||
|
if request.HasField("transition"):
|
||||||
|
buffer = io.BytesIO(request.transition.transition_bytes)
|
||||||
|
transition = torch.load(buffer)
|
||||||
|
self.transition_queue.put(transition)
|
||||||
|
if request.HasField("interaction_message"):
|
||||||
|
content = pickle.loads(
|
||||||
|
request.interaction_message.interaction_message_bytes
|
||||||
|
)
|
||||||
|
self.interaction_message_queue.put(content)
|
File diff suppressed because it is too large
Load Diff
|
@ -53,16 +53,19 @@ dependencies = [
|
||||||
"einops>=0.8.0",
|
"einops>=0.8.0",
|
||||||
"flask>=3.0.3",
|
"flask>=3.0.3",
|
||||||
"gdown>=5.1.0",
|
"gdown>=5.1.0",
|
||||||
|
"grpcio>=1.70.0",
|
||||||
"gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
"gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
||||||
"h5py>=3.10.0",
|
"h5py>=3.10.0",
|
||||||
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
|
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
|
||||||
"imageio[ffmpeg]>=2.34.0",
|
"imageio[ffmpeg]>=2.34.0",
|
||||||
"jsonlines>=4.0.0",
|
"jsonlines>=4.0.0",
|
||||||
|
"mani-skill>=3.0.0b18",
|
||||||
"numba>=0.59.0",
|
"numba>=0.59.0",
|
||||||
"omegaconf>=2.3.0",
|
"omegaconf>=2.3.0",
|
||||||
"opencv-python>=4.9.0",
|
"opencv-python>=4.9.0",
|
||||||
"packaging>=24.2",
|
"packaging>=24.2",
|
||||||
"av>=12.0.5",
|
"av>=12.0.5",
|
||||||
|
"protobuf>=5.29.3",
|
||||||
"pymunk>=6.6.0",
|
"pymunk>=6.6.0",
|
||||||
"pynput>=1.7.7",
|
"pynput>=1.7.7",
|
||||||
"pyzmq>=26.2.1",
|
"pyzmq>=26.2.1",
|
||||||
|
@ -87,6 +90,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||||
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"]
|
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"]
|
||||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||||
|
mani_skill = ["mani-skill"]
|
||||||
pi0 = ["transformers>=4.48.0"]
|
pi0 = ["transformers>=4.48.0"]
|
||||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||||
stretch = [
|
stretch = [
|
||||||
|
@ -107,7 +111,33 @@ requires-poetry = ">=2.1"
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 110
|
line-length = 110
|
||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
exclude = ["tests/artifacts/**/*.safetensors"]
|
exclude = [
|
||||||
|
"tests/data",
|
||||||
|
".bzr",
|
||||||
|
".direnv",
|
||||||
|
".eggs",
|
||||||
|
".git",
|
||||||
|
".git-rewrite",
|
||||||
|
".hg",
|
||||||
|
".mypy_cache",
|
||||||
|
".nox",
|
||||||
|
".pants.d",
|
||||||
|
".pytype",
|
||||||
|
".ruff_cache",
|
||||||
|
".svn",
|
||||||
|
".tox",
|
||||||
|
".venv",
|
||||||
|
"__pypackages__",
|
||||||
|
"_build",
|
||||||
|
"buck-out",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"node_modules",
|
||||||
|
"venv",
|
||||||
|
"*_pb2.py",
|
||||||
|
"*_pb2_grpc.py",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
# Exclude files/directories from Ruff
|
||||||
|
exclude = [
|
||||||
|
"*_pb2.py", # Ignore all protobuf generated files
|
||||||
|
"*_pb2_grpc.py", # Ignore all gRPC generated files
|
||||||
|
"lerobot/scripts/server/hilserl_pb2.py", # Ignore specific file
|
||||||
|
".git",
|
||||||
|
".env",
|
||||||
|
".venv",
|
||||||
|
"build",
|
||||||
|
"dist"
|
||||||
|
]
|
Loading…
Reference in New Issue