Several fixes to move the actor_server and learner_server code from the maniskill environment to the real robot environment.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-10 16:03:39 +01:00 committed by Adil Zouitine
parent 4891270886
commit 174087eed2
10 changed files with 597 additions and 318 deletions

View File

@ -39,6 +39,12 @@ class SACConfig:
"observation.environment_state": "min_max",
}
)
input_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"observation.image": {"mean": [[0.485, 0.456, 0.406]], "std": [[0.229, 0.224, 0.225]]},
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
output_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {

View File

@ -51,18 +51,20 @@ class SACPolicy(
if config is None:
config = SACConfig()
self.config = config
if config.input_normalization_modes is not None:
input_normalization_params = _convert_normalization_params_to_tensor(
config.input_normalization_params
)
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
config.input_shapes, config.input_normalization_modes, input_normalization_params
)
else:
self.normalize_inputs = nn.Identity()
output_normalization_params = {}
for outer_key, inner_dict in config.output_normalization_params.items():
output_normalization_params[outer_key] = {}
for key, value in inner_dict.items():
output_normalization_params[outer_key][key] = torch.tensor(value)
output_normalization_params = _convert_normalization_params_to_tensor(
config.output_normalization_params
)
# HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params
@ -75,7 +77,7 @@ class SACPolicy(
# NOTE: For images the encoder should be shared between the actor and critic
if config.shared_encoder:
encoder_critic = SACObservationEncoder(config)
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor: SACObservationEncoder = encoder_critic
else:
encoder_critic = SACObservationEncoder(config)
@ -92,6 +94,7 @@ class SACPolicy(
for _ in range(config.num_critics)
]
),
output_normalization=self.normalize_targets,
)
self.critic_target = CriticEnsemble(
@ -105,6 +108,7 @@ class SACPolicy(
for _ in range(config.num_critics)
]
),
output_normalization=self.normalize_targets,
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
@ -122,7 +126,7 @@ class SACPolicy(
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
# it triggers "can't optimize a non-leaf Tensor"
self.log_alpha = torch.zeros(1, requires_grad=True, device=torch.device("cuda:0"))
self.log_alpha = torch.tensor([0.0], requires_grad=True, device=torch.device("mps"))
self.temperature = self.log_alpha.exp().item()
def reset(self):
@ -313,12 +317,14 @@ class CriticEnsemble(nn.Module):
self,
encoder: Optional[nn.Module],
network_list: nn.ModuleList,
output_normalization: nn.Module,
init_final: Optional[float] = None,
):
super().__init__()
self.encoder = encoder
self.network_list = network_list
self.init_final = init_final
self.output_normalization = output_normalization
self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
@ -358,6 +364,10 @@ class CriticEnsemble(nn.Module):
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
# NOTE: We normalize actions it helps for sample efficiency
actions: dict[str, torch.tensor] = {"action": actions}
# NOTE: Normalization layer took dict in input and outputs a dict that why
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = observations if self.encoder is None else self.encoder(observations)
@ -474,17 +484,18 @@ class Policy(nn.Module):
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig):
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if "observation.image" in config.input_shapes:
if any("observation.image" in key for key in config.input_shapes):
self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None:
@ -534,8 +545,9 @@ class SACObservationEncoder(nn.Module):
over all features.
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
# Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
image_keys = [k for k in obs_dict if k.startswith("observation.image")]
for image_key in image_keys:
enc_feat = self.image_enc_layers(obs_dict[image_key])
@ -681,6 +693,18 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params = {}
for outer_key, inner_dict in normalization_params.items():
converted_params[outer_key] = {}
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
return converted_params
if __name__ == "__main__":
# Test the SACObservationEncoder
import time

View File

@ -18,6 +18,7 @@ import os
import os.path as osp
import platform
import subprocess
import time
from copy import copy
from datetime import datetime, timezone
from pathlib import Path
@ -228,3 +229,28 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
except TypeError:
# If a TypeError is raised, the string is not a valid dtype
return False
class TimerManager:
def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True):
self.label = label
self.elapsed_time_list = elapsed_time_list
self.log = log
self.elapsed = 0.0
def __enter__(self):
self.start = time.perf_counter()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.elapsed: float = time.perf_counter() - self.start
if self.elapsed_time_list is not None:
self.elapsed_time_list.append(self.elapsed)
if self.log:
print(f"{self.label}: {self.elapsed:.6f} seconds")
@property
def elapsed_seconds(self):
return self.elapsed

View File

