Add pusht dataset (TODO verify reward is aligned), Refactor visualize_dataset, Add video_dir, fps, state_dim, action_dim to config (Training works)

This commit is contained in:
Cadene 2024-02-21 00:49:40 +00:00
parent 3dc14b5576
commit ece89730e6
8 changed files with 222 additions and 111 deletions

View File

@ -5,18 +5,21 @@ from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
def make_offline_buffer(cfg): def make_offline_buffer(cfg, sampler=None):
num_traj_per_batch = cfg.batch_size # // cfg.horizon overwrite_sampler = sampler is not None
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. if not overwrite_sampler:
sampler = PrioritizedSliceSampler( num_traj_per_batch = cfg.batch_size # // cfg.horizon
max_capacity=100_000, # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
alpha=cfg.per_alpha, # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
beta=cfg.per_beta, sampler = PrioritizedSliceSampler(
num_slices=num_traj_per_batch, max_capacity=100_000,
strict_length=False, alpha=cfg.per_alpha,
) beta=cfg.per_beta,
num_slices=num_traj_per_batch,
strict_length=False,
)
if cfg.env == "simxarm": if cfg.env == "simxarm":
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
@ -30,9 +33,9 @@ def make_offline_buffer(cfg):
) )
elif cfg.env == "pusht": elif cfg.env == "pusht":
offline_buffer = PushtExperienceReplay( offline_buffer = PushtExperienceReplay(
f"xarm_{cfg.task}_medium", "pusht",
# download="force", # download="force",
download=True, download=False,
streaming=False, streaming=False,
root="data", root="data",
sampler=sampler, sampler=sampler,
@ -40,8 +43,9 @@ def make_offline_buffer(cfg):
else: else:
raise ValueError(cfg.env) raise ValueError(cfg.env)
num_steps = len(offline_buffer) if not overwrite_sampler:
index = torch.arange(0, num_steps, 1) num_steps = len(offline_buffer)
sampler.extend(index) index = torch.arange(0, num_steps, 1)
sampler.extend(index)
return offline_buffer return offline_buffer

View File

@ -3,9 +3,15 @@ import pickle
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Tuple from typing import Any, Callable, Dict, Tuple
import einops
import numpy as np
import pygame
import pymunk
import torch import torch
import torchrl import torchrl
import tqdm import tqdm
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import ( from torchrl.data.replay_buffers.replay_buffers import (
@ -20,12 +26,71 @@ from torchrl.data.replay_buffers.samplers import (
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
def get_goal_pose_body(pose):
mass = 1
inertia = pymunk.moment_for_box(mass, (50, 100))
body = pymunk.Body(mass, inertia)
# preserving the legacy assignment order for compatibility
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
body.position = pose[:2].tolist()
body.angle = pose[2]
return body
def add_segment(space, a, b, radius):
shape = pymunk.Segment(space.static_body, a, b, radius)
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
return shape
def add_tee(
space,
position,
angle,
scale=30,
color="LightSlateGray",
mask=pymunk.ShapeFilter.ALL_MASKS(),
):
mass = 1
length = 4
vertices1 = [
(-length * scale / 2, scale),
(length * scale / 2, scale),
(length * scale / 2, 0),
(-length * scale / 2, 0),
]
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
vertices2 = [
(-scale / 2, scale),
(-scale / 2, length * scale),
(scale / 2, length * scale),
(scale / 2, scale),
]
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
body = pymunk.Body(mass, inertia1 + inertia2)
shape1 = pymunk.Poly(body, vertices1)
shape2 = pymunk.Poly(body, vertices2)
shape1.color = pygame.Color(color)
shape2.color = pygame.Color(color)
shape1.filter = pymunk.ShapeFilter(mask=mask)
shape2.filter = pymunk.ShapeFilter(mask=mask)
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
body.position = position
body.angle = angle
body.friction = 1
space.add(body, shape1, shape2)
return body
class PushtExperienceReplay(TensorDictReplayBuffer): class PushtExperienceReplay(TensorDictReplayBuffer):
available_datasets = [ # available_datasets = [
"xarm_lift_medium", # "xarm_lift_medium",
] # ]
def __init__( def __init__(
self, self,
@ -49,8 +114,6 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
split_trajs: bool = False, split_trajs: bool = False,
strict_length: bool = True, strict_length: bool = True,
): ):
# TODO
raise NotImplementedError()
self.download = download self.download = download
if streaming: if streaming:
raise NotImplementedError raise NotImplementedError
@ -68,8 +131,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
if split_trajs: if split_trajs:
raise NotImplementedError raise NotImplementedError
if self.download == True:
raise NotImplementedError()
if root is None: if root is None:
root = _get_root_dir("simxarm") root = _get_root_dir("pusht")
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
self.root = Path(root) self.root = Path(root)
if self.download == "force" or (self.download and not self._is_downloaded()): if self.download == "force" or (self.download and not self._is_downloaded()):
@ -77,29 +143,29 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
else: else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
if num_slices is not None or slice_len is not None: # if num_slices is not None or slice_len is not None:
if sampler is not None: # if sampler is not None:
raise ValueError( # raise ValueError(
"`num_slices` and `slice_len` are exclusive with the `sampler` argument." # "`num_slices` and `slice_len` are exclusive with the `sampler` argument."
) # )
if replacement: # if replacement:
if not self.shuffle: # if not self.shuffle:
raise RuntimeError( # raise RuntimeError(
"shuffle=False can only be used when replacement=False." # "shuffle=False can only be used when replacement=False."
) # )
sampler = SliceSampler( # sampler = SliceSampler(
num_slices=num_slices, # num_slices=num_slices,
slice_len=slice_len, # slice_len=slice_len,
strict_length=strict_length, # strict_length=strict_length,
) # )
else: # else:
sampler = SliceSamplerWithoutReplacement( # sampler = SliceSamplerWithoutReplacement(
num_slices=num_slices, # num_slices=num_slices,
slice_len=slice_len, # slice_len=slice_len,
strict_length=strict_length, # strict_length=strict_length,
shuffle=self.shuffle, # shuffle=self.shuffle,
) # )
if writer is None: if writer is None:
writer = ImmutableDatasetWriter() writer = ImmutableDatasetWriter()
@ -131,49 +197,82 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
# TODO(rcadene) # TODO(rcadene)
# load # load
dataset_dir = Path("data") / self.dataset_id zarr_path = (
dataset_path = dataset_dir / f"buffer.pkl" "/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
print(f"Using offline dataset '{dataset_path}'") )
with open(dataset_path, "rb") as f: dataset_dict = ReplayBuffer.copy_from_path(
dataset_dict = pickle.load(f) zarr_path
) # , keys=['img', 'state', 'action'])
total_frames = dataset_dict["actions"].shape[0] episode_ids = dataset_dict.get_episode_idxs()
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
assert len(
set([dataset_dict[key].shape[0] for key in dataset_dict.keys()])
), "Some data type dont have the same number of total frames."
# TODO: verify that goal pose is expected to be fixed
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
goal_body = get_goal_pose_body(goal_pos_angle)
idx0 = 0 idx0 = 0
idx1 = 0 idxtd = 0
episode_id = 0 for episode_id in tqdm.tqdm(range(num_episodes)):
for i in tqdm.tqdm(range(total_frames)): idx1 = dataset_dict.meta["episode_ends"][episode_id]
idx1 += 1
if not dataset_dict["dones"][i]:
continue
num_frames = idx1 - idx0 num_frames = idx1 - idx0
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) assert (episode_ids[idx0:idx1] == episode_id).all()
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
next_image = torch.tensor( image = torch.from_numpy(dataset_dict["img"][idx0:idx1])
dataset_dict["next_observations"]["rgb"][idx0:idx1] image = einops.rearrange(image, "b h w c -> b c h w")
)
next_state = torch.tensor( state = torch.from_numpy(dataset_dict["state"][idx0:idx1])
dataset_dict["next_observations"]["state"][idx0:idx1] agent_pos = state[:, :2]
) block_pos = state[:, 2:4]
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) block_angle = state[:, 4]
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
reward = torch.zeros(num_frames, 1)
done = torch.zeros(num_frames, 1, dtype=torch.bool)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0
space.damping = 0
# Add walls.
walls = [
add_segment(space, (5, 506), (5, 5), 2),
add_segment(space, (5, 5), (506, 5), 2),
add_segment(space, (506, 5), (506, 506), 2),
add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
block_body = add_tee(
space, block_pos[i].tolist(), block_angle[i].item()
)
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage = intersection_area / goal_area
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
done[i] = coverage > SUCCESS_THRESHOLD
episode = TensorDict( episode = TensorDict(
{ {
("observation", "image"): image, ("observation", "image"): image[:-1],
("observation", "state"): state, ("observation", "state"): agent_pos[:-1],
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]), "action": torch.from_numpy(dataset_dict["action"][idx0:idx1])[:-1],
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), "episode": torch.from_numpy(episode_ids[idx0:idx1])[:-1],
"frame_id": torch.arange(0, num_frames, 1), "frame_id": torch.arange(0, num_frames - 1, 1),
("next", "observation", "image"): next_image, ("next", "observation", "image"): image[1:],
("next", "observation", "state"): next_state, ("next", "observation", "state"): agent_pos[1:],
("next", "observation", "reward"): next_reward, # TODO: verify that reward and done are aligned with image and agent_pos
("next", "observation", "done"): next_done, ("next", "reward"): reward[1:],
("next", "done"): done[1:],
}, },
batch_size=num_frames, batch_size=num_frames - 1,
) )
if episode_id == 0: if episode_id == 0:
@ -184,9 +283,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
.memmap_like(self.root / self.dataset_id) .memmap_like(self.root / self.dataset_id)
) )
td_data[idx0:idx1] = episode td_data[idxtd : idxtd + len(episode)] = episode
episode_id += 1
idx0 = idx1 idx0 = idx1
idxtd = idxtd + len(episode)
return TensorStorage(td_data.lock_()) return TensorStorage(td_data.lock_())

