Eval reproduced! Train running (but not reproduced)
This commit is contained in:
parent
937b2f8cba
commit
228c045674
|
@ -5,6 +5,7 @@ wandb
|
||||||
data
|
data
|
||||||
outputs
|
outputs
|
||||||
.vscode
|
.vscode
|
||||||
|
rl
|
||||||
|
|
||||||
# HPC
|
# HPC
|
||||||
nautilus/*.yaml
|
nautilus/*.yaml
|
||||||
|
|
|
@ -15,6 +15,7 @@ conda activate lerobot
|
||||||
python setup.py develop
|
python setup.py develop
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Contribute
|
## Contribute
|
||||||
|
|
||||||
**style**
|
**style**
|
||||||
|
|
|
@ -0,0 +1,190 @@
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchrl
|
||||||
|
import tqdm
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.datasets.utils import _get_root_dir
|
||||||
|
from torchrl.data.replay_buffers.replay_buffers import (
|
||||||
|
TensorDictPrioritizedReplayBuffer,
|
||||||
|
TensorDictReplayBuffer,
|
||||||
|
)
|
||||||
|
from torchrl.data.replay_buffers.samplers import (
|
||||||
|
Sampler,
|
||||||
|
SliceSampler,
|
||||||
|
SliceSamplerWithoutReplacement,
|
||||||
|
)
|
||||||
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||||
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||||
|
|
||||||
|
|
||||||
|
class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||||
|
|
||||||
|
available_datasets = [
|
||||||
|
"xarm_lift_medium",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_id,
|
||||||
|
batch_size: int = None,
|
||||||
|
*,
|
||||||
|
shuffle: bool = True,
|
||||||
|
num_slices: int = None,
|
||||||
|
slice_len: int = None,
|
||||||
|
pad: float = None,
|
||||||
|
replacement: bool = None,
|
||||||
|
streaming: bool = False,
|
||||||
|
root: Path = None,
|
||||||
|
download: bool = False,
|
||||||
|
sampler: Sampler = None,
|
||||||
|
writer: Writer = None,
|
||||||
|
collate_fn: Callable = None,
|
||||||
|
pin_memory: bool = False,
|
||||||
|
prefetch: int = None,
|
||||||
|
transform: "torchrl.envs.Transform" = None, # noqa-F821
|
||||||
|
split_trajs: bool = False,
|
||||||
|
strict_length: bool = True,
|
||||||
|
):
|
||||||
|
self.download = download
|
||||||
|
if streaming:
|
||||||
|
raise NotImplementedError
|
||||||
|
self.streaming = streaming
|
||||||
|
self.dataset_id = dataset_id
|
||||||
|
self.split_trajs = split_trajs
|
||||||
|
self.shuffle = shuffle
|
||||||
|
self.num_slices = num_slices
|
||||||
|
self.slice_len = slice_len
|
||||||
|
self.pad = pad
|
||||||
|
|
||||||
|
self.strict_length = strict_length
|
||||||
|
if (self.num_slices is not None) and (self.slice_len is not None):
|
||||||
|
raise ValueError("num_slices or slice_len can be not None, but not both.")
|
||||||
|
if split_trajs:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if root is None:
|
||||||
|
root = _get_root_dir("simxarm")
|
||||||
|
os.makedirs(root, exist_ok=True)
|
||||||
|
self.root = Path(root)
|
||||||
|
if self.download == "force" or (self.download and not self._is_downloaded()):
|
||||||
|
storage = self._download_and_preproc()
|
||||||
|
else:
|
||||||
|
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||||
|
|
||||||
|
if num_slices is not None or slice_len is not None:
|
||||||
|
if sampler is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
|
||||||
|
)
|
||||||
|
|
||||||
|
if replacement:
|
||||||
|
if not self.shuffle:
|
||||||
|
raise RuntimeError(
|
||||||
|
"shuffle=False can only be used when replacement=False."
|
||||||
|
)
|
||||||
|
sampler = SliceSampler(
|
||||||
|
num_slices=num_slices,
|
||||||
|
slice_len=slice_len,
|
||||||
|
strict_length=strict_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampler = SliceSamplerWithoutReplacement(
|
||||||
|
num_slices=num_slices,
|
||||||
|
slice_len=slice_len,
|
||||||
|
strict_length=strict_length,
|
||||||
|
shuffle=self.shuffle,
|
||||||
|
)
|
||||||
|
|
||||||
|
if writer is None:
|
||||||
|
writer = ImmutableDatasetWriter()
|
||||||
|
if collate_fn is None:
|
||||||
|
collate_fn = _collate_id
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
storage=storage,
|
||||||
|
sampler=sampler,
|
||||||
|
writer=writer,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
prefetch=prefetch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_path_root(self):
|
||||||
|
if self.streaming:
|
||||||
|
return None
|
||||||
|
return self.root / self.dataset_id
|
||||||
|
|
||||||
|
def _is_downloaded(self):
|
||||||
|
return os.path.exists(self.data_path_root)
|
||||||
|
|
||||||
|
def _download_and_preproc(self):
|
||||||
|
# download
|
||||||
|
# TODO(rcadene)
|
||||||
|
|
||||||
|
# load
|
||||||
|
dataset_dir = Path("data") / self.dataset_id
|
||||||
|
dataset_path = dataset_dir / f"buffer.pkl"
|
||||||
|
print(f"Using offline dataset '{dataset_path}'")
|
||||||
|
with open(dataset_path, "rb") as f:
|
||||||
|
dataset_dict = pickle.load(f)
|
||||||
|
|
||||||
|
total_frames = dataset_dict["actions"].shape[0]
|
||||||
|
|
||||||
|
idx0 = 0
|
||||||
|
idx1 = 0
|
||||||
|
episode_id = 0
|
||||||
|
for i in tqdm.tqdm(range(total_frames)):
|
||||||
|
idx1 += 1
|
||||||
|
|
||||||
|
if not dataset_dict["dones"][i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_frames = idx1 - idx0
|
||||||
|
|
||||||
|
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||||
|
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||||
|
next_image = torch.tensor(
|
||||||
|
dataset_dict["next_observations"]["rgb"][idx0:idx1]
|
||||||
|
)
|
||||||
|
next_state = torch.tensor(
|
||||||
|
dataset_dict["next_observations"]["state"][idx0:idx1]
|
||||||
|
)
|
||||||
|
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||||
|
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||||
|
|
||||||
|
episode = TensorDict(
|
||||||
|
{
|
||||||
|
("observation", "image"): image,
|
||||||
|
("observation", "state"): state,
|
||||||
|
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
|
||||||
|
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||||
|
"frame_id": torch.arange(0, num_frames, 1),
|
||||||
|
("next", "observation", "image"): next_image,
|
||||||
|
("next", "observation", "state"): next_state,
|
||||||
|
("next", "observation", "reward"): next_reward,
|
||||||
|
("next", "observation", "done"): next_done,
|
||||||
|
},
|
||||||
|
batch_size=num_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
if episode_id == 0:
|
||||||
|
# hack to initialize tensordict data structure to store episodes
|
||||||
|
td_data = (
|
||||||
|
episode[0]
|
||||||
|
.expand(total_frames)
|
||||||
|
.memmap_like(self.root / self.dataset_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
td_data[idx0:idx1] = episode
|
||||||
|
|
||||||
|
episode_id += 1
|
||||||
|
idx0 = idx1
|
||||||
|
|
||||||
|
return TensorStorage(td_data.lock_())
|
|
@ -7,6 +7,7 @@ def make_env(cfg):
|
||||||
assert cfg.env == "simxarm"
|
assert cfg.env == "simxarm"
|
||||||
env = SimxarmEnv(
|
env = SimxarmEnv(
|
||||||
task=cfg.task,
|
task=cfg.task,
|
||||||
|
frame_skip=cfg.action_repeat,
|
||||||
from_pixels=cfg.from_pixels,
|
from_pixels=cfg.from_pixels,
|
||||||
pixels_only=cfg.pixels_only,
|
pixels_only=cfg.pixels_only,
|
||||||
image_size=cfg.image_size,
|
image_size=cfg.image_size,
|
||||||
|
|
|
@ -24,6 +24,7 @@ class SimxarmEnv(EnvBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
task,
|
task,
|
||||||
|
frame_skip: int = 1,
|
||||||
from_pixels: bool = False,
|
from_pixels: bool = False,
|
||||||
pixels_only: bool = False,
|
pixels_only: bool = False,
|
||||||
image_size=None,
|
image_size=None,
|
||||||
|
@ -32,6 +33,7 @@ class SimxarmEnv(EnvBase):
|
||||||
):
|
):
|
||||||
super().__init__(device=device, batch_size=[])
|
super().__init__(device=device, batch_size=[])
|
||||||
self.task = task
|
self.task = task
|
||||||
|
self.frame_skip = frame_skip
|
||||||
self.from_pixels = from_pixels
|
self.from_pixels = from_pixels
|
||||||
self.pixels_only = pixels_only
|
self.pixels_only = pixels_only
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
|
@ -115,12 +117,15 @@ class SimxarmEnv(EnvBase):
|
||||||
# step expects shape=(4,) so we pad if necessary
|
# step expects shape=(4,) so we pad if necessary
|
||||||
action = np.concatenate([action, self._action_padding])
|
action = np.concatenate([action, self._action_padding])
|
||||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||||
|
sum_reward = 0
|
||||||
|
for t in range(self.frame_skip):
|
||||||
raw_obs, reward, done, info = self._env.step(action)
|
raw_obs, reward, done, info = self._env.step(action)
|
||||||
|
sum_reward += reward
|
||||||
|
|
||||||
td = TensorDict(
|
td = TensorDict(
|
||||||
{
|
{
|
||||||
"observation": self._format_raw_obs(raw_obs),
|
"observation": self._format_raw_obs(raw_obs),
|
||||||
"reward": torch.tensor([reward], dtype=torch.float32),
|
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||||
"done": torch.tensor([done], dtype=torch.bool),
|
"done": torch.tensor([done], dtype=torch.bool),
|
||||||
"success": torch.tensor([info["success"]], dtype=torch.bool),
|
"success": torch.tensor([info["success"]], dtype=torch.bool),
|
||||||
},
|
},
|
||||||
|
|
|
@ -0,0 +1,243 @@
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
CONSOLE_FORMAT = [
|
||||||
|
("episode", "E", "int"),
|
||||||
|
("env_step", "S", "int"),
|
||||||
|
("avg_reward", "R", "float"),
|
||||||
|
("pc_success", "R", "float"),
|
||||||
|
("total_time", "T", "time"),
|
||||||
|
]
|
||||||
|
AGENT_METRICS = [
|
||||||
|
"consistency_loss",
|
||||||
|
"reward_loss",
|
||||||
|
"value_loss",
|
||||||
|
"total_loss",
|
||||||
|
"weighted_loss",
|
||||||
|
"pi_loss",
|
||||||
|
"grad_norm",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def make_dir(dir_path):
|
||||||
|
"""Create directory if it does not already exist."""
|
||||||
|
try:
|
||||||
|
dir_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return dir_path
|
||||||
|
|
||||||
|
|
||||||
|
def print_run(cfg, reward=None):
|
||||||
|
"""Pretty-printing of run information. Call at start of training."""
|
||||||
|
prefix, color, attrs = " ", "green", ["bold"]
|
||||||
|
|
||||||
|
def limstr(s, maxlen=32):
|
||||||
|
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
|
||||||
|
|
||||||
|
def pprint(k, v):
|
||||||
|
print(
|
||||||
|
prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs),
|
||||||
|
limstr(v),
|
||||||
|
)
|
||||||
|
|
||||||
|
kvs = [
|
||||||
|
("task", cfg.task),
|
||||||
|
("train steps", f"{int(cfg.train_steps * cfg.action_repeat):,}"),
|
||||||
|
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
|
||||||
|
# ('actions', cfg.action_dim),
|
||||||
|
# ('experiment', cfg.exp_name),
|
||||||
|
]
|
||||||
|
if reward is not None:
|
||||||
|
kvs.append(
|
||||||
|
("episode reward", colored(str(int(reward)), "white", attrs=["bold"]))
|
||||||
|
)
|
||||||
|
w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
|
||||||
|
div = "-" * w
|
||||||
|
print(div)
|
||||||
|
for k, v in kvs:
|
||||||
|
pprint(k, v)
|
||||||
|
print(div)
|
||||||
|
|
||||||
|
|
||||||
|
def cfg_to_group(cfg, return_list=False):
|
||||||
|
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
||||||
|
lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
||||||
|
return lst if return_list else "-".join(lst)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoRecorder:
|
||||||
|
"""Utility class for logging evaluation videos."""
|
||||||
|
|
||||||
|
def __init__(self, root_dir, wandb, render_size=384, fps=15):
|
||||||
|
self.save_dir = (root_dir / "eval_video") if root_dir else None
|
||||||
|
self._wandb = wandb
|
||||||
|
self.render_size = render_size
|
||||||
|
self.fps = fps
|
||||||
|
self.frames = []
|
||||||
|
self.enabled = False
|
||||||
|
self.camera_id = 0
|
||||||
|
|
||||||
|
def init(self, env, enabled=True):
|
||||||
|
self.frames = []
|
||||||
|
self.enabled = self.save_dir and self._wandb and enabled
|
||||||
|
try:
|
||||||
|
env_name = env.unwrapped.spec.id
|
||||||
|
except:
|
||||||
|
env_name = ""
|
||||||
|
if "maze2d" in env_name:
|
||||||
|
self.camera_id = -1
|
||||||
|
elif "quadruped" in env_name:
|
||||||
|
self.camera_id = 2
|
||||||
|
self.record(env)
|
||||||
|
|
||||||
|
def record(self, env):
|
||||||
|
if self.enabled:
|
||||||
|
frame = env.render(
|
||||||
|
mode="rgb_array",
|
||||||
|
height=self.render_size,
|
||||||
|
width=self.render_size,
|
||||||
|
camera_id=self.camera_id,
|
||||||
|
)
|
||||||
|
self.frames.append(frame)
|
||||||
|
|
||||||
|
def save(self, step):
|
||||||
|
if self.enabled:
|
||||||
|
frames = np.stack(self.frames).transpose(0, 3, 1, 2)
|
||||||
|
self._wandb.log(
|
||||||
|
{"eval_video": self._wandb.Video(frames, fps=self.fps, format="mp4")},
|
||||||
|
step=step,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Logger(object):
|
||||||
|
"""Primary logger object. Logs either locally or using wandb."""
|
||||||
|
|
||||||
|
def __init__(self, log_dir, cfg):
|
||||||
|
self._log_dir = make_dir(Path(log_dir))
|
||||||
|
self._model_dir = make_dir(self._log_dir / "models")
|
||||||
|
self._buffer_dir = make_dir(self._log_dir / "buffers")
|
||||||
|
self._save_model = cfg.save_model
|
||||||
|
self._save_buffer = cfg.save_buffer
|
||||||
|
self._group = cfg_to_group(cfg)
|
||||||
|
self._seed = cfg.seed
|
||||||
|
self._cfg = cfg
|
||||||
|
self._eval = []
|
||||||
|
print_run(cfg)
|
||||||
|
project, entity = cfg.get("wandb_project", "none"), cfg.get(
|
||||||
|
"wandb_entity", "none"
|
||||||
|
)
|
||||||
|
run_offline = (
|
||||||
|
not cfg.get("use_wandb", False) or project == "none" or entity == "none"
|
||||||
|
)
|
||||||
|
if run_offline:
|
||||||
|
print(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||||
|
self._wandb = None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
os.environ["WANDB_SILENT"] = "true"
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
wandb.init(
|
||||||
|
project=project,
|
||||||
|
entity=entity,
|
||||||
|
name=str(cfg.seed),
|
||||||
|
notes=cfg.notes,
|
||||||
|
group=self._group,
|
||||||
|
tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
|
||||||
|
dir=self._log_dir,
|
||||||
|
config=OmegaConf.to_container(cfg, resolve=True),
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
colored("Logs will be synced with wandb.", "blue", attrs=["bold"])
|
||||||
|
)
|
||||||
|
self._wandb = wandb
|
||||||
|
except:
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
"Warning: failed to init wandb. Make sure `wandb_entity` is set to your username in `config.yaml`. Logs will be saved locally.",
|
||||||
|
"yellow",
|
||||||
|
attrs=["bold"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._wandb = None
|
||||||
|
self._video = (
|
||||||
|
VideoRecorder(log_dir, self._wandb)
|
||||||
|
if self._wandb and cfg.save_video
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video(self):
|
||||||
|
return self._video
|
||||||
|
|
||||||
|
def save_model(self, agent, identifier):
|
||||||
|
if self._save_model:
|
||||||
|
fp = self._model_dir / f"{str(identifier)}.pt"
|
||||||
|
agent.save(fp)
|
||||||
|
if self._wandb:
|
||||||
|
artifact = self._wandb.Artifact(
|
||||||
|
self._group + "-" + str(self._seed) + "-" + str(identifier),
|
||||||
|
type="model",
|
||||||
|
)
|
||||||
|
artifact.add_file(fp)
|
||||||
|
self._wandb.log_artifact(artifact)
|
||||||
|
|
||||||
|
def save_buffer(self, buffer, identifier):
|
||||||
|
fp = self._buffer_dir / f"{str(identifier)}.pkl"
|
||||||
|
buffer.save(fp)
|
||||||
|
if self._wandb:
|
||||||
|
artifact = self._wandb.Artifact(
|
||||||
|
self._group + "-" + str(self._seed) + "-" + str(identifier),
|
||||||
|
type="buffer",
|
||||||
|
)
|
||||||
|
artifact.add_file(fp)
|
||||||
|
self._wandb.log_artifact(artifact)
|
||||||
|
|
||||||
|
def finish(self, agent, buffer):
|
||||||
|
if self._save_model:
|
||||||
|
self.save_model(agent, identifier="final")
|
||||||
|
if self._save_buffer:
|
||||||
|
self.save_buffer(buffer, identifier="buffer")
|
||||||
|
if self._wandb:
|
||||||
|
self._wandb.finish()
|
||||||
|
print_run(self._cfg, self._eval[-1][-1])
|
||||||
|
|
||||||
|
def _format(self, key, value, ty):
|
||||||
|
if ty == "int":
|
||||||
|
return f'{colored(key + ":", "grey")} {int(value):,}'
|
||||||
|
elif ty == "float":
|
||||||
|
return f'{colored(key + ":", "grey")} {value:.01f}'
|
||||||
|
elif ty == "time":
|
||||||
|
value = str(datetime.timedelta(seconds=int(value)))
|
||||||
|
return f'{colored(key + ":", "grey")} {value}'
|
||||||
|
else:
|
||||||
|
raise f"invalid log format type: {ty}"
|
||||||
|
|
||||||
|
def _print(self, d, category):
|
||||||
|
category = colored(category, "blue" if category == "train" else "green")
|
||||||
|
pieces = [f" {category:<14}"]
|
||||||
|
for k, disp_k, ty in CONSOLE_FORMAT:
|
||||||
|
pieces.append(f"{self._format(disp_k, d.get(k, 0), ty):<26}")
|
||||||
|
print(" ".join(pieces))
|
||||||
|
|
||||||
|
def log(self, d, category="train"):
|
||||||
|
assert category in {"train", "eval"}
|
||||||
|
if self._wandb is not None:
|
||||||
|
for k, v in d.items():
|
||||||
|
self._wandb.log({category + "/" + k: v}, step=d["env_step"])
|
||||||
|
if category == "eval":
|
||||||
|
# keys = ['env_step', 'avg_reward']
|
||||||
|
keys = ["env_step", "avg_reward", "pc_success"]
|
||||||
|
self._eval.append(np.array([d[key] for key in keys]))
|
||||||
|
pd.DataFrame(np.array(self._eval)).to_csv(
|
||||||
|
self._log_dir / "eval.log", header=keys, index=None
|
||||||
|
)
|
||||||
|
self._print(d, category)
|
|
@ -1,5 +1,6 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -90,7 +91,7 @@ class TDMPC(nn.Module):
|
||||||
self.model_target = deepcopy(self.model)
|
self.model_target = deepcopy(self.model)
|
||||||
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
||||||
self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.model_target.eval()
|
self.model_target.eval()
|
||||||
self.batch_size = cfg.batch_size
|
self.batch_size = cfg.batch_size
|
||||||
|
@ -308,9 +309,41 @@ class TDMPC(nn.Module):
|
||||||
self.demo_batch_size = 0
|
self.demo_batch_size = 0
|
||||||
|
|
||||||
# Sample from interaction dataset
|
# Sample from interaction dataset
|
||||||
obs, next_obses, action, reward, mask, done, idxs, weights = (
|
|
||||||
replay_buffer.sample()
|
# to not have to mask
|
||||||
|
# batch_size = (self.cfg.batch_size // self.cfg.horizon) * self.cfg.horizon
|
||||||
|
batch_size = self.cfg.horizon * self.cfg.batch_size
|
||||||
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
|
||||||
|
# trajectory t = 256, horizon h = 5
|
||||||
|
# (t h) ... -> h t ...
|
||||||
|
batch = (
|
||||||
|
batch.reshape(self.cfg.batch_size, self.cfg.horizon)
|
||||||
|
.transpose(1, 0)
|
||||||
|
.contiguous()
|
||||||
)
|
)
|
||||||
|
batch = batch.to("cuda")
|
||||||
|
|
||||||
|
FIRST_FRAME = 0
|
||||||
|
obs = {
|
||||||
|
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
|
||||||
|
"state": batch["observation", "state"][FIRST_FRAME],
|
||||||
|
}
|
||||||
|
action = batch["action"]
|
||||||
|
next_obses = {
|
||||||
|
"rgb": batch["next", "observation", "image"].float(),
|
||||||
|
"state": batch["next", "observation", "state"],
|
||||||
|
}
|
||||||
|
reward = batch["next", "reward"]
|
||||||
|
reward = einops.rearrange(reward, "h t -> h t 1")
|
||||||
|
# We dont use `batch["next", "done"]` since it only indicates the end of an
|
||||||
|
# episode, but not the end of the trajectory of an episode.
|
||||||
|
# Neither does `batch["next", "terminated"]`
|
||||||
|
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
|
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
|
|
||||||
|
idxs = batch["frame_id"][FIRST_FRAME]
|
||||||
|
weights = batch["_weight"][FIRST_FRAME, :, None]
|
||||||
|
|
||||||
# Sample from demonstration dataset
|
# Sample from demonstration dataset
|
||||||
if self.demo_batch_size > 0:
|
if self.demo_batch_size > 0:
|
||||||
|
@ -341,6 +374,21 @@ class TDMPC(nn.Module):
|
||||||
idxs = torch.cat([idxs, demo_idxs])
|
idxs = torch.cat([idxs, demo_idxs])
|
||||||
weights = torch.cat([weights, demo_weights])
|
weights = torch.cat([weights, demo_weights])
|
||||||
|
|
||||||
|
# Apply augmentations
|
||||||
|
aug_tf = h.aug(self.cfg)
|
||||||
|
obs = aug_tf(obs)
|
||||||
|
|
||||||
|
for k in next_obses:
|
||||||
|
next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...")
|
||||||
|
next_obses = aug_tf(next_obses)
|
||||||
|
for k in next_obses:
|
||||||
|
next_obses[k] = einops.rearrange(
|
||||||
|
next_obses[k],
|
||||||
|
"(h t) ... -> h t ...",
|
||||||
|
h=self.cfg.horizon,
|
||||||
|
t=self.cfg.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
horizon = self.cfg.horizon
|
horizon = self.cfg.horizon
|
||||||
loss_mask = torch.ones_like(mask, device=self.device)
|
loss_mask = torch.ones_like(mask, device=self.device)
|
||||||
for t in range(1, horizon):
|
for t in range(1, horizon):
|
||||||
|
@ -407,6 +455,7 @@ class TDMPC(nn.Module):
|
||||||
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
||||||
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
||||||
weighted_loss.backward()
|
weighted_loss.backward()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||||
)
|
)
|
||||||
|
@ -415,13 +464,16 @@ class TDMPC(nn.Module):
|
||||||
if self.cfg.per:
|
if self.cfg.per:
|
||||||
# Update priorities
|
# Update priorities
|
||||||
priorities = priority_loss.clamp(max=1e4).detach()
|
priorities = priority_loss.clamp(max=1e4).detach()
|
||||||
replay_buffer.update_priorities(
|
# normalize between [0,1] to fit torchrl specification
|
||||||
idxs[: replay_buffer.cfg.batch_size],
|
priorities /= 1e4
|
||||||
priorities[: replay_buffer.cfg.batch_size],
|
priorities = priorities.clamp(max=1.0)
|
||||||
|
replay_buffer.update_priority(
|
||||||
|
idxs[: self.cfg.batch_size],
|
||||||
|
priorities[: self.cfg.batch_size],
|
||||||
)
|
)
|
||||||
if self.demo_batch_size > 0:
|
if self.demo_batch_size > 0:
|
||||||
demo_buffer.update_priorities(
|
demo_buffer.update_priority(
|
||||||
demo_idxs, priorities[replay_buffer.cfg.batch_size :]
|
demo_idxs, priorities[self.cfg.batch_size :]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update policy + target network
|
# Update policy + target network
|
||||||
|
|
|
@ -306,13 +306,21 @@ class RandomShiftsAug(nn.Module):
|
||||||
x = F.pad(x, padding, "replicate")
|
x = F.pad(x, padding, "replicate")
|
||||||
eps = 1.0 / (h + 2 * self.pad)
|
eps = 1.0 / (h + 2 * self.pad)
|
||||||
arange = torch.linspace(
|
arange = torch.linspace(
|
||||||
-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype
|
-1.0 + eps,
|
||||||
|
1.0 - eps,
|
||||||
|
h + 2 * self.pad,
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.float32,
|
||||||
)[:h]
|
)[:h]
|
||||||
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
|
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
|
||||||
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
|
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
|
||||||
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
|
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
|
||||||
shift = torch.randint(
|
shift = torch.randint(
|
||||||
0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype
|
0,
|
||||||
|
2 * self.pad + 1,
|
||||||
|
size=(n, 1, 1, 2),
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
shift *= 2.0 / (h + 2 * self.pad)
|
shift *= 2.0 / (h + 2 * self.pad)
|
||||||
grid = base_grid + shift
|
grid = base_grid + shift
|
||||||
|
|
|
@ -1,5 +1,14 @@
|
||||||
seed: 1337
|
seed: 1337
|
||||||
log_dir: logs/2024_01_26_train
|
log_dir: logs/2024_01_26_train
|
||||||
|
exp_name: default
|
||||||
|
device: cuda
|
||||||
|
buffer_device: cuda
|
||||||
|
eval_freq: 1000
|
||||||
|
save_freq: 10000
|
||||||
|
eval_episodes: 20
|
||||||
|
save_video: false
|
||||||
|
save_model: false
|
||||||
|
save_buffer: false
|
||||||
|
|
||||||
# env
|
# env
|
||||||
env: simxarm
|
env: simxarm
|
||||||
|
@ -8,6 +17,14 @@ from_pixels: True
|
||||||
pixels_only: False
|
pixels_only: False
|
||||||
image_size: 84
|
image_size: 84
|
||||||
|
|
||||||
|
reward_scale: 1.0
|
||||||
|
|
||||||
|
# xarm_lift
|
||||||
|
episode_length: 25
|
||||||
|
modality: 'all'
|
||||||
|
action_repeat: 2 # TODO(rcadene): verify we use this
|
||||||
|
discount: 0.9
|
||||||
|
train_steps: 50000
|
||||||
|
|
||||||
# pixels
|
# pixels
|
||||||
frame_stack: 1
|
frame_stack: 1
|
||||||
|
@ -54,6 +71,19 @@ update_freq: 2
|
||||||
tau: 0.01
|
tau: 0.01
|
||||||
utd: 1
|
utd: 1
|
||||||
|
|
||||||
|
# offline rl
|
||||||
|
# dataset_dir: ???
|
||||||
|
data_first_percent: 1.0
|
||||||
|
is_data_clip: true
|
||||||
|
data_clip_eps: 1e-5
|
||||||
|
expectile: 0.9
|
||||||
|
A_scaling: 3.0
|
||||||
|
|
||||||
|
# offline->online
|
||||||
|
offline_steps: ${train_steps}/2
|
||||||
|
pretrained_model_path: ""
|
||||||
|
balanced_sampling: true
|
||||||
|
demo_schedule: 0.5
|
||||||
|
|
||||||
# architecture
|
# architecture
|
||||||
enc_dim: 256
|
enc_dim: 256
|
||||||
|
@ -61,11 +91,8 @@ num_q: 5
|
||||||
mlp_dim: 512
|
mlp_dim: 512
|
||||||
latent_dim: 50
|
latent_dim: 50
|
||||||
|
|
||||||
|
# wandb
|
||||||
|
use_wandb: false
|
||||||
|
wandb_project: FOWM
|
||||||
|
wandb_entity: rcadene # insert your own
|
||||||
|
|
||||||
# xarm_lift
|
|
||||||
A_scaling: 3.0
|
|
||||||
expectile: 0.9
|
|
||||||
episode_length: 25
|
|
||||||
modality: 'all'
|
|
||||||
action_repeat: 2
|
|
||||||
discount: 0.9
|
|
|
@ -32,23 +32,25 @@ def eval_policy(
|
||||||
ep_frames.append(frame)
|
ep_frames.append(frame)
|
||||||
|
|
||||||
tensordict = env.reset()
|
tensordict = env.reset()
|
||||||
|
if save_video:
|
||||||
# render first frame before rollout
|
# render first frame before rollout
|
||||||
rendering_callback(env)
|
rendering_callback(env)
|
||||||
|
|
||||||
rollout = env.rollout(
|
rollout = env.rollout(
|
||||||
max_steps=max_steps,
|
max_steps=max_steps,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
callback=rendering_callback,
|
callback=rendering_callback if save_video else None,
|
||||||
auto_reset=False,
|
auto_reset=False,
|
||||||
tensordict=tensordict,
|
tensordict=tensordict,
|
||||||
)
|
)
|
||||||
|
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
||||||
ep_reward = rollout["next", "reward"].sum()
|
ep_reward = rollout["next", "reward"].sum()
|
||||||
ep_success = rollout["next", "success"].any()
|
ep_success = rollout["next", "success"].any()
|
||||||
rewards.append(ep_reward.item())
|
rewards.append(ep_reward.item())
|
||||||
successes.append(ep_success.item())
|
successes.append(ep_success.item())
|
||||||
|
|
||||||
if save_video:
|
if save_video:
|
||||||
video_dir.parent.mkdir(parents=True, exist_ok=True)
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
# TODO(rcadene): make fps configurable
|
# TODO(rcadene): make fps configurable
|
||||||
video_path = video_dir / f"eval_episode_{i}.mp4"
|
video_path = video_dir / f"eval_episode_{i}.mp4"
|
||||||
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
|
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
|
||||||
|
@ -82,8 +84,8 @@ def eval(cfg: dict):
|
||||||
metrics = eval_policy(
|
metrics = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
num_episodes=10,
|
num_episodes=20,
|
||||||
save_video=True,
|
save_video=False,
|
||||||
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
||||||
)
|
)
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
|
@ -1,11 +1,24 @@
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from tensordict.nn import TensorDictModule
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
||||||
|
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
|
||||||
|
from torchrl.data.datasets.openx import OpenXExperienceReplay
|
||||||
|
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||||
|
|
||||||
|
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
|
from lerobot.common.logger import Logger
|
||||||
from lerobot.common.tdmpc import TDMPC
|
from lerobot.common.tdmpc import TDMPC
|
||||||
|
from lerobot.common.utils import set_seed
|
||||||
from ..common.utils import set_seed
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||||
|
@ -15,22 +28,169 @@ def train(cfg: dict):
|
||||||
print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir)
|
print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir)
|
||||||
|
|
||||||
env = make_env(cfg)
|
env = make_env(cfg)
|
||||||
agent = TDMPC(cfg)
|
policy = TDMPC(cfg)
|
||||||
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||||
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||||
agent.load(ckpt_path)
|
policy.load(ckpt_path)
|
||||||
|
|
||||||
# online training
|
td_policy = TensorDictModule(
|
||||||
|
policy,
|
||||||
eval_metrics = train_agent(
|
in_keys=["observation", "step_count"],
|
||||||
env,
|
out_keys=["action"],
|
||||||
agent,
|
|
||||||
num_episodes=10,
|
|
||||||
save_video=True,
|
|
||||||
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(eval_metrics)
|
# initialize offline dataset
|
||||||
|
|
||||||
|
dataset_id = f"xarm_{cfg.task}_medium"
|
||||||
|
|
||||||
|
num_traj_per_batch = cfg.batch_size # // cfg.horizon
|
||||||
|
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||||
|
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||||
|
sampler = PrioritizedSliceSampler(
|
||||||
|
max_capacity=100_000,
|
||||||
|
alpha=0.7,
|
||||||
|
beta=0.9,
|
||||||
|
num_slices=num_traj_per_batch,
|
||||||
|
strict_length=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(rcadene): use PrioritizedReplayBuffer
|
||||||
|
offline_buffer = SimxarmExperienceReplay(
|
||||||
|
dataset_id,
|
||||||
|
# download="force",
|
||||||
|
download=True,
|
||||||
|
streaming=False,
|
||||||
|
root="data",
|
||||||
|
sampler=sampler,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_steps = len(offline_buffer)
|
||||||
|
index = torch.arange(0, num_steps, 1)
|
||||||
|
sampler.extend(index)
|
||||||
|
|
||||||
|
# offline_buffer._storage.device = torch.device("cuda")
|
||||||
|
# offline_buffer._storage._storage.to(torch.device("cuda"))
|
||||||
|
# TODO(rcadene): add online_buffer
|
||||||
|
|
||||||
|
# Observation encoder
|
||||||
|
# Dynamics predictor
|
||||||
|
# Reward predictor
|
||||||
|
# Policy
|
||||||
|
# Qs state-action value predictor
|
||||||
|
# V state value predictor
|
||||||
|
|
||||||
|
L = Logger(cfg.log_dir, cfg)
|
||||||
|
|
||||||
|
episode_idx = 0
|
||||||
|
start_time = time.time()
|
||||||
|
step = 0
|
||||||
|
last_log_step = 0
|
||||||
|
last_save_step = 0
|
||||||
|
|
||||||
|
while step < cfg.train_steps:
|
||||||
|
is_offline = True
|
||||||
|
num_updates = cfg.episode_length
|
||||||
|
_step = step + num_updates
|
||||||
|
rollout_metrics = {}
|
||||||
|
|
||||||
|
# if step >= cfg.offline_steps:
|
||||||
|
# is_offline = False
|
||||||
|
|
||||||
|
# # Collect trajectory
|
||||||
|
# obs = env.reset()
|
||||||
|
# episode = Episode(cfg, obs)
|
||||||
|
# success = False
|
||||||
|
# while not episode.done:
|
||||||
|
# action = policy.act(obs, step=step, t0=episode.first)
|
||||||
|
# obs, reward, done, info = env.step(action.cpu().numpy())
|
||||||
|
# reward = reward_normalizer(reward)
|
||||||
|
# mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0
|
||||||
|
# success = info.get('success', False)
|
||||||
|
# episode += (obs, action, reward, done, mask, success)
|
||||||
|
# assert len(episode) <= cfg.episode_length
|
||||||
|
# buffer += episode
|
||||||
|
# episode_idx += 1
|
||||||
|
# rollout_metrics = {
|
||||||
|
# 'episode_reward': episode.cumulative_reward,
|
||||||
|
# 'episode_success': float(success),
|
||||||
|
# 'episode_length': len(episode)
|
||||||
|
# }
|
||||||
|
# num_updates = len(episode) * cfg.utd
|
||||||
|
# _step = min(step + len(episode), cfg.train_steps)
|
||||||
|
|
||||||
|
# Update model
|
||||||
|
train_metrics = {}
|
||||||
|
if is_offline:
|
||||||
|
for i in range(num_updates):
|
||||||
|
train_metrics.update(policy.update(offline_buffer, step + i))
|
||||||
|
# else:
|
||||||
|
# for i in range(num_updates):
|
||||||
|
# train_metrics.update(
|
||||||
|
# policy.update(buffer, step + i // cfg.utd,
|
||||||
|
# demo_buffer=offline_buffer if cfg.balanced_sampling else None)
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Log training metrics
|
||||||
|
env_step = int(_step * cfg.action_repeat)
|
||||||
|
common_metrics = {
|
||||||
|
"episode": episode_idx,
|
||||||
|
"step": _step,
|
||||||
|
"env_step": env_step,
|
||||||
|
"total_time": time.time() - start_time,
|
||||||
|
"is_offline": float(is_offline),
|
||||||
|
}
|
||||||
|
train_metrics.update(common_metrics)
|
||||||
|
train_metrics.update(rollout_metrics)
|
||||||
|
L.log(train_metrics, category="train")
|
||||||
|
|
||||||
|
# Evaluate policy periodically
|
||||||
|
if step == 0 or env_step - last_log_step >= cfg.eval_freq:
|
||||||
|
|
||||||
|
eval_metrics = eval_policy(
|
||||||
|
env,
|
||||||
|
td_policy,
|
||||||
|
num_episodes=cfg.eval_episodes,
|
||||||
|
# TODO(rcadene): add step, env_step, L.video
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(rcadene):
|
||||||
|
# if hasattr(env, "get_normalized_score"):
|
||||||
|
# eval_metrics['normalized_score'] = env.get_normalized_score(eval_metrics["episode_reward"]) * 100.0
|
||||||
|
|
||||||
|
common_metrics.update(eval_metrics)
|
||||||
|
|
||||||
|
L.log(common_metrics, category="eval")
|
||||||
|
last_log_step = env_step - env_step % cfg.eval_freq
|
||||||
|
|
||||||
|
# Save model periodically
|
||||||
|
# if cfg.save_model and env_step - last_save_step >= cfg.save_freq:
|
||||||
|
# L.save_model(policy, identifier=env_step)
|
||||||
|
# print(f"Model has been checkpointed at step {env_step}")
|
||||||
|
# last_save_step = env_step - env_step % cfg.save_freq
|
||||||
|
|
||||||
|
# if cfg.save_model and is_offline and _step >= cfg.offline_steps:
|
||||||
|
# # save the model after offline training
|
||||||
|
# L.save_model(policy, identifier="offline")
|
||||||
|
|
||||||
|
step = _step
|
||||||
|
|
||||||
|
# dataset_d4rl = D4RLExperienceReplay(
|
||||||
|
# dataset_id="maze2d-umaze-v1",
|
||||||
|
# split_trajs=False,
|
||||||
|
# batch_size=1,
|
||||||
|
# sampler=SamplerWithoutReplacement(drop_last=False),
|
||||||
|
# prefetch=4,
|
||||||
|
# direct_download=True,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# dataset_openx = OpenXExperienceReplay(
|
||||||
|
# "cmu_stretch",
|
||||||
|
# batch_size=1,
|
||||||
|
# num_slices=1,
|
||||||
|
# #download="force",
|
||||||
|
# streaming=False,
|
||||||
|
# root="data",
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,80 +0,0 @@
|
||||||
import pickle
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import imageio
|
|
||||||
import simxarm
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
task = "lift"
|
|
||||||
dataset_dir = Path(f"data/xarm_{task}_medium")
|
|
||||||
dataset_path = dataset_dir / f"buffer.pkl"
|
|
||||||
print(f"Using offline dataset '{dataset_path}'")
|
|
||||||
with open(dataset_path, "rb") as f:
|
|
||||||
dataset_dict = pickle.load(f)
|
|
||||||
|
|
||||||
required_keys = [
|
|
||||||
"observations",
|
|
||||||
"next_observations",
|
|
||||||
"actions",
|
|
||||||
"rewards",
|
|
||||||
"dones",
|
|
||||||
"masks",
|
|
||||||
]
|
|
||||||
for k in required_keys:
|
|
||||||
if k not in dataset_dict and k[:-1] in dataset_dict:
|
|
||||||
dataset_dict[k] = dataset_dict.pop(k[:-1])
|
|
||||||
|
|
||||||
out_dir = Path("tmp/2023_01_26_xarm_lift_medium")
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
frames = dataset_dict["observations"]["rgb"][:100]
|
|
||||||
frames = frames.transpose(0, 2, 3, 1)
|
|
||||||
imageio.mimsave(out_dir / "test.mp4", frames, fps=30)
|
|
||||||
|
|
||||||
frames = []
|
|
||||||
cfg = {}
|
|
||||||
|
|
||||||
env = simxarm.make(
|
|
||||||
task=task,
|
|
||||||
obs_mode="all",
|
|
||||||
image_size=84,
|
|
||||||
action_repeat=cfg.get("action_repeat", 1),
|
|
||||||
frame_stack=cfg.get("frame_stack", 1),
|
|
||||||
seed=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
obs = env.reset()
|
|
||||||
frame = env.render(mode="rgb_array", width=384, height=384)
|
|
||||||
frames.append(frame)
|
|
||||||
|
|
||||||
# def is_first_obs(obs):
|
|
||||||
# nonlocal first_obs
|
|
||||||
# print(((dataset_dict["observations"]["state"][i]-obs["state"])**2).sum())
|
|
||||||
# print(((dataset_dict["observations"]["rgb"][i]-obs["rgb"])**2).sum())
|
|
||||||
|
|
||||||
for i in range(25):
|
|
||||||
action = dataset_dict["actions"][i]
|
|
||||||
|
|
||||||
print(f"#{i}")
|
|
||||||
# print(obs["state"])
|
|
||||||
# print(dataset_dict["observations"]["state"][i])
|
|
||||||
print(((dataset_dict["observations"]["state"][i] - obs["state"]) ** 2).sum())
|
|
||||||
print(((dataset_dict["observations"]["rgb"][i] - obs["rgb"]) ** 2).sum())
|
|
||||||
|
|
||||||
obs, reward, done, info = env.step(action)
|
|
||||||
frame = env.render(mode="rgb_array", width=384, height=384)
|
|
||||||
frames.append(frame)
|
|
||||||
|
|
||||||
print(reward)
|
|
||||||
print(dataset_dict["rewards"][i])
|
|
||||||
|
|
||||||
print(done)
|
|
||||||
print(dataset_dict["dones"][i])
|
|
||||||
|
|
||||||
if dataset_dict["dones"][i]:
|
|
||||||
obs = env.reset()
|
|
||||||
frame = env.render(mode="rgb_array", width=384, height=384)
|
|
||||||
frames.append(frame)
|
|
||||||
|
|
||||||
# imageio.mimsave(out_dir / 'test_rollout.mp4', frames, fps=60)
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import imageio
|
||||||
|
import simxarm
|
||||||
|
import torch
|
||||||
|
from torchrl.data.replay_buffers import (
|
||||||
|
SamplerWithoutReplacement,
|
||||||
|
SliceSampler,
|
||||||
|
SliceSamplerWithoutReplacement,
|
||||||
|
)
|
||||||
|
|
||||||
|
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"):
|
||||||
|
sampler = SliceSamplerWithoutReplacement(
|
||||||
|
num_slices=1,
|
||||||
|
strict_length=False,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = SimxarmExperienceReplay(
|
||||||
|
dataset_id,
|
||||||
|
# download="force",
|
||||||
|
download=True,
|
||||||
|
streaming=False,
|
||||||
|
root="data",
|
||||||
|
sampler=sampler,
|
||||||
|
)
|
||||||
|
|
||||||
|
NUM_EPISODES_TO_RENDER = 10
|
||||||
|
MAX_NUM_STEPS = 50
|
||||||
|
FIRST_FRAME = 0
|
||||||
|
for _ in range(NUM_EPISODES_TO_RENDER):
|
||||||
|
episode = dataset.sample(MAX_NUM_STEPS)
|
||||||
|
|
||||||
|
ep_idx = episode["episode"][FIRST_FRAME].item()
|
||||||
|
ep_frames = torch.cat(
|
||||||
|
[
|
||||||
|
episode["observation"]["image"][FIRST_FRAME][None, ...],
|
||||||
|
episode["next", "observation"]["image"],
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
video_dir = Path("tmp/2024_02_03_xarm_lift_medium")
|
||||||
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# TODO(rcadene): make fps configurable
|
||||||
|
video_path = video_dir / f"eval_episode_{ep_idx}.mp4"
|
||||||
|
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=15)
|
||||||
|
|
||||||
|
# ran out of episodes
|
||||||
|
if dataset._sampler._sample_list.numel() == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
visualize_simxarm_dataset()
|
Loading…
Reference in New Issue