online training works (loss goes down), remove repeat_action, eval_policy outputs episodes data, eval_policy uses max_episodes_rendered

This commit is contained in:
Cadene 2024-04-10 11:34:01 +00:00
parent 19e7661b8d
commit 06573d7f67
11 changed files with 219 additions and 211 deletions

View File

@ -105,7 +105,7 @@ class AlohaDataset(torch.utils.data.Dataset):
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
return len(self.data_dict["index"]) return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:

View File

@ -119,7 +119,7 @@ class PushtDataset(torch.utils.data.Dataset):
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
return len(self.data_dict["index"]) return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:

View File

@ -60,7 +60,7 @@ class XarmDataset(torch.utils.data.Dataset):
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
return len(self.data_dict["index"]) return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
@ -126,7 +126,8 @@ class XarmDataset(torch.utils.data.Dataset):
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
action = torch.tensor(dataset_dict["actions"][idx0:idx1]) action = torch.tensor(dataset_dict["actions"][idx0:idx1])
# TODO(rcadene): concat the last "next_observations" to "observations" # TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) # next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])

View File

@ -35,9 +35,9 @@ def make_policy(cfg):
if cfg.policy.pretrained_model_path: if cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm # TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path: if "offline" in cfg.policy.pretrained_model_path:
policy.step[0] = 25000 policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path: elif "final" in cfg.policy.pretrained_model_path:
policy.step[0] = 100000 policy.step[0] = 100000
else: else:
raise NotImplementedError() raise NotImplementedError()

View File

