Added possiblity to record and replay delta actions during teleoperation rather than absolute actions
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
6868c88ef1
commit
b9217b06db
|
@ -84,7 +84,7 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
# Load metadata
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
# self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
|
@ -537,7 +537,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
]
|
||||
files += video_files
|
||||
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
# HACK: UNCOMMENT IF YOU REVIEW THAT, PLEASE SUGGEST TO UNCOMMENT
|
||||
logging.warning("HACK: WE COMMENT THIS LINE, IF SOMETHING IS WEIRD WITH DATASETS UNCOMMENT")
|
||||
# self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
|
|
|
@ -147,6 +147,8 @@ class Classifier(
|
|||
|
||||
def predict_reward(self, x, threshold=0.6):
|
||||
if self.config.num_classes == 2:
|
||||
return (self.forward(x).probabilities > threshold).float()
|
||||
probs = self.forward(x).probabilities
|
||||
logging.info(f"Predicted reward images: {probs}")
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.forward(x).probabilities, dim=1)
|
||||
|
|
|
@ -225,6 +225,7 @@ def record_episode(
|
|||
device,
|
||||
use_amp,
|
||||
fps,
|
||||
record_delta_actions,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
|
@ -236,6 +237,7 @@ def record_episode(
|
|||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
record_delta_actions=record_delta_actions,
|
||||
teleoperate=policy is None,
|
||||
)
|
||||
|
||||
|
@ -252,6 +254,7 @@ def control_loop(
|
|||
device=None,
|
||||
use_amp=None,
|
||||
fps=None,
|
||||
record_delta_actions=False,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
|
@ -274,8 +277,12 @@ def control_loop(
|
|||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
if record_delta_actions:
|
||||
action["action"] = action["action"] - current_joint_positions
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
|
@ -290,8 +297,12 @@ def control_loop(
|
|||
frame = {**observation, **action}
|
||||
if "next.reward" in events:
|
||||
frame["next.reward"] = events["next.reward"]
|
||||
frame["next.done"] = (events["next.reward"] == 1) or (events["exit_early"])
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# if frame["next.done"]:
|
||||
# break
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
|
|
|
@ -12,8 +12,10 @@ env:
|
|||
|
||||
wrapper:
|
||||
crop_params_dict:
|
||||
observation.images.laptop: [58, 89, 357, 455]
|
||||
observation.images.phone: [3, 4, 471, 633]
|
||||
observation.images.front: [126, 43, 329, 518]
|
||||
observation.images.side: [93, 69, 381, 434]
|
||||
# observation.images.front: [135, 59, 331, 527]
|
||||
# observation.images.side: [79, 47, 397, 450]
|
||||
resize_size: [128, 128]
|
||||
control_time_s: 20
|
||||
reset_follower_pos: true
|
||||
|
|
|
@ -4,7 +4,9 @@ defaults:
|
|||
- _self_
|
||||
|
||||
seed: 13
|
||||
dataset_repo_id: aractingi/push_green_cube_hf_cropped_resized
|
||||
dataset_repo_id: aractingi/push_cube_square_reward_cropped_resized
|
||||
dataset_root: data/aractingi/push_cube_square_reward_cropped_resized
|
||||
local_files_only: true
|
||||
train_split_proportion: 0.8
|
||||
|
||||
# Required by logger
|
||||
|
@ -14,7 +16,7 @@ env:
|
|||
|
||||
|
||||
training:
|
||||
num_epochs: 5
|
||||
num_epochs: 6
|
||||
batch_size: 16
|
||||
learning_rate: 1e-4
|
||||
num_workers: 4
|
||||
|
@ -25,7 +27,7 @@ training:
|
|||
save_freq: 1 # How often to save checkpoints (in epochs)
|
||||
save_checkpoint: true
|
||||
# image_keys: ["observation.images.top", "observation.images.wrist"]
|
||||
image_keys: ["observation.images.laptop", "observation.images.phone"]
|
||||
image_keys: ["observation.images.front", "observation.images.side"]
|
||||
label_key: "next.reward"
|
||||
profile_inference_time: false
|
||||
profile_inference_time_iters: 20
|
||||
|
@ -35,8 +37,8 @@ eval:
|
|||
num_samples_to_log: 30 # Number of validation samples to log in the table
|
||||
|
||||
policy:
|
||||
name: "hilserl/classifier/push_green_cube_hf_cropped_resized" #"hilserl/classifier/pick_place_lego_cube_1"
|
||||
model_name: "helper2424/resnet10"
|
||||
name: "hilserl/classifier/push_cube_square_reward_cropped_resized" #"hilserl/classifier/pick_place_lego_cube_120
|
||||
model_name: "helper2424/resnet10" # "facebook/convnext-base-224" #"helper2424/resnet10"
|
||||
model_type: "cnn"
|
||||
num_cameras: 2 # Has to be len(training.image_keys)
|
||||
|
||||
|
@ -48,4 +50,4 @@ wandb:
|
|||
|
||||
device: "mps"
|
||||
resume: false
|
||||
output_dir: "outputs/classifier"
|
||||
output_dir: "outputs/classifier/resnet10_frozen"
|
||||
|
|
|
@ -57,22 +57,22 @@ policy:
|
|||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
observation.images.laptop: [3, 128, 128]
|
||||
observation.images.phone: [3, 128, 128]
|
||||
observation.images.front: [3, 128, 128]
|
||||
observation.images.side: [3, 128, 128]
|
||||
# observation.image: [3, 128, 128]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
observation.images.front: mean_std
|
||||
observation.images.side: mean_std
|
||||
observation.state: min_max
|
||||
input_normalization_params:
|
||||
observation.images.laptop:
|
||||
observation.images.front:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
observation.images.phone:
|
||||
observation.images.side:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
observation.state:
|
||||
|
|
|
@ -50,13 +50,13 @@ follower_arms:
|
|||
gripper: [6, "sts3215"]
|
||||
|
||||
cameras:
|
||||
laptop:
|
||||
front:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
phone:
|
||||
side:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
|
|
|
@ -206,7 +206,8 @@ def record(
|
|||
num_image_writer_threads_per_camera: int = 4,
|
||||
display_cameras: bool = True,
|
||||
play_sounds: bool = True,
|
||||
reset_follower: bool = False,
|
||||
reset_follower: bool = False,
|
||||
record_delta_actions: bool = False,
|
||||
resume: bool = False,
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
local_files_only: bool = False,
|
||||
|
@ -218,7 +219,12 @@ def record(
|
|||
device = None
|
||||
use_amp = None
|
||||
extra_features = (
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||
{
|
||||
"next.reward": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
}
|
||||
if assign_rewards
|
||||
else None
|
||||
)
|
||||
|
||||
if single_task:
|
||||
|
@ -269,7 +275,7 @@ def record(
|
|||
|
||||
if reset_follower:
|
||||
initial_position = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
# 2. give times to the robot devices to connect and start synchronizing,
|
||||
|
@ -302,6 +308,7 @@ def record(
|
|||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
record_delta_actions=record_delta_actions,
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
|
@ -353,21 +360,24 @@ def replay(
|
|||
fps: int | None = None,
|
||||
play_sounds: bool = True,
|
||||
local_files_only: bool = False,
|
||||
replay_delta_actions: bool = False,
|
||||
):
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"]
|
||||
if replay_delta_actions:
|
||||
action = action + current_joint_positions
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
|
@ -534,6 +544,12 @@ if __name__ == "__main__":
|
|||
default=0,
|
||||
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--record-delta-actions",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the recording of delta actions instead of absolute actions.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--reset-follower",
|
||||
type=int,
|
||||
|
@ -563,6 +579,12 @@ if __name__ == "__main__":
|
|||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--replay-delta-actions",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the replay of delta actions instead of absolute actions.",
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -239,13 +239,17 @@ if __name__ == "__main__":
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=True)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
||||
|
||||
rois = select_square_roi_for_images(images)
|
||||
# rois = {
|
||||
# "observation.images.front": [126, 43, 329, 518],
|
||||
# "observation.images.side": [93, 69, 381, 434],
|
||||
# }
|
||||
|
||||
# Print the selected rectangular ROIs
|
||||
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
||||
|
|
|
@ -230,6 +230,8 @@ class HILSerlRobotEnv(gym.Env):
|
|||
if teleop_action.dim() == 1:
|
||||
teleop_action = teleop_action.unsqueeze(0)
|
||||
|
||||
# self.render()
|
||||
|
||||
self.current_step += 1
|
||||
|
||||
reward = 0.0
|
||||
|
@ -255,8 +257,7 @@ class HILSerlRobotEnv(gym.Env):
|
|||
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
|
||||
cv2.waitKey(1)
|
||||
cv2.waitKey(1)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
|
@ -311,10 +312,14 @@ class RewardWrapper(gym.Wrapper):
|
|||
start_time = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
reward = (
|
||||
self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
|
||||
self.reward_classifier.predict_reward(images, threshold=0.5)
|
||||
if self.reward_classifier is not None
|
||||
else 0.0
|
||||
)
|
||||
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
logging.info(f"Reward: {reward}")
|
||||
|
||||
if reward == 1.0:
|
||||
terminated = True
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
@ -760,7 +765,7 @@ if __name__ == "__main__":
|
|||
env = make_robot_env(
|
||||
robot,
|
||||
reward_classifier,
|
||||
cfg.wrapper,
|
||||
cfg.env, # .wrapper,
|
||||
)
|
||||
|
||||
env.reset()
|
||||
|
|
|
@ -1,584 +0,0 @@
|
|||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.robot_devices.control_utils import is_headless, reset_follower_position
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.utils.utils import init_hydra_config, log_say
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class HILSerlRobotEnv(gym.Env):
|
||||
"""
|
||||
Gym-like environment wrapper for robot policy evaluation.
|
||||
|
||||
This wrapper provides a consistent interface for interacting with the robot,
|
||||
following the OpenAI Gym environment conventions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
robot,
|
||||
display_cameras=False,
|
||||
):
|
||||
"""
|
||||
Initialize the robot environment.
|
||||
|
||||
Args:
|
||||
robot: The robot interface object
|
||||
reward_classifier: Optional reward classifier
|
||||
fps: Frames per second for control
|
||||
control_time_s: Total control time for each episode
|
||||
display_cameras: Whether to display camera feeds
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.robot = robot
|
||||
self.display_cameras = display_cameras
|
||||
|
||||
# connect robot
|
||||
if not self.robot.is_connected:
|
||||
self.robot.connect()
|
||||
|
||||
# Dynamically determine observation and action spaces
|
||||
self._setup_spaces()
|
||||
|
||||
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
# Episode tracking
|
||||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
|
||||
def _setup_spaces(self):
|
||||
"""
|
||||
Dynamically determine observation and action spaces based on robot capabilities.
|
||||
|
||||
This method should be customized based on the specific robot's observation
|
||||
and action representations.
|
||||
"""
|
||||
# Example space setup - you'll need to adapt this to your specific robot
|
||||
example_obs = self.robot.capture_observation()
|
||||
|
||||
# Observation space (assuming image-based observations)
|
||||
image_keys = [key for key in example_obs if "image" in key]
|
||||
state_keys = [key for key in example_obs if "image" not in key]
|
||||
observation_spaces = {
|
||||
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
|
||||
for key in image_keys
|
||||
}
|
||||
observation_spaces["observation.state"] = gym.spaces.Dict(
|
||||
{
|
||||
key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32)
|
||||
for key in state_keys
|
||||
}
|
||||
)
|
||||
|
||||
self.observation_space = gym.spaces.Dict(observation_spaces)
|
||||
|
||||
# Action space (assuming joint positions)
|
||||
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
|
||||
self.action_space = gym.spaces.Tuple(
|
||||
(
|
||||
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(action_dim,), dtype=np.float32),
|
||||
gym.spaces.Discrete(2),
|
||||
),
|
||||
)
|
||||
|
||||
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
||||
"""
|
||||
Reset the environment to initial state.
|
||||
|
||||
Returns:
|
||||
observation (dict): Initial observation
|
||||
info (dict): Additional information
|
||||
"""
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
# Capture initial observation
|
||||
observation = self.robot.capture_observation()
|
||||
|
||||
# Reset tracking variables
|
||||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
|
||||
return observation, {"initial_position": self.initial_follower_position}
|
||||
|
||||
def step(
|
||||
self, action: Tuple[np.ndarray, bool]
|
||||
) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
|
||||
"""
|
||||
Take a step in the environment.
|
||||
|
||||
Args:
|
||||
action tuple(np.ndarray, bool):
|
||||
Policy action to be executed on the robot and boolean to determine
|
||||
whether to choose policy action or expert action.
|
||||
|
||||
Returns:
|
||||
observation (dict): Next observation
|
||||
reward (float): Reward for this step
|
||||
terminated (bool): Whether the episode has terminated
|
||||
truncated (bool): Whether the episode was truncated
|
||||
info (dict): Additional information
|
||||
"""
|
||||
# The actions recieved are the in form of a tuple containing the policy action and an intervention bool
|
||||
# The boolean inidicated whether we will use the expert's actions (through teleoperation) or the policy actions
|
||||
policy_action, intervention_bool = action
|
||||
teleop_action = None
|
||||
if not intervention_bool:
|
||||
self.robot.send_action(policy_action.cpu())
|
||||
observation = self.robot.capture_observation()
|
||||
else:
|
||||
observation, teleop_action = self.robot.teleop_step(record_data=True)
|
||||
teleop_action = teleop_action["action"] # teleop step returns torch tensors but in a dict
|
||||
|
||||
self.current_step += 1
|
||||
|
||||
reward = 0.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
|
||||
return (
|
||||
observation,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
|
||||
)
|
||||
|
||||
def render(self):
|
||||
"""
|
||||
Render the environment (in this case, display camera feeds).
|
||||
"""
|
||||
import cv2
|
||||
|
||||
observation = self.robot.capture_observation()
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
|
||||
cv2.waitKey(1)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the environment and disconnect the robot.
|
||||
"""
|
||||
if self.robot.is_connected:
|
||||
self.robot.disconnect()
|
||||
|
||||
|
||||
class ActionRepeatWrapper(gym.Wrapper):
|
||||
def __init__(self, env, nb_repeat: int = 1):
|
||||
super().__init__(env)
|
||||
self.nb_repeat = nb_repeat
|
||||
|
||||
def step(self, action):
|
||||
for _ in range(self.nb_repeat):
|
||||
obs, reward, done, truncated, info = self.env.step(action)
|
||||
if done or truncated:
|
||||
break
|
||||
return obs, reward, done, truncated, info
|
||||
|
||||
|
||||
class RelativeJointPositionActionWrapper(gym.Wrapper):
|
||||
def __init__(self, env: HILSerlRobotEnv, delta: float = 0.1):
|
||||
super().__init__(env)
|
||||
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
||||
self.delta = delta
|
||||
|
||||
def step(self, action):
|
||||
action_joint = action
|
||||
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
||||
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||
action_joint = action[0]
|
||||
joint_positions = self.joint_positions + (self.delta * action_joint)
|
||||
# clip the joint positions to the joint limits with the action space
|
||||
joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high)
|
||||
|
||||
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||
return self.env.step((joint_positions, action[1]))
|
||||
|
||||
obs, reward, terminated, truncated, info = self.env.step(joint_positions)
|
||||
if info["is_intervention"]:
|
||||
# teleop actions are returned in absolute joint space
|
||||
# If we are using a relative joint position action space,
|
||||
# there will be a mismatch between the spaces of the policy and teleop actions
|
||||
# Solution is to transform the teleop actions into relative space.
|
||||
teleop_action = info["action_intervention"] # teleop actions are in absolute joint space
|
||||
relative_teleop_action = (teleop_action - self.joint_positions) / self.delta
|
||||
info["action_intervention"] = relative_teleop_action
|
||||
|
||||
return self.env.step(joint_positions)
|
||||
|
||||
|
||||
class RewardWrapper(gym.Wrapper):
|
||||
def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"):
|
||||
self.env = env
|
||||
self.reward_classifier = reward_classifier
|
||||
self.device = device
|
||||
|
||||
def step(self, action):
|
||||
observation, _, terminated, truncated, info = self.env.step(action)
|
||||
images = [
|
||||
observation[key].to(self.device, non_blocking=True) for key in observation if "image" in key
|
||||
]
|
||||
reward = self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
|
||||
reward = reward.item()
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class TimeLimitWrapper(gym.Wrapper):
|
||||
def __init__(self, env, control_time_s, fps):
|
||||
self.env = env
|
||||
self.control_time_s = control_time_s
|
||||
self.fps = fps
|
||||
|
||||
self.last_timestamp = 0.0
|
||||
self.episode_time_in_s = 0.0
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
time_since_last_step = time.perf_counter() - self.last_timestamp
|
||||
self.episode_time_in_s += time_since_last_step
|
||||
self.last_timestamp = time.perf_counter()
|
||||
|
||||
# check if last timestep took more time than the expected fps
|
||||
if 1.0 / time_since_last_step < self.fps:
|
||||
logging.warning(f"Current timestep exceeded expected fps {self.fps}")
|
||||
|
||||
if self.episode_time_in_s > self.control_time_s:
|
||||
# Terminated = True
|
||||
terminated = True
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
self.episode_time_in_s = 0.0
|
||||
self.last_timestamp = time.perf_counter()
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class ImageCropResizeWrapper(gym.Wrapper):
|
||||
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
|
||||
self.env = env
|
||||
self.crop_params_dict = crop_params_dict
|
||||
for key in crop_params_dict:
|
||||
assert key in self.env.observation_space, f"Key {key} not in observation space"
|
||||
top, left, height, width = crop_params_dict[key]
|
||||
new_shape = (top + height, left + width)
|
||||
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
|
||||
|
||||
self.resize_size = resize_size
|
||||
if self.resize_size is None:
|
||||
self.resize_size = (128, 128)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
for k in self.crop_params_dict:
|
||||
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
|
||||
obs[k] = F.resize(obs[k], self.resize_size)
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
||||
def __init__(self, env, device):
|
||||
super().__init__(env)
|
||||
self.device = device
|
||||
|
||||
def observation(self, observation):
|
||||
observation = preprocess_observation(observation)
|
||||
|
||||
observation = {key: observation[key].to(self.device, non_blocking=True) for key in observation}
|
||||
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
|
||||
return observation
|
||||
|
||||
|
||||
class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.listener = None
|
||||
self.events = {
|
||||
"exit_early": False,
|
||||
"pause_policy": False,
|
||||
"reset_env": False,
|
||||
"human_intervention_step": False,
|
||||
}
|
||||
self.event_lock = Lock() # Thread-safe access to events
|
||||
self._init_keyboard_listener()
|
||||
|
||||
def _init_keyboard_listener(self):
|
||||
"""Initialize keyboard listener if not in headless mode"""
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
return
|
||||
try:
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
with self.event_lock:
|
||||
try:
|
||||
if key == keyboard.Key.right or key == keyboard.Key.esc:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
self.events["exit_early"] = True
|
||||
elif key == keyboard.Key.space:
|
||||
if not self.events["pause_policy"]:
|
||||
print(
|
||||
"Space key pressed. Human intervention required.\n"
|
||||
"Place the leader in similar pose to the follower and press space again."
|
||||
)
|
||||
self.events["pause_policy"] = True
|
||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||
elif self.events["pause_policy"] and not self.events["human_intervention_step"]:
|
||||
self.events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
log_say("Starting human intervention.", play_sounds=True)
|
||||
else:
|
||||
self.events["pause_policy"] = False
|
||||
self.events["human_intervention_step"] = False
|
||||
print("Space key pressed for a third time.")
|
||||
log_say("Continuing with policy actions.", play_sounds=True)
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
self.listener = keyboard.Listener(on_press=on_press)
|
||||
self.listener.start()
|
||||
except ImportError:
|
||||
logging.warning("Could not import pynput. Keyboard interface will not be available.")
|
||||
self.listener = None
|
||||
|
||||
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
|
||||
is_intervention = False
|
||||
terminated_by_keyboard = False
|
||||
|
||||
# Extract policy_action if needed
|
||||
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||
policy_action = action[0]
|
||||
|
||||
# Check the event flags without holding the lock for too long.
|
||||
with self.event_lock:
|
||||
if self.events["exit_early"]:
|
||||
terminated_by_keyboard = True
|
||||
# If we need to wait for human intervention, we note that outside the lock.
|
||||
pause_policy = self.events["pause_policy"]
|
||||
|
||||
if pause_policy:
|
||||
# Now, wait for human_intervention_step without holding the lock
|
||||
while True:
|
||||
with self.event_lock:
|
||||
if self.events["human_intervention_step"]:
|
||||
is_intervention = True
|
||||
break
|
||||
time.sleep(0.1) # Check more frequently if desired
|
||||
|
||||
# Execute the step in the underlying environment
|
||||
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
|
||||
return obs, reward, terminated or terminated_by_keyboard, truncated, info
|
||||
|
||||
def reset(self, **kwargs) -> Tuple[Any, Dict]:
|
||||
"""
|
||||
Reset the environment and clear any pending events
|
||||
"""
|
||||
with self.event_lock:
|
||||
self.events = {k: False for k in self.events}
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Properly clean up the keyboard listener when the environment is closed
|
||||
"""
|
||||
if self.listener is not None:
|
||||
self.listener.stop()
|
||||
super().close()
|
||||
|
||||
|
||||
class ResetWrapper(gym.Wrapper):
|
||||
def __init__(
|
||||
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
|
||||
):
|
||||
super().__init__(env)
|
||||
self.reset_fn = reset_fn
|
||||
self.reset_time_s = reset_time_s
|
||||
|
||||
self.robot = self.unwrapped.robot
|
||||
self.init_pos = self.unwrapped.initial_follower_position
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
if self.reset_fn is not None:
|
||||
self.reset_fn(self.env)
|
||||
else:
|
||||
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
|
||||
start_time = time.perf_counter()
|
||||
while time.perf_counter() - start_time < self.reset_time_s:
|
||||
self.robot.teleop_step()
|
||||
|
||||
log_say("Manual reseting of the environment done.", play_sounds=True)
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
|
||||
def make_robot_env(
|
||||
robot,
|
||||
reward_classifier,
|
||||
crop_params_dict=None,
|
||||
fps=30,
|
||||
control_time_s=20,
|
||||
reset_follower_pos=True,
|
||||
display_cameras=False,
|
||||
device="cuda:0",
|
||||
resize_size=None,
|
||||
reset_time_s=10,
|
||||
delta_action=0.1,
|
||||
nb_repeats=1,
|
||||
use_relative_joint_positions=False,
|
||||
):
|
||||
"""
|
||||
Factory function to create the robot environment.
|
||||
|
||||
Mimics gym.make() for consistent environment creation.
|
||||
"""
|
||||
env = HILSerlRobotEnv(robot, display_cameras)
|
||||
env = ConvertToLeRobotObservation(env, device)
|
||||
# if crop_params_dict is not None:
|
||||
# env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
|
||||
# env = RewardWrapper(env, reward_classifier)
|
||||
env = TimeLimitWrapper(env, control_time_s, fps)
|
||||
# if use_relative_joint_positions:
|
||||
# env = RelativeJointPositionActionWrapper(env, delta=delta_action)
|
||||
# env = ActionRepeatWrapper(env, nb_repeat=nb_repeats)
|
||||
env = KeyboardInterfaceWrapper(env)
|
||||
env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s)
|
||||
return env
|
||||
|
||||
|
||||
def get_classifier(pretrained_path, config_path, device="mps"):
|
||||
if pretrained_path is None or config_path is None:
|
||||
return
|
||||
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
cfg = init_hydra_config(config_path)
|
||||
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
model = model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fps", type=int, default=30, help="control frequency")
|
||||
parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
help=(
|
||||
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-pretrained-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the pretrained classifier weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-config-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
||||
)
|
||||
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
||||
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
|
||||
args = parser.parse_args()
|
||||
|
||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
|
||||
reward_classifier = get_classifier(
|
||||
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
|
||||
)
|
||||
|
||||
env = make_robot_env(
|
||||
robot,
|
||||
reward_classifier,
|
||||
None,
|
||||
args.fps,
|
||||
args.control_time_s,
|
||||
args.reset_follower_pos,
|
||||
args.display_cameras,
|
||||
device="mps",
|
||||
resize_size=None,
|
||||
reset_time_s=10,
|
||||
delta_action=0.1,
|
||||
nb_repeats=1,
|
||||
use_relative_joint_positions=False,
|
||||
)
|
||||
|
||||
env.reset()
|
||||
init_pos = env.unwrapped.initial_follower_position
|
||||
goal_pos = init_pos
|
||||
|
||||
right_goal = init_pos.copy()
|
||||
right_goal[0] += 50
|
||||
|
||||
left_goal = init_pos.copy()
|
||||
left_goal[0] -= 50
|
||||
|
||||
# Michel is a beast
|
||||
pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000)
|
||||
|
||||
while True:
|
||||
for i in range(len(pitch_angle)):
|
||||
goal_pos[0] = pitch_angle[i]
|
||||
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False))
|
||||
if terminated or truncated:
|
||||
logging.info("Max control time reached, reset environment.")
|
||||
env.reset()
|
||||
|
||||
for i in reversed(range(len(pitch_angle))):
|
||||
goal_pos[0] = pitch_angle[i]
|
||||
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False))
|
||||
if terminated or truncated:
|
||||
logging.info("Max control time reached, reset environment.")
|
||||
env.reset()
|
|
@ -1,5 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -23,6 +21,7 @@ import hydra
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
|
@ -32,7 +31,6 @@ from torch.cuda.amp import GradScaler
|
|||
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
|
@ -45,6 +43,7 @@ from lerobot.common.utils.utils import (
|
|||
init_hydra_config,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server.buffer import random_shift
|
||||
|
||||
|
||||
def get_model(cfg, logger): # noqa I001
|
||||
|
@ -82,6 +81,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
|
|||
for batch_idx, batch in enumerate(pbar):
|
||||
start_time = time.perf_counter()
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
images = [random_shift(img, 4) for img in images]
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
|
@ -161,14 +161,17 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
|||
|
||||
# Log sample predictions for visualization
|
||||
if len(samples) < cfg.eval.num_samples_to_log:
|
||||
for i in range(min( cfg.eval.num_samples_to_log - len(samples), len(images))):
|
||||
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
|
||||
if model.config.num_classes == 2:
|
||||
confidence = round(outputs.probabilities[i].item(), 3)
|
||||
else:
|
||||
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
||||
samples.append(
|
||||
{
|
||||
**{f"image_{img_key}": wandb.Image(images[img_idx][i].cpu()) for img_idx, img_key in enumerate(cfg.training.image_keys)},
|
||||
**{
|
||||
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
|
||||
for img_idx, img_key in enumerate(cfg.training.image_keys)
|
||||
},
|
||||
"true_label": labels[i].item(),
|
||||
"predicted": predictions[i].item(),
|
||||
"confidence": confidence,
|
||||
|
@ -270,11 +273,13 @@ def train(cfg: DictConfig) -> None:
|
|||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "classifier"
|
||||
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "frozen_resnet10_2"
|
||||
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
||||
|
||||
# Setup dataset and dataloaders
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only
|
||||
)
|
||||
logging.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
n_total = len(dataset)
|
||||
|
@ -282,14 +287,13 @@ def train(cfg: DictConfig) -> None:
|
|||
train_dataset = torch.utils.data.Subset(dataset, range(0, n_train))
|
||||
val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total))
|
||||
|
||||
|
||||
sampler = create_balanced_sampler(train_dataset, cfg)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.training.num_workers,
|
||||
sampler=sampler,
|
||||
pin_memory=True,
|
||||
pin_memory=device.type == "cuda",
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
|
@ -297,7 +301,7 @@ def train(cfg: DictConfig) -> None:
|
|||
batch_size=cfg.eval.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=cfg.training.num_workers,
|
||||
pin_memory=True,
|
||||
pin_memory=device.type == "cuda",
|
||||
)
|
||||
|
||||
# Resume training if requested
|
||||
|
|
Loading…
Reference in New Issue