diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 639acf1f..1fe27e95 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -105,7 +105,7 @@ class AlohaDataset(torch.utils.data.Dataset): @property def num_samples(self) -> int: - return len(self.data_dict["index"]) + return len(self.data_dict["index"]) if "index" in self.data_dict else 0 @property def num_episodes(self) -> int: diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index b468637e..068b154e 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -119,7 +119,7 @@ class PushtDataset(torch.utils.data.Dataset): @property def num_samples(self) -> int: - return len(self.data_dict["index"]) + return len(self.data_dict["index"]) if "index" in self.data_dict else 0 @property def num_episodes(self) -> int: diff --git a/lerobot/common/datasets/xarm.py b/lerobot/common/datasets/xarm.py index 733267ab..0dfcc5c9 100644 --- a/lerobot/common/datasets/xarm.py +++ b/lerobot/common/datasets/xarm.py @@ -60,7 +60,7 @@ class XarmDataset(torch.utils.data.Dataset): @property def num_samples(self) -> int: - return len(self.data_dict["index"]) + return len(self.data_dict["index"]) if "index" in self.data_dict else 0 @property def num_episodes(self) -> int: @@ -126,7 +126,8 @@ class XarmDataset(torch.utils.data.Dataset): image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) action = torch.tensor(dataset_dict["actions"][idx0:idx1]) - # TODO(rcadene): concat the last "next_observations" to "observations" + # TODO(rcadene): we have a missing last frame which is the observation when the env is done + # it is critical to have this frame for tdmpc to predict a "done observation/state" # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) # next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 8636aa6e..98880f4a 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -35,9 +35,9 @@ def make_policy(cfg): if cfg.policy.pretrained_model_path: # TODO(rcadene): hack for old pretrained models from fowm if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: - if "offline" in cfg.pretrained_model_path: + if "offline" in cfg.policy.pretrained_model_path: policy.step[0] = 25000 - elif "final" in cfg.pretrained_model_path: + elif "final" in cfg.policy.pretrained_model_path: policy.step[0] = 100000 else: raise NotImplementedError() diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 942ee9b1..14728576 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -333,94 +333,6 @@ class TDMPCPolicy(nn.Module): """Main update function. Corresponds to one iteration of the model learning.""" start_time = time.time() - # num_slices = self.cfg.batch_size - # batch_size = self.cfg.horizon * num_slices - - # if demo_buffer is None: - # demo_batch_size = 0 - # else: - # # Update oversampling ratio - # demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step) - # demo_num_slices = int(demo_pc_batch * self.batch_size) - # demo_batch_size = self.cfg.horizon * demo_num_slices - # batch_size -= demo_batch_size - # num_slices -= demo_num_slices - # replay_buffer._sampler.num_slices = num_slices - # demo_buffer._sampler.num_slices = demo_num_slices - - # assert demo_batch_size % self.cfg.horizon == 0 - # assert demo_batch_size % demo_num_slices == 0 - - # assert batch_size % self.cfg.horizon == 0 - # assert batch_size % num_slices == 0 - - # # Sample from interaction dataset - - # def process_batch(batch, horizon, num_slices): - # # trajectory t = 256, horizon h = 5 - # # (t h) ... -> h t ... - # batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() - - # obs = { - # "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True), - # "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True), - # } - # action = batch["action"].to(self.device, non_blocking=True) - # next_obses = { - # "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True), - # "state": batch["next", "observation", "state"].to(self.device, non_blocking=True), - # } - # reward = batch["next", "reward"].to(self.device, non_blocking=True) - - # idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True) - # weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True) - - # # TODO(rcadene): rearrange directly in offline dataset - # if reward.ndim == 2: - # reward = einops.rearrange(reward, "h t -> h t 1") - - # assert reward.ndim == 3 - # assert reward.shape == (horizon, num_slices, 1) - # # We dont use `batch["next", "done"]` since it only indicates the end of an - # # episode, but not the end of the trajectory of an episode. - # # Neither does `batch["next", "terminated"]` - # done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) - # mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) - # return obs, action, next_obses, reward, mask, done, idxs, weights - - # batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() - - # obs, action, next_obses, reward, mask, done, idxs, weights = process_batch( - # batch, self.cfg.horizon, num_slices - # ) - - # Sample from demonstration dataset - # if demo_batch_size > 0: - # demo_batch = demo_buffer.sample(demo_batch_size) - # ( - # demo_obs, - # demo_action, - # demo_next_obses, - # demo_reward, - # demo_mask, - # demo_done, - # demo_idxs, - # demo_weights, - # ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices) - - # if isinstance(obs, dict): - # obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs} - # next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses} - # else: - # obs = torch.cat([obs, demo_obs]) - # next_obses = torch.cat([next_obses, demo_next_obses], dim=1) - # action = torch.cat([action, demo_action], dim=1) - # reward = torch.cat([reward, demo_reward], dim=1) - # mask = torch.cat([mask, demo_mask], dim=1) - # done = torch.cat([done, demo_done], dim=1) - # idxs = torch.cat([idxs, demo_idxs]) - # weights = torch.cat([weights, demo_weights]) - batch_size = batch["index"].shape[0] # TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels) @@ -534,6 +446,7 @@ class TDMPCPolicy(nn.Module): ) self.optim.step() + # TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion # if self.cfg.per: # # Update priorities # priorities = priority_loss.clamp(max=1e4).detach() diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 7a8d8b58..6b836795 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -18,7 +18,6 @@ env: from_pixels: True pixels_only: False image_size: [3, 480, 640] - action_repeat: 1 episode_length: 400 fps: ${fps} diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index a5fbcc25..a7097ffd 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -18,7 +18,6 @@ env: from_pixels: True pixels_only: False image_size: 96 - action_repeat: 1 episode_length: 300 fps: ${fps} diff --git a/lerobot/configs/env/xarm.yaml b/lerobot/configs/env/xarm.yaml index 8b3c72ef..bcba659e 100644 --- a/lerobot/configs/env/xarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -17,7 +17,6 @@ env: from_pixels: True pixels_only: False image_size: 84 - # action_repeat: 2 # we can remove if policy has n_action_steps=2 episode_length: 25 fps: ${fps} diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 2ebaad9b..4fd2b6bb 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -36,6 +36,7 @@ policy: log_std_max: 2 # learning + batch_size: 256 max_buffer_size: 10000 horizon: 5 reward_coef: 0.5 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 512bb451..394a5d15 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -32,6 +32,7 @@ import json import logging import threading import time +from copy import deepcopy from datetime import datetime as dt from pathlib import Path @@ -57,14 +58,14 @@ def write_video(video_path, stacked_frames, fps): def eval_policy( env: gym.vector.VectorEnv, policy, - save_video: bool = False, + max_episodes_rendered: int = 0, video_dir: Path = None, # TODO(rcadene): make it possible to overwrite fps? we should use env.fps - fps: int = 15, - return_first_video: bool = False, transform: callable = None, seed=None, ): + fps = env.unwrapped.metadata["render_fps"] + if policy is not None: policy.eval() device = "cpu" if policy is None else next(policy.parameters()).device @@ -83,14 +84,11 @@ def eval_policy( # needed as I'm currently taking a ceil. ep_frames = [] - def maybe_render_frame(env): - if save_video: # noqa: B023 - if return_first_video: - visu = env.envs[0].render() - visu = visu[None, ...] # add batch dim - else: - visu = np.stack([env.render() for env in env.envs]) - ep_frames.append(visu) # noqa: B023 + def render_frame(env): + # noqa: B023 + eps_rendered = min(max_episodes_rendered, len(env.envs)) + visu = np.stack([env.envs[i].render() for i in range(eps_rendered)]) + ep_frames.append(visu) # noqa: B023 for _ in range(num_episodes): seeds.append("TODO") @@ -104,8 +102,14 @@ def eval_policy( # reset the environment observation, info = env.reset(seed=seed) - maybe_render_frame(env) + if max_episodes_rendered > 0: + render_frame(env) + observations = [] + actions = [] + # episode + # frame_id + # timestamp rewards = [] successes = [] dones = [] @@ -113,8 +117,13 @@ def eval_policy( done = torch.tensor([False for _ in env.envs]) step = 0 while not done.all(): + # format from env keys to lerobot keys + observation = preprocess_observation(observation) + observations.append(deepcopy(observation)) + # apply transform to normalize the observations - observation = preprocess_observation(observation, transform) + for key in observation: + observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]]) # send observation to device/gpu observation = {key: observation[key].to(device, non_blocking=True) for key in observation} @@ -128,9 +137,11 @@ def eval_policy( # apply the next observation, reward, terminated, truncated, info = env.step(action) - maybe_render_frame(env) + if max_episodes_rendered > 0: + render_frame(env) # TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?) + action = torch.from_numpy(action) reward = torch.from_numpy(reward) terminated = torch.from_numpy(terminated) truncated = torch.from_numpy(truncated) @@ -147,12 +158,24 @@ def eval_policy( success = [False for _ in env.envs] success = torch.tensor(success) + actions.append(action) rewards.append(reward) dones.append(done) successes.append(success) step += 1 + env.close() + + # add the last observation when the env is done + observation = preprocess_observation(observation) + observations.append(deepcopy(observation)) + + new_obses = {} + for key in observations[0].keys(): # noqa: SIM118 + new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1) + observations = new_obses + actions = torch.stack(actions, dim=1) rewards = torch.stack(rewards, dim=1) successes = torch.stack(successes, dim=1) dones = torch.stack(dones, dim=1) @@ -172,29 +195,61 @@ def eval_policy( max_rewards.extend(batch_max_reward.tolist()) all_successes.extend(batch_success.tolist()) - env.close() + # similar logic is implemented in dataset preprocessing + ep_dicts = [] + num_episodes = dones.shape[0] + total_frames = 0 + idx0 = idx1 = 0 + data_ids_per_episode = {} + for ep_id in range(num_episodes): + num_frames = done_indices[ep_id].item() + 1 + # TODO(rcadene): We need to add a missing last frame which is the observation + # of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state" + ep_dict = { + "action": actions[ep_id, :num_frames], + "episode": torch.tensor([ep_id] * num_frames), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / fps, + "next.done": dones[ep_id, :num_frames], + "next.reward": rewards[ep_id, :num_frames].type(torch.float32), + } + for key in observations: + ep_dict[key] = observations[key][ep_id, :num_frames] + ep_dicts.append(ep_dict) - if save_video or return_first_video: + total_frames += num_frames + idx1 += num_frames + + data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1) + + idx0 = idx1 + + # similar logic is implemented in dataset preprocessing + data_dict = {} + keys = ep_dicts[0].keys() + for key in keys: + data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + data_dict["index"] = torch.arange(0, total_frames, 1) + + if max_episodes_rendered > 0: batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *) - if save_video: - for stacked_frames, done_index in zip( - batch_stacked_frames, done_indices.flatten().tolist(), strict=False - ): - if episode_counter >= num_episodes: - continue - video_dir.mkdir(parents=True, exist_ok=True) - video_path = video_dir / f"eval_episode_{episode_counter}.mp4" - thread = threading.Thread( - target=write_video, - args=(str(video_path), stacked_frames[:done_index], fps), - ) - thread.start() - threads.append(thread) - episode_counter += 1 + for stacked_frames, done_index in zip( + batch_stacked_frames, done_indices.flatten().tolist(), strict=False + ): + if episode_counter >= num_episodes: + continue + video_dir.mkdir(parents=True, exist_ok=True) + video_path = video_dir / f"eval_episode_{episode_counter}.mp4" + thread = threading.Thread( + target=write_video, + args=(str(video_path), stacked_frames[:done_index], fps), + ) + thread.start() + threads.append(thread) + episode_counter += 1 - if return_first_video: - first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2) + videos = batch_stacked_frames.transpose(0, 3, 1, 2) for thread in threads: thread.join() @@ -225,9 +280,13 @@ def eval_policy( "eval_s": time.time() - start, "eval_ep_s": (time.time() - start) / num_episodes, }, + "episodes": { + "data_dict": data_dict, + "data_ids_per_episode": data_ids_per_episode, + }, } - if return_first_video: - return info, first_video + if max_episodes_rendered > 0: + info["videos"] = videos return info @@ -259,7 +318,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None): info = eval_policy( env, policy=policy, - save_video=True, + max_episodes_rendered=10, video_dir=Path(out_dir) / "eval", fps=cfg.env.fps, # TODO(rcadene): what should we do with the transform? diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index cca26902..6dfbd12b 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,8 +1,8 @@ import logging +from copy import deepcopy from pathlib import Path import hydra -import numpy as np import torch from lerobot.common.datasets.factory import make_dataset @@ -110,6 +110,64 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline): logger.log_dict(info, step, mode="eval") +def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float): + """ + Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average). + + Parameters: + - n_off (int): Number of offline samples, each with a sampling weight of 1. + - n_on (int): Number of online samples. + - pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5). + + The total weight of offline samples is n_off * 1.0. + The total weight of offline samples is n_on * w. + The total combined weight of all samples is n_off + n_on * w. + The fraction of the weight that is online is n_on * w / (n_off + n_on * w). + We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on. + The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1)) + """ + assert 0.0 <= pc_on <= 1.0 + return -(n_off * pc_on) / (n_on * (pc_on - 1)) + + +def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples): + data_dict = episodes["data_dict"] + data_ids_per_episode = episodes["data_ids_per_episode"] + + if len(online_dataset) == 0: + # initialize online dataset + online_dataset.data_dict = data_dict + online_dataset.data_ids_per_episode = data_ids_per_episode + else: + # find episode index and data frame indices according to previous episode in online_dataset + start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1 + start_index = online_dataset.data_dict["index"][-1].item() + 1 + data_dict["episode"] += start_episode + data_dict["index"] += start_index + + # extend online dataset + for key in data_dict: + # TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure + online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]]) + for ep_id in data_ids_per_episode: + online_dataset.data_ids_per_episode[ep_id + start_episode] = ( + data_ids_per_episode[ep_id] + start_index + ) + + # update the concatenated dataset length used during sampling + concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) + + # update the sampling weights for each frame so that online frames get sampled a certain percentage of times + len_online = len(online_dataset) + len_offline = len(concat_dataset) - len_online + weight_offline = 1.0 + weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples) + sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset)) + + # update the total number of samples used during sampling + sampler.num_samples = len(concat_dataset) + + def train(cfg: dict, out_dir=None, job_name=None): if out_dir is None: raise NotImplementedError() @@ -128,26 +186,7 @@ def train(cfg: dict, out_dir=None, job_name=None): set_global_seed(cfg.seed) logging.info("make_dataset") - dataset = make_dataset(cfg) - - # TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy - # if cfg.policy.balanced_sampling: - # logging.info("make online_buffer") - # num_traj_per_batch = cfg.policy.batch_size - - # online_sampler = PrioritizedSliceSampler( - # max_capacity=100_000, - # alpha=cfg.policy.per_alpha, - # beta=cfg.policy.per_beta, - # num_slices=num_traj_per_batch, - # strict_length=True, - # ) - - # online_buffer = TensorDictReplayBuffer( - # storage=LazyMemmapStorage(100_000), - # sampler=online_sampler, - # transform=dataset.transform, - # ) + offline_dataset = make_dataset(cfg) logging.info("make_env") env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) @@ -165,9 +204,8 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.online_steps=}") - logging.info(f"{cfg.env.action_repeat=}") - logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})") - logging.info(f"{dataset.num_episodes=}") + logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})") + logging.info(f"{offline_dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") @@ -175,18 +213,17 @@ def train(cfg: dict, out_dir=None, job_name=None): def _maybe_eval_and_maybe_save(step): if step % cfg.eval_freq == 0: logging.info(f"Eval policy at step {step}") - eval_info, first_video = eval_policy( + eval_info = eval_policy( env, policy, - return_first_video=True, video_dir=Path(out_dir) / "eval", - save_video=True, - transform=dataset.transform, + max_episodes_rendered=4, + transform=offline_dataset.transform, seed=cfg.seed, ) - log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline) + log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) if cfg.wandb.enable: - logger.log_video(first_video, step, mode="eval") + logger.log_video(eval_info["videos"][0], step, mode="eval") logging.info("Resume training") if cfg.save_model and step % cfg.save_freq == 0: @@ -194,11 +231,9 @@ def train(cfg: dict, out_dir=None, job_name=None): logger.save_model(policy, identifier=step) logging.info("Resume training") - step = 0 # number of policy update (forward + backward + optim) - - is_offline = True + # create dataloader for offline training dataloader = torch.utils.data.DataLoader( - dataset, + offline_dataset, num_workers=4, batch_size=cfg.policy.batch_size, shuffle=True, @@ -206,6 +241,9 @@ def train(cfg: dict, out_dir=None, job_name=None): drop_last=True, ) dl_iter = cycle(dataloader) + + step = 0 # number of policy update (forward + backward + optim) + is_offline = True for offline_step in range(cfg.offline_steps): if offline_step == 0: logging.info("Start offline training on a fixed dataset") @@ -219,7 +257,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: - log_train_info(logger, train_info, step, cfg, dataset, is_offline) + log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in # step + 1. @@ -227,61 +265,60 @@ def train(cfg: dict, out_dir=None, job_name=None): step += 1 - raise NotImplementedError() + # create an env dedicated to online episodes collection from policy rollout + rollout_env = make_env(cfg, num_parallel_envs=1) + + # create an empty online dataset similar to offline dataset + online_dataset = deepcopy(offline_dataset) + online_dataset.data_dict = {} + online_dataset.data_ids_per_episode = {} + + # create dataloader for online training + concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) + weights = [1.0] * len(concat_dataset) + sampler = torch.utils.data.WeightedRandomSampler( + weights, num_samples=len(concat_dataset), replacement=True + ) + dataloader = torch.utils.data.DataLoader( + concat_dataset, + num_workers=4, + batch_size=cfg.policy.batch_size, + sampler=sampler, + pin_memory=cfg.device != "cpu", + drop_last=True, + ) + dl_iter = cycle(dataloader) - demo_buffer = dataset if cfg.policy.balanced_sampling else None online_step = 0 is_offline = False for env_step in range(cfg.online_steps): if env_step == 0: logging.info("Start online training by interacting with environment") - # TODO: add configurable number of rollout? (default=1) + with torch.no_grad(): - rollout = env.rollout( - max_steps=cfg.env.episode_length, - policy=policy, - auto_cast_to_device=True, + eval_info = eval_policy( + rollout_env, + policy, + transform=offline_dataset.transform, + seed=cfg.seed, ) - assert ( - len(rollout.batch_size) == 2 - ), "2 dimensions expected: number of env in parallel x max number of steps during rollout" - - num_parallel_env = rollout.batch_size[0] - if num_parallel_env != 1: - # TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests - raise NotImplementedError() - - num_max_steps = rollout.batch_size[1] - assert num_max_steps <= cfg.env.episode_length - - # reshape to have a list of steps to insert into online_buffer - rollout = rollout.reshape(num_parallel_env * num_max_steps) - - # set same episode index for all time steps contained in this rollout - rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) - # online_buffer.extend(rollout) - - ep_sum_reward = rollout["next", "reward"].sum() - ep_max_reward = rollout["next", "reward"].max() - ep_success = rollout["next", "success"].any() - rollout_info = { - "avg_sum_reward": np.nanmean(ep_sum_reward), - "avg_max_reward": np.nanmean(ep_max_reward), - "pc_success": np.nanmean(ep_success) * 100, - "env_step": env_step, - "ep_length": len(rollout), - } + online_pc_sampling = cfg.get("demo_schedule", 0.5) + add_episodes_inplace( + eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling + ) for _ in range(cfg.policy.utd): - train_info = policy.update( - # online_buffer, - step, - demo_buffer=demo_buffer, - ) + policy.train() + batch = next(dl_iter) + + for key in batch: + batch[key] = batch[key].to(cfg.device, non_blocking=True) + + train_info = policy(batch, step) + if step % cfg.log_freq == 0: - train_info.update(rollout_info) - log_train_info(logger, train_info, step, cfg, dataset, is_offline) + log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass # in step + 1.