View File

@ -13,7 +13,7 @@ class TOLD(nn.Module):
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
action_dim = 4 action_dim = cfg.action_dim
self.cfg = cfg self.cfg = cfg
self._encoder = h.enc(cfg) self._encoder = h.enc(cfg)
@ -82,7 +82,7 @@ class TDMPC(nn.Module):
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
self.action_dim = 4 self.action_dim = cfg.action_dim
self.cfg = cfg self.cfg = cfg
self.device = torch.device("cuda") self.device = torch.device("cuda")

View File

@ -130,7 +130,7 @@ class Flatten(nn.Module):
def enc(cfg): def enc(cfg):
obs_shape = { obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size), "rgb": (3, cfg.img_size, cfg.img_size),
"state": (4,), "state": (cfg.state_dim,),
} }
"""Returns a TOLD encoder.""" """Returns a TOLD encoder."""
@ -209,7 +209,7 @@ def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
def q(cfg): def q(cfg):
action_dim = 4 action_dim = cfg.action_dim
"""Returns a Q-function that uses Layer Normalization.""" """Returns a Q-function that uses Layer Normalization."""
return nn.Sequential( return nn.Sequential(
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim), nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
@ -331,7 +331,7 @@ class Episode(object):
"""Storage object for a single episode.""" """Storage object for a single episode."""
def __init__(self, cfg, init_obs): def __init__(self, cfg, init_obs):
action_dim = 4 action_dim = cfg.action_dim
self.cfg = cfg self.cfg = cfg
self.device = torch.device(cfg.buffer_device) self.device = torch.device(cfg.buffer_device)
@ -447,8 +447,8 @@ class ReplayBuffer:
""" """
def __init__(self, cfg, dataset=None): def __init__(self, cfg, dataset=None):
action_dim = 4 action_dim = cfg.action_dim
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (4,)} obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (cfg.state_dim,)}
self.cfg = cfg self.cfg = cfg
self.device = torch.device(cfg.buffer_device) self.device = torch.device(cfg.buffer_device)

