Several fixes to move the actor_server and learner_server code from the maniskill environment to the real robot environment.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-10 16:03:39 +01:00
parent b63738674c
commit d51374ce12
10 changed files with 457 additions and 318 deletions

View File

@ -39,6 +39,12 @@ class SACConfig:
"observation.environment_state": "min_max", "observation.environment_state": "min_max",
} }
) )
input_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"observation.image": {"mean": [[0.485, 0.456, 0.406]], "std": [[0.229, 0.224, 0.225]]},
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"}) output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
output_normalization_params: dict[str, dict[str, list[float]]] = field( output_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: { default_factory=lambda: {

View File

@ -51,18 +51,20 @@ class SACPolicy(
if config is None: if config is None:
config = SACConfig() config = SACConfig()
self.config = config self.config = config
if config.input_normalization_modes is not None: if config.input_normalization_modes is not None:
input_normalization_params = _convert_normalization_params_to_tensor(
config.input_normalization_params
)
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats config.input_shapes, config.input_normalization_modes, input_normalization_params
) )
else: else:
self.normalize_inputs = nn.Identity() self.normalize_inputs = nn.Identity()
output_normalization_params = {} output_normalization_params = _convert_normalization_params_to_tensor(
for outer_key, inner_dict in config.output_normalization_params.items(): config.output_normalization_params
output_normalization_params[outer_key] = {} )
for key, value in inner_dict.items():
output_normalization_params[outer_key][key] = torch.tensor(value)
# HACK: This is hacky and should be removed # HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params dataset_stats = dataset_stats or output_normalization_params
@ -75,7 +77,7 @@ class SACPolicy(
# NOTE: For images the encoder should be shared between the actor and critic # NOTE: For images the encoder should be shared between the actor and critic
if config.shared_encoder: if config.shared_encoder:
encoder_critic = SACObservationEncoder(config) encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor: SACObservationEncoder = encoder_critic encoder_actor: SACObservationEncoder = encoder_critic
else: else:
encoder_critic = SACObservationEncoder(config) encoder_critic = SACObservationEncoder(config)
@ -92,6 +94,7 @@ class SACPolicy(
for _ in range(config.num_critics) for _ in range(config.num_critics)
] ]
), ),
output_normalization=self.normalize_targets,
) )
self.critic_target = CriticEnsemble( self.critic_target = CriticEnsemble(
@ -105,6 +108,7 @@ class SACPolicy(
for _ in range(config.num_critics) for _ in range(config.num_critics)
] ]
), ),
output_normalization=self.normalize_targets,
) )
self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
@ -122,7 +126,7 @@ class SACPolicy(
# TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
# it triggers "can't optimize a non-leaf Tensor" # it triggers "can't optimize a non-leaf Tensor"
self.log_alpha = torch.zeros(1, requires_grad=True, device=torch.device("cuda:0")) self.log_alpha = torch.tensor([0.0], requires_grad=True, device=torch.device("mps"))
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
def reset(self): def reset(self):
@ -313,12 +317,14 @@ class CriticEnsemble(nn.Module):
self, self,
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network_list: nn.ModuleList, network_list: nn.ModuleList,
output_normalization: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
): ):
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder
self.network_list = network_list self.network_list = network_list
self.init_final = init_final self.init_final = init_final
self.output_normalization = output_normalization
self.parameters_to_optimize = [] self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen # Handle the case where a part of the encoder if frozen
@ -358,6 +364,10 @@ class CriticEnsemble(nn.Module):
device = get_device_from_parameters(self) device = get_device_from_parameters(self)
# Move each tensor in observations to device # Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()} observations = {k: v.to(device) for k, v in observations.items()}
# NOTE: We normalize actions it helps for sample efficiency
actions: dict[str, torch.tensor] = {"action": actions}
# NOTE: Normalization layer took dict in input and outputs a dict that why
actions = self.output_normalization(actions)["action"]
actions = actions.to(device) actions = actions.to(device)
obs_enc = observations if self.encoder is None else self.encoder(observations) obs_enc = observations if self.encoder is None else self.encoder(observations)
@ -474,17 +484,18 @@ class Policy(nn.Module):
class SACObservationEncoder(nn.Module): class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations.""" """Encode image and/or state vector observations."""
def __init__(self, config: SACConfig): def __init__(self, config: SACConfig, input_normalizer: nn.Module):
""" """
Creates encoders for pixel and/or state modalities. Creates encoders for pixel and/or state modalities.
""" """
super().__init__() super().__init__()
self.config = config self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = [] self.parameters_to_optimize = []
self.aggregation_size: int = 0 self.aggregation_size: int = 0
if "observation.image" in config.input_shapes: if any("observation.image" in key for key in config.input_shapes):
self.camera_number = config.camera_number self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None: if self.config.vision_encoder_name is not None:
@ -534,8 +545,9 @@ class SACObservationEncoder(nn.Module):
over all features. over all features.
""" """
feat = [] feat = []
obs_dict = self.input_normalization(obs_dict)
# Concatenate all images along the channel dimension. # Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")] image_keys = [k for k in obs_dict if k.startswith("observation.image")]
for image_key in image_keys: for image_key in image_keys:
enc_feat = self.image_enc_layers(obs_dict[image_key]) enc_feat = self.image_enc_layers(obs_dict[image_key])
@ -681,6 +693,18 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params = {}
for outer_key, inner_dict in normalization_params.items():
converted_params[outer_key] = {}
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
return converted_params
if __name__ == "__main__": if __name__ == "__main__":
# Test the SACObservationEncoder # Test the SACObservationEncoder
import time import time

View File

@ -18,6 +18,7 @@ import os
import os.path as osp import os.path as osp
import platform import platform
import random import random
import time
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
@ -217,3 +218,28 @@ def log_say(text, play_sounds, blocking=False):
if play_sounds: if play_sounds:
say(text, blocking) say(text, blocking)
class TimerManager:
def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True):
self.label = label
self.elapsed_time_list = elapsed_time_list
self.log = log
self.elapsed = 0.0
def __enter__(self):
self.start = time.perf_counter()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.elapsed: float = time.perf_counter() - self.start
if self.elapsed_time_list is not None:
self.elapsed_time_list.append(self.elapsed)
if self.log:
print(f"{self.label}: {self.elapsed:.6f} seconds")
@property
def elapsed_seconds(self):
return self.elapsed

View File

@ -2,6 +2,7 @@ defaults:
- _self_ - _self_
- env: pusht - env: pusht
- policy: diffusion - policy: diffusion
- robot: so100
hydra: hydra:
run: run:

View File

@ -8,3 +8,20 @@ env:
state_dim: 6 state_dim: 6
action_dim: 6 action_dim: 6
fps: ${fps} fps: ${fps}
device: mps
wrapper:
crop_params_dict:
observation.images.laptop: [58, 89, 357, 455]
observation.images.phone: [3, 4, 471, 633]
resize_size: [128, 128]
control_time_s: 20
reset_follower_pos: true
use_relative_joint_positions: true
reset_time_s: 10
display_cameras: false
delta_action: 0.1
reward_classifier:
pretrained_path: outputs/classifier/checkpoints/best/pretrained_model
config_path: lerobot/configs/policy/hilserl_classifier.yaml

View File

@ -31,16 +31,21 @@ from omegaconf import DictConfig
from torch import nn from torch import nn
# TODO: Remove the import of maniskill # TODO: Remove the import of maniskill
from lerobot.common.envs.factory import make_maniskill_env # from lerobot.common.envs.factory import make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation # from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.control_utils import busy_wait
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
TimerManager,
get_safe_torch_device, get_safe_torch_device,
set_global_seed, set_global_seed,
) )
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -152,7 +157,15 @@ def serve_actor_service(port=50052):
server.wait_for_termination() server.wait_for_termination()
def act_with_policy(cfg: DictConfig): def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
if not parameters_queue.empty():
logging.debug("[ACTOR] Load new parameters from Learner.")
state_dict = parameters_queue.get()
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.load_state_dict(state_dict)
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
""" """
Executes policy interaction within the environment. Executes policy interaction within the environment.
@ -165,9 +178,7 @@ def act_with_policy(cfg: DictConfig):
logging.info("make_env online") logging.info("make_env online")
# online_env = make_env(cfg, n_envs=1) online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg.env)
# TODO: Remove the import of maniskill and unifiy with make env
online_env = make_maniskill_env(cfg, n_envs=1)
set_global_seed(cfg.seed) set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
@ -177,6 +188,16 @@ def act_with_policy(cfg: DictConfig):
logging.info("make_policy") logging.info("make_policy")
# HACK: This is an ugly hack to pass the normalization parameters to the policy
# Because the action space is dynamic so we override the output normalization parameters
# it's ugly, we know ... and we will fix it
min_action_space: list = online_env.action_space.spaces[0].low.tolist()
max_action_space: list = online_env.action_space.spaces[0].high.tolist()
output_normalization_params: dict[dict[str, list]] = {
"action": {"min": min_action_space, "max": max_action_space}
}
cfg.policy.output_normalization_params = output_normalization_params
### Instantiate the policy in both the actor and learner processes ### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance ### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
@ -187,92 +208,41 @@ def act_with_policy(cfg: DictConfig):
# Hack: But if we do online training, we do not need dataset_stats # Hack: But if we do online training, we do not need dataset_stats
dataset_stats=None, dataset_stats=None,
# TODO: Handle resume training # TODO: Handle resume training
device=device,
) )
# pretrained_policy_name_or_path=None,
# device=device,
# )
policy = torch.compile(policy) policy = torch.compile(policy)
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
# HACK for maniskill
obs, info = online_env.reset() obs, info = online_env.reset()
# obs = preprocess_observation(obs)
obs = preprocess_maniskill_observation(obs)
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
# NOTE: For the moment we will solely handle the case of a single environment # NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0 sum_reward_episode = 0
list_transition_to_send_to_learner = [] list_transition_to_send_to_learner = []
list_policy_fps = [] list_policy_time = []
for interaction_step in range(cfg.training.online_steps): for interaction_step in range(cfg.training.online_steps):
if interaction_step >= cfg.training.online_step_before_learning: if interaction_step >= cfg.training.online_step_before_learning:
start = time.perf_counter() # Time policy inference and check if it meets FPS requirement
action = policy.select_action(batch=obs) with TimerManager(
list_policy_fps.append(1.0 / (time.perf_counter() - start + 1e-9)) elapsed_time_list=list_policy_time, label="Policy inference time", log=False
if list_policy_fps[-1] < cfg.fps: ) as timer: # noqa: F841
logging.warning( action = policy.select_action(batch=obs) * 0.0
f"[ACTOR] policy frame rate {list_policy_fps[-1]} during interaction step {interaction_step} is below the required control frame rate {cfg.fps}" policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
)
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy()) log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
else: else:
# TODO (azouitine): Make a custom space for torch tensor
action = online_env.action_space.sample() action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action) next_obs, reward, done, truncated, info = online_env.step(action)
# HACK
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
# HACK: For maniskill # HACK: We have only one env but we want to batch it, it will be resolved with the torch box
# next_obs = preprocess_observation(next_obs) action = torch.from_numpy(action[0]).to(device, non_blocking=True).unsqueeze(dim=0)
next_obs = preprocess_maniskill_observation(next_obs)
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
sum_reward_episode += float(reward[0])
# Because we are using a single environment we can index at zero sum_reward_episode += float(reward)
if done[0].item() or truncated[0].item():
# TODO: Handle logging for episode information
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
if not parameters_queue.empty(): # NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
logging.debug("[ACTOR] Load new parameters from Learner.")
state_dict = parameters_queue.get()
state_dict = move_state_dict_to_device(state_dict, device=device)
# strict=False for the case when the image encoder is frozen and not sent through
# the network. Becareful might cause issues if the wrong keys are passed
policy.actor.load_state_dict(state_dict, strict=False)
if len(list_transition_to_send_to_learner) > 0:
logging.debug(
f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner."
)
message_queue.put(ActorInformation(transition=list_transition_to_send_to_learner))
list_transition_to_send_to_learner = []
stats = {}
if len(list_policy_fps) > 0:
policy_fps = mean(list_policy_fps)
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
list_policy_fps = []
# Send episodic reward to the learner
message_queue.put(
ActorInformation(
interaction_message={
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
**stats,
}
)
)
sum_reward_episode = 0.0
# TODO (michel-aractingi): Label the reward
# if config.label_reward_on_actor:
# reward = reward_classifier(obs)
if info["is_intervention"]: if info["is_intervention"]:
# TODO: Check the shape # TODO: Check the shape
action = info["action_intervention"] action = info["action_intervention"]
@ -291,17 +261,85 @@ def act_with_policy(cfg: DictConfig):
# assign obs to the next obs and continue the rollout # assign obs to the next obs and continue the rollout
obs = next_obs obs = next_obs
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
# Because we are using a single environment we can index at zero
if done or truncated:
# TODO: Handle logging for episode information
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
# update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
send_transitions_in_chunks(
transitions=list_transition_to_send_to_learner, message_queue=message_queue, chunk_size=4
)
list_transition_to_send_to_learner = []
stats = get_frequency_stats(list_policy_time)
list_policy_time.clear()
# Send episodic reward to the learner
message_queue.put(
ActorInformation(
interaction_message={
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
**stats,
}
)
)
sum_reward_episode = 0.0
obs, info = online_env.reset()
def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int = 100):
"""Send transitions to learner in smaller chunks to avoid network issues.
Args:
transitions: List of transitions to send
message_queue: Queue to send messages to learner
chunk_size: Size of each chunk to send
"""
for i in range(0, len(transitions), chunk_size):
chunk = transitions[i : i + chunk_size]
logging.debug(f"[ACTOR] Sending chunk of {len(chunk)} transitions to Learner.")
message_queue.put(ActorInformation(transition=chunk))
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
stats = {}
list_policy_fps = [1.0 / t for t in list_policy_time]
if len(list_policy_fps) > 0:
policy_fps = mean(list_policy_fps)
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
return stats
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
if policy_fps < cfg.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
)
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs") @hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def actor_cli(cfg: dict): def actor_cli(cfg: dict):
port = cfg.actor_learner_config.port robot = make_robot(cfg=cfg.robot)
server_thread = Thread(target=serve_actor_service, args=(port,), daemon=True)
server_thread.start() server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True)
reward_classifier = get_classifier(
pretrained_path=cfg.env.reward_classifier.pretrained_path,
config_path=cfg.env.reward_classifier.config_path,
)
policy_thread = Thread( policy_thread = Thread(
target=act_with_policy, target=act_with_policy,
daemon=True, daemon=True,
args=(cfg,), args=(cfg, robot, reward_classifier),
) )
server_thread.start()
policy_thread.start() policy_thread.start()
policy_thread.join() policy_thread.join()
server_thread.join() server_thread.join()

View File

@ -56,10 +56,10 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
} }
# If complementary_info is present, move its tensors to CPU # If complementary_info is present, move its tensors to CPU
if transition["complementary_info"] is not None: # if transition["complementary_info"] is not None:
transition["complementary_info"] = { # transition["complementary_info"] = {
key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items() # key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
} # }
return transition return transition
@ -309,6 +309,7 @@ class ReplayBuffer:
def sample(self, batch_size: int) -> BatchTransition: def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors.""" """Sample a random batch of transitions and collate them into batched tensors."""
batch_size = min(batch_size, len(self.memory))
list_of_transitions = random.sample(self.memory, batch_size) list_of_transitions = random.sample(self.memory, batch_size)
# -- Build batched states -- # -- Build batched states --
@ -341,9 +342,6 @@ class ReplayBuffer:
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device self.device
) )
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
# Return a BatchTransition typed dict # Return a BatchTransition typed dict
return BatchTransition( return BatchTransition(
@ -531,30 +529,31 @@ def concatenate_batch_transitions(
# if __name__ == "__main__": # if __name__ == "__main__":
# dataset_name = "lerobot/pusht_image" # dataset_name = "aractingi/push_green_cube_hf_cropped_resized"
# dataset = LeRobotDataset(repo_id=dataset_name, episodes=range(1, 3)) # dataset = LeRobotDataset(repo_id=dataset_name)
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
# )
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
# for i in range(len(replay_buffer_converted)):
# replay_convert = replay_buffer_converted[i]
# dataset_convert = dataset[i]
# for key in replay_convert.keys():
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
# continue
# if key in dataset_convert.keys():
# assert torch.equal(replay_convert[key], dataset_convert[key])
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
# )
# for _ in range(20):
# batch = re_reconverted_dataset.sample(32)
# for key in batch.keys(): # replay_buffer = ReplayBuffer.from_lerobot_dataset(
# if key in {"state", "next_state"}: # lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
# for key_state in batch[key].keys(): # )
# print(key_state, batch[key][key_state].size()) # replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
# continue # for i in range(len(replay_buffer_converted)):
# print(key, batch[key].size()) # replay_convert = replay_buffer_converted[i]
# dataset_convert = dataset[i]
# for key in replay_convert.keys():
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
# continue
# if key in dataset_convert.keys():
# assert torch.equal(replay_convert[key], dataset_convert[key])
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
# )
# for _ in range(20):
# batch = re_reconverted_dataset.sample(32)
# for key in batch.keys():
# if key in {"state", "next_state"}:
# for key_state in batch[key].keys():
# print(key_state, batch[key][key_state].size())
# continue
# print(key, batch[key].size())

View File

@ -4,7 +4,6 @@ import time
from threading import Lock from threading import Lock
from typing import Annotated, Any, Callable, Dict, Optional, Tuple from typing import Annotated, Any, Callable, Dict, Optional, Tuple
import cv2
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
import torch import torch
@ -20,10 +19,15 @@ logging.basicConfig(level=logging.INFO)
class HILSerlRobotEnv(gym.Env): class HILSerlRobotEnv(gym.Env):
""" """
Gym-like environment wrapper for robot policy evaluation. Gym-compatible environment for evaluating robotic control policies with integrated human intervention.
This wrapper provides a consistent interface for interacting with the robot, This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta)
following the OpenAI Gym environment conventions. and absolute joint position commands and automatically configures its observation and action spaces based on the robot's
sensors and configuration.
The environment can switch between executing actions from a policy or using teleoperated actions (human intervention) during
each step. When teleoperation is used, the override action is captured and returned in the `info` dict along with a flag
`is_intervention`.
""" """
def __init__( def __init__(
@ -31,32 +35,34 @@ class HILSerlRobotEnv(gym.Env):
robot, robot,
use_delta_action_space: bool = True, use_delta_action_space: bool = True,
delta: float | None = None, delta: float | None = None,
display_cameras=False, display_cameras: bool = False,
): ):
""" """
Initialize the robot environment. Initialize the HILSerlRobotEnv environment.
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
Args: Args:
robot: The robot interface object robot: The robot interface object used to connect and interact with the physical robot.
reward_classifier: Optional reward classifier use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
fps: Frames per second for control joint positions are used.
control_time_s: Total control time for each episode delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
display_cameras: Whether to display camera feeds 0 and 1 when using a delta action space.
output_normalization_params_action: Bound parameters for the action space display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
delta: The delta for the relative joint position action space
""" """
super().__init__() super().__init__()
self.robot = robot self.robot = robot
self.display_cameras = display_cameras self.display_cameras = display_cameras
# connect robot # Connect to the robot if not already connected.
if not self.robot.is_connected: if not self.robot.is_connected:
self.robot.connect() self.robot.connect()
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position") self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
# Episode tracking # Episode tracking.
self.current_step = 0 self.current_step = 0
self.episode_data = None self.episode_data = None
@ -64,6 +70,7 @@ class HILSerlRobotEnv(gym.Env):
self.use_delta_action_space = use_delta_action_space self.use_delta_action_space = use_delta_action_space
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
# Retrieve the size of the joint position interval bound.
self.relative_bounds_size = ( self.relative_bounds_size = (
self.robot.config.joint_position_relative_bounds["max"] self.robot.config.joint_position_relative_bounds["max"]
- self.robot.config.joint_position_relative_bounds["min"] - self.robot.config.joint_position_relative_bounds["min"]
@ -73,20 +80,26 @@ class HILSerlRobotEnv(gym.Env):
self.robot.config.max_relative_target = self.delta_relative_bounds_size.float() self.robot.config.max_relative_target = self.delta_relative_bounds_size.float()
# Dynamically determine observation and action spaces # Dynamically configure the observation and action spaces.
self._setup_spaces() self._setup_spaces()
def _setup_spaces(self): def _setup_spaces(self):
""" """
Dynamically determine observation and action spaces based on robot capabilities. Dynamically configure the observation and action spaces based on the robot's capabilities.
This method should be customized based on the specific robot's observation Observation Space:
and action representations. - For keys with "image": A Box space with pixel values ranging from 0 to 255.
- For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range.
Action Space:
- The action space is defined as a Tuple where:
The first element is a Box space representing joint position commands. It is defined as relative (delta)
or absolute, based on the configuration.
The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation).
""" """
# Example space setup - you'll need to adapt this to your specific robot
example_obs = self.robot.capture_observation() example_obs = self.robot.capture_observation()
# Observation space (assuming image-based observations) # Define observation spaces for images and other states.
image_keys = [key for key in example_obs if "image" in key] image_keys = [key for key in example_obs if "image" in key]
state_keys = [key for key in example_obs if "image" not in key] state_keys = [key for key in example_obs if "image" not in key]
observation_spaces = { observation_spaces = {
@ -102,7 +115,7 @@ class HILSerlRobotEnv(gym.Env):
self.observation_space = gym.spaces.Dict(observation_spaces) self.observation_space = gym.spaces.Dict(observation_spaces)
# Action space (assuming joint positions) # Define the action space for joint positions along with setting an intervention flag.
action_dim = len(self.robot.follower_arms["main"].read("Present_Position")) action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
if self.use_delta_action_space: if self.use_delta_action_space:
action_space_robot = gym.spaces.Box( action_space_robot = gym.spaces.Box(
@ -128,18 +141,24 @@ class HILSerlRobotEnv(gym.Env):
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
""" """
Reset the environment to initial state. Reset the environment to its initial state.
This method resets the step counter and clears any episodic data.
Args:
seed (Optional[int]): A seed for random number generation to ensure reproducibility.
options (Optional[dict]): Additional options to influence the reset behavior.
Returns: Returns:
observation (dict): Initial observation A tuple containing:
info (dict): Additional information - observation (dict): The initial sensor observation.
- info (dict): A dictionary with supplementary information, including the key "initial_position".
""" """
super().reset(seed=seed, options=options) super().reset(seed=seed, options=options)
# Capture initial observation # Capture the initial observation.
observation = self.robot.capture_observation() observation = self.robot.capture_observation()
# Reset tracking variables # Reset episode tracking variables.
self.current_step = 0 self.current_step = 0
self.episode_data = None self.episode_data = None
@ -149,28 +168,38 @@ class HILSerlRobotEnv(gym.Env):
self, action: Tuple[np.ndarray, bool] self, action: Tuple[np.ndarray, bool]
) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
""" """
Take a step in the environment. Execute a single step within the environment using the specified action.
The provided action is a tuple comprised of:
A policy action (joint position commands) that may be either in absolute values or as a delta.
A boolean flag indicating whether teleoperation (human intervention) should be used for this step.
Behavior:
- When the intervention flag is False, the environment processes and sends the policy action to the robot.
- When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted
to relative change based on the current joint positions.
Args: Args:
action tuple(np.ndarray, bool): action (tuple): A tuple with two elements:
Policy action to be executed on the robot and boolean to determine - policy_action (np.ndarray or torch.Tensor): The commanded joint positions.
whether to choose policy action or expert action. - intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input.
Returns: Returns:
observation (dict): Next observation tuple: A tuple containing:
reward (float): Reward for this step - observation (dict): The new sensor observation after taking the step.
terminated (bool): Whether the episode has terminated - reward (float): The step reward (default is 0.0 within this wrapper).
truncated (bool): Whether the episode was truncated - terminated (bool): True if the episode has reached a terminal state.
info (dict): Additional information - truncated (bool): True if the episode was truncated (e.g., time constraints).
- info (dict): Additional debugging information including:
"action_intervention": The teleop action if intervention was used.
"is_intervention": Flag indicating whether teleoperation was employed.
""" """
# The actions recieved are the in form of a tuple containing the policy action and an intervention bool
# The boolean inidicated whether we will use the expert's actions (through teleoperation) or the policy actions
policy_action, intervention_bool = action policy_action, intervention_bool = action
teleop_action = None teleop_action = None
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
if isinstance(policy_action, torch.Tensor): if isinstance(policy_action, torch.Tensor):
policy_action = policy_action.cpu().numpy() policy_action = policy_action.cpu().numpy()
olicy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high) policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
if not intervention_bool: if not intervention_bool:
if self.use_delta_action_space: if self.use_delta_action_space:
target_joint_positions = self.current_joint_positions + self.delta * policy_action target_joint_positions = self.current_joint_positions + self.delta * policy_action
@ -180,26 +209,26 @@ class HILSerlRobotEnv(gym.Env):
observation = self.robot.capture_observation() observation = self.robot.capture_observation()
else: else:
observation, teleop_action = self.robot.teleop_step(record_data=True) observation, teleop_action = self.robot.teleop_step(record_data=True)
teleop_action = teleop_action["action"] # teleop step returns torch tensors but in a dict teleop_action = teleop_action["action"] # Convert tensor to appropriate format
# teleop actions are returned in absolute joint space # When applying the delta action space, convert teleop absolute values to relative differences.
# If we are using a relative joint position action space,
# there will be a mismatch between the spaces of the policy and teleop actions
# Solution is to transform the teleop actions into relative space.
# teleop relative action is:
if self.use_delta_action_space: if self.use_delta_action_space:
teleop_action = teleop_action - self.current_joint_positions teleop_action = teleop_action - self.current_joint_positions
if torch.any(teleop_action < -self.delta_relative_bounds_size * self.delta) and torch.any( if torch.any(teleop_action < -self.delta_relative_bounds_size * self.delta) and torch.any(
teleop_action > self.delta_relative_bounds_size teleop_action > self.delta_relative_bounds_size
): ):
print( print(
f"relative teleop delta exceeded bounds {self.delta_relative_bounds_size}, teleop_action {teleop_action}\n" f"Relative teleop delta exceeded bounds {self.delta_relative_bounds_size}, teleop_action {teleop_action}\n"
f"lower bounds condition {teleop_action < -self.delta_relative_bounds_size}\n" f"lower bounds condition {teleop_action < -self.delta_relative_bounds_size}\n"
f"upper bounds condition {teleop_action > self.delta_relative_bounds_size}" f"upper bounds condition {teleop_action > self.delta_relative_bounds_size}"
) )
teleop_action = torch.clamp( teleop_action = torch.clamp(
teleop_action, -self.delta_relative_bounds_size, self.delta_relative_bounds_size teleop_action, -self.delta_relative_bounds_size, self.delta_relative_bounds_size
) )
# NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action.
if teleop_action.dim() == 1:
teleop_action = teleop_action.unsqueeze(0)
self.current_step += 1 self.current_step += 1
@ -217,7 +246,7 @@ class HILSerlRobotEnv(gym.Env):
def render(self): def render(self):
""" """
Render the environment (in this case, display camera feeds). Render the current state of the environment by displaying the robot's camera feeds.
""" """
import cv2 import cv2
@ -231,7 +260,10 @@ class HILSerlRobotEnv(gym.Env):
def close(self): def close(self):
""" """
Close the environment and disconnect the robot. Close the environment and clean up resources by disconnecting the robot.
If the robot is currently connected, this method properly terminates the connection to ensure that all
associated resources are released.
""" """
if self.robot.is_connected: if self.robot.is_connected:
self.robot.disconnect() self.robot.disconnect()
@ -250,48 +282,19 @@ class ActionRepeatWrapper(gym.Wrapper):
return obs, reward, done, truncated, info return obs, reward, done, truncated, info
class RelativeJointPositionActionWrapper(gym.Wrapper):
def __init__(
self,
env: HILSerlRobotEnv,
# output_normalization_params_action: dict[str, list[float]],
delta: float = 0.1,
):
super().__init__(env)
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
self.delta = delta
if delta > 1:
raise ValueError("Delta should be less than 1")
def step(self, action):
action_joint = action
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
if isinstance(self.env.action_space, gym.spaces.Tuple):
action_joint = action[0]
joint_positions = self.joint_positions + (self.delta * action_joint)
# clip the joint positions to the joint limits with the action space
joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high)
if isinstance(self.env.action_space, gym.spaces.Tuple):
return self.env.step((joint_positions, action[1]))
obs, reward, terminated, truncated, info = self.env.step(joint_positions)
if info["is_intervention"]:
# teleop actions are returned in absolute joint space
# If we are using a relative joint position action space,
# there will be a mismatch between the spaces of the policy and teleop actions
# Solution is to transform the teleop actions into relative space.
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
teleop_action = info["action_intervention"] # teleop actions are in absolute joint space
relative_teleop_action = (teleop_action - self.joint_positions) / self.delta
info["action_intervention"] = relative_teleop_action
return self.env.step(joint_positions)
class RewardWrapper(gym.Wrapper): class RewardWrapper(gym.Wrapper):
def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"): def __init__(self, env, reward_classifier, device: torch.device = "cuda"):
"""
Wrapper to add reward prediction to the environment, it use a trained classifer.
Args:
env: The environment to wrap
reward_classifier: The reward classifier model
device: The device to run the model on
"""
self.env = env self.env = env
# NOTE: We got 15% speedup by compiling the model
self.reward_classifier = torch.compile(reward_classifier) self.reward_classifier = torch.compile(reward_classifier)
self.device = device self.device = device
@ -305,9 +308,7 @@ class RewardWrapper(gym.Wrapper):
reward = ( reward = (
self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0 self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
) )
# print(f"fps for reward classifier {1/(time.perf_counter() - start_time)}") info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
reward = reward.item()
# print(f"Reward from reward classifier {reward}")
return observation, reward, terminated, truncated, info return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None): def reset(self, seed=None, options=None):
@ -323,17 +324,23 @@ class TimeLimitWrapper(gym.Wrapper):
self.last_timestamp = 0.0 self.last_timestamp = 0.0
self.episode_time_in_s = 0.0 self.episode_time_in_s = 0.0
self.max_episode_steps = int(self.control_time_s * self.fps)
self.current_step = 0
def step(self, action): def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action) obs, reward, terminated, truncated, info = self.env.step(action)
time_since_last_step = time.perf_counter() - self.last_timestamp time_since_last_step = time.perf_counter() - self.last_timestamp
# logging.warning(f"Current timestep is lower than the expected fps {self.fps}")
self.episode_time_in_s += time_since_last_step self.episode_time_in_s += time_since_last_step
self.last_timestamp = time.perf_counter() self.last_timestamp = time.perf_counter()
self.current_step += 1
# check if last timestep took more time than the expected fps # check if last timestep took more time than the expected fps
if 1.0 / time_since_last_step < self.fps: # if 1.0 / time_since_last_step < self.fps:
logging.warning(f"Current timestep is lower than the expected fps {self.fps}") # logging.warning(f"Current timestep exceeded expected fps {self.fps}")
if self.episode_time_in_s > self.control_time_s: if self.episode_time_in_s > self.control_time_s:
# if self.current_step >= self.max_episode_steps:
# Terminated = True # Terminated = True
terminated = True terminated = True
return obs, reward, terminated, truncated, info return obs, reward, terminated, truncated, info
@ -341,11 +348,13 @@ class TimeLimitWrapper(gym.Wrapper):
def reset(self, seed=None, options=None): def reset(self, seed=None, options=None):
self.episode_time_in_s = 0.0 self.episode_time_in_s = 0.0
self.last_timestamp = time.perf_counter() self.last_timestamp = time.perf_counter()
self.current_step = 0
return self.env.reset(seed=seed, options=options) return self.env.reset(seed=seed, options=options)
class ImageCropResizeWrapper(gym.Wrapper): class ImageCropResizeWrapper(gym.Wrapper):
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None): def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
super().__init__(env)
self.env = env self.env = env
self.crop_params_dict = crop_params_dict self.crop_params_dict = crop_params_dict
print(f"obs_keys , {self.env.observation_space}") print(f"obs_keys , {self.env.observation_space}")
@ -372,10 +381,21 @@ class ImageCropResizeWrapper(gym.Wrapper):
obs[k] = F.resize(obs[k], self.resize_size) obs[k] = F.resize(obs[k], self.resize_size)
obs[k] = obs[k].to(device) obs[k] = obs[k].to(device)
# print(f"observation with key {k} with size {obs[k].size()}") # print(f"observation with key {k} with size {obs[k].size()}")
cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)) # cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1) # cv2.waitKey(1)
return obs, reward, terminated, truncated, info return obs, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options=options)
for k in self.crop_params_dict:
device = obs[k].device
if device == torch.device("mps:0"):
obs[k] = obs[k].cpu()
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
obs[k] = F.resize(obs[k], self.resize_size)
obs[k] = obs[k].to(device)
return obs, info
class ConvertToLeRobotObservation(gym.ObservationWrapper): class ConvertToLeRobotObservation(gym.ObservationWrapper):
def __init__(self, env, device): def __init__(self, env, device):
@ -515,42 +535,64 @@ class ResetWrapper(gym.Wrapper):
return super().reset(seed=seed, options=options) return super().reset(seed=seed, options=options)
class BatchCompitableWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
for key in observation:
if "image" in key and observation[key].dim() == 3:
observation[key] = observation[key].unsqueeze(0)
if "state" in key and observation[key].dim() == 1:
observation[key] = observation[key].unsqueeze(0)
return observation
def make_robot_env( def make_robot_env(
robot, robot,
reward_classifier, reward_classifier,
crop_params_dict=None, cfg,
fps=30, n_envs: int = 1,
control_time_s=20, ) -> gym.vector.VectorEnv:
reset_follower_pos=True,
display_cameras=False,
device="cuda:0",
resize_size=None,
reset_time_s=10,
delta_action=0.1,
nb_repeats=1,
use_relative_joint_positions=False,
):
""" """
Factory function to create the robot environment. Factory function to create a vectorized robot environment.
Mimics gym.make() for consistent environment creation. Args:
robot: Robot instance to control
reward_classifier: Classifier model for computing rewards
cfg: Configuration object containing environment parameters
n_envs: Number of environments to create in parallel. Defaults to 1.
Returns:
A vectorized gym environment with all the necessary wrappers applied.
""" """
# Create base environment
env = HILSerlRobotEnv( env = HILSerlRobotEnv(
robot, robot=robot,
display_cameras=display_cameras, display_cameras=cfg.wrapper.display_cameras,
delta=delta_action, delta=cfg.wrapper.delta_action,
use_delta_action_space=use_relative_joint_positions, use_delta_action_space=cfg.wrapper.use_relative_joint_positions,
) )
env = ConvertToLeRobotObservation(env, device)
if crop_params_dict is not None: # Add observation and image processing
env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size) env = ConvertToLeRobotObservation(env=env, device=cfg.device)
env = RewardWrapper(env, reward_classifier, device=device) if cfg.wrapper.crop_params_dict is not None:
env = TimeLimitWrapper(env, control_time_s, fps) env = ImageCropResizeWrapper(
# env = ActionRepeatWrapper(env, nb_repeat=nb_repeats) env=env, crop_params_dict=cfg.wrapper.crop_params_dict, resize_size=cfg.wrapper.resize_size
env = KeyboardInterfaceWrapper(env) )
env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s)
# Add reward computation and control wrappers
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
env = KeyboardInterfaceWrapper(env=env)
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.wrapper.reset_time_s)
env = BatchCompitableWrapper(env=env)
return env return env
# batched version of the env that returns an observation of shape (b, c)
def get_classifier(pretrained_path, config_path, device="mps"): def get_classifier(pretrained_path, config_path, device="mps"):
if pretrained_path is None or config_path is None: if pretrained_path is None or config_path is None:
@ -616,6 +658,8 @@ if __name__ == "__main__":
default=None, default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.", help="Path to a yaml config file that is necessary to build the reward classifier model.",
) )
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
parser.add_argument("--env-overrides", type=str, default=None, help="Overrides for the env yaml file")
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds") parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes") parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
args = parser.parse_args() args = parser.parse_args()
@ -626,72 +670,38 @@ if __name__ == "__main__":
reward_classifier = get_classifier( reward_classifier = get_classifier(
args.reward_classifier_pretrained_path, args.reward_classifier_config_file args.reward_classifier_pretrained_path, args.reward_classifier_config_file
) )
crop_parameters = {
"observation.images.laptop": (58, 89, 357, 455),
"observation.images.phone": (3, 4, 471, 633),
}
user_relative_joint_positions = True user_relative_joint_positions = True
cfg = init_hydra_config(args.env_path, args.env_overrides)
env = make_robot_env( env = make_robot_env(
robot, robot,
reward_classifier, reward_classifier,
crop_parameters, cfg.wrapper,
args.fps,
args.control_time_s,
args.reset_follower_pos,
args.display_cameras,
device="mps",
resize_size=None,
reset_time_s=10,
delta_action=0.1,
nb_repeats=1,
use_relative_joint_positions=user_relative_joint_positions,
) )
env.reset() env.reset()
init_pos = env.unwrapped.initial_follower_position
right_goal = init_pos.copy() # Retrieve the robot's action space for joint commands.
right_goal[0] += 50 action_space_robot = env.action_space.spaces[0]
left_goal = init_pos.copy() # Initialize the smoothed action as a random sample.
left_goal[0] -= 50 smoothed_action = action_space_robot.sample()
pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000) # Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
delta_angle = np.concatenate((-np.ones(50), np.ones(50))) * 100 alpha = 0.4
while True: while True:
action = np.zeros(len(init_pos)) start_loop_s = time.perf_counter()
for i in range(len(delta_angle)): # Sample a new random action from the robot's action space.
start_loop_s = time.perf_counter() new_random_action = action_space_robot.sample()
action[0] = delta_angle[i] # Update the smoothed action using an exponential moving average.
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False)) smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
if terminated or truncated:
env.reset()
dt_s = time.perf_counter() - start_loop_s # Execute the step: wrap the NumPy action in a torch tensor.
busy_wait(1 / args.fps - dt_s) obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
# action = np.zeros(len(init_pos)) if user_relative_joint_positions else init_pos if terminated or truncated:
# for i in range(len(pitch_angle)): env.reset()
# if user_relative_joint_positions:
# action[0] = delta_angle[i]
# else:
# action[0] = pitch_angle[i]
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
# if terminated or truncated:
# logging.info("Max control time reached, reset environment.")
# env.reset()
# for i in reversed(range(len(pitch_angle))): dt_s = time.perf_counter() - start_loop_s
# if user_relative_joint_positions: busy_wait(1 / args.fps - dt_s)
# action[0] = delta_angle[i]
# else:
# action[0] = pitch_angle[i]
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
# if terminated or truncated:
# logging.info("Max control time reached, reset environment.")
# env.reset()

