added possibility to record with a policy; added temporary fixes to train.py to enable training on mac
This commit is contained in:
parent
9a5356d0ac
commit
5e01c21692
|
@ -345,7 +345,7 @@ class TDMPCPolicy(
|
|||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"] # (t, b, action_dim)
|
||||
reward = batch["reward"] # (t, b)
|
||||
reward = batch["next.reward"] # (t, b)
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
# Apply random image augmentations.
|
||||
|
@ -422,7 +422,7 @@ class TDMPCPolicy(
|
|||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||
* ~batch["reward_is_pad"]
|
||||
* ~batch["next.reward_is_pad"]
|
||||
# `reward_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
|
@ -443,7 +443,7 @@ class TDMPCPolicy(
|
|||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["reward_is_pad"]
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
|
|
|
@ -85,11 +85,15 @@ from functools import cache
|
|||
from pathlib import Path
|
||||
import gymnasium as gym
|
||||
import multiprocessing
|
||||
from contextlib import nullcontext
|
||||
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from PIL import Image
|
||||
from datasets import Dataset, Features, Sequence, Value
|
||||
|
||||
|
@ -99,12 +103,15 @@ from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDat
|
|||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch, hf_transform_to_torch
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.utils.utils import init_hydra_config, init_logging
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
from lerobot.scripts.push_dataset_to_hub import (
|
||||
push_dataset_card_to_hub,
|
||||
push_meta_data_to_hub,
|
||||
|
@ -178,6 +185,29 @@ def is_headless():
|
|||
print()
|
||||
return True
|
||||
|
||||
def get_action_from_policy(policy, observation, device, use_amp=False):
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if device.type == "cuda" and use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
# Move to cpu, if not already the case
|
||||
return action.to("cpu")
|
||||
|
||||
def init_read_leader(robot, fps, **kwargs):
|
||||
axis_directions = kwargs.get('axis_directions', [1])
|
||||
offsets = kwargs.get('offsets', [0])
|
||||
|
@ -240,7 +270,7 @@ def create_rl_hf_dataset(data_dict):
|
|||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["reward"] = Value(dtype="float32", id=None)
|
||||
features["next.reward"] = Value(dtype="float32", id=None)
|
||||
|
||||
features["seed"] = Value(dtype="int64", id=None)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
|
@ -277,6 +307,8 @@ def teleoperate(env, robot: Robot, teleop_time_s=None, **kwargs):
|
|||
def record(
|
||||
env,
|
||||
robot: Robot,
|
||||
policy: torch.nn.Module | None = None,
|
||||
policy_cfg: DictConfig | None = None,
|
||||
fps: int | None = None,
|
||||
root="data",
|
||||
repo_id="lerobot/debug",
|
||||
|
@ -355,7 +387,23 @@ def record(
|
|||
num_image_writers = num_image_writers_per_camera * 2 ###############
|
||||
num_image_writers = max(num_image_writers, 1)
|
||||
|
||||
read_leader, command_queue = init_read_leader(robot, fps, **kwargs)
|
||||
# Load policy if any
|
||||
if policy is not None:
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(policy_cfg.device, log=True)
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(policy_cfg.seed)
|
||||
|
||||
# override fps using policy fps
|
||||
fps = policy_cfg.env.fps
|
||||
else:
|
||||
read_leader, command_queue = init_read_leader(robot, fps, **kwargs)
|
||||
|
||||
if not is_headless() and visualize_images:
|
||||
observations_queue = multiprocessing.Queue(1000)
|
||||
show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, ))
|
||||
|
@ -369,7 +417,7 @@ def record(
|
|||
while episode_index < num_episodes:
|
||||
logging.info(f"Recording episode {episode_index}")
|
||||
say(f"Recording episode {episode_index}")
|
||||
ep_dict = {'action':[], 'reward':[]}
|
||||
ep_dict = {'action':[], 'next.reward':[]}
|
||||
for k in state_keys_dict:
|
||||
ep_dict[k] = []
|
||||
frame_index = 0
|
||||
|
@ -381,9 +429,14 @@ def record(
|
|||
observation, info = env.reset(seed=seed)
|
||||
#with stop_reading_leader.get_lock():
|
||||
#stop_reading_leader.Value = 0
|
||||
read_leader.start()
|
||||
if policy is None:
|
||||
read_leader.start()
|
||||
while timestamp < episode_time_s:
|
||||
action = command_queue.get()
|
||||
if policy is None:
|
||||
action = command_queue.get()
|
||||
else:
|
||||
action = get_action_from_policy(policy, observation)
|
||||
|
||||
for key in image_keys:
|
||||
str_key = key if key.startswith('observation.images.') else 'observation.images.' + key
|
||||
futures += [
|
||||
|
@ -402,7 +455,7 @@ def record(
|
|||
action = np.expand_dims(action, 0)
|
||||
observation, reward, _, _ , info = env.step(action)
|
||||
ep_dict['action'].append(torch.from_numpy(action))
|
||||
ep_dict['reward'].append(torch.tensor(reward))
|
||||
ep_dict['next.reward'].append(torch.tensor(reward))
|
||||
print(reward)
|
||||
|
||||
frame_index += 1
|
||||
|
@ -417,9 +470,10 @@ def record(
|
|||
#stop_reading_leader.Value = 1
|
||||
# TODO (michel_aractinig): temp fix until I figure out the problem with shared memory
|
||||
# stop_reading_leader is blocking
|
||||
command_queue.close()
|
||||
read_leader.terminate()
|
||||
read_leader, command_queue = init_read_leader(robot, fps, **kwargs)
|
||||
if policy is None:
|
||||
command_queue.close()
|
||||
read_leader.terminate()
|
||||
read_leader, command_queue = init_read_leader(robot, fps, **kwargs)
|
||||
|
||||
timestamp = 0
|
||||
|
||||
|
@ -451,7 +505,7 @@ def record(
|
|||
for key in state_keys_dict:
|
||||
ep_dict[key] = torch.vstack(ep_dict[key]) * 180.0 / np.pi
|
||||
ep_dict['action'] = torch.vstack(ep_dict['action']) * 180.0 / np.pi
|
||||
ep_dict['reward'] = torch.stack(ep_dict['reward'])
|
||||
ep_dict['next.reward'] = torch.stack(ep_dict['next.reward'])
|
||||
|
||||
ep_dict["seed"] = torch.tensor([seed] * num_frames)
|
||||
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
|
||||
|
@ -577,7 +631,11 @@ def record(
|
|||
return lerobot_dataset
|
||||
|
||||
|
||||
def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
||||
def replay(env,
|
||||
episodes: list,
|
||||
fps: int | None = None,
|
||||
root="data",
|
||||
repo_id="lerobot/debug"):
|
||||
|
||||
env = env()
|
||||
local_dir = Path(root) / repo_id
|
||||
|
@ -700,6 +758,21 @@ if __name__ == "__main__":
|
|||
default=0,
|
||||
help="Visualize image observations with opencv.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
type=str,
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`."
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--policy-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
|
@ -748,6 +821,16 @@ if __name__ == "__main__":
|
|||
teleoperate(env_fn, robot, **kwargs)
|
||||
|
||||
elif control_mode == "record":
|
||||
pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
|
||||
policy_overrides = args.policy_overrides
|
||||
del kwargs["pretrained_policy_name_or_path"]
|
||||
del kwargs["policy_overrides"]
|
||||
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
kwargs["policy_cfg"] = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||
kwargs["policy"] = make_policy(hydra_cfg=kwargs["policy_cfg"], pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
|
||||
record(env_fn, robot, **kwargs)
|
||||
|
||||
elif control_mode == "replay":
|
||||
|
|
|
@ -135,8 +135,8 @@ def update_policy(
|
|||
|
||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||
with lock if lock is not None else nullcontext():
|
||||
grad_scaler.step(optimizer)
|
||||
#with lock if lock is not None else nullcontext():
|
||||
grad_scaler.step(optimizer)
|
||||
# Updates the scale for next iteration.
|
||||
grad_scaler.update()
|
||||
|
||||
|
@ -311,6 +311,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
logging.info("make_dataset")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
|
||||
remove_indices=['observation.images.image_top', 'observation.velocity', 'seed']
|
||||
# temp fix michel_Aractingi TODO
|
||||
offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(remove_indices)
|
||||
|
||||
if isinstance(offline_dataset, MultiLeRobotDataset):
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
|
@ -477,7 +482,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
|
||||
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
|
||||
"next.done": {"shape": (), "dtype": np.dtype("?")},
|
||||
"next.success": {"shape": (), "dtype": np.dtype("?")},
|
||||
#"next.success": {"shape": (), "dtype": np.dtype("?")},
|
||||
},
|
||||
buffer_capacity=cfg.training.online_buffer_capacity,
|
||||
fps=online_env.unwrapped.metadata["render_fps"],
|
||||
|
@ -504,6 +509,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
num_samples=len(concat_dataset),
|
||||
replacement=True,
|
||||
)
|
||||
|
||||
# TODO michel_aractingi temp fix for incosistent keys
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
batch_size=cfg.training.batch_size,
|
||||
|
@ -538,8 +546,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
def sample_trajectory_and_update_buffer():
|
||||
nonlocal rollout_start_seed
|
||||
with lock:
|
||||
online_rollout_policy.load_state_dict(policy.state_dict())
|
||||
#with lock:
|
||||
online_rollout_policy.load_state_dict(policy.state_dict())
|
||||
online_rollout_policy.eval()
|
||||
start_rollout_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
|
@ -556,37 +564,35 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
)
|
||||
online_rollout_s = time.perf_counter() - start_rollout_time
|
||||
|
||||
with lock:
|
||||
start_update_buffer_time = time.perf_counter()
|
||||
online_dataset.add_data(eval_info["episodes"])
|
||||
|
||||
# Update the concatenated dataset length used during sampling.
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||
|
||||
# Update the sampling weights.
|
||||
sampler.weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
|
||||
online_dataset=online_dataset,
|
||||
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||
# this final observation in the offline datasets, but we might add them in future.
|
||||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||
)
|
||||
sampler.num_samples = len(concat_dataset)
|
||||
|
||||
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
|
||||
#with lock:
|
||||
start_update_buffer_time = time.perf_counter()
|
||||
online_dataset.add_data(eval_info["episodes"])
|
||||
# Update the concatenated dataset length used during sampling.
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||
# Update the sampling weights.
|
||||
sampler.weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
|
||||
online_dataset=online_dataset,
|
||||
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||
# this final observation in the offline datasets, but we might add them in future.
|
||||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||
)
|
||||
sampler.num_samples = len(concat_dataset)
|
||||
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
|
||||
|
||||
return online_rollout_s, update_online_buffer_s
|
||||
|
||||
future = executor.submit(sample_trajectory_and_update_buffer)
|
||||
# TODO remove parallelization for sim
|
||||
#future = executor.submit(sample_trajectory_and_update_buffer)
|
||||
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
|
||||
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
|
||||
if (
|
||||
not cfg.training.do_online_rollout_async
|
||||
or len(online_dataset) <= cfg.training.online_buffer_seed_size
|
||||
):
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()#future.result()
|
||||
|
||||
if len(online_dataset) <= cfg.training.online_buffer_seed_size:
|
||||
logging.info(
|
||||
|
@ -596,12 +602,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
policy.train()
|
||||
for _ in range(cfg.training.online_steps_between_rollouts):
|
||||
with lock:
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
#with lock:
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
# TODO michel aractingi convert float64 to float32 for mac
|
||||
if batch[key].dtype == torch.float64:
|
||||
batch[key] = batch[key].float()
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(
|
||||
|
@ -619,8 +628,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
train_info["online_rollout_s"] = online_rollout_s
|
||||
train_info["update_online_buffer_s"] = update_online_buffer_s
|
||||
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
|
||||
with lock:
|
||||
train_info["online_buffer_size"] = len(online_dataset)
|
||||
#with lock:
|
||||
train_info["online_buffer_size"] = len(online_dataset)
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
|
||||
|
@ -634,10 +643,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
|
||||
# to do the next batch of rollouts.
|
||||
if future.running():
|
||||
start = time.perf_counter()
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
await_update_online_buffer_s = time.perf_counter() - start
|
||||
#if future.running():
|
||||
#start = time.perf_counter()
|
||||
#online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()#future.result()
|
||||
#await_update_online_buffer_s = time.perf_counter() - start
|
||||
|
||||
if online_step >= cfg.training.online_steps:
|
||||
break
|
||||
|
|
Loading…
Reference in New Issue