View File

@ -1,5 +1,6 @@
seed: 1337 seed: 1337
log_dir: logs/2024_01_26_train log_dir: logs/2024_01_26_train
video_dir: tmp/2024_01_26_xarm_lift_medium
exp_name: default exp_name: default
device: cuda device: cuda
buffer_device: cuda buffer_device: cuda
@ -16,6 +17,7 @@ task: lift
from_pixels: True from_pixels: True
pixels_only: False pixels_only: False
image_size: 84 image_size: 84
fps: 15
reward_scale: 1.0 reward_scale: 1.0
@ -30,7 +32,8 @@ train_steps: 50000
frame_stack: 1 frame_stack: 1
num_channels: 32 num_channels: 32
img_size: ${image_size} img_size: ${image_size}
state_dim: 4
action_dim: 4
# TDMPC # TDMPC
@ -97,4 +100,3 @@ latent_dim: 50
use_wandb: false use_wandb: false
wandb_project: FOWM wandb_project: FOWM
wandb_entity: rcadene # insert your own wandb_entity: rcadene # insert your own

View File

@ -5,8 +5,12 @@ hydra:
job: job:
name: pusht name: pusht
video_dir: tmp/2024_02_21_pusht
# env # env
env: pusht env: pusht
image_size: 96 image_size: 96
frame_skip: 1 frame_skip: 1
state_dim: 2
action_dim: 2
fps: 10