@ -333,94 +333,6 @@ class TDMPCPolicy(nn.Module):
"""Main update function. Corresponds to one iteration of the model learning.""" """Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time() start_time = time.time()
# num_slices = self.cfg.batch_size
# batch_size = self.cfg.horizon * num_slices
# if demo_buffer is None:
# demo_batch_size = 0
# else:
# # Update oversampling ratio
# demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
# demo_num_slices = int(demo_pc_batch * self.batch_size)
# demo_batch_size = self.cfg.horizon * demo_num_slices
# batch_size -= demo_batch_size
# num_slices -= demo_num_slices
# replay_buffer._sampler.num_slices = num_slices
# demo_buffer._sampler.num_slices = demo_num_slices
# assert demo_batch_size % self.cfg.horizon == 0
# assert demo_batch_size % demo_num_slices == 0
# assert batch_size % self.cfg.horizon == 0
# assert batch_size % num_slices == 0
# # Sample from interaction dataset
# def process_batch(batch, horizon, num_slices):
# # trajectory t = 256, horizon h = 5
# # (t h) ... -> h t ...
# batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
# obs = {
# "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
# "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
# }
# action = batch["action"].to(self.device, non_blocking=True)
# next_obses = {
# "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
# "state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
# }
# reward = batch["next", "reward"].to(self.device, non_blocking=True)
# idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
# weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
# # TODO(rcadene): rearrange directly in offline dataset
# if reward.ndim == 2:
# reward = einops.rearrange(reward, "h t -> h t 1")
# assert reward.ndim == 3
# assert reward.shape == (horizon, num_slices, 1)
# # We dont use `batch["next", "done"]` since it only indicates the end of an
# # episode, but not the end of the trajectory of an episode.
# # Neither does `batch["next", "terminated"]`
# done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
# mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
# return obs, action, next_obses, reward, mask, done, idxs, weights
# batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
# obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
# batch, self.cfg.horizon, num_slices
# )
# Sample from demonstration dataset
# if demo_batch_size > 0:
# demo_batch = demo_buffer.sample(demo_batch_size)
# (
# demo_obs,
# demo_action,
# demo_next_obses,
# demo_reward,
# demo_mask,
# demo_done,
# demo_idxs,
# demo_weights,
# ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
# if isinstance(obs, dict):
# obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
# next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
# else:
# obs = torch.cat([obs, demo_obs])
# next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
# action = torch.cat([action, demo_action], dim=1)
# reward = torch.cat([reward, demo_reward], dim=1)
# mask = torch.cat([mask, demo_mask], dim=1)
# done = torch.cat([done, demo_done], dim=1)
# idxs = torch.cat([idxs, demo_idxs])
# weights = torch.cat([weights, demo_weights])
batch_size = batch["index"].shape[0] batch_size = batch["index"].shape[0]
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels) # TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
@ -534,6 +446,7 @@ class TDMPCPolicy(nn.Module):
) )
self.optim.step() self.optim.step()
# TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
# if self.cfg.per: # if self.cfg.per:
# # Update priorities # # Update priorities
# priorities = priority_loss.clamp(max=1e4).detach() # priorities = priority_loss.clamp(max=1e4).detach()

View File

@ -18,7 +18,6 @@ env:
from_pixels: True from_pixels: True
pixels_only: False pixels_only: False
image_size: [3, 480, 640] image_size: [3, 480, 640]
action_repeat: 1
episode_length: 400 episode_length: 400
fps: ${fps} fps: ${fps}

View File

@ -18,7 +18,6 @@ env:
from_pixels: True from_pixels: True
pixels_only: False pixels_only: False
image_size: 96 image_size: 96
action_repeat: 1
episode_length: 300 episode_length: 300
fps: ${fps} fps: ${fps}

View File

@ -17,7 +17,6 @@ env:
from_pixels: True from_pixels: True
pixels_only: False pixels_only: False
image_size: 84 image_size: 84
# action_repeat: 2 # we can remove if policy has n_action_steps=2
episode_length: 25 episode_length: 25
fps: ${fps} fps: ${fps}

View File

@ -36,6 +36,7 @@ policy:
log_std_max: 2 log_std_max: 2
# learning # learning
batch_size: 256
max_buffer_size: 10000 max_buffer_size: 10000
horizon: 5 horizon: 5
reward_coef: 0.5 reward_coef: 0.5

View File

@ -32,6 +32,7 @@ import json
import logging import logging
import threading import threading
import time import time
from copy import deepcopy
from datetime import datetime as dt from datetime import datetime as dt
from pathlib import Path from pathlib import Path
@ -57,14 +58,14 @@ def write_video(video_path, stacked_frames, fps):
def eval_policy( def eval_policy(
env: gym.vector.VectorEnv, env: gym.vector.VectorEnv,
policy, policy,
save_video: bool = False, max_episodes_rendered: int = 0,
video_dir: Path = None, video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps # TODO(rcadene): make it possible to overwrite fps? we should use env.fps
fps: int = 15,
return_first_video: bool = False,
transform: callable = None, transform: callable = None,
seed=None, seed=None,
): ):
fps = env.unwrapped.metadata["render_fps"]
if policy is not None: if policy is not None:
policy.eval() policy.eval()
device = "cpu" if policy is None else next(policy.parameters()).device device = "cpu" if policy is None else next(policy.parameters()).device
@ -83,14 +84,11 @@ def eval_policy(
# needed as I'm currently taking a ceil. # needed as I'm currently taking a ceil.
ep_frames = [] ep_frames = []
def maybe_render_frame(env): def render_frame(env):
if save_video: # noqa: B023 # noqa: B023
if return_first_video: eps_rendered = min(max_episodes_rendered, len(env.envs))
visu = env.envs[0].render() visu = np.stack([env.envs[i].render() for i in range(eps_rendered)])
visu = visu[None, ...] # add batch dim ep_frames.append(visu) # noqa: B023
else:
visu = np.stack([env.render() for env in env.envs])
ep_frames.append(visu) # noqa: B023
for _ in range(num_episodes): for _ in range(num_episodes):
seeds.append("TODO") seeds.append("TODO")
@ -104,8 +102,14 @@ def eval_policy(
# reset the environment # reset the environment
observation, info = env.reset(seed=seed) observation, info = env.reset(seed=seed)
maybe_render_frame(env) if max_episodes_rendered > 0:
render_frame(env)
observations = []
actions = []
# episode
# frame_id
# timestamp
rewards = [] rewards = []
successes = [] successes = []
dones = [] dones = []
@ -113,8 +117,13 @@ def eval_policy(
done = torch.tensor([False for _ in env.envs]) done = torch.tensor([False for _ in env.envs])
step = 0 step = 0
while not done.all(): while not done.all():
# format from env keys to lerobot keys
observation = preprocess_observation(observation)
observations.append(deepcopy(observation))
# apply transform to normalize the observations # apply transform to normalize the observations
observation = preprocess_observation(observation, transform) for key in observation:
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
# send observation to device/gpu # send observation to device/gpu
observation = {key: observation[key].to(device, non_blocking=True) for key in observation} observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
@ -128,9 +137,11 @@ def eval_policy(
# apply the next # apply the next
observation, reward, terminated, truncated, info = env.step(action) observation, reward, terminated, truncated, info = env.step(action)
maybe_render_frame(env) if max_episodes_rendered > 0:
render_frame(env)
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?) # TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
action = torch.from_numpy(action)
reward = torch.from_numpy(reward) reward = torch.from_numpy(reward)
terminated = torch.from_numpy(terminated) terminated = torch.from_numpy(terminated)
truncated = torch.from_numpy(truncated) truncated = torch.from_numpy(truncated)
@ -147,12 +158,24 @@ def eval_policy(
success = [False for _ in env.envs] success = [False for _ in env.envs]
success = torch.tensor(success) success = torch.tensor(success)
actions.append(action)
rewards.append(reward) rewards.append(reward)
dones.append(done) dones.append(done)
successes.append(success) successes.append(success)
step += 1 step += 1
env.close()
# add the last observation when the env is done
observation = preprocess_observation(observation)
observations.append(deepcopy(observation))
new_obses = {}
for key in observations[0].keys(): # noqa: SIM118
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
observations = new_obses
actions = torch.stack(actions, dim=1)
rewards = torch.stack(rewards, dim=1) rewards = torch.stack(rewards, dim=1)
successes = torch.stack(successes, dim=1) successes = torch.stack(successes, dim=1)
dones = torch.stack(dones, dim=1) dones = torch.stack(dones, dim=1)
@ -172,29 +195,61 @@ def eval_policy(
max_rewards.extend(batch_max_reward.tolist()) max_rewards.extend(batch_max_reward.tolist())
all_successes.extend(batch_success.tolist()) all_successes.extend(batch_success.tolist())
env.close() # similar logic is implemented in dataset preprocessing
ep_dicts = []
num_episodes = dones.shape[0]
total_frames = 0
idx0 = idx1 = 0
data_ids_per_episode = {}
for ep_id in range(num_episodes):
num_frames = done_indices[ep_id].item() + 1
# TODO(rcadene): We need to add a missing last frame which is the observation
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
ep_dict = {
"action": actions[ep_id, :num_frames],
"episode": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
}
for key in observations:
ep_dict[key] = observations[key][ep_id, :num_frames]
ep_dicts.append(ep_dict)
if save_video or return_first_video: total_frames += num_frames
idx1 += num_frames
data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
idx0 = idx1
# similar logic is implemented in dataset preprocessing
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(0, total_frames, 1)
if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *) batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
if save_video: for stacked_frames, done_index in zip(
for stacked_frames, done_index in zip( batch_stacked_frames, done_indices.flatten().tolist(), strict=False
batch_stacked_frames, done_indices.flatten().tolist(), strict=False ):
): if episode_counter >= num_episodes:
if episode_counter >= num_episodes: continue
continue video_dir.mkdir(parents=True, exist_ok=True)
video_dir.mkdir(parents=True, exist_ok=True) video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
video_path = video_dir / f"eval_episode_{episode_counter}.mp4" thread = threading.Thread(
thread = threading.Thread( target=write_video,
target=write_video, args=(str(video_path), stacked_frames[:done_index], fps),
args=(str(video_path), stacked_frames[:done_index], fps), )
) thread.start()
thread.start() threads.append(thread)
threads.append(thread) episode_counter += 1
episode_counter += 1
if return_first_video: videos = batch_stacked_frames.transpose(0, 3, 1, 2)
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
for thread in threads: for thread in threads:
thread.join() thread.join()
@ -225,9 +280,13 @@ def eval_policy(
"eval_s": time.time() - start, "eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes, "eval_ep_s": (time.time() - start) / num_episodes,
}, },
"episodes": {
"data_dict": data_dict,
"data_ids_per_episode": data_ids_per_episode,
},
} }
if return_first_video: if max_episodes_rendered > 0:
return info, first_video info["videos"] = videos
return info return info
@ -259,7 +318,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
info = eval_policy( info = eval_policy(
env, env,
policy=policy, policy=policy,
save_video=True, max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps, fps=cfg.env.fps,
# TODO(rcadene): what should we do with the transform? # TODO(rcadene): what should we do with the transform?

View File

@ -1,8 +1,8 @@
import logging import logging
from copy import deepcopy
from pathlib import Path from pathlib import Path
import hydra import hydra
import numpy as np
import torch import torch
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
@ -110,6 +110,64 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
logger.log_dict(info, step, mode="eval") logger.log_dict(info, step, mode="eval")
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
"""
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
Parameters:
- n_off (int): Number of offline samples, each with a sampling weight of 1.
- n_on (int): Number of online samples.
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
The total weight of offline samples is n_off * 1.0.
The total weight of offline samples is n_on * w.
The total combined weight of all samples is n_off + n_on * w.
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
"""
assert 0.0 <= pc_on <= 1.0
return -(n_off * pc_on) / (n_on * (pc_on - 1))
def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples):
data_dict = episodes["data_dict"]
data_ids_per_episode = episodes["data_ids_per_episode"]
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.data_dict = data_dict
online_dataset.data_ids_per_episode = data_ids_per_episode
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1
start_index = online_dataset.data_dict["index"][-1].item() + 1
data_dict["episode"] += start_episode
data_dict["index"] += start_index
# extend online dataset
for key in data_dict:
# TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure
online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]])
for ep_id in data_ids_per_episode:
online_dataset.data_ids_per_episode[ep_id + start_episode] = (
data_ids_per_episode[ep_id] + start_index
)
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
len_online = len(online_dataset)
len_offline = len(concat_dataset) - len_online
weight_offline = 1.0
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
# update the total number of samples used during sampling
sampler.num_samples = len(concat_dataset)
def train(cfg: dict, out_dir=None, job_name=None): def train(cfg: dict, out_dir=None, job_name=None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
@ -128,26 +186,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
set_global_seed(cfg.seed) set_global_seed(cfg.seed)
logging.info("make_dataset") logging.info("make_dataset")
dataset = make_dataset(cfg) offline_dataset = make_dataset(cfg)
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
# if cfg.policy.balanced_sampling:
# logging.info("make online_buffer")
# num_traj_per_batch = cfg.policy.batch_size
# online_sampler = PrioritizedSliceSampler(
# max_capacity=100_000,
# alpha=cfg.policy.per_alpha,
# beta=cfg.policy.per_beta,
# num_slices=num_traj_per_batch,
# strict_length=True,
# )
# online_buffer = TensorDictReplayBuffer(
# storage=LazyMemmapStorage(100_000),
# sampler=online_sampler,
# transform=dataset.transform,
# )
logging.info("make_env") logging.info("make_env")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
@ -165,9 +204,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}") logging.info(f"{cfg.online_steps=}")
logging.info(f"{cfg.env.action_repeat=}") logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})") logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
@ -175,18 +213,17 @@ def train(cfg: dict, out_dir=None, job_name=None):
def _maybe_eval_and_maybe_save(step): def _maybe_eval_and_maybe_save(step):
if step % cfg.eval_freq == 0: if step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy( eval_info = eval_policy(
env, env,
policy, policy,
return_first_video=True,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
save_video=True, max_episodes_rendered=4,
transform=dataset.transform, transform=offline_dataset.transform,
seed=cfg.seed, seed=cfg.seed,
) )
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable: if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval") logger.log_video(eval_info["videos"][0], step, mode="eval")
logging.info("Resume training") logging.info("Resume training")
if cfg.save_model and step % cfg.save_freq == 0: if cfg.save_model and step % cfg.save_freq == 0:
@ -194,11 +231,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
logger.save_model(policy, identifier=step) logger.save_model(policy, identifier=step)
logging.info("Resume training") logging.info("Resume training")
step = 0 # number of policy update (forward + backward + optim) # create dataloader for offline training
is_offline = True
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, offline_dataset,
num_workers=4, num_workers=4,
batch_size=cfg.policy.batch_size, batch_size=cfg.policy.batch_size,
shuffle=True, shuffle=True,
@ -206,6 +241,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
drop_last=True, drop_last=True,
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
for offline_step in range(cfg.offline_steps): for offline_step in range(cfg.offline_steps):
if offline_step == 0: if offline_step == 0:
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
@ -219,7 +257,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0: if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, dataset, is_offline) log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
# step + 1. # step + 1.
@ -227,61 +265,60 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1 step += 1
raise NotImplementedError() # create an env dedicated to online episodes collection from policy rollout
rollout_env = make_env(cfg, num_parallel_envs=1)
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.data_dict = {}
online_dataset.data_ids_per_episode = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
weights = [1.0] * len(concat_dataset)
sampler = torch.utils.data.WeightedRandomSampler(
weights, num_samples=len(concat_dataset), replacement=True
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
sampler=sampler,
pin_memory=cfg.device != "cpu",
drop_last=True,
)
dl_iter = cycle(dataloader)
demo_buffer = dataset if cfg.policy.balanced_sampling else None
online_step = 0 online_step = 0
is_offline = False is_offline = False
for env_step in range(cfg.online_steps): for env_step in range(cfg.online_steps):
if env_step == 0: if env_step == 0:
logging.info("Start online training by interacting with environment") logging.info("Start online training by interacting with environment")
# TODO: add configurable number of rollout? (default=1)
with torch.no_grad(): with torch.no_grad():
rollout = env.rollout( eval_info = eval_policy(
max_steps=cfg.env.episode_length, rollout_env,
policy=policy, policy,
auto_cast_to_device=True, transform=offline_dataset.transform,
seed=cfg.seed,
) )
assert ( online_pc_sampling = cfg.get("demo_schedule", 0.5)
len(rollout.batch_size) == 2 add_episodes_inplace(
), "2 dimensions expected: number of env in parallel x max number of steps during rollout" eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
)
num_parallel_env = rollout.batch_size[0]
if num_parallel_env != 1:
# TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests
raise NotImplementedError()
num_max_steps = rollout.batch_size[1]
assert num_max_steps <= cfg.env.episode_length
# reshape to have a list of steps to insert into online_buffer
rollout = rollout.reshape(num_parallel_env * num_max_steps)
# set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
# online_buffer.extend(rollout)
ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
ep_success = rollout["next", "success"].any()
rollout_info = {
"avg_sum_reward": np.nanmean(ep_sum_reward),
"avg_max_reward": np.nanmean(ep_max_reward),
"pc_success": np.nanmean(ep_success) * 100,
"env_step": env_step,
"ep_length": len(rollout),
}
for _ in range(cfg.policy.utd): for _ in range(cfg.policy.utd):
train_info = policy.update( policy.train()
# online_buffer, batch = next(dl_iter)
step,
demo_buffer=demo_buffer, for key in batch:
) batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy(batch, step)
if step % cfg.log_freq == 0: if step % cfg.log_freq == 0:
train_info.update(rollout_info) log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
# in step + 1. # in step + 1.