add simple manual real world gym env example
This commit is contained in:
parent
0935e49c8a
commit
57d3d27c78
|
@ -0,0 +1,200 @@
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import gym_real_env # noqa: F401
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets import Dataset, Features, Sequence, Value
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, DATA_DIR, LeRobotDataset
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
hf_transform_to_torch,
|
||||||
|
)
|
||||||
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_videos_to_hub, save_meta_data
|
||||||
|
|
||||||
|
# parse the repo_id name via command line
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--repo_id", type=str, default="blue_red_sort")
|
||||||
|
parser.add_argument("--num_episodes", type=int, default=2)
|
||||||
|
parser.add_argument("--num_frames", type=int, default=400)
|
||||||
|
parser.add_argument("--num_workers", type=int, default=16)
|
||||||
|
parser.add_argument("--keep_last", action="store_true")
|
||||||
|
parser.add_argument("--push_to_hub", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
repo_id = args.repo_id
|
||||||
|
num_episodes = args.num_episodes
|
||||||
|
num_frames = args.num_frames
|
||||||
|
revision = args.revision
|
||||||
|
|
||||||
|
out_data = DATA_DIR / repo_id
|
||||||
|
|
||||||
|
images_dir = out_data / "images"
|
||||||
|
videos_dir = out_data / "videos"
|
||||||
|
meta_data_dir = out_data / "meta_data"
|
||||||
|
|
||||||
|
|
||||||
|
# Create image and video directories
|
||||||
|
if not os.path.exists(images_dir):
|
||||||
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
|
if not os.path.exists(videos_dir):
|
||||||
|
os.makedirs(videos_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Create the gym environment - check the kwargs in gym_real_env/src/env.py
|
||||||
|
gym_handle = "gym_real_env/RealEnv-v0"
|
||||||
|
env = gym.make(gym_handle, disable_env_checker=True, record=True)
|
||||||
|
|
||||||
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
ep_fps = []
|
||||||
|
id_from = 0
|
||||||
|
id_to = 0
|
||||||
|
os.system('spd-say "env created"')
|
||||||
|
|
||||||
|
for ep_idx in range(num_episodes):
|
||||||
|
# bring the follower to the leader and start camera
|
||||||
|
env.reset()
|
||||||
|
|
||||||
|
os.system(f'spd-say "go {ep_idx}"')
|
||||||
|
# init buffers
|
||||||
|
obs_replay = {k: [] for k in env.observation_space}
|
||||||
|
timestamps = []
|
||||||
|
|
||||||
|
starting_time = time.time()
|
||||||
|
for _ in tqdm(range(num_frames)):
|
||||||
|
# Apply the next action
|
||||||
|
observation, _, _, _, _ = env.step(action=None)
|
||||||
|
# images_stacked = np.hstack(list(observation['pixels'].values()))
|
||||||
|
# images_stacked = cv2.cvtColor(images_stacked, cv2.COLOR_RGB2BGR)
|
||||||
|
# cv2.imshow('frame', images_stacked)
|
||||||
|
|
||||||
|
# store data
|
||||||
|
for key in observation:
|
||||||
|
obs_replay[key].append(copy.deepcopy(observation[key]))
|
||||||
|
timestamps.append(time.time() - starting_time)
|
||||||
|
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
# break
|
||||||
|
|
||||||
|
os.system('spd-say "stop"')
|
||||||
|
|
||||||
|
ep_dict = {}
|
||||||
|
# store images in png and create the video
|
||||||
|
for img_key in env.cameras:
|
||||||
|
save_images_concurrently(
|
||||||
|
obs_replay[f"images.{img_key}"],
|
||||||
|
images_dir / f"{img_key}_episode_{ep_idx:06d}",
|
||||||
|
args.num_workers,
|
||||||
|
)
|
||||||
|
# for i in tqdm(range(num_frames)):
|
||||||
|
# cv2.imwrite(str(images_dir / f"{img_key}_episode_{ep_idx:06d}" / f"frame_{i:06d}.png"),
|
||||||
|
# obs_replay[i]['pixels'][img_key])
|
||||||
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
|
# store the reference to the video frame
|
||||||
|
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps]
|
||||||
|
# shutil.rmtree(tmp_imgs_dir)
|
||||||
|
|
||||||
|
state = torch.tensor(np.array(obs_replay["agent_pos"]))
|
||||||
|
action = torch.tensor(np.array(obs_replay["leader_pos"]))
|
||||||
|
next_done = torch.zeros(num_frames, dtype=torch.bool)
|
||||||
|
next_done[-1] = True
|
||||||
|
|
||||||
|
ep_dict["observation.state"] = state
|
||||||
|
ep_dict["action"] = action
|
||||||
|
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||||
|
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||||
|
ep_dict["timestamp"] = torch.tensor(timestamps)
|
||||||
|
ep_dict["next.done"] = next_done
|
||||||
|
ep_fps.append(num_frames / timestamps[-1])
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
print(f"Episode {ep_idx} done, fps: {ep_fps[-1]:.2f}")
|
||||||
|
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames if args.keep_last else id_from + num_frames - 1)
|
||||||
|
|
||||||
|
id_to = id_from + num_frames if args.keep_last else id_from + num_frames - 1
|
||||||
|
id_from = id_to
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
os.system('spd-say "encode video frames"')
|
||||||
|
for ep_idx in range(num_episodes):
|
||||||
|
for img_key in env.cameras:
|
||||||
|
# If necessary, we may want to encode the video
|
||||||
|
# with variable frame rate: https://superuser.com/questions/1661901/encoding-video-from-vfr-still-images
|
||||||
|
encode_video_frames(
|
||||||
|
images_dir / f"{img_key}_episode_{ep_idx:06d}",
|
||||||
|
videos_dir / f"{img_key}_episode_{ep_idx:06d}.mp4",
|
||||||
|
ep_fps[ep_idx],
|
||||||
|
)
|
||||||
|
|
||||||
|
os.system('spd-say "concatenate episodes"')
|
||||||
|
data_dict = concatenate_episodes(
|
||||||
|
ep_dicts, drop_episodes_last_frame=not args.keep_last
|
||||||
|
) # Since our fps varies we are sometimes off tolerance for the last frame
|
||||||
|
|
||||||
|
features = {}
|
||||||
|
|
||||||
|
keys = [key for key in data_dict if "observation.images." in key]
|
||||||
|
for key in keys:
|
||||||
|
features[key] = VideoFrame()
|
||||||
|
|
||||||
|
features["observation.state"] = Sequence(
|
||||||
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features["action"] = Sequence(
|
||||||
|
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features["episode_index"] = Value(dtype="int64", id=None)
|
||||||
|
features["frame_index"] = Value(dtype="int64", id=None)
|
||||||
|
features["timestamp"] = Value(dtype="float32", id=None)
|
||||||
|
features["next.done"] = Value(dtype="bool", id=None)
|
||||||
|
features["index"] = Value(dtype="int64", id=None)
|
||||||
|
# TODO(rcadene): add success
|
||||||
|
# features["next.success"] = Value(dtype='bool', id=None)
|
||||||
|
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"fps": sum(ep_fps) / len(ep_fps), # to have a good tolerance in data processing for the slowest video
|
||||||
|
"video": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
os.system('spd-say "from preloaded"')
|
||||||
|
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||||
|
repo_id=repo_id,
|
||||||
|
version=revision,
|
||||||
|
hf_dataset=hf_dataset,
|
||||||
|
episode_data_index=episode_data_index,
|
||||||
|
info=info,
|
||||||
|
videos_dir=videos_dir,
|
||||||
|
)
|
||||||
|
os.system('spd-say "compute stats"')
|
||||||
|
stats = compute_stats(lerobot_dataset)
|
||||||
|
|
||||||
|
os.system('spd-say "save to disk"')
|
||||||
|
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||||
|
hf_dataset.save_to_disk(str(out_data / "train"))
|
||||||
|
|
||||||
|
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||||
|
|
||||||
|
if args.push_to_hub:
|
||||||
|
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
||||||
|
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
||||||
|
|
||||||
|
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||||
|
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
|
||||||
|
|
||||||
|
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||||
|
push_videos_to_hub(repo_id, videos_dir, revision=revision)
|
|
@ -0,0 +1,103 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||||
|
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
|
||||||
|
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||||
|
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||||
|
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||||
|
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||||
|
#
|
||||||
|
# Example of usage for training:
|
||||||
|
# ```bash
|
||||||
|
# python lerobot/scripts/train.py \
|
||||||
|
# policy=act_real \
|
||||||
|
# env=aloha_real
|
||||||
|
# ```
|
||||||
|
|
||||||
|
seed: 1000
|
||||||
|
dataset_repo_id: thomwolf/blue_sort
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.images.cam_high:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_low:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 1000
|
||||||
|
online_steps: 0
|
||||||
|
eval_freq: -1
|
||||||
|
save_freq: 1000
|
||||||
|
log_freq: 100
|
||||||
|
save_checkpoint: true
|
||||||
|
|
||||||
|
batch_size: 8
|
||||||
|
lr: 1e-5
|
||||||
|
lr_backbone: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
grad_clip_norm: 10
|
||||||
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||||
|
|
||||||
|
eval:
|
||||||
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
|
# See `configuration_act.py` for more details.
|
||||||
|
policy:
|
||||||
|
name: act
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: 1
|
||||||
|
chunk_size: 100 # chunk_size
|
||||||
|
n_action_steps: 100
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.images.cam_high: [3, 480, 640]
|
||||||
|
observation.images.cam_low: [3, 480, 640]
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.images.cam_high: mean_std
|
||||||
|
observation.images.cam_low: mean_std
|
||||||
|
observation.state: mean_std
|
||||||
|
output_normalization_modes:
|
||||||
|
action: mean_std
|
||||||
|
|
||||||
|
# Architecture.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||||
|
replace_final_stride_with_dilation: false
|
||||||
|
# Transformer layers.
|
||||||
|
pre_norm: false
|
||||||
|
dim_model: 512
|
||||||
|
n_heads: 8
|
||||||
|
dim_feedforward: 3200
|
||||||
|
feedforward_activation: relu
|
||||||
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
|
n_decoder_layers: 1
|
||||||
|
# VAE.
|
||||||
|
use_vae: true
|
||||||
|
latent_dim: 32
|
||||||
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
|
# Inference.
|
||||||
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
dropout: 0.1
|
||||||
|
kl_weight: 10.0
|
|
@ -0,0 +1,13 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
fps: 30
|
||||||
|
|
||||||
|
env:
|
||||||
|
name: dora
|
||||||
|
task: DoraKoch-v0
|
||||||
|
state_dim: 6
|
||||||
|
action_dim: 6
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 400
|
||||||
|
gym:
|
||||||
|
fps: ${fps}
|
|
@ -0,0 +1,8 @@
|
||||||
|
from gymnasium.envs.registration import register
|
||||||
|
|
||||||
|
register(
|
||||||
|
id="gym_real_env/RealEnv-v0",
|
||||||
|
entry_point="gym_real_env.env:RealEnv",
|
||||||
|
max_episode_steps=300,
|
||||||
|
nondeterministic=True,
|
||||||
|
)
|
|
@ -0,0 +1,359 @@
|
||||||
|
# ruff: noqa
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from dynamixel_sdk import * # Uses Dynamixel SDK library
|
||||||
|
|
||||||
|
|
||||||
|
def pos2pwm(pos: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
:param pos: numpy array of joint positions in range [-pi, pi]
|
||||||
|
:return: numpy array of pwm values in range [0, 4096]
|
||||||
|
"""
|
||||||
|
return ((pos / 3.14 + 1.0) * 2048).astype(np.int64)
|
||||||
|
|
||||||
|
|
||||||
|
def pwm2pos(pwm: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
:param pwm: numpy array of pwm values in range [0, 4096]
|
||||||
|
:return: numpy array of joint positions in range [-pi, pi]
|
||||||
|
"""
|
||||||
|
return (pwm / 2048 - 1) * 3.14
|
||||||
|
|
||||||
|
|
||||||
|
def pwm2vel(pwm: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
:param pwm: numpy array of pwm/s joint velocities
|
||||||
|
:return: numpy array of rad/s joint velocities
|
||||||
|
"""
|
||||||
|
return pwm * 3.14 / 2048
|
||||||
|
|
||||||
|
|
||||||
|
def vel2pwm(vel: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
:param vel: numpy array of rad/s joint velocities
|
||||||
|
:return: numpy array of pwm/s joint velocities
|
||||||
|
"""
|
||||||
|
return (vel * 2048 / 3.14).astype(np.int64)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadAttribute(enum.Enum):
|
||||||
|
TEMPERATURE = 146
|
||||||
|
VOLTAGE = 145
|
||||||
|
VELOCITY = 128
|
||||||
|
POSITION = 132
|
||||||
|
CURRENT = 126
|
||||||
|
PWM = 124
|
||||||
|
HARDWARE_ERROR_STATUS = 70
|
||||||
|
HOMING_OFFSET = 20
|
||||||
|
BAUDRATE = 8
|
||||||
|
|
||||||
|
|
||||||
|
class OperatingMode(enum.Enum):
|
||||||
|
VELOCITY = 1
|
||||||
|
POSITION = 3
|
||||||
|
CURRENT_CONTROLLED_POSITION = 5
|
||||||
|
PWM = 16
|
||||||
|
UNKNOWN = -1
|
||||||
|
|
||||||
|
|
||||||
|
class Dynamixel:
|
||||||
|
ADDR_TORQUE_ENABLE = 64
|
||||||
|
ADDR_GOAL_POSITION = 116
|
||||||
|
ADDR_VELOCITY_LIMIT = 44
|
||||||
|
ADDR_GOAL_PWM = 100
|
||||||
|
OPERATING_MODE_ADDR = 11
|
||||||
|
POSITION_I = 82
|
||||||
|
POSITION_P = 84
|
||||||
|
ADDR_ID = 7
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
def instantiate(self):
|
||||||
|
return Dynamixel(self)
|
||||||
|
|
||||||
|
baudrate: int = 57600
|
||||||
|
protocol_version: float = 2.0
|
||||||
|
device_name: str = "" # /dev/tty.usbserial-1120'
|
||||||
|
dynamixel_id: int = 1
|
||||||
|
|
||||||
|
def __init__(self, config: Config):
|
||||||
|
self.config = config
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
if self.config.device_name == "":
|
||||||
|
for port_name in os.listdir("/dev"):
|
||||||
|
if "ttyUSB" in port_name or "ttyACM" in port_name:
|
||||||
|
self.config.device_name = "/dev/" + port_name
|
||||||
|
print(f"using device {self.config.device_name}")
|
||||||
|
self.portHandler = PortHandler(self.config.device_name)
|
||||||
|
# self.portHandler.LA
|
||||||
|
self.packetHandler = PacketHandler(self.config.protocol_version)
|
||||||
|
if not self.portHandler.openPort():
|
||||||
|
raise Exception(f"Failed to open port {self.config.device_name}")
|
||||||
|
|
||||||
|
if not self.portHandler.setBaudRate(self.config.baudrate):
|
||||||
|
raise Exception(f"failed to set baudrate to {self.config.baudrate}")
|
||||||
|
|
||||||
|
# self.operating_mode = OperatingMode.UNKNOWN
|
||||||
|
# self.torque_enabled = False
|
||||||
|
# self._disable_torque()
|
||||||
|
|
||||||
|
self.operating_modes = [None for _ in range(32)]
|
||||||
|
self.torque_enabled = [None for _ in range(32)]
|
||||||
|
return True
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
self.portHandler.closePort()
|
||||||
|
|
||||||
|
def set_goal_position(self, motor_id, goal_position):
|
||||||
|
# if self.operating_modes[motor_id] is not OperatingMode.POSITION:
|
||||||
|
# self._disable_torque(motor_id)
|
||||||
|
# self.set_operating_mode(motor_id, OperatingMode.POSITION)
|
||||||
|
|
||||||
|
# if not self.torque_enabled[motor_id]:
|
||||||
|
# self._enable_torque(motor_id)
|
||||||
|
|
||||||
|
# self._enable_torque(motor_id)
|
||||||
|
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
|
||||||
|
self.portHandler, motor_id, self.ADDR_GOAL_POSITION, goal_position
|
||||||
|
)
|
||||||
|
# self._process_response(dxl_comm_result, dxl_error)
|
||||||
|
# print(f'set position of motor {motor_id} to {goal_position}')
|
||||||
|
|
||||||
|
def set_pwm_value(self, motor_id: int, pwm_value, tries=3):
|
||||||
|
if self.operating_modes[motor_id] is not OperatingMode.PWM:
|
||||||
|
self._disable_torque(motor_id)
|
||||||
|
self.set_operating_mode(motor_id, OperatingMode.PWM)
|
||||||
|
|
||||||
|
if not self.torque_enabled[motor_id]:
|
||||||
|
self._enable_torque(motor_id)
|
||||||
|
# print(f'enabling torque')
|
||||||
|
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
|
||||||
|
self.portHandler, motor_id, self.ADDR_GOAL_PWM, pwm_value
|
||||||
|
)
|
||||||
|
# self._process_response(dxl_comm_result, dxl_error)
|
||||||
|
# print(f'set pwm of motor {motor_id} to {pwm_value}')
|
||||||
|
if dxl_comm_result != COMM_SUCCESS:
|
||||||
|
if tries <= 1:
|
||||||
|
raise ConnectionError(f"dxl_comm_result: {self.packetHandler.getTxRxResult(dxl_comm_result)}")
|
||||||
|
else:
|
||||||
|
print(f"dynamixel pwm setting failure trying again with {tries - 1} tries")
|
||||||
|
self.set_pwm_value(motor_id, pwm_value, tries=tries - 1)
|
||||||
|
elif dxl_error != 0:
|
||||||
|
print(f"dxl error {dxl_error}")
|
||||||
|
raise ConnectionError(f"dynamixel error: {self.packetHandler.getTxRxResult(dxl_error)}")
|
||||||
|
|
||||||
|
def read_temperature(self, motor_id: int):
|
||||||
|
return self._read_value(motor_id, ReadAttribute.TEMPERATURE, 1)
|
||||||
|
|
||||||
|
def read_velocity(self, motor_id: int):
|
||||||
|
pos = self._read_value(motor_id, ReadAttribute.VELOCITY, 4)
|
||||||
|
if pos > 2**31:
|
||||||
|
pos -= 2**32
|
||||||
|
# print(f'read position {pos} for motor {motor_id}')
|
||||||
|
return pos
|
||||||
|
|
||||||
|
def read_position(self, motor_id: int):
|
||||||
|
pos = self._read_value(motor_id, ReadAttribute.POSITION, 4)
|
||||||
|
if pos > 2**31:
|
||||||
|
pos -= 2**32
|
||||||
|
# print(f'read position {pos} for motor {motor_id}')
|
||||||
|
return pos
|
||||||
|
|
||||||
|
def read_position_degrees(self, motor_id: int) -> float:
|
||||||
|
return (self.read_position(motor_id) / 4096) * 360
|
||||||
|
|
||||||
|
def read_position_radians(self, motor_id: int) -> float:
|
||||||
|
return (self.read_position(motor_id) / 4096) * 2 * math.pi
|
||||||
|
|
||||||
|
def read_current(self, motor_id: int):
|
||||||
|
current = self._read_value(motor_id, ReadAttribute.CURRENT, 2)
|
||||||
|
if current > 2**15:
|
||||||
|
current -= 2**16
|
||||||
|
return current
|
||||||
|
|
||||||
|
def read_present_pwm(self, motor_id: int):
|
||||||
|
return self._read_value(motor_id, ReadAttribute.PWM, 2)
|
||||||
|
|
||||||
|
def read_hardware_error_status(self, motor_id: int):
|
||||||
|
return self._read_value(motor_id, ReadAttribute.HARDWARE_ERROR_STATUS, 1)
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
self.portHandler.closePort()
|
||||||
|
|
||||||
|
def set_id(self, old_id, new_id, use_broadcast_id: bool = False):
|
||||||
|
"""
|
||||||
|
sets the id of the dynamixel servo
|
||||||
|
@param old_id: current id of the servo
|
||||||
|
@param new_id: new id
|
||||||
|
@param use_broadcast_id: set ids of all connected dynamixels if True.
|
||||||
|
If False, change only servo with self.config.id
|
||||||
|
@return:
|
||||||
|
"""
|
||||||
|
if use_broadcast_id:
|
||||||
|
current_id = 254
|
||||||
|
else:
|
||||||
|
current_id = old_id
|
||||||
|
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
|
||||||
|
self.portHandler, current_id, self.ADDR_ID, new_id
|
||||||
|
)
|
||||||
|
self._process_response(dxl_comm_result, dxl_error, old_id)
|
||||||
|
self.config.id = id
|
||||||
|
|
||||||
|
def _enable_torque(self, motor_id):
|
||||||
|
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
|
||||||
|
self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 1
|
||||||
|
)
|
||||||
|
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||||
|
self.torque_enabled[motor_id] = True
|
||||||
|
|
||||||
|
def _disable_torque(self, motor_id):
|
||||||
|
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
|
||||||
|
self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 0
|
||||||
|
)
|
||||||
|
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||||
|
self.torque_enabled[motor_id] = False
|
||||||
|
|
||||||
|
def _process_response(self, dxl_comm_result: int, dxl_error: int, motor_id: int):
|
||||||
|
if dxl_comm_result != COMM_SUCCESS:
|
||||||
|
raise ConnectionError(
|
||||||
|
f"dxl_comm_result for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_comm_result)}"
|
||||||
|
)
|
||||||
|
elif dxl_error != 0:
|
||||||
|
print(f"dxl error {dxl_error}")
|
||||||
|
raise ConnectionError(
|
||||||
|
f"dynamixel error for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_error)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_operating_mode(self, motor_id: int, operating_mode: OperatingMode):
|
||||||
|
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
|
||||||
|
self.portHandler, motor_id, self.OPERATING_MODE_ADDR, operating_mode.value
|
||||||
|
)
|
||||||
|
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||||
|
self.operating_modes[motor_id] = operating_mode
|
||||||
|
|
||||||
|
def set_pwm_limit(self, motor_id: int, limit: int):
|
||||||
|
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(self.portHandler, motor_id, 36, limit)
|
||||||
|
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||||
|
|
||||||
|
def set_velocity_limit(self, motor_id: int, velocity_limit):
|
||||||
|
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
|
||||||
|
self.portHandler, motor_id, self.ADDR_VELOCITY_LIMIT, velocity_limit
|
||||||
|
)
|
||||||
|
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||||
|
|
||||||
|
def set_P(self, motor_id: int, P: int):
|
||||||
|
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
|
||||||
|
self.portHandler, motor_id, self.POSITION_P, P
|
||||||
|
)
|
||||||
|
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||||
|
|
||||||
|
def set_I(self, motor_id: int, I: int):
|
||||||
|
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
|
||||||
|
self.portHandler, motor_id, self.POSITION_I, I
|
||||||
|
)
|
||||||
|
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||||
|
|
||||||
|
def read_home_offset(self, motor_id: int):
|
||||||
|
self._disable_torque(motor_id)
|
||||||
|
# dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(self.portHandler, motor_id,
|
||||||
|
# ReadAttribute.HOMING_OFFSET.value, home_position)
|
||||||
|
home_offset = self._read_value(motor_id, ReadAttribute.HOMING_OFFSET, 4)
|
||||||
|
# self._process_response(dxl_comm_result, dxl_error)
|
||||||
|
self._enable_torque(motor_id)
|
||||||
|
return home_offset
|
||||||
|
|
||||||
|
def set_home_offset(self, motor_id: int, home_position: int):
|
||||||
|
self._disable_torque(motor_id)
|
||||||
|
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
|
||||||
|
self.portHandler, motor_id, ReadAttribute.HOMING_OFFSET.value, home_position
|
||||||
|
)
|
||||||
|
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||||
|
self._enable_torque(motor_id)
|
||||||
|
|
||||||
|
def set_baudrate(self, motor_id: int, baudrate):
|
||||||
|
# translate baudrate into dynamixel baudrate setting id
|
||||||
|
if baudrate == 57600:
|
||||||
|
baudrate_id = 1
|
||||||
|
elif baudrate == 1_000_000:
|
||||||
|
baudrate_id = 3
|
||||||
|
elif baudrate == 2_000_000:
|
||||||
|
baudrate_id = 4
|
||||||
|
elif baudrate == 3_000_000:
|
||||||
|
baudrate_id = 5
|
||||||
|
elif baudrate == 4_000_000:
|
||||||
|
baudrate_id = 6
|
||||||
|
else:
|
||||||
|
raise Exception("baudrate not implemented")
|
||||||
|
|
||||||
|
self._disable_torque(motor_id)
|
||||||
|
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
|
||||||
|
self.portHandler, motor_id, ReadAttribute.BAUDRATE.value, baudrate_id
|
||||||
|
)
|
||||||
|
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||||
|
|
||||||
|
def _read_value(self, motor_id, attribute: ReadAttribute, num_bytes: int, tries=10):
|
||||||
|
try:
|
||||||
|
if num_bytes == 1:
|
||||||
|
value, dxl_comm_result, dxl_error = self.packetHandler.read1ByteTxRx(
|
||||||
|
self.portHandler, motor_id, attribute.value
|
||||||
|
)
|
||||||
|
elif num_bytes == 2:
|
||||||
|
value, dxl_comm_result, dxl_error = self.packetHandler.read2ByteTxRx(
|
||||||
|
self.portHandler, motor_id, attribute.value
|
||||||
|
)
|
||||||
|
elif num_bytes == 4:
|
||||||
|
value, dxl_comm_result, dxl_error = self.packetHandler.read4ByteTxRx(
|
||||||
|
self.portHandler, motor_id, attribute.value
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
if tries == 0:
|
||||||
|
raise Exception
|
||||||
|
else:
|
||||||
|
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
|
||||||
|
if dxl_comm_result != COMM_SUCCESS:
|
||||||
|
if tries <= 1:
|
||||||
|
# print("%s" % self.packetHandler.getTxRxResult(dxl_comm_result))
|
||||||
|
raise ConnectionError(f"dxl_comm_result {dxl_comm_result} for servo {motor_id} value {value}")
|
||||||
|
else:
|
||||||
|
print(f"dynamixel read failure for servo {motor_id} trying again with {tries - 1} tries")
|
||||||
|
time.sleep(0.02)
|
||||||
|
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
|
||||||
|
elif dxl_error != 0: # # print("%s" % self.packetHandler.getRxPacketError(dxl_error))
|
||||||
|
# raise ConnectionError(f'dxl_error {dxl_error} binary ' + "{0:b}".format(37))
|
||||||
|
if tries == 0 and dxl_error != 128:
|
||||||
|
raise Exception(f"Failed to read value from motor {motor_id} error is {dxl_error}")
|
||||||
|
else:
|
||||||
|
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def set_home_position(self, motor_id: int):
|
||||||
|
print(f"setting home position for motor {motor_id}")
|
||||||
|
self.set_home_offset(motor_id, 0)
|
||||||
|
current_position = self.read_position(motor_id)
|
||||||
|
print(f"position before {current_position}")
|
||||||
|
self.set_home_offset(motor_id, -current_position)
|
||||||
|
# dynamixel.set_home_offset(motor_id, -4096)
|
||||||
|
# dynamixel.set_home_offset(motor_id, -4294964109)
|
||||||
|
current_position = self.read_position(motor_id)
|
||||||
|
# print(f'signed position {current_position - 2** 32}')
|
||||||
|
print(f"position after {current_position}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dynamixel = Dynamixel.Config(baudrate=1_000_000, device_name="/dev/tty.usbmodem57380045631").instantiate()
|
||||||
|
motor_id = 1
|
||||||
|
pos = dynamixel.read_position(motor_id)
|
||||||
|
for i in range(10):
|
||||||
|
s = time.monotonic()
|
||||||
|
pos = dynamixel.read_position(motor_id)
|
||||||
|
delta = time.monotonic() - s
|
||||||
|
print(f"read position took {delta}")
|
||||||
|
print(f"position {pos}")
|
|
@ -0,0 +1,158 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
from .dynamixel import pos2pwm, pwm2pos
|
||||||
|
from .robot import Robot
|
||||||
|
|
||||||
|
FPS = 30
|
||||||
|
|
||||||
|
CAMERAS_SHAPES = {
|
||||||
|
"observation.images.high": (480, 640, 3),
|
||||||
|
"observation.images.low": (480, 640, 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
CAMERAS_PORTS = {
|
||||||
|
"observation.images.high": "/dev/video6",
|
||||||
|
"observation.images.low": "/dev/video0",
|
||||||
|
}
|
||||||
|
|
||||||
|
LEADER_PORT = "/dev/ttyACM1"
|
||||||
|
FOLLOWER_PORT = "/dev/ttyACM0"
|
||||||
|
|
||||||
|
|
||||||
|
def capture_image(cam, cam_width, cam_height):
|
||||||
|
# Capture a single frame
|
||||||
|
_, frame = cam.read()
|
||||||
|
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
# # Define your crop coordinates (top left corner and bottom right corner)
|
||||||
|
# x1, y1 = 400, 0 # Example starting coordinates (top left of the crop rectangle)
|
||||||
|
# x2, y2 = 1600, 900 # Example ending coordinates (bottom right of the crop rectangle)
|
||||||
|
# # Crop the image
|
||||||
|
# image = image[y1:y2, x1:x2]
|
||||||
|
# Resize the image
|
||||||
|
image = cv2.resize(image, (cam_width, cam_height), interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class RealEnv(gym.Env):
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
record: bool = False,
|
||||||
|
num_joints: int = 6,
|
||||||
|
cameras_shapes: dict = CAMERAS_SHAPES,
|
||||||
|
cameras_ports: dict = CAMERAS_PORTS,
|
||||||
|
follower_port: str = FOLLOWER_PORT,
|
||||||
|
leader_port: str = LEADER_PORT,
|
||||||
|
warmup_steps: int = 100,
|
||||||
|
trigger_torque=70,
|
||||||
|
):
|
||||||
|
self.num_joints = num_joints
|
||||||
|
self.cameras_shapes = cameras_shapes
|
||||||
|
self.cameras_ports = cameras_ports
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
assert len(self.cameras_shapes) == len(self.cameras_ports), "Number of cameras and shapes must match."
|
||||||
|
|
||||||
|
self.follower_port = follower_port
|
||||||
|
self.leader_port = leader_port
|
||||||
|
self.record = record
|
||||||
|
|
||||||
|
# Initialize the robot
|
||||||
|
self.follower = Robot(device_name=self.follower_port)
|
||||||
|
if self.record:
|
||||||
|
self.leader = Robot(device_name=self.leader_port)
|
||||||
|
self.leader.set_trigger_torque(trigger_torque)
|
||||||
|
|
||||||
|
# Initialize the cameras - sorted by camera names
|
||||||
|
self.cameras = {}
|
||||||
|
for cn, p in sorted(self.cameras_ports.items()):
|
||||||
|
assert cn.startswith("observation.images."), "Camera names must start with 'observation.images.'."
|
||||||
|
self.cameras[cn] = cv2.VideoCapture(p)
|
||||||
|
if not all(c.isOpened() for c in self.cameras.values()):
|
||||||
|
raise OSError("Cannot open all camera ports.")
|
||||||
|
|
||||||
|
# Specify gym action and observation spaces
|
||||||
|
observation_space = {}
|
||||||
|
|
||||||
|
if self.num_joints > 0:
|
||||||
|
observation_space["agent_pos"] = spaces.Box(
|
||||||
|
low=-1000.0,
|
||||||
|
high=1000.0,
|
||||||
|
shape=(num_joints,),
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
if self.record:
|
||||||
|
observation_space["leader_pos"] = spaces.Box(
|
||||||
|
low=-1000.0,
|
||||||
|
high=1000.0,
|
||||||
|
shape=(num_joints,),
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.cameras_shapes:
|
||||||
|
for cn, hwc_shape in self.cameras_shapes.items():
|
||||||
|
# Assumes images are unsigned int8 in [0,255]
|
||||||
|
observation_space[f"images.{cn}"] = spaces.Box(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
# height x width x channels (e.g. 480 x 640 x 3)
|
||||||
|
shape=hwc_shape,
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.observation_space = spaces.Dict(observation_space)
|
||||||
|
self.action_space = spaces.Box(low=-1, high=1, shape=(num_joints,), dtype=np.float32)
|
||||||
|
|
||||||
|
self._observation = {}
|
||||||
|
self._terminated = False
|
||||||
|
self._action_time = time.time()
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
qpos = self.follower.read_position()
|
||||||
|
self._observation["agent_pos"] = pwm2pos(qpos)
|
||||||
|
for cn, c in self.cameras.items():
|
||||||
|
self._observation[f"images.{cn}"] = capture_image(
|
||||||
|
c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.record:
|
||||||
|
leader_pos = self.leader.read_position()
|
||||||
|
self._observation["leader_pos"] = pwm2pos(leader_pos)
|
||||||
|
|
||||||
|
def reset(self, seed: int | None = None):
|
||||||
|
del seed
|
||||||
|
# Reset the robot and sync the leader and follower if we are recording
|
||||||
|
for _ in range(self.warmup_steps):
|
||||||
|
self._get_obs()
|
||||||
|
if self.record:
|
||||||
|
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
|
||||||
|
self._terminated = False
|
||||||
|
info = {}
|
||||||
|
return self._observation, info
|
||||||
|
|
||||||
|
def step(self, action: np.ndarray = None):
|
||||||
|
# Reset the observation
|
||||||
|
self._get_obs()
|
||||||
|
if self.record:
|
||||||
|
# Teleoperate the leader
|
||||||
|
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
|
||||||
|
else:
|
||||||
|
# Apply the action to the follower
|
||||||
|
self.follower.set_goal_pos(pos2pwm(action))
|
||||||
|
reward = 0
|
||||||
|
terminated = truncated = self._terminated
|
||||||
|
info = {}
|
||||||
|
return self._observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def render(self): ...
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.follower._disable_torque()
|
||||||
|
if self.record:
|
||||||
|
self.leader._disable_torque()
|
|
@ -0,0 +1,163 @@
|
||||||
|
# ruff: noqa
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from dynamixel import Dynamixel, OperatingMode, ReadAttribute
|
||||||
|
from dynamixel_sdk import DXL_HIBYTE, DXL_HIWORD, DXL_LOBYTE, DXL_LOWORD, GroupSyncRead, GroupSyncWrite
|
||||||
|
|
||||||
|
|
||||||
|
class MotorControlType(Enum):
|
||||||
|
PWM = auto()
|
||||||
|
POSITION_CONTROL = auto()
|
||||||
|
DISABLED = auto()
|
||||||
|
UNKNOWN = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class Robot:
|
||||||
|
def __init__(self, device_name: str, baudrate=1_000_000, servo_ids=[1, 2, 3, 4, 5, 6]) -> None:
|
||||||
|
self.servo_ids = servo_ids
|
||||||
|
self.dynamixel = Dynamixel.Config(baudrate=baudrate, device_name=device_name).instantiate()
|
||||||
|
self._init_motors()
|
||||||
|
|
||||||
|
def _init_motors(self):
|
||||||
|
self.position_reader = GroupSyncRead(
|
||||||
|
self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.POSITION.value, 4
|
||||||
|
)
|
||||||
|
for id in self.servo_ids:
|
||||||
|
self.position_reader.addParam(id)
|
||||||
|
|
||||||
|
self.velocity_reader = GroupSyncRead(
|
||||||
|
self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.VELOCITY.value, 4
|
||||||
|
)
|
||||||
|
for id in self.servo_ids:
|
||||||
|
self.velocity_reader.addParam(id)
|
||||||
|
|
||||||
|
self.pos_writer = GroupSyncWrite(
|
||||||
|
self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_POSITION, 4
|
||||||
|
)
|
||||||
|
for id in self.servo_ids:
|
||||||
|
self.pos_writer.addParam(id, [2048])
|
||||||
|
|
||||||
|
self.pwm_writer = GroupSyncWrite(
|
||||||
|
self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_PWM, 2
|
||||||
|
)
|
||||||
|
for id in self.servo_ids:
|
||||||
|
self.pwm_writer.addParam(id, [2048])
|
||||||
|
self._disable_torque()
|
||||||
|
self.motor_control_state = MotorControlType.DISABLED
|
||||||
|
|
||||||
|
def read_position(self, tries=2):
|
||||||
|
"""
|
||||||
|
Reads the joint positions of the robot. 2048 is the center position. 0 and 4096 are 180 degrees in each direction.
|
||||||
|
:param tries: maximum number of tries to read the position
|
||||||
|
:return: list of joint positions in range [0, 4096]
|
||||||
|
"""
|
||||||
|
result = self.position_reader.txRxPacket()
|
||||||
|
if result != 0:
|
||||||
|
if tries > 0:
|
||||||
|
return self.read_position(tries=tries - 1)
|
||||||
|
else:
|
||||||
|
print("failed to read position!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
positions = []
|
||||||
|
for id in self.servo_ids:
|
||||||
|
position = self.position_reader.getData(id, ReadAttribute.POSITION.value, 4)
|
||||||
|
if position > 2**31:
|
||||||
|
position -= 2**32
|
||||||
|
positions.append(position)
|
||||||
|
return np.array(positions)
|
||||||
|
|
||||||
|
def read_velocity(self):
|
||||||
|
"""
|
||||||
|
Reads the joint velocities of the robot.
|
||||||
|
:return: list of joint velocities,
|
||||||
|
"""
|
||||||
|
self.velocity_reader.txRxPacket()
|
||||||
|
velocties = []
|
||||||
|
for id in self.servo_ids:
|
||||||
|
velocity = self.velocity_reader.getData(id, ReadAttribute.VELOCITY.value, 4)
|
||||||
|
if velocity > 2**31:
|
||||||
|
velocity -= 2**32
|
||||||
|
velocties.append(velocity)
|
||||||
|
return np.array(velocties)
|
||||||
|
|
||||||
|
def set_goal_pos(self, action):
|
||||||
|
"""
|
||||||
|
:param action: list or numpy array of target joint positions in range [0, 4096]
|
||||||
|
"""
|
||||||
|
if self.motor_control_state is not MotorControlType.POSITION_CONTROL:
|
||||||
|
self._set_position_control()
|
||||||
|
for i, motor_id in enumerate(self.servo_ids):
|
||||||
|
data_write = [
|
||||||
|
DXL_LOBYTE(DXL_LOWORD(action[i])),
|
||||||
|
DXL_HIBYTE(DXL_LOWORD(action[i])),
|
||||||
|
DXL_LOBYTE(DXL_HIWORD(action[i])),
|
||||||
|
DXL_HIBYTE(DXL_HIWORD(action[i])),
|
||||||
|
]
|
||||||
|
self.pos_writer.changeParam(motor_id, data_write)
|
||||||
|
|
||||||
|
self.pos_writer.txPacket()
|
||||||
|
|
||||||
|
def set_pwm(self, action):
|
||||||
|
"""
|
||||||
|
Sets the pwm values for the servos.
|
||||||
|
:param action: list or numpy array of pwm values in range [0, 885]
|
||||||
|
"""
|
||||||
|
if self.motor_control_state is not MotorControlType.PWM:
|
||||||
|
self._set_pwm_control()
|
||||||
|
for i, motor_id in enumerate(self.servo_ids):
|
||||||
|
data_write = [
|
||||||
|
DXL_LOBYTE(DXL_LOWORD(action[i])),
|
||||||
|
DXL_HIBYTE(DXL_LOWORD(action[i])),
|
||||||
|
]
|
||||||
|
self.pwm_writer.changeParam(motor_id, data_write)
|
||||||
|
|
||||||
|
self.pwm_writer.txPacket()
|
||||||
|
|
||||||
|
def set_trigger_torque(self, torque: int):
|
||||||
|
"""
|
||||||
|
Sets a constant torque torque for the last servo in the chain. This is useful for the trigger of the leader arm
|
||||||
|
"""
|
||||||
|
self.dynamixel._enable_torque(self.servo_ids[-1])
|
||||||
|
self.dynamixel.set_pwm_value(self.servo_ids[-1], torque)
|
||||||
|
|
||||||
|
def limit_pwm(self, limit: Union[int, list, np.ndarray]):
|
||||||
|
"""
|
||||||
|
Limits the pwm values for the servos in for position control
|
||||||
|
@param limit: 0 ~ 885
|
||||||
|
@return:
|
||||||
|
"""
|
||||||
|
if isinstance(limit, int):
|
||||||
|
limits = [
|
||||||
|
limit,
|
||||||
|
] * 5
|
||||||
|
else:
|
||||||
|
limits = limit
|
||||||
|
self._disable_torque()
|
||||||
|
for motor_id, limit in zip(self.servo_ids, limits, strict=False):
|
||||||
|
self.dynamixel.set_pwm_limit(motor_id, limit)
|
||||||
|
self._enable_torque()
|
||||||
|
|
||||||
|
def _disable_torque(self):
|
||||||
|
print(f"disabling torque for servos {self.servo_ids}")
|
||||||
|
for motor_id in self.servo_ids:
|
||||||
|
self.dynamixel._disable_torque(motor_id)
|
||||||
|
|
||||||
|
def _enable_torque(self):
|
||||||
|
print(f"enabling torque for servos {self.servo_ids}")
|
||||||
|
for motor_id in self.servo_ids:
|
||||||
|
self.dynamixel._enable_torque(motor_id)
|
||||||
|
|
||||||
|
def _set_pwm_control(self):
|
||||||
|
self._disable_torque()
|
||||||
|
for motor_id in self.servo_ids:
|
||||||
|
self.dynamixel.set_operating_mode(motor_id, OperatingMode.PWM)
|
||||||
|
self._enable_torque()
|
||||||
|
self.motor_control_state = MotorControlType.PWM
|
||||||
|
|
||||||
|
def _set_position_control(self):
|
||||||
|
self._disable_torque()
|
||||||
|
for motor_id in self.servo_ids:
|
||||||
|
self.dynamixel.set_operating_mode(motor_id, OperatingMode.POSITION)
|
||||||
|
self._enable_torque()
|
||||||
|
self.motor_control_state = MotorControlType.POSITION_CONTROL
|
|
@ -21,19 +21,24 @@ import PIL
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def concatenate_episodes(ep_dicts):
|
def concatenate_episodes(ep_dicts, drop_episodes_last_frame=False):
|
||||||
data_dict = {}
|
data_dict = {}
|
||||||
|
|
||||||
keys = ep_dicts[0].keys()
|
keys = ep_dicts[0].keys()
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
if drop_episodes_last_frame:
|
||||||
|
data_dict[key] = torch.cat([ep_dict[key][:-1] for ep_dict in ep_dicts])
|
||||||
|
else:
|
||||||
|
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||||
else:
|
else:
|
||||||
if key not in data_dict:
|
if key not in data_dict:
|
||||||
data_dict[key] = []
|
data_dict[key] = []
|
||||||
for ep_dict in ep_dicts:
|
for ep_dict in ep_dicts:
|
||||||
for x in ep_dict[key]:
|
for x in ep_dict[key]:
|
||||||
data_dict[key].append(x)
|
data_dict[key].append(x)
|
||||||
|
if drop_episodes_last_frame:
|
||||||
|
data_dict[key].pop()
|
||||||
|
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
Loading…
Reference in New Issue