Several fixes to move the actor_server and learner_server code from the maniskill environment to the real robot environment.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
b63738674c
commit
d51374ce12
|
@ -39,6 +39,12 @@ class SACConfig:
|
||||||
"observation.environment_state": "min_max",
|
"observation.environment_state": "min_max",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
input_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"observation.image": {"mean": [[0.485, 0.456, 0.406]], "std": [[0.229, 0.224, 0.225]]},
|
||||||
|
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
|
||||||
|
}
|
||||||
|
)
|
||||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||||
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
|
|
|
@ -51,18 +51,20 @@ class SACPolicy(
|
||||||
if config is None:
|
if config is None:
|
||||||
config = SACConfig()
|
config = SACConfig()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
if config.input_normalization_modes is not None:
|
if config.input_normalization_modes is not None:
|
||||||
|
input_normalization_params = _convert_normalization_params_to_tensor(
|
||||||
|
config.input_normalization_params
|
||||||
|
)
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = Normalize(
|
||||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
config.input_shapes, config.input_normalization_modes, input_normalization_params
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.normalize_inputs = nn.Identity()
|
self.normalize_inputs = nn.Identity()
|
||||||
|
|
||||||
output_normalization_params = {}
|
output_normalization_params = _convert_normalization_params_to_tensor(
|
||||||
for outer_key, inner_dict in config.output_normalization_params.items():
|
config.output_normalization_params
|
||||||
output_normalization_params[outer_key] = {}
|
)
|
||||||
for key, value in inner_dict.items():
|
|
||||||
output_normalization_params[outer_key][key] = torch.tensor(value)
|
|
||||||
|
|
||||||
# HACK: This is hacky and should be removed
|
# HACK: This is hacky and should be removed
|
||||||
dataset_stats = dataset_stats or output_normalization_params
|
dataset_stats = dataset_stats or output_normalization_params
|
||||||
|
@ -75,7 +77,7 @@ class SACPolicy(
|
||||||
|
|
||||||
# NOTE: For images the encoder should be shared between the actor and critic
|
# NOTE: For images the encoder should be shared between the actor and critic
|
||||||
if config.shared_encoder:
|
if config.shared_encoder:
|
||||||
encoder_critic = SACObservationEncoder(config)
|
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||||
encoder_actor: SACObservationEncoder = encoder_critic
|
encoder_actor: SACObservationEncoder = encoder_critic
|
||||||
else:
|
else:
|
||||||
encoder_critic = SACObservationEncoder(config)
|
encoder_critic = SACObservationEncoder(config)
|
||||||
|
@ -92,6 +94,7 @@ class SACPolicy(
|
||||||
for _ in range(config.num_critics)
|
for _ in range(config.num_critics)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
output_normalization=self.normalize_targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.critic_target = CriticEnsemble(
|
self.critic_target = CriticEnsemble(
|
||||||
|
@ -105,6 +108,7 @@ class SACPolicy(
|
||||||
for _ in range(config.num_critics)
|
for _ in range(config.num_critics)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
output_normalization=self.normalize_targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||||
|
@ -122,7 +126,7 @@ class SACPolicy(
|
||||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||||
# it triggers "can't optimize a non-leaf Tensor"
|
# it triggers "can't optimize a non-leaf Tensor"
|
||||||
self.log_alpha = torch.zeros(1, requires_grad=True, device=torch.device("cuda:0"))
|
self.log_alpha = torch.tensor([0.0], requires_grad=True, device=torch.device("mps"))
|
||||||
self.temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
@ -313,12 +317,14 @@ class CriticEnsemble(nn.Module):
|
||||||
self,
|
self,
|
||||||
encoder: Optional[nn.Module],
|
encoder: Optional[nn.Module],
|
||||||
network_list: nn.ModuleList,
|
network_list: nn.ModuleList,
|
||||||
|
output_normalization: nn.Module,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.network_list = network_list
|
self.network_list = network_list
|
||||||
self.init_final = init_final
|
self.init_final = init_final
|
||||||
|
self.output_normalization = output_normalization
|
||||||
|
|
||||||
self.parameters_to_optimize = []
|
self.parameters_to_optimize = []
|
||||||
# Handle the case where a part of the encoder if frozen
|
# Handle the case where a part of the encoder if frozen
|
||||||
|
@ -358,6 +364,10 @@ class CriticEnsemble(nn.Module):
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
# Move each tensor in observations to device
|
# Move each tensor in observations to device
|
||||||
observations = {k: v.to(device) for k, v in observations.items()}
|
observations = {k: v.to(device) for k, v in observations.items()}
|
||||||
|
# NOTE: We normalize actions it helps for sample efficiency
|
||||||
|
actions: dict[str, torch.tensor] = {"action": actions}
|
||||||
|
# NOTE: Normalization layer took dict in input and outputs a dict that why
|
||||||
|
actions = self.output_normalization(actions)["action"]
|
||||||
actions = actions.to(device)
|
actions = actions.to(device)
|
||||||
|
|
||||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||||
|
@ -474,17 +484,18 @@ class Policy(nn.Module):
|
||||||
class SACObservationEncoder(nn.Module):
|
class SACObservationEncoder(nn.Module):
|
||||||
"""Encode image and/or state vector observations."""
|
"""Encode image and/or state vector observations."""
|
||||||
|
|
||||||
def __init__(self, config: SACConfig):
|
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
||||||
"""
|
"""
|
||||||
Creates encoders for pixel and/or state modalities.
|
Creates encoders for pixel and/or state modalities.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.input_normalization = input_normalizer
|
||||||
self.has_pretrained_vision_encoder = False
|
self.has_pretrained_vision_encoder = False
|
||||||
self.parameters_to_optimize = []
|
self.parameters_to_optimize = []
|
||||||
|
|
||||||
self.aggregation_size: int = 0
|
self.aggregation_size: int = 0
|
||||||
if "observation.image" in config.input_shapes:
|
if any("observation.image" in key for key in config.input_shapes):
|
||||||
self.camera_number = config.camera_number
|
self.camera_number = config.camera_number
|
||||||
|
|
||||||
if self.config.vision_encoder_name is not None:
|
if self.config.vision_encoder_name is not None:
|
||||||
|
@ -534,8 +545,9 @@ class SACObservationEncoder(nn.Module):
|
||||||
over all features.
|
over all features.
|
||||||
"""
|
"""
|
||||||
feat = []
|
feat = []
|
||||||
|
obs_dict = self.input_normalization(obs_dict)
|
||||||
# Concatenate all images along the channel dimension.
|
# Concatenate all images along the channel dimension.
|
||||||
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
|
image_keys = [k for k in obs_dict if k.startswith("observation.image")]
|
||||||
for image_key in image_keys:
|
for image_key in image_keys:
|
||||||
enc_feat = self.image_enc_layers(obs_dict[image_key])
|
enc_feat = self.image_enc_layers(obs_dict[image_key])
|
||||||
|
|
||||||
|
@ -681,6 +693,18 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
||||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||||
|
converted_params = {}
|
||||||
|
for outer_key, inner_dict in normalization_params.items():
|
||||||
|
converted_params[outer_key] = {}
|
||||||
|
for key, value in inner_dict.items():
|
||||||
|
converted_params[outer_key][key] = torch.tensor(value)
|
||||||
|
if "image" in outer_key:
|
||||||
|
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||||
|
|
||||||
|
return converted_params
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Test the SACObservationEncoder
|
# Test the SACObservationEncoder
|
||||||
import time
|
import time
|
||||||
|
|
|
@ -18,6 +18,7 @@ import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -217,3 +218,28 @@ def log_say(text, play_sounds, blocking=False):
|
||||||
|
|
||||||
if play_sounds:
|
if play_sounds:
|
||||||
say(text, blocking)
|
say(text, blocking)
|
||||||
|
|
||||||
|
|
||||||
|
class TimerManager:
|
||||||
|
def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True):
|
||||||
|
self.label = label
|
||||||
|
self.elapsed_time_list = elapsed_time_list
|
||||||
|
self.log = log
|
||||||
|
self.elapsed = 0.0
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start = time.perf_counter()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.elapsed: float = time.perf_counter() - self.start
|
||||||
|
|
||||||
|
if self.elapsed_time_list is not None:
|
||||||
|
self.elapsed_time_list.append(self.elapsed)
|
||||||
|
|
||||||
|
if self.log:
|
||||||
|
print(f"{self.label}: {self.elapsed:.6f} seconds")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def elapsed_seconds(self):
|
||||||
|
return self.elapsed
|
||||||
|
|
|
@ -2,6 +2,7 @@ defaults:
|
||||||
- _self_
|
- _self_
|
||||||
- env: pusht
|
- env: pusht
|
||||||
- policy: diffusion
|
- policy: diffusion
|
||||||
|
- robot: so100
|
||||||
|
|
||||||
hydra:
|
hydra:
|
||||||
run:
|
run:
|
||||||
|
|
|
@ -8,3 +8,20 @@ env:
|
||||||
state_dim: 6
|
state_dim: 6
|
||||||
action_dim: 6
|
action_dim: 6
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
|
device: mps
|
||||||
|
|
||||||
|
wrapper:
|
||||||
|
crop_params_dict:
|
||||||
|
observation.images.laptop: [58, 89, 357, 455]
|
||||||
|
observation.images.phone: [3, 4, 471, 633]
|
||||||
|
resize_size: [128, 128]
|
||||||
|
control_time_s: 20
|
||||||
|
reset_follower_pos: true
|
||||||
|
use_relative_joint_positions: true
|
||||||
|
reset_time_s: 10
|
||||||
|
display_cameras: false
|
||||||
|
delta_action: 0.1
|
||||||
|
|
||||||
|
reward_classifier:
|
||||||
|
pretrained_path: outputs/classifier/checkpoints/best/pretrained_model
|
||||||
|
config_path: lerobot/configs/policy/hilserl_classifier.yaml
|
||||||
|
|
|
@ -31,16 +31,21 @@ from omegaconf import DictConfig
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
# TODO: Remove the import of maniskill
|
# TODO: Remove the import of maniskill
|
||||||
from lerobot.common.envs.factory import make_maniskill_env
|
# from lerobot.common.envs.factory import make_maniskill_env
|
||||||
from lerobot.common.envs.utils import preprocess_maniskill_observation
|
# from lerobot.common.envs.utils import preprocess_maniskill_observation
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
|
from lerobot.common.robot_devices.control_utils import busy_wait
|
||||||
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
|
TimerManager,
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
set_global_seed,
|
set_global_seed,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
|
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
|
||||||
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
|
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
|
||||||
|
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
@ -152,7 +157,15 @@ def serve_actor_service(port=50052):
|
||||||
server.wait_for_termination()
|
server.wait_for_termination()
|
||||||
|
|
||||||
|
|
||||||
def act_with_policy(cfg: DictConfig):
|
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
|
||||||
|
if not parameters_queue.empty():
|
||||||
|
logging.debug("[ACTOR] Load new parameters from Learner.")
|
||||||
|
state_dict = parameters_queue.get()
|
||||||
|
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||||
|
policy.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
|
||||||
"""
|
"""
|
||||||
Executes policy interaction within the environment.
|
Executes policy interaction within the environment.
|
||||||
|
|
||||||
|
@ -165,9 +178,7 @@ def act_with_policy(cfg: DictConfig):
|
||||||
|
|
||||||
logging.info("make_env online")
|
logging.info("make_env online")
|
||||||
|
|
||||||
# online_env = make_env(cfg, n_envs=1)
|
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg.env)
|
||||||
# TODO: Remove the import of maniskill and unifiy with make env
|
|
||||||
online_env = make_maniskill_env(cfg, n_envs=1)
|
|
||||||
|
|
||||||
set_global_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.device, log=True)
|
||||||
|
@ -177,6 +188,16 @@ def act_with_policy(cfg: DictConfig):
|
||||||
|
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
|
|
||||||
|
# HACK: This is an ugly hack to pass the normalization parameters to the policy
|
||||||
|
# Because the action space is dynamic so we override the output normalization parameters
|
||||||
|
# it's ugly, we know ... and we will fix it
|
||||||
|
min_action_space: list = online_env.action_space.spaces[0].low.tolist()
|
||||||
|
max_action_space: list = online_env.action_space.spaces[0].high.tolist()
|
||||||
|
output_normalization_params: dict[dict[str, list]] = {
|
||||||
|
"action": {"min": min_action_space, "max": max_action_space}
|
||||||
|
}
|
||||||
|
cfg.policy.output_normalization_params = output_normalization_params
|
||||||
|
|
||||||
### Instantiate the policy in both the actor and learner processes
|
### Instantiate the policy in both the actor and learner processes
|
||||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||||
|
@ -187,92 +208,41 @@ def act_with_policy(cfg: DictConfig):
|
||||||
# Hack: But if we do online training, we do not need dataset_stats
|
# Hack: But if we do online training, we do not need dataset_stats
|
||||||
dataset_stats=None,
|
dataset_stats=None,
|
||||||
# TODO: Handle resume training
|
# TODO: Handle resume training
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
# pretrained_policy_name_or_path=None,
|
|
||||||
# device=device,
|
|
||||||
# )
|
|
||||||
policy = torch.compile(policy)
|
policy = torch.compile(policy)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
# HACK for maniskill
|
|
||||||
obs, info = online_env.reset()
|
obs, info = online_env.reset()
|
||||||
|
|
||||||
# obs = preprocess_observation(obs)
|
|
||||||
obs = preprocess_maniskill_observation(obs)
|
|
||||||
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
|
|
||||||
|
|
||||||
# NOTE: For the moment we will solely handle the case of a single environment
|
# NOTE: For the moment we will solely handle the case of a single environment
|
||||||
sum_reward_episode = 0
|
sum_reward_episode = 0
|
||||||
list_transition_to_send_to_learner = []
|
list_transition_to_send_to_learner = []
|
||||||
list_policy_fps = []
|
list_policy_time = []
|
||||||
|
|
||||||
for interaction_step in range(cfg.training.online_steps):
|
for interaction_step in range(cfg.training.online_steps):
|
||||||
if interaction_step >= cfg.training.online_step_before_learning:
|
if interaction_step >= cfg.training.online_step_before_learning:
|
||||||
start = time.perf_counter()
|
# Time policy inference and check if it meets FPS requirement
|
||||||
action = policy.select_action(batch=obs)
|
with TimerManager(
|
||||||
list_policy_fps.append(1.0 / (time.perf_counter() - start + 1e-9))
|
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
|
||||||
if list_policy_fps[-1] < cfg.fps:
|
) as timer: # noqa: F841
|
||||||
logging.warning(
|
action = policy.select_action(batch=obs) * 0.0
|
||||||
f"[ACTOR] policy frame rate {list_policy_fps[-1]} during interaction step {interaction_step} is below the required control frame rate {cfg.fps}"
|
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||||
)
|
|
||||||
|
|
||||||
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
|
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||||
|
|
||||||
|
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
||||||
else:
|
else:
|
||||||
|
# TODO (azouitine): Make a custom space for torch tensor
|
||||||
action = online_env.action_space.sample()
|
action = online_env.action_space.sample()
|
||||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||||
# HACK
|
|
||||||
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
|
|
||||||
|
|
||||||
# HACK: For maniskill
|
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||||
# next_obs = preprocess_observation(next_obs)
|
action = torch.from_numpy(action[0]).to(device, non_blocking=True).unsqueeze(dim=0)
|
||||||
next_obs = preprocess_maniskill_observation(next_obs)
|
|
||||||
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
|
|
||||||
sum_reward_episode += float(reward[0])
|
|
||||||
|
|
||||||
# Because we are using a single environment we can index at zero
|
sum_reward_episode += float(reward)
|
||||||
if done[0].item() or truncated[0].item():
|
|
||||||
# TODO: Handle logging for episode information
|
|
||||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
|
||||||
|
|
||||||
if not parameters_queue.empty():
|
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
|
||||||
logging.debug("[ACTOR] Load new parameters from Learner.")
|
|
||||||
state_dict = parameters_queue.get()
|
|
||||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
|
||||||
# strict=False for the case when the image encoder is frozen and not sent through
|
|
||||||
# the network. Becareful might cause issues if the wrong keys are passed
|
|
||||||
policy.actor.load_state_dict(state_dict, strict=False)
|
|
||||||
|
|
||||||
if len(list_transition_to_send_to_learner) > 0:
|
|
||||||
logging.debug(
|
|
||||||
f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner."
|
|
||||||
)
|
|
||||||
message_queue.put(ActorInformation(transition=list_transition_to_send_to_learner))
|
|
||||||
list_transition_to_send_to_learner = []
|
|
||||||
|
|
||||||
stats = {}
|
|
||||||
if len(list_policy_fps) > 0:
|
|
||||||
policy_fps = mean(list_policy_fps)
|
|
||||||
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
|
|
||||||
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
|
|
||||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
|
|
||||||
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
|
|
||||||
list_policy_fps = []
|
|
||||||
|
|
||||||
# Send episodic reward to the learner
|
|
||||||
message_queue.put(
|
|
||||||
ActorInformation(
|
|
||||||
interaction_message={
|
|
||||||
"Episodic reward": sum_reward_episode,
|
|
||||||
"Interaction step": interaction_step,
|
|
||||||
**stats,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
sum_reward_episode = 0.0
|
|
||||||
|
|
||||||
# TODO (michel-aractingi): Label the reward
|
|
||||||
# if config.label_reward_on_actor:
|
|
||||||
# reward = reward_classifier(obs)
|
|
||||||
if info["is_intervention"]:
|
if info["is_intervention"]:
|
||||||
# TODO: Check the shape
|
# TODO: Check the shape
|
||||||
action = info["action_intervention"]
|
action = info["action_intervention"]
|
||||||
|
@ -291,17 +261,85 @@ def act_with_policy(cfg: DictConfig):
|
||||||
# assign obs to the next obs and continue the rollout
|
# assign obs to the next obs and continue the rollout
|
||||||
obs = next_obs
|
obs = next_obs
|
||||||
|
|
||||||
|
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||||
|
# Because we are using a single environment we can index at zero
|
||||||
|
if done or truncated:
|
||||||
|
# TODO: Handle logging for episode information
|
||||||
|
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||||
|
|
||||||
|
# update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
|
||||||
|
|
||||||
|
if len(list_transition_to_send_to_learner) > 0:
|
||||||
|
send_transitions_in_chunks(
|
||||||
|
transitions=list_transition_to_send_to_learner, message_queue=message_queue, chunk_size=4
|
||||||
|
)
|
||||||
|
list_transition_to_send_to_learner = []
|
||||||
|
|
||||||
|
stats = get_frequency_stats(list_policy_time)
|
||||||
|
list_policy_time.clear()
|
||||||
|
|
||||||
|
# Send episodic reward to the learner
|
||||||
|
message_queue.put(
|
||||||
|
ActorInformation(
|
||||||
|
interaction_message={
|
||||||
|
"Episodic reward": sum_reward_episode,
|
||||||
|
"Interaction step": interaction_step,
|
||||||
|
**stats,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sum_reward_episode = 0.0
|
||||||
|
obs, info = online_env.reset()
|
||||||
|
|
||||||
|
|
||||||
|
def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int = 100):
|
||||||
|
"""Send transitions to learner in smaller chunks to avoid network issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transitions: List of transitions to send
|
||||||
|
message_queue: Queue to send messages to learner
|
||||||
|
chunk_size: Size of each chunk to send
|
||||||
|
"""
|
||||||
|
for i in range(0, len(transitions), chunk_size):
|
||||||
|
chunk = transitions[i : i + chunk_size]
|
||||||
|
logging.debug(f"[ACTOR] Sending chunk of {len(chunk)} transitions to Learner.")
|
||||||
|
message_queue.put(ActorInformation(transition=chunk))
|
||||||
|
|
||||||
|
|
||||||
|
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||||
|
stats = {}
|
||||||
|
list_policy_fps = [1.0 / t for t in list_policy_time]
|
||||||
|
if len(list_policy_fps) > 0:
|
||||||
|
policy_fps = mean(list_policy_fps)
|
||||||
|
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
|
||||||
|
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
|
||||||
|
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
|
||||||
|
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
|
||||||
|
if policy_fps < cfg.fps:
|
||||||
|
logging.warning(
|
||||||
|
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||||
def actor_cli(cfg: dict):
|
def actor_cli(cfg: dict):
|
||||||
port = cfg.actor_learner_config.port
|
robot = make_robot(cfg=cfg.robot)
|
||||||
server_thread = Thread(target=serve_actor_service, args=(port,), daemon=True)
|
|
||||||
server_thread.start()
|
server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True)
|
||||||
|
reward_classifier = get_classifier(
|
||||||
|
pretrained_path=cfg.env.reward_classifier.pretrained_path,
|
||||||
|
config_path=cfg.env.reward_classifier.config_path,
|
||||||
|
)
|
||||||
policy_thread = Thread(
|
policy_thread = Thread(
|
||||||
target=act_with_policy,
|
target=act_with_policy,
|
||||||
daemon=True,
|
daemon=True,
|
||||||
args=(cfg,),
|
args=(cfg, robot, reward_classifier),
|
||||||
)
|
)
|
||||||
|
server_thread.start()
|
||||||
policy_thread.start()
|
policy_thread.start()
|
||||||
policy_thread.join()
|
policy_thread.join()
|
||||||
server_thread.join()
|
server_thread.join()
|
||||||
|
|
|
@ -56,10 +56,10 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
|
||||||
}
|
}
|
||||||
|
|
||||||
# If complementary_info is present, move its tensors to CPU
|
# If complementary_info is present, move its tensors to CPU
|
||||||
if transition["complementary_info"] is not None:
|
# if transition["complementary_info"] is not None:
|
||||||
transition["complementary_info"] = {
|
# transition["complementary_info"] = {
|
||||||
key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
|
# key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
|
||||||
}
|
# }
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
|
|
||||||
|
@ -309,6 +309,7 @@ class ReplayBuffer:
|
||||||
|
|
||||||
def sample(self, batch_size: int) -> BatchTransition:
|
def sample(self, batch_size: int) -> BatchTransition:
|
||||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||||
|
batch_size = min(batch_size, len(self.memory))
|
||||||
list_of_transitions = random.sample(self.memory, batch_size)
|
list_of_transitions = random.sample(self.memory, batch_size)
|
||||||
|
|
||||||
# -- Build batched states --
|
# -- Build batched states --
|
||||||
|
@ -341,9 +342,6 @@ class ReplayBuffer:
|
||||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
|
||||||
self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return a BatchTransition typed dict
|
# Return a BatchTransition typed dict
|
||||||
return BatchTransition(
|
return BatchTransition(
|
||||||
|
@ -531,30 +529,31 @@ def concatenate_batch_transitions(
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
# dataset_name = "lerobot/pusht_image"
|
# dataset_name = "aractingi/push_green_cube_hf_cropped_resized"
|
||||||
# dataset = LeRobotDataset(repo_id=dataset_name, episodes=range(1, 3))
|
# dataset = LeRobotDataset(repo_id=dataset_name)
|
||||||
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
|
||||||
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
|
|
||||||
# )
|
|
||||||
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
|
|
||||||
# for i in range(len(replay_buffer_converted)):
|
|
||||||
# replay_convert = replay_buffer_converted[i]
|
|
||||||
# dataset_convert = dataset[i]
|
|
||||||
# for key in replay_convert.keys():
|
|
||||||
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
|
|
||||||
# continue
|
|
||||||
# if key in dataset_convert.keys():
|
|
||||||
# assert torch.equal(replay_convert[key], dataset_convert[key])
|
|
||||||
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
|
|
||||||
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
|
|
||||||
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
|
|
||||||
# )
|
|
||||||
# for _ in range(20):
|
|
||||||
# batch = re_reconverted_dataset.sample(32)
|
|
||||||
|
|
||||||
# for key in batch.keys():
|
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||||
# if key in {"state", "next_state"}:
|
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
|
||||||
# for key_state in batch[key].keys():
|
# )
|
||||||
# print(key_state, batch[key][key_state].size())
|
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
|
||||||
# continue
|
# for i in range(len(replay_buffer_converted)):
|
||||||
# print(key, batch[key].size())
|
# replay_convert = replay_buffer_converted[i]
|
||||||
|
# dataset_convert = dataset[i]
|
||||||
|
# for key in replay_convert.keys():
|
||||||
|
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
|
||||||
|
# continue
|
||||||
|
# if key in dataset_convert.keys():
|
||||||
|
# assert torch.equal(replay_convert[key], dataset_convert[key])
|
||||||
|
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
|
||||||
|
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
|
||||||
|
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
|
||||||
|
# )
|
||||||
|
# for _ in range(20):
|
||||||
|
# batch = re_reconverted_dataset.sample(32)
|
||||||
|
|
||||||
|
# for key in batch.keys():
|
||||||
|
# if key in {"state", "next_state"}:
|
||||||
|
# for key_state in batch[key].keys():
|
||||||
|
# print(key_state, batch[key][key_state].size())
|
||||||
|
# continue
|
||||||
|
# print(key, batch[key].size())
|
||||||
|
|
|
@ -4,7 +4,6 @@ import time
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
|
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import cv2
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -20,10 +19,15 @@ logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
class HILSerlRobotEnv(gym.Env):
|
class HILSerlRobotEnv(gym.Env):
|
||||||
"""
|
"""
|
||||||
Gym-like environment wrapper for robot policy evaluation.
|
Gym-compatible environment for evaluating robotic control policies with integrated human intervention.
|
||||||
|
|
||||||
This wrapper provides a consistent interface for interacting with the robot,
|
This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta)
|
||||||
following the OpenAI Gym environment conventions.
|
and absolute joint position commands and automatically configures its observation and action spaces based on the robot's
|
||||||
|
sensors and configuration.
|
||||||
|
|
||||||
|
The environment can switch between executing actions from a policy or using teleoperated actions (human intervention) during
|
||||||
|
each step. When teleoperation is used, the override action is captured and returned in the `info` dict along with a flag
|
||||||
|
`is_intervention`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -31,32 +35,34 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
robot,
|
robot,
|
||||||
use_delta_action_space: bool = True,
|
use_delta_action_space: bool = True,
|
||||||
delta: float | None = None,
|
delta: float | None = None,
|
||||||
display_cameras=False,
|
display_cameras: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the robot environment.
|
Initialize the HILSerlRobotEnv environment.
|
||||||
|
|
||||||
|
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
|
||||||
|
supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
robot: The robot interface object
|
robot: The robot interface object used to connect and interact with the physical robot.
|
||||||
reward_classifier: Optional reward classifier
|
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
|
||||||
fps: Frames per second for control
|
joint positions are used.
|
||||||
control_time_s: Total control time for each episode
|
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
|
||||||
display_cameras: Whether to display camera feeds
|
0 and 1 when using a delta action space.
|
||||||
output_normalization_params_action: Bound parameters for the action space
|
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
|
||||||
delta: The delta for the relative joint position action space
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.robot = robot
|
self.robot = robot
|
||||||
self.display_cameras = display_cameras
|
self.display_cameras = display_cameras
|
||||||
|
|
||||||
# connect robot
|
# Connect to the robot if not already connected.
|
||||||
if not self.robot.is_connected:
|
if not self.robot.is_connected:
|
||||||
self.robot.connect()
|
self.robot.connect()
|
||||||
|
|
||||||
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
|
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
|
||||||
|
|
||||||
# Episode tracking
|
# Episode tracking.
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.episode_data = None
|
self.episode_data = None
|
||||||
|
|
||||||
|
@ -64,6 +70,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
self.use_delta_action_space = use_delta_action_space
|
self.use_delta_action_space = use_delta_action_space
|
||||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||||
|
|
||||||
|
# Retrieve the size of the joint position interval bound.
|
||||||
self.relative_bounds_size = (
|
self.relative_bounds_size = (
|
||||||
self.robot.config.joint_position_relative_bounds["max"]
|
self.robot.config.joint_position_relative_bounds["max"]
|
||||||
- self.robot.config.joint_position_relative_bounds["min"]
|
- self.robot.config.joint_position_relative_bounds["min"]
|
||||||
|
@ -73,20 +80,26 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
|
|
||||||
self.robot.config.max_relative_target = self.delta_relative_bounds_size.float()
|
self.robot.config.max_relative_target = self.delta_relative_bounds_size.float()
|
||||||
|
|
||||||
# Dynamically determine observation and action spaces
|
# Dynamically configure the observation and action spaces.
|
||||||
self._setup_spaces()
|
self._setup_spaces()
|
||||||
|
|
||||||
def _setup_spaces(self):
|
def _setup_spaces(self):
|
||||||
"""
|
"""
|
||||||
Dynamically determine observation and action spaces based on robot capabilities.
|
Dynamically configure the observation and action spaces based on the robot's capabilities.
|
||||||
|
|
||||||
This method should be customized based on the specific robot's observation
|
Observation Space:
|
||||||
and action representations.
|
- For keys with "image": A Box space with pixel values ranging from 0 to 255.
|
||||||
|
- For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range.
|
||||||
|
|
||||||
|
Action Space:
|
||||||
|
- The action space is defined as a Tuple where:
|
||||||
|
• The first element is a Box space representing joint position commands. It is defined as relative (delta)
|
||||||
|
or absolute, based on the configuration.
|
||||||
|
• The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation).
|
||||||
"""
|
"""
|
||||||
# Example space setup - you'll need to adapt this to your specific robot
|
|
||||||
example_obs = self.robot.capture_observation()
|
example_obs = self.robot.capture_observation()
|
||||||
|
|
||||||
# Observation space (assuming image-based observations)
|
# Define observation spaces for images and other states.
|
||||||
image_keys = [key for key in example_obs if "image" in key]
|
image_keys = [key for key in example_obs if "image" in key]
|
||||||
state_keys = [key for key in example_obs if "image" not in key]
|
state_keys = [key for key in example_obs if "image" not in key]
|
||||||
observation_spaces = {
|
observation_spaces = {
|
||||||
|
@ -102,7 +115,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
|
|
||||||
self.observation_space = gym.spaces.Dict(observation_spaces)
|
self.observation_space = gym.spaces.Dict(observation_spaces)
|
||||||
|
|
||||||
# Action space (assuming joint positions)
|
# Define the action space for joint positions along with setting an intervention flag.
|
||||||
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
|
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
|
||||||
if self.use_delta_action_space:
|
if self.use_delta_action_space:
|
||||||
action_space_robot = gym.spaces.Box(
|
action_space_robot = gym.spaces.Box(
|
||||||
|
@ -128,18 +141,24 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
|
|
||||||
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Reset the environment to initial state.
|
Reset the environment to its initial state.
|
||||||
|
This method resets the step counter and clears any episodic data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (Optional[int]): A seed for random number generation to ensure reproducibility.
|
||||||
|
options (Optional[dict]): Additional options to influence the reset behavior.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
observation (dict): Initial observation
|
A tuple containing:
|
||||||
info (dict): Additional information
|
- observation (dict): The initial sensor observation.
|
||||||
|
- info (dict): A dictionary with supplementary information, including the key "initial_position".
|
||||||
"""
|
"""
|
||||||
super().reset(seed=seed, options=options)
|
super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
# Capture initial observation
|
# Capture the initial observation.
|
||||||
observation = self.robot.capture_observation()
|
observation = self.robot.capture_observation()
|
||||||
|
|
||||||
# Reset tracking variables
|
# Reset episode tracking variables.
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.episode_data = None
|
self.episode_data = None
|
||||||
|
|
||||||
|
@ -149,28 +168,38 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
self, action: Tuple[np.ndarray, bool]
|
self, action: Tuple[np.ndarray, bool]
|
||||||
) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
|
) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Take a step in the environment.
|
Execute a single step within the environment using the specified action.
|
||||||
|
|
||||||
|
The provided action is a tuple comprised of:
|
||||||
|
• A policy action (joint position commands) that may be either in absolute values or as a delta.
|
||||||
|
• A boolean flag indicating whether teleoperation (human intervention) should be used for this step.
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
- When the intervention flag is False, the environment processes and sends the policy action to the robot.
|
||||||
|
- When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted
|
||||||
|
to relative change based on the current joint positions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
action tuple(np.ndarray, bool):
|
action (tuple): A tuple with two elements:
|
||||||
Policy action to be executed on the robot and boolean to determine
|
- policy_action (np.ndarray or torch.Tensor): The commanded joint positions.
|
||||||
whether to choose policy action or expert action.
|
- intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
observation (dict): Next observation
|
tuple: A tuple containing:
|
||||||
reward (float): Reward for this step
|
- observation (dict): The new sensor observation after taking the step.
|
||||||
terminated (bool): Whether the episode has terminated
|
- reward (float): The step reward (default is 0.0 within this wrapper).
|
||||||
truncated (bool): Whether the episode was truncated
|
- terminated (bool): True if the episode has reached a terminal state.
|
||||||
info (dict): Additional information
|
- truncated (bool): True if the episode was truncated (e.g., time constraints).
|
||||||
|
- info (dict): Additional debugging information including:
|
||||||
|
◦ "action_intervention": The teleop action if intervention was used.
|
||||||
|
◦ "is_intervention": Flag indicating whether teleoperation was employed.
|
||||||
"""
|
"""
|
||||||
# The actions recieved are the in form of a tuple containing the policy action and an intervention bool
|
|
||||||
# The boolean inidicated whether we will use the expert's actions (through teleoperation) or the policy actions
|
|
||||||
policy_action, intervention_bool = action
|
policy_action, intervention_bool = action
|
||||||
teleop_action = None
|
teleop_action = None
|
||||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||||
if isinstance(policy_action, torch.Tensor):
|
if isinstance(policy_action, torch.Tensor):
|
||||||
policy_action = policy_action.cpu().numpy()
|
policy_action = policy_action.cpu().numpy()
|
||||||
olicy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
|
policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
|
||||||
if not intervention_bool:
|
if not intervention_bool:
|
||||||
if self.use_delta_action_space:
|
if self.use_delta_action_space:
|
||||||
target_joint_positions = self.current_joint_positions + self.delta * policy_action
|
target_joint_positions = self.current_joint_positions + self.delta * policy_action
|
||||||
|
@ -180,26 +209,26 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
observation = self.robot.capture_observation()
|
observation = self.robot.capture_observation()
|
||||||
else:
|
else:
|
||||||
observation, teleop_action = self.robot.teleop_step(record_data=True)
|
observation, teleop_action = self.robot.teleop_step(record_data=True)
|
||||||
teleop_action = teleop_action["action"] # teleop step returns torch tensors but in a dict
|
teleop_action = teleop_action["action"] # Convert tensor to appropriate format
|
||||||
|
|
||||||
# teleop actions are returned in absolute joint space
|
# When applying the delta action space, convert teleop absolute values to relative differences.
|
||||||
# If we are using a relative joint position action space,
|
|
||||||
# there will be a mismatch between the spaces of the policy and teleop actions
|
|
||||||
# Solution is to transform the teleop actions into relative space.
|
|
||||||
# teleop relative action is:
|
|
||||||
if self.use_delta_action_space:
|
if self.use_delta_action_space:
|
||||||
teleop_action = teleop_action - self.current_joint_positions
|
teleop_action = teleop_action - self.current_joint_positions
|
||||||
if torch.any(teleop_action < -self.delta_relative_bounds_size * self.delta) and torch.any(
|
if torch.any(teleop_action < -self.delta_relative_bounds_size * self.delta) and torch.any(
|
||||||
teleop_action > self.delta_relative_bounds_size
|
teleop_action > self.delta_relative_bounds_size
|
||||||
):
|
):
|
||||||
print(
|
print(
|
||||||
f"relative teleop delta exceeded bounds {self.delta_relative_bounds_size}, teleop_action {teleop_action}\n"
|
f"Relative teleop delta exceeded bounds {self.delta_relative_bounds_size}, teleop_action {teleop_action}\n"
|
||||||
f"lower bounds condition {teleop_action < -self.delta_relative_bounds_size}\n"
|
f"lower bounds condition {teleop_action < -self.delta_relative_bounds_size}\n"
|
||||||
f"upper bounds condition {teleop_action > self.delta_relative_bounds_size}"
|
f"upper bounds condition {teleop_action > self.delta_relative_bounds_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
teleop_action = torch.clamp(
|
teleop_action = torch.clamp(
|
||||||
teleop_action, -self.delta_relative_bounds_size, self.delta_relative_bounds_size
|
teleop_action, -self.delta_relative_bounds_size, self.delta_relative_bounds_size
|
||||||
)
|
)
|
||||||
|
# NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action.
|
||||||
|
if teleop_action.dim() == 1:
|
||||||
|
teleop_action = teleop_action.unsqueeze(0)
|
||||||
|
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
|
|
||||||
|
@ -217,7 +246,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
"""
|
"""
|
||||||
Render the environment (in this case, display camera feeds).
|
Render the current state of the environment by displaying the robot's camera feeds.
|
||||||
"""
|
"""
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
@ -231,7 +260,10 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""
|
"""
|
||||||
Close the environment and disconnect the robot.
|
Close the environment and clean up resources by disconnecting the robot.
|
||||||
|
|
||||||
|
If the robot is currently connected, this method properly terminates the connection to ensure that all
|
||||||
|
associated resources are released.
|
||||||
"""
|
"""
|
||||||
if self.robot.is_connected:
|
if self.robot.is_connected:
|
||||||
self.robot.disconnect()
|
self.robot.disconnect()
|
||||||
|
@ -250,48 +282,19 @@ class ActionRepeatWrapper(gym.Wrapper):
|
||||||
return obs, reward, done, truncated, info
|
return obs, reward, done, truncated, info
|
||||||
|
|
||||||
|
|
||||||
class RelativeJointPositionActionWrapper(gym.Wrapper):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
env: HILSerlRobotEnv,
|
|
||||||
# output_normalization_params_action: dict[str, list[float]],
|
|
||||||
delta: float = 0.1,
|
|
||||||
):
|
|
||||||
super().__init__(env)
|
|
||||||
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
|
||||||
self.delta = delta
|
|
||||||
if delta > 1:
|
|
||||||
raise ValueError("Delta should be less than 1")
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
action_joint = action
|
|
||||||
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
|
||||||
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
|
||||||
action_joint = action[0]
|
|
||||||
joint_positions = self.joint_positions + (self.delta * action_joint)
|
|
||||||
# clip the joint positions to the joint limits with the action space
|
|
||||||
joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high)
|
|
||||||
|
|
||||||
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
|
||||||
return self.env.step((joint_positions, action[1]))
|
|
||||||
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(joint_positions)
|
|
||||||
if info["is_intervention"]:
|
|
||||||
# teleop actions are returned in absolute joint space
|
|
||||||
# If we are using a relative joint position action space,
|
|
||||||
# there will be a mismatch between the spaces of the policy and teleop actions
|
|
||||||
# Solution is to transform the teleop actions into relative space.
|
|
||||||
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
|
||||||
teleop_action = info["action_intervention"] # teleop actions are in absolute joint space
|
|
||||||
relative_teleop_action = (teleop_action - self.joint_positions) / self.delta
|
|
||||||
info["action_intervention"] = relative_teleop_action
|
|
||||||
|
|
||||||
return self.env.step(joint_positions)
|
|
||||||
|
|
||||||
|
|
||||||
class RewardWrapper(gym.Wrapper):
|
class RewardWrapper(gym.Wrapper):
|
||||||
def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"):
|
def __init__(self, env, reward_classifier, device: torch.device = "cuda"):
|
||||||
|
"""
|
||||||
|
Wrapper to add reward prediction to the environment, it use a trained classifer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap
|
||||||
|
reward_classifier: The reward classifier model
|
||||||
|
device: The device to run the model on
|
||||||
|
"""
|
||||||
self.env = env
|
self.env = env
|
||||||
|
|
||||||
|
# NOTE: We got 15% speedup by compiling the model
|
||||||
self.reward_classifier = torch.compile(reward_classifier)
|
self.reward_classifier = torch.compile(reward_classifier)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
@ -305,9 +308,7 @@ class RewardWrapper(gym.Wrapper):
|
||||||
reward = (
|
reward = (
|
||||||
self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
|
self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
|
||||||
)
|
)
|
||||||
# print(f"fps for reward classifier {1/(time.perf_counter() - start_time)}")
|
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||||
reward = reward.item()
|
|
||||||
# print(f"Reward from reward classifier {reward}")
|
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
def reset(self, seed=None, options=None):
|
def reset(self, seed=None, options=None):
|
||||||
|
@ -323,17 +324,23 @@ class TimeLimitWrapper(gym.Wrapper):
|
||||||
self.last_timestamp = 0.0
|
self.last_timestamp = 0.0
|
||||||
self.episode_time_in_s = 0.0
|
self.episode_time_in_s = 0.0
|
||||||
|
|
||||||
|
self.max_episode_steps = int(self.control_time_s * self.fps)
|
||||||
|
|
||||||
|
self.current_step = 0
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
time_since_last_step = time.perf_counter() - self.last_timestamp
|
time_since_last_step = time.perf_counter() - self.last_timestamp
|
||||||
|
# logging.warning(f"Current timestep is lower than the expected fps {self.fps}")
|
||||||
self.episode_time_in_s += time_since_last_step
|
self.episode_time_in_s += time_since_last_step
|
||||||
self.last_timestamp = time.perf_counter()
|
self.last_timestamp = time.perf_counter()
|
||||||
|
self.current_step += 1
|
||||||
# check if last timestep took more time than the expected fps
|
# check if last timestep took more time than the expected fps
|
||||||
if 1.0 / time_since_last_step < self.fps:
|
# if 1.0 / time_since_last_step < self.fps:
|
||||||
logging.warning(f"Current timestep is lower than the expected fps {self.fps}")
|
# logging.warning(f"Current timestep exceeded expected fps {self.fps}")
|
||||||
|
|
||||||
if self.episode_time_in_s > self.control_time_s:
|
if self.episode_time_in_s > self.control_time_s:
|
||||||
|
# if self.current_step >= self.max_episode_steps:
|
||||||
# Terminated = True
|
# Terminated = True
|
||||||
terminated = True
|
terminated = True
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
@ -341,11 +348,13 @@ class TimeLimitWrapper(gym.Wrapper):
|
||||||
def reset(self, seed=None, options=None):
|
def reset(self, seed=None, options=None):
|
||||||
self.episode_time_in_s = 0.0
|
self.episode_time_in_s = 0.0
|
||||||
self.last_timestamp = time.perf_counter()
|
self.last_timestamp = time.perf_counter()
|
||||||
|
self.current_step = 0
|
||||||
return self.env.reset(seed=seed, options=options)
|
return self.env.reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
class ImageCropResizeWrapper(gym.Wrapper):
|
class ImageCropResizeWrapper(gym.Wrapper):
|
||||||
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
|
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
|
||||||
|
super().__init__(env)
|
||||||
self.env = env
|
self.env = env
|
||||||
self.crop_params_dict = crop_params_dict
|
self.crop_params_dict = crop_params_dict
|
||||||
print(f"obs_keys , {self.env.observation_space}")
|
print(f"obs_keys , {self.env.observation_space}")
|
||||||
|
@ -372,10 +381,21 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||||
obs[k] = F.resize(obs[k], self.resize_size)
|
obs[k] = F.resize(obs[k], self.resize_size)
|
||||||
obs[k] = obs[k].to(device)
|
obs[k] = obs[k].to(device)
|
||||||
# print(f"observation with key {k} with size {obs[k].size()}")
|
# print(f"observation with key {k} with size {obs[k].size()}")
|
||||||
cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR))
|
# cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR))
|
||||||
cv2.waitKey(1)
|
# cv2.waitKey(1)
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def reset(self, seed=None, options=None):
|
||||||
|
obs, info = self.env.reset(seed=seed, options=options)
|
||||||
|
for k in self.crop_params_dict:
|
||||||
|
device = obs[k].device
|
||||||
|
if device == torch.device("mps:0"):
|
||||||
|
obs[k] = obs[k].cpu()
|
||||||
|
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
|
||||||
|
obs[k] = F.resize(obs[k], self.resize_size)
|
||||||
|
obs[k] = obs[k].to(device)
|
||||||
|
return obs, info
|
||||||
|
|
||||||
|
|
||||||
class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
||||||
def __init__(self, env, device):
|
def __init__(self, env, device):
|
||||||
|
@ -515,42 +535,64 @@ class ResetWrapper(gym.Wrapper):
|
||||||
return super().reset(seed=seed, options=options)
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env):
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||||
|
for key in observation:
|
||||||
|
if "image" in key and observation[key].dim() == 3:
|
||||||
|
observation[key] = observation[key].unsqueeze(0)
|
||||||
|
if "state" in key and observation[key].dim() == 1:
|
||||||
|
observation[key] = observation[key].unsqueeze(0)
|
||||||
|
return observation
|
||||||
|
|
||||||
|
|
||||||
def make_robot_env(
|
def make_robot_env(
|
||||||
robot,
|
robot,
|
||||||
reward_classifier,
|
reward_classifier,
|
||||||
crop_params_dict=None,
|
cfg,
|
||||||
fps=30,
|
n_envs: int = 1,
|
||||||
control_time_s=20,
|
) -> gym.vector.VectorEnv:
|
||||||
reset_follower_pos=True,
|
|
||||||
display_cameras=False,
|
|
||||||
device="cuda:0",
|
|
||||||
resize_size=None,
|
|
||||||
reset_time_s=10,
|
|
||||||
delta_action=0.1,
|
|
||||||
nb_repeats=1,
|
|
||||||
use_relative_joint_positions=False,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Factory function to create the robot environment.
|
Factory function to create a vectorized robot environment.
|
||||||
|
|
||||||
Mimics gym.make() for consistent environment creation.
|
Args:
|
||||||
|
robot: Robot instance to control
|
||||||
|
reward_classifier: Classifier model for computing rewards
|
||||||
|
cfg: Configuration object containing environment parameters
|
||||||
|
n_envs: Number of environments to create in parallel. Defaults to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A vectorized gym environment with all the necessary wrappers applied.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Create base environment
|
||||||
env = HILSerlRobotEnv(
|
env = HILSerlRobotEnv(
|
||||||
robot,
|
robot=robot,
|
||||||
display_cameras=display_cameras,
|
display_cameras=cfg.wrapper.display_cameras,
|
||||||
delta=delta_action,
|
delta=cfg.wrapper.delta_action,
|
||||||
use_delta_action_space=use_relative_joint_positions,
|
use_delta_action_space=cfg.wrapper.use_relative_joint_positions,
|
||||||
)
|
)
|
||||||
env = ConvertToLeRobotObservation(env, device)
|
|
||||||
if crop_params_dict is not None:
|
# Add observation and image processing
|
||||||
env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
|
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
||||||
env = RewardWrapper(env, reward_classifier, device=device)
|
if cfg.wrapper.crop_params_dict is not None:
|
||||||
env = TimeLimitWrapper(env, control_time_s, fps)
|
env = ImageCropResizeWrapper(
|
||||||
# env = ActionRepeatWrapper(env, nb_repeat=nb_repeats)
|
env=env, crop_params_dict=cfg.wrapper.crop_params_dict, resize_size=cfg.wrapper.resize_size
|
||||||
env = KeyboardInterfaceWrapper(env)
|
)
|
||||||
env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s)
|
|
||||||
|
# Add reward computation and control wrappers
|
||||||
|
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||||
|
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||||
|
env = KeyboardInterfaceWrapper(env=env)
|
||||||
|
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.wrapper.reset_time_s)
|
||||||
|
env = BatchCompitableWrapper(env=env)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
# batched version of the env that returns an observation of shape (b, c)
|
||||||
|
|
||||||
|
|
||||||
def get_classifier(pretrained_path, config_path, device="mps"):
|
def get_classifier(pretrained_path, config_path, device="mps"):
|
||||||
if pretrained_path is None or config_path is None:
|
if pretrained_path is None or config_path is None:
|
||||||
|
@ -616,6 +658,8 @@ if __name__ == "__main__":
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
|
||||||
|
parser.add_argument("--env-overrides", type=str, default=None, help="Overrides for the env yaml file")
|
||||||
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
||||||
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
|
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -626,72 +670,38 @@ if __name__ == "__main__":
|
||||||
reward_classifier = get_classifier(
|
reward_classifier = get_classifier(
|
||||||
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
|
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
|
||||||
)
|
)
|
||||||
|
|
||||||
crop_parameters = {
|
|
||||||
"observation.images.laptop": (58, 89, 357, 455),
|
|
||||||
"observation.images.phone": (3, 4, 471, 633),
|
|
||||||
}
|
|
||||||
|
|
||||||
user_relative_joint_positions = True
|
user_relative_joint_positions = True
|
||||||
|
|
||||||
|
cfg = init_hydra_config(args.env_path, args.env_overrides)
|
||||||
env = make_robot_env(
|
env = make_robot_env(
|
||||||
robot,
|
robot,
|
||||||
reward_classifier,
|
reward_classifier,
|
||||||
crop_parameters,
|
cfg.wrapper,
|
||||||
args.fps,
|
|
||||||
args.control_time_s,
|
|
||||||
args.reset_follower_pos,
|
|
||||||
args.display_cameras,
|
|
||||||
device="mps",
|
|
||||||
resize_size=None,
|
|
||||||
reset_time_s=10,
|
|
||||||
delta_action=0.1,
|
|
||||||
nb_repeats=1,
|
|
||||||
use_relative_joint_positions=user_relative_joint_positions,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
env.reset()
|
env.reset()
|
||||||
init_pos = env.unwrapped.initial_follower_position
|
|
||||||
|
|
||||||
right_goal = init_pos.copy()
|
# Retrieve the robot's action space for joint commands.
|
||||||
right_goal[0] += 50
|
action_space_robot = env.action_space.spaces[0]
|
||||||
|
|
||||||
left_goal = init_pos.copy()
|
# Initialize the smoothed action as a random sample.
|
||||||
left_goal[0] -= 50
|
smoothed_action = action_space_robot.sample()
|
||||||
|
|
||||||
pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000)
|
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
|
||||||
|
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
|
||||||
delta_angle = np.concatenate((-np.ones(50), np.ones(50))) * 100
|
alpha = 0.4
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
action = np.zeros(len(init_pos))
|
start_loop_s = time.perf_counter()
|
||||||
for i in range(len(delta_angle)):
|
# Sample a new random action from the robot's action space.
|
||||||
start_loop_s = time.perf_counter()
|
new_random_action = action_space_robot.sample()
|
||||||
action[0] = delta_angle[i]
|
# Update the smoothed action using an exponential moving average.
|
||||||
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
|
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
||||||
if terminated or truncated:
|
|
||||||
env.reset()
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_s
|
# Execute the step: wrap the NumPy action in a torch tensor.
|
||||||
busy_wait(1 / args.fps - dt_s)
|
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
|
||||||
# action = np.zeros(len(init_pos)) if user_relative_joint_positions else init_pos
|
if terminated or truncated:
|
||||||
# for i in range(len(pitch_angle)):
|
env.reset()
|
||||||
# if user_relative_joint_positions:
|
|
||||||
# action[0] = delta_angle[i]
|
|
||||||
# else:
|
|
||||||
# action[0] = pitch_angle[i]
|
|
||||||
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
|
|
||||||
# if terminated or truncated:
|
|
||||||
# logging.info("Max control time reached, reset environment.")
|
|
||||||
# env.reset()
|
|
||||||
|
|
||||||
# for i in reversed(range(len(pitch_angle))):
|
dt_s = time.perf_counter() - start_loop_s
|
||||||
# if user_relative_joint_positions:
|
busy_wait(1 / args.fps - dt_s)
|
||||||
# action[0] = delta_angle[i]
|
|
||||||
# else:
|
|
||||||
# action[0] = pitch_angle[i]
|
|
||||||
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
|
|
||||||
|
|
||||||
# if terminated or truncated:
|
|
||||||
# logging.info("Max control time reached, reset environment.")
|
|
||||||
# env.reset()
|
|
||||||
|
|
|
@ -36,6 +36,8 @@ from termcolor import colored
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
|
||||||
# TODO: Remove the import of maniskill
|
# TODO: Remove the import of maniskill
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
|
@ -52,6 +54,7 @@ from lerobot.common.utils.utils import (
|
||||||
)
|
)
|
||||||
from lerobot.scripts.server.buffer import (
|
from lerobot.scripts.server.buffer import (
|
||||||
ReplayBuffer,
|
ReplayBuffer,
|
||||||
|
concatenate_batch_transitions,
|
||||||
move_state_dict_to_device,
|
move_state_dict_to_device,
|
||||||
move_transition_to_device,
|
move_transition_to_device,
|
||||||
)
|
)
|
||||||
|
@ -259,8 +262,15 @@ def learner_push_parameters(
|
||||||
while True:
|
while True:
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
params_dict = policy.actor.state_dict()
|
params_dict = policy.actor.state_dict()
|
||||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
if policy.config.vision_encoder_name is not None:
|
||||||
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")}
|
if policy.config.freeze_vision_encoder:
|
||||||
|
params_dict: dict[str, torch.Tensor] = {
|
||||||
|
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||||
|
)
|
||||||
|
|
||||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||||
# Serialize
|
# Serialize
|
||||||
|
@ -322,6 +332,7 @@ def add_actor_information_and_train(
|
||||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||||
# are divided by 200. So we need to have a single thread that does all the work.
|
# are divided by 200. So we need to have a single thread that does all the work.
|
||||||
time.time()
|
time.time()
|
||||||
|
logging.info("Starting learner thread")
|
||||||
interaction_message, transition = None, None
|
interaction_message, transition = None, None
|
||||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||||
|
@ -340,16 +351,21 @@ def add_actor_information_and_train(
|
||||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||||
interaction_message["Interaction step"] += interaction_step_shift
|
interaction_message["Interaction step"] += interaction_step_shift
|
||||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
||||||
|
logging.info(f"Interaction message: {interaction_message}")
|
||||||
|
|
||||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
|
||||||
|
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
|
||||||
|
|
||||||
time_for_one_optimization_step = time.time()
|
time_for_one_optimization_step = time.time()
|
||||||
for _ in range(cfg.policy.utd_ratio - 1):
|
for _ in range(cfg.policy.utd_ratio - 1):
|
||||||
batch = replay_buffer.sample(batch_size)
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
|
||||||
# if cfg.offline_dataset_repo_id is not None:
|
if cfg.dataset_repo_id is not None:
|
||||||
# batch_offline = offline_replay_buffer.sample(batch_size)
|
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||||
# batch = concatenate_batch_transitions(batch, batch_offline)
|
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||||
|
|
||||||
actions = batch["action"]
|
actions = batch["action"]
|
||||||
rewards = batch["reward"]
|
rewards = batch["reward"]
|
||||||
|
@ -371,11 +387,11 @@ def add_actor_information_and_train(
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size)
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
|
||||||
# if cfg.offline_dataset_repo_id is not None:
|
if cfg.dataset_repo_id is not None:
|
||||||
# batch_offline = offline_replay_buffer.sample(batch_size)
|
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||||
# batch = concatenate_batch_transitions(
|
batch = concatenate_batch_transitions(
|
||||||
# left_batch_transitions=batch, right_batch_transition=batch_offline
|
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||||
# )
|
)
|
||||||
|
|
||||||
actions = batch["action"]
|
actions = batch["action"]
|
||||||
rewards = batch["reward"]
|
rewards = batch["reward"]
|
||||||
|
@ -423,7 +439,7 @@ def add_actor_information_and_train(
|
||||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||||
|
|
||||||
logging.debug(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||||
|
|
||||||
logger.log_dict(
|
logger.log_dict(
|
||||||
{"Optimization frequency loop [Hz]": frequency_for_one_optimization_step},
|
{"Optimization frequency loop [Hz]": frequency_for_one_optimization_step},
|
||||||
|
@ -560,14 +576,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
batch_size = cfg.training.batch_size
|
batch_size = cfg.training.batch_size
|
||||||
offline_replay_buffer = None
|
offline_replay_buffer = None
|
||||||
|
|
||||||
# if cfg.dataset_repo_id is not None:
|
if cfg.dataset_repo_id is not None:
|
||||||
# logging.info("make_dataset offline buffer")
|
logging.info("make_dataset offline buffer")
|
||||||
# offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
# logging.info("Convertion to a offline replay buffer")
|
logging.info("Convertion to a offline replay buffer")
|
||||||
# offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||||
# offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||||
# )
|
)
|
||||||
# batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||||
|
|
||||||
start_learner_threads(
|
start_learner_threads(
|
||||||
cfg,
|
cfg,
|
||||||
|
|
|
@ -279,8 +279,10 @@ def train(cfg: DictConfig) -> None:
|
||||||
logging.info(f"Dataset size: {len(dataset)}")
|
logging.info(f"Dataset size: {len(dataset)}")
|
||||||
|
|
||||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
train_size = int(cfg.train_split_proportion * len(dataset))
|
||||||
val_size = len(dataset) - train_size
|
# val_size = len(dataset) - train_size
|
||||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||||
|
train_dataset = dataset[:train_size]
|
||||||
|
val_dataset = dataset[train_size:]
|
||||||
|
|
||||||
sampler = create_balanced_sampler(train_dataset, cfg)
|
sampler = create_balanced_sampler(train_dataset, cfg)
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
|
|
Loading…
Reference in New Issue