Refactor train, eval_policy, logger, Add diffusion.yaml (WIP)
This commit is contained in:
parent
5a219fed6e
commit
21670dce90
|
@ -4,6 +4,26 @@ from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||||
from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
|
from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
|
||||||
|
|
||||||
|
# TODO(rcadene): implement
|
||||||
|
|
||||||
|
# dataset_d4rl = D4RLExperienceReplay(
|
||||||
|
# dataset_id="maze2d-umaze-v1",
|
||||||
|
# split_trajs=False,
|
||||||
|
# batch_size=1,
|
||||||
|
# sampler=SamplerWithoutReplacement(drop_last=False),
|
||||||
|
# prefetch=4,
|
||||||
|
# direct_download=True,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# dataset_openx = OpenXExperienceReplay(
|
||||||
|
# "cmu_stretch",
|
||||||
|
# batch_size=1,
|
||||||
|
# num_slices=1,
|
||||||
|
# #download="force",
|
||||||
|
# streaming=False,
|
||||||
|
# root="data",
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
def make_offline_buffer(cfg, sampler=None):
|
def make_offline_buffer(cfg, sampler=None):
|
||||||
|
|
||||||
|
|
|
@ -10,10 +10,10 @@ from termcolor import colored
|
||||||
|
|
||||||
CONSOLE_FORMAT = [
|
CONSOLE_FORMAT = [
|
||||||
("episode", "E", "int"),
|
("episode", "E", "int"),
|
||||||
("env_step", "S", "int"),
|
("step", "S", "int"),
|
||||||
("avg_sum_reward", "RS", "float"),
|
("avg_sum_reward", "RS", "float"),
|
||||||
("avg_max_reward", "RM", "float"),
|
("avg_max_reward", "RM", "float"),
|
||||||
("pc_success", "S", "float"),
|
("pc_success", "SR", "float"),
|
||||||
("total_time", "T", "time"),
|
("total_time", "T", "time"),
|
||||||
]
|
]
|
||||||
AGENT_METRICS = [
|
AGENT_METRICS = [
|
||||||
|
@ -51,7 +51,9 @@ def print_run(cfg, reward=None):
|
||||||
|
|
||||||
kvs = [
|
kvs = [
|
||||||
("task", cfg.env.task),
|
("task", cfg.env.task),
|
||||||
("train steps", f"{int(cfg.train_steps * cfg.env.action_repeat):,}"),
|
("offline_steps", f"{cfg.offline_steps}"),
|
||||||
|
("online_steps", f"{cfg.online_steps}"),
|
||||||
|
("action_repeat", f"{cfg.env.action_repeat}"),
|
||||||
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
|
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
|
||||||
# ('actions', cfg.action_dim),
|
# ('actions', cfg.action_dim),
|
||||||
# ('experiment', cfg.exp_name),
|
# ('experiment', cfg.exp_name),
|
||||||
|
@ -78,54 +80,6 @@ def cfg_to_group(cfg, return_list=False):
|
||||||
return lst if return_list else "-".join(lst)
|
return lst if return_list else "-".join(lst)
|
||||||
|
|
||||||
|
|
||||||
class VideoRecorder:
|
|
||||||
"""Utility class for logging evaluation videos."""
|
|
||||||
|
|
||||||
def __init__(self, root_dir, wandb, render_size=384, fps=15):
|
|
||||||
self.save_dir = (root_dir / "eval_video") if root_dir else None
|
|
||||||
self._wandb = wandb
|
|
||||||
self.render_size = render_size
|
|
||||||
self.fps = fps
|
|
||||||
self.frames = []
|
|
||||||
self.enabled = False
|
|
||||||
self.camera_id = 0
|
|
||||||
|
|
||||||
def init(self, env, enabled=True):
|
|
||||||
self.frames = []
|
|
||||||
self.enabled = self.save_dir and self._wandb and enabled
|
|
||||||
try:
|
|
||||||
env_name = env.unwrapped.spec.id
|
|
||||||
except:
|
|
||||||
env_name = ""
|
|
||||||
if "maze2d" in env_name:
|
|
||||||
self.camera_id = -1
|
|
||||||
elif "quadruped" in env_name:
|
|
||||||
self.camera_id = 2
|
|
||||||
self.record(env)
|
|
||||||
|
|
||||||
def record(self, env):
|
|
||||||
if self.enabled:
|
|
||||||
frame = env.render(
|
|
||||||
mode="rgb_array",
|
|
||||||
height=self.render_size,
|
|
||||||
width=self.render_size,
|
|
||||||
camera_id=self.camera_id,
|
|
||||||
)
|
|
||||||
self.frames.append(frame)
|
|
||||||
|
|
||||||
def save(self, step):
|
|
||||||
if self.enabled:
|
|
||||||
frames = np.stack(self.frames).transpose(0, 3, 1, 2)
|
|
||||||
self._wandb.log(
|
|
||||||
{
|
|
||||||
"eval_video": self._wandb.Video(
|
|
||||||
frames, fps=self.env.fps, format="mp4"
|
|
||||||
)
|
|
||||||
},
|
|
||||||
step=step,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Logger(object):
|
class Logger(object):
|
||||||
"""Primary logger object. Logs either locally or using wandb."""
|
"""Primary logger object. Logs either locally or using wandb."""
|
||||||
|
|
||||||
|
@ -170,15 +124,6 @@ class Logger(object):
|
||||||
)
|
)
|
||||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||||
self._wandb = wandb
|
self._wandb = wandb
|
||||||
self._video = (
|
|
||||||
VideoRecorder(self._log_dir, self._wandb)
|
|
||||||
if self._wandb and cfg.save_video
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def video(self):
|
|
||||||
return self._video
|
|
||||||
|
|
||||||
def save_model(self, agent, identifier):
|
def save_model(self, agent, identifier):
|
||||||
if self._save_model:
|
if self._save_model:
|
||||||
|
@ -214,12 +159,12 @@ class Logger(object):
|
||||||
|
|
||||||
def _format(self, key, value, ty):
|
def _format(self, key, value, ty):
|
||||||
if ty == "int":
|
if ty == "int":
|
||||||
return f'{colored(key + ":", "grey")} {int(value):,}'
|
return f'{colored(key + ":", "yellow")} {int(value):,}'
|
||||||
elif ty == "float":
|
elif ty == "float":
|
||||||
return f'{colored(key + ":", "grey")} {value:.01f}'
|
return f'{colored(key + ":", "yellow")} {value:.01f}'
|
||||||
elif ty == "time":
|
elif ty == "time":
|
||||||
value = str(datetime.timedelta(seconds=int(value)))
|
value = str(datetime.timedelta(seconds=int(value)))
|
||||||
return f'{colored(key + ":", "grey")} {value}'
|
return f'{colored(key + ":", "yellow")} {value}'
|
||||||
else:
|
else:
|
||||||
raise f"invalid log format type: {ty}"
|
raise f"invalid log format type: {ty}"
|
||||||
|
|
||||||
|
@ -234,10 +179,9 @@ class Logger(object):
|
||||||
assert category in {"train", "eval"}
|
assert category in {"train", "eval"}
|
||||||
if self._wandb is not None:
|
if self._wandb is not None:
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
self._wandb.log({category + "/" + k: v}, step=d["env_step"])
|
self._wandb.log({category + "/" + k: v}, step=d["step"])
|
||||||
if category == "eval":
|
if category == "eval":
|
||||||
# keys = ['env_step', 'avg_reward']
|
keys = ["step", "avg_sum_reward", "avg_max_reward", "pc_success"]
|
||||||
keys = ["env_step", "avg_sum_reward", "avg_max_reward", "pc_success"]
|
|
||||||
self._eval.append(np.array([d[key] for key in keys]))
|
self._eval.append(np.array([d[key] for key in keys]))
|
||||||
pd.DataFrame(np.array(self._eval)).to_csv(
|
pd.DataFrame(np.array(self._eval)).to_csv(
|
||||||
self._log_dir / "eval.log", header=keys, index=None
|
self._log_dir / "eval.log", header=keys, index=None
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||||
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,29 @@ def make_policy(cfg):
|
||||||
|
|
||||||
policy = TDMPC(cfg.policy)
|
policy = TDMPC(cfg.policy)
|
||||||
elif cfg.policy.name == "diffusion":
|
elif cfg.policy.name == "diffusion":
|
||||||
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||||
|
from diffusion_policy.model.vision.multi_image_obs_encoder import (
|
||||||
|
MultiImageObsEncoder,
|
||||||
|
)
|
||||||
|
|
||||||
from lerobot.common.policies.diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion import DiffusionPolicy
|
||||||
|
|
||||||
policy = DiffusionPolicy(cfg.policy)
|
noise_scheduler = DDPMScheduler(**cfg.noise_scheduler)
|
||||||
|
|
||||||
|
rgb_model = get_resnet(**cfg.rgb_model)
|
||||||
|
|
||||||
|
obs_encoder = MultiImageObsEncoder(
|
||||||
|
rgb_model=rgb_model,
|
||||||
|
**cfg.obs_encoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
policy = DiffusionPolicy(
|
||||||
|
noise_scheduler=noise_scheduler,
|
||||||
|
obs_encoder=obs_encoder,
|
||||||
|
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||||
|
**cfg.policy,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.policy.name)
|
raise ValueError(cfg.policy.name)
|
||||||
|
|
||||||
|
|
|
@ -441,261 +441,6 @@ class Episode(object):
|
||||||
self._idx += 1
|
self._idx += 1
|
||||||
|
|
||||||
|
|
||||||
class ReplayBuffer:
|
|
||||||
"""
|
|
||||||
Storage and sampling functionality.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, cfg, dataset=None):
|
|
||||||
action_dim = cfg.action_dim
|
|
||||||
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (cfg.state_dim,)}
|
|
||||||
|
|
||||||
self.cfg = cfg
|
|
||||||
self.device = torch.device(cfg.buffer_device)
|
|
||||||
print("Replay buffer device: ", self.device)
|
|
||||||
|
|
||||||
if dataset is not None:
|
|
||||||
self.capacity = max(dataset["rewards"].shape[0], cfg.max_buffer_size)
|
|
||||||
else:
|
|
||||||
self.capacity = min(cfg.train_steps, cfg.max_buffer_size)
|
|
||||||
|
|
||||||
if cfg.modality in {"pixels", "state"}:
|
|
||||||
dtype = torch.float32 if cfg.modality == "state" else torch.uint8
|
|
||||||
# Note self.obs_shape always has single frame, which is different from cfg.obs_shape
|
|
||||||
self.obs_shape = (
|
|
||||||
obs_shape if cfg.modality == "state" else (3, *obs_shape[-2:])
|
|
||||||
)
|
|
||||||
self._obs = torch.zeros(
|
|
||||||
(self.capacity + cfg.horizon - 1, *self.obs_shape),
|
|
||||||
dtype=dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self._next_obs = torch.zeros(
|
|
||||||
(self.capacity + cfg.horizon - 1, *self.obs_shape),
|
|
||||||
dtype=dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
elif cfg.modality == "all":
|
|
||||||
self.obs_shape = {}
|
|
||||||
self._obs, self._next_obs = {}, {}
|
|
||||||
for k, v in obs_shape.items():
|
|
||||||
assert k in {"rgb", "state"}
|
|
||||||
dtype = torch.float32 if k == "state" else torch.uint8
|
|
||||||
self.obs_shape[k] = v if k == "state" else (3, *v[-2:])
|
|
||||||
self._obs[k] = torch.zeros(
|
|
||||||
(self.capacity + cfg.horizon - 1, *self.obs_shape[k]),
|
|
||||||
dtype=dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self._next_obs[k] = self._obs[k].clone()
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
self._action = torch.zeros(
|
|
||||||
(self.capacity + cfg.horizon - 1, action_dim),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self._reward = torch.zeros(
|
|
||||||
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
|
|
||||||
)
|
|
||||||
self._mask = torch.zeros(
|
|
||||||
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
|
|
||||||
)
|
|
||||||
self._done = torch.zeros(
|
|
||||||
(self.capacity + cfg.horizon - 1,), dtype=torch.bool, device=self.device
|
|
||||||
)
|
|
||||||
self._priorities = torch.ones(
|
|
||||||
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
|
|
||||||
)
|
|
||||||
self._eps = 1e-6
|
|
||||||
self._full = False
|
|
||||||
self.idx = 0
|
|
||||||
if dataset is not None:
|
|
||||||
self.init_from_offline_dataset(dataset)
|
|
||||||
|
|
||||||
self._aug = aug(cfg)
|
|
||||||
|
|
||||||
def init_from_offline_dataset(self, dataset):
|
|
||||||
"""Initialize the replay buffer from an offline dataset."""
|
|
||||||
assert self.idx == 0 and not self._full
|
|
||||||
n_transitions = int(len(dataset["rewards"]) * self.cfg.data_first_percent)
|
|
||||||
|
|
||||||
def copy_data(dst, src, n):
|
|
||||||
assert isinstance(dst, dict) == isinstance(src, dict)
|
|
||||||
if isinstance(dst, dict):
|
|
||||||
for k in dst:
|
|
||||||
copy_data(dst[k], src[k], n)
|
|
||||||
else:
|
|
||||||
dst[:n] = torch.from_numpy(src[:n])
|
|
||||||
|
|
||||||
copy_data(self._obs, dataset["observations"], n_transitions)
|
|
||||||
copy_data(self._next_obs, dataset["next_observations"], n_transitions)
|
|
||||||
copy_data(self._action, dataset["actions"], n_transitions)
|
|
||||||
copy_data(self._reward, dataset["rewards"], n_transitions)
|
|
||||||
copy_data(self._mask, dataset["masks"], n_transitions)
|
|
||||||
copy_data(self._done, dataset["dones"], n_transitions)
|
|
||||||
self.idx = (self.idx + n_transitions) % self.capacity
|
|
||||||
self._full = n_transitions >= self.capacity
|
|
||||||
|
|
||||||
def __add__(self, episode: Episode):
|
|
||||||
self.add(episode)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add(self, episode: Episode):
|
|
||||||
"""Add an episode to the replay buffer."""
|
|
||||||
if self.idx + len(episode) > self.capacity:
|
|
||||||
print("Warning: episode got truncated")
|
|
||||||
ep_len = min(len(episode), self.capacity - self.idx)
|
|
||||||
idxs = slice(self.idx, self.idx + ep_len)
|
|
||||||
assert self.idx + ep_len <= self.capacity
|
|
||||||
if self.cfg.modality in {"pixels", "state"}:
|
|
||||||
self._obs[idxs] = (
|
|
||||||
episode.obses[:ep_len]
|
|
||||||
if self.cfg.modality == "state"
|
|
||||||
else episode.obses[:ep_len, -3:]
|
|
||||||
)
|
|
||||||
self._next_obs[idxs] = (
|
|
||||||
episode.obses[1 : ep_len + 1]
|
|
||||||
if self.cfg.modality == "state"
|
|
||||||
else episode.obses[1 : ep_len + 1, -3:]
|
|
||||||
)
|
|
||||||
elif self.cfg.modality == "all":
|
|
||||||
for k, v in episode.obses.items():
|
|
||||||
assert k in {"rgb", "state"}
|
|
||||||
assert k in self._obs
|
|
||||||
assert k in self._next_obs
|
|
||||||
if k == "rgb":
|
|
||||||
self._obs[k][idxs] = episode.obses[k][:ep_len, -3:]
|
|
||||||
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1, -3:]
|
|
||||||
else:
|
|
||||||
self._obs[k][idxs] = episode.obses[k][:ep_len]
|
|
||||||
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1]
|
|
||||||
self._action[idxs] = episode.actions[:ep_len]
|
|
||||||
self._reward[idxs] = episode.rewards[:ep_len]
|
|
||||||
self._mask[idxs] = episode.masks[:ep_len]
|
|
||||||
self._done[idxs] = episode.dones[:ep_len]
|
|
||||||
self._done[self.idx + ep_len - 1] = True # in case truncated
|
|
||||||
if self._full:
|
|
||||||
max_priority = (
|
|
||||||
self._priorities[: self.capacity].max().to(self.device).item()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
max_priority = (
|
|
||||||
1.0
|
|
||||||
if self.idx == 0
|
|
||||||
else self._priorities[: self.idx].max().to(self.device).item()
|
|
||||||
)
|
|
||||||
new_priorities = torch.full((ep_len,), max_priority, device=self.device)
|
|
||||||
self._priorities[idxs] = new_priorities
|
|
||||||
self.idx = (self.idx + ep_len) % self.capacity
|
|
||||||
self._full = self._full or self.idx == 0
|
|
||||||
|
|
||||||
def update_priorities(self, idxs, priorities):
|
|
||||||
"""Update priorities for Prioritized Experience Replay (PER)"""
|
|
||||||
self._priorities[idxs] = priorities.squeeze(1).to(self.device) + self._eps
|
|
||||||
|
|
||||||
def _get_obs(self, arr, idxs):
|
|
||||||
"""Retrieve observations by indices"""
|
|
||||||
if isinstance(arr, dict):
|
|
||||||
return {k: self._get_obs(v, idxs) for k, v in arr.items()}
|
|
||||||
if arr.ndim <= 2: # if self.cfg.modality == 'state':
|
|
||||||
return arr[idxs].cuda()
|
|
||||||
obs = torch.empty(
|
|
||||||
(self.cfg.batch_size, 3 * self.cfg.frame_stack, *arr.shape[-2:]),
|
|
||||||
dtype=arr.dtype,
|
|
||||||
device=torch.device("cuda"),
|
|
||||||
)
|
|
||||||
obs[:, -3:] = arr[idxs].cuda()
|
|
||||||
_idxs = idxs.clone()
|
|
||||||
mask = torch.ones_like(_idxs, dtype=torch.bool)
|
|
||||||
for i in range(1, self.cfg.frame_stack):
|
|
||||||
mask[_idxs % self.cfg.episode_length == 0] = False
|
|
||||||
_idxs[mask] -= 1
|
|
||||||
obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda()
|
|
||||||
return obs.float()
|
|
||||||
|
|
||||||
def sample(self):
|
|
||||||
"""Sample transitions from the replay buffer."""
|
|
||||||
probs = (
|
|
||||||
self._priorities[: self.capacity]
|
|
||||||
if self._full
|
|
||||||
else self._priorities[: self.idx]
|
|
||||||
) ** self.cfg.per_alpha
|
|
||||||
probs /= probs.sum()
|
|
||||||
total = len(probs)
|
|
||||||
idxs = torch.from_numpy(
|
|
||||||
np.random.choice(
|
|
||||||
total,
|
|
||||||
self.cfg.batch_size,
|
|
||||||
p=probs.cpu().numpy(),
|
|
||||||
replace=not self._full,
|
|
||||||
)
|
|
||||||
).to(self.device)
|
|
||||||
weights = (total * probs[idxs]) ** (-self.cfg.per_beta)
|
|
||||||
weights /= weights.max()
|
|
||||||
|
|
||||||
idxs_in_horizon = torch.stack([idxs + t for t in range(self.cfg.horizon)])
|
|
||||||
|
|
||||||
obs = self._aug(self._get_obs(self._obs, idxs))
|
|
||||||
next_obs = [
|
|
||||||
self._aug(self._get_obs(self._next_obs, _idxs)) for _idxs in idxs_in_horizon
|
|
||||||
]
|
|
||||||
if isinstance(next_obs[0], dict):
|
|
||||||
next_obs = {k: torch.stack([o[k] for o in next_obs]) for k in next_obs[0]}
|
|
||||||
else:
|
|
||||||
next_obs = torch.stack(next_obs)
|
|
||||||
action = self._action[idxs_in_horizon]
|
|
||||||
reward = self._reward[idxs_in_horizon]
|
|
||||||
mask = self._mask[idxs_in_horizon]
|
|
||||||
done = self._done[idxs_in_horizon]
|
|
||||||
|
|
||||||
if not action.is_cuda:
|
|
||||||
action, reward, mask, done, idxs, weights = (
|
|
||||||
action.cuda(),
|
|
||||||
reward.cuda(),
|
|
||||||
mask.cuda(),
|
|
||||||
done.cuda(),
|
|
||||||
idxs.cuda(),
|
|
||||||
weights.cuda(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
obs,
|
|
||||||
next_obs,
|
|
||||||
action,
|
|
||||||
reward.unsqueeze(2),
|
|
||||||
mask.unsqueeze(2),
|
|
||||||
done.unsqueeze(2),
|
|
||||||
idxs,
|
|
||||||
weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
def save(self, path):
|
|
||||||
"""Save the replay buffer to path"""
|
|
||||||
print(f"saving replay buffer to '{path}'...")
|
|
||||||
sz = self.capacity if self._full else self.idx
|
|
||||||
dataset = {
|
|
||||||
"observations": (
|
|
||||||
{k: v[:sz].cpu().numpy() for k, v in self._obs.items()}
|
|
||||||
if isinstance(self._obs, dict)
|
|
||||||
else self._obs[:sz].cpu().numpy()
|
|
||||||
),
|
|
||||||
"next_observations": (
|
|
||||||
{k: v[:sz].cpu().numpy() for k, v in self._next_obs.items()}
|
|
||||||
if isinstance(self._next_obs, dict)
|
|
||||||
else self._next_obs[:sz].cpu().numpy()
|
|
||||||
),
|
|
||||||
"actions": self._action[:sz].cpu().numpy(),
|
|
||||||
"rewards": self._reward[:sz].cpu().numpy(),
|
|
||||||
"dones": self._done[:sz].cpu().numpy(),
|
|
||||||
"masks": self._mask[:sz].cpu().numpy(),
|
|
||||||
}
|
|
||||||
with open(path, "wb") as f:
|
|
||||||
pickle.dump(dataset, f)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
|
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
|
||||||
"""Construct a dataset for env"""
|
"""Construct a dataset for env"""
|
||||||
required_keys = [
|
required_keys = [
|
||||||
|
|
|
@ -3,7 +3,9 @@
|
||||||
eval_episodes: 50
|
eval_episodes: 50
|
||||||
eval_freq: 7500
|
eval_freq: 7500
|
||||||
save_freq: 75000
|
save_freq: 75000
|
||||||
train_steps: 50000 # TODO: same as simxarm, need to adjust
|
# TODO: same as simxarm, need to adjust
|
||||||
|
offline_steps: 25000
|
||||||
|
online_steps: 25000
|
||||||
|
|
||||||
fps: 10
|
fps: 10
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,9 @@
|
||||||
eval_episodes: 20
|
eval_episodes: 20
|
||||||
eval_freq: 1000
|
eval_freq: 1000
|
||||||
save_freq: 10000
|
save_freq: 10000
|
||||||
train_steps: 50000
|
log_freq: 50
|
||||||
|
offline_steps: 25000
|
||||||
|
online_steps: 25000
|
||||||
|
|
||||||
fps: 15
|
fps: 15
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,117 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
shape_meta:
|
||||||
|
# acceptable types: rgb, low_dim
|
||||||
|
obs:
|
||||||
|
image:
|
||||||
|
shape: [3, 96, 96]
|
||||||
|
type: rgb
|
||||||
|
agent_pos:
|
||||||
|
shape: [2]
|
||||||
|
type: low_dim
|
||||||
|
action:
|
||||||
|
shape: [2]
|
||||||
|
|
||||||
|
horizon: 16
|
||||||
|
n_obs_steps: 2
|
||||||
|
n_action_steps: 8
|
||||||
|
n_latency_steps: 0
|
||||||
|
dataset_obs_steps: ${n_obs_steps}
|
||||||
|
past_action_visible: False
|
||||||
|
keypoint_visible_rate: 1.0
|
||||||
|
obs_as_global_cond: True
|
||||||
|
|
||||||
|
policy:
|
||||||
|
name: diffusion
|
||||||
|
|
||||||
|
shape_meta: ${shape_meta}
|
||||||
|
|
||||||
|
horizon: ${horizon}
|
||||||
|
# n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
||||||
|
n_obs_steps: ${n_obs_steps}
|
||||||
|
num_inference_steps: 100
|
||||||
|
obs_as_global_cond: ${obs_as_global_cond}
|
||||||
|
# crop_shape: null
|
||||||
|
diffusion_step_embed_dim: 128
|
||||||
|
down_dims: [512, 1024, 2048]
|
||||||
|
kernel_size: 5
|
||||||
|
n_groups: 8
|
||||||
|
cond_predict_scale: True
|
||||||
|
|
||||||
|
pretrained_model_path:
|
||||||
|
|
||||||
|
batch_size: 64
|
||||||
|
|
||||||
|
per_alpha: 0.6
|
||||||
|
per_beta: 0.4
|
||||||
|
|
||||||
|
balanced_sampling: true
|
||||||
|
|
||||||
|
utd: 1
|
||||||
|
|
||||||
|
noise_scheduler:
|
||||||
|
# _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||||
|
num_train_timesteps: 100
|
||||||
|
beta_start: 0.0001
|
||||||
|
beta_end: 0.02
|
||||||
|
beta_schedule: squaredcos_cap_v2
|
||||||
|
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
||||||
|
clip_sample: True # required when predict_epsilon=False
|
||||||
|
prediction_type: epsilon # or sample
|
||||||
|
|
||||||
|
obs_encoder:
|
||||||
|
# _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
||||||
|
shape_meta: ${shape_meta}
|
||||||
|
resize_shape: null
|
||||||
|
crop_shape: [76, 76]
|
||||||
|
# constant center crop
|
||||||
|
random_crop: True
|
||||||
|
use_group_norm: True
|
||||||
|
share_rgb_model: False
|
||||||
|
imagenet_norm: True
|
||||||
|
|
||||||
|
rgb_model:
|
||||||
|
#_target_: diffusion_policy.model.vision.model_getter.get_resnet
|
||||||
|
name: resnet18
|
||||||
|
weights: null
|
||||||
|
|
||||||
|
ema:
|
||||||
|
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
||||||
|
update_after_step: 0
|
||||||
|
inv_gamma: 1.0
|
||||||
|
power: 0.75
|
||||||
|
min_value: 0.0
|
||||||
|
max_value: 0.9999
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
_target_: torch.optim.AdamW
|
||||||
|
lr: 1.0e-4
|
||||||
|
betas: [0.95, 0.999]
|
||||||
|
eps: 1.0e-8
|
||||||
|
weight_decay: 1.0e-6
|
||||||
|
|
||||||
|
training:
|
||||||
|
device: "cuda:0"
|
||||||
|
seed: 42
|
||||||
|
debug: False
|
||||||
|
resume: True
|
||||||
|
# optimization
|
||||||
|
lr_scheduler: cosine
|
||||||
|
lr_warmup_steps: 500
|
||||||
|
num_epochs: 8000
|
||||||
|
gradient_accumulate_every: 1
|
||||||
|
# EMA destroys performance when used with BatchNorm
|
||||||
|
# replace BatchNorm with GroupNorm.
|
||||||
|
use_ema: True
|
||||||
|
freeze_encoder: False
|
||||||
|
# training loop control
|
||||||
|
# in epochs
|
||||||
|
rollout_every: 50
|
||||||
|
checkpoint_every: 50
|
||||||
|
val_every: 1
|
||||||
|
sample_every: 5
|
||||||
|
# steps per epoch
|
||||||
|
max_train_steps: null
|
||||||
|
max_val_steps: null
|
||||||
|
# misc
|
||||||
|
tqdm_interval_sec: 1.0
|
|
@ -5,8 +5,6 @@ policy:
|
||||||
|
|
||||||
reward_scale: 1.0
|
reward_scale: 1.0
|
||||||
|
|
||||||
# xarm_lift
|
|
||||||
train_steps: ${train_steps}
|
|
||||||
episode_length: ${env.episode_length}
|
episode_length: ${env.episode_length}
|
||||||
discount: 0.9
|
discount: 0.9
|
||||||
modality: 'all'
|
modality: 'all'
|
||||||
|
|
|
@ -26,31 +26,31 @@ def eval_policy(
|
||||||
save_video: bool = False,
|
save_video: bool = False,
|
||||||
video_dir: Path = None,
|
video_dir: Path = None,
|
||||||
fps: int = 15,
|
fps: int = 15,
|
||||||
env_step: int = None,
|
return_first_video: bool = False,
|
||||||
wandb=None,
|
|
||||||
):
|
):
|
||||||
if wandb is not None:
|
|
||||||
assert env_step is not None
|
|
||||||
sum_rewards = []
|
sum_rewards = []
|
||||||
max_rewards = []
|
max_rewards = []
|
||||||
successes = []
|
successes = []
|
||||||
threads = []
|
threads = []
|
||||||
for i in range(num_episodes):
|
for i in range(num_episodes):
|
||||||
ep_frames = []
|
|
||||||
|
|
||||||
def rendering_callback(env, td=None):
|
|
||||||
ep_frames.append(env.render())
|
|
||||||
|
|
||||||
tensordict = env.reset()
|
tensordict = env.reset()
|
||||||
if save_video or wandb:
|
|
||||||
|
ep_frames = []
|
||||||
|
if save_video or (return_first_video and i == 0):
|
||||||
|
|
||||||
|
def rendering_callback(env, td=None):
|
||||||
|
ep_frames.append(env.render())
|
||||||
|
|
||||||
# render first frame before rollout
|
# render first frame before rollout
|
||||||
rendering_callback(env)
|
rendering_callback(env)
|
||||||
|
else:
|
||||||
|
rendering_callback = None
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
rollout = env.rollout(
|
rollout = env.rollout(
|
||||||
max_steps=max_steps,
|
max_steps=max_steps,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
callback=rendering_callback if save_video or wandb else None,
|
callback=rendering_callback,
|
||||||
auto_reset=False,
|
auto_reset=False,
|
||||||
tensordict=tensordict,
|
tensordict=tensordict,
|
||||||
auto_cast_to_device=True,
|
auto_cast_to_device=True,
|
||||||
|
@ -63,7 +63,7 @@ def eval_policy(
|
||||||
max_rewards.append(ep_max_reward.item())
|
max_rewards.append(ep_max_reward.item())
|
||||||
successes.append(ep_success.item())
|
successes.append(ep_success.item())
|
||||||
|
|
||||||
if save_video or wandb:
|
if save_video or (return_first_video and i == 0):
|
||||||
stacked_frames = np.stack(ep_frames)
|
stacked_frames = np.stack(ep_frames)
|
||||||
|
|
||||||
if save_video:
|
if save_video:
|
||||||
|
@ -76,12 +76,8 @@ def eval_policy(
|
||||||
thread.start()
|
thread.start()
|
||||||
threads.append(thread)
|
threads.append(thread)
|
||||||
|
|
||||||
first_episode = i == 0
|
if return_first_video and i == 0:
|
||||||
if wandb and first_episode:
|
first_video = stacked_frames.transpose(0, 3, 1, 2)
|
||||||
eval_video = wandb.Video(
|
|
||||||
stacked_frames.transpose(0, 3, 1, 2), fps=fps, format="mp4"
|
|
||||||
)
|
|
||||||
wandb.log({"eval_video": eval_video}, step=env_step)
|
|
||||||
|
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
@ -91,6 +87,8 @@ def eval_policy(
|
||||||
"avg_max_reward": np.nanmean(max_rewards),
|
"avg_max_reward": np.nanmean(max_rewards),
|
||||||
"pc_success": np.nanmean(successes) * 100,
|
"pc_success": np.nanmean(successes) * 100,
|
||||||
}
|
}
|
||||||
|
if return_first_video:
|
||||||
|
return metrics, first_video
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,40 @@ def train_notebook(
|
||||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||||
|
|
||||||
|
|
||||||
|
def log_training_metrics(L, metrics, step, online_episode_idx, start_time, is_offline):
|
||||||
|
common_metrics = {
|
||||||
|
"episode": online_episode_idx,
|
||||||
|
"step": step,
|
||||||
|
"total_time": time.time() - start_time,
|
||||||
|
"is_offline": float(is_offline),
|
||||||
|
}
|
||||||
|
metrics.update(common_metrics)
|
||||||
|
L.log(metrics, category="train")
|
||||||
|
|
||||||
|
|
||||||
|
def eval_policy_and_log(
|
||||||
|
env, td_policy, step, online_episode_idx, start_time, is_offline, cfg, L
|
||||||
|
):
|
||||||
|
common_metrics = {
|
||||||
|
"episode": online_episode_idx,
|
||||||
|
"step": step,
|
||||||
|
"total_time": time.time() - start_time,
|
||||||
|
"is_offline": float(is_offline),
|
||||||
|
}
|
||||||
|
metrics, first_video = eval_policy(
|
||||||
|
env,
|
||||||
|
td_policy,
|
||||||
|
num_episodes=cfg.eval_episodes,
|
||||||
|
return_first_video=True,
|
||||||
|
)
|
||||||
|
metrics.update(common_metrics)
|
||||||
|
L.log(metrics, category="eval")
|
||||||
|
|
||||||
|
if cfg.wandb.enable:
|
||||||
|
eval_video = L._wandb.Video(first_video, fps=cfg.fps, format="mp4")
|
||||||
|
L._wandb.log({"eval_video": eval_video}, step=step)
|
||||||
|
|
||||||
|
|
||||||
def train(cfg: dict, out_dir=None, job_name=None):
|
def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -84,115 +118,89 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
online_episode_idx = 0
|
online_episode_idx = 0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
step = 0
|
step = 0
|
||||||
last_log_step = 0
|
|
||||||
last_save_step = 0
|
|
||||||
|
|
||||||
while step < cfg.train_steps:
|
# First eval with a random model or pretrained
|
||||||
is_offline = True
|
eval_policy_and_log(
|
||||||
num_updates = cfg.env.episode_length
|
env, td_policy, step, online_episode_idx, start_time, is_offline, cfg, L
|
||||||
_step = step + num_updates
|
)
|
||||||
rollout_metrics = {}
|
|
||||||
|
|
||||||
# TODO(rcadene): move offline_steps outside policy
|
# Train offline
|
||||||
if step >= cfg.policy.offline_steps:
|
for _ in range(cfg.offline_steps):
|
||||||
is_offline = False
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
|
metrics = policy.update(offline_buffer, step)
|
||||||
|
|
||||||
# TODO: use SyncDataCollector for that?
|
if step % cfg.log_freq == 0:
|
||||||
with torch.no_grad():
|
log_training_metrics(
|
||||||
rollout = env.rollout(
|
L, metrics, step, online_episode_idx, start_time, is_offline=False
|
||||||
max_steps=cfg.env.episode_length,
|
|
||||||
policy=td_policy,
|
|
||||||
auto_cast_to_device=True,
|
|
||||||
)
|
|
||||||
assert len(rollout) <= cfg.env.episode_length
|
|
||||||
rollout["episode"] = torch.tensor(
|
|
||||||
[online_episode_idx] * len(rollout), dtype=torch.int
|
|
||||||
)
|
)
|
||||||
online_buffer.extend(rollout)
|
|
||||||
|
|
||||||
ep_sum_reward = rollout["next", "reward"].sum()
|
if step > 0 and step % cfg.eval_freq == 0:
|
||||||
ep_max_reward = rollout["next", "reward"].max()
|
eval_policy_and_log(
|
||||||
ep_success = rollout["next", "success"].any()
|
env, td_policy, step, online_episode_idx, start_time, is_offline, cfg, L
|
||||||
|
)
|
||||||
|
|
||||||
online_episode_idx += 1
|
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
||||||
rollout_metrics = {
|
print(f"Checkpoint model at step {step}")
|
||||||
"avg_sum_reward": np.nanmean(ep_sum_reward),
|
L.save_model(policy, identifier=step)
|
||||||
"avg_max_reward": np.nanmean(ep_max_reward),
|
|
||||||
"pc_success": np.nanmean(ep_success) * 100,
|
|
||||||
}
|
|
||||||
num_updates = len(rollout) * cfg.policy.utd
|
|
||||||
_step = min(step + len(rollout), cfg.train_steps)
|
|
||||||
|
|
||||||
# Update model
|
step += 1
|
||||||
for i in range(num_updates):
|
|
||||||
if is_offline:
|
|
||||||
train_metrics = policy.update(offline_buffer, step + i)
|
|
||||||
else:
|
|
||||||
train_metrics = policy.update(
|
|
||||||
online_buffer,
|
|
||||||
step + i // cfg.policy.utd,
|
|
||||||
demo_buffer=(
|
|
||||||
offline_buffer if cfg.policy.balanced_sampling else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log training metrics
|
# Train online
|
||||||
env_step = int(_step * cfg.env.action_repeat)
|
demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None
|
||||||
common_metrics = {
|
for _ in range(cfg.online_steps):
|
||||||
"episode": online_episode_idx,
|
# TODO: use SyncDataCollector for that?
|
||||||
"step": _step,
|
with torch.no_grad():
|
||||||
"env_step": env_step,
|
rollout = env.rollout(
|
||||||
"total_time": time.time() - start_time,
|
max_steps=cfg.env.episode_length,
|
||||||
"is_offline": float(is_offline),
|
policy=td_policy,
|
||||||
|
auto_cast_to_device=True,
|
||||||
|
)
|
||||||
|
assert len(rollout) <= cfg.env.episode_length
|
||||||
|
rollout["episode"] = torch.tensor(
|
||||||
|
[online_episode_idx] * len(rollout), dtype=torch.int
|
||||||
|
)
|
||||||
|
online_buffer.extend(rollout)
|
||||||
|
|
||||||
|
ep_sum_reward = rollout["next", "reward"].sum()
|
||||||
|
ep_max_reward = rollout["next", "reward"].max()
|
||||||
|
ep_success = rollout["next", "success"].any()
|
||||||
|
metrics = {
|
||||||
|
"avg_sum_reward": np.nanmean(ep_sum_reward),
|
||||||
|
"avg_max_reward": np.nanmean(ep_max_reward),
|
||||||
|
"pc_success": np.nanmean(ep_success) * 100,
|
||||||
}
|
}
|
||||||
train_metrics.update(common_metrics)
|
|
||||||
train_metrics.update(rollout_metrics)
|
|
||||||
L.log(train_metrics, category="train")
|
|
||||||
|
|
||||||
# Evaluate policy periodically
|
online_episode_idx += 1
|
||||||
if step == 0 or env_step - last_log_step >= cfg.eval_freq:
|
|
||||||
|
|
||||||
eval_metrics = eval_policy(
|
for _ in range(cfg.policy.utd):
|
||||||
env,
|
train_metrics = policy.update(
|
||||||
td_policy,
|
online_buffer,
|
||||||
num_episodes=cfg.eval_episodes,
|
step,
|
||||||
env_step=env_step,
|
demo_buffer=demo_buffer,
|
||||||
wandb=L._wandb,
|
|
||||||
)
|
)
|
||||||
|
metrics.update(train_metrics)
|
||||||
|
if step % cfg.log_freq == 0:
|
||||||
|
log_training_metrics(
|
||||||
|
L, metrics, step, online_episode_idx, start_time, is_offline=False
|
||||||
|
)
|
||||||
|
|
||||||
common_metrics.update(eval_metrics)
|
if step > 0 and step & cfg.eval_freq == 0:
|
||||||
L.log(common_metrics, category="eval")
|
eval_policy_and_log(
|
||||||
last_log_step = env_step - env_step % cfg.eval_freq
|
env,
|
||||||
|
td_policy,
|
||||||
|
step,
|
||||||
|
online_episode_idx,
|
||||||
|
start_time,
|
||||||
|
is_offline,
|
||||||
|
cfg,
|
||||||
|
L,
|
||||||
|
)
|
||||||
|
|
||||||
# Save model periodically
|
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
||||||
if cfg.save_model and env_step - last_save_step >= cfg.save_freq:
|
print(f"Checkpoint model at step {step}")
|
||||||
L.save_model(policy, identifier=env_step)
|
L.save_model(policy, identifier=step)
|
||||||
print(f"Model has been checkpointed at step {env_step}")
|
|
||||||
last_save_step = env_step - env_step % cfg.save_freq
|
|
||||||
|
|
||||||
if cfg.save_model and is_offline and _step >= cfg.offline_steps:
|
step += 1
|
||||||
# save the model after offline training
|
|
||||||
L.save_model(policy, identifier="offline")
|
|
||||||
|
|
||||||
step = _step
|
|
||||||
|
|
||||||
# dataset_d4rl = D4RLExperienceReplay(
|
|
||||||
# dataset_id="maze2d-umaze-v1",
|
|
||||||
# split_trajs=False,
|
|
||||||
# batch_size=1,
|
|
||||||
# sampler=SamplerWithoutReplacement(drop_last=False),
|
|
||||||
# prefetch=4,
|
|
||||||
# direct_download=True,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# dataset_openx = OpenXExperienceReplay(
|
|
||||||
# "cmu_stretch",
|
|
||||||
# batch_size=1,
|
|
||||||
# num_slices=1,
|
|
||||||
# #download="force",
|
|
||||||
# streaming=False,
|
|
||||||
# root="data",
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -6,12 +6,19 @@ from .utils import init_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name",
|
"env_name,policy_name",
|
||||||
[
|
[
|
||||||
"simxarm",
|
("simxarm", "tdmpc"),
|
||||||
"pusht",
|
("pusht", "tdmpc"),
|
||||||
|
("simxarm", "diffusion"),
|
||||||
|
("pusht", "diffusion"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_factory(env_name):
|
def test_factory(env_name, policy_name):
|
||||||
cfg = init_config(overrides=[f"env={env_name}"])
|
cfg = init_config(
|
||||||
|
overrides=[
|
||||||
|
f"env={env_name}",
|
||||||
|
f"policy={policy_name}",
|
||||||
|
]
|
||||||
|
)
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
|
Loading…
Reference in New Issue