online training works (loss goes down), remove repeat_action, eval_policy outputs episodes data, eval_policy uses max_episodes_rendered
This commit is contained in:
parent
19e7661b8d
commit
06573d7f67
|
@ -105,7 +105,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
return len(self.data_dict["index"])
|
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
|
|
|
@ -119,7 +119,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
return len(self.data_dict["index"])
|
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
|
|
|
@ -60,7 +60,7 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
return len(self.data_dict["index"])
|
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
|
@ -126,7 +126,8 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||||
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
|
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
|
||||||
# TODO(rcadene): concat the last "next_observations" to "observations"
|
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||||
|
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||||
|
|
|
@ -35,9 +35,9 @@ def make_policy(cfg):
|
||||||
if cfg.policy.pretrained_model_path:
|
if cfg.policy.pretrained_model_path:
|
||||||
# TODO(rcadene): hack for old pretrained models from fowm
|
# TODO(rcadene): hack for old pretrained models from fowm
|
||||||
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
||||||
if "offline" in cfg.pretrained_model_path:
|
if "offline" in cfg.policy.pretrained_model_path:
|
||||||
policy.step[0] = 25000
|
policy.step[0] = 25000
|
||||||
elif "final" in cfg.pretrained_model_path:
|
elif "final" in cfg.policy.pretrained_model_path:
|
||||||
policy.step[0] = 100000
|
policy.step[0] = 100000
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -333,94 +333,6 @@ class TDMPCPolicy(nn.Module):
|
||||||
"""Main update function. Corresponds to one iteration of the model learning."""
|
"""Main update function. Corresponds to one iteration of the model learning."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# num_slices = self.cfg.batch_size
|
|
||||||
# batch_size = self.cfg.horizon * num_slices
|
|
||||||
|
|
||||||
# if demo_buffer is None:
|
|
||||||
# demo_batch_size = 0
|
|
||||||
# else:
|
|
||||||
# # Update oversampling ratio
|
|
||||||
# demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
|
|
||||||
# demo_num_slices = int(demo_pc_batch * self.batch_size)
|
|
||||||
# demo_batch_size = self.cfg.horizon * demo_num_slices
|
|
||||||
# batch_size -= demo_batch_size
|
|
||||||
# num_slices -= demo_num_slices
|
|
||||||
# replay_buffer._sampler.num_slices = num_slices
|
|
||||||
# demo_buffer._sampler.num_slices = demo_num_slices
|
|
||||||
|
|
||||||
# assert demo_batch_size % self.cfg.horizon == 0
|
|
||||||
# assert demo_batch_size % demo_num_slices == 0
|
|
||||||
|
|
||||||
# assert batch_size % self.cfg.horizon == 0
|
|
||||||
# assert batch_size % num_slices == 0
|
|
||||||
|
|
||||||
# # Sample from interaction dataset
|
|
||||||
|
|
||||||
# def process_batch(batch, horizon, num_slices):
|
|
||||||
# # trajectory t = 256, horizon h = 5
|
|
||||||
# # (t h) ... -> h t ...
|
|
||||||
# batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
|
||||||
|
|
||||||
# obs = {
|
|
||||||
# "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
|
|
||||||
# "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
|
|
||||||
# }
|
|
||||||
# action = batch["action"].to(self.device, non_blocking=True)
|
|
||||||
# next_obses = {
|
|
||||||
# "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
|
|
||||||
# "state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
|
|
||||||
# }
|
|
||||||
# reward = batch["next", "reward"].to(self.device, non_blocking=True)
|
|
||||||
|
|
||||||
# idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
|
|
||||||
# weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
|
|
||||||
|
|
||||||
# # TODO(rcadene): rearrange directly in offline dataset
|
|
||||||
# if reward.ndim == 2:
|
|
||||||
# reward = einops.rearrange(reward, "h t -> h t 1")
|
|
||||||
|
|
||||||
# assert reward.ndim == 3
|
|
||||||
# assert reward.shape == (horizon, num_slices, 1)
|
|
||||||
# # We dont use `batch["next", "done"]` since it only indicates the end of an
|
|
||||||
# # episode, but not the end of the trajectory of an episode.
|
|
||||||
# # Neither does `batch["next", "terminated"]`
|
|
||||||
# done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
|
||||||
# mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
|
||||||
# return obs, action, next_obses, reward, mask, done, idxs, weights
|
|
||||||
|
|
||||||
# batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
|
||||||
|
|
||||||
# obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
|
|
||||||
# batch, self.cfg.horizon, num_slices
|
|
||||||
# )
|
|
||||||
|
|
||||||
# Sample from demonstration dataset
|
|
||||||
# if demo_batch_size > 0:
|
|
||||||
# demo_batch = demo_buffer.sample(demo_batch_size)
|
|
||||||
# (
|
|
||||||
# demo_obs,
|
|
||||||
# demo_action,
|
|
||||||
# demo_next_obses,
|
|
||||||
# demo_reward,
|
|
||||||
# demo_mask,
|
|
||||||
# demo_done,
|
|
||||||
# demo_idxs,
|
|
||||||
# demo_weights,
|
|
||||||
# ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
|
|
||||||
|
|
||||||
# if isinstance(obs, dict):
|
|
||||||
# obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
|
|
||||||
# next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
|
|
||||||
# else:
|
|
||||||
# obs = torch.cat([obs, demo_obs])
|
|
||||||
# next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
|
|
||||||
# action = torch.cat([action, demo_action], dim=1)
|
|
||||||
# reward = torch.cat([reward, demo_reward], dim=1)
|
|
||||||
# mask = torch.cat([mask, demo_mask], dim=1)
|
|
||||||
# done = torch.cat([done, demo_done], dim=1)
|
|
||||||
# idxs = torch.cat([idxs, demo_idxs])
|
|
||||||
# weights = torch.cat([weights, demo_weights])
|
|
||||||
|
|
||||||
batch_size = batch["index"].shape[0]
|
batch_size = batch["index"].shape[0]
|
||||||
|
|
||||||
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
||||||
|
@ -534,6 +446,7 @@ class TDMPCPolicy(nn.Module):
|
||||||
)
|
)
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
|
|
||||||
|
# TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
|
||||||
# if self.cfg.per:
|
# if self.cfg.per:
|
||||||
# # Update priorities
|
# # Update priorities
|
||||||
# priorities = priority_loss.clamp(max=1e4).detach()
|
# priorities = priority_loss.clamp(max=1e4).detach()
|
||||||
|
|
|
@ -18,7 +18,6 @@ env:
|
||||||
from_pixels: True
|
from_pixels: True
|
||||||
pixels_only: False
|
pixels_only: False
|
||||||
image_size: [3, 480, 640]
|
image_size: [3, 480, 640]
|
||||||
action_repeat: 1
|
|
||||||
episode_length: 400
|
episode_length: 400
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,6 @@ env:
|
||||||
from_pixels: True
|
from_pixels: True
|
||||||
pixels_only: False
|
pixels_only: False
|
||||||
image_size: 96
|
image_size: 96
|
||||||
action_repeat: 1
|
|
||||||
episode_length: 300
|
episode_length: 300
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@ env:
|
||||||
from_pixels: True
|
from_pixels: True
|
||||||
pixels_only: False
|
pixels_only: False
|
||||||
image_size: 84
|
image_size: 84
|
||||||
# action_repeat: 2 # we can remove if policy has n_action_steps=2
|
|
||||||
episode_length: 25
|
episode_length: 25
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ policy:
|
||||||
log_std_max: 2
|
log_std_max: 2
|
||||||
|
|
||||||
# learning
|
# learning
|
||||||
|
batch_size: 256
|
||||||
max_buffer_size: 10000
|
max_buffer_size: 10000
|
||||||
horizon: 5
|
horizon: 5
|
||||||
reward_coef: 0.5
|
reward_coef: 0.5
|
||||||
|
|
|
@ -32,6 +32,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from copy import deepcopy
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -57,14 +58,14 @@ def write_video(video_path, stacked_frames, fps):
|
||||||
def eval_policy(
|
def eval_policy(
|
||||||
env: gym.vector.VectorEnv,
|
env: gym.vector.VectorEnv,
|
||||||
policy,
|
policy,
|
||||||
save_video: bool = False,
|
max_episodes_rendered: int = 0,
|
||||||
video_dir: Path = None,
|
video_dir: Path = None,
|
||||||
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
||||||
fps: int = 15,
|
|
||||||
return_first_video: bool = False,
|
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
seed=None,
|
seed=None,
|
||||||
):
|
):
|
||||||
|
fps = env.unwrapped.metadata["render_fps"]
|
||||||
|
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
policy.eval()
|
policy.eval()
|
||||||
device = "cpu" if policy is None else next(policy.parameters()).device
|
device = "cpu" if policy is None else next(policy.parameters()).device
|
||||||
|
@ -83,14 +84,11 @@ def eval_policy(
|
||||||
# needed as I'm currently taking a ceil.
|
# needed as I'm currently taking a ceil.
|
||||||
ep_frames = []
|
ep_frames = []
|
||||||
|
|
||||||
def maybe_render_frame(env):
|
def render_frame(env):
|
||||||
if save_video: # noqa: B023
|
# noqa: B023
|
||||||
if return_first_video:
|
eps_rendered = min(max_episodes_rendered, len(env.envs))
|
||||||
visu = env.envs[0].render()
|
visu = np.stack([env.envs[i].render() for i in range(eps_rendered)])
|
||||||
visu = visu[None, ...] # add batch dim
|
ep_frames.append(visu) # noqa: B023
|
||||||
else:
|
|
||||||
visu = np.stack([env.render() for env in env.envs])
|
|
||||||
ep_frames.append(visu) # noqa: B023
|
|
||||||
|
|
||||||
for _ in range(num_episodes):
|
for _ in range(num_episodes):
|
||||||
seeds.append("TODO")
|
seeds.append("TODO")
|
||||||
|
@ -104,8 +102,14 @@ def eval_policy(
|
||||||
|
|
||||||
# reset the environment
|
# reset the environment
|
||||||
observation, info = env.reset(seed=seed)
|
observation, info = env.reset(seed=seed)
|
||||||
maybe_render_frame(env)
|
if max_episodes_rendered > 0:
|
||||||
|
render_frame(env)
|
||||||
|
|
||||||
|
observations = []
|
||||||
|
actions = []
|
||||||
|
# episode
|
||||||
|
# frame_id
|
||||||
|
# timestamp
|
||||||
rewards = []
|
rewards = []
|
||||||
successes = []
|
successes = []
|
||||||
dones = []
|
dones = []
|
||||||
|
@ -113,8 +117,13 @@ def eval_policy(
|
||||||
done = torch.tensor([False for _ in env.envs])
|
done = torch.tensor([False for _ in env.envs])
|
||||||
step = 0
|
step = 0
|
||||||
while not done.all():
|
while not done.all():
|
||||||
|
# format from env keys to lerobot keys
|
||||||
|
observation = preprocess_observation(observation)
|
||||||
|
observations.append(deepcopy(observation))
|
||||||
|
|
||||||
# apply transform to normalize the observations
|
# apply transform to normalize the observations
|
||||||
observation = preprocess_observation(observation, transform)
|
for key in observation:
|
||||||
|
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
|
||||||
|
|
||||||
# send observation to device/gpu
|
# send observation to device/gpu
|
||||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||||
|
@ -128,9 +137,11 @@ def eval_policy(
|
||||||
|
|
||||||
# apply the next
|
# apply the next
|
||||||
observation, reward, terminated, truncated, info = env.step(action)
|
observation, reward, terminated, truncated, info = env.step(action)
|
||||||
maybe_render_frame(env)
|
if max_episodes_rendered > 0:
|
||||||
|
render_frame(env)
|
||||||
|
|
||||||
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
|
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
|
||||||
|
action = torch.from_numpy(action)
|
||||||
reward = torch.from_numpy(reward)
|
reward = torch.from_numpy(reward)
|
||||||
terminated = torch.from_numpy(terminated)
|
terminated = torch.from_numpy(terminated)
|
||||||
truncated = torch.from_numpy(truncated)
|
truncated = torch.from_numpy(truncated)
|
||||||
|
@ -147,12 +158,24 @@ def eval_policy(
|
||||||
success = [False for _ in env.envs]
|
success = [False for _ in env.envs]
|
||||||
success = torch.tensor(success)
|
success = torch.tensor(success)
|
||||||
|
|
||||||
|
actions.append(action)
|
||||||
rewards.append(reward)
|
rewards.append(reward)
|
||||||
dones.append(done)
|
dones.append(done)
|
||||||
successes.append(success)
|
successes.append(success)
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
# add the last observation when the env is done
|
||||||
|
observation = preprocess_observation(observation)
|
||||||
|
observations.append(deepcopy(observation))
|
||||||
|
|
||||||
|
new_obses = {}
|
||||||
|
for key in observations[0].keys(): # noqa: SIM118
|
||||||
|
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
|
||||||
|
observations = new_obses
|
||||||
|
actions = torch.stack(actions, dim=1)
|
||||||
rewards = torch.stack(rewards, dim=1)
|
rewards = torch.stack(rewards, dim=1)
|
||||||
successes = torch.stack(successes, dim=1)
|
successes = torch.stack(successes, dim=1)
|
||||||
dones = torch.stack(dones, dim=1)
|
dones = torch.stack(dones, dim=1)
|
||||||
|
@ -172,29 +195,61 @@ def eval_policy(
|
||||||
max_rewards.extend(batch_max_reward.tolist())
|
max_rewards.extend(batch_max_reward.tolist())
|
||||||
all_successes.extend(batch_success.tolist())
|
all_successes.extend(batch_success.tolist())
|
||||||
|
|
||||||
env.close()
|
# similar logic is implemented in dataset preprocessing
|
||||||
|
ep_dicts = []
|
||||||
|
num_episodes = dones.shape[0]
|
||||||
|
total_frames = 0
|
||||||
|
idx0 = idx1 = 0
|
||||||
|
data_ids_per_episode = {}
|
||||||
|
for ep_id in range(num_episodes):
|
||||||
|
num_frames = done_indices[ep_id].item() + 1
|
||||||
|
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||||
|
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||||
|
ep_dict = {
|
||||||
|
"action": actions[ep_id, :num_frames],
|
||||||
|
"episode": torch.tensor([ep_id] * num_frames),
|
||||||
|
"frame_id": torch.arange(0, num_frames, 1),
|
||||||
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
|
"next.done": dones[ep_id, :num_frames],
|
||||||
|
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||||
|
}
|
||||||
|
for key in observations:
|
||||||
|
ep_dict[key] = observations[key][ep_id, :num_frames]
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
if save_video or return_first_video:
|
total_frames += num_frames
|
||||||
|
idx1 += num_frames
|
||||||
|
|
||||||
|
data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
|
||||||
|
|
||||||
|
idx0 = idx1
|
||||||
|
|
||||||
|
# similar logic is implemented in dataset preprocessing
|
||||||
|
data_dict = {}
|
||||||
|
keys = ep_dicts[0].keys()
|
||||||
|
for key in keys:
|
||||||
|
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
|
if max_episodes_rendered > 0:
|
||||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||||
|
|
||||||
if save_video:
|
for stacked_frames, done_index in zip(
|
||||||
for stacked_frames, done_index in zip(
|
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
||||||
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
):
|
||||||
):
|
if episode_counter >= num_episodes:
|
||||||
if episode_counter >= num_episodes:
|
continue
|
||||||
continue
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
thread = threading.Thread(
|
||||||
thread = threading.Thread(
|
target=write_video,
|
||||||
target=write_video,
|
args=(str(video_path), stacked_frames[:done_index], fps),
|
||||||
args=(str(video_path), stacked_frames[:done_index], fps),
|
)
|
||||||
)
|
thread.start()
|
||||||
thread.start()
|
threads.append(thread)
|
||||||
threads.append(thread)
|
episode_counter += 1
|
||||||
episode_counter += 1
|
|
||||||
|
|
||||||
if return_first_video:
|
videos = batch_stacked_frames.transpose(0, 3, 1, 2)
|
||||||
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
|
|
||||||
|
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
@ -225,9 +280,13 @@ def eval_policy(
|
||||||
"eval_s": time.time() - start,
|
"eval_s": time.time() - start,
|
||||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||||
},
|
},
|
||||||
|
"episodes": {
|
||||||
|
"data_dict": data_dict,
|
||||||
|
"data_ids_per_episode": data_ids_per_episode,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if return_first_video:
|
if max_episodes_rendered > 0:
|
||||||
return info, first_video
|
info["videos"] = videos
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
@ -259,7 +318,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
info = eval_policy(
|
info = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
save_video=True,
|
max_episodes_rendered=10,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
fps=cfg.env.fps,
|
fps=cfg.env.fps,
|
||||||
# TODO(rcadene): what should we do with the transform?
|
# TODO(rcadene): what should we do with the transform?
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
@ -110,6 +110,64 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
||||||
logger.log_dict(info, step, mode="eval")
|
logger.log_dict(info, step, mode="eval")
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
||||||
|
"""
|
||||||
|
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- n_off (int): Number of offline samples, each with a sampling weight of 1.
|
||||||
|
- n_on (int): Number of online samples.
|
||||||
|
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
|
||||||
|
|
||||||
|
The total weight of offline samples is n_off * 1.0.
|
||||||
|
The total weight of offline samples is n_on * w.
|
||||||
|
The total combined weight of all samples is n_off + n_on * w.
|
||||||
|
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
|
||||||
|
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
|
||||||
|
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
|
||||||
|
"""
|
||||||
|
assert 0.0 <= pc_on <= 1.0
|
||||||
|
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
||||||
|
|
||||||
|
|
||||||
|
def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples):
|
||||||
|
data_dict = episodes["data_dict"]
|
||||||
|
data_ids_per_episode = episodes["data_ids_per_episode"]
|
||||||
|
|
||||||
|
if len(online_dataset) == 0:
|
||||||
|
# initialize online dataset
|
||||||
|
online_dataset.data_dict = data_dict
|
||||||
|
online_dataset.data_ids_per_episode = data_ids_per_episode
|
||||||
|
else:
|
||||||
|
# find episode index and data frame indices according to previous episode in online_dataset
|
||||||
|
start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1
|
||||||
|
start_index = online_dataset.data_dict["index"][-1].item() + 1
|
||||||
|
data_dict["episode"] += start_episode
|
||||||
|
data_dict["index"] += start_index
|
||||||
|
|
||||||
|
# extend online dataset
|
||||||
|
for key in data_dict:
|
||||||
|
# TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure
|
||||||
|
online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]])
|
||||||
|
for ep_id in data_ids_per_episode:
|
||||||
|
online_dataset.data_ids_per_episode[ep_id + start_episode] = (
|
||||||
|
data_ids_per_episode[ep_id] + start_index
|
||||||
|
)
|
||||||
|
|
||||||
|
# update the concatenated dataset length used during sampling
|
||||||
|
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||||
|
|
||||||
|
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
|
||||||
|
len_online = len(online_dataset)
|
||||||
|
len_offline = len(concat_dataset) - len_online
|
||||||
|
weight_offline = 1.0
|
||||||
|
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
|
||||||
|
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
|
||||||
|
|
||||||
|
# update the total number of samples used during sampling
|
||||||
|
sampler.num_samples = len(concat_dataset)
|
||||||
|
|
||||||
|
|
||||||
def train(cfg: dict, out_dir=None, job_name=None):
|
def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -128,26 +186,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
set_global_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
|
|
||||||
logging.info("make_dataset")
|
logging.info("make_dataset")
|
||||||
dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
|
|
||||||
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
|
|
||||||
# if cfg.policy.balanced_sampling:
|
|
||||||
# logging.info("make online_buffer")
|
|
||||||
# num_traj_per_batch = cfg.policy.batch_size
|
|
||||||
|
|
||||||
# online_sampler = PrioritizedSliceSampler(
|
|
||||||
# max_capacity=100_000,
|
|
||||||
# alpha=cfg.policy.per_alpha,
|
|
||||||
# beta=cfg.policy.per_beta,
|
|
||||||
# num_slices=num_traj_per_batch,
|
|
||||||
# strict_length=True,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# online_buffer = TensorDictReplayBuffer(
|
|
||||||
# storage=LazyMemmapStorage(100_000),
|
|
||||||
# sampler=online_sampler,
|
|
||||||
# transform=dataset.transform,
|
|
||||||
# )
|
|
||||||
|
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||||
|
@ -165,9 +204,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
logging.info(f"{cfg.env.task=}")
|
logging.info(f"{cfg.env.task=}")
|
||||||
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
||||||
logging.info(f"{cfg.online_steps=}")
|
logging.info(f"{cfg.online_steps=}")
|
||||||
logging.info(f"{cfg.env.action_repeat=}")
|
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
|
||||||
logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})")
|
logging.info(f"{offline_dataset.num_episodes=}")
|
||||||
logging.info(f"{dataset.num_episodes=}")
|
|
||||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
|
@ -175,18 +213,17 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
def _maybe_eval_and_maybe_save(step):
|
def _maybe_eval_and_maybe_save(step):
|
||||||
if step % cfg.eval_freq == 0:
|
if step % cfg.eval_freq == 0:
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
eval_info, first_video = eval_policy(
|
eval_info = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy,
|
policy,
|
||||||
return_first_video=True,
|
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
save_video=True,
|
max_episodes_rendered=4,
|
||||||
transform=dataset.transform,
|
transform=offline_dataset.transform,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
|
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(first_video, step, mode="eval")
|
logger.log_video(eval_info["videos"][0], step, mode="eval")
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
if cfg.save_model and step % cfg.save_freq == 0:
|
if cfg.save_model and step % cfg.save_freq == 0:
|
||||||
|
@ -194,11 +231,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
logger.save_model(policy, identifier=step)
|
logger.save_model(policy, identifier=step)
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
step = 0 # number of policy update (forward + backward + optim)
|
# create dataloader for offline training
|
||||||
|
|
||||||
is_offline = True
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
offline_dataset,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=cfg.policy.batch_size,
|
batch_size=cfg.policy.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
|
@ -206,6 +241,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
|
step = 0 # number of policy update (forward + backward + optim)
|
||||||
|
is_offline = True
|
||||||
for offline_step in range(cfg.offline_steps):
|
for offline_step in range(cfg.offline_steps):
|
||||||
if offline_step == 0:
|
if offline_step == 0:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
|
@ -219,7 +257,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
|
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
|
||||||
|
|
||||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
|
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
|
||||||
# step + 1.
|
# step + 1.
|
||||||
|
@ -227,61 +265,60 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
raise NotImplementedError()
|
# create an env dedicated to online episodes collection from policy rollout
|
||||||
|
rollout_env = make_env(cfg, num_parallel_envs=1)
|
||||||
|
|
||||||
|
# create an empty online dataset similar to offline dataset
|
||||||
|
online_dataset = deepcopy(offline_dataset)
|
||||||
|
online_dataset.data_dict = {}
|
||||||
|
online_dataset.data_ids_per_episode = {}
|
||||||
|
|
||||||
|
# create dataloader for online training
|
||||||
|
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||||
|
weights = [1.0] * len(concat_dataset)
|
||||||
|
sampler = torch.utils.data.WeightedRandomSampler(
|
||||||
|
weights, num_samples=len(concat_dataset), replacement=True
|
||||||
|
)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
concat_dataset,
|
||||||
|
num_workers=4,
|
||||||
|
batch_size=cfg.policy.batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
pin_memory=cfg.device != "cpu",
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
demo_buffer = dataset if cfg.policy.balanced_sampling else None
|
|
||||||
online_step = 0
|
online_step = 0
|
||||||
is_offline = False
|
is_offline = False
|
||||||
for env_step in range(cfg.online_steps):
|
for env_step in range(cfg.online_steps):
|
||||||
if env_step == 0:
|
if env_step == 0:
|
||||||
logging.info("Start online training by interacting with environment")
|
logging.info("Start online training by interacting with environment")
|
||||||
# TODO: add configurable number of rollout? (default=1)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
rollout = env.rollout(
|
eval_info = eval_policy(
|
||||||
max_steps=cfg.env.episode_length,
|
rollout_env,
|
||||||
policy=policy,
|
policy,
|
||||||
auto_cast_to_device=True,
|
transform=offline_dataset.transform,
|
||||||
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
||||||
len(rollout.batch_size) == 2
|
add_episodes_inplace(
|
||||||
), "2 dimensions expected: number of env in parallel x max number of steps during rollout"
|
eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
|
||||||
|
)
|
||||||
num_parallel_env = rollout.batch_size[0]
|
|
||||||
if num_parallel_env != 1:
|
|
||||||
# TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
num_max_steps = rollout.batch_size[1]
|
|
||||||
assert num_max_steps <= cfg.env.episode_length
|
|
||||||
|
|
||||||
# reshape to have a list of steps to insert into online_buffer
|
|
||||||
rollout = rollout.reshape(num_parallel_env * num_max_steps)
|
|
||||||
|
|
||||||
# set same episode index for all time steps contained in this rollout
|
|
||||||
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
|
|
||||||
# online_buffer.extend(rollout)
|
|
||||||
|
|
||||||
ep_sum_reward = rollout["next", "reward"].sum()
|
|
||||||
ep_max_reward = rollout["next", "reward"].max()
|
|
||||||
ep_success = rollout["next", "success"].any()
|
|
||||||
rollout_info = {
|
|
||||||
"avg_sum_reward": np.nanmean(ep_sum_reward),
|
|
||||||
"avg_max_reward": np.nanmean(ep_max_reward),
|
|
||||||
"pc_success": np.nanmean(ep_success) * 100,
|
|
||||||
"env_step": env_step,
|
|
||||||
"ep_length": len(rollout),
|
|
||||||
}
|
|
||||||
|
|
||||||
for _ in range(cfg.policy.utd):
|
for _ in range(cfg.policy.utd):
|
||||||
train_info = policy.update(
|
policy.train()
|
||||||
# online_buffer,
|
batch = next(dl_iter)
|
||||||
step,
|
|
||||||
demo_buffer=demo_buffer,
|
for key in batch:
|
||||||
)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
|
train_info = policy(batch, step)
|
||||||
|
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
train_info.update(rollout_info)
|
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
||||||
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
|
|
||||||
|
|
||||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
|
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
|
||||||
# in step + 1.
|
# in step + 1.
|
||||||
|
|
Loading…
Reference in New Issue