View File

@ -20,6 +20,7 @@ def eval_policy(
max_steps: int = 30, max_steps: int = 30,
save_video: bool = False, save_video: bool = False,
video_dir: Path = None, video_dir: Path = None,
fps: int = 15,
): ):
rewards = [] rewards = []
successes = [] successes = []
@ -55,7 +56,7 @@ def eval_policy(
video_dir.mkdir(parents=True, exist_ok=True) video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable # TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{i}.mp4" video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, np.stack(ep_frames), fps=15) imageio.mimsave(video_path, np.stack(ep_frames), fps=fps)
metrics = { metrics = {
"avg_reward": np.nanmean(rewards), "avg_reward": np.nanmean(rewards),
@ -74,16 +75,13 @@ def eval(cfg: dict):
if cfg.pretrained_model_path: if cfg.pretrained_model_path:
policy = TDMPC(cfg) policy = TDMPC(cfg)
ckpt_path = (
"/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
)
if "offline" in cfg.pretrained_model_path: if "offline" in cfg.pretrained_model_path:
policy.step = 25000 policy.step = 25000
elif "final" in cfg.pretrained_model_path: elif "final" in cfg.pretrained_model_path:
policy.step = 100000 policy.step = 100000
else: else:
raise NotImplementedError() raise NotImplementedError()
policy.load(ckpt_path) policy.load(cfg.pretrained_model_path)
policy = TensorDictModule( policy = TensorDictModule(
policy, policy,
@ -99,7 +97,8 @@ def eval(cfg: dict):
policy=policy, policy=policy,
num_episodes=20, num_episodes=20,
save_video=True, save_video=True,
video_dir=Path("tmp/2023_02_19_pusht"), video_dir=Path(cfg.video_dir),
fps=cfg.fps,
) )
print(metrics) print(metrics)

View File

@ -1,6 +1,7 @@
import pickle import pickle
from pathlib import Path from pathlib import Path
import hydra
import imageio import imageio
import simxarm import simxarm
import torch import torch
@ -10,30 +11,25 @@ from torchrl.data.replay_buffers import (
SliceSamplerWithoutReplacement, SliceSamplerWithoutReplacement,
) )
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay from lerobot.common.datasets.factory import make_offline_buffer
def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"): @hydra.main(version_base=None, config_name="default", config_path="../configs")
def visualize_dataset(cfg: dict):
sampler = SliceSamplerWithoutReplacement( sampler = SliceSamplerWithoutReplacement(
num_slices=1, num_slices=1,
strict_length=False, strict_length=False,
shuffle=False, shuffle=False,
) )
dataset = SimxarmExperienceReplay( offline_buffer = make_offline_buffer(cfg, sampler)
dataset_id,
# download="force",
download=True,
streaming=False,
root="data",
sampler=sampler,
)
NUM_EPISODES_TO_RENDER = 10 NUM_EPISODES_TO_RENDER = 10
MAX_NUM_STEPS = 50 MAX_NUM_STEPS = 1000
FIRST_FRAME = 0 FIRST_FRAME = 0
for _ in range(NUM_EPISODES_TO_RENDER): for _ in range(NUM_EPISODES_TO_RENDER):
episode = dataset.sample(MAX_NUM_STEPS) episode = offline_buffer.sample(MAX_NUM_STEPS)
ep_idx = episode["episode"][FIRST_FRAME].item() ep_idx = episode["episode"][FIRST_FRAME].item()
ep_frames = torch.cat( ep_frames = torch.cat(
@ -44,16 +40,23 @@ def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"):
dim=0, dim=0,
) )
video_dir = Path("tmp/2024_02_03_xarm_lift_medium") video_dir = Path(cfg.video_dir)
video_dir.mkdir(parents=True, exist_ok=True) video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable # TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{ep_idx}.mp4" video_path = video_dir / f"eval_episode_{ep_idx}.mp4"
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=15)
assert ep_frames.min().item() >= 0
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
assert ep_frames.max().item() <= 255
ep_frames = ep_frames.type(torch.uint8)
imageio.mimsave(
video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps
)
# ran out of episodes # ran out of episodes
if dataset._sampler._sample_list.numel() == 0: if offline_buffer._sampler._sample_list.numel() == 0:
break break
if __name__ == "__main__": if __name__ == "__main__":
visualize_simxarm_dataset() visualize_dataset()