From 5e01c216921a3baf91c14ef4cfa7e226cf2f11b9 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Thu, 24 Oct 2024 23:35:25 +0200 Subject: [PATCH] added possibility to record with a policy; added temporary fixes to train.py to enable training on mac --- .../common/policies/tdmpc/modeling_tdmpc.py | 6 +- lerobot/scripts/control_sim_robot.py | 105 ++++++++++++++++-- lerobot/scripts/train.py | 83 ++++++++------ 3 files changed, 143 insertions(+), 51 deletions(-) diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 7b081537..7ee57693 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -345,7 +345,7 @@ class TDMPCPolicy( batch[key] = batch[key].transpose(1, 0) action = batch["action"] # (t, b, action_dim) - reward = batch["reward"] # (t, b) + reward = batch["next.reward"] # (t, b) observations = {k: v for k, v in batch.items() if k.startswith("observation.")} # Apply random image augmentations. @@ -422,7 +422,7 @@ class TDMPCPolicy( ( temporal_loss_coeffs * F.mse_loss(reward_preds, reward, reduction="none") - * ~batch["reward_is_pad"] + * ~batch["next.reward_is_pad"] # `reward_preds` depends on the current observation and the actions. * ~batch["observation.state_is_pad"][0] * ~batch["action_is_pad"] @@ -443,7 +443,7 @@ class TDMPCPolicy( * ~batch["observation.state_is_pad"][0] * ~batch["action_is_pad"] # q_targets depends on the reward and the next observations. - * ~batch["reward_is_pad"] + * ~batch["next.reward_is_pad"] * ~batch["observation.state_is_pad"][1:] ) .sum(0) diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py index 91344efe..99268cf4 100644 --- a/lerobot/scripts/control_sim_robot.py +++ b/lerobot/scripts/control_sim_robot.py @@ -85,11 +85,15 @@ from functools import cache from pathlib import Path import gymnasium as gym import multiprocessing +from contextlib import nullcontext + import cv2 import torch import numpy as np import tqdm +from omegaconf import DictConfig + from PIL import Image from datasets import Dataset, Features, Sequence, Value @@ -99,12 +103,15 @@ from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDat from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch, hf_transform_to_torch +from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed from lerobot.common.datasets.video_utils import encode_video_frames from lerobot.common.robot_devices.robots.factory import make_robot +from lerobot.common.policies.factory import make_policy from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.envs.factory import make_env from lerobot.common.utils.utils import init_hydra_config, init_logging +from lerobot.scripts.eval import get_pretrained_policy_path from lerobot.scripts.push_dataset_to_hub import ( push_dataset_card_to_hub, push_meta_data_to_hub, @@ -178,6 +185,29 @@ def is_headless(): print() return True +def get_action_from_policy(policy, observation, device, use_amp=False): + with ( + torch.inference_mode(), + torch.autocast(device_type=device.type) + if device.type == "cuda" and use_amp + else nullcontext(), + ): + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension + for name in observation: + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + # Compute the next action with the policy + # based on the current observation + action = policy.select_action(observation) + # Remove batch dimension + action = action.squeeze(0) + # Move to cpu, if not already the case + return action.to("cpu") + def init_read_leader(robot, fps, **kwargs): axis_directions = kwargs.get('axis_directions', [1]) offsets = kwargs.get('offsets', [0]) @@ -240,7 +270,7 @@ def create_rl_hf_dataset(data_dict): features["action"] = Sequence( length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) ) - features["reward"] = Value(dtype="float32", id=None) + features["next.reward"] = Value(dtype="float32", id=None) features["seed"] = Value(dtype="int64", id=None) features["episode_index"] = Value(dtype="int64", id=None) @@ -277,6 +307,8 @@ def teleoperate(env, robot: Robot, teleop_time_s=None, **kwargs): def record( env, robot: Robot, + policy: torch.nn.Module | None = None, + policy_cfg: DictConfig | None = None, fps: int | None = None, root="data", repo_id="lerobot/debug", @@ -355,7 +387,23 @@ def record( num_image_writers = num_image_writers_per_camera * 2 ############### num_image_writers = max(num_image_writers, 1) - read_leader, command_queue = init_read_leader(robot, fps, **kwargs) + # Load policy if any + if policy is not None: + # Check device is available + device = get_safe_torch_device(policy_cfg.device, log=True) + + policy.eval() + policy.to(device) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + set_global_seed(policy_cfg.seed) + + # override fps using policy fps + fps = policy_cfg.env.fps + else: + read_leader, command_queue = init_read_leader(robot, fps, **kwargs) + if not is_headless() and visualize_images: observations_queue = multiprocessing.Queue(1000) show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, )) @@ -369,7 +417,7 @@ def record( while episode_index < num_episodes: logging.info(f"Recording episode {episode_index}") say(f"Recording episode {episode_index}") - ep_dict = {'action':[], 'reward':[]} + ep_dict = {'action':[], 'next.reward':[]} for k in state_keys_dict: ep_dict[k] = [] frame_index = 0 @@ -381,9 +429,14 @@ def record( observation, info = env.reset(seed=seed) #with stop_reading_leader.get_lock(): #stop_reading_leader.Value = 0 - read_leader.start() + if policy is None: + read_leader.start() while timestamp < episode_time_s: - action = command_queue.get() + if policy is None: + action = command_queue.get() + else: + action = get_action_from_policy(policy, observation) + for key in image_keys: str_key = key if key.startswith('observation.images.') else 'observation.images.' + key futures += [ @@ -402,7 +455,7 @@ def record( action = np.expand_dims(action, 0) observation, reward, _, _ , info = env.step(action) ep_dict['action'].append(torch.from_numpy(action)) - ep_dict['reward'].append(torch.tensor(reward)) + ep_dict['next.reward'].append(torch.tensor(reward)) print(reward) frame_index += 1 @@ -417,9 +470,10 @@ def record( #stop_reading_leader.Value = 1 # TODO (michel_aractinig): temp fix until I figure out the problem with shared memory # stop_reading_leader is blocking - command_queue.close() - read_leader.terminate() - read_leader, command_queue = init_read_leader(robot, fps, **kwargs) + if policy is None: + command_queue.close() + read_leader.terminate() + read_leader, command_queue = init_read_leader(robot, fps, **kwargs) timestamp = 0 @@ -451,7 +505,7 @@ def record( for key in state_keys_dict: ep_dict[key] = torch.vstack(ep_dict[key]) * 180.0 / np.pi ep_dict['action'] = torch.vstack(ep_dict['action']) * 180.0 / np.pi - ep_dict['reward'] = torch.stack(ep_dict['reward']) + ep_dict['next.reward'] = torch.stack(ep_dict['next.reward']) ep_dict["seed"] = torch.tensor([seed] * num_frames) ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames) @@ -577,7 +631,11 @@ def record( return lerobot_dataset -def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="lerobot/debug"): +def replay(env, + episodes: list, + fps: int | None = None, + root="data", + repo_id="lerobot/debug"): env = env() local_dir = Path(root) / repo_id @@ -700,6 +758,21 @@ if __name__ == "__main__": default=0, help="Visualize image observations with opencv.", ) + parser_record.add_argument( + "-p", + "--pretrained-policy-name-or-path", + type=str, + help=( + "Either the repo ID of a model hosted on the Hub or a path to a directory containing weights " + "saved using `Policy.save_pretrained`." + ), + ) + parser_record.add_argument( + "--policy-overrides", + type=str, + nargs="*", + help="Any key=value arguments to override config values (use dots for.nested=overrides)", + ) parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay.add_argument( @@ -748,6 +821,16 @@ if __name__ == "__main__": teleoperate(env_fn, robot, **kwargs) elif control_mode == "record": + pretrained_policy_name_or_path = args.pretrained_policy_name_or_path + policy_overrides = args.policy_overrides + del kwargs["pretrained_policy_name_or_path"] + del kwargs["policy_overrides"] + + if pretrained_policy_name_or_path is not None: + pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path) + kwargs["policy_cfg"] = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides) + kwargs["policy"] = make_policy(hydra_cfg=kwargs["policy_cfg"], pretrained_policy_name_or_path=pretrained_policy_path) + record(env_fn, robot, **kwargs) elif control_mode == "replay": diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 45807503..795189b0 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -135,8 +135,8 @@ def update_policy( # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, # although it still skips optimizer.step() if the gradients contain infs or NaNs. - with lock if lock is not None else nullcontext(): - grad_scaler.step(optimizer) + #with lock if lock is not None else nullcontext(): + grad_scaler.step(optimizer) # Updates the scale for next iteration. grad_scaler.update() @@ -311,6 +311,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("make_dataset") offline_dataset = make_dataset(cfg) + + remove_indices=['observation.images.image_top', 'observation.velocity', 'seed'] + # temp fix michel_Aractingi TODO + offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(remove_indices) + if isinstance(offline_dataset, MultiLeRobotDataset): logging.info( "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " @@ -477,7 +482,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No **{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()}, "next.reward": {"shape": (), "dtype": np.dtype("float32")}, "next.done": {"shape": (), "dtype": np.dtype("?")}, - "next.success": {"shape": (), "dtype": np.dtype("?")}, + #"next.success": {"shape": (), "dtype": np.dtype("?")}, }, buffer_capacity=cfg.training.online_buffer_capacity, fps=online_env.unwrapped.metadata["render_fps"], @@ -504,6 +509,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No num_samples=len(concat_dataset), replacement=True, ) + + # TODO michel_aractingi temp fix for incosistent keys + dataloader = torch.utils.data.DataLoader( concat_dataset, batch_size=cfg.training.batch_size, @@ -538,8 +546,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No def sample_trajectory_and_update_buffer(): nonlocal rollout_start_seed - with lock: - online_rollout_policy.load_state_dict(policy.state_dict()) + #with lock: + online_rollout_policy.load_state_dict(policy.state_dict()) online_rollout_policy.eval() start_rollout_time = time.perf_counter() with torch.no_grad(): @@ -556,37 +564,35 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No ) online_rollout_s = time.perf_counter() - start_rollout_time - with lock: - start_update_buffer_time = time.perf_counter() - online_dataset.add_data(eval_info["episodes"]) - - # Update the concatenated dataset length used during sampling. - concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) - - # Update the sampling weights. - sampler.weights = compute_sampler_weights( - offline_dataset, - offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0), - online_dataset=online_dataset, - # +1 because online rollouts return an extra frame for the "final observation". Note: we don't have - # this final observation in the offline datasets, but we might add them in future. - online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1, - online_sampling_ratio=cfg.training.online_sampling_ratio, - ) - sampler.num_samples = len(concat_dataset) - - update_online_buffer_s = time.perf_counter() - start_update_buffer_time + #with lock: + start_update_buffer_time = time.perf_counter() + online_dataset.add_data(eval_info["episodes"]) + # Update the concatenated dataset length used during sampling. + concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) + # Update the sampling weights. + sampler.weights = compute_sampler_weights( + offline_dataset, + offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0), + online_dataset=online_dataset, + # +1 because online rollouts return an extra frame for the "final observation". Note: we don't have + # this final observation in the offline datasets, but we might add them in future. + online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1, + online_sampling_ratio=cfg.training.online_sampling_ratio, + ) + sampler.num_samples = len(concat_dataset) + update_online_buffer_s = time.perf_counter() - start_update_buffer_time return online_rollout_s, update_online_buffer_s - future = executor.submit(sample_trajectory_and_update_buffer) + # TODO remove parallelization for sim + #future = executor.submit(sample_trajectory_and_update_buffer) # If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait # here until the rollout and buffer update is done, before proceeding to the policy update steps. if ( not cfg.training.do_online_rollout_async or len(online_dataset) <= cfg.training.online_buffer_seed_size ): - online_rollout_s, update_online_buffer_s = future.result() + online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()#future.result() if len(online_dataset) <= cfg.training.online_buffer_seed_size: logging.info( @@ -596,12 +602,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No policy.train() for _ in range(cfg.training.online_steps_between_rollouts): - with lock: - start_time = time.perf_counter() - batch = next(dl_iter) - dataloading_s = time.perf_counter() - start_time + #with lock: + start_time = time.perf_counter() + batch = next(dl_iter) + dataloading_s = time.perf_counter() - start_time for key in batch: + # TODO michel aractingi convert float64 to float32 for mac + if batch[key].dtype == torch.float64: + batch[key] = batch[key].float() batch[key] = batch[key].to(cfg.device, non_blocking=True) train_info = update_policy( @@ -619,8 +628,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No train_info["online_rollout_s"] = online_rollout_s train_info["update_online_buffer_s"] = update_online_buffer_s train_info["await_update_online_buffer_s"] = await_update_online_buffer_s - with lock: - train_info["online_buffer_size"] = len(online_dataset) + #with lock: + train_info["online_buffer_size"] = len(online_dataset) if step % cfg.training.log_freq == 0: log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True) @@ -634,10 +643,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # If we're doing async rollouts, we should now wait until we've completed them before proceeding # to do the next batch of rollouts. - if future.running(): - start = time.perf_counter() - online_rollout_s, update_online_buffer_s = future.result() - await_update_online_buffer_s = time.perf_counter() - start + #if future.running(): + #start = time.perf_counter() + #online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()#future.result() + #await_update_online_buffer_s = time.perf_counter() - start if online_step >= cfg.training.online_steps: break