add train and evals

This commit is contained in:
Thomas Wolf 2024-06-09 14:03:47 +02:00
parent 797f79f182
commit ef074d7281
14 changed files with 184 additions and 85 deletions

View File

@ -27,9 +27,9 @@ Follow these steps:
## 0 - record examples
Run the `0_record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g.
Run the `record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g.
```
DATA_DIR='./data' python 0_record_training_data.py \
DATA_DIR='./data' python record_training_data.py \
--repo-id=thomwolf/blue_red_sort \
--num-episodes=50 \
--num-frames=400
@ -44,15 +44,34 @@ TODO:
Use the standard dataset visualization script pointing it to the right folder:
```
DATA_DIR='./data' python visualize_dataset.py python lerobot/scripts/visualize_dataset.py \
DATA_DIR='./data' python ../../lerobot/scripts/visualize_dataset.py \
--repo-id thomwolf/blue_red_sort \
--episode-index 0
```
## (soon) Train a policy
## 2 - Train a policy
Run `1_train_real_policy.py` example
From the example directory let's run this command to train a model using ACT
## (soon) Evaluate the policy in the real world
```
DATA_DIR='./data' python ../../lerobot/scripts/train.py \
device=cuda \
hydra.searchpath=[file://./train_config/] \
hydra.run.dir=./outputs/train/blue_red_sort \
dataset_repo_id=thomwolf/blue_red_sort \
env=gym_real_world \
policy=act_real_world \
wandb.enable=false
```
Run `2_evaluate_real_policy.py` example
## 3 - Evaluate the policy in the real world
From the example directory let's run this command to evaluate our policy.
The configuration for running the policy is in the checkpoint of the model.
You can override parameters as follow:
```
python run_policy.py \
-p ./outputs/train/blue_red_sort/checkpoints/last/pretrained_model/
env.episode_length=1000
```

View File

@ -11,13 +11,13 @@ from .robot import Robot
FPS = 30
CAMERAS_SHAPES = {
"observation.images.high": (480, 640, 3),
"observation.images.low": (480, 640, 3),
"images.high": (480, 640, 3),
"images.low": (480, 640, 3),
}
CAMERAS_PORTS = {
"observation.images.high": "/dev/video6",
"observation.images.low": "/dev/video0",
"images.high": "/dev/video6",
"images.low": "/dev/video0",
}
LEADER_PORT = "/dev/ttyACM1"
@ -52,6 +52,8 @@ class RealEnv(gym.Env):
leader_port: str = LEADER_PORT,
warmup_steps: int = 100,
trigger_torque=70,
fps: int = FPS,
fps_tolerance: float = 0.1,
):
self.num_joints = num_joints
self.cameras_shapes = cameras_shapes
@ -62,6 +64,8 @@ class RealEnv(gym.Env):
self.follower_port = follower_port
self.leader_port = leader_port
self.record = record
self.fps = fps
self.fps_tolerance = fps_tolerance
# Initialize the robot
self.follower = Robot(device_name=self.follower_port)
@ -72,10 +76,13 @@ class RealEnv(gym.Env):
# Initialize the cameras - sorted by camera names
self.cameras = {}
for cn, p in sorted(self.cameras_ports.items()):
assert cn.startswith("observation.images."), "Camera names must start with 'observation.images.'."
self.cameras[cn] = cv2.VideoCapture(p)
if not all(c.isOpened() for c in self.cameras.values()):
raise OSError("Cannot open all camera ports.")
if not self.cameras[cn].isOpened():
raise OSError(
f"Cannot open camera port {p} for {cn}."
f" Make sure the camera is connected and the port is correct."
f"Also check you are not spinning several instances of the same environment (eval.batch_size)"
)
# Specify gym action and observation spaces
observation_space = {}
@ -98,7 +105,7 @@ class RealEnv(gym.Env):
if self.cameras_shapes:
for cn, hwc_shape in self.cameras_shapes.items():
# Assumes images are unsigned int8 in [0,255]
observation_space[f"images.{cn}"] = spaces.Box(
observation_space[cn] = spaces.Box(
low=0,
high=255,
# height x width x channels (e.g. 480 x 640 x 3)
@ -111,22 +118,20 @@ class RealEnv(gym.Env):
self._observation = {}
self._terminated = False
self._action_time = time.time()
self.starting_time = time.time()
self.timestamps = []
def _get_obs(self):
qpos = self.follower.read_position()
self._observation["agent_pos"] = pwm2pos(qpos)
for cn, c in self.cameras.items():
self._observation[f"images.{cn}"] = capture_image(
c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0]
)
self._observation[cn] = capture_image(c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0])
if self.record:
leader_pos = self.leader.read_position()
self._observation["leader_pos"] = pwm2pos(leader_pos)
action = self.leader.read_position()
self._observation["leader_pos"] = pwm2pos(action)
def reset(self, seed: int | None = None):
del seed
# Reset the robot and sync the leader and follower if we are recording
for _ in range(self.warmup_steps):
self._get_obs()
@ -134,10 +139,22 @@ class RealEnv(gym.Env):
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
self._terminated = False
info = {}
self.timestamps = []
return self._observation, info
def step(self, action: np.ndarray = None):
# Reset the observation
if self.timestamps:
# wait the right amount of time to stay at the desired fps
time.sleep(max(0, 1 / self.fps - (time.time() - self.timestamps[-1])))
recording_time = time.time() - self.starting_time
else:
# it's the first step so we start the timer
self.starting_time = time.time()
recording_time = 0
self.timestamps.append(recording_time)
# Get the observation
self._get_obs()
if self.record:
# Teleoperate the leader
@ -145,9 +162,20 @@ class RealEnv(gym.Env):
else:
# Apply the action to the follower
self.follower.set_goal_pos(pos2pwm(action))
reward = 0
terminated = truncated = self._terminated
info = {}
info = {"timestamp": recording_time, "fps_error": False}
# Check if we are able to keep up with the desired fps
if recording_time - self.timestamps[-1] > 1 / (self.fps - self.fps_tolerance):
print(
f"Error: recording time interval {recording_time - self.timestamps[-1]:.2f} is greater"
f"than expected {1 / (self.fps - self.fps_tolerance):.2f}"
f" at frame {len(self.timestamps)}"
)
info["fps_error"] = True
return self._observation, reward, terminated, truncated, info
def render(self): ...

View File

@ -1,7 +1,11 @@
"""This script demonstrates how to record a LeRobot dataset of training data
using a very simple gym environment (see in examples/real_robot_example/gym_real_world/gym_environment.py).
"""
import argparse
import copy
import os
import time
import gym_real_world # noqa: F401
import gymnasium as gym
@ -27,15 +31,12 @@ parser.add_argument("--num-frames", type=int, default=400)
parser.add_argument("--num-workers", type=int, default=16)
parser.add_argument("--keep-last", action="store_true")
parser.add_argument("--push-to-hub", action="store_true")
parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.")
parser.add_argument(
"--fps",
type=int,
default=30,
help="Frames per second of the recording."
"If we are not able to record at this fps, we will adjust the fps in the metadata.",
)
parser.add_argument(
"--tolerance", type=float, default=0.01, help="Tolerance in seconds for the recording time."
"--fps_tolerance",
type=float,
default=0.1,
help="Tolerance in fps for the recording before dropping episodes.",
)
parser.add_argument(
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
@ -47,7 +48,7 @@ num_episodes = args.num_episodes
num_frames = args.num_frames
revision = args.revision
fps = args.fps
tolerance = args.tolerance
fps_tolerance = args.fps_tolerance
out_data = DATA_DIR / repo_id
@ -67,7 +68,7 @@ if not os.path.exists(videos_dir):
if __name__ == "__main__":
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
gym_handle = "gym_real_world/RealEnv-v0"
env = gym.make(gym_handle, disable_env_checker=True, record=True)
env = gym.make(gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
@ -84,59 +85,46 @@ if __name__ == "__main__":
os.system(f'spd-say "go {ep_idx}"')
# init buffers
obs_replay = {k: [] for k in env.observation_space}
timestamps = []
starting_time = time.time()
drop_episode = False
timestamps = []
for _ in tqdm(range(num_frames)):
# Apply the next action
observation, _, _, _, _ = env.step(action=None)
observation, _, _, _, info = env.step(action=None)
# images_stacked = np.hstack(list(observation['pixels'].values()))
# images_stacked = cv2.cvtColor(images_stacked, cv2.COLOR_RGB2BGR)
# cv2.imshow('frame', images_stacked)
if info["fps_error"]:
os.system(f'spd-say "Error fps too low, dropping episode {ep_idx}"')
drop_episode = True
break
# store data
for key in observation:
obs_replay[key].append(copy.deepcopy(observation[key]))
recording_time = time.time() - starting_time
timestamps.append(recording_time)
# Check if we are able to keep up with the desired fps
if recording_time > num_frames / fps + tolerance:
print(
f"Error: recording time {recording_time:.2f} is greater than expected {num_frames / fps:.2f}"
f" + tolerance {tolerance:.2f}"
f" at frame {len(timestamps)}"
f" in episode {ep_idx}."
f"Dropping the rest of the episode."
)
break
# wait the right amount of time to stay at the desired fps
time.sleep(max(0, 1 / fps - (time.time() - starting_time)))
timestamps.append(info["timestamp"])
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
os.system('spd-say "stop"')
if len(timestamps) == num_frames:
if not drop_episode:
os.system(f'spd-say "saving episode {ep_idx}"')
ep_dict = {}
# store images in png and create the video
for img_key in env.cameras:
save_images_concurrently(
obs_replay[f"images.{img_key}"],
obs_replay[img_key],
images_dir / f"{img_key}_episode_{ep_idx:06d}",
args.num_workers,
)
# for i in tqdm(range(num_frames)):
# cv2.imwrite(str(images_dir / f"{img_key}_episode_{ep_idx:06d}" / f"frame_{i:06d}.png"),
# obs_replay[i]['pixels'][img_key])
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
# store the reference to the video frame
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps]
# shutil.rmtree(tmp_imgs_dir)
ep_dict[f"observation.{img_key}"] = [
{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps
]
state = torch.tensor(np.array(obs_replay["agent_pos"]))
action = torch.tensor(np.array(obs_replay["leader_pos"]))
@ -198,8 +186,6 @@ if __name__ == "__main__":
features["timestamp"] = Value(dtype="float32", id=None)
features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None)
# TODO(rcadene): add success
# features["next.success"] = Value(dtype='bool', id=None)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch)

View File

@ -0,0 +1,60 @@
import argparse
import logging
from pathlib import Path
import gym_real_world # noqa: F401
import gymnasium as gym # noqa: F401
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from lerobot.common.utils.utils import init_logging
from lerobot.scripts.eval import eval
if __name__ == "__main__":
init_logging()
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
try:
pretrained_policy_path = Path(
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
)
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)

View File

@ -3,11 +3,10 @@
fps: 30
env:
name: dora
task: DoraKoch-v0
name: real_world
task: RealEnv-v0
state_dim: 6
action_dim: 6
fps: ${fps}
episode_length: 400
gym:
fps: ${fps}
episode_length: 200
real_world: true

View File

@ -1,8 +1,8 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
# low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
@ -15,14 +15,14 @@
# ```
seed: 1000
dataset_repo_id: thomwolf/blue_sort
dataset_repo_id: ???
override_dataset_stats:
observation.images.cam_high:
observation.images.high:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_low:
observation.images.low:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
@ -46,8 +46,8 @@ training:
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 50
batch_size: 50
n_episodes: 1
batch_size: 1
# See `configuration_act.py` for more details.
policy:
@ -60,16 +60,16 @@ policy:
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.cam_high: [3, 480, 640]
observation.images.cam_low: [3, 480, 640]
observation.images.high: [3, 480, 640]
observation.images.low: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.cam_high: mean_std
observation.images.cam_low: mean_std
observation.images.high: mean_std
observation.images.low: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std

View File

@ -56,7 +56,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
)
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
if cfg.env.name != "dora":
if not cfg.env.real_world:
if isinstance(cfg.dataset_repo_id, str):
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
else:

View File

@ -29,10 +29,12 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# map to expected inputs for the policy
return_observations = {}
if isinstance(observations["pixels"], dict):
if "pixels" in observations and isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
else:
elif "pixels" in observations and isinstance(observations["pixels"], np.ndarray):
imgs = {"observation.image": observations["pixels"]}
else:
imgs = {f"observation.{key}": img for key, img in observations.items() if "images" in key}
for imgkey, img in imgs.items():
img = torch.from_numpy(img)

View File

@ -9,6 +9,7 @@ env:
action_dim: 14
fps: ${fps}
episode_length: 400
real_world: false
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array

View File

@ -9,5 +9,6 @@ env:
action_dim: 14
fps: ${fps}
episode_length: 400
real_world: true
gym:
fps: ${fps}

View File

@ -10,6 +10,7 @@ env:
action_dim: 2
fps: ${fps}
episode_length: 300
real_world: false
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array

View File

@ -10,6 +10,7 @@ env:
action_dim: 4
fps: ${fps}
episode_length: 25
real_world: false
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array

View File

@ -164,7 +164,7 @@ def rollout(
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished.
if "final_info" in info:
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
successes = [i["is_success"] if i is not None else False for i in info["final_info"]]
else:
successes = [False] * env.num_envs

View File

@ -406,7 +406,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1
eval_env.close()
if cfg.training.eval_freq > 0:
eval_env.close()
logging.info("End of training")