- Added `lerobot/scripts/server/gym_manipulator.py` that contains all the necessary wrappers to run a gym-style env around the real robot.
- Added `lerobot/scripts/server/find_joint_limits.py` to test the min and max angles of the motion you wish the robot to explore during RL training. - Added logic in `manipulator.py` to limit the maximum possible joint angles to allow motion within a predefined joint position range. The limits are specified in the yaml config for each robot. Checkout the so100.yaml. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
163bcbcad4
commit
729b4ed697
|
@ -126,27 +126,29 @@ class PixelWrapper(gym.Wrapper):
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
return self._get_obs(obs), reward, terminated, truncated, info
|
return self._get_obs(obs), reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
class ConvertToLeRobotEnv(gym.Wrapper):
|
class ConvertToLeRobotEnv(gym.Wrapper):
|
||||||
def __init__(self, env, num_envs):
|
def __init__(self, env, num_envs):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
def reset(self, seed=None, options=None):
|
def reset(self, seed=None, options=None):
|
||||||
obs, info = self.env.reset(seed=seed, options={})
|
obs, info = self.env.reset(seed=seed, options={})
|
||||||
return self._get_obs(obs), info
|
return self._get_obs(obs), info
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
return self._get_obs(obs), reward, terminated, truncated, info
|
return self._get_obs(obs), reward, terminated, truncated, info
|
||||||
|
|
||||||
def _get_obs(self, observation):
|
def _get_obs(self, observation):
|
||||||
sensor_data = observation.pop("sensor_data")
|
sensor_data = observation.pop("sensor_data")
|
||||||
del observation["sensor_param"]
|
del observation["sensor_param"]
|
||||||
images = []
|
images = []
|
||||||
for cam_data in sensor_data.values():
|
for cam_data in sensor_data.values():
|
||||||
images.append(cam_data["rgb"])
|
images.append(cam_data["rgb"])
|
||||||
|
|
||||||
images = torch.concat(images, axis=-1)
|
images = torch.concat(images, axis=-1)
|
||||||
# flatten the rest of the data which should just be state data
|
# flatten the rest of the data which should just be state data
|
||||||
observation = common.flatten_state_dict(
|
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
|
||||||
observation, use_torch=True, device=self.base_env.device
|
|
||||||
)
|
|
||||||
ret = dict()
|
ret = dict()
|
||||||
ret["state"] = observation
|
ret["state"] = observation
|
||||||
ret["pixels"] = images
|
ret["pixels"] = images
|
||||||
|
|
|
@ -36,11 +36,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
# TODO: You have to merge all tensors from agent key and extra key
|
# TODO: You have to merge all tensors from agent key and extra key
|
||||||
# You don't keep sensor param key in the observation
|
# You don't keep sensor param key in the observation
|
||||||
# And you keep sensor data rgb
|
# And you keep sensor data rgb
|
||||||
if "pixels" in observations:
|
for key, img in observations.items():
|
||||||
if isinstance(observations["pixels"], dict):
|
if "images" not in key:
|
||||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
continue
|
||||||
else:
|
|
||||||
imgs = {"observation.image": observations["pixels"]}
|
|
||||||
|
|
||||||
for imgkey, img in imgs.items():
|
for imgkey, img in imgs.items():
|
||||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||||
|
@ -50,15 +48,15 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
_, h, w, c = img.shape
|
_, h, w, c = img.shape
|
||||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||||
|
|
||||||
# sanity check that images are uint8
|
# sanity check that images are uint8
|
||||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
|
|
||||||
# convert to channel first of type float32 in range [0,1]
|
# convert to channel first of type float32 in range [0,1]
|
||||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||||
img = img.type(torch.float32)
|
img = img.type(torch.float32)
|
||||||
img /= 255
|
img /= 255
|
||||||
|
|
||||||
return_observations[imgkey] = img
|
return_observations[key] = img
|
||||||
# obs state agent qpos and qvel
|
# obs state agent qpos and qvel
|
||||||
# image
|
# image
|
||||||
|
|
||||||
|
@ -69,7 +67,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
|
|
||||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||||
# requirement for "agent_pos"
|
# requirement for "agent_pos"
|
||||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
# return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||||
|
return_observations["observation.state"] = observations["observation.state"].float()
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,7 @@ class Classifier(
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||||
# Extract vision model if we're given a multimodal model
|
# Extract vision model if we're given a multimodal model
|
||||||
if hasattr(encoder, "vision_model"):
|
if hasattr(encoder, "vision_model"):
|
||||||
|
@ -108,11 +108,12 @@ class Classifier(
|
||||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Extract the appropriate output from the encoder."""
|
"""Extract the appropriate output from the encoder."""
|
||||||
# Process images with the processor (handles resizing and normalization)
|
# Process images with the processor (handles resizing and normalization)
|
||||||
processed = self.processor(
|
# processed = self.processor(
|
||||||
images=x, # LeRobotDataset already provides proper tensor format
|
# images=x, # LeRobotDataset already provides proper tensor format
|
||||||
return_tensors="pt",
|
# return_tensors="pt",
|
||||||
)
|
# )
|
||||||
processed = processed["pixel_values"].to(x.device)
|
# processed = processed["pixel_values"].to(x.device)
|
||||||
|
processed = x
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.is_cnn:
|
if self.is_cnn:
|
||||||
|
@ -146,6 +147,6 @@ class Classifier(
|
||||||
|
|
||||||
def predict_reward(self, x):
|
def predict_reward(self, x):
|
||||||
if self.config.num_classes == 2:
|
if self.config.num_classes == 2:
|
||||||
return (self.forward(x).probabilities > 0.5).float()
|
return (self.forward(x).probabilities > 0.6).float()
|
||||||
else:
|
else:
|
||||||
return torch.argmax(self.forward(x).probabilities, dim=1)
|
return torch.argmax(self.forward(x).probabilities, dim=1)
|
||||||
|
|
|
@ -45,7 +45,7 @@ def ensure_safe_goal_position(
|
||||||
safe_goal_pos = present_pos + safe_diff
|
safe_goal_pos = present_pos + safe_diff
|
||||||
|
|
||||||
if not torch.allclose(goal_pos, safe_goal_pos):
|
if not torch.allclose(goal_pos, safe_goal_pos):
|
||||||
logging.warning(
|
logging.debug(
|
||||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||||
f" requested relative goal position target: {diff}\n"
|
f" requested relative goal position target: {diff}\n"
|
||||||
f" clamped relative goal position target: {safe_diff}"
|
f" clamped relative goal position target: {safe_diff}"
|
||||||
|
@ -464,6 +464,14 @@ class ManipulatorRobot:
|
||||||
before_fwrite_t = time.perf_counter()
|
before_fwrite_t = time.perf_counter()
|
||||||
goal_pos = leader_pos[name]
|
goal_pos = leader_pos[name]
|
||||||
|
|
||||||
|
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||||
|
if self.config.joint_position_relative_bounds is not None:
|
||||||
|
goal_pos = torch.clamp(
|
||||||
|
goal_pos,
|
||||||
|
self.config.joint_position_relative_bounds["min"],
|
||||||
|
self.config.joint_position_relative_bounds["max"],
|
||||||
|
)
|
||||||
|
|
||||||
# Cap goal position when too far away from present position.
|
# Cap goal position when too far away from present position.
|
||||||
# Slower fps expected due to reading from the follower.
|
# Slower fps expected due to reading from the follower.
|
||||||
if self.config.max_relative_target is not None:
|
if self.config.max_relative_target is not None:
|
||||||
|
@ -585,6 +593,14 @@ class ManipulatorRobot:
|
||||||
goal_pos = action[from_idx:to_idx]
|
goal_pos = action[from_idx:to_idx]
|
||||||
from_idx = to_idx
|
from_idx = to_idx
|
||||||
|
|
||||||
|
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||||
|
if self.config.joint_position_relative_bounds is not None:
|
||||||
|
goal_pos = torch.clamp(
|
||||||
|
goal_pos,
|
||||||
|
self.config.joint_position_relative_bounds["min"],
|
||||||
|
self.config.joint_position_relative_bounds["max"],
|
||||||
|
)
|
||||||
|
|
||||||
# Cap goal position when too far away from present position.
|
# Cap goal position when too far away from present position.
|
||||||
# Slower fps expected due to reading from the follower.
|
# Slower fps expected due to reading from the follower.
|
||||||
if self.config.max_relative_target is not None:
|
if self.config.max_relative_target is not None:
|
||||||
|
|
|
@ -4,7 +4,7 @@ defaults:
|
||||||
- _self_
|
- _self_
|
||||||
|
|
||||||
seed: 13
|
seed: 13
|
||||||
dataset_repo_id: aractingi/pick_place_lego_cube_1
|
dataset_repo_id: aractingi/push_green_cube_hf_cropped_resized
|
||||||
train_split_proportion: 0.8
|
train_split_proportion: 0.8
|
||||||
|
|
||||||
# Required by logger
|
# Required by logger
|
||||||
|
@ -24,7 +24,8 @@ training:
|
||||||
eval_freq: 1 # How often to run validation (in epochs)
|
eval_freq: 1 # How often to run validation (in epochs)
|
||||||
save_freq: 1 # How often to save checkpoints (in epochs)
|
save_freq: 1 # How often to save checkpoints (in epochs)
|
||||||
save_checkpoint: true
|
save_checkpoint: true
|
||||||
image_keys: ["observation.images.top", "observation.images.wrist"]
|
# image_keys: ["observation.images.top", "observation.images.wrist"]
|
||||||
|
image_keys: ["observation.images.laptop", "observation.images.phone"]
|
||||||
label_key: "next.reward"
|
label_key: "next.reward"
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
|
@ -32,7 +33,7 @@ eval:
|
||||||
num_samples_to_log: 30 # Number of validation samples to log in the table
|
num_samples_to_log: 30 # Number of validation samples to log in the table
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
name: "hilserl/classifier/pick_place_lego_cube_1"
|
name: "hilserl/classifier/push_green_cube_hf_cropped_resized" #"hilserl/classifier/pick_place_lego_cube_1"
|
||||||
model_name: "facebook/convnext-base-224"
|
model_name: "facebook/convnext-base-224"
|
||||||
model_type: "cnn"
|
model_type: "cnn"
|
||||||
num_cameras: 2 # Has to be len(training.image_keys)
|
num_cameras: 2 # Has to be len(training.image_keys)
|
||||||
|
|
|
@ -14,6 +14,9 @@ calibration_dir: .cache/calibration/so100
|
||||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||||
# the number of motors in your follower arms.
|
# the number of motors in your follower arms.
|
||||||
max_relative_target: null
|
max_relative_target: null
|
||||||
|
joint_position_relative_bounds:
|
||||||
|
min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
||||||
|
max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
||||||
|
|
||||||
leader_arms:
|
leader_arms:
|
||||||
main:
|
main:
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.common.robot_devices.control_utils import is_headless
|
||||||
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
|
||||||
|
|
||||||
|
def find_joint_bounds(
|
||||||
|
robot,
|
||||||
|
control_time_s=20,
|
||||||
|
display_cameras=False,
|
||||||
|
):
|
||||||
|
# TODO(rcadene): Add option to record logs
|
||||||
|
if not robot.is_connected:
|
||||||
|
robot.connect()
|
||||||
|
|
||||||
|
control_time_s = float("inf")
|
||||||
|
|
||||||
|
timestamp = 0
|
||||||
|
start_episode_t = time.perf_counter()
|
||||||
|
pos_list = []
|
||||||
|
while timestamp < control_time_s:
|
||||||
|
observation, action = robot.teleop_step(record_data=True)
|
||||||
|
|
||||||
|
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
|
||||||
|
|
||||||
|
if display_cameras and not is_headless():
|
||||||
|
image_keys = [key for key in observation if "image" in key]
|
||||||
|
for key in image_keys:
|
||||||
|
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
|
timestamp = time.perf_counter() - start_episode_t
|
||||||
|
if timestamp > 60:
|
||||||
|
max = np.max(np.stack(pos_list), 0)
|
||||||
|
min = np.min(np.stack(pos_list), 0)
|
||||||
|
print(f"Max angle position per joint {max}")
|
||||||
|
print(f"Min angle position per joint {min}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--robot-path",
|
||||||
|
type=str,
|
||||||
|
default="lerobot/configs/robot/koch.yaml",
|
||||||
|
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--robot-overrides",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
||||||
|
args = parser.parse_args()
|
||||||
|
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||||
|
|
||||||
|
robot = make_robot(robot_cfg)
|
||||||
|
find_joint_bounds(robot, control_time_s=args.control_time_s)
|
|
@ -0,0 +1,697 @@
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms.functional as F # noqa: N812
|
||||||
|
|
||||||
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
|
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
||||||
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config, log_say
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
class HILSerlRobotEnv(gym.Env):
|
||||||
|
"""
|
||||||
|
Gym-like environment wrapper for robot policy evaluation.
|
||||||
|
|
||||||
|
This wrapper provides a consistent interface for interacting with the robot,
|
||||||
|
following the OpenAI Gym environment conventions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
robot,
|
||||||
|
use_delta_action_space: bool = True,
|
||||||
|
delta: float | None = None,
|
||||||
|
display_cameras=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the robot environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot: The robot interface object
|
||||||
|
reward_classifier: Optional reward classifier
|
||||||
|
fps: Frames per second for control
|
||||||
|
control_time_s: Total control time for each episode
|
||||||
|
display_cameras: Whether to display camera feeds
|
||||||
|
output_normalization_params_action: Bound parameters for the action space
|
||||||
|
delta: The delta for the relative joint position action space
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.robot = robot
|
||||||
|
self.display_cameras = display_cameras
|
||||||
|
|
||||||
|
# connect robot
|
||||||
|
if not self.robot.is_connected:
|
||||||
|
self.robot.connect()
|
||||||
|
|
||||||
|
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
|
||||||
|
|
||||||
|
# Episode tracking
|
||||||
|
self.current_step = 0
|
||||||
|
self.episode_data = None
|
||||||
|
|
||||||
|
self.delta = delta
|
||||||
|
self.use_delta_action_space = use_delta_action_space
|
||||||
|
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||||
|
|
||||||
|
self.relative_bounds_size = (
|
||||||
|
self.robot.config.joint_position_relative_bounds["max"]
|
||||||
|
- self.robot.config.joint_position_relative_bounds["min"]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.delta_relative_bounds_size = self.relative_bounds_size * self.delta
|
||||||
|
|
||||||
|
self.robot.config.max_relative_target = self.delta_relative_bounds_size.float()
|
||||||
|
|
||||||
|
# Dynamically determine observation and action spaces
|
||||||
|
self._setup_spaces()
|
||||||
|
|
||||||
|
def _setup_spaces(self):
|
||||||
|
"""
|
||||||
|
Dynamically determine observation and action spaces based on robot capabilities.
|
||||||
|
|
||||||
|
This method should be customized based on the specific robot's observation
|
||||||
|
and action representations.
|
||||||
|
"""
|
||||||
|
# Example space setup - you'll need to adapt this to your specific robot
|
||||||
|
example_obs = self.robot.capture_observation()
|
||||||
|
|
||||||
|
# Observation space (assuming image-based observations)
|
||||||
|
image_keys = [key for key in example_obs if "image" in key]
|
||||||
|
state_keys = [key for key in example_obs if "image" not in key]
|
||||||
|
observation_spaces = {
|
||||||
|
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
|
||||||
|
for key in image_keys
|
||||||
|
}
|
||||||
|
observation_spaces["observation.state"] = gym.spaces.Dict(
|
||||||
|
{
|
||||||
|
key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32)
|
||||||
|
for key in state_keys
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.observation_space = gym.spaces.Dict(observation_spaces)
|
||||||
|
|
||||||
|
# Action space (assuming joint positions)
|
||||||
|
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
|
||||||
|
if self.use_delta_action_space:
|
||||||
|
action_space_robot = gym.spaces.Box(
|
||||||
|
low=-self.relative_bounds_size.cpu().numpy(),
|
||||||
|
high=self.relative_bounds_size.cpu().numpy(),
|
||||||
|
shape=(action_dim,),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
action_space_robot = gym.spaces.Box(
|
||||||
|
low=self.robot.config.joint_position_relative_bounds["min"].cpu().numpy(),
|
||||||
|
high=self.robot.config.joint_position_relative_bounds["max"].cpu().numpy(),
|
||||||
|
shape=(action_dim,),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.action_space = gym.spaces.Tuple(
|
||||||
|
(
|
||||||
|
action_space_robot,
|
||||||
|
gym.spaces.Discrete(2),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Reset the environment to initial state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
observation (dict): Initial observation
|
||||||
|
info (dict): Additional information
|
||||||
|
"""
|
||||||
|
super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
# Capture initial observation
|
||||||
|
observation = self.robot.capture_observation()
|
||||||
|
|
||||||
|
# Reset tracking variables
|
||||||
|
self.current_step = 0
|
||||||
|
self.episode_data = None
|
||||||
|
|
||||||
|
return observation, {"initial_position": self.initial_follower_position}
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: Tuple[np.ndarray, bool]
|
||||||
|
) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Take a step in the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action tuple(np.ndarray, bool):
|
||||||
|
Policy action to be executed on the robot and boolean to determine
|
||||||
|
whether to choose policy action or expert action.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
observation (dict): Next observation
|
||||||
|
reward (float): Reward for this step
|
||||||
|
terminated (bool): Whether the episode has terminated
|
||||||
|
truncated (bool): Whether the episode was truncated
|
||||||
|
info (dict): Additional information
|
||||||
|
"""
|
||||||
|
# The actions recieved are the in form of a tuple containing the policy action and an intervention bool
|
||||||
|
# The boolean inidicated whether we will use the expert's actions (through teleoperation) or the policy actions
|
||||||
|
policy_action, intervention_bool = action
|
||||||
|
teleop_action = None
|
||||||
|
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||||
|
if isinstance(policy_action, torch.Tensor):
|
||||||
|
policy_action = policy_action.cpu().numpy()
|
||||||
|
olicy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
|
||||||
|
if not intervention_bool:
|
||||||
|
if self.use_delta_action_space:
|
||||||
|
target_joint_positions = self.current_joint_positions + self.delta * policy_action
|
||||||
|
else:
|
||||||
|
target_joint_positions = policy_action
|
||||||
|
self.robot.send_action(torch.from_numpy(target_joint_positions))
|
||||||
|
observation = self.robot.capture_observation()
|
||||||
|
else:
|
||||||
|
observation, teleop_action = self.robot.teleop_step(record_data=True)
|
||||||
|
teleop_action = teleop_action["action"] # teleop step returns torch tensors but in a dict
|
||||||
|
|
||||||
|
# teleop actions are returned in absolute joint space
|
||||||
|
# If we are using a relative joint position action space,
|
||||||
|
# there will be a mismatch between the spaces of the policy and teleop actions
|
||||||
|
# Solution is to transform the teleop actions into relative space.
|
||||||
|
# teleop relative action is:
|
||||||
|
if self.use_delta_action_space:
|
||||||
|
teleop_action = teleop_action - self.current_joint_positions
|
||||||
|
if torch.any(teleop_action < -self.delta_relative_bounds_size * self.delta) and torch.any(
|
||||||
|
teleop_action > self.delta_relative_bounds_size
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"relative teleop delta exceeded bounds {self.delta_relative_bounds_size}, teleop_action {teleop_action}\n"
|
||||||
|
f"lower bounds condition {teleop_action < -self.delta_relative_bounds_size}\n"
|
||||||
|
f"upper bounds condition {teleop_action > self.delta_relative_bounds_size}"
|
||||||
|
)
|
||||||
|
teleop_action = torch.clamp(
|
||||||
|
teleop_action, -self.delta_relative_bounds_size, self.delta_relative_bounds_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.current_step += 1
|
||||||
|
|
||||||
|
reward = 0.0
|
||||||
|
terminated = False
|
||||||
|
truncated = False
|
||||||
|
|
||||||
|
return (
|
||||||
|
observation,
|
||||||
|
reward,
|
||||||
|
terminated,
|
||||||
|
truncated,
|
||||||
|
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
|
||||||
|
)
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
"""
|
||||||
|
Render the environment (in this case, display camera feeds).
|
||||||
|
"""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
observation = self.robot.capture_observation()
|
||||||
|
image_keys = [key for key in observation if "image" in key]
|
||||||
|
|
||||||
|
for key in image_keys:
|
||||||
|
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||||
|
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Close the environment and disconnect the robot.
|
||||||
|
"""
|
||||||
|
if self.robot.is_connected:
|
||||||
|
self.robot.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
class ActionRepeatWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, nb_repeat: int = 1):
|
||||||
|
super().__init__(env)
|
||||||
|
self.nb_repeat = nb_repeat
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
for _ in range(self.nb_repeat):
|
||||||
|
obs, reward, done, truncated, info = self.env.step(action)
|
||||||
|
if done or truncated:
|
||||||
|
break
|
||||||
|
return obs, reward, done, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
class RelativeJointPositionActionWrapper(gym.Wrapper):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: HILSerlRobotEnv,
|
||||||
|
# output_normalization_params_action: dict[str, list[float]],
|
||||||
|
delta: float = 0.1,
|
||||||
|
):
|
||||||
|
super().__init__(env)
|
||||||
|
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
||||||
|
self.delta = delta
|
||||||
|
if delta > 1:
|
||||||
|
raise ValueError("Delta should be less than 1")
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
action_joint = action
|
||||||
|
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
||||||
|
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||||
|
action_joint = action[0]
|
||||||
|
joint_positions = self.joint_positions + (self.delta * action_joint)
|
||||||
|
# clip the joint positions to the joint limits with the action space
|
||||||
|
joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high)
|
||||||
|
|
||||||
|
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||||
|
return self.env.step((joint_positions, action[1]))
|
||||||
|
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(joint_positions)
|
||||||
|
if info["is_intervention"]:
|
||||||
|
# teleop actions are returned in absolute joint space
|
||||||
|
# If we are using a relative joint position action space,
|
||||||
|
# there will be a mismatch between the spaces of the policy and teleop actions
|
||||||
|
# Solution is to transform the teleop actions into relative space.
|
||||||
|
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
||||||
|
teleop_action = info["action_intervention"] # teleop actions are in absolute joint space
|
||||||
|
relative_teleop_action = (teleop_action - self.joint_positions) / self.delta
|
||||||
|
info["action_intervention"] = relative_teleop_action
|
||||||
|
|
||||||
|
return self.env.step(joint_positions)
|
||||||
|
|
||||||
|
|
||||||
|
class RewardWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"):
|
||||||
|
self.env = env
|
||||||
|
self.reward_classifier = torch.compile(reward_classifier)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
observation, _, terminated, truncated, info = self.env.step(action)
|
||||||
|
images = [
|
||||||
|
observation[key].to(self.device, non_blocking=True) for key in observation if "image" in key
|
||||||
|
]
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
with torch.inference_mode():
|
||||||
|
reward = (
|
||||||
|
self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
|
||||||
|
)
|
||||||
|
# print(f"fps for reward classifier {1/(time.perf_counter() - start_time)}")
|
||||||
|
reward = reward.item()
|
||||||
|
# print(f"Reward from reward classifier {reward}")
|
||||||
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def reset(self, seed=None, options=None):
|
||||||
|
return self.env.reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeLimitWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, control_time_s, fps):
|
||||||
|
self.env = env
|
||||||
|
self.control_time_s = control_time_s
|
||||||
|
self.fps = fps
|
||||||
|
|
||||||
|
self.last_timestamp = 0.0
|
||||||
|
self.episode_time_in_s = 0.0
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
time_since_last_step = time.perf_counter() - self.last_timestamp
|
||||||
|
self.episode_time_in_s += time_since_last_step
|
||||||
|
self.last_timestamp = time.perf_counter()
|
||||||
|
|
||||||
|
# check if last timestep took more time than the expected fps
|
||||||
|
if 1.0 / time_since_last_step < self.fps:
|
||||||
|
logging.warning(f"Current timestep is lower than the expected fps {self.fps}")
|
||||||
|
|
||||||
|
if self.episode_time_in_s > self.control_time_s:
|
||||||
|
# Terminated = True
|
||||||
|
terminated = True
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def reset(self, seed=None, options=None):
|
||||||
|
self.episode_time_in_s = 0.0
|
||||||
|
self.last_timestamp = time.perf_counter()
|
||||||
|
return self.env.reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCropResizeWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
|
||||||
|
self.env = env
|
||||||
|
self.crop_params_dict = crop_params_dict
|
||||||
|
print(f"obs_keys , {self.env.observation_space}")
|
||||||
|
print(f"crop params dict {crop_params_dict.keys()}")
|
||||||
|
for key_crop in crop_params_dict:
|
||||||
|
if key_crop not in self.env.observation_space.keys():
|
||||||
|
raise ValueError(f"Key {key_crop} not in observation space")
|
||||||
|
for key in crop_params_dict:
|
||||||
|
top, left, height, width = crop_params_dict[key]
|
||||||
|
new_shape = (top + height, left + width)
|
||||||
|
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
|
||||||
|
|
||||||
|
self.resize_size = resize_size
|
||||||
|
if self.resize_size is None:
|
||||||
|
self.resize_size = (128, 128)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
for k in self.crop_params_dict:
|
||||||
|
device = obs[k].device
|
||||||
|
if device == torch.device("mps:0"):
|
||||||
|
obs[k] = obs[k].cpu()
|
||||||
|
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
|
||||||
|
obs[k] = F.resize(obs[k], self.resize_size)
|
||||||
|
obs[k] = obs[k].to(device)
|
||||||
|
# print(f"observation with key {k} with size {obs[k].size()}")
|
||||||
|
cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR))
|
||||||
|
cv2.waitKey(1)
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env, device):
|
||||||
|
super().__init__(env)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
observation = preprocess_observation(observation)
|
||||||
|
|
||||||
|
observation = {key: observation[key].to(self.device, non_blocking=True) for key in observation}
|
||||||
|
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
|
||||||
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env):
|
||||||
|
super().__init__(env)
|
||||||
|
self.listener = None
|
||||||
|
self.events = {
|
||||||
|
"exit_early": False,
|
||||||
|
"pause_policy": False,
|
||||||
|
"reset_env": False,
|
||||||
|
"human_intervention_step": False,
|
||||||
|
}
|
||||||
|
self.event_lock = Lock() # Thread-safe access to events
|
||||||
|
self._init_keyboard_listener()
|
||||||
|
|
||||||
|
def _init_keyboard_listener(self):
|
||||||
|
"""Initialize keyboard listener if not in headless mode"""
|
||||||
|
|
||||||
|
if is_headless():
|
||||||
|
logging.warning(
|
||||||
|
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
from pynput import keyboard
|
||||||
|
|
||||||
|
def on_press(key):
|
||||||
|
with self.event_lock:
|
||||||
|
try:
|
||||||
|
if key == keyboard.Key.right or key == keyboard.Key.esc:
|
||||||
|
print("Right arrow key pressed. Exiting loop...")
|
||||||
|
self.events["exit_early"] = True
|
||||||
|
elif key == keyboard.Key.space:
|
||||||
|
if not self.events["pause_policy"]:
|
||||||
|
print(
|
||||||
|
"Space key pressed. Human intervention required.\n"
|
||||||
|
"Place the leader in similar pose to the follower and press space again."
|
||||||
|
)
|
||||||
|
self.events["pause_policy"] = True
|
||||||
|
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||||
|
elif self.events["pause_policy"] and not self.events["human_intervention_step"]:
|
||||||
|
self.events["human_intervention_step"] = True
|
||||||
|
print("Space key pressed. Human intervention starting.")
|
||||||
|
log_say("Starting human intervention.", play_sounds=True)
|
||||||
|
else:
|
||||||
|
self.events["pause_policy"] = False
|
||||||
|
self.events["human_intervention_step"] = False
|
||||||
|
print("Space key pressed for a third time.")
|
||||||
|
log_say("Continuing with policy actions.", play_sounds=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error handling key press: {e}")
|
||||||
|
|
||||||
|
self.listener = keyboard.Listener(on_press=on_press)
|
||||||
|
self.listener.start()
|
||||||
|
except ImportError:
|
||||||
|
logging.warning("Could not import pynput. Keyboard interface will not be available.")
|
||||||
|
self.listener = None
|
||||||
|
|
||||||
|
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
|
||||||
|
is_intervention = False
|
||||||
|
terminated_by_keyboard = False
|
||||||
|
|
||||||
|
# Extract policy_action if needed
|
||||||
|
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||||
|
policy_action = action[0]
|
||||||
|
|
||||||
|
# Check the event flags without holding the lock for too long.
|
||||||
|
with self.event_lock:
|
||||||
|
if self.events["exit_early"]:
|
||||||
|
terminated_by_keyboard = True
|
||||||
|
# If we need to wait for human intervention, we note that outside the lock.
|
||||||
|
pause_policy = self.events["pause_policy"]
|
||||||
|
|
||||||
|
if pause_policy:
|
||||||
|
# Now, wait for human_intervention_step without holding the lock
|
||||||
|
while True:
|
||||||
|
with self.event_lock:
|
||||||
|
if self.events["human_intervention_step"]:
|
||||||
|
is_intervention = True
|
||||||
|
break
|
||||||
|
time.sleep(0.1) # Check more frequently if desired
|
||||||
|
|
||||||
|
# Execute the step in the underlying environment
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
|
||||||
|
return obs, reward, terminated or terminated_by_keyboard, truncated, info
|
||||||
|
|
||||||
|
def reset(self, **kwargs) -> Tuple[Any, Dict]:
|
||||||
|
"""
|
||||||
|
Reset the environment and clear any pending events
|
||||||
|
"""
|
||||||
|
with self.event_lock:
|
||||||
|
self.events = {k: False for k in self.events}
|
||||||
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Properly clean up the keyboard listener when the environment is closed
|
||||||
|
"""
|
||||||
|
if self.listener is not None:
|
||||||
|
self.listener.stop()
|
||||||
|
super().close()
|
||||||
|
|
||||||
|
|
||||||
|
class ResetWrapper(gym.Wrapper):
|
||||||
|
def __init__(
|
||||||
|
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
|
||||||
|
):
|
||||||
|
super().__init__(env)
|
||||||
|
self.reset_fn = reset_fn
|
||||||
|
self.reset_time_s = reset_time_s
|
||||||
|
|
||||||
|
self.robot = self.unwrapped.robot
|
||||||
|
self.init_pos = self.unwrapped.initial_follower_position
|
||||||
|
|
||||||
|
def reset(self, *, seed=None, options=None):
|
||||||
|
if self.reset_fn is not None:
|
||||||
|
self.reset_fn(self.env)
|
||||||
|
else:
|
||||||
|
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
while time.perf_counter() - start_time < self.reset_time_s:
|
||||||
|
self.robot.teleop_step()
|
||||||
|
|
||||||
|
log_say("Manual reseting of the environment done.", play_sounds=True)
|
||||||
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
|
def make_robot_env(
|
||||||
|
robot,
|
||||||
|
reward_classifier,
|
||||||
|
crop_params_dict=None,
|
||||||
|
fps=30,
|
||||||
|
control_time_s=20,
|
||||||
|
reset_follower_pos=True,
|
||||||
|
display_cameras=False,
|
||||||
|
device="cuda:0",
|
||||||
|
resize_size=None,
|
||||||
|
reset_time_s=10,
|
||||||
|
delta_action=0.1,
|
||||||
|
nb_repeats=1,
|
||||||
|
use_relative_joint_positions=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the robot environment.
|
||||||
|
|
||||||
|
Mimics gym.make() for consistent environment creation.
|
||||||
|
"""
|
||||||
|
env = HILSerlRobotEnv(
|
||||||
|
robot,
|
||||||
|
display_cameras=display_cameras,
|
||||||
|
delta=delta_action,
|
||||||
|
use_delta_action_space=use_relative_joint_positions,
|
||||||
|
)
|
||||||
|
env = ConvertToLeRobotObservation(env, device)
|
||||||
|
if crop_params_dict is not None:
|
||||||
|
env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
|
||||||
|
env = RewardWrapper(env, reward_classifier, device=device)
|
||||||
|
env = TimeLimitWrapper(env, control_time_s, fps)
|
||||||
|
# env = ActionRepeatWrapper(env, nb_repeat=nb_repeats)
|
||||||
|
env = KeyboardInterfaceWrapper(env)
|
||||||
|
env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def get_classifier(pretrained_path, config_path, device="mps"):
|
||||||
|
if pretrained_path is None or config_path is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||||
|
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||||
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||||
|
|
||||||
|
cfg = init_hydra_config(config_path)
|
||||||
|
|
||||||
|
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||||
|
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||||
|
model = Classifier(classifier_config)
|
||||||
|
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||||
|
model = model.to(device)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--fps", type=int, default=30, help="control frequency")
|
||||||
|
parser.add_argument(
|
||||||
|
"--robot-path",
|
||||||
|
type=str,
|
||||||
|
default="lerobot/configs/robot/koch.yaml",
|
||||||
|
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--robot-overrides",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--pretrained-policy-name-or-path",
|
||||||
|
help=(
|
||||||
|
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||||
|
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||||
|
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
help=(
|
||||||
|
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
|
||||||
|
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reward-classifier-pretrained-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the pretrained classifier weights.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reward-classifier-config-file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
||||||
|
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||||
|
robot = make_robot(robot_cfg)
|
||||||
|
|
||||||
|
reward_classifier = get_classifier(
|
||||||
|
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
|
||||||
|
)
|
||||||
|
|
||||||
|
crop_parameters = {
|
||||||
|
"observation.images.laptop": (58, 89, 357, 455),
|
||||||
|
"observation.images.phone": (3, 4, 471, 633),
|
||||||
|
}
|
||||||
|
|
||||||
|
user_relative_joint_positions = True
|
||||||
|
|
||||||
|
env = make_robot_env(
|
||||||
|
robot,
|
||||||
|
reward_classifier,
|
||||||
|
crop_parameters,
|
||||||
|
args.fps,
|
||||||
|
args.control_time_s,
|
||||||
|
args.reset_follower_pos,
|
||||||
|
args.display_cameras,
|
||||||
|
device="mps",
|
||||||
|
resize_size=None,
|
||||||
|
reset_time_s=10,
|
||||||
|
delta_action=0.1,
|
||||||
|
nb_repeats=1,
|
||||||
|
use_relative_joint_positions=user_relative_joint_positions,
|
||||||
|
)
|
||||||
|
|
||||||
|
env.reset()
|
||||||
|
init_pos = env.unwrapped.initial_follower_position
|
||||||
|
|
||||||
|
right_goal = init_pos.copy()
|
||||||
|
right_goal[0] += 50
|
||||||
|
|
||||||
|
left_goal = init_pos.copy()
|
||||||
|
left_goal[0] -= 50
|
||||||
|
|
||||||
|
pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000)
|
||||||
|
|
||||||
|
delta_angle = np.concatenate((-np.ones(50), np.ones(50))) * 100
|
||||||
|
|
||||||
|
while True:
|
||||||
|
action = np.zeros(len(init_pos))
|
||||||
|
for i in range(len(delta_angle)):
|
||||||
|
start_loop_s = time.perf_counter()
|
||||||
|
action[0] = delta_angle[i]
|
||||||
|
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
|
||||||
|
if terminated or truncated:
|
||||||
|
env.reset()
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_loop_s
|
||||||
|
busy_wait(1 / args.fps - dt_s)
|
||||||
|
# action = np.zeros(len(init_pos)) if user_relative_joint_positions else init_pos
|
||||||
|
# for i in range(len(pitch_angle)):
|
||||||
|
# if user_relative_joint_positions:
|
||||||
|
# action[0] = delta_angle[i]
|
||||||
|
# else:
|
||||||
|
# action[0] = pitch_angle[i]
|
||||||
|
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
|
||||||
|
# if terminated or truncated:
|
||||||
|
# logging.info("Max control time reached, reset environment.")
|
||||||
|
# env.reset()
|
||||||
|
|
||||||
|
# for i in reversed(range(len(pitch_angle))):
|
||||||
|
# if user_relative_joint_positions:
|
||||||
|
# action[0] = delta_angle[i]
|
||||||
|
# else:
|
||||||
|
# action[0] = pitch_angle[i]
|
||||||
|
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
|
||||||
|
|
||||||
|
# if terminated or truncated:
|
||||||
|
# logging.info("Max control time reached, reset environment.")
|
||||||
|
# env.reset()
|
Loading…
Reference in New Issue