View File

@ -36,6 +36,8 @@ from termcolor import colored
from torch import nn from torch import nn
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill # TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
@ -52,6 +54,7 @@ from lerobot.common.utils.utils import (
) )
from lerobot.scripts.server.buffer import ( from lerobot.scripts.server.buffer import (
ReplayBuffer, ReplayBuffer,
concatenate_batch_transitions,
move_state_dict_to_device, move_state_dict_to_device,
move_transition_to_device, move_transition_to_device,
) )
@ -259,8 +262,15 @@ def learner_push_parameters(
while True: while True:
with policy_lock: with policy_lock:
params_dict = policy.actor.state_dict() params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder: if policy.config.vision_encoder_name is not None:
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")} if policy.config.freeze_vision_encoder:
params_dict: dict[str, torch.Tensor] = {
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
}
else:
raise NotImplementedError(
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
)
params_dict = move_state_dict_to_device(params_dict, device="cpu") params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize # Serialize
@ -322,6 +332,7 @@ def add_actor_information_and_train(
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance # in the future. The reason why we did that is the GIL in Python. It's super slow the performance
# are divided by 200. So we need to have a single thread that does all the work. # are divided by 200. So we need to have a single thread that does all the work.
time.time() time.time()
logging.info("Starting learner thread")
interaction_message, transition = None, None interaction_message, transition = None, None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0 optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
@ -340,16 +351,21 @@ def add_actor_information_and_train(
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step") logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
logging.info(f"Interaction message: {interaction_message}")
if len(replay_buffer) < cfg.training.online_step_before_learning: if len(replay_buffer) < cfg.training.online_step_before_learning:
continue continue
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
time_for_one_optimization_step = time.time() time_for_one_optimization_step = time.time()
for _ in range(cfg.policy.utd_ratio - 1): for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size) batch = replay_buffer.sample(batch_size)
# if cfg.offline_dataset_repo_id is not None: if cfg.dataset_repo_id is not None:
# batch_offline = offline_replay_buffer.sample(batch_size) batch_offline = offline_replay_buffer.sample(batch_size)
# batch = concatenate_batch_transitions(batch, batch_offline) batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"] actions = batch["action"]
rewards = batch["reward"] rewards = batch["reward"]
@ -371,11 +387,11 @@ def add_actor_information_and_train(
batch = replay_buffer.sample(batch_size) batch = replay_buffer.sample(batch_size)
# if cfg.offline_dataset_repo_id is not None: if cfg.dataset_repo_id is not None:
# batch_offline = offline_replay_buffer.sample(batch_size) batch_offline = offline_replay_buffer.sample(batch_size)
# batch = concatenate_batch_transitions( batch = concatenate_batch_transitions(
# left_batch_transitions=batch, right_batch_transition=batch_offline left_batch_transitions=batch, right_batch_transition=batch_offline
# ) )
actions = batch["action"] actions = batch["action"]
rewards = batch["reward"] rewards = batch["reward"]
@ -423,7 +439,7 @@ def add_actor_information_and_train(
time_for_one_optimization_step = time.time() - time_for_one_optimization_step time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
logging.debug(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}") logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logger.log_dict( logger.log_dict(
{"Optimization frequency loop [Hz]": frequency_for_one_optimization_step}, {"Optimization frequency loop [Hz]": frequency_for_one_optimization_step},
@ -560,14 +576,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
batch_size = cfg.training.batch_size batch_size = cfg.training.batch_size
offline_replay_buffer = None offline_replay_buffer = None
# if cfg.dataset_repo_id is not None: if cfg.dataset_repo_id is not None:
# logging.info("make_dataset offline buffer") logging.info("make_dataset offline buffer")
# offline_dataset = make_dataset(cfg) offline_dataset = make_dataset(cfg)
# logging.info("Convertion to a offline replay buffer") logging.info("Convertion to a offline replay buffer")
# offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
# offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys() offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
# ) )
# batch_size: int = batch_size // 2 # We will sample from both replay buffer batch_size: int = batch_size // 2 # We will sample from both replay buffer
start_learner_threads( start_learner_threads(
cfg, cfg,

View File

@ -279,8 +279,10 @@ def train(cfg: DictConfig) -> None:
logging.info(f"Dataset size: {len(dataset)}") logging.info(f"Dataset size: {len(dataset)}")
train_size = int(cfg.train_split_proportion * len(dataset)) train_size = int(cfg.train_split_proportion * len(dataset))
val_size = len(dataset) - train_size # val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) # train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_dataset = dataset[:train_size]
val_dataset = dataset[train_size:]
sampler = create_balanced_sampler(train_dataset, cfg) sampler = create_balanced_sampler(train_dataset, cfg)
train_loader = DataLoader( train_loader = DataLoader(