diff --git a/examples/real_robot_example/0_record_training_data.py b/examples/real_robot_example/0_record_training_data.py new file mode 100644 index 00000000..5412002f --- /dev/null +++ b/examples/real_robot_example/0_record_training_data.py @@ -0,0 +1,200 @@ +import argparse +import copy +import os +import time + +import gym_real_env # noqa: F401 +import gymnasium as gym +import numpy as np +import torch +from datasets import Dataset, Features, Sequence, Value +from tqdm import tqdm + +from lerobot.common.datasets.compute_stats import compute_stats +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, DATA_DIR, LeRobotDataset +from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently +from lerobot.common.datasets.utils import ( + hf_transform_to_torch, +) +from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames +from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_videos_to_hub, save_meta_data + +# parse the repo_id name via command line +parser = argparse.ArgumentParser() +parser.add_argument("--repo_id", type=str, default="blue_red_sort") +parser.add_argument("--num_episodes", type=int, default=2) +parser.add_argument("--num_frames", type=int, default=400) +parser.add_argument("--num_workers", type=int, default=16) +parser.add_argument("--keep_last", action="store_true") +parser.add_argument("--push_to_hub", action="store_true") +parser.add_argument( + "--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset." +) +args = parser.parse_args() + +repo_id = args.repo_id +num_episodes = args.num_episodes +num_frames = args.num_frames +revision = args.revision + +out_data = DATA_DIR / repo_id + +images_dir = out_data / "images" +videos_dir = out_data / "videos" +meta_data_dir = out_data / "meta_data" + + +# Create image and video directories +if not os.path.exists(images_dir): + os.makedirs(images_dir, exist_ok=True) +if not os.path.exists(videos_dir): + os.makedirs(videos_dir, exist_ok=True) + +if __name__ == "__main__": + # Create the gym environment - check the kwargs in gym_real_env/src/env.py + gym_handle = "gym_real_env/RealEnv-v0" + env = gym.make(gym_handle, disable_env_checker=True, record=True) + + ep_dicts = [] + episode_data_index = {"from": [], "to": []} + ep_fps = [] + id_from = 0 + id_to = 0 + os.system('spd-say "env created"') + + for ep_idx in range(num_episodes): + # bring the follower to the leader and start camera + env.reset() + + os.system(f'spd-say "go {ep_idx}"') + # init buffers + obs_replay = {k: [] for k in env.observation_space} + timestamps = [] + + starting_time = time.time() + for _ in tqdm(range(num_frames)): + # Apply the next action + observation, _, _, _, _ = env.step(action=None) + # images_stacked = np.hstack(list(observation['pixels'].values())) + # images_stacked = cv2.cvtColor(images_stacked, cv2.COLOR_RGB2BGR) + # cv2.imshow('frame', images_stacked) + + # store data + for key in observation: + obs_replay[key].append(copy.deepcopy(observation[key])) + timestamps.append(time.time() - starting_time) + # if cv2.waitKey(1) & 0xFF == ord('q'): + # break + + os.system('spd-say "stop"') + + ep_dict = {} + # store images in png and create the video + for img_key in env.cameras: + save_images_concurrently( + obs_replay[f"images.{img_key}"], + images_dir / f"{img_key}_episode_{ep_idx:06d}", + args.num_workers, + ) + # for i in tqdm(range(num_frames)): + # cv2.imwrite(str(images_dir / f"{img_key}_episode_{ep_idx:06d}" / f"frame_{i:06d}.png"), + # obs_replay[i]['pixels'][img_key]) + fname = f"{img_key}_episode_{ep_idx:06d}.mp4" + # store the reference to the video frame + ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps] + # shutil.rmtree(tmp_imgs_dir) + + state = torch.tensor(np.array(obs_replay["agent_pos"])) + action = torch.tensor(np.array(obs_replay["leader_pos"])) + next_done = torch.zeros(num_frames, dtype=torch.bool) + next_done[-1] = True + + ep_dict["observation.state"] = state + ep_dict["action"] = action + ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64) + ep_dict["frame_index"] = torch.arange(0, num_frames, 1) + ep_dict["timestamp"] = torch.tensor(timestamps) + ep_dict["next.done"] = next_done + ep_fps.append(num_frames / timestamps[-1]) + ep_dicts.append(ep_dict) + print(f"Episode {ep_idx} done, fps: {ep_fps[-1]:.2f}") + + episode_data_index["from"].append(id_from) + episode_data_index["to"].append(id_from + num_frames if args.keep_last else id_from + num_frames - 1) + + id_to = id_from + num_frames if args.keep_last else id_from + num_frames - 1 + id_from = id_to + + env.close() + + os.system('spd-say "encode video frames"') + for ep_idx in range(num_episodes): + for img_key in env.cameras: + # If necessary, we may want to encode the video + # with variable frame rate: https://superuser.com/questions/1661901/encoding-video-from-vfr-still-images + encode_video_frames( + images_dir / f"{img_key}_episode_{ep_idx:06d}", + videos_dir / f"{img_key}_episode_{ep_idx:06d}.mp4", + ep_fps[ep_idx], + ) + + os.system('spd-say "concatenate episodes"') + data_dict = concatenate_episodes( + ep_dicts, drop_episodes_last_frame=not args.keep_last + ) # Since our fps varies we are sometimes off tolerance for the last frame + + features = {} + + keys = [key for key in data_dict if "observation.images." in key] + for key in keys: + features[key] = VideoFrame() + + features["observation.state"] = Sequence( + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + ) + features["action"] = Sequence( + length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) + ) + features["episode_index"] = Value(dtype="int64", id=None) + features["frame_index"] = Value(dtype="int64", id=None) + features["timestamp"] = Value(dtype="float32", id=None) + features["next.done"] = Value(dtype="bool", id=None) + features["index"] = Value(dtype="int64", id=None) + # TODO(rcadene): add success + # features["next.success"] = Value(dtype='bool', id=None) + + hf_dataset = Dataset.from_dict(data_dict, features=Features(features)) + hf_dataset.set_transform(hf_transform_to_torch) + + info = { + "fps": sum(ep_fps) / len(ep_fps), # to have a good tolerance in data processing for the slowest video + "video": 1, + } + + os.system('spd-say "from preloaded"') + lerobot_dataset = LeRobotDataset.from_preloaded( + repo_id=repo_id, + version=revision, + hf_dataset=hf_dataset, + episode_data_index=episode_data_index, + info=info, + videos_dir=videos_dir, + ) + os.system('spd-say "compute stats"') + stats = compute_stats(lerobot_dataset) + + os.system('spd-say "save to disk"') + hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved + hf_dataset.save_to_disk(str(out_data / "train")) + + save_meta_data(info, stats, episode_data_index, meta_data_dir) + + if args.push_to_hub: + hf_dataset.push_to_hub(repo_id, token=True, revision="main") + hf_dataset.push_to_hub(repo_id, token=True, revision=revision) + + push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") + push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision) + + push_videos_to_hub(repo_id, videos_dir, revision="main") + push_videos_to_hub(repo_id, videos_dir, revision=revision) diff --git a/examples/real_robot_example/config/act_koch_real.yaml b/examples/real_robot_example/config/act_koch_real.yaml new file mode 100644 index 00000000..b5f4bd98 --- /dev/null +++ b/examples/real_robot_example/config/act_koch_real.yaml @@ -0,0 +1,103 @@ +# @package _global_ + +# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets. +# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images, +# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used +# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation. +# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot). +# Look at its README for more information on how to evaluate a checkpoint in the real-world. +# +# Example of usage for training: +# ```bash +# python lerobot/scripts/train.py \ +# policy=act_real \ +# env=aloha_real +# ``` + +seed: 1000 +dataset_repo_id: thomwolf/blue_sort + +override_dataset_stats: + observation.images.cam_high: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_low: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + +training: + offline_steps: 1000 + online_steps: 0 + eval_freq: -1 + save_freq: 1000 + log_freq: 100 + save_checkpoint: true + + batch_size: 8 + lr: 1e-5 + lr_backbone: 1e-5 + weight_decay: 1e-4 + grad_clip_norm: 10 + online_steps_between_rollouts: 1 + + delta_timestamps: + action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]" + +eval: + n_episodes: 50 + batch_size: 50 + +# See `configuration_act.py` for more details. +policy: + name: act + + # Input / output structure. + n_obs_steps: 1 + chunk_size: 100 # chunk_size + n_action_steps: 100 + + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.images.cam_high: [3, 480, 640] + observation.images.cam_low: [3, 480, 640] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.images.cam_high: mean_std + observation.images.cam_low: mean_std + observation.state: mean_std + output_normalization_modes: + action: mean_std + + # Architecture. + # Vision backbone. + vision_backbone: resnet18 + pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1 + replace_final_stride_with_dilation: false + # Transformer layers. + pre_norm: false + dim_model: 512 + n_heads: 8 + dim_feedforward: 3200 + feedforward_activation: relu + n_encoder_layers: 4 + # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code + # that means only the first layer is used. Here we match the original implementation by setting this to 1. + # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521. + n_decoder_layers: 1 + # VAE. + use_vae: true + latent_dim: 32 + n_vae_encoder_layers: 4 + + # Inference. + temporal_ensemble_momentum: null + + # Training and loss computation. + dropout: 0.1 + kl_weight: 10.0 diff --git a/examples/real_robot_example/config/dora_koch_real.yaml b/examples/real_robot_example/config/dora_koch_real.yaml new file mode 100644 index 00000000..bf067f50 --- /dev/null +++ b/examples/real_robot_example/config/dora_koch_real.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +fps: 30 + +env: + name: dora + task: DoraKoch-v0 + state_dim: 6 + action_dim: 6 + fps: ${fps} + episode_length: 400 + gym: + fps: ${fps} diff --git a/examples/real_robot_example/gym_real_env/__init__.py b/examples/real_robot_example/gym_real_env/__init__.py new file mode 100644 index 00000000..50cc2a7f --- /dev/null +++ b/examples/real_robot_example/gym_real_env/__init__.py @@ -0,0 +1,8 @@ +from gymnasium.envs.registration import register + +register( + id="gym_real_env/RealEnv-v0", + entry_point="gym_real_env.env:RealEnv", + max_episode_steps=300, + nondeterministic=True, +) diff --git a/examples/real_robot_example/gym_real_env/dynamixel.py b/examples/real_robot_example/gym_real_env/dynamixel.py new file mode 100644 index 00000000..5271acbb --- /dev/null +++ b/examples/real_robot_example/gym_real_env/dynamixel.py @@ -0,0 +1,359 @@ +# ruff: noqa +from __future__ import annotations + +import enum +import math +import os +from dataclasses import dataclass + +import numpy as np +from dynamixel_sdk import * # Uses Dynamixel SDK library + + +def pos2pwm(pos: np.ndarray) -> np.ndarray: + """ + :param pos: numpy array of joint positions in range [-pi, pi] + :return: numpy array of pwm values in range [0, 4096] + """ + return ((pos / 3.14 + 1.0) * 2048).astype(np.int64) + + +def pwm2pos(pwm: np.ndarray) -> np.ndarray: + """ + :param pwm: numpy array of pwm values in range [0, 4096] + :return: numpy array of joint positions in range [-pi, pi] + """ + return (pwm / 2048 - 1) * 3.14 + + +def pwm2vel(pwm: np.ndarray) -> np.ndarray: + """ + :param pwm: numpy array of pwm/s joint velocities + :return: numpy array of rad/s joint velocities + """ + return pwm * 3.14 / 2048 + + +def vel2pwm(vel: np.ndarray) -> np.ndarray: + """ + :param vel: numpy array of rad/s joint velocities + :return: numpy array of pwm/s joint velocities + """ + return (vel * 2048 / 3.14).astype(np.int64) + + +class ReadAttribute(enum.Enum): + TEMPERATURE = 146 + VOLTAGE = 145 + VELOCITY = 128 + POSITION = 132 + CURRENT = 126 + PWM = 124 + HARDWARE_ERROR_STATUS = 70 + HOMING_OFFSET = 20 + BAUDRATE = 8 + + +class OperatingMode(enum.Enum): + VELOCITY = 1 + POSITION = 3 + CURRENT_CONTROLLED_POSITION = 5 + PWM = 16 + UNKNOWN = -1 + + +class Dynamixel: + ADDR_TORQUE_ENABLE = 64 + ADDR_GOAL_POSITION = 116 + ADDR_VELOCITY_LIMIT = 44 + ADDR_GOAL_PWM = 100 + OPERATING_MODE_ADDR = 11 + POSITION_I = 82 + POSITION_P = 84 + ADDR_ID = 7 + + @dataclass + class Config: + def instantiate(self): + return Dynamixel(self) + + baudrate: int = 57600 + protocol_version: float = 2.0 + device_name: str = "" # /dev/tty.usbserial-1120' + dynamixel_id: int = 1 + + def __init__(self, config: Config): + self.config = config + self.connect() + + def connect(self): + if self.config.device_name == "": + for port_name in os.listdir("/dev"): + if "ttyUSB" in port_name or "ttyACM" in port_name: + self.config.device_name = "/dev/" + port_name + print(f"using device {self.config.device_name}") + self.portHandler = PortHandler(self.config.device_name) + # self.portHandler.LA + self.packetHandler = PacketHandler(self.config.protocol_version) + if not self.portHandler.openPort(): + raise Exception(f"Failed to open port {self.config.device_name}") + + if not self.portHandler.setBaudRate(self.config.baudrate): + raise Exception(f"failed to set baudrate to {self.config.baudrate}") + + # self.operating_mode = OperatingMode.UNKNOWN + # self.torque_enabled = False + # self._disable_torque() + + self.operating_modes = [None for _ in range(32)] + self.torque_enabled = [None for _ in range(32)] + return True + + def disconnect(self): + self.portHandler.closePort() + + def set_goal_position(self, motor_id, goal_position): + # if self.operating_modes[motor_id] is not OperatingMode.POSITION: + # self._disable_torque(motor_id) + # self.set_operating_mode(motor_id, OperatingMode.POSITION) + + # if not self.torque_enabled[motor_id]: + # self._enable_torque(motor_id) + + # self._enable_torque(motor_id) + dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx( + self.portHandler, motor_id, self.ADDR_GOAL_POSITION, goal_position + ) + # self._process_response(dxl_comm_result, dxl_error) + # print(f'set position of motor {motor_id} to {goal_position}') + + def set_pwm_value(self, motor_id: int, pwm_value, tries=3): + if self.operating_modes[motor_id] is not OperatingMode.PWM: + self._disable_torque(motor_id) + self.set_operating_mode(motor_id, OperatingMode.PWM) + + if not self.torque_enabled[motor_id]: + self._enable_torque(motor_id) + # print(f'enabling torque') + dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx( + self.portHandler, motor_id, self.ADDR_GOAL_PWM, pwm_value + ) + # self._process_response(dxl_comm_result, dxl_error) + # print(f'set pwm of motor {motor_id} to {pwm_value}') + if dxl_comm_result != COMM_SUCCESS: + if tries <= 1: + raise ConnectionError(f"dxl_comm_result: {self.packetHandler.getTxRxResult(dxl_comm_result)}") + else: + print(f"dynamixel pwm setting failure trying again with {tries - 1} tries") + self.set_pwm_value(motor_id, pwm_value, tries=tries - 1) + elif dxl_error != 0: + print(f"dxl error {dxl_error}") + raise ConnectionError(f"dynamixel error: {self.packetHandler.getTxRxResult(dxl_error)}") + + def read_temperature(self, motor_id: int): + return self._read_value(motor_id, ReadAttribute.TEMPERATURE, 1) + + def read_velocity(self, motor_id: int): + pos = self._read_value(motor_id, ReadAttribute.VELOCITY, 4) + if pos > 2**31: + pos -= 2**32 + # print(f'read position {pos} for motor {motor_id}') + return pos + + def read_position(self, motor_id: int): + pos = self._read_value(motor_id, ReadAttribute.POSITION, 4) + if pos > 2**31: + pos -= 2**32 + # print(f'read position {pos} for motor {motor_id}') + return pos + + def read_position_degrees(self, motor_id: int) -> float: + return (self.read_position(motor_id) / 4096) * 360 + + def read_position_radians(self, motor_id: int) -> float: + return (self.read_position(motor_id) / 4096) * 2 * math.pi + + def read_current(self, motor_id: int): + current = self._read_value(motor_id, ReadAttribute.CURRENT, 2) + if current > 2**15: + current -= 2**16 + return current + + def read_present_pwm(self, motor_id: int): + return self._read_value(motor_id, ReadAttribute.PWM, 2) + + def read_hardware_error_status(self, motor_id: int): + return self._read_value(motor_id, ReadAttribute.HARDWARE_ERROR_STATUS, 1) + + def disconnect(self): + self.portHandler.closePort() + + def set_id(self, old_id, new_id, use_broadcast_id: bool = False): + """ + sets the id of the dynamixel servo + @param old_id: current id of the servo + @param new_id: new id + @param use_broadcast_id: set ids of all connected dynamixels if True. + If False, change only servo with self.config.id + @return: + """ + if use_broadcast_id: + current_id = 254 + else: + current_id = old_id + dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx( + self.portHandler, current_id, self.ADDR_ID, new_id + ) + self._process_response(dxl_comm_result, dxl_error, old_id) + self.config.id = id + + def _enable_torque(self, motor_id): + dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx( + self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 1 + ) + self._process_response(dxl_comm_result, dxl_error, motor_id) + self.torque_enabled[motor_id] = True + + def _disable_torque(self, motor_id): + dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx( + self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 0 + ) + self._process_response(dxl_comm_result, dxl_error, motor_id) + self.torque_enabled[motor_id] = False + + def _process_response(self, dxl_comm_result: int, dxl_error: int, motor_id: int): + if dxl_comm_result != COMM_SUCCESS: + raise ConnectionError( + f"dxl_comm_result for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_comm_result)}" + ) + elif dxl_error != 0: + print(f"dxl error {dxl_error}") + raise ConnectionError( + f"dynamixel error for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_error)}" + ) + + def set_operating_mode(self, motor_id: int, operating_mode: OperatingMode): + dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx( + self.portHandler, motor_id, self.OPERATING_MODE_ADDR, operating_mode.value + ) + self._process_response(dxl_comm_result, dxl_error, motor_id) + self.operating_modes[motor_id] = operating_mode + + def set_pwm_limit(self, motor_id: int, limit: int): + dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(self.portHandler, motor_id, 36, limit) + self._process_response(dxl_comm_result, dxl_error, motor_id) + + def set_velocity_limit(self, motor_id: int, velocity_limit): + dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx( + self.portHandler, motor_id, self.ADDR_VELOCITY_LIMIT, velocity_limit + ) + self._process_response(dxl_comm_result, dxl_error, motor_id) + + def set_P(self, motor_id: int, P: int): + dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx( + self.portHandler, motor_id, self.POSITION_P, P + ) + self._process_response(dxl_comm_result, dxl_error, motor_id) + + def set_I(self, motor_id: int, I: int): + dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx( + self.portHandler, motor_id, self.POSITION_I, I + ) + self._process_response(dxl_comm_result, dxl_error, motor_id) + + def read_home_offset(self, motor_id: int): + self._disable_torque(motor_id) + # dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(self.portHandler, motor_id, + # ReadAttribute.HOMING_OFFSET.value, home_position) + home_offset = self._read_value(motor_id, ReadAttribute.HOMING_OFFSET, 4) + # self._process_response(dxl_comm_result, dxl_error) + self._enable_torque(motor_id) + return home_offset + + def set_home_offset(self, motor_id: int, home_position: int): + self._disable_torque(motor_id) + dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx( + self.portHandler, motor_id, ReadAttribute.HOMING_OFFSET.value, home_position + ) + self._process_response(dxl_comm_result, dxl_error, motor_id) + self._enable_torque(motor_id) + + def set_baudrate(self, motor_id: int, baudrate): + # translate baudrate into dynamixel baudrate setting id + if baudrate == 57600: + baudrate_id = 1 + elif baudrate == 1_000_000: + baudrate_id = 3 + elif baudrate == 2_000_000: + baudrate_id = 4 + elif baudrate == 3_000_000: + baudrate_id = 5 + elif baudrate == 4_000_000: + baudrate_id = 6 + else: + raise Exception("baudrate not implemented") + + self._disable_torque(motor_id) + dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx( + self.portHandler, motor_id, ReadAttribute.BAUDRATE.value, baudrate_id + ) + self._process_response(dxl_comm_result, dxl_error, motor_id) + + def _read_value(self, motor_id, attribute: ReadAttribute, num_bytes: int, tries=10): + try: + if num_bytes == 1: + value, dxl_comm_result, dxl_error = self.packetHandler.read1ByteTxRx( + self.portHandler, motor_id, attribute.value + ) + elif num_bytes == 2: + value, dxl_comm_result, dxl_error = self.packetHandler.read2ByteTxRx( + self.portHandler, motor_id, attribute.value + ) + elif num_bytes == 4: + value, dxl_comm_result, dxl_error = self.packetHandler.read4ByteTxRx( + self.portHandler, motor_id, attribute.value + ) + except Exception: + if tries == 0: + raise Exception + else: + return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1) + if dxl_comm_result != COMM_SUCCESS: + if tries <= 1: + # print("%s" % self.packetHandler.getTxRxResult(dxl_comm_result)) + raise ConnectionError(f"dxl_comm_result {dxl_comm_result} for servo {motor_id} value {value}") + else: + print(f"dynamixel read failure for servo {motor_id} trying again with {tries - 1} tries") + time.sleep(0.02) + return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1) + elif dxl_error != 0: # # print("%s" % self.packetHandler.getRxPacketError(dxl_error)) + # raise ConnectionError(f'dxl_error {dxl_error} binary ' + "{0:b}".format(37)) + if tries == 0 and dxl_error != 128: + raise Exception(f"Failed to read value from motor {motor_id} error is {dxl_error}") + else: + return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1) + return value + + def set_home_position(self, motor_id: int): + print(f"setting home position for motor {motor_id}") + self.set_home_offset(motor_id, 0) + current_position = self.read_position(motor_id) + print(f"position before {current_position}") + self.set_home_offset(motor_id, -current_position) + # dynamixel.set_home_offset(motor_id, -4096) + # dynamixel.set_home_offset(motor_id, -4294964109) + current_position = self.read_position(motor_id) + # print(f'signed position {current_position - 2** 32}') + print(f"position after {current_position}") + + +if __name__ == "__main__": + dynamixel = Dynamixel.Config(baudrate=1_000_000, device_name="/dev/tty.usbmodem57380045631").instantiate() + motor_id = 1 + pos = dynamixel.read_position(motor_id) + for i in range(10): + s = time.monotonic() + pos = dynamixel.read_position(motor_id) + delta = time.monotonic() - s + print(f"read position took {delta}") + print(f"position {pos}") diff --git a/examples/real_robot_example/gym_real_env/env.py b/examples/real_robot_example/gym_real_env/env.py new file mode 100644 index 00000000..9a1f5694 --- /dev/null +++ b/examples/real_robot_example/gym_real_env/env.py @@ -0,0 +1,158 @@ +import time + +import cv2 +import gymnasium as gym +import numpy as np +from gymnasium import spaces + +from .dynamixel import pos2pwm, pwm2pos +from .robot import Robot + +FPS = 30 + +CAMERAS_SHAPES = { + "observation.images.high": (480, 640, 3), + "observation.images.low": (480, 640, 3), +} + +CAMERAS_PORTS = { + "observation.images.high": "/dev/video6", + "observation.images.low": "/dev/video0", +} + +LEADER_PORT = "/dev/ttyACM1" +FOLLOWER_PORT = "/dev/ttyACM0" + + +def capture_image(cam, cam_width, cam_height): + # Capture a single frame + _, frame = cam.read() + image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + # # Define your crop coordinates (top left corner and bottom right corner) + # x1, y1 = 400, 0 # Example starting coordinates (top left of the crop rectangle) + # x2, y2 = 1600, 900 # Example ending coordinates (bottom right of the crop rectangle) + # # Crop the image + # image = image[y1:y2, x1:x2] + # Resize the image + image = cv2.resize(image, (cam_width, cam_height), interpolation=cv2.INTER_AREA) + + return image + + +class RealEnv(gym.Env): + metadata = {} + + def __init__( + self, + record: bool = False, + num_joints: int = 6, + cameras_shapes: dict = CAMERAS_SHAPES, + cameras_ports: dict = CAMERAS_PORTS, + follower_port: str = FOLLOWER_PORT, + leader_port: str = LEADER_PORT, + warmup_steps: int = 100, + trigger_torque=70, + ): + self.num_joints = num_joints + self.cameras_shapes = cameras_shapes + self.cameras_ports = cameras_ports + self.warmup_steps = warmup_steps + assert len(self.cameras_shapes) == len(self.cameras_ports), "Number of cameras and shapes must match." + + self.follower_port = follower_port + self.leader_port = leader_port + self.record = record + + # Initialize the robot + self.follower = Robot(device_name=self.follower_port) + if self.record: + self.leader = Robot(device_name=self.leader_port) + self.leader.set_trigger_torque(trigger_torque) + + # Initialize the cameras - sorted by camera names + self.cameras = {} + for cn, p in sorted(self.cameras_ports.items()): + assert cn.startswith("observation.images."), "Camera names must start with 'observation.images.'." + self.cameras[cn] = cv2.VideoCapture(p) + if not all(c.isOpened() for c in self.cameras.values()): + raise OSError("Cannot open all camera ports.") + + # Specify gym action and observation spaces + observation_space = {} + + if self.num_joints > 0: + observation_space["agent_pos"] = spaces.Box( + low=-1000.0, + high=1000.0, + shape=(num_joints,), + dtype=np.float64, + ) + if self.record: + observation_space["leader_pos"] = spaces.Box( + low=-1000.0, + high=1000.0, + shape=(num_joints,), + dtype=np.float64, + ) + + if self.cameras_shapes: + for cn, hwc_shape in self.cameras_shapes.items(): + # Assumes images are unsigned int8 in [0,255] + observation_space[f"images.{cn}"] = spaces.Box( + low=0, + high=255, + # height x width x channels (e.g. 480 x 640 x 3) + shape=hwc_shape, + dtype=np.uint8, + ) + + self.observation_space = spaces.Dict(observation_space) + self.action_space = spaces.Box(low=-1, high=1, shape=(num_joints,), dtype=np.float32) + + self._observation = {} + self._terminated = False + self._action_time = time.time() + + def _get_obs(self): + qpos = self.follower.read_position() + self._observation["agent_pos"] = pwm2pos(qpos) + for cn, c in self.cameras.items(): + self._observation[f"images.{cn}"] = capture_image( + c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0] + ) + + if self.record: + leader_pos = self.leader.read_position() + self._observation["leader_pos"] = pwm2pos(leader_pos) + + def reset(self, seed: int | None = None): + del seed + # Reset the robot and sync the leader and follower if we are recording + for _ in range(self.warmup_steps): + self._get_obs() + if self.record: + self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"])) + self._terminated = False + info = {} + return self._observation, info + + def step(self, action: np.ndarray = None): + # Reset the observation + self._get_obs() + if self.record: + # Teleoperate the leader + self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"])) + else: + # Apply the action to the follower + self.follower.set_goal_pos(pos2pwm(action)) + reward = 0 + terminated = truncated = self._terminated + info = {} + return self._observation, reward, terminated, truncated, info + + def render(self): ... + + def close(self): + self.follower._disable_torque() + if self.record: + self.leader._disable_torque() diff --git a/examples/real_robot_example/gym_real_env/robot.py b/examples/real_robot_example/gym_real_env/robot.py new file mode 100644 index 00000000..6c95ff16 --- /dev/null +++ b/examples/real_robot_example/gym_real_env/robot.py @@ -0,0 +1,163 @@ +# ruff: noqa +from enum import Enum, auto +from typing import Union + +import numpy as np +from dynamixel import Dynamixel, OperatingMode, ReadAttribute +from dynamixel_sdk import DXL_HIBYTE, DXL_HIWORD, DXL_LOBYTE, DXL_LOWORD, GroupSyncRead, GroupSyncWrite + + +class MotorControlType(Enum): + PWM = auto() + POSITION_CONTROL = auto() + DISABLED = auto() + UNKNOWN = auto() + + +class Robot: + def __init__(self, device_name: str, baudrate=1_000_000, servo_ids=[1, 2, 3, 4, 5, 6]) -> None: + self.servo_ids = servo_ids + self.dynamixel = Dynamixel.Config(baudrate=baudrate, device_name=device_name).instantiate() + self._init_motors() + + def _init_motors(self): + self.position_reader = GroupSyncRead( + self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.POSITION.value, 4 + ) + for id in self.servo_ids: + self.position_reader.addParam(id) + + self.velocity_reader = GroupSyncRead( + self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.VELOCITY.value, 4 + ) + for id in self.servo_ids: + self.velocity_reader.addParam(id) + + self.pos_writer = GroupSyncWrite( + self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_POSITION, 4 + ) + for id in self.servo_ids: + self.pos_writer.addParam(id, [2048]) + + self.pwm_writer = GroupSyncWrite( + self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_PWM, 2 + ) + for id in self.servo_ids: + self.pwm_writer.addParam(id, [2048]) + self._disable_torque() + self.motor_control_state = MotorControlType.DISABLED + + def read_position(self, tries=2): + """ + Reads the joint positions of the robot. 2048 is the center position. 0 and 4096 are 180 degrees in each direction. + :param tries: maximum number of tries to read the position + :return: list of joint positions in range [0, 4096] + """ + result = self.position_reader.txRxPacket() + if result != 0: + if tries > 0: + return self.read_position(tries=tries - 1) + else: + print("failed to read position!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + positions = [] + for id in self.servo_ids: + position = self.position_reader.getData(id, ReadAttribute.POSITION.value, 4) + if position > 2**31: + position -= 2**32 + positions.append(position) + return np.array(positions) + + def read_velocity(self): + """ + Reads the joint velocities of the robot. + :return: list of joint velocities, + """ + self.velocity_reader.txRxPacket() + velocties = [] + for id in self.servo_ids: + velocity = self.velocity_reader.getData(id, ReadAttribute.VELOCITY.value, 4) + if velocity > 2**31: + velocity -= 2**32 + velocties.append(velocity) + return np.array(velocties) + + def set_goal_pos(self, action): + """ + :param action: list or numpy array of target joint positions in range [0, 4096] + """ + if self.motor_control_state is not MotorControlType.POSITION_CONTROL: + self._set_position_control() + for i, motor_id in enumerate(self.servo_ids): + data_write = [ + DXL_LOBYTE(DXL_LOWORD(action[i])), + DXL_HIBYTE(DXL_LOWORD(action[i])), + DXL_LOBYTE(DXL_HIWORD(action[i])), + DXL_HIBYTE(DXL_HIWORD(action[i])), + ] + self.pos_writer.changeParam(motor_id, data_write) + + self.pos_writer.txPacket() + + def set_pwm(self, action): + """ + Sets the pwm values for the servos. + :param action: list or numpy array of pwm values in range [0, 885] + """ + if self.motor_control_state is not MotorControlType.PWM: + self._set_pwm_control() + for i, motor_id in enumerate(self.servo_ids): + data_write = [ + DXL_LOBYTE(DXL_LOWORD(action[i])), + DXL_HIBYTE(DXL_LOWORD(action[i])), + ] + self.pwm_writer.changeParam(motor_id, data_write) + + self.pwm_writer.txPacket() + + def set_trigger_torque(self, torque: int): + """ + Sets a constant torque torque for the last servo in the chain. This is useful for the trigger of the leader arm + """ + self.dynamixel._enable_torque(self.servo_ids[-1]) + self.dynamixel.set_pwm_value(self.servo_ids[-1], torque) + + def limit_pwm(self, limit: Union[int, list, np.ndarray]): + """ + Limits the pwm values for the servos in for position control + @param limit: 0 ~ 885 + @return: + """ + if isinstance(limit, int): + limits = [ + limit, + ] * 5 + else: + limits = limit + self._disable_torque() + for motor_id, limit in zip(self.servo_ids, limits, strict=False): + self.dynamixel.set_pwm_limit(motor_id, limit) + self._enable_torque() + + def _disable_torque(self): + print(f"disabling torque for servos {self.servo_ids}") + for motor_id in self.servo_ids: + self.dynamixel._disable_torque(motor_id) + + def _enable_torque(self): + print(f"enabling torque for servos {self.servo_ids}") + for motor_id in self.servo_ids: + self.dynamixel._enable_torque(motor_id) + + def _set_pwm_control(self): + self._disable_torque() + for motor_id in self.servo_ids: + self.dynamixel.set_operating_mode(motor_id, OperatingMode.PWM) + self._enable_torque() + self.motor_control_state = MotorControlType.PWM + + def _set_position_control(self): + self._disable_torque() + for motor_id in self.servo_ids: + self.dynamixel.set_operating_mode(motor_id, OperatingMode.POSITION) + self._enable_torque() + self.motor_control_state = MotorControlType.POSITION_CONTROL diff --git a/lerobot/common/datasets/push_dataset_to_hub/utils.py b/lerobot/common/datasets/push_dataset_to_hub/utils.py index 4feb1dcf..adfe42a5 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/utils.py +++ b/lerobot/common/datasets/push_dataset_to_hub/utils.py @@ -21,19 +21,24 @@ import PIL import torch -def concatenate_episodes(ep_dicts): +def concatenate_episodes(ep_dicts, drop_episodes_last_frame=False): data_dict = {} keys = ep_dicts[0].keys() for key in keys: if torch.is_tensor(ep_dicts[0][key][0]): - data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) + if drop_episodes_last_frame: + data_dict[key] = torch.cat([ep_dict[key][:-1] for ep_dict in ep_dicts]) + else: + data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) else: if key not in data_dict: data_dict[key] = [] for ep_dict in ep_dicts: for x in ep_dict[key]: data_dict[key].append(x) + if drop_episodes_last_frame: + data_dict[key].pop() total_frames = data_dict["frame_index"].shape[0] data_dict["index"] = torch.arange(0, total_frames, 1)