Eval reproduced! Train running (but not reproduced)
This commit is contained in:
parent
937b2f8cba
commit
228c045674
|
@ -5,6 +5,7 @@ wandb
|
|||
data
|
||||
outputs
|
||||
.vscode
|
||||
rl
|
||||
|
||||
# HPC
|
||||
nautilus/*.yaml
|
||||
|
|
|
@ -15,6 +15,7 @@ conda activate lerobot
|
|||
python setup.py develop
|
||||
```
|
||||
|
||||
|
||||
## Contribute
|
||||
|
||||
**style**
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import (
|
||||
TensorDictPrioritizedReplayBuffer,
|
||||
TensorDictReplayBuffer,
|
||||
)
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
Sampler,
|
||||
SliceSampler,
|
||||
SliceSamplerWithoutReplacement,
|
||||
)
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
|
||||
|
||||
class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
num_slices: int = None,
|
||||
slice_len: int = None,
|
||||
pad: float = None,
|
||||
replacement: bool = None,
|
||||
streaming: bool = False,
|
||||
root: Path = None,
|
||||
download: bool = False,
|
||||
sampler: Sampler = None,
|
||||
writer: Writer = None,
|
||||
collate_fn: Callable = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
transform: "torchrl.envs.Transform" = None, # noqa-F821
|
||||
split_trajs: bool = False,
|
||||
strict_length: bool = True,
|
||||
):
|
||||
self.download = download
|
||||
if streaming:
|
||||
raise NotImplementedError
|
||||
self.streaming = streaming
|
||||
self.dataset_id = dataset_id
|
||||
self.split_trajs = split_trajs
|
||||
self.shuffle = shuffle
|
||||
self.num_slices = num_slices
|
||||
self.slice_len = slice_len
|
||||
self.pad = pad
|
||||
|
||||
self.strict_length = strict_length
|
||||
if (self.num_slices is not None) and (self.slice_len is not None):
|
||||
raise ValueError("num_slices or slice_len can be not None, but not both.")
|
||||
if split_trajs:
|
||||
raise NotImplementedError
|
||||
|
||||
if root is None:
|
||||
root = _get_root_dir("simxarm")
|
||||
os.makedirs(root, exist_ok=True)
|
||||
self.root = Path(root)
|
||||
if self.download == "force" or (self.download and not self._is_downloaded()):
|
||||
storage = self._download_and_preproc()
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||
|
||||
if num_slices is not None or slice_len is not None:
|
||||
if sampler is not None:
|
||||
raise ValueError(
|
||||
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
|
||||
)
|
||||
|
||||
if replacement:
|
||||
if not self.shuffle:
|
||||
raise RuntimeError(
|
||||
"shuffle=False can only be used when replacement=False."
|
||||
)
|
||||
sampler = SliceSampler(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
strict_length=strict_length,
|
||||
)
|
||||
else:
|
||||
sampler = SliceSamplerWithoutReplacement(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
strict_length=strict_length,
|
||||
shuffle=self.shuffle,
|
||||
)
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
if collate_fn is None:
|
||||
collate_fn = _collate_id
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=writer,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def data_path_root(self):
|
||||
if self.streaming:
|
||||
return None
|
||||
return self.root / self.dataset_id
|
||||
|
||||
def _is_downloaded(self):
|
||||
return os.path.exists(self.data_path_root)
|
||||
|
||||
def _download_and_preproc(self):
|
||||
# download
|
||||
# TODO(rcadene)
|
||||
|
||||
# load
|
||||
dataset_dir = Path("data") / self.dataset_id
|
||||
dataset_path = dataset_dir / f"buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
|
||||
total_frames = dataset_dict["actions"].shape[0]
|
||||
|
||||
idx0 = 0
|
||||
idx1 = 0
|
||||
episode_id = 0
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
idx1 += 1
|
||||
|
||||
if not dataset_dict["dones"][i]:
|
||||
continue
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||
next_image = torch.tensor(
|
||||
dataset_dict["next_observations"]["rgb"][idx0:idx1]
|
||||
)
|
||||
next_state = torch.tensor(
|
||||
dataset_dict["next_observations"]["state"][idx0:idx1]
|
||||
)
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||
|
||||
episode = TensorDict(
|
||||
{
|
||||
("observation", "image"): image,
|
||||
("observation", "state"): state,
|
||||
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
("next", "observation", "image"): next_image,
|
||||
("next", "observation", "state"): next_state,
|
||||
("next", "observation", "reward"): next_reward,
|
||||
("next", "observation", "done"): next_done,
|
||||
},
|
||||
batch_size=num_frames,
|
||||
)
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = (
|
||||
episode[0]
|
||||
.expand(total_frames)
|
||||
.memmap_like(self.root / self.dataset_id)
|
||||
)
|
||||
|
||||
td_data[idx0:idx1] = episode
|
||||
|
||||
episode_id += 1
|
||||
idx0 = idx1
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
|
@ -7,6 +7,7 @@ def make_env(cfg):
|
|||
assert cfg.env == "simxarm"
|
||||
env = SimxarmEnv(
|
||||
task=cfg.task,
|
||||
frame_skip=cfg.action_repeat,
|
||||
from_pixels=cfg.from_pixels,
|
||||
pixels_only=cfg.pixels_only,
|
||||
image_size=cfg.image_size,
|
||||
|
|
|
@ -24,6 +24,7 @@ class SimxarmEnv(EnvBase):
|
|||
def __init__(
|
||||
self,
|
||||
task,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
|
@ -32,6 +33,7 @@ class SimxarmEnv(EnvBase):
|
|||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
|
@ -115,12 +117,15 @@ class SimxarmEnv(EnvBase):
|
|||
# step expects shape=(4,) so we pad if necessary
|
||||
action = np.concatenate([action, self._action_padding])
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
for t in range(self.frame_skip):
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
sum_reward += reward
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
"reward": torch.tensor([reward], dtype=torch.float32),
|
||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([info["success"]], dtype=torch.bool),
|
||||
},
|
||||
|
|
|
@ -0,0 +1,243 @@
|
|||
import datetime
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from omegaconf import OmegaConf
|
||||
from termcolor import colored
|
||||
|
||||
CONSOLE_FORMAT = [
|
||||
("episode", "E", "int"),
|
||||
("env_step", "S", "int"),
|
||||
("avg_reward", "R", "float"),
|
||||
("pc_success", "R", "float"),
|
||||
("total_time", "T", "time"),
|
||||
]
|
||||
AGENT_METRICS = [
|
||||
"consistency_loss",
|
||||
"reward_loss",
|
||||
"value_loss",
|
||||
"total_loss",
|
||||
"weighted_loss",
|
||||
"pi_loss",
|
||||
"grad_norm",
|
||||
]
|
||||
|
||||
|
||||
def make_dir(dir_path):
|
||||
"""Create directory if it does not already exist."""
|
||||
try:
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
return dir_path
|
||||
|
||||
|
||||
def print_run(cfg, reward=None):
|
||||
"""Pretty-printing of run information. Call at start of training."""
|
||||
prefix, color, attrs = " ", "green", ["bold"]
|
||||
|
||||
def limstr(s, maxlen=32):
|
||||
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
|
||||
|
||||
def pprint(k, v):
|
||||
print(
|
||||
prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs),
|
||||
limstr(v),
|
||||
)
|
||||
|
||||
kvs = [
|
||||
("task", cfg.task),
|
||||
("train steps", f"{int(cfg.train_steps * cfg.action_repeat):,}"),
|
||||
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
|
||||
# ('actions', cfg.action_dim),
|
||||
# ('experiment', cfg.exp_name),
|
||||
]
|
||||
if reward is not None:
|
||||
kvs.append(
|
||||
("episode reward", colored(str(int(reward)), "white", attrs=["bold"]))
|
||||
)
|
||||
w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
|
||||
div = "-" * w
|
||||
print(div)
|
||||
for k, v in kvs:
|
||||
pprint(k, v)
|
||||
print(div)
|
||||
|
||||
|
||||
def cfg_to_group(cfg, return_list=False):
|
||||
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
||||
lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
class VideoRecorder:
|
||||
"""Utility class for logging evaluation videos."""
|
||||
|
||||
def __init__(self, root_dir, wandb, render_size=384, fps=15):
|
||||
self.save_dir = (root_dir / "eval_video") if root_dir else None
|
||||
self._wandb = wandb
|
||||
self.render_size = render_size
|
||||
self.fps = fps
|
||||
self.frames = []
|
||||
self.enabled = False
|
||||
self.camera_id = 0
|
||||
|
||||
def init(self, env, enabled=True):
|
||||
self.frames = []
|
||||
self.enabled = self.save_dir and self._wandb and enabled
|
||||
try:
|
||||
env_name = env.unwrapped.spec.id
|
||||
except:
|
||||
env_name = ""
|
||||
if "maze2d" in env_name:
|
||||
self.camera_id = -1
|
||||
elif "quadruped" in env_name:
|
||||
self.camera_id = 2
|
||||
self.record(env)
|
||||
|
||||
def record(self, env):
|
||||
if self.enabled:
|
||||
frame = env.render(
|
||||
mode="rgb_array",
|
||||
height=self.render_size,
|
||||
width=self.render_size,
|
||||
camera_id=self.camera_id,
|
||||
)
|
||||
self.frames.append(frame)
|
||||
|
||||
def save(self, step):
|
||||
if self.enabled:
|
||||
frames = np.stack(self.frames).transpose(0, 3, 1, 2)
|
||||
self._wandb.log(
|
||||
{"eval_video": self._wandb.Video(frames, fps=self.fps, format="mp4")},
|
||||
step=step,
|
||||
)
|
||||
|
||||
|
||||
class Logger(object):
|
||||
"""Primary logger object. Logs either locally or using wandb."""
|
||||
|
||||
def __init__(self, log_dir, cfg):
|
||||
self._log_dir = make_dir(Path(log_dir))
|
||||
self._model_dir = make_dir(self._log_dir / "models")
|
||||
self._buffer_dir = make_dir(self._log_dir / "buffers")
|
||||
self._save_model = cfg.save_model
|
||||
self._save_buffer = cfg.save_buffer
|
||||
self._group = cfg_to_group(cfg)
|
||||
self._seed = cfg.seed
|
||||
self._cfg = cfg
|
||||
self._eval = []
|
||||
print_run(cfg)
|
||||
project, entity = cfg.get("wandb_project", "none"), cfg.get(
|
||||
"wandb_entity", "none"
|
||||
)
|
||||
run_offline = (
|
||||
not cfg.get("use_wandb", False) or project == "none" or entity == "none"
|
||||
)
|
||||
if run_offline:
|
||||
print(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
self._wandb = None
|
||||
else:
|
||||
try:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
import wandb
|
||||
|
||||
wandb.init(
|
||||
project=project,
|
||||
entity=entity,
|
||||
name=str(cfg.seed),
|
||||
notes=cfg.notes,
|
||||
group=self._group,
|
||||
tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
|
||||
dir=self._log_dir,
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
)
|
||||
print(
|
||||
colored("Logs will be synced with wandb.", "blue", attrs=["bold"])
|
||||
)
|
||||
self._wandb = wandb
|
||||
except:
|
||||
print(
|
||||
colored(
|
||||
"Warning: failed to init wandb. Make sure `wandb_entity` is set to your username in `config.yaml`. Logs will be saved locally.",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
self._wandb = None
|
||||
self._video = (
|
||||
VideoRecorder(log_dir, self._wandb)
|
||||
if self._wandb and cfg.save_video
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def video(self):
|
||||
return self._video
|
||||
|
||||
def save_model(self, agent, identifier):
|
||||
if self._save_model:
|
||||
fp = self._model_dir / f"{str(identifier)}.pt"
|
||||
agent.save(fp)
|
||||
if self._wandb:
|
||||
artifact = self._wandb.Artifact(
|
||||
self._group + "-" + str(self._seed) + "-" + str(identifier),
|
||||
type="model",
|
||||
)
|
||||
artifact.add_file(fp)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def save_buffer(self, buffer, identifier):
|
||||
fp = self._buffer_dir / f"{str(identifier)}.pkl"
|
||||
buffer.save(fp)
|
||||
if self._wandb:
|
||||
artifact = self._wandb.Artifact(
|
||||
self._group + "-" + str(self._seed) + "-" + str(identifier),
|
||||
type="buffer",
|
||||
)
|
||||
artifact.add_file(fp)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def finish(self, agent, buffer):
|
||||
if self._save_model:
|
||||
self.save_model(agent, identifier="final")
|
||||
if self._save_buffer:
|
||||
self.save_buffer(buffer, identifier="buffer")
|
||||
if self._wandb:
|
||||
self._wandb.finish()
|
||||
print_run(self._cfg, self._eval[-1][-1])
|
||||
|
||||
def _format(self, key, value, ty):
|
||||
if ty == "int":
|
||||
return f'{colored(key + ":", "grey")} {int(value):,}'
|
||||
elif ty == "float":
|
||||
return f'{colored(key + ":", "grey")} {value:.01f}'
|
||||
elif ty == "time":
|
||||
value = str(datetime.timedelta(seconds=int(value)))
|
||||
return f'{colored(key + ":", "grey")} {value}'
|
||||
else:
|
||||
raise f"invalid log format type: {ty}"
|
||||
|
||||
def _print(self, d, category):
|
||||
category = colored(category, "blue" if category == "train" else "green")
|
||||
pieces = [f" {category:<14}"]
|
||||
for k, disp_k, ty in CONSOLE_FORMAT:
|
||||
pieces.append(f"{self._format(disp_k, d.get(k, 0), ty):<26}")
|
||||
print(" ".join(pieces))
|
||||
|
||||
def log(self, d, category="train"):
|
||||
assert category in {"train", "eval"}
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
self._wandb.log({category + "/" + k: v}, step=d["env_step"])
|
||||
if category == "eval":
|
||||
# keys = ['env_step', 'avg_reward']
|
||||
keys = ["env_step", "avg_reward", "pc_success"]
|
||||
self._eval.append(np.array([d[key] for key in keys]))
|
||||
pd.DataFrame(np.array(self._eval)).to_csv(
|
||||
self._log_dir / "eval.log", header=keys, index=None
|
||||
)
|
||||
self._print(d, category)
|
|
@ -1,5 +1,6 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -90,7 +91,7 @@ class TDMPC(nn.Module):
|
|||
self.model_target = deepcopy(self.model)
|
||||
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
||||
self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.model.eval()
|
||||
self.model_target.eval()
|
||||
self.batch_size = cfg.batch_size
|
||||
|
@ -308,9 +309,41 @@ class TDMPC(nn.Module):
|
|||
self.demo_batch_size = 0
|
||||
|
||||
# Sample from interaction dataset
|
||||
obs, next_obses, action, reward, mask, done, idxs, weights = (
|
||||
replay_buffer.sample()
|
||||
|
||||
# to not have to mask
|
||||
# batch_size = (self.cfg.batch_size // self.cfg.horizon) * self.cfg.horizon
|
||||
batch_size = self.cfg.horizon * self.cfg.batch_size
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
# trajectory t = 256, horizon h = 5
|
||||
# (t h) ... -> h t ...
|
||||
batch = (
|
||||
batch.reshape(self.cfg.batch_size, self.cfg.horizon)
|
||||
.transpose(1, 0)
|
||||
.contiguous()
|
||||
)
|
||||
batch = batch.to("cuda")
|
||||
|
||||
FIRST_FRAME = 0
|
||||
obs = {
|
||||
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
|
||||
"state": batch["observation", "state"][FIRST_FRAME],
|
||||
}
|
||||
action = batch["action"]
|
||||
next_obses = {
|
||||
"rgb": batch["next", "observation", "image"].float(),
|
||||
"state": batch["next", "observation", "state"],
|
||||
}
|
||||
reward = batch["next", "reward"]
|
||||
reward = einops.rearrange(reward, "h t -> h t 1")
|
||||
# We dont use `batch["next", "done"]` since it only indicates the end of an
|
||||
# episode, but not the end of the trajectory of an episode.
|
||||
# Neither does `batch["next", "terminated"]`
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
|
||||
idxs = batch["frame_id"][FIRST_FRAME]
|
||||
weights = batch["_weight"][FIRST_FRAME, :, None]
|
||||
|
||||
# Sample from demonstration dataset
|
||||
if self.demo_batch_size > 0:
|
||||
|
@ -341,6 +374,21 @@ class TDMPC(nn.Module):
|
|||
idxs = torch.cat([idxs, demo_idxs])
|
||||
weights = torch.cat([weights, demo_weights])
|
||||
|
||||
# Apply augmentations
|
||||
aug_tf = h.aug(self.cfg)
|
||||
obs = aug_tf(obs)
|
||||
|
||||
for k in next_obses:
|
||||
next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...")
|
||||
next_obses = aug_tf(next_obses)
|
||||
for k in next_obses:
|
||||
next_obses[k] = einops.rearrange(
|
||||
next_obses[k],
|
||||
"(h t) ... -> h t ...",
|
||||
h=self.cfg.horizon,
|
||||
t=self.cfg.batch_size,
|
||||
)
|
||||
|
||||
horizon = self.cfg.horizon
|
||||
loss_mask = torch.ones_like(mask, device=self.device)
|
||||
for t in range(1, horizon):
|
||||
|
@ -407,6 +455,7 @@ class TDMPC(nn.Module):
|
|||
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
||||
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
||||
weighted_loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||
)
|
||||
|
@ -415,13 +464,16 @@ class TDMPC(nn.Module):
|
|||
if self.cfg.per:
|
||||
# Update priorities
|
||||
priorities = priority_loss.clamp(max=1e4).detach()
|
||||
replay_buffer.update_priorities(
|
||||
idxs[: replay_buffer.cfg.batch_size],
|
||||
priorities[: replay_buffer.cfg.batch_size],
|
||||
# normalize between [0,1] to fit torchrl specification
|
||||
priorities /= 1e4
|
||||
priorities = priorities.clamp(max=1.0)
|
||||
replay_buffer.update_priority(
|
||||
idxs[: self.cfg.batch_size],
|
||||
priorities[: self.cfg.batch_size],
|
||||
)
|
||||
if self.demo_batch_size > 0:
|
||||
demo_buffer.update_priorities(
|
||||
demo_idxs, priorities[replay_buffer.cfg.batch_size :]
|
||||
demo_buffer.update_priority(
|
||||
demo_idxs, priorities[self.cfg.batch_size :]
|
||||
)
|
||||
|
||||
# Update policy + target network
|
||||
|
|
|
@ -306,13 +306,21 @@ class RandomShiftsAug(nn.Module):
|
|||
x = F.pad(x, padding, "replicate")
|
||||
eps = 1.0 / (h + 2 * self.pad)
|
||||
arange = torch.linspace(
|
||||
-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype
|
||||
-1.0 + eps,
|
||||
1.0 - eps,
|
||||
h + 2 * self.pad,
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)[:h]
|
||||
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
|
||||
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
|
||||
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
|
||||
shift = torch.randint(
|
||||
0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype
|
||||
0,
|
||||
2 * self.pad + 1,
|
||||
size=(n, 1, 1, 2),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
shift *= 2.0 / (h + 2 * self.pad)
|
||||
grid = base_grid + shift
|
||||
|
|
|
@ -1,5 +1,14 @@
|
|||
seed: 1337
|
||||
log_dir: logs/2024_01_26_train
|
||||
exp_name: default
|
||||
device: cuda
|
||||
buffer_device: cuda
|
||||
eval_freq: 1000
|
||||
save_freq: 10000
|
||||
eval_episodes: 20
|
||||
save_video: false
|
||||
save_model: false
|
||||
save_buffer: false
|
||||
|
||||
# env
|
||||
env: simxarm
|
||||
|
@ -8,6 +17,14 @@ from_pixels: True
|
|||
pixels_only: False
|
||||
image_size: 84
|
||||
|
||||
reward_scale: 1.0
|
||||
|
||||
# xarm_lift
|
||||
episode_length: 25
|
||||
modality: 'all'
|
||||
action_repeat: 2 # TODO(rcadene): verify we use this
|
||||
discount: 0.9
|
||||
train_steps: 50000
|
||||
|
||||
# pixels
|
||||
frame_stack: 1
|
||||
|
@ -54,6 +71,19 @@ update_freq: 2
|
|||
tau: 0.01
|
||||
utd: 1
|
||||
|
||||
# offline rl
|
||||
# dataset_dir: ???
|
||||
data_first_percent: 1.0
|
||||
is_data_clip: true
|
||||
data_clip_eps: 1e-5
|
||||
expectile: 0.9
|
||||
A_scaling: 3.0
|
||||
|
||||
# offline->online
|
||||
offline_steps: ${train_steps}/2
|
||||
pretrained_model_path: ""
|
||||
balanced_sampling: true
|
||||
demo_schedule: 0.5
|
||||
|
||||
# architecture
|
||||
enc_dim: 256
|
||||
|
@ -61,11 +91,8 @@ num_q: 5
|
|||
mlp_dim: 512
|
||||
latent_dim: 50
|
||||
|
||||
# wandb
|
||||
use_wandb: false
|
||||
wandb_project: FOWM
|
||||
wandb_entity: rcadene # insert your own
|
||||
|
||||
# xarm_lift
|
||||
A_scaling: 3.0
|
||||
expectile: 0.9
|
||||
episode_length: 25
|
||||
modality: 'all'
|
||||
action_repeat: 2
|
||||
discount: 0.9
|
|
@ -32,23 +32,25 @@ def eval_policy(
|
|||
ep_frames.append(frame)
|
||||
|
||||
tensordict = env.reset()
|
||||
if save_video:
|
||||
# render first frame before rollout
|
||||
rendering_callback(env)
|
||||
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
callback=rendering_callback,
|
||||
callback=rendering_callback if save_video else None,
|
||||
auto_reset=False,
|
||||
tensordict=tensordict,
|
||||
)
|
||||
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
||||
ep_reward = rollout["next", "reward"].sum()
|
||||
ep_success = rollout["next", "success"].any()
|
||||
rewards.append(ep_reward.item())
|
||||
successes.append(ep_success.item())
|
||||
|
||||
if save_video:
|
||||
video_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
# TODO(rcadene): make fps configurable
|
||||
video_path = video_dir / f"eval_episode_{i}.mp4"
|
||||
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
|
||||
|
@ -82,8 +84,8 @@ def eval(cfg: dict):
|
|||
metrics = eval_policy(
|
||||
env,
|
||||
policy=policy,
|
||||
num_episodes=10,
|
||||
save_video=True,
|
||||
num_episodes=20,
|
||||
save_video=False,
|
||||
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
||||
)
|
||||
print(metrics)
|
||||
|
|
|
@ -1,11 +1,24 @@
|
|||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict.nn import TensorDictModule
|
||||
from termcolor import colored
|
||||
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
||||
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
|
||||
from torchrl.data.datasets.openx import OpenXExperienceReplay
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.tdmpc import TDMPC
|
||||
|
||||
from ..common.utils import set_seed
|
||||
from lerobot.common.utils import set_seed
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||
|
@ -15,22 +28,169 @@ def train(cfg: dict):
|
|||
print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir)
|
||||
|
||||
env = make_env(cfg)
|
||||
agent = TDMPC(cfg)
|
||||
policy = TDMPC(cfg)
|
||||
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||
agent.load(ckpt_path)
|
||||
policy.load(ckpt_path)
|
||||
|
||||
# online training
|
||||
|
||||
eval_metrics = train_agent(
|
||||
env,
|
||||
agent,
|
||||
num_episodes=10,
|
||||
save_video=True,
|
||||
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
||||
td_policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=["observation", "step_count"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
|
||||
print(eval_metrics)
|
||||
# initialize offline dataset
|
||||
|
||||
dataset_id = f"xarm_{cfg.task}_medium"
|
||||
|
||||
num_traj_per_batch = cfg.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=0.7,
|
||||
beta=0.9,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
|
||||
# TODO(rcadene): use PrioritizedReplayBuffer
|
||||
offline_buffer = SimxarmExperienceReplay(
|
||||
dataset_id,
|
||||
# download="force",
|
||||
download=True,
|
||||
streaming=False,
|
||||
root="data",
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
num_steps = len(offline_buffer)
|
||||
index = torch.arange(0, num_steps, 1)
|
||||
sampler.extend(index)
|
||||
|
||||
# offline_buffer._storage.device = torch.device("cuda")
|
||||
# offline_buffer._storage._storage.to(torch.device("cuda"))
|
||||
# TODO(rcadene): add online_buffer
|
||||
|
||||
# Observation encoder
|
||||
# Dynamics predictor
|
||||
# Reward predictor
|
||||
# Policy
|
||||
# Qs state-action value predictor
|
||||
# V state value predictor
|
||||
|
||||
L = Logger(cfg.log_dir, cfg)
|
||||
|
||||
episode_idx = 0
|
||||
start_time = time.time()
|
||||
step = 0
|
||||
last_log_step = 0
|
||||
last_save_step = 0
|
||||
|
||||
while step < cfg.train_steps:
|
||||
is_offline = True
|
||||
num_updates = cfg.episode_length
|
||||
_step = step + num_updates
|
||||
rollout_metrics = {}
|
||||
|
||||
# if step >= cfg.offline_steps:
|
||||
# is_offline = False
|
||||
|
||||
# # Collect trajectory
|
||||
# obs = env.reset()
|
||||
# episode = Episode(cfg, obs)
|
||||
# success = False
|
||||
# while not episode.done:
|
||||
# action = policy.act(obs, step=step, t0=episode.first)
|
||||
# obs, reward, done, info = env.step(action.cpu().numpy())
|
||||
# reward = reward_normalizer(reward)
|
||||
# mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0
|
||||
# success = info.get('success', False)
|
||||
# episode += (obs, action, reward, done, mask, success)
|
||||
# assert len(episode) <= cfg.episode_length
|
||||
# buffer += episode
|
||||
# episode_idx += 1
|
||||
# rollout_metrics = {
|
||||
# 'episode_reward': episode.cumulative_reward,
|
||||
# 'episode_success': float(success),
|
||||
# 'episode_length': len(episode)
|
||||
# }
|
||||
# num_updates = len(episode) * cfg.utd
|
||||
# _step = min(step + len(episode), cfg.train_steps)
|
||||
|
||||
# Update model
|
||||
train_metrics = {}
|
||||
if is_offline:
|
||||
for i in range(num_updates):
|
||||
train_metrics.update(policy.update(offline_buffer, step + i))
|
||||
# else:
|
||||
# for i in range(num_updates):
|
||||
# train_metrics.update(
|
||||
# policy.update(buffer, step + i // cfg.utd,
|
||||
# demo_buffer=offline_buffer if cfg.balanced_sampling else None)
|
||||
# )
|
||||
|
||||
# Log training metrics
|
||||
env_step = int(_step * cfg.action_repeat)
|
||||
common_metrics = {
|
||||
"episode": episode_idx,
|
||||
"step": _step,
|
||||
"env_step": env_step,
|
||||
"total_time": time.time() - start_time,
|
||||
"is_offline": float(is_offline),
|
||||
}
|
||||
train_metrics.update(common_metrics)
|
||||
train_metrics.update(rollout_metrics)
|
||||
L.log(train_metrics, category="train")
|
||||
|
||||
# Evaluate policy periodically
|
||||
if step == 0 or env_step - last_log_step >= cfg.eval_freq:
|
||||
|
||||
eval_metrics = eval_policy(
|
||||
env,
|
||||
td_policy,
|
||||
num_episodes=cfg.eval_episodes,
|
||||
# TODO(rcadene): add step, env_step, L.video
|
||||
)
|
||||
|
||||
# TODO(rcadene):
|
||||
# if hasattr(env, "get_normalized_score"):
|
||||
# eval_metrics['normalized_score'] = env.get_normalized_score(eval_metrics["episode_reward"]) * 100.0
|
||||
|
||||
common_metrics.update(eval_metrics)
|
||||
|
||||
L.log(common_metrics, category="eval")
|
||||
last_log_step = env_step - env_step % cfg.eval_freq
|
||||
|
||||
# Save model periodically
|
||||
# if cfg.save_model and env_step - last_save_step >= cfg.save_freq:
|
||||
# L.save_model(policy, identifier=env_step)
|
||||
# print(f"Model has been checkpointed at step {env_step}")
|
||||
# last_save_step = env_step - env_step % cfg.save_freq
|
||||
|
||||
# if cfg.save_model and is_offline and _step >= cfg.offline_steps:
|
||||
# # save the model after offline training
|
||||
# L.save_model(policy, identifier="offline")
|
||||
|
||||
step = _step
|
||||
|
||||
# dataset_d4rl = D4RLExperienceReplay(
|
||||
# dataset_id="maze2d-umaze-v1",
|
||||
# split_trajs=False,
|
||||
# batch_size=1,
|
||||
# sampler=SamplerWithoutReplacement(drop_last=False),
|
||||
# prefetch=4,
|
||||
# direct_download=True,
|
||||
# )
|
||||
|
||||
# dataset_openx = OpenXExperienceReplay(
|
||||
# "cmu_stretch",
|
||||
# batch_size=1,
|
||||
# num_slices=1,
|
||||
# #download="force",
|
||||
# streaming=False,
|
||||
# root="data",
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import imageio
|
||||
import simxarm
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
task = "lift"
|
||||
dataset_dir = Path(f"data/xarm_{task}_medium")
|
||||
dataset_path = dataset_dir / f"buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
|
||||
required_keys = [
|
||||
"observations",
|
||||
"next_observations",
|
||||
"actions",
|
||||
"rewards",
|
||||
"dones",
|
||||
"masks",
|
||||
]
|
||||
for k in required_keys:
|
||||
if k not in dataset_dict and k[:-1] in dataset_dict:
|
||||
dataset_dict[k] = dataset_dict.pop(k[:-1])
|
||||
|
||||
out_dir = Path("tmp/2023_01_26_xarm_lift_medium")
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
frames = dataset_dict["observations"]["rgb"][:100]
|
||||
frames = frames.transpose(0, 2, 3, 1)
|
||||
imageio.mimsave(out_dir / "test.mp4", frames, fps=30)
|
||||
|
||||
frames = []
|
||||
cfg = {}
|
||||
|
||||
env = simxarm.make(
|
||||
task=task,
|
||||
obs_mode="all",
|
||||
image_size=84,
|
||||
action_repeat=cfg.get("action_repeat", 1),
|
||||
frame_stack=cfg.get("frame_stack", 1),
|
||||
seed=1,
|
||||
)
|
||||
|
||||
obs = env.reset()
|
||||
frame = env.render(mode="rgb_array", width=384, height=384)
|
||||
frames.append(frame)
|
||||
|
||||
# def is_first_obs(obs):
|
||||
# nonlocal first_obs
|
||||
# print(((dataset_dict["observations"]["state"][i]-obs["state"])**2).sum())
|
||||
# print(((dataset_dict["observations"]["rgb"][i]-obs["rgb"])**2).sum())
|
||||
|
||||
for i in range(25):
|
||||
action = dataset_dict["actions"][i]
|
||||
|
||||
print(f"#{i}")
|
||||
# print(obs["state"])
|
||||
# print(dataset_dict["observations"]["state"][i])
|
||||
print(((dataset_dict["observations"]["state"][i] - obs["state"]) ** 2).sum())
|
||||
print(((dataset_dict["observations"]["rgb"][i] - obs["rgb"]) ** 2).sum())
|
||||
|
||||
obs, reward, done, info = env.step(action)
|
||||
frame = env.render(mode="rgb_array", width=384, height=384)
|
||||
frames.append(frame)
|
||||
|
||||
print(reward)
|
||||
print(dataset_dict["rewards"][i])
|
||||
|
||||
print(done)
|
||||
print(dataset_dict["dones"][i])
|
||||
|
||||
if dataset_dict["dones"][i]:
|
||||
obs = env.reset()
|
||||
frame = env.render(mode="rgb_array", width=384, height=384)
|
||||
frames.append(frame)
|
||||
|
||||
# imageio.mimsave(out_dir / 'test_rollout.mp4', frames, fps=60)
|
|
@ -0,0 +1,59 @@
|
|||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import imageio
|
||||
import simxarm
|
||||
import torch
|
||||
from torchrl.data.replay_buffers import (
|
||||
SamplerWithoutReplacement,
|
||||
SliceSampler,
|
||||
SliceSamplerWithoutReplacement,
|
||||
)
|
||||
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
|
||||
|
||||
def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"):
|
||||
sampler = SliceSamplerWithoutReplacement(
|
||||
num_slices=1,
|
||||
strict_length=False,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
dataset = SimxarmExperienceReplay(
|
||||
dataset_id,
|
||||
# download="force",
|
||||
download=True,
|
||||
streaming=False,
|
||||
root="data",
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
NUM_EPISODES_TO_RENDER = 10
|
||||
MAX_NUM_STEPS = 50
|
||||
FIRST_FRAME = 0
|
||||
for _ in range(NUM_EPISODES_TO_RENDER):
|
||||
episode = dataset.sample(MAX_NUM_STEPS)
|
||||
|
||||
ep_idx = episode["episode"][FIRST_FRAME].item()
|
||||
ep_frames = torch.cat(
|
||||
[
|
||||
episode["observation"]["image"][FIRST_FRAME][None, ...],
|
||||
episode["next", "observation"]["image"],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
video_dir = Path("tmp/2024_02_03_xarm_lift_medium")
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
# TODO(rcadene): make fps configurable
|
||||
video_path = video_dir / f"eval_episode_{ep_idx}.mp4"
|
||||
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=15)
|
||||
|
||||
# ran out of episodes
|
||||
if dataset._sampler._sample_list.numel() == 0:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
visualize_simxarm_dataset()
|
Loading…
Reference in New Issue