Add pusht dataset (TODO verify reward is aligned), Refactor visualize_dataset, Add video_dir, fps, state_dim, action_dim to config (Training works)
This commit is contained in:
parent
3dc14b5576
commit
ece89730e6
|
@ -5,8 +5,11 @@ from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||||
from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
|
from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
|
||||||
|
|
||||||
|
|
||||||
def make_offline_buffer(cfg):
|
def make_offline_buffer(cfg, sampler=None):
|
||||||
|
|
||||||
|
overwrite_sampler = sampler is not None
|
||||||
|
|
||||||
|
if not overwrite_sampler:
|
||||||
num_traj_per_batch = cfg.batch_size # // cfg.horizon
|
num_traj_per_batch = cfg.batch_size # // cfg.horizon
|
||||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||||
|
@ -30,9 +33,9 @@ def make_offline_buffer(cfg):
|
||||||
)
|
)
|
||||||
elif cfg.env == "pusht":
|
elif cfg.env == "pusht":
|
||||||
offline_buffer = PushtExperienceReplay(
|
offline_buffer = PushtExperienceReplay(
|
||||||
f"xarm_{cfg.task}_medium",
|
"pusht",
|
||||||
# download="force",
|
# download="force",
|
||||||
download=True,
|
download=False,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
root="data",
|
root="data",
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
@ -40,6 +43,7 @@ def make_offline_buffer(cfg):
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env)
|
raise ValueError(cfg.env)
|
||||||
|
|
||||||
|
if not overwrite_sampler:
|
||||||
num_steps = len(offline_buffer)
|
num_steps = len(offline_buffer)
|
||||||
index = torch.arange(0, num_steps, 1)
|
index = torch.arange(0, num_steps, 1)
|
||||||
sampler.extend(index)
|
sampler.extend(index)
|
||||||
|
|
|
@ -3,9 +3,15 @@ import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Tuple
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import numpy as np
|
||||||
|
import pygame
|
||||||
|
import pymunk
|
||||||
import torch
|
import torch
|
||||||
import torchrl
|
import torchrl
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||||
|
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.data.datasets.utils import _get_root_dir
|
from torchrl.data.datasets.utils import _get_root_dir
|
||||||
from torchrl.data.replay_buffers.replay_buffers import (
|
from torchrl.data.replay_buffers.replay_buffers import (
|
||||||
|
@ -20,12 +26,71 @@ from torchrl.data.replay_buffers.samplers import (
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||||
|
|
||||||
|
# as define in env
|
||||||
|
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||||
|
|
||||||
|
|
||||||
|
def get_goal_pose_body(pose):
|
||||||
|
mass = 1
|
||||||
|
inertia = pymunk.moment_for_box(mass, (50, 100))
|
||||||
|
body = pymunk.Body(mass, inertia)
|
||||||
|
# preserving the legacy assignment order for compatibility
|
||||||
|
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
|
||||||
|
body.position = pose[:2].tolist()
|
||||||
|
body.angle = pose[2]
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def add_segment(space, a, b, radius):
|
||||||
|
shape = pymunk.Segment(space.static_body, a, b, radius)
|
||||||
|
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
|
||||||
|
return shape
|
||||||
|
|
||||||
|
|
||||||
|
def add_tee(
|
||||||
|
space,
|
||||||
|
position,
|
||||||
|
angle,
|
||||||
|
scale=30,
|
||||||
|
color="LightSlateGray",
|
||||||
|
mask=pymunk.ShapeFilter.ALL_MASKS(),
|
||||||
|
):
|
||||||
|
mass = 1
|
||||||
|
length = 4
|
||||||
|
vertices1 = [
|
||||||
|
(-length * scale / 2, scale),
|
||||||
|
(length * scale / 2, scale),
|
||||||
|
(length * scale / 2, 0),
|
||||||
|
(-length * scale / 2, 0),
|
||||||
|
]
|
||||||
|
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||||
|
vertices2 = [
|
||||||
|
(-scale / 2, scale),
|
||||||
|
(-scale / 2, length * scale),
|
||||||
|
(scale / 2, length * scale),
|
||||||
|
(scale / 2, scale),
|
||||||
|
]
|
||||||
|
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||||
|
body = pymunk.Body(mass, inertia1 + inertia2)
|
||||||
|
shape1 = pymunk.Poly(body, vertices1)
|
||||||
|
shape2 = pymunk.Poly(body, vertices2)
|
||||||
|
shape1.color = pygame.Color(color)
|
||||||
|
shape2.color = pygame.Color(color)
|
||||||
|
shape1.filter = pymunk.ShapeFilter(mask=mask)
|
||||||
|
shape2.filter = pymunk.ShapeFilter(mask=mask)
|
||||||
|
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
|
||||||
|
body.position = position
|
||||||
|
body.angle = angle
|
||||||
|
body.friction = 1
|
||||||
|
space.add(body, shape1, shape2)
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
class PushtExperienceReplay(TensorDictReplayBuffer):
|
class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
|
|
||||||
available_datasets = [
|
# available_datasets = [
|
||||||
"xarm_lift_medium",
|
# "xarm_lift_medium",
|
||||||
]
|
# ]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -49,8 +114,6 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
split_trajs: bool = False,
|
split_trajs: bool = False,
|
||||||
strict_length: bool = True,
|
strict_length: bool = True,
|
||||||
):
|
):
|
||||||
# TODO
|
|
||||||
raise NotImplementedError()
|
|
||||||
self.download = download
|
self.download = download
|
||||||
if streaming:
|
if streaming:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -68,8 +131,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
if split_trajs:
|
if split_trajs:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if self.download == True:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
if root is None:
|
if root is None:
|
||||||
root = _get_root_dir("simxarm")
|
root = _get_root_dir("pusht")
|
||||||
os.makedirs(root, exist_ok=True)
|
os.makedirs(root, exist_ok=True)
|
||||||
self.root = Path(root)
|
self.root = Path(root)
|
||||||
if self.download == "force" or (self.download and not self._is_downloaded()):
|
if self.download == "force" or (self.download and not self._is_downloaded()):
|
||||||
|
@ -77,29 +143,29 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
else:
|
else:
|
||||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||||
|
|
||||||
if num_slices is not None or slice_len is not None:
|
# if num_slices is not None or slice_len is not None:
|
||||||
if sampler is not None:
|
# if sampler is not None:
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
|
# "`num_slices` and `slice_len` are exclusive with the `sampler` argument."
|
||||||
)
|
# )
|
||||||
|
|
||||||
if replacement:
|
# if replacement:
|
||||||
if not self.shuffle:
|
# if not self.shuffle:
|
||||||
raise RuntimeError(
|
# raise RuntimeError(
|
||||||
"shuffle=False can only be used when replacement=False."
|
# "shuffle=False can only be used when replacement=False."
|
||||||
)
|
# )
|
||||||
sampler = SliceSampler(
|
# sampler = SliceSampler(
|
||||||
num_slices=num_slices,
|
# num_slices=num_slices,
|
||||||
slice_len=slice_len,
|
# slice_len=slice_len,
|
||||||
strict_length=strict_length,
|
# strict_length=strict_length,
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
sampler = SliceSamplerWithoutReplacement(
|
# sampler = SliceSamplerWithoutReplacement(
|
||||||
num_slices=num_slices,
|
# num_slices=num_slices,
|
||||||
slice_len=slice_len,
|
# slice_len=slice_len,
|
||||||
strict_length=strict_length,
|
# strict_length=strict_length,
|
||||||
shuffle=self.shuffle,
|
# shuffle=self.shuffle,
|
||||||
)
|
# )
|
||||||
|
|
||||||
if writer is None:
|
if writer is None:
|
||||||
writer = ImmutableDatasetWriter()
|
writer = ImmutableDatasetWriter()
|
||||||
|
@ -131,49 +197,82 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
# TODO(rcadene)
|
# TODO(rcadene)
|
||||||
|
|
||||||
# load
|
# load
|
||||||
dataset_dir = Path("data") / self.dataset_id
|
zarr_path = (
|
||||||
dataset_path = dataset_dir / f"buffer.pkl"
|
"/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
|
||||||
print(f"Using offline dataset '{dataset_path}'")
|
)
|
||||||
with open(dataset_path, "rb") as f:
|
dataset_dict = ReplayBuffer.copy_from_path(
|
||||||
dataset_dict = pickle.load(f)
|
zarr_path
|
||||||
|
) # , keys=['img', 'state', 'action'])
|
||||||
|
|
||||||
total_frames = dataset_dict["actions"].shape[0]
|
episode_ids = dataset_dict.get_episode_idxs()
|
||||||
|
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||||
|
total_frames = dataset_dict["action"].shape[0]
|
||||||
|
assert len(
|
||||||
|
set([dataset_dict[key].shape[0] for key in dataset_dict.keys()])
|
||||||
|
), "Some data type dont have the same number of total frames."
|
||||||
|
|
||||||
|
# TODO: verify that goal pose is expected to be fixed
|
||||||
|
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||||
|
goal_body = get_goal_pose_body(goal_pos_angle)
|
||||||
|
|
||||||
idx0 = 0
|
idx0 = 0
|
||||||
idx1 = 0
|
idxtd = 0
|
||||||
episode_id = 0
|
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||||
for i in tqdm.tqdm(range(total_frames)):
|
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||||
idx1 += 1
|
|
||||||
|
|
||||||
if not dataset_dict["dones"][i]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
num_frames = idx1 - idx0
|
num_frames = idx1 - idx0
|
||||||
|
|
||||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
|
||||||
next_image = torch.tensor(
|
image = torch.from_numpy(dataset_dict["img"][idx0:idx1])
|
||||||
dataset_dict["next_observations"]["rgb"][idx0:idx1]
|
image = einops.rearrange(image, "b h w c -> b c h w")
|
||||||
|
|
||||||
|
state = torch.from_numpy(dataset_dict["state"][idx0:idx1])
|
||||||
|
agent_pos = state[:, :2]
|
||||||
|
block_pos = state[:, 2:4]
|
||||||
|
block_angle = state[:, 4]
|
||||||
|
|
||||||
|
reward = torch.zeros(num_frames, 1)
|
||||||
|
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||||
|
for i in range(num_frames):
|
||||||
|
space = pymunk.Space()
|
||||||
|
space.gravity = 0, 0
|
||||||
|
space.damping = 0
|
||||||
|
|
||||||
|
# Add walls.
|
||||||
|
walls = [
|
||||||
|
add_segment(space, (5, 506), (5, 5), 2),
|
||||||
|
add_segment(space, (5, 5), (506, 5), 2),
|
||||||
|
add_segment(space, (506, 5), (506, 506), 2),
|
||||||
|
add_segment(space, (5, 506), (506, 506), 2),
|
||||||
|
]
|
||||||
|
space.add(*walls)
|
||||||
|
|
||||||
|
block_body = add_tee(
|
||||||
|
space, block_pos[i].tolist(), block_angle[i].item()
|
||||||
)
|
)
|
||||||
next_state = torch.tensor(
|
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||||
dataset_dict["next_observations"]["state"][idx0:idx1]
|
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||||
)
|
intersection_area = goal_geom.intersection(block_geom).area
|
||||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
goal_area = goal_geom.area
|
||||||
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
coverage = intersection_area / goal_area
|
||||||
|
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
|
||||||
|
done[i] = coverage > SUCCESS_THRESHOLD
|
||||||
|
|
||||||
episode = TensorDict(
|
episode = TensorDict(
|
||||||
{
|
{
|
||||||
("observation", "image"): image,
|
("observation", "image"): image[:-1],
|
||||||
("observation", "state"): state,
|
("observation", "state"): agent_pos[:-1],
|
||||||
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
|
"action": torch.from_numpy(dataset_dict["action"][idx0:idx1])[:-1],
|
||||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
"episode": torch.from_numpy(episode_ids[idx0:idx1])[:-1],
|
||||||
"frame_id": torch.arange(0, num_frames, 1),
|
"frame_id": torch.arange(0, num_frames - 1, 1),
|
||||||
("next", "observation", "image"): next_image,
|
("next", "observation", "image"): image[1:],
|
||||||
("next", "observation", "state"): next_state,
|
("next", "observation", "state"): agent_pos[1:],
|
||||||
("next", "observation", "reward"): next_reward,
|
# TODO: verify that reward and done are aligned with image and agent_pos
|
||||||
("next", "observation", "done"): next_done,
|
("next", "reward"): reward[1:],
|
||||||
|
("next", "done"): done[1:],
|
||||||
},
|
},
|
||||||
batch_size=num_frames,
|
batch_size=num_frames - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
if episode_id == 0:
|
if episode_id == 0:
|
||||||
|
@ -184,9 +283,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
.memmap_like(self.root / self.dataset_id)
|
.memmap_like(self.root / self.dataset_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
td_data[idx0:idx1] = episode
|
td_data[idxtd : idxtd + len(episode)] = episode
|
||||||
|
|
||||||
episode_id += 1
|
|
||||||
idx0 = idx1
|
idx0 = idx1
|
||||||
|
idxtd = idxtd + len(episode)
|
||||||
|
|
||||||
return TensorStorage(td_data.lock_())
|
return TensorStorage(td_data.lock_())
|
||||||
|
|
|
@ -13,7 +13,7 @@ class TOLD(nn.Module):
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
action_dim = 4
|
action_dim = cfg.action_dim
|
||||||
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self._encoder = h.enc(cfg)
|
self._encoder = h.enc(cfg)
|
||||||
|
@ -82,7 +82,7 @@ class TDMPC(nn.Module):
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.action_dim = 4
|
self.action_dim = cfg.action_dim
|
||||||
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.device = torch.device("cuda")
|
self.device = torch.device("cuda")
|
||||||
|
|
|
@ -130,7 +130,7 @@ class Flatten(nn.Module):
|
||||||
def enc(cfg):
|
def enc(cfg):
|
||||||
obs_shape = {
|
obs_shape = {
|
||||||
"rgb": (3, cfg.img_size, cfg.img_size),
|
"rgb": (3, cfg.img_size, cfg.img_size),
|
||||||
"state": (4,),
|
"state": (cfg.state_dim,),
|
||||||
}
|
}
|
||||||
|
|
||||||
"""Returns a TOLD encoder."""
|
"""Returns a TOLD encoder."""
|
||||||
|
@ -209,7 +209,7 @@ def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
|
||||||
|
|
||||||
|
|
||||||
def q(cfg):
|
def q(cfg):
|
||||||
action_dim = 4
|
action_dim = cfg.action_dim
|
||||||
"""Returns a Q-function that uses Layer Normalization."""
|
"""Returns a Q-function that uses Layer Normalization."""
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
|
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
|
||||||
|
@ -331,7 +331,7 @@ class Episode(object):
|
||||||
"""Storage object for a single episode."""
|
"""Storage object for a single episode."""
|
||||||
|
|
||||||
def __init__(self, cfg, init_obs):
|
def __init__(self, cfg, init_obs):
|
||||||
action_dim = 4
|
action_dim = cfg.action_dim
|
||||||
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.device = torch.device(cfg.buffer_device)
|
self.device = torch.device(cfg.buffer_device)
|
||||||
|
@ -447,8 +447,8 @@ class ReplayBuffer:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg, dataset=None):
|
def __init__(self, cfg, dataset=None):
|
||||||
action_dim = 4
|
action_dim = cfg.action_dim
|
||||||
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (4,)}
|
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (cfg.state_dim,)}
|
||||||
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.device = torch.device(cfg.buffer_device)
|
self.device = torch.device(cfg.buffer_device)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
seed: 1337
|
seed: 1337
|
||||||
log_dir: logs/2024_01_26_train
|
log_dir: logs/2024_01_26_train
|
||||||
|
video_dir: tmp/2024_01_26_xarm_lift_medium
|
||||||
exp_name: default
|
exp_name: default
|
||||||
device: cuda
|
device: cuda
|
||||||
buffer_device: cuda
|
buffer_device: cuda
|
||||||
|
@ -16,6 +17,7 @@ task: lift
|
||||||
from_pixels: True
|
from_pixels: True
|
||||||
pixels_only: False
|
pixels_only: False
|
||||||
image_size: 84
|
image_size: 84
|
||||||
|
fps: 15
|
||||||
|
|
||||||
reward_scale: 1.0
|
reward_scale: 1.0
|
||||||
|
|
||||||
|
@ -30,7 +32,8 @@ train_steps: 50000
|
||||||
frame_stack: 1
|
frame_stack: 1
|
||||||
num_channels: 32
|
num_channels: 32
|
||||||
img_size: ${image_size}
|
img_size: ${image_size}
|
||||||
|
state_dim: 4
|
||||||
|
action_dim: 4
|
||||||
|
|
||||||
# TDMPC
|
# TDMPC
|
||||||
|
|
||||||
|
@ -97,4 +100,3 @@ latent_dim: 50
|
||||||
use_wandb: false
|
use_wandb: false
|
||||||
wandb_project: FOWM
|
wandb_project: FOWM
|
||||||
wandb_entity: rcadene # insert your own
|
wandb_entity: rcadene # insert your own
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,12 @@ hydra:
|
||||||
job:
|
job:
|
||||||
name: pusht
|
name: pusht
|
||||||
|
|
||||||
|
video_dir: tmp/2024_02_21_pusht
|
||||||
|
|
||||||
# env
|
# env
|
||||||
env: pusht
|
env: pusht
|
||||||
image_size: 96
|
image_size: 96
|
||||||
frame_skip: 1
|
frame_skip: 1
|
||||||
|
state_dim: 2
|
||||||
|
action_dim: 2
|
||||||
|
fps: 10
|
|
@ -20,6 +20,7 @@ def eval_policy(
|
||||||
max_steps: int = 30,
|
max_steps: int = 30,
|
||||||
save_video: bool = False,
|
save_video: bool = False,
|
||||||
video_dir: Path = None,
|
video_dir: Path = None,
|
||||||
|
fps: int = 15,
|
||||||
):
|
):
|
||||||
rewards = []
|
rewards = []
|
||||||
successes = []
|
successes = []
|
||||||
|
@ -55,7 +56,7 @@ def eval_policy(
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
# TODO(rcadene): make fps configurable
|
# TODO(rcadene): make fps configurable
|
||||||
video_path = video_dir / f"eval_episode_{i}.mp4"
|
video_path = video_dir / f"eval_episode_{i}.mp4"
|
||||||
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
|
imageio.mimsave(video_path, np.stack(ep_frames), fps=fps)
|
||||||
|
|
||||||
metrics = {
|
metrics = {
|
||||||
"avg_reward": np.nanmean(rewards),
|
"avg_reward": np.nanmean(rewards),
|
||||||
|
@ -74,16 +75,13 @@ def eval(cfg: dict):
|
||||||
|
|
||||||
if cfg.pretrained_model_path:
|
if cfg.pretrained_model_path:
|
||||||
policy = TDMPC(cfg)
|
policy = TDMPC(cfg)
|
||||||
ckpt_path = (
|
|
||||||
"/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
|
||||||
)
|
|
||||||
if "offline" in cfg.pretrained_model_path:
|
if "offline" in cfg.pretrained_model_path:
|
||||||
policy.step = 25000
|
policy.step = 25000
|
||||||
elif "final" in cfg.pretrained_model_path:
|
elif "final" in cfg.pretrained_model_path:
|
||||||
policy.step = 100000
|
policy.step = 100000
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
policy.load(ckpt_path)
|
policy.load(cfg.pretrained_model_path)
|
||||||
|
|
||||||
policy = TensorDictModule(
|
policy = TensorDictModule(
|
||||||
policy,
|
policy,
|
||||||
|
@ -99,7 +97,8 @@ def eval(cfg: dict):
|
||||||
policy=policy,
|
policy=policy,
|
||||||
num_episodes=20,
|
num_episodes=20,
|
||||||
save_video=True,
|
save_video=True,
|
||||||
video_dir=Path("tmp/2023_02_19_pusht"),
|
video_dir=Path(cfg.video_dir),
|
||||||
|
fps=cfg.fps,
|
||||||
)
|
)
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import hydra
|
||||||
import imageio
|
import imageio
|
||||||
import simxarm
|
import simxarm
|
||||||
import torch
|
import torch
|
||||||
|
@ -10,30 +11,25 @@ from torchrl.data.replay_buffers import (
|
||||||
SliceSamplerWithoutReplacement,
|
SliceSamplerWithoutReplacement,
|
||||||
)
|
)
|
||||||
|
|
||||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
|
|
||||||
|
|
||||||
def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"):
|
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||||
|
def visualize_dataset(cfg: dict):
|
||||||
|
|
||||||
sampler = SliceSamplerWithoutReplacement(
|
sampler = SliceSamplerWithoutReplacement(
|
||||||
num_slices=1,
|
num_slices=1,
|
||||||
strict_length=False,
|
strict_length=False,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = SimxarmExperienceReplay(
|
offline_buffer = make_offline_buffer(cfg, sampler)
|
||||||
dataset_id,
|
|
||||||
# download="force",
|
|
||||||
download=True,
|
|
||||||
streaming=False,
|
|
||||||
root="data",
|
|
||||||
sampler=sampler,
|
|
||||||
)
|
|
||||||
|
|
||||||
NUM_EPISODES_TO_RENDER = 10
|
NUM_EPISODES_TO_RENDER = 10
|
||||||
MAX_NUM_STEPS = 50
|
MAX_NUM_STEPS = 1000
|
||||||
FIRST_FRAME = 0
|
FIRST_FRAME = 0
|
||||||
for _ in range(NUM_EPISODES_TO_RENDER):
|
for _ in range(NUM_EPISODES_TO_RENDER):
|
||||||
episode = dataset.sample(MAX_NUM_STEPS)
|
episode = offline_buffer.sample(MAX_NUM_STEPS)
|
||||||
|
|
||||||
ep_idx = episode["episode"][FIRST_FRAME].item()
|
ep_idx = episode["episode"][FIRST_FRAME].item()
|
||||||
ep_frames = torch.cat(
|
ep_frames = torch.cat(
|
||||||
|
@ -44,16 +40,23 @@ def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"):
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
video_dir = Path("tmp/2024_02_03_xarm_lift_medium")
|
video_dir = Path(cfg.video_dir)
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
# TODO(rcadene): make fps configurable
|
# TODO(rcadene): make fps configurable
|
||||||
video_path = video_dir / f"eval_episode_{ep_idx}.mp4"
|
video_path = video_dir / f"eval_episode_{ep_idx}.mp4"
|
||||||
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=15)
|
|
||||||
|
assert ep_frames.min().item() >= 0
|
||||||
|
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
|
||||||
|
assert ep_frames.max().item() <= 255
|
||||||
|
ep_frames = ep_frames.type(torch.uint8)
|
||||||
|
imageio.mimsave(
|
||||||
|
video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps
|
||||||
|
)
|
||||||
|
|
||||||
# ran out of episodes
|
# ran out of episodes
|
||||||
if dataset._sampler._sample_list.numel() == 0:
|
if offline_buffer._sampler._sample_list.numel() == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
visualize_simxarm_dataset()
|
visualize_dataset()
|
||||||
|
|
Loading…
Reference in New Issue