@ -0,0 +1,131 @@
defaults:
- _self_
- env: pusht
- policy: diffusion
- robot: so100
hydra:
run:
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
job:
name: default
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
# `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it.
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# regardless of what's provided with the training command at the time of resumption.
resume: false
device: cuda # cpu
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: false
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: ???
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datsets are provided.
dataset_repo_id: lerobot/pusht
video_backend: pyav
training:
offline_steps: ???
# Number of workers for the offline training dataloader.
num_workers: 4
batch_size: ???
eval_freq: ???
log_freq: 200
save_checkpoint: true
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
save_freq: ???
# Online training. Note that the online training loop adopts most of the options above apart from the
# dataloader options. Unless otherwise specified.
# The online training look looks something like:
#
# for i in range(online_steps):
# do_online_rollout_and_update_online_buffer()
# for j in range(online_steps_between_rollouts):
# batch = next(dataloader_with_offline_and_online_data)
# loss = policy(batch)
# loss.backward()
# optimizer.step()
#
online_steps: ???
# How many episodes to collect at once when we reach the online rollout part of the training loop.
online_rollout_n_episodes: 1
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
# the policy. Ideally you should set this to by an even divisor or online_rollout_n_episodes.
online_rollout_batch_size: 1
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
online_steps_between_rollouts: null
# The proportion of online samples (vs offline samples) to include in the online training batches.
online_sampling_ratio: 0.5
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
online_env_seed: null
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
# FIFO.
online_buffer_capacity: null
# The minimum number of frames to have in the online buffer before commencing online training.
# If online_buffer_seed_size > online_rollout_n_episodes, the rollout will be run multiple times until the
# seed size condition is satisfied.
online_buffer_seed_size: 0
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
# + eval + environment rendering simultaneously.
do_online_rollout_async: false
image_transforms:
# These transforms are all using standard torchvision.transforms.v2
# You can find out how these transformations affect images here:
# https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
# We use a custom RandomSubsetApply container to sample them.
# For each transform, the following parameters are available:
# weight: This represents the multinomial probability (with no replacement)
# used for sampling the transform. If the sum of the weights is not 1,
# they will be normalized.
# min_max: Lower & upper bound respectively used for sampling the transform's parameter
# (following uniform distribution) when it's applied.
# Set this flag to `true` to enable transforms during training
enable: false
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
# It's an integer in the interval [1, number of available transforms].
max_num_transforms: 3
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: false
brightness:
weight: 1
min_max: [0.8, 1.2]
contrast:
weight: 1
min_max: [0.8, 1.2]
saturation:
weight: 1
min_max: [0.5, 1.5]
hue:
weight: 1
min_max: [-0.05, 0.05]
sharpness:
weight: 1
min_max: [0.8, 1.2]
eval:
n_episodes: 1
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
batch_size: 1
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: false
wandb:
enable: false
# Set to true to disable saving an artifact despite save_checkpoint == True
disable_artifact: false
project: lerobot
notes: ""

27
lerobot/configs/env/so100_real.yaml vendored Normal file
View File

@ -0,0 +1,27 @@
# @package _global_
fps: 30
env:
name: real_world
task: null
state_dim: 6
action_dim: 6
fps: ${fps}
device: mps
wrapper:
crop_params_dict:
observation.images.laptop: [58, 89, 357, 455]
observation.images.phone: [3, 4, 471, 633]
resize_size: [128, 128]
control_time_s: 20
reset_follower_pos: true
use_relative_joint_positions: true
reset_time_s: 10
display_cameras: false
delta_action: 0.1
reward_classifier:
pretrained_path: outputs/classifier/checkpoints/best/pretrained_model
config_path: lerobot/configs/policy/hilserl_classifier.yaml

View File

