Eval reproduced! Train running (but not reproduced)

This commit is contained in:
Cadene 2024-02-10 15:46:24 +00:00
parent 937b2f8cba
commit 228c045674
14 changed files with 787 additions and 118 deletions

1
.gitignore vendored
View File

@ -5,6 +5,7 @@ wandb
data
outputs
.vscode
rl
# HPC
nautilus/*.yaml

View File

@ -15,6 +15,7 @@ conda activate lerobot
python setup.py develop
```
## Contribute
**style**

View File

View File

@ -0,0 +1,190 @@
import os
import pickle
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import (
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import (
Sampler,
SliceSampler,
SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
class SimxarmExperienceReplay(TensorDictReplayBuffer):
available_datasets = [
"xarm_lift_medium",
]
def __init__(
self,
dataset_id,
batch_size: int = None,
*,
shuffle: bool = True,
num_slices: int = None,
slice_len: int = None,
pad: float = None,
replacement: bool = None,
streaming: bool = False,
root: Path = None,
download: bool = False,
sampler: Sampler = None,
writer: Writer = None,
collate_fn: Callable = None,
pin_memory: bool = False,
prefetch: int = None,
transform: "torchrl.envs.Transform" = None, # noqa-F821
split_trajs: bool = False,
strict_length: bool = True,
):
self.download = download
if streaming:
raise NotImplementedError
self.streaming = streaming
self.dataset_id = dataset_id
self.split_trajs = split_trajs
self.shuffle = shuffle
self.num_slices = num_slices
self.slice_len = slice_len
self.pad = pad
self.strict_length = strict_length
if (self.num_slices is not None) and (self.slice_len is not None):
raise ValueError("num_slices or slice_len can be not None, but not both.")
if split_trajs:
raise NotImplementedError
if root is None:
root = _get_root_dir("simxarm")
os.makedirs(root, exist_ok=True)
self.root = Path(root)
if self.download == "force" or (self.download and not self._is_downloaded()):
storage = self._download_and_preproc()
else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
if num_slices is not None or slice_len is not None:
if sampler is not None:
raise ValueError(
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
)
if replacement:
if not self.shuffle:
raise RuntimeError(
"shuffle=False can only be used when replacement=False."
)
sampler = SliceSampler(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
)
else:
sampler = SliceSamplerWithoutReplacement(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
shuffle=self.shuffle,
)
if writer is None:
writer = ImmutableDatasetWriter()
if collate_fn is None:
collate_fn = _collate_id
super().__init__(
storage=storage,
sampler=sampler,
writer=writer,
collate_fn=collate_fn,
pin_memory=pin_memory,
prefetch=prefetch,
batch_size=batch_size,
transform=transform,
)
@property
def data_path_root(self):
if self.streaming:
return None
return self.root / self.dataset_id
def _is_downloaded(self):
return os.path.exists(self.data_path_root)
def _download_and_preproc(self):
# download
# TODO(rcadene)
# load
dataset_dir = Path("data") / self.dataset_id
dataset_path = dataset_dir / f"buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
total_frames = dataset_dict["actions"].shape[0]
idx0 = 0
idx1 = 0
episode_id = 0
for i in tqdm.tqdm(range(total_frames)):
idx1 += 1
if not dataset_dict["dones"][i]:
continue
num_frames = idx1 - idx0
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
next_image = torch.tensor(
dataset_dict["next_observations"]["rgb"][idx0:idx1]
)
next_state = torch.tensor(
dataset_dict["next_observations"]["state"][idx0:idx1]
)
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
episode = TensorDict(
{
("observation", "image"): image,
("observation", "state"): state,
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
("next", "observation", "image"): next_image,
("next", "observation", "state"): next_state,
("next", "observation", "reward"): next_reward,
("next", "observation", "done"): next_done,
},
batch_size=num_frames,
)
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = (
episode[0]
.expand(total_frames)
.memmap_like(self.root / self.dataset_id)
)
td_data[idx0:idx1] = episode
episode_id += 1
idx0 = idx1
return TensorStorage(td_data.lock_())

View File

@ -7,6 +7,7 @@ def make_env(cfg):
assert cfg.env == "simxarm"
env = SimxarmEnv(
task=cfg.task,
frame_skip=cfg.action_repeat,
from_pixels=cfg.from_pixels,
pixels_only=cfg.pixels_only,
image_size=cfg.image_size,

View File

@ -24,6 +24,7 @@ class SimxarmEnv(EnvBase):
def __init__(
self,
task,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
@ -32,6 +33,7 @@ class SimxarmEnv(EnvBase):
):
super().__init__(device=device, batch_size=[])
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
@ -115,12 +117,15 @@ class SimxarmEnv(EnvBase):
# step expects shape=(4,) so we pad if necessary
action = np.concatenate([action, self._action_padding])
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
for t in range(self.frame_skip):
raw_obs, reward, done, info = self._env.step(action)
sum_reward += reward
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
"reward": torch.tensor([reward], dtype=torch.float32),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([info["success"]], dtype=torch.bool),
},

243
lerobot/common/logger.py Normal file
View File

@ -0,0 +1,243 @@
import datetime
import os
import re
from pathlib import Path
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from termcolor import colored
CONSOLE_FORMAT = [
("episode", "E", "int"),
("env_step", "S", "int"),
("avg_reward", "R", "float"),
("pc_success", "R", "float"),
("total_time", "T", "time"),
]
AGENT_METRICS = [
"consistency_loss",
"reward_loss",
"value_loss",
"total_loss",
"weighted_loss",
"pi_loss",
"grad_norm",
]
def make_dir(dir_path):
"""Create directory if it does not already exist."""
try:
dir_path.mkdir(parents=True, exist_ok=True)
except OSError:
pass
return dir_path
def print_run(cfg, reward=None):
"""Pretty-printing of run information. Call at start of training."""
prefix, color, attrs = " ", "green", ["bold"]
def limstr(s, maxlen=32):
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
def pprint(k, v):
print(
prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs),
limstr(v),
)
kvs = [
("task", cfg.task),
("train steps", f"{int(cfg.train_steps * cfg.action_repeat):,}"),
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
# ('actions', cfg.action_dim),
# ('experiment', cfg.exp_name),
]
if reward is not None:
kvs.append(
("episode reward", colored(str(int(reward)), "white", attrs=["bold"]))
)
w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
div = "-" * w
print(div)
for k, v in kvs:
pprint(k, v)
print(div)
def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
return lst if return_list else "-".join(lst)
class VideoRecorder:
"""Utility class for logging evaluation videos."""
def __init__(self, root_dir, wandb, render_size=384, fps=15):
self.save_dir = (root_dir / "eval_video") if root_dir else None
self._wandb = wandb
self.render_size = render_size
self.fps = fps
self.frames = []
self.enabled = False
self.camera_id = 0
def init(self, env, enabled=True):
self.frames = []
self.enabled = self.save_dir and self._wandb and enabled
try:
env_name = env.unwrapped.spec.id
except:
env_name = ""
if "maze2d" in env_name:
self.camera_id = -1
elif "quadruped" in env_name:
self.camera_id = 2
self.record(env)
def record(self, env):
if self.enabled:
frame = env.render(
mode="rgb_array",
height=self.render_size,
width=self.render_size,
camera_id=self.camera_id,
)
self.frames.append(frame)
def save(self, step):
if self.enabled:
frames = np.stack(self.frames).transpose(0, 3, 1, 2)
self._wandb.log(
{"eval_video": self._wandb.Video(frames, fps=self.fps, format="mp4")},
step=step,
)
class Logger(object):
"""Primary logger object. Logs either locally or using wandb."""
def __init__(self, log_dir, cfg):
self._log_dir = make_dir(Path(log_dir))
self._model_dir = make_dir(self._log_dir / "models")
self._buffer_dir = make_dir(self._log_dir / "buffers")
self._save_model = cfg.save_model
self._save_buffer = cfg.save_buffer
self._group = cfg_to_group(cfg)
self._seed = cfg.seed
self._cfg = cfg
self._eval = []
print_run(cfg)
project, entity = cfg.get("wandb_project", "none"), cfg.get(
"wandb_entity", "none"
)
run_offline = (
not cfg.get("use_wandb", False) or project == "none" or entity == "none"
)
if run_offline:
print(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None
else:
try:
os.environ["WANDB_SILENT"] = "true"
import wandb
wandb.init(
project=project,
entity=entity,
name=str(cfg.seed),
notes=cfg.notes,
group=self._group,
tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
dir=self._log_dir,
config=OmegaConf.to_container(cfg, resolve=True),
)
print(
colored("Logs will be synced with wandb.", "blue", attrs=["bold"])
)
self._wandb = wandb
except:
print(
colored(
"Warning: failed to init wandb. Make sure `wandb_entity` is set to your username in `config.yaml`. Logs will be saved locally.",
"yellow",
attrs=["bold"],
)
)
self._wandb = None
self._video = (
VideoRecorder(log_dir, self._wandb)
if self._wandb and cfg.save_video
else None
)
@property
def video(self):
return self._video
def save_model(self, agent, identifier):
if self._save_model:
fp = self._model_dir / f"{str(identifier)}.pt"
agent.save(fp)
if self._wandb:
artifact = self._wandb.Artifact(
self._group + "-" + str(self._seed) + "-" + str(identifier),
type="model",
)
artifact.add_file(fp)
self._wandb.log_artifact(artifact)
def save_buffer(self, buffer, identifier):
fp = self._buffer_dir / f"{str(identifier)}.pkl"
buffer.save(fp)
if self._wandb:
artifact = self._wandb.Artifact(
self._group + "-" + str(self._seed) + "-" + str(identifier),
type="buffer",
)
artifact.add_file(fp)
self._wandb.log_artifact(artifact)
def finish(self, agent, buffer):
if self._save_model:
self.save_model(agent, identifier="final")
if self._save_buffer:
self.save_buffer(buffer, identifier="buffer")
if self._wandb:
self._wandb.finish()
print_run(self._cfg, self._eval[-1][-1])
def _format(self, key, value, ty):
if ty == "int":
return f'{colored(key + ":", "grey")} {int(value):,}'
elif ty == "float":
return f'{colored(key + ":", "grey")} {value:.01f}'
elif ty == "time":
value = str(datetime.timedelta(seconds=int(value)))
return f'{colored(key + ":", "grey")} {value}'
else:
raise f"invalid log format type: {ty}"
def _print(self, d, category):
category = colored(category, "blue" if category == "train" else "green")
pieces = [f" {category:<14}"]
for k, disp_k, ty in CONSOLE_FORMAT:
pieces.append(f"{self._format(disp_k, d.get(k, 0), ty):<26}")
print(" ".join(pieces))
def log(self, d, category="train"):
assert category in {"train", "eval"}
if self._wandb is not None:
for k, v in d.items():
self._wandb.log({category + "/" + k: v}, step=d["env_step"])
if category == "eval":
# keys = ['env_step', 'avg_reward']
keys = ["env_step", "avg_reward", "pc_success"]
self._eval.append(np.array([d[key] for key in keys]))
pd.DataFrame(np.array(self._eval)).to_csv(
self._log_dir / "eval.log", header=keys, index=None
)
self._print(d, category)

View File

@ -1,5 +1,6 @@
from copy import deepcopy
import einops
import numpy as np
import torch
import torch.nn as nn
@ -90,7 +91,7 @@ class TDMPC(nn.Module):
self.model_target = deepcopy(self.model)
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.model.eval()
self.model_target.eval()
self.batch_size = cfg.batch_size
@ -308,9 +309,41 @@ class TDMPC(nn.Module):
self.demo_batch_size = 0
# Sample from interaction dataset
obs, next_obses, action, reward, mask, done, idxs, weights = (
replay_buffer.sample()
# to not have to mask
# batch_size = (self.cfg.batch_size // self.cfg.horizon) * self.cfg.horizon
batch_size = self.cfg.horizon * self.cfg.batch_size
batch = replay_buffer.sample(batch_size)
# trajectory t = 256, horizon h = 5
# (t h) ... -> h t ...
batch = (
batch.reshape(self.cfg.batch_size, self.cfg.horizon)
.transpose(1, 0)
.contiguous()
)
batch = batch.to("cuda")
FIRST_FRAME = 0
obs = {
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
"state": batch["observation", "state"][FIRST_FRAME],
}
action = batch["action"]
next_obses = {
"rgb": batch["next", "observation", "image"].float(),
"state": batch["next", "observation", "state"],
}
reward = batch["next", "reward"]
reward = einops.rearrange(reward, "h t -> h t 1")
# We dont use `batch["next", "done"]` since it only indicates the end of an
# episode, but not the end of the trajectory of an episode.
# Neither does `batch["next", "terminated"]`
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
idxs = batch["frame_id"][FIRST_FRAME]
weights = batch["_weight"][FIRST_FRAME, :, None]
# Sample from demonstration dataset
if self.demo_batch_size > 0:
@ -341,6 +374,21 @@ class TDMPC(nn.Module):
idxs = torch.cat([idxs, demo_idxs])
weights = torch.cat([weights, demo_weights])
# Apply augmentations
aug_tf = h.aug(self.cfg)
obs = aug_tf(obs)
for k in next_obses:
next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...")
next_obses = aug_tf(next_obses)
for k in next_obses:
next_obses[k] = einops.rearrange(
next_obses[k],
"(h t) ... -> h t ...",
h=self.cfg.horizon,
t=self.cfg.batch_size,
)
horizon = self.cfg.horizon
loss_mask = torch.ones_like(mask, device=self.device)
for t in range(1, horizon):
@ -407,6 +455,7 @@ class TDMPC(nn.Module):
weighted_loss = (total_loss.squeeze(1) * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
weighted_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
)
@ -415,13 +464,16 @@ class TDMPC(nn.Module):
if self.cfg.per:
# Update priorities
priorities = priority_loss.clamp(max=1e4).detach()
replay_buffer.update_priorities(
idxs[: replay_buffer.cfg.batch_size],
priorities[: replay_buffer.cfg.batch_size],
# normalize between [0,1] to fit torchrl specification
priorities /= 1e4
priorities = priorities.clamp(max=1.0)
replay_buffer.update_priority(
idxs[: self.cfg.batch_size],
priorities[: self.cfg.batch_size],
)
if self.demo_batch_size > 0:
demo_buffer.update_priorities(
demo_idxs, priorities[replay_buffer.cfg.batch_size :]
demo_buffer.update_priority(
demo_idxs, priorities[self.cfg.batch_size :]
)
# Update policy + target network

View File

@ -306,13 +306,21 @@ class RandomShiftsAug(nn.Module):
x = F.pad(x, padding, "replicate")
eps = 1.0 / (h + 2 * self.pad)
arange = torch.linspace(
-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype
-1.0 + eps,
1.0 - eps,
h + 2 * self.pad,
device=x.device,
dtype=torch.float32,
)[:h]
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
shift = torch.randint(
0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype
0,
2 * self.pad + 1,
size=(n, 1, 1, 2),
device=x.device,
dtype=torch.float32,
)
shift *= 2.0 / (h + 2 * self.pad)
grid = base_grid + shift

View File

@ -1,5 +1,14 @@
seed: 1337
log_dir: logs/2024_01_26_train
exp_name: default
device: cuda
buffer_device: cuda
eval_freq: 1000
save_freq: 10000
eval_episodes: 20
save_video: false
save_model: false
save_buffer: false
# env
env: simxarm
@ -8,6 +17,14 @@ from_pixels: True
pixels_only: False
image_size: 84
reward_scale: 1.0
# xarm_lift
episode_length: 25
modality: 'all'
action_repeat: 2 # TODO(rcadene): verify we use this
discount: 0.9
train_steps: 50000
# pixels
frame_stack: 1
@ -54,6 +71,19 @@ update_freq: 2
tau: 0.01
utd: 1
# offline rl
# dataset_dir: ???
data_first_percent: 1.0
is_data_clip: true
data_clip_eps: 1e-5
expectile: 0.9
A_scaling: 3.0
# offline->online
offline_steps: ${train_steps}/2
pretrained_model_path: ""
balanced_sampling: true
demo_schedule: 0.5
# architecture
enc_dim: 256
@ -61,11 +91,8 @@ num_q: 5
mlp_dim: 512
latent_dim: 50
# wandb
use_wandb: false
wandb_project: FOWM
wandb_entity: rcadene # insert your own
# xarm_lift
A_scaling: 3.0
expectile: 0.9
episode_length: 25
modality: 'all'
action_repeat: 2
discount: 0.9

View File

@ -32,23 +32,25 @@ def eval_policy(
ep_frames.append(frame)
tensordict = env.reset()
if save_video:
# render first frame before rollout
rendering_callback(env)
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
callback=rendering_callback,
callback=rendering_callback if save_video else None,
auto_reset=False,
tensordict=tensordict,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
ep_reward = rollout["next", "reward"].sum()
ep_success = rollout["next", "success"].any()
rewards.append(ep_reward.item())
successes.append(ep_success.item())
if save_video:
video_dir.parent.mkdir(parents=True, exist_ok=True)
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
@ -82,8 +84,8 @@ def eval(cfg: dict):
metrics = eval_policy(
env,
policy=policy,
num_episodes=10,
save_video=True,
num_episodes=20,
save_video=False,
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
)
print(metrics)

View File

@ -1,11 +1,24 @@
import pickle
import time
from pathlib import Path
import hydra
import imageio
import numpy as np
import torch
from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
from torchrl.data.datasets.openx import OpenXExperienceReplay
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger
from lerobot.common.tdmpc import TDMPC
from ..common.utils import set_seed
from lerobot.common.utils import set_seed
from lerobot.scripts.eval import eval_policy
@hydra.main(version_base=None, config_name="default", config_path="../configs")
@ -15,22 +28,169 @@ def train(cfg: dict):
print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir)
env = make_env(cfg)
agent = TDMPC(cfg)
policy = TDMPC(cfg)
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
agent.load(ckpt_path)
policy.load(ckpt_path)
# online training
eval_metrics = train_agent(
env,
agent,
num_episodes=10,
save_video=True,
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
td_policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
print(eval_metrics)
# initialize offline dataset
dataset_id = f"xarm_{cfg.task}_medium"
num_traj_per_batch = cfg.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
sampler = PrioritizedSliceSampler(
max_capacity=100_000,
alpha=0.7,
beta=0.9,
num_slices=num_traj_per_batch,
strict_length=False,
)
# TODO(rcadene): use PrioritizedReplayBuffer
offline_buffer = SimxarmExperienceReplay(
dataset_id,
# download="force",
download=True,
streaming=False,
root="data",
sampler=sampler,
)
num_steps = len(offline_buffer)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
# offline_buffer._storage.device = torch.device("cuda")
# offline_buffer._storage._storage.to(torch.device("cuda"))
# TODO(rcadene): add online_buffer
# Observation encoder
# Dynamics predictor
# Reward predictor
# Policy
# Qs state-action value predictor
# V state value predictor
L = Logger(cfg.log_dir, cfg)
episode_idx = 0
start_time = time.time()
step = 0
last_log_step = 0
last_save_step = 0
while step < cfg.train_steps:
is_offline = True
num_updates = cfg.episode_length
_step = step + num_updates
rollout_metrics = {}
# if step >= cfg.offline_steps:
# is_offline = False
# # Collect trajectory
# obs = env.reset()
# episode = Episode(cfg, obs)
# success = False
# while not episode.done:
# action = policy.act(obs, step=step, t0=episode.first)
# obs, reward, done, info = env.step(action.cpu().numpy())
# reward = reward_normalizer(reward)
# mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0
# success = info.get('success', False)
# episode += (obs, action, reward, done, mask, success)
# assert len(episode) <= cfg.episode_length
# buffer += episode
# episode_idx += 1
# rollout_metrics = {
# 'episode_reward': episode.cumulative_reward,
# 'episode_success': float(success),
# 'episode_length': len(episode)
# }
# num_updates = len(episode) * cfg.utd
# _step = min(step + len(episode), cfg.train_steps)
# Update model
train_metrics = {}
if is_offline:
for i in range(num_updates):
train_metrics.update(policy.update(offline_buffer, step + i))
# else:
# for i in range(num_updates):
# train_metrics.update(
# policy.update(buffer, step + i // cfg.utd,
# demo_buffer=offline_buffer if cfg.balanced_sampling else None)
# )
# Log training metrics
env_step = int(_step * cfg.action_repeat)
common_metrics = {
"episode": episode_idx,
"step": _step,
"env_step": env_step,
"total_time": time.time() - start_time,
"is_offline": float(is_offline),
}
train_metrics.update(common_metrics)
train_metrics.update(rollout_metrics)
L.log(train_metrics, category="train")
# Evaluate policy periodically
if step == 0 or env_step - last_log_step >= cfg.eval_freq:
eval_metrics = eval_policy(
env,
td_policy,
num_episodes=cfg.eval_episodes,
# TODO(rcadene): add step, env_step, L.video
)
# TODO(rcadene):
# if hasattr(env, "get_normalized_score"):
# eval_metrics['normalized_score'] = env.get_normalized_score(eval_metrics["episode_reward"]) * 100.0
common_metrics.update(eval_metrics)
L.log(common_metrics, category="eval")
last_log_step = env_step - env_step % cfg.eval_freq
# Save model periodically
# if cfg.save_model and env_step - last_save_step >= cfg.save_freq:
# L.save_model(policy, identifier=env_step)
# print(f"Model has been checkpointed at step {env_step}")
# last_save_step = env_step - env_step % cfg.save_freq
# if cfg.save_model and is_offline and _step >= cfg.offline_steps:
# # save the model after offline training
# L.save_model(policy, identifier="offline")
step = _step
# dataset_d4rl = D4RLExperienceReplay(
# dataset_id="maze2d-umaze-v1",
# split_trajs=False,
# batch_size=1,
# sampler=SamplerWithoutReplacement(drop_last=False),
# prefetch=4,
# direct_download=True,
# )
# dataset_openx = OpenXExperienceReplay(
# "cmu_stretch",
# batch_size=1,
# num_slices=1,
# #download="force",
# streaming=False,
# root="data",
# )
if __name__ == "__main__":

View File

@ -1,80 +0,0 @@
import pickle
from pathlib import Path
import imageio
import simxarm
if __name__ == "__main__":
task = "lift"
dataset_dir = Path(f"data/xarm_{task}_medium")
dataset_path = dataset_dir / f"buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
required_keys = [
"observations",
"next_observations",
"actions",
"rewards",
"dones",
"masks",
]
for k in required_keys:
if k not in dataset_dict and k[:-1] in dataset_dict:
dataset_dict[k] = dataset_dict.pop(k[:-1])
out_dir = Path("tmp/2023_01_26_xarm_lift_medium")
out_dir.mkdir(parents=True, exist_ok=True)
frames = dataset_dict["observations"]["rgb"][:100]
frames = frames.transpose(0, 2, 3, 1)
imageio.mimsave(out_dir / "test.mp4", frames, fps=30)
frames = []
cfg = {}
env = simxarm.make(
task=task,
obs_mode="all",
image_size=84,
action_repeat=cfg.get("action_repeat", 1),
frame_stack=cfg.get("frame_stack", 1),
seed=1,
)
obs = env.reset()
frame = env.render(mode="rgb_array", width=384, height=384)
frames.append(frame)
# def is_first_obs(obs):
# nonlocal first_obs
# print(((dataset_dict["observations"]["state"][i]-obs["state"])**2).sum())
# print(((dataset_dict["observations"]["rgb"][i]-obs["rgb"])**2).sum())
for i in range(25):
action = dataset_dict["actions"][i]
print(f"#{i}")
# print(obs["state"])
# print(dataset_dict["observations"]["state"][i])
print(((dataset_dict["observations"]["state"][i] - obs["state"]) ** 2).sum())
print(((dataset_dict["observations"]["rgb"][i] - obs["rgb"]) ** 2).sum())
obs, reward, done, info = env.step(action)
frame = env.render(mode="rgb_array", width=384, height=384)
frames.append(frame)
print(reward)
print(dataset_dict["rewards"][i])
print(done)
print(dataset_dict["dones"][i])
if dataset_dict["dones"][i]:
obs = env.reset()
frame = env.render(mode="rgb_array", width=384, height=384)
frames.append(frame)
# imageio.mimsave(out_dir / 'test_rollout.mp4', frames, fps=60)

View File

@ -0,0 +1,59 @@
import pickle
from pathlib import Path
import imageio
import simxarm
import torch
from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
SliceSampler,
SliceSamplerWithoutReplacement,
)
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"):
sampler = SliceSamplerWithoutReplacement(
num_slices=1,
strict_length=False,
shuffle=False,
)
dataset = SimxarmExperienceReplay(
dataset_id,
# download="force",
download=True,
streaming=False,
root="data",
sampler=sampler,
)
NUM_EPISODES_TO_RENDER = 10
MAX_NUM_STEPS = 50
FIRST_FRAME = 0
for _ in range(NUM_EPISODES_TO_RENDER):
episode = dataset.sample(MAX_NUM_STEPS)
ep_idx = episode["episode"][FIRST_FRAME].item()
ep_frames = torch.cat(
[
episode["observation"]["image"][FIRST_FRAME][None, ...],
episode["next", "observation"]["image"],
],
dim=0,
)
video_dir = Path("tmp/2024_02_03_xarm_lift_medium")
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{ep_idx}.mp4"
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=15)
# ran out of episodes
if dataset._sampler._sample_list.numel() == 0:
break
if __name__ == "__main__":
visualize_simxarm_dataset()