add train and evals
This commit is contained in:
parent
797f79f182
commit
ef074d7281
|
@ -27,9 +27,9 @@ Follow these steps:
|
||||||
|
|
||||||
## 0 - record examples
|
## 0 - record examples
|
||||||
|
|
||||||
Run the `0_record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g.
|
Run the `record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g.
|
||||||
```
|
```
|
||||||
DATA_DIR='./data' python 0_record_training_data.py \
|
DATA_DIR='./data' python record_training_data.py \
|
||||||
--repo-id=thomwolf/blue_red_sort \
|
--repo-id=thomwolf/blue_red_sort \
|
||||||
--num-episodes=50 \
|
--num-episodes=50 \
|
||||||
--num-frames=400
|
--num-frames=400
|
||||||
|
@ -44,15 +44,34 @@ TODO:
|
||||||
|
|
||||||
Use the standard dataset visualization script pointing it to the right folder:
|
Use the standard dataset visualization script pointing it to the right folder:
|
||||||
```
|
```
|
||||||
DATA_DIR='./data' python visualize_dataset.py python lerobot/scripts/visualize_dataset.py \
|
DATA_DIR='./data' python ../../lerobot/scripts/visualize_dataset.py \
|
||||||
--repo-id thomwolf/blue_red_sort \
|
--repo-id thomwolf/blue_red_sort \
|
||||||
--episode-index 0
|
--episode-index 0
|
||||||
```
|
```
|
||||||
|
|
||||||
## (soon) Train a policy
|
## 2 - Train a policy
|
||||||
|
|
||||||
Run `1_train_real_policy.py` example
|
From the example directory let's run this command to train a model using ACT
|
||||||
|
|
||||||
## (soon) Evaluate the policy in the real world
|
```
|
||||||
|
DATA_DIR='./data' python ../../lerobot/scripts/train.py \
|
||||||
|
device=cuda \
|
||||||
|
hydra.searchpath=[file://./train_config/] \
|
||||||
|
hydra.run.dir=./outputs/train/blue_red_sort \
|
||||||
|
dataset_repo_id=thomwolf/blue_red_sort \
|
||||||
|
env=gym_real_world \
|
||||||
|
policy=act_real_world \
|
||||||
|
wandb.enable=false
|
||||||
|
```
|
||||||
|
|
||||||
Run `2_evaluate_real_policy.py` example
|
## 3 - Evaluate the policy in the real world
|
||||||
|
|
||||||
|
From the example directory let's run this command to evaluate our policy.
|
||||||
|
The configuration for running the policy is in the checkpoint of the model.
|
||||||
|
You can override parameters as follow:
|
||||||
|
|
||||||
|
```
|
||||||
|
python run_policy.py \
|
||||||
|
-p ./outputs/train/blue_red_sort/checkpoints/last/pretrained_model/
|
||||||
|
env.episode_length=1000
|
||||||
|
```
|
||||||
|
|
|
@ -11,13 +11,13 @@ from .robot import Robot
|
||||||
FPS = 30
|
FPS = 30
|
||||||
|
|
||||||
CAMERAS_SHAPES = {
|
CAMERAS_SHAPES = {
|
||||||
"observation.images.high": (480, 640, 3),
|
"images.high": (480, 640, 3),
|
||||||
"observation.images.low": (480, 640, 3),
|
"images.low": (480, 640, 3),
|
||||||
}
|
}
|
||||||
|
|
||||||
CAMERAS_PORTS = {
|
CAMERAS_PORTS = {
|
||||||
"observation.images.high": "/dev/video6",
|
"images.high": "/dev/video6",
|
||||||
"observation.images.low": "/dev/video0",
|
"images.low": "/dev/video0",
|
||||||
}
|
}
|
||||||
|
|
||||||
LEADER_PORT = "/dev/ttyACM1"
|
LEADER_PORT = "/dev/ttyACM1"
|
||||||
|
@ -52,6 +52,8 @@ class RealEnv(gym.Env):
|
||||||
leader_port: str = LEADER_PORT,
|
leader_port: str = LEADER_PORT,
|
||||||
warmup_steps: int = 100,
|
warmup_steps: int = 100,
|
||||||
trigger_torque=70,
|
trigger_torque=70,
|
||||||
|
fps: int = FPS,
|
||||||
|
fps_tolerance: float = 0.1,
|
||||||
):
|
):
|
||||||
self.num_joints = num_joints
|
self.num_joints = num_joints
|
||||||
self.cameras_shapes = cameras_shapes
|
self.cameras_shapes = cameras_shapes
|
||||||
|
@ -62,6 +64,8 @@ class RealEnv(gym.Env):
|
||||||
self.follower_port = follower_port
|
self.follower_port = follower_port
|
||||||
self.leader_port = leader_port
|
self.leader_port = leader_port
|
||||||
self.record = record
|
self.record = record
|
||||||
|
self.fps = fps
|
||||||
|
self.fps_tolerance = fps_tolerance
|
||||||
|
|
||||||
# Initialize the robot
|
# Initialize the robot
|
||||||
self.follower = Robot(device_name=self.follower_port)
|
self.follower = Robot(device_name=self.follower_port)
|
||||||
|
@ -72,10 +76,13 @@ class RealEnv(gym.Env):
|
||||||
# Initialize the cameras - sorted by camera names
|
# Initialize the cameras - sorted by camera names
|
||||||
self.cameras = {}
|
self.cameras = {}
|
||||||
for cn, p in sorted(self.cameras_ports.items()):
|
for cn, p in sorted(self.cameras_ports.items()):
|
||||||
assert cn.startswith("observation.images."), "Camera names must start with 'observation.images.'."
|
|
||||||
self.cameras[cn] = cv2.VideoCapture(p)
|
self.cameras[cn] = cv2.VideoCapture(p)
|
||||||
if not all(c.isOpened() for c in self.cameras.values()):
|
if not self.cameras[cn].isOpened():
|
||||||
raise OSError("Cannot open all camera ports.")
|
raise OSError(
|
||||||
|
f"Cannot open camera port {p} for {cn}."
|
||||||
|
f" Make sure the camera is connected and the port is correct."
|
||||||
|
f"Also check you are not spinning several instances of the same environment (eval.batch_size)"
|
||||||
|
)
|
||||||
|
|
||||||
# Specify gym action and observation spaces
|
# Specify gym action and observation spaces
|
||||||
observation_space = {}
|
observation_space = {}
|
||||||
|
@ -98,7 +105,7 @@ class RealEnv(gym.Env):
|
||||||
if self.cameras_shapes:
|
if self.cameras_shapes:
|
||||||
for cn, hwc_shape in self.cameras_shapes.items():
|
for cn, hwc_shape in self.cameras_shapes.items():
|
||||||
# Assumes images are unsigned int8 in [0,255]
|
# Assumes images are unsigned int8 in [0,255]
|
||||||
observation_space[f"images.{cn}"] = spaces.Box(
|
observation_space[cn] = spaces.Box(
|
||||||
low=0,
|
low=0,
|
||||||
high=255,
|
high=255,
|
||||||
# height x width x channels (e.g. 480 x 640 x 3)
|
# height x width x channels (e.g. 480 x 640 x 3)
|
||||||
|
@ -111,22 +118,20 @@ class RealEnv(gym.Env):
|
||||||
|
|
||||||
self._observation = {}
|
self._observation = {}
|
||||||
self._terminated = False
|
self._terminated = False
|
||||||
self._action_time = time.time()
|
self.starting_time = time.time()
|
||||||
|
self.timestamps = []
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
qpos = self.follower.read_position()
|
qpos = self.follower.read_position()
|
||||||
self._observation["agent_pos"] = pwm2pos(qpos)
|
self._observation["agent_pos"] = pwm2pos(qpos)
|
||||||
for cn, c in self.cameras.items():
|
for cn, c in self.cameras.items():
|
||||||
self._observation[f"images.{cn}"] = capture_image(
|
self._observation[cn] = capture_image(c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0])
|
||||||
c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0]
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.record:
|
if self.record:
|
||||||
leader_pos = self.leader.read_position()
|
action = self.leader.read_position()
|
||||||
self._observation["leader_pos"] = pwm2pos(leader_pos)
|
self._observation["leader_pos"] = pwm2pos(action)
|
||||||
|
|
||||||
def reset(self, seed: int | None = None):
|
def reset(self, seed: int | None = None):
|
||||||
del seed
|
|
||||||
# Reset the robot and sync the leader and follower if we are recording
|
# Reset the robot and sync the leader and follower if we are recording
|
||||||
for _ in range(self.warmup_steps):
|
for _ in range(self.warmup_steps):
|
||||||
self._get_obs()
|
self._get_obs()
|
||||||
|
@ -134,10 +139,22 @@ class RealEnv(gym.Env):
|
||||||
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
|
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
|
||||||
self._terminated = False
|
self._terminated = False
|
||||||
info = {}
|
info = {}
|
||||||
|
self.timestamps = []
|
||||||
return self._observation, info
|
return self._observation, info
|
||||||
|
|
||||||
def step(self, action: np.ndarray = None):
|
def step(self, action: np.ndarray = None):
|
||||||
# Reset the observation
|
if self.timestamps:
|
||||||
|
# wait the right amount of time to stay at the desired fps
|
||||||
|
time.sleep(max(0, 1 / self.fps - (time.time() - self.timestamps[-1])))
|
||||||
|
recording_time = time.time() - self.starting_time
|
||||||
|
else:
|
||||||
|
# it's the first step so we start the timer
|
||||||
|
self.starting_time = time.time()
|
||||||
|
recording_time = 0
|
||||||
|
|
||||||
|
self.timestamps.append(recording_time)
|
||||||
|
|
||||||
|
# Get the observation
|
||||||
self._get_obs()
|
self._get_obs()
|
||||||
if self.record:
|
if self.record:
|
||||||
# Teleoperate the leader
|
# Teleoperate the leader
|
||||||
|
@ -145,9 +162,20 @@ class RealEnv(gym.Env):
|
||||||
else:
|
else:
|
||||||
# Apply the action to the follower
|
# Apply the action to the follower
|
||||||
self.follower.set_goal_pos(pos2pwm(action))
|
self.follower.set_goal_pos(pos2pwm(action))
|
||||||
|
|
||||||
reward = 0
|
reward = 0
|
||||||
terminated = truncated = self._terminated
|
terminated = truncated = self._terminated
|
||||||
info = {}
|
info = {"timestamp": recording_time, "fps_error": False}
|
||||||
|
|
||||||
|
# Check if we are able to keep up with the desired fps
|
||||||
|
if recording_time - self.timestamps[-1] > 1 / (self.fps - self.fps_tolerance):
|
||||||
|
print(
|
||||||
|
f"Error: recording time interval {recording_time - self.timestamps[-1]:.2f} is greater"
|
||||||
|
f"than expected {1 / (self.fps - self.fps_tolerance):.2f}"
|
||||||
|
f" at frame {len(self.timestamps)}"
|
||||||
|
)
|
||||||
|
info["fps_error"] = True
|
||||||
|
|
||||||
return self._observation, reward, terminated, truncated, info
|
return self._observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
def render(self): ...
|
def render(self): ...
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
|
"""This script demonstrates how to record a LeRobot dataset of training data
|
||||||
|
using a very simple gym environment (see in examples/real_robot_example/gym_real_world/gym_environment.py).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
|
|
||||||
import gym_real_world # noqa: F401
|
import gym_real_world # noqa: F401
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
@ -27,15 +31,12 @@ parser.add_argument("--num-frames", type=int, default=400)
|
||||||
parser.add_argument("--num-workers", type=int, default=16)
|
parser.add_argument("--num-workers", type=int, default=16)
|
||||||
parser.add_argument("--keep-last", action="store_true")
|
parser.add_argument("--keep-last", action="store_true")
|
||||||
parser.add_argument("--push-to-hub", action="store_true")
|
parser.add_argument("--push-to-hub", action="store_true")
|
||||||
|
parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fps",
|
"--fps_tolerance",
|
||||||
type=int,
|
type=float,
|
||||||
default=30,
|
default=0.1,
|
||||||
help="Frames per second of the recording."
|
help="Tolerance in fps for the recording before dropping episodes.",
|
||||||
"If we are not able to record at this fps, we will adjust the fps in the metadata.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tolerance", type=float, default=0.01, help="Tolerance in seconds for the recording time."
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
|
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
|
||||||
|
@ -47,7 +48,7 @@ num_episodes = args.num_episodes
|
||||||
num_frames = args.num_frames
|
num_frames = args.num_frames
|
||||||
revision = args.revision
|
revision = args.revision
|
||||||
fps = args.fps
|
fps = args.fps
|
||||||
tolerance = args.tolerance
|
fps_tolerance = args.fps_tolerance
|
||||||
|
|
||||||
out_data = DATA_DIR / repo_id
|
out_data = DATA_DIR / repo_id
|
||||||
|
|
||||||
|
@ -67,7 +68,7 @@ if not os.path.exists(videos_dir):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
|
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
|
||||||
gym_handle = "gym_real_world/RealEnv-v0"
|
gym_handle = "gym_real_world/RealEnv-v0"
|
||||||
env = gym.make(gym_handle, disable_env_checker=True, record=True)
|
env = gym.make(gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance)
|
||||||
|
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
episode_data_index = {"from": [], "to": []}
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
@ -84,59 +85,46 @@ if __name__ == "__main__":
|
||||||
os.system(f'spd-say "go {ep_idx}"')
|
os.system(f'spd-say "go {ep_idx}"')
|
||||||
# init buffers
|
# init buffers
|
||||||
obs_replay = {k: [] for k in env.observation_space}
|
obs_replay = {k: [] for k in env.observation_space}
|
||||||
timestamps = []
|
|
||||||
|
|
||||||
starting_time = time.time()
|
drop_episode = False
|
||||||
|
timestamps = []
|
||||||
for _ in tqdm(range(num_frames)):
|
for _ in tqdm(range(num_frames)):
|
||||||
# Apply the next action
|
# Apply the next action
|
||||||
observation, _, _, _, _ = env.step(action=None)
|
observation, _, _, _, info = env.step(action=None)
|
||||||
# images_stacked = np.hstack(list(observation['pixels'].values()))
|
# images_stacked = np.hstack(list(observation['pixels'].values()))
|
||||||
# images_stacked = cv2.cvtColor(images_stacked, cv2.COLOR_RGB2BGR)
|
# images_stacked = cv2.cvtColor(images_stacked, cv2.COLOR_RGB2BGR)
|
||||||
# cv2.imshow('frame', images_stacked)
|
# cv2.imshow('frame', images_stacked)
|
||||||
|
|
||||||
|
if info["fps_error"]:
|
||||||
|
os.system(f'spd-say "Error fps too low, dropping episode {ep_idx}"')
|
||||||
|
drop_episode = True
|
||||||
|
break
|
||||||
|
|
||||||
# store data
|
# store data
|
||||||
for key in observation:
|
for key in observation:
|
||||||
obs_replay[key].append(copy.deepcopy(observation[key]))
|
obs_replay[key].append(copy.deepcopy(observation[key]))
|
||||||
|
timestamps.append(info["timestamp"])
|
||||||
recording_time = time.time() - starting_time
|
|
||||||
timestamps.append(recording_time)
|
|
||||||
|
|
||||||
# Check if we are able to keep up with the desired fps
|
|
||||||
if recording_time > num_frames / fps + tolerance:
|
|
||||||
print(
|
|
||||||
f"Error: recording time {recording_time:.2f} is greater than expected {num_frames / fps:.2f}"
|
|
||||||
f" + tolerance {tolerance:.2f}"
|
|
||||||
f" at frame {len(timestamps)}"
|
|
||||||
f" in episode {ep_idx}."
|
|
||||||
f"Dropping the rest of the episode."
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
# wait the right amount of time to stay at the desired fps
|
|
||||||
time.sleep(max(0, 1 / fps - (time.time() - starting_time)))
|
|
||||||
|
|
||||||
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
# break
|
# break
|
||||||
|
|
||||||
os.system('spd-say "stop"')
|
os.system('spd-say "stop"')
|
||||||
|
|
||||||
if len(timestamps) == num_frames:
|
if not drop_episode:
|
||||||
os.system(f'spd-say "saving episode {ep_idx}"')
|
os.system(f'spd-say "saving episode {ep_idx}"')
|
||||||
ep_dict = {}
|
ep_dict = {}
|
||||||
# store images in png and create the video
|
# store images in png and create the video
|
||||||
for img_key in env.cameras:
|
for img_key in env.cameras:
|
||||||
save_images_concurrently(
|
save_images_concurrently(
|
||||||
obs_replay[f"images.{img_key}"],
|
obs_replay[img_key],
|
||||||
images_dir / f"{img_key}_episode_{ep_idx:06d}",
|
images_dir / f"{img_key}_episode_{ep_idx:06d}",
|
||||||
args.num_workers,
|
args.num_workers,
|
||||||
)
|
)
|
||||||
# for i in tqdm(range(num_frames)):
|
|
||||||
# cv2.imwrite(str(images_dir / f"{img_key}_episode_{ep_idx:06d}" / f"frame_{i:06d}.png"),
|
|
||||||
# obs_replay[i]['pixels'][img_key])
|
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
# store the reference to the video frame
|
# store the reference to the video frame
|
||||||
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps]
|
ep_dict[f"observation.{img_key}"] = [
|
||||||
# shutil.rmtree(tmp_imgs_dir)
|
{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps
|
||||||
|
]
|
||||||
|
|
||||||
state = torch.tensor(np.array(obs_replay["agent_pos"]))
|
state = torch.tensor(np.array(obs_replay["agent_pos"]))
|
||||||
action = torch.tensor(np.array(obs_replay["leader_pos"]))
|
action = torch.tensor(np.array(obs_replay["leader_pos"]))
|
||||||
|
@ -198,8 +186,6 @@ if __name__ == "__main__":
|
||||||
features["timestamp"] = Value(dtype="float32", id=None)
|
features["timestamp"] = Value(dtype="float32", id=None)
|
||||||
features["next.done"] = Value(dtype="bool", id=None)
|
features["next.done"] = Value(dtype="bool", id=None)
|
||||||
features["index"] = Value(dtype="int64", id=None)
|
features["index"] = Value(dtype="int64", id=None)
|
||||||
# TODO(rcadene): add success
|
|
||||||
# features["next.success"] = Value(dtype='bool', id=None)
|
|
||||||
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
|
@ -0,0 +1,60 @@
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import gym_real_world # noqa: F401
|
||||||
|
import gymnasium as gym # noqa: F401
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||||
|
from huggingface_hub.utils._validators import HFValidationError
|
||||||
|
|
||||||
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
from lerobot.scripts.eval import eval
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
|
)
|
||||||
|
group = parser.add_mutually_exclusive_group(required=True)
|
||||||
|
group.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--pretrained-policy-name-or-path",
|
||||||
|
help=(
|
||||||
|
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||||
|
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||||
|
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||||
|
parser.add_argument(
|
||||||
|
"overrides",
|
||||||
|
nargs="*",
|
||||||
|
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
pretrained_policy_path = Path(
|
||||||
|
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
|
||||||
|
)
|
||||||
|
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||||
|
if isinstance(e, HFValidationError):
|
||||||
|
error_message = (
|
||||||
|
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_message = (
|
||||||
|
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||||
|
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
|
||||||
|
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||||
|
raise ValueError(
|
||||||
|
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||||
|
"repo ID, nor is it an existing local directory."
|
||||||
|
)
|
||||||
|
|
||||||
|
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
|
|
@ -3,11 +3,10 @@
|
||||||
fps: 30
|
fps: 30
|
||||||
|
|
||||||
env:
|
env:
|
||||||
name: dora
|
name: real_world
|
||||||
task: DoraKoch-v0
|
task: RealEnv-v0
|
||||||
state_dim: 6
|
state_dim: 6
|
||||||
action_dim: 6
|
action_dim: 6
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
episode_length: 400
|
episode_length: 200
|
||||||
gym:
|
real_world: true
|
||||||
fps: ${fps}
|
|
|
@ -1,8 +1,8 @@
|
||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
|
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
|
||||||
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
# low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||||
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||||
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||||
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||||
|
@ -15,14 +15,14 @@
|
||||||
# ```
|
# ```
|
||||||
|
|
||||||
seed: 1000
|
seed: 1000
|
||||||
dataset_repo_id: thomwolf/blue_sort
|
dataset_repo_id: ???
|
||||||
|
|
||||||
override_dataset_stats:
|
override_dataset_stats:
|
||||||
observation.images.cam_high:
|
observation.images.high:
|
||||||
# stats from imagenet, since we use a pretrained vision model
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
observation.images.cam_low:
|
observation.images.low:
|
||||||
# stats from imagenet, since we use a pretrained vision model
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
@ -46,8 +46,8 @@ training:
|
||||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 50
|
n_episodes: 1
|
||||||
batch_size: 50
|
batch_size: 1
|
||||||
|
|
||||||
# See `configuration_act.py` for more details.
|
# See `configuration_act.py` for more details.
|
||||||
policy:
|
policy:
|
||||||
|
@ -60,16 +60,16 @@ policy:
|
||||||
|
|
||||||
input_shapes:
|
input_shapes:
|
||||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
observation.images.cam_high: [3, 480, 640]
|
observation.images.high: [3, 480, 640]
|
||||||
observation.images.cam_low: [3, 480, 640]
|
observation.images.low: [3, 480, 640]
|
||||||
observation.state: ["${env.state_dim}"]
|
observation.state: ["${env.state_dim}"]
|
||||||
output_shapes:
|
output_shapes:
|
||||||
action: ["${env.action_dim}"]
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
input_normalization_modes:
|
input_normalization_modes:
|
||||||
observation.images.cam_high: mean_std
|
observation.images.high: mean_std
|
||||||
observation.images.cam_low: mean_std
|
observation.images.low: mean_std
|
||||||
observation.state: mean_std
|
observation.state: mean_std
|
||||||
output_normalization_modes:
|
output_normalization_modes:
|
||||||
action: mean_std
|
action: mean_std
|
|
@ -56,7 +56,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||||
)
|
)
|
||||||
|
|
||||||
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
|
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
|
||||||
if cfg.env.name != "dora":
|
if not cfg.env.real_world:
|
||||||
if isinstance(cfg.dataset_repo_id, str):
|
if isinstance(cfg.dataset_repo_id, str):
|
||||||
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
|
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -29,10 +29,12 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
# map to expected inputs for the policy
|
# map to expected inputs for the policy
|
||||||
return_observations = {}
|
return_observations = {}
|
||||||
|
|
||||||
if isinstance(observations["pixels"], dict):
|
if "pixels" in observations and isinstance(observations["pixels"], dict):
|
||||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||||
else:
|
elif "pixels" in observations and isinstance(observations["pixels"], np.ndarray):
|
||||||
imgs = {"observation.image": observations["pixels"]}
|
imgs = {"observation.image": observations["pixels"]}
|
||||||
|
else:
|
||||||
|
imgs = {f"observation.{key}": img for key, img in observations.items() if "images" in key}
|
||||||
|
|
||||||
for imgkey, img in imgs.items():
|
for imgkey, img in imgs.items():
|
||||||
img = torch.from_numpy(img)
|
img = torch.from_numpy(img)
|
||||||
|
|
|
@ -9,6 +9,7 @@ env:
|
||||||
action_dim: 14
|
action_dim: 14
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
episode_length: 400
|
episode_length: 400
|
||||||
|
real_world: false
|
||||||
gym:
|
gym:
|
||||||
obs_type: pixels_agent_pos
|
obs_type: pixels_agent_pos
|
||||||
render_mode: rgb_array
|
render_mode: rgb_array
|
||||||
|
|
|
@ -9,5 +9,6 @@ env:
|
||||||
action_dim: 14
|
action_dim: 14
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
episode_length: 400
|
episode_length: 400
|
||||||
|
real_world: true
|
||||||
gym:
|
gym:
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
|
|
|
@ -10,6 +10,7 @@ env:
|
||||||
action_dim: 2
|
action_dim: 2
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
episode_length: 300
|
episode_length: 300
|
||||||
|
real_world: false
|
||||||
gym:
|
gym:
|
||||||
obs_type: pixels_agent_pos
|
obs_type: pixels_agent_pos
|
||||||
render_mode: rgb_array
|
render_mode: rgb_array
|
||||||
|
|
|
@ -10,6 +10,7 @@ env:
|
||||||
action_dim: 4
|
action_dim: 4
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
episode_length: 25
|
episode_length: 25
|
||||||
|
real_world: false
|
||||||
gym:
|
gym:
|
||||||
obs_type: pixels_agent_pos
|
obs_type: pixels_agent_pos
|
||||||
render_mode: rgb_array
|
render_mode: rgb_array
|
||||||
|
|
|
@ -164,7 +164,7 @@ def rollout(
|
||||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||||
# available of none of the envs finished.
|
# available of none of the envs finished.
|
||||||
if "final_info" in info:
|
if "final_info" in info:
|
||||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
successes = [i["is_success"] if i is not None else False for i in info["final_info"]]
|
||||||
else:
|
else:
|
||||||
successes = [False] * env.num_envs
|
successes = [False] * env.num_envs
|
||||||
|
|
||||||
|
|
|
@ -406,6 +406,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
|
if cfg.training.eval_freq > 0:
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
logging.info("End of training")
|
logging.info("End of training")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue