added possibility to record with a policy; added temporary fixes to train.py to enable training on mac

This commit is contained in:
Michel Aractingi 2024-10-24 23:35:25 +02:00
parent 9a5356d0ac
commit 5e01c21692
3 changed files with 143 additions and 51 deletions

View File

@ -345,7 +345,7 @@ class TDMPCPolicy(
batch[key] = batch[key].transpose(1, 0) batch[key] = batch[key].transpose(1, 0)
action = batch["action"] # (t, b, action_dim) action = batch["action"] # (t, b, action_dim)
reward = batch["reward"] # (t, b) reward = batch["next.reward"] # (t, b)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")} observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations. # Apply random image augmentations.
@ -422,7 +422,7 @@ class TDMPCPolicy(
( (
temporal_loss_coeffs temporal_loss_coeffs
* F.mse_loss(reward_preds, reward, reduction="none") * F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["reward_is_pad"] * ~batch["next.reward_is_pad"]
# `reward_preds` depends on the current observation and the actions. # `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0] * ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
@ -443,7 +443,7 @@ class TDMPCPolicy(
* ~batch["observation.state_is_pad"][0] * ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations. # q_targets depends on the reward and the next observations.
* ~batch["reward_is_pad"] * ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:] * ~batch["observation.state_is_pad"][1:]
) )
.sum(0) .sum(0)

View File

@ -85,11 +85,15 @@ from functools import cache
from pathlib import Path from pathlib import Path
import gymnasium as gym import gymnasium as gym
import multiprocessing import multiprocessing
from contextlib import nullcontext
import cv2 import cv2
import torch import torch
import numpy as np import numpy as np
import tqdm import tqdm
from omegaconf import DictConfig
from PIL import Image from PIL import Image
from datasets import Dataset, Features, Sequence, Value from datasets import Dataset, Features, Sequence, Value
@ -99,12 +103,15 @@ from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDat
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch, hf_transform_to_torch from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch, hf_transform_to_torch
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.common.datasets.video_utils import encode_video_frames from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.utils.utils import init_hydra_config, init_logging from lerobot.common.utils.utils import init_hydra_config, init_logging
from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.scripts.push_dataset_to_hub import ( from lerobot.scripts.push_dataset_to_hub import (
push_dataset_card_to_hub, push_dataset_card_to_hub,
push_meta_data_to_hub, push_meta_data_to_hub,
@ -178,6 +185,29 @@ def is_headless():
print() print()
return True return True
def get_action_from_policy(policy, observation, device, use_amp=False):
with (
torch.inference_mode(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and use_amp
else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
return action.to("cpu")
def init_read_leader(robot, fps, **kwargs): def init_read_leader(robot, fps, **kwargs):
axis_directions = kwargs.get('axis_directions', [1]) axis_directions = kwargs.get('axis_directions', [1])
offsets = kwargs.get('offsets', [0]) offsets = kwargs.get('offsets', [0])
@ -240,7 +270,7 @@ def create_rl_hf_dataset(data_dict):
features["action"] = Sequence( features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
) )
features["reward"] = Value(dtype="float32", id=None) features["next.reward"] = Value(dtype="float32", id=None)
features["seed"] = Value(dtype="int64", id=None) features["seed"] = Value(dtype="int64", id=None)
features["episode_index"] = Value(dtype="int64", id=None) features["episode_index"] = Value(dtype="int64", id=None)
@ -277,6 +307,8 @@ def teleoperate(env, robot: Robot, teleop_time_s=None, **kwargs):
def record( def record(
env, env,
robot: Robot, robot: Robot,
policy: torch.nn.Module | None = None,
policy_cfg: DictConfig | None = None,
fps: int | None = None, fps: int | None = None,
root="data", root="data",
repo_id="lerobot/debug", repo_id="lerobot/debug",
@ -355,7 +387,23 @@ def record(
num_image_writers = num_image_writers_per_camera * 2 ############### num_image_writers = num_image_writers_per_camera * 2 ###############
num_image_writers = max(num_image_writers, 1) num_image_writers = max(num_image_writers, 1)
read_leader, command_queue = init_read_leader(robot, fps, **kwargs) # Load policy if any
if policy is not None:
# Check device is available
device = get_safe_torch_device(policy_cfg.device, log=True)
policy.eval()
policy.to(device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(policy_cfg.seed)
# override fps using policy fps
fps = policy_cfg.env.fps
else:
read_leader, command_queue = init_read_leader(robot, fps, **kwargs)
if not is_headless() and visualize_images: if not is_headless() and visualize_images:
observations_queue = multiprocessing.Queue(1000) observations_queue = multiprocessing.Queue(1000)
show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, )) show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, ))
@ -369,7 +417,7 @@ def record(
while episode_index < num_episodes: while episode_index < num_episodes:
logging.info(f"Recording episode {episode_index}") logging.info(f"Recording episode {episode_index}")
say(f"Recording episode {episode_index}") say(f"Recording episode {episode_index}")
ep_dict = {'action':[], 'reward':[]} ep_dict = {'action':[], 'next.reward':[]}
for k in state_keys_dict: for k in state_keys_dict:
ep_dict[k] = [] ep_dict[k] = []
frame_index = 0 frame_index = 0
@ -381,9 +429,14 @@ def record(
observation, info = env.reset(seed=seed) observation, info = env.reset(seed=seed)
#with stop_reading_leader.get_lock(): #with stop_reading_leader.get_lock():
#stop_reading_leader.Value = 0 #stop_reading_leader.Value = 0
read_leader.start() if policy is None:
read_leader.start()
while timestamp < episode_time_s: while timestamp < episode_time_s:
action = command_queue.get() if policy is None:
action = command_queue.get()
else:
action = get_action_from_policy(policy, observation)
for key in image_keys: for key in image_keys:
str_key = key if key.startswith('observation.images.') else 'observation.images.' + key str_key = key if key.startswith('observation.images.') else 'observation.images.' + key
futures += [ futures += [
@ -402,7 +455,7 @@ def record(
action = np.expand_dims(action, 0) action = np.expand_dims(action, 0)
observation, reward, _, _ , info = env.step(action) observation, reward, _, _ , info = env.step(action)
ep_dict['action'].append(torch.from_numpy(action)) ep_dict['action'].append(torch.from_numpy(action))
ep_dict['reward'].append(torch.tensor(reward)) ep_dict['next.reward'].append(torch.tensor(reward))
print(reward) print(reward)
frame_index += 1 frame_index += 1
@ -417,9 +470,10 @@ def record(
#stop_reading_leader.Value = 1 #stop_reading_leader.Value = 1
# TODO (michel_aractinig): temp fix until I figure out the problem with shared memory # TODO (michel_aractinig): temp fix until I figure out the problem with shared memory
# stop_reading_leader is blocking # stop_reading_leader is blocking
command_queue.close() if policy is None:
read_leader.terminate() command_queue.close()
read_leader, command_queue = init_read_leader(robot, fps, **kwargs) read_leader.terminate()
read_leader, command_queue = init_read_leader(robot, fps, **kwargs)
timestamp = 0 timestamp = 0
@ -451,7 +505,7 @@ def record(
for key in state_keys_dict: for key in state_keys_dict:
ep_dict[key] = torch.vstack(ep_dict[key]) * 180.0 / np.pi ep_dict[key] = torch.vstack(ep_dict[key]) * 180.0 / np.pi
ep_dict['action'] = torch.vstack(ep_dict['action']) * 180.0 / np.pi ep_dict['action'] = torch.vstack(ep_dict['action']) * 180.0 / np.pi
ep_dict['reward'] = torch.stack(ep_dict['reward']) ep_dict['next.reward'] = torch.stack(ep_dict['next.reward'])
ep_dict["seed"] = torch.tensor([seed] * num_frames) ep_dict["seed"] = torch.tensor([seed] * num_frames)
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames) ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
@ -577,7 +631,11 @@ def record(
return lerobot_dataset return lerobot_dataset
def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="lerobot/debug"): def replay(env,
episodes: list,
fps: int | None = None,
root="data",
repo_id="lerobot/debug"):
env = env() env = env()
local_dir = Path(root) / repo_id local_dir = Path(root) / repo_id
@ -700,6 +758,21 @@ if __name__ == "__main__":
default=0, default=0,
help="Visualize image observations with opencv.", help="Visualize image observations with opencv.",
) )
parser_record.add_argument(
"-p",
"--pretrained-policy-name-or-path",
type=str,
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`."
),
)
parser_record.add_argument(
"--policy-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument( parser_replay.add_argument(
@ -748,6 +821,16 @@ if __name__ == "__main__":
teleoperate(env_fn, robot, **kwargs) teleoperate(env_fn, robot, **kwargs)
elif control_mode == "record": elif control_mode == "record":
pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
policy_overrides = args.policy_overrides
del kwargs["pretrained_policy_name_or_path"]
del kwargs["policy_overrides"]
if pretrained_policy_name_or_path is not None:
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
kwargs["policy_cfg"] = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
kwargs["policy"] = make_policy(hydra_cfg=kwargs["policy_cfg"], pretrained_policy_name_or_path=pretrained_policy_path)
record(env_fn, robot, **kwargs) record(env_fn, robot, **kwargs)
elif control_mode == "replay": elif control_mode == "replay":

View File

@ -135,8 +135,8 @@ def update_policy(
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them, # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs. # although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext(): #with lock if lock is not None else nullcontext():
grad_scaler.step(optimizer) grad_scaler.step(optimizer)
# Updates the scale for next iteration. # Updates the scale for next iteration.
grad_scaler.update() grad_scaler.update()
@ -311,6 +311,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset") logging.info("make_dataset")
offline_dataset = make_dataset(cfg) offline_dataset = make_dataset(cfg)
remove_indices=['observation.images.image_top', 'observation.velocity', 'seed']
# temp fix michel_Aractingi TODO
offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(remove_indices)
if isinstance(offline_dataset, MultiLeRobotDataset): if isinstance(offline_dataset, MultiLeRobotDataset):
logging.info( logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: " "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
@ -477,7 +482,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()}, **{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
"next.reward": {"shape": (), "dtype": np.dtype("float32")}, "next.reward": {"shape": (), "dtype": np.dtype("float32")},
"next.done": {"shape": (), "dtype": np.dtype("?")}, "next.done": {"shape": (), "dtype": np.dtype("?")},
"next.success": {"shape": (), "dtype": np.dtype("?")}, #"next.success": {"shape": (), "dtype": np.dtype("?")},
}, },
buffer_capacity=cfg.training.online_buffer_capacity, buffer_capacity=cfg.training.online_buffer_capacity,
fps=online_env.unwrapped.metadata["render_fps"], fps=online_env.unwrapped.metadata["render_fps"],
@ -504,6 +509,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
num_samples=len(concat_dataset), num_samples=len(concat_dataset),
replacement=True, replacement=True,
) )
# TODO michel_aractingi temp fix for incosistent keys
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
concat_dataset, concat_dataset,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
@ -538,8 +546,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
def sample_trajectory_and_update_buffer(): def sample_trajectory_and_update_buffer():
nonlocal rollout_start_seed nonlocal rollout_start_seed
with lock: #with lock:
online_rollout_policy.load_state_dict(policy.state_dict()) online_rollout_policy.load_state_dict(policy.state_dict())
online_rollout_policy.eval() online_rollout_policy.eval()
start_rollout_time = time.perf_counter() start_rollout_time = time.perf_counter()
with torch.no_grad(): with torch.no_grad():
@ -556,37 +564,35 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
) )
online_rollout_s = time.perf_counter() - start_rollout_time online_rollout_s = time.perf_counter() - start_rollout_time
with lock: #with lock:
start_update_buffer_time = time.perf_counter() start_update_buffer_time = time.perf_counter()
online_dataset.add_data(eval_info["episodes"]) online_dataset.add_data(eval_info["episodes"])
# Update the concatenated dataset length used during sampling.
# Update the concatenated dataset length used during sampling. concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) # Update the sampling weights.
sampler.weights = compute_sampler_weights(
# Update the sampling weights. offline_dataset,
sampler.weights = compute_sampler_weights( offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
offline_dataset, online_dataset=online_dataset,
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0), # +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
online_dataset=online_dataset, # this final observation in the offline datasets, but we might add them in future.
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
# this final observation in the offline datasets, but we might add them in future. online_sampling_ratio=cfg.training.online_sampling_ratio,
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1, )
online_sampling_ratio=cfg.training.online_sampling_ratio, sampler.num_samples = len(concat_dataset)
) update_online_buffer_s = time.perf_counter() - start_update_buffer_time
sampler.num_samples = len(concat_dataset)
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
return online_rollout_s, update_online_buffer_s return online_rollout_s, update_online_buffer_s
future = executor.submit(sample_trajectory_and_update_buffer) # TODO remove parallelization for sim
#future = executor.submit(sample_trajectory_and_update_buffer)
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait # If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# here until the rollout and buffer update is done, before proceeding to the policy update steps. # here until the rollout and buffer update is done, before proceeding to the policy update steps.
if ( if (
not cfg.training.do_online_rollout_async not cfg.training.do_online_rollout_async
or len(online_dataset) <= cfg.training.online_buffer_seed_size or len(online_dataset) <= cfg.training.online_buffer_seed_size
): ):
online_rollout_s, update_online_buffer_s = future.result() online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()#future.result()
if len(online_dataset) <= cfg.training.online_buffer_seed_size: if len(online_dataset) <= cfg.training.online_buffer_seed_size:
logging.info( logging.info(
@ -596,12 +602,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
policy.train() policy.train()
for _ in range(cfg.training.online_steps_between_rollouts): for _ in range(cfg.training.online_steps_between_rollouts):
with lock: #with lock:
start_time = time.perf_counter() start_time = time.perf_counter()
batch = next(dl_iter) batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time dataloading_s = time.perf_counter() - start_time
for key in batch: for key in batch:
# TODO michel aractingi convert float64 to float32 for mac
if batch[key].dtype == torch.float64:
batch[key] = batch[key].float()
batch[key] = batch[key].to(cfg.device, non_blocking=True) batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy( train_info = update_policy(
@ -619,8 +628,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
train_info["online_rollout_s"] = online_rollout_s train_info["online_rollout_s"] = online_rollout_s
train_info["update_online_buffer_s"] = update_online_buffer_s train_info["update_online_buffer_s"] = update_online_buffer_s
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
with lock: #with lock:
train_info["online_buffer_size"] = len(online_dataset) train_info["online_buffer_size"] = len(online_dataset)
if step % cfg.training.log_freq == 0: if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True) log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
@ -634,10 +643,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# If we're doing async rollouts, we should now wait until we've completed them before proceeding # If we're doing async rollouts, we should now wait until we've completed them before proceeding
# to do the next batch of rollouts. # to do the next batch of rollouts.
if future.running(): #if future.running():
start = time.perf_counter() #start = time.perf_counter()
online_rollout_s, update_online_buffer_s = future.result() #online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()#future.result()
await_update_online_buffer_s = time.perf_counter() - start #await_update_online_buffer_s = time.perf_counter() - start
if online_step >= cfg.training.online_steps: if online_step >= cfg.training.online_steps:
break break