@ -31,16 +31,21 @@ from omegaconf import DictConfig
from torch import nn
# TODO: Remove the import of maniskill
from lerobot.common.envs.factory import make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation
# from lerobot.common.envs.factory import make_maniskill_env
# from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.control_utils import busy_wait
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import (
TimerManager,
get_safe_torch_device,
set_global_seed,
)
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
logging.basicConfig(level=logging.INFO)
@ -152,7 +157,15 @@ def serve_actor_service(port=50052):
server.wait_for_termination()
def act_with_policy(cfg: DictConfig):
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
if not parameters_queue.empty():
logging.debug("[ACTOR] Load new parameters from Learner.")
state_dict = parameters_queue.get()
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.load_state_dict(state_dict)
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
"""
Executes policy interaction within the environment.
@ -165,9 +178,7 @@ def act_with_policy(cfg: DictConfig):
logging.info("make_env online")
# online_env = make_env(cfg, n_envs=1)
# TODO: Remove the import of maniskill and unifiy with make env
online_env = make_maniskill_env(cfg, n_envs=1)
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg.env)
set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
@ -177,6 +188,16 @@ def act_with_policy(cfg: DictConfig):
logging.info("make_policy")
# HACK: This is an ugly hack to pass the normalization parameters to the policy
# Because the action space is dynamic so we override the output normalization parameters
# it's ugly, we know ... and we will fix it
min_action_space: list = online_env.action_space.spaces[0].low.tolist()
max_action_space: list = online_env.action_space.spaces[0].high.tolist()
output_normalization_params: dict[dict[str, list]] = {
"action": {"min": min_action_space, "max": max_action_space}
}
cfg.policy.output_normalization_params = output_normalization_params
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
@ -187,92 +208,41 @@ def act_with_policy(cfg: DictConfig):
# Hack: But if we do online training, we do not need dataset_stats
dataset_stats=None,
# TODO: Handle resume training
device=device,
)
# pretrained_policy_name_or_path=None,
# device=device,
# )
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
# HACK for maniskill
obs, info = online_env.reset()
# obs = preprocess_observation(obs)
obs = preprocess_maniskill_observation(obs)
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
list_transition_to_send_to_learner = []
list_policy_fps = []
list_policy_time = []
for interaction_step in range(cfg.training.online_steps):
if interaction_step >= cfg.training.online_step_before_learning:
start = time.perf_counter()
action = policy.select_action(batch=obs)
list_policy_fps.append(1.0 / (time.perf_counter() - start + 1e-9))
if list_policy_fps[-1] < cfg.fps:
logging.warning(
f"[ACTOR] policy frame rate {list_policy_fps[-1]} during interaction step {interaction_step} is below the required control frame rate {cfg.fps}"
)
# Time policy inference and check if it meets FPS requirement
with TimerManager(
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
) as timer: # noqa: F841
action = policy.select_action(batch=obs) * 0.0
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
else:
# TODO (azouitine): Make a custom space for torch tensor
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# HACK
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
# HACK: For maniskill
# next_obs = preprocess_observation(next_obs)
next_obs = preprocess_maniskill_observation(next_obs)
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
sum_reward_episode += float(reward[0])
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
action = torch.from_numpy(action[0]).to(device, non_blocking=True).unsqueeze(dim=0)
# Because we are using a single environment we can index at zero
if done[0].item() or truncated[0].item():
# TODO: Handle logging for episode information
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
sum_reward_episode += float(reward)
if not parameters_queue.empty():
logging.debug("[ACTOR] Load new parameters from Learner.")
state_dict = parameters_queue.get()
state_dict = move_state_dict_to_device(state_dict, device=device)
# strict=False for the case when the image encoder is frozen and not sent through
# the network. Becareful might cause issues if the wrong keys are passed
policy.actor.load_state_dict(state_dict, strict=False)
if len(list_transition_to_send_to_learner) > 0:
logging.debug(
f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner."
)
message_queue.put(ActorInformation(transition=list_transition_to_send_to_learner))
list_transition_to_send_to_learner = []
stats = {}
if len(list_policy_fps) > 0:
policy_fps = mean(list_policy_fps)
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
list_policy_fps = []
# Send episodic reward to the learner
message_queue.put(
ActorInformation(
interaction_message={
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
**stats,
}
)
)
sum_reward_episode = 0.0
# TODO (michel-aractingi): Label the reward
# if config.label_reward_on_actor:
# reward = reward_classifier(obs)
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
if info["is_intervention"]:
# TODO: Check the shape
action = info["action_intervention"]
@ -291,17 +261,85 @@ def act_with_policy(cfg: DictConfig):
# assign obs to the next obs and continue the rollout
obs = next_obs
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
# Because we are using a single environment we can index at zero
if done or truncated:
# TODO: Handle logging for episode information
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
# update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
send_transitions_in_chunks(
transitions=list_transition_to_send_to_learner, message_queue=message_queue, chunk_size=4
)
list_transition_to_send_to_learner = []
stats = get_frequency_stats(list_policy_time)
list_policy_time.clear()
# Send episodic reward to the learner
message_queue.put(
ActorInformation(
interaction_message={
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
**stats,
}
)
)
sum_reward_episode = 0.0
obs, info = online_env.reset()
def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int = 100):
"""Send transitions to learner in smaller chunks to avoid network issues.
Args:
transitions: List of transitions to send
message_queue: Queue to send messages to learner
chunk_size: Size of each chunk to send
"""
for i in range(0, len(transitions), chunk_size):
chunk = transitions[i : i + chunk_size]
logging.debug(f"[ACTOR] Sending chunk of {len(chunk)} transitions to Learner.")
message_queue.put(ActorInformation(transition=chunk))
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
stats = {}
list_policy_fps = [1.0 / t for t in list_policy_time]
if len(list_policy_fps) > 0:
policy_fps = mean(list_policy_fps)
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
return stats
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
if policy_fps < cfg.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
)
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def actor_cli(cfg: dict):
port = cfg.actor_learner_config.port
server_thread = Thread(target=serve_actor_service, args=(port,), daemon=True)
server_thread.start()
robot = make_robot(cfg=cfg.robot)
server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True)
reward_classifier = get_classifier(
pretrained_path=cfg.env.reward_classifier.pretrained_path,
config_path=cfg.env.reward_classifier.config_path,
)
policy_thread = Thread(
target=act_with_policy,
daemon=True,
args=(cfg,),
args=(cfg, robot, reward_classifier),
)
server_thread.start()
policy_thread.start()
policy_thread.join()
server_thread.join()

View File

@ -56,10 +56,10 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
}
# If complementary_info is present, move its tensors to CPU
if transition["complementary_info"] is not None:
transition["complementary_info"] = {
key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
}
# if transition["complementary_info"] is not None:
# transition["complementary_info"] = {
# key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
# }
return transition
@ -309,6 +309,7 @@ class ReplayBuffer:
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
batch_size = min(batch_size, len(self.memory))
list_of_transitions = random.sample(self.memory, batch_size)
# -- Build batched states --
@ -341,9 +342,6 @@ class ReplayBuffer:
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
# Return a BatchTransition typed dict
return BatchTransition(
@ -531,30 +529,31 @@ def concatenate_batch_transitions(
# if __name__ == "__main__":
# dataset_name = "lerobot/pusht_image"
# dataset = LeRobotDataset(repo_id=dataset_name, episodes=range(1, 3))
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
# )
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
# for i in range(len(replay_buffer_converted)):
# replay_convert = replay_buffer_converted[i]
# dataset_convert = dataset[i]
# for key in replay_convert.keys():
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
# continue
# if key in dataset_convert.keys():
# assert torch.equal(replay_convert[key], dataset_convert[key])
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
# )
# for _ in range(20):
# batch = re_reconverted_dataset.sample(32)
# dataset_name = "aractingi/push_green_cube_hf_cropped_resized"
# dataset = LeRobotDataset(repo_id=dataset_name)
# for key in batch.keys():
# if key in {"state", "next_state"}:
# for key_state in batch[key].keys():
# print(key_state, batch[key][key_state].size())
# continue
# print(key, batch[key].size())
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
# )
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
# for i in range(len(replay_buffer_converted)):
# replay_convert = replay_buffer_converted[i]
# dataset_convert = dataset[i]
# for key in replay_convert.keys():
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
# continue
# if key in dataset_convert.keys():
# assert torch.equal(replay_convert[key], dataset_convert[key])
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
# )
# for _ in range(20):
# batch = re_reconverted_dataset.sample(32)
# for key in batch.keys():
# if key in {"state", "next_state"}:
# for key_state in batch[key].keys():
# print(key_state, batch[key][key_state].size())
# continue
# print(key, batch[key].size())

View File

@ -4,7 +4,6 @@ import time
from threading import Lock
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
import cv2
import gymnasium as gym
import numpy as np
import torch
@ -20,10 +19,15 @@ logging.basicConfig(level=logging.INFO)
class HILSerlRobotEnv(gym.Env):
"""
Gym-like environment wrapper for robot policy evaluation.
Gym-compatible environment for evaluating robotic control policies with integrated human intervention.
This wrapper provides a consistent interface for interacting with the robot,
following the OpenAI Gym environment conventions.
This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta)
and absolute joint position commands and automatically configures its observation and action spaces based on the robot's
sensors and configuration.
The environment can switch between executing actions from a policy or using teleoperated actions (human intervention) during
each step. When teleoperation is used, the override action is captured and returned in the `info` dict along with a flag
`is_intervention`.
"""
def __init__(
@ -31,32 +35,34 @@ class HILSerlRobotEnv(gym.Env):
robot,
use_delta_action_space: bool = True,
delta: float | None = None,
display_cameras=False,
display_cameras: bool = False,
):
"""
Initialize the robot environment.
Initialize the HILSerlRobotEnv environment.
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
Args:
robot: The robot interface object
reward_classifier: Optional reward classifier
fps: Frames per second for control
control_time_s: Total control time for each episode
display_cameras: Whether to display camera feeds
output_normalization_params_action: Bound parameters for the action space
delta: The delta for the relative joint position action space
robot: The robot interface object used to connect and interact with the physical robot.
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
joint positions are used.
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
0 and 1 when using a delta action space.
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
"""
super().__init__()
self.robot = robot
self.display_cameras = display_cameras
# connect robot
# Connect to the robot if not already connected.
if not self.robot.is_connected:
self.robot.connect()
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
# Episode tracking
# Episode tracking.
self.current_step = 0
self.episode_data = None
@ -64,6 +70,7 @@ class HILSerlRobotEnv(gym.Env):
self.use_delta_action_space = use_delta_action_space
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
# Retrieve the size of the joint position interval bound.
self.relative_bounds_size = (
self.robot.config.joint_position_relative_bounds["max"]
- self.robot.config.joint_position_relative_bounds["min"]
@ -73,20 +80,26 @@ class HILSerlRobotEnv(gym.Env):
self.robot.config.max_relative_target = self.delta_relative_bounds_size.float()
# Dynamically determine observation and action spaces
# Dynamically configure the observation and action spaces.
self._setup_spaces()
def _setup_spaces(self):
"""
Dynamically determine observation and action spaces based on robot capabilities.
Dynamically configure the observation and action spaces based on the robot's capabilities.
This method should be customized based on the specific robot's observation
and action representations.
Observation Space:
- For keys with "image": A Box space with pixel values ranging from 0 to 255.
- For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range.
Action Space:
- The action space is defined as a Tuple where:
The first element is a Box space representing joint position commands. It is defined as relative (delta)
or absolute, based on the configuration.
The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation).
"""
# Example space setup - you'll need to adapt this to your specific robot
example_obs = self.robot.capture_observation()
# Observation space (assuming image-based observations)
# Define observation spaces for images and other states.
image_keys = [key for key in example_obs if "image" in key]
state_keys = [key for key in example_obs if "image" not in key]
observation_spaces = {
@ -102,7 +115,7 @@ class HILSerlRobotEnv(gym.Env):
self.observation_space = gym.spaces.Dict(observation_spaces)
# Action space (assuming joint positions)
# Define the action space for joint positions along with setting an intervention flag.
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
if self.use_delta_action_space:
action_space_robot = gym.spaces.Box(
@ -128,18 +141,24 @@ class HILSerlRobotEnv(gym.Env):
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""
Reset the environment to initial state.
Reset the environment to its initial state.
This method resets the step counter and clears any episodic data.
Args:
seed (Optional[int]): A seed for random number generation to ensure reproducibility.
options (Optional[dict]): Additional options to influence the reset behavior.
Returns:
observation (dict): Initial observation
info (dict): Additional information
A tuple containing:
- observation (dict): The initial sensor observation.
- info (dict): A dictionary with supplementary information, including the key "initial_position".
"""
super().reset(seed=seed, options=options)
# Capture initial observation
# Capture the initial observation.
observation = self.robot.capture_observation()
# Reset tracking variables
# Reset episode tracking variables.
self.current_step = 0
self.episode_data = None
@ -149,28 +168,38 @@ class HILSerlRobotEnv(gym.Env):
self, action: Tuple[np.ndarray, bool]
) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
"""
Take a step in the environment.
Execute a single step within the environment using the specified action.
The provided action is a tuple comprised of:
A policy action (joint position commands) that may be either in absolute values or as a delta.
A boolean flag indicating whether teleoperation (human intervention) should be used for this step.
Behavior:
- When the intervention flag is False, the environment processes and sends the policy action to the robot.
- When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted
to relative change based on the current joint positions.
Args:
action tuple(np.ndarray, bool):
Policy action to be executed on the robot and boolean to determine
whether to choose policy action or expert action.
action (tuple): A tuple with two elements:
- policy_action (np.ndarray or torch.Tensor): The commanded joint positions.
- intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input.
Returns:
observation (dict): Next observation
reward (float): Reward for this step
terminated (bool): Whether the episode has terminated
truncated (bool): Whether the episode was truncated
info (dict): Additional information
tuple: A tuple containing:
- observation (dict): The new sensor observation after taking the step.
- reward (float): The step reward (default is 0.0 within this wrapper).
- terminated (bool): True if the episode has reached a terminal state.
- truncated (bool): True if the episode was truncated (e.g., time constraints).
- info (dict): Additional debugging information including:
"action_intervention": The teleop action if intervention was used.
"is_intervention": Flag indicating whether teleoperation was employed.
"""
# The actions recieved are the in form of a tuple containing the policy action and an intervention bool
# The boolean inidicated whether we will use the expert's actions (through teleoperation) or the policy actions
policy_action, intervention_bool = action
teleop_action = None
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
if isinstance(policy_action, torch.Tensor):
policy_action = policy_action.cpu().numpy()
olicy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
if not intervention_bool:
if self.use_delta_action_space:
target_joint_positions = self.current_joint_positions + self.delta * policy_action
@ -180,26 +209,26 @@ class HILSerlRobotEnv(gym.Env):
observation = self.robot.capture_observation()
else:
observation, teleop_action = self.robot.teleop_step(record_data=True)
teleop_action = teleop_action["action"] # teleop step returns torch tensors but in a dict
teleop_action = teleop_action["action"] # Convert tensor to appropriate format
# teleop actions are returned in absolute joint space
# If we are using a relative joint position action space,
# there will be a mismatch between the spaces of the policy and teleop actions
# Solution is to transform the teleop actions into relative space.
# teleop relative action is:
# When applying the delta action space, convert teleop absolute values to relative differences.
if self.use_delta_action_space:
teleop_action = teleop_action - self.current_joint_positions
if torch.any(teleop_action < -self.delta_relative_bounds_size * self.delta) and torch.any(
teleop_action > self.delta_relative_bounds_size
):
print(
f"relative teleop delta exceeded bounds {self.delta_relative_bounds_size}, teleop_action {teleop_action}\n"
f"Relative teleop delta exceeded bounds {self.delta_relative_bounds_size}, teleop_action {teleop_action}\n"
f"lower bounds condition {teleop_action < -self.delta_relative_bounds_size}\n"
f"upper bounds condition {teleop_action > self.delta_relative_bounds_size}"
)
teleop_action = torch.clamp(
teleop_action, -self.delta_relative_bounds_size, self.delta_relative_bounds_size
)
# NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action.
if teleop_action.dim() == 1:
teleop_action = teleop_action.unsqueeze(0)
self.current_step += 1
@ -217,7 +246,7 @@ class HILSerlRobotEnv(gym.Env):
def render(self):
"""
Render the environment (in this case, display camera feeds).
Render the current state of the environment by displaying the robot's camera feeds.
"""
import cv2
@ -231,7 +260,10 @@ class HILSerlRobotEnv(gym.Env):
def close(self):
"""
Close the environment and disconnect the robot.
Close the environment and clean up resources by disconnecting the robot.
If the robot is currently connected, this method properly terminates the connection to ensure that all
associated resources are released.
"""
if self.robot.is_connected:
self.robot.disconnect()
@ -250,48 +282,19 @@ class ActionRepeatWrapper(gym.Wrapper):
return obs, reward, done, truncated, info
class RelativeJointPositionActionWrapper(gym.Wrapper):
def __init__(
self,
env: HILSerlRobotEnv,
# output_normalization_params_action: dict[str, list[float]],
delta: float = 0.1,
):
super().__init__(env)
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
self.delta = delta
if delta > 1:
raise ValueError("Delta should be less than 1")
def step(self, action):
action_joint = action
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
if isinstance(self.env.action_space, gym.spaces.Tuple):
action_joint = action[0]
joint_positions = self.joint_positions + (self.delta * action_joint)
# clip the joint positions to the joint limits with the action space
joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high)
if isinstance(self.env.action_space, gym.spaces.Tuple):
return self.env.step((joint_positions, action[1]))
obs, reward, terminated, truncated, info = self.env.step(joint_positions)
if info["is_intervention"]:
# teleop actions are returned in absolute joint space
# If we are using a relative joint position action space,
# there will be a mismatch between the spaces of the policy and teleop actions
# Solution is to transform the teleop actions into relative space.
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
teleop_action = info["action_intervention"] # teleop actions are in absolute joint space
relative_teleop_action = (teleop_action - self.joint_positions) / self.delta
info["action_intervention"] = relative_teleop_action
return self.env.step(joint_positions)
class RewardWrapper(gym.Wrapper):
def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"):
def __init__(self, env, reward_classifier, device: torch.device = "cuda"):
"""
Wrapper to add reward prediction to the environment, it use a trained classifer.
Args:
env: The environment to wrap
reward_classifier: The reward classifier model
device: The device to run the model on
"""
self.env = env
# NOTE: We got 15% speedup by compiling the model
self.reward_classifier = torch.compile(reward_classifier)
self.device = device
@ -305,9 +308,7 @@ class RewardWrapper(gym.Wrapper):
reward = (
self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
)
# print(f"fps for reward classifier {1/(time.perf_counter() - start_time)}")
reward = reward.item()
# print(f"Reward from reward classifier {reward}")
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
@ -323,17 +324,23 @@ class TimeLimitWrapper(gym.Wrapper):
self.last_timestamp = 0.0
self.episode_time_in_s = 0.0
self.max_episode_steps = int(self.control_time_s * self.fps)
self.current_step = 0
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
time_since_last_step = time.perf_counter() - self.last_timestamp
# logging.warning(f"Current timestep is lower than the expected fps {self.fps}")
self.episode_time_in_s += time_since_last_step
self.last_timestamp = time.perf_counter()
self.current_step += 1
# check if last timestep took more time than the expected fps
if 1.0 / time_since_last_step < self.fps:
logging.warning(f"Current timestep is lower than the expected fps {self.fps}")
# if 1.0 / time_since_last_step < self.fps:
# logging.warning(f"Current timestep exceeded expected fps {self.fps}")
if self.episode_time_in_s > self.control_time_s:
# if self.current_step >= self.max_episode_steps:
# Terminated = True
terminated = True
return obs, reward, terminated, truncated, info
@ -341,11 +348,13 @@ class TimeLimitWrapper(gym.Wrapper):
def reset(self, seed=None, options=None):
self.episode_time_in_s = 0.0
self.last_timestamp = time.perf_counter()
self.current_step = 0
return self.env.reset(seed=seed, options=options)
class ImageCropResizeWrapper(gym.Wrapper):
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
super().__init__(env)
self.env = env
self.crop_params_dict = crop_params_dict
print(f"obs_keys , {self.env.observation_space}")
@ -372,10 +381,21 @@ class ImageCropResizeWrapper(gym.Wrapper):
obs[k] = F.resize(obs[k], self.resize_size)
obs[k] = obs[k].to(device)
# print(f"observation with key {k} with size {obs[k].size()}")
cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
# cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR))
# cv2.waitKey(1)
return obs, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options=options)
for k in self.crop_params_dict:
device = obs[k].device
if device == torch.device("mps:0"):
obs[k] = obs[k].cpu()
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
obs[k] = F.resize(obs[k], self.resize_size)
obs[k] = obs[k].to(device)
return obs, info
class ConvertToLeRobotObservation(gym.ObservationWrapper):
def __init__(self, env, device):
@ -515,42 +535,64 @@ class ResetWrapper(gym.Wrapper):
return super().reset(seed=seed, options=options)
class BatchCompitableWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
for key in observation:
if "image" in key and observation[key].dim() == 3:
observation[key] = observation[key].unsqueeze(0)
if "state" in key and observation[key].dim() == 1:
observation[key] = observation[key].unsqueeze(0)
return observation
def make_robot_env(
robot,
reward_classifier,
crop_params_dict=None,
fps=30,
control_time_s=20,
reset_follower_pos=True,
display_cameras=False,
device="cuda:0",
resize_size=None,
reset_time_s=10,
delta_action=0.1,
nb_repeats=1,
use_relative_joint_positions=False,
):
cfg,
n_envs: int = 1,
) -> gym.vector.VectorEnv:
"""
Factory function to create the robot environment.
Factory function to create a vectorized robot environment.
Mimics gym.make() for consistent environment creation.
Args:
robot: Robot instance to control
reward_classifier: Classifier model for computing rewards
cfg: Configuration object containing environment parameters
n_envs: Number of environments to create in parallel. Defaults to 1.
Returns:
A vectorized gym environment with all the necessary wrappers applied.
"""
# Create base environment
env = HILSerlRobotEnv(
robot,
display_cameras=display_cameras,
delta=delta_action,
use_delta_action_space=use_relative_joint_positions,
robot=robot,
display_cameras=cfg.wrapper.display_cameras,
delta=cfg.wrapper.delta_action,
use_delta_action_space=cfg.wrapper.use_relative_joint_positions,
)
env = ConvertToLeRobotObservation(env, device)
if crop_params_dict is not None:
env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
env = RewardWrapper(env, reward_classifier, device=device)
env = TimeLimitWrapper(env, control_time_s, fps)
# env = ActionRepeatWrapper(env, nb_repeat=nb_repeats)
env = KeyboardInterfaceWrapper(env)
env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s)
# Add observation and image processing
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
if cfg.wrapper.crop_params_dict is not None:
env = ImageCropResizeWrapper(
env=env, crop_params_dict=cfg.wrapper.crop_params_dict, resize_size=cfg.wrapper.resize_size
)
# Add reward computation and control wrappers
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
env = KeyboardInterfaceWrapper(env=env)
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.wrapper.reset_time_s)
env = BatchCompitableWrapper(env=env)
return env
# batched version of the env that returns an observation of shape (b, c)
def get_classifier(pretrained_path, config_path, device="mps"):
if pretrained_path is None or config_path is None:
@ -616,6 +658,8 @@ if __name__ == "__main__":
default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.",
)
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
parser.add_argument("--env-overrides", type=str, default=None, help="Overrides for the env yaml file")
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
args = parser.parse_args()
@ -626,72 +670,38 @@ if __name__ == "__main__":
reward_classifier = get_classifier(
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
)
crop_parameters = {
"observation.images.laptop": (58, 89, 357, 455),
"observation.images.phone": (3, 4, 471, 633),
}
user_relative_joint_positions = True
cfg = init_hydra_config(args.env_path, args.env_overrides)
env = make_robot_env(
robot,
reward_classifier,
crop_parameters,
args.fps,
args.control_time_s,
args.reset_follower_pos,
args.display_cameras,
device="mps",
resize_size=None,
reset_time_s=10,
delta_action=0.1,
nb_repeats=1,
use_relative_joint_positions=user_relative_joint_positions,
cfg.wrapper,
)
env.reset()
init_pos = env.unwrapped.initial_follower_position
right_goal = init_pos.copy()
right_goal[0] += 50
# Retrieve the robot's action space for joint commands.
action_space_robot = env.action_space.spaces[0]
left_goal = init_pos.copy()
left_goal[0] -= 50
# Initialize the smoothed action as a random sample.
smoothed_action = action_space_robot.sample()
pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000)
delta_angle = np.concatenate((-np.ones(50), np.ones(50))) * 100
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
alpha = 0.4
while True:
action = np.zeros(len(init_pos))
for i in range(len(delta_angle)):
start_loop_s = time.perf_counter()
action[0] = delta_angle[i]
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
if terminated or truncated:
env.reset()
start_loop_s = time.perf_counter()
# Sample a new random action from the robot's action space.
new_random_action = action_space_robot.sample()
# Update the smoothed action using an exponential moving average.
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
dt_s = time.perf_counter() - start_loop_s
busy_wait(1 / args.fps - dt_s)
# action = np.zeros(len(init_pos)) if user_relative_joint_positions else init_pos
# for i in range(len(pitch_angle)):
# if user_relative_joint_positions:
# action[0] = delta_angle[i]
# else:
# action[0] = pitch_angle[i]
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
# if terminated or truncated:
# logging.info("Max control time reached, reset environment.")
# env.reset()
# Execute the step: wrap the NumPy action in a torch tensor.
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
if terminated or truncated:
env.reset()
# for i in reversed(range(len(pitch_angle))):
# if user_relative_joint_positions:
# action[0] = delta_angle[i]
# else:
# action[0] = pitch_angle[i]
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
# if terminated or truncated:
# logging.info("Max control time reached, reset environment.")
# env.reset()
dt_s = time.perf_counter() - start_loop_s
busy_wait(1 / args.fps - dt_s)

View File

@ -36,6 +36,8 @@ from termcolor import colored
from torch import nn
from torch.optim.optimizer import Optimizer
from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger, log_output_dir
@ -52,6 +54,7 @@ from lerobot.common.utils.utils import (
)
from lerobot.scripts.server.buffer import (
ReplayBuffer,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
)
@ -259,8 +262,15 @@ def learner_push_parameters(
while True:
with policy_lock:
params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")}
if policy.config.vision_encoder_name is not None:
if policy.config.freeze_vision_encoder:
params_dict: dict[str, torch.Tensor] = {
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
}
else:
raise NotImplementedError(
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
)
params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize
@ -322,6 +332,7 @@ def add_actor_information_and_train(
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
# are divided by 200. So we need to have a single thread that does all the work.
time.time()
logging.info("Starting learner thread")
interaction_message, transition = None, None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
@ -340,16 +351,21 @@ def add_actor_information_and_train(
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
logging.info(f"Interaction message: {interaction_message}")
if len(replay_buffer) < cfg.training.online_step_before_learning:
continue
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
time_for_one_optimization_step = time.time()
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
# if cfg.offline_dataset_repo_id is not None:
# batch_offline = offline_replay_buffer.sample(batch_size)
# batch = concatenate_batch_transitions(batch, batch_offline)
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
@ -371,11 +387,11 @@ def add_actor_information_and_train(
batch = replay_buffer.sample(batch_size)
# if cfg.offline_dataset_repo_id is not None:
# batch_offline = offline_replay_buffer.sample(batch_size)
# batch = concatenate_batch_transitions(
# left_batch_transitions=batch, right_batch_transition=batch_offline
# )
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
rewards = batch["reward"]
@ -423,7 +439,7 @@ def add_actor_information_and_train(
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
logging.debug(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logger.log_dict(
{"Optimization frequency loop [Hz]": frequency_for_one_optimization_step},
@ -560,14 +576,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
batch_size = cfg.training.batch_size
offline_replay_buffer = None
# if cfg.dataset_repo_id is not None:
# logging.info("make_dataset offline buffer")
# offline_dataset = make_dataset(cfg)
# logging.info("Convertion to a offline replay buffer")
# offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
# offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
# )
# batch_size: int = batch_size // 2 # We will sample from both replay buffer
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
start_learner_threads(
cfg,

View File

@ -279,8 +279,10 @@ def train(cfg: DictConfig) -> None:
logging.info(f"Dataset size: {len(dataset)}")
train_size = int(cfg.train_split_proportion * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_dataset = dataset[:train_size]
val_dataset = dataset[train_size:]
sampler = create_balanced_sampler(train_dataset, cfg)
train_loader = DataLoader(