Added possiblity to record and replay delta actions during teleoperation rather than absolute actions

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-12 19:25:41 +01:00
parent 6868c88ef1
commit b9217b06db
12 changed files with 92 additions and 622 deletions

View File

@ -84,7 +84,7 @@ class LeRobotDatasetMetadata:
# Load metadata
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
# self.pull_from_repo(allow_patterns="meta/")
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root)
@ -537,7 +537,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
]
files += video_files
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
# HACK: UNCOMMENT IF YOU REVIEW THAT, PLEASE SUGGEST TO UNCOMMENT
logging.warning("HACK: WE COMMENT THIS LINE, IF SOMETHING IS WEIRD WITH DATASETS UNCOMMENT")
# self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""

View File

@ -147,6 +147,8 @@ class Classifier(
def predict_reward(self, x, threshold=0.6):
if self.config.num_classes == 2:
return (self.forward(x).probabilities > threshold).float()
probs = self.forward(x).probabilities
logging.info(f"Predicted reward images: {probs}")
return (probs > threshold).float()
else:
return torch.argmax(self.forward(x).probabilities, dim=1)

View File

@ -225,6 +225,7 @@ def record_episode(
device,
use_amp,
fps,
record_delta_actions,
):
control_loop(
robot=robot,
@ -236,6 +237,7 @@ def record_episode(
device=device,
use_amp=use_amp,
fps=fps,
record_delta_actions=record_delta_actions,
teleoperate=policy is None,
)
@ -252,6 +254,7 @@ def control_loop(
device=None,
use_amp=None,
fps=None,
record_delta_actions=False,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
@ -274,8 +277,12 @@ def control_loop(
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
if teleoperate:
observation, action = robot.teleop_step(record_data=True)
if record_delta_actions:
action["action"] = action["action"] - current_joint_positions
else:
observation = robot.capture_observation()
@ -290,8 +297,12 @@ def control_loop(
frame = {**observation, **action}
if "next.reward" in events:
frame["next.reward"] = events["next.reward"]
frame["next.done"] = (events["next.reward"] == 1) or (events["exit_early"])
dataset.add_frame(frame)
# if frame["next.done"]:
# break
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:

View File

@ -12,8 +12,10 @@ env:
wrapper:
crop_params_dict:
observation.images.laptop: [58, 89, 357, 455]
observation.images.phone: [3, 4, 471, 633]
observation.images.front: [126, 43, 329, 518]
observation.images.side: [93, 69, 381, 434]
# observation.images.front: [135, 59, 331, 527]
# observation.images.side: [79, 47, 397, 450]
resize_size: [128, 128]
control_time_s: 20
reset_follower_pos: true

View File

@ -4,7 +4,9 @@ defaults:
- _self_
seed: 13
dataset_repo_id: aractingi/push_green_cube_hf_cropped_resized
dataset_repo_id: aractingi/push_cube_square_reward_cropped_resized
dataset_root: data/aractingi/push_cube_square_reward_cropped_resized
local_files_only: true
train_split_proportion: 0.8
# Required by logger
@ -14,7 +16,7 @@ env:
training:
num_epochs: 5
num_epochs: 6
batch_size: 16
learning_rate: 1e-4
num_workers: 4
@ -25,7 +27,7 @@ training:
save_freq: 1 # How often to save checkpoints (in epochs)
save_checkpoint: true
# image_keys: ["observation.images.top", "observation.images.wrist"]
image_keys: ["observation.images.laptop", "observation.images.phone"]
image_keys: ["observation.images.front", "observation.images.side"]
label_key: "next.reward"
profile_inference_time: false
profile_inference_time_iters: 20
@ -35,8 +37,8 @@ eval:
num_samples_to_log: 30 # Number of validation samples to log in the table
policy:
name: "hilserl/classifier/push_green_cube_hf_cropped_resized" #"hilserl/classifier/pick_place_lego_cube_1"
model_name: "helper2424/resnet10"
name: "hilserl/classifier/push_cube_square_reward_cropped_resized" #"hilserl/classifier/pick_place_lego_cube_120
model_name: "helper2424/resnet10" # "facebook/convnext-base-224" #"helper2424/resnet10"
model_type: "cnn"
num_cameras: 2 # Has to be len(training.image_keys)
@ -48,4 +50,4 @@ wandb:
device: "mps"
resume: false
output_dir: "outputs/classifier"
output_dir: "outputs/classifier/resnet10_frozen"

View File

@ -57,22 +57,22 @@ policy:
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]
observation.images.laptop: [3, 128, 128]
observation.images.phone: [3, 128, 128]
observation.images.front: [3, 128, 128]
observation.images.side: [3, 128, 128]
# observation.image: [3, 128, 128]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.laptop: mean_std
observation.images.phone: mean_std
observation.images.front: mean_std
observation.images.side: mean_std
observation.state: min_max
input_normalization_params:
observation.images.laptop:
observation.images.front:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
observation.images.phone:
observation.images.side:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
observation.state:

View File

@ -50,13 +50,13 @@ follower_arms:
gripper: [6, "sts3215"]
cameras:
laptop:
front:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 0
fps: 30
width: 640
height: 480
phone:
side:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 1
fps: 30

View File

@ -206,7 +206,8 @@ def record(
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
play_sounds: bool = True,
reset_follower: bool = False,
reset_follower: bool = False,
record_delta_actions: bool = False,
resume: bool = False,
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
local_files_only: bool = False,
@ -218,7 +219,12 @@ def record(
device = None
use_amp = None
extra_features = (
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
{
"next.reward": {"dtype": "int64", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
}
if assign_rewards
else None
)
if single_task:
@ -269,7 +275,7 @@ def record(
if reset_follower:
initial_position = robot.follower_arms["main"].read("Present_Position")
# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,
# 2. give times to the robot devices to connect and start synchronizing,
@ -302,6 +308,7 @@ def record(
device=device,
use_amp=use_amp,
fps=fps,
record_delta_actions=record_delta_actions,
)
# Execute a few seconds without recording to give time to manually reset the environment
@ -353,21 +360,24 @@ def replay(
fps: int | None = None,
play_sounds: bool = True,
local_files_only: bool = False,
replay_delta_actions: bool = False,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
actions = dataset.hf_dataset.select_columns("action")
if not robot.is_connected:
robot.connect()
log_say("Replaying episode", play_sounds, blocking=True)
for idx in range(dataset.num_frames):
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
start_episode_t = time.perf_counter()
action = actions[idx]["action"]
if replay_delta_actions:
action = action + current_joint_positions
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
@ -534,6 +544,12 @@ if __name__ == "__main__":
default=0,
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
)
parser_record.add_argument(
"--record-delta-actions",
type=int,
default=0,
help="Enables the recording of delta actions instead of absolute actions.",
)
parser_record.add_argument(
"--reset-follower",
type=int,
@ -563,6 +579,12 @@ if __name__ == "__main__":
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser_replay.add_argument(
"--replay-delta-actions",
type=int,
default=0,
help="Enables the replay of delta actions instead of absolute actions.",
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
args = parser.parse_args()

View File

@ -239,13 +239,17 @@ if __name__ == "__main__":
)
args = parser.parse_args()
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=True)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
rois = select_square_roi_for_images(images)
# rois = {
# "observation.images.front": [126, 43, 329, 518],
# "observation.images.side": [93, 69, 381, 434],
# }
# Print the selected rectangular ROIs
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")

View File

@ -230,6 +230,8 @@ class HILSerlRobotEnv(gym.Env):
if teleop_action.dim() == 1:
teleop_action = teleop_action.unsqueeze(0)
# self.render()
self.current_step += 1
reward = 0.0
@ -255,8 +257,7 @@ class HILSerlRobotEnv(gym.Env):
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
cv2.waitKey(1)
def close(self):
"""
@ -311,10 +312,14 @@ class RewardWrapper(gym.Wrapper):
start_time = time.perf_counter()
with torch.inference_mode():
reward = (
self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
self.reward_classifier.predict_reward(images, threshold=0.5)
if self.reward_classifier is not None
else 0.0
)
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
logging.info(f"Reward: {reward}")
if reward == 1.0:
terminated = True
return observation, reward, terminated, truncated, info
@ -760,7 +765,7 @@ if __name__ == "__main__":
env = make_robot_env(
robot,
reward_classifier,
cfg.wrapper,
cfg.env, # .wrapper,
)
env.reset()

View File

@ -1,584 +0,0 @@
import argparse
import logging
import time
from threading import Lock
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
import gymnasium as gym
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.robot_devices.control_utils import is_headless, reset_follower_position
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config, log_say
logging.basicConfig(level=logging.INFO)
class HILSerlRobotEnv(gym.Env):
"""
Gym-like environment wrapper for robot policy evaluation.
This wrapper provides a consistent interface for interacting with the robot,
following the OpenAI Gym environment conventions.
"""
def __init__(
self,
robot,
display_cameras=False,
):
"""
Initialize the robot environment.
Args:
robot: The robot interface object
reward_classifier: Optional reward classifier
fps: Frames per second for control
control_time_s: Total control time for each episode
display_cameras: Whether to display camera feeds
"""
super().__init__()
self.robot = robot
self.display_cameras = display_cameras
# connect robot
if not self.robot.is_connected:
self.robot.connect()
# Dynamically determine observation and action spaces
self._setup_spaces()
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
# Episode tracking
self.current_step = 0
self.episode_data = None
def _setup_spaces(self):
"""
Dynamically determine observation and action spaces based on robot capabilities.
This method should be customized based on the specific robot's observation
and action representations.
"""
# Example space setup - you'll need to adapt this to your specific robot
example_obs = self.robot.capture_observation()
# Observation space (assuming image-based observations)
image_keys = [key for key in example_obs if "image" in key]
state_keys = [key for key in example_obs if "image" not in key]
observation_spaces = {
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
for key in image_keys
}
observation_spaces["observation.state"] = gym.spaces.Dict(
{
key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32)
for key in state_keys
}
)
self.observation_space = gym.spaces.Dict(observation_spaces)
# Action space (assuming joint positions)
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
self.action_space = gym.spaces.Tuple(
(
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(action_dim,), dtype=np.float32),
gym.spaces.Discrete(2),
),
)
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""
Reset the environment to initial state.
Returns:
observation (dict): Initial observation
info (dict): Additional information
"""
super().reset(seed=seed, options=options)
# Capture initial observation
observation = self.robot.capture_observation()
# Reset tracking variables
self.current_step = 0
self.episode_data = None
return observation, {"initial_position": self.initial_follower_position}
def step(
self, action: Tuple[np.ndarray, bool]
) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
"""
Take a step in the environment.
Args:
action tuple(np.ndarray, bool):
Policy action to be executed on the robot and boolean to determine
whether to choose policy action or expert action.
Returns:
observation (dict): Next observation
reward (float): Reward for this step
terminated (bool): Whether the episode has terminated
truncated (bool): Whether the episode was truncated
info (dict): Additional information
"""
# The actions recieved are the in form of a tuple containing the policy action and an intervention bool
# The boolean inidicated whether we will use the expert's actions (through teleoperation) or the policy actions
policy_action, intervention_bool = action
teleop_action = None
if not intervention_bool:
self.robot.send_action(policy_action.cpu())
observation = self.robot.capture_observation()
else:
observation, teleop_action = self.robot.teleop_step(record_data=True)
teleop_action = teleop_action["action"] # teleop step returns torch tensors but in a dict
self.current_step += 1
reward = 0.0
terminated = False
truncated = False
return (
observation,
reward,
terminated,
truncated,
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
)
def render(self):
"""
Render the environment (in this case, display camera feeds).
"""
import cv2
observation = self.robot.capture_observation()
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
def close(self):
"""
Close the environment and disconnect the robot.
"""
if self.robot.is_connected:
self.robot.disconnect()
class ActionRepeatWrapper(gym.Wrapper):
def __init__(self, env, nb_repeat: int = 1):
super().__init__(env)
self.nb_repeat = nb_repeat
def step(self, action):
for _ in range(self.nb_repeat):
obs, reward, done, truncated, info = self.env.step(action)
if done or truncated:
break
return obs, reward, done, truncated, info
class RelativeJointPositionActionWrapper(gym.Wrapper):
def __init__(self, env: HILSerlRobotEnv, delta: float = 0.1):
super().__init__(env)
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
self.delta = delta
def step(self, action):
action_joint = action
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
if isinstance(self.env.action_space, gym.spaces.Tuple):
action_joint = action[0]
joint_positions = self.joint_positions + (self.delta * action_joint)
# clip the joint positions to the joint limits with the action space
joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high)
if isinstance(self.env.action_space, gym.spaces.Tuple):
return self.env.step((joint_positions, action[1]))
obs, reward, terminated, truncated, info = self.env.step(joint_positions)
if info["is_intervention"]:
# teleop actions are returned in absolute joint space
# If we are using a relative joint position action space,
# there will be a mismatch between the spaces of the policy and teleop actions
# Solution is to transform the teleop actions into relative space.
teleop_action = info["action_intervention"] # teleop actions are in absolute joint space
relative_teleop_action = (teleop_action - self.joint_positions) / self.delta
info["action_intervention"] = relative_teleop_action
return self.env.step(joint_positions)
class RewardWrapper(gym.Wrapper):
def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"):
self.env = env
self.reward_classifier = reward_classifier
self.device = device
def step(self, action):
observation, _, terminated, truncated, info = self.env.step(action)
images = [
observation[key].to(self.device, non_blocking=True) for key in observation if "image" in key
]
reward = self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
reward = reward.item()
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
return self.env.reset(seed=seed, options=options)
class TimeLimitWrapper(gym.Wrapper):
def __init__(self, env, control_time_s, fps):
self.env = env
self.control_time_s = control_time_s
self.fps = fps
self.last_timestamp = 0.0
self.episode_time_in_s = 0.0
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
time_since_last_step = time.perf_counter() - self.last_timestamp
self.episode_time_in_s += time_since_last_step
self.last_timestamp = time.perf_counter()
# check if last timestep took more time than the expected fps
if 1.0 / time_since_last_step < self.fps:
logging.warning(f"Current timestep exceeded expected fps {self.fps}")
if self.episode_time_in_s > self.control_time_s:
# Terminated = True
terminated = True
return obs, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
self.episode_time_in_s = 0.0
self.last_timestamp = time.perf_counter()
return self.env.reset(seed=seed, options=options)
class ImageCropResizeWrapper(gym.Wrapper):
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
self.env = env
self.crop_params_dict = crop_params_dict
for key in crop_params_dict:
assert key in self.env.observation_space, f"Key {key} not in observation space"
top, left, height, width = crop_params_dict[key]
new_shape = (top + height, left + width)
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
self.resize_size = resize_size
if self.resize_size is None:
self.resize_size = (128, 128)
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
for k in self.crop_params_dict:
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
obs[k] = F.resize(obs[k], self.resize_size)
return obs, reward, terminated, truncated, info
class ConvertToLeRobotObservation(gym.ObservationWrapper):
def __init__(self, env, device):
super().__init__(env)
self.device = device
def observation(self, observation):
observation = preprocess_observation(observation)
observation = {key: observation[key].to(self.device, non_blocking=True) for key in observation}
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
return observation
class KeyboardInterfaceWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.listener = None
self.events = {
"exit_early": False,
"pause_policy": False,
"reset_env": False,
"human_intervention_step": False,
}
self.event_lock = Lock() # Thread-safe access to events
self._init_keyboard_listener()
def _init_keyboard_listener(self):
"""Initialize keyboard listener if not in headless mode"""
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
return
try:
from pynput import keyboard
def on_press(key):
with self.event_lock:
try:
if key == keyboard.Key.right or key == keyboard.Key.esc:
print("Right arrow key pressed. Exiting loop...")
self.events["exit_early"] = True
elif key == keyboard.Key.space:
if not self.events["pause_policy"]:
print(
"Space key pressed. Human intervention required.\n"
"Place the leader in similar pose to the follower and press space again."
)
self.events["pause_policy"] = True
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
elif self.events["pause_policy"] and not self.events["human_intervention_step"]:
self.events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.")
log_say("Starting human intervention.", play_sounds=True)
else:
self.events["pause_policy"] = False
self.events["human_intervention_step"] = False
print("Space key pressed for a third time.")
log_say("Continuing with policy actions.", play_sounds=True)
except Exception as e:
print(f"Error handling key press: {e}")
self.listener = keyboard.Listener(on_press=on_press)
self.listener.start()
except ImportError:
logging.warning("Could not import pynput. Keyboard interface will not be available.")
self.listener = None
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
is_intervention = False
terminated_by_keyboard = False
# Extract policy_action if needed
if isinstance(self.env.action_space, gym.spaces.Tuple):
policy_action = action[0]
# Check the event flags without holding the lock for too long.
with self.event_lock:
if self.events["exit_early"]:
terminated_by_keyboard = True
# If we need to wait for human intervention, we note that outside the lock.
pause_policy = self.events["pause_policy"]
if pause_policy:
# Now, wait for human_intervention_step without holding the lock
while True:
with self.event_lock:
if self.events["human_intervention_step"]:
is_intervention = True
break
time.sleep(0.1) # Check more frequently if desired
# Execute the step in the underlying environment
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
return obs, reward, terminated or terminated_by_keyboard, truncated, info
def reset(self, **kwargs) -> Tuple[Any, Dict]:
"""
Reset the environment and clear any pending events
"""
with self.event_lock:
self.events = {k: False for k in self.events}
return self.env.reset(**kwargs)
def close(self):
"""
Properly clean up the keyboard listener when the environment is closed
"""
if self.listener is not None:
self.listener.stop()
super().close()
class ResetWrapper(gym.Wrapper):
def __init__(
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
):
super().__init__(env)
self.reset_fn = reset_fn
self.reset_time_s = reset_time_s
self.robot = self.unwrapped.robot
self.init_pos = self.unwrapped.initial_follower_position
def reset(self, *, seed=None, options=None):
if self.reset_fn is not None:
self.reset_fn(self.env)
else:
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
start_time = time.perf_counter()
while time.perf_counter() - start_time < self.reset_time_s:
self.robot.teleop_step()
log_say("Manual reseting of the environment done.", play_sounds=True)
return super().reset(seed=seed, options=options)
def make_robot_env(
robot,
reward_classifier,
crop_params_dict=None,
fps=30,
control_time_s=20,
reset_follower_pos=True,
display_cameras=False,
device="cuda:0",
resize_size=None,
reset_time_s=10,
delta_action=0.1,
nb_repeats=1,
use_relative_joint_positions=False,
):
"""
Factory function to create the robot environment.
Mimics gym.make() for consistent environment creation.
"""
env = HILSerlRobotEnv(robot, display_cameras)
env = ConvertToLeRobotObservation(env, device)
# if crop_params_dict is not None:
# env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
# env = RewardWrapper(env, reward_classifier)
env = TimeLimitWrapper(env, control_time_s, fps)
# if use_relative_joint_positions:
# env = RelativeJointPositionActionWrapper(env, delta=delta_action)
# env = ActionRepeatWrapper(env, nb_repeat=nb_repeats)
env = KeyboardInterfaceWrapper(env)
env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s)
return env
def get_classifier(pretrained_path, config_path, device="mps"):
if pretrained_path is None or config_path is None:
return
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to(device)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fps", type=int, default=30, help="control frequency")
parser.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
parser.add_argument(
"--config",
help=(
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument(
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
)
parser.add_argument(
"--reward-classifier-pretrained-path",
type=str,
default=None,
help="Path to the pretrained classifier weights.",
)
parser.add_argument(
"--reward-classifier-config-file",
type=str,
default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.",
)
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
robot = make_robot(robot_cfg)
reward_classifier = get_classifier(
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
)
env = make_robot_env(
robot,
reward_classifier,
None,
args.fps,
args.control_time_s,
args.reset_follower_pos,
args.display_cameras,
device="mps",
resize_size=None,
reset_time_s=10,
delta_action=0.1,
nb_repeats=1,
use_relative_joint_positions=False,
)
env.reset()
init_pos = env.unwrapped.initial_follower_position
goal_pos = init_pos
right_goal = init_pos.copy()
right_goal[0] += 50
left_goal = init_pos.copy()
left_goal[0] -= 50
# Michel is a beast
pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000)
while True:
for i in range(len(pitch_angle)):
goal_pos[0] = pitch_angle[i]
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False))
if terminated or truncated:
logging.info("Max control time reached, reset environment.")
env.reset()
for i in reversed(range(len(pitch_angle))):
goal_pos[0] = pitch_angle[i]
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False))
if terminated or truncated:
logging.info("Max control time reached, reset environment.")
env.reset()

View File

@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -23,6 +21,7 @@ import hydra
import numpy as np
import torch
import torch.nn as nn
import wandb
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
@ -32,7 +31,6 @@ from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger
@ -45,6 +43,7 @@ from lerobot.common.utils.utils import (
init_hydra_config,
set_global_seed,
)
from lerobot.scripts.server.buffer import random_shift
def get_model(cfg, logger): # noqa I001
@ -82,6 +81,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
for batch_idx, batch in enumerate(pbar):
start_time = time.perf_counter()
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
images = [random_shift(img, 4) for img in images]
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
@ -161,14 +161,17 @@ def validate(model, val_loader, criterion, device, logger, cfg):
# Log sample predictions for visualization
if len(samples) < cfg.eval.num_samples_to_log:
for i in range(min( cfg.eval.num_samples_to_log - len(samples), len(images))):
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3)
else:
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
samples.append(
{
**{f"image_{img_key}": wandb.Image(images[img_idx][i].cpu()) for img_idx, img_key in enumerate(cfg.training.image_keys)},
**{
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
for img_idx, img_key in enumerate(cfg.training.image_keys)
},
"true_label": labels[i].item(),
"predicted": predictions[i].item(),
"confidence": confidence,
@ -270,11 +273,13 @@ def train(cfg: DictConfig) -> None:
device = get_safe_torch_device(cfg.device, log=True)
set_global_seed(cfg.seed)
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "classifier"
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "frozen_resnet10_2"
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
# Setup dataset and dataloaders
dataset = LeRobotDataset(cfg.dataset_repo_id)
dataset = LeRobotDataset(
cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only
)
logging.info(f"Dataset size: {len(dataset)}")
n_total = len(dataset)
@ -282,14 +287,13 @@ def train(cfg: DictConfig) -> None:
train_dataset = torch.utils.data.Subset(dataset, range(0, n_train))
val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total))
sampler = create_balanced_sampler(train_dataset, cfg)
train_loader = DataLoader(
train_dataset,
batch_size=cfg.training.batch_size,
num_workers=cfg.training.num_workers,
sampler=sampler,
pin_memory=True,
pin_memory=device.type == "cuda",
)
val_loader = DataLoader(
@ -297,7 +301,7 @@ def train(cfg: DictConfig) -> None:
batch_size=cfg.eval.batch_size,
shuffle=False,
num_workers=cfg.training.num_workers,
pin_memory=True,
pin_memory=device.type == "cuda",
)
# Resume training if requested