Add pusht dataset (TODO verify reward is aligned), Refactor visualize_dataset, Add video_dir, fps, state_dim, action_dim to config (Training works)

This commit is contained in:
Cadene 2024-02-21 00:49:40 +00:00
parent 3dc14b5576
commit ece89730e6
8 changed files with 222 additions and 111 deletions

View File

@ -5,18 +5,21 @@ from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
def make_offline_buffer(cfg):
def make_offline_buffer(cfg, sampler=None):
num_traj_per_batch = cfg.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
sampler = PrioritizedSliceSampler(
max_capacity=100_000,
alpha=cfg.per_alpha,
beta=cfg.per_beta,
num_slices=num_traj_per_batch,
strict_length=False,
)
overwrite_sampler = sampler is not None
if not overwrite_sampler:
num_traj_per_batch = cfg.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
sampler = PrioritizedSliceSampler(
max_capacity=100_000,
alpha=cfg.per_alpha,
beta=cfg.per_beta,
num_slices=num_traj_per_batch,
strict_length=False,
)
if cfg.env == "simxarm":
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
@ -30,9 +33,9 @@ def make_offline_buffer(cfg):
)
elif cfg.env == "pusht":
offline_buffer = PushtExperienceReplay(
f"xarm_{cfg.task}_medium",
"pusht",
# download="force",
download=True,
download=False,
streaming=False,
root="data",
sampler=sampler,
@ -40,8 +43,9 @@ def make_offline_buffer(cfg):
else:
raise ValueError(cfg.env)
num_steps = len(offline_buffer)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
if not overwrite_sampler:
num_steps = len(offline_buffer)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
return offline_buffer

View File

@ -3,9 +3,15 @@ import pickle
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
import einops
import numpy as np
import pygame
import pymunk
import torch
import torchrl
import tqdm
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import (
@ -20,12 +26,71 @@ from torchrl.data.replay_buffers.samplers import (
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
def get_goal_pose_body(pose):
mass = 1
inertia = pymunk.moment_for_box(mass, (50, 100))
body = pymunk.Body(mass, inertia)
# preserving the legacy assignment order for compatibility
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
body.position = pose[:2].tolist()
body.angle = pose[2]
return body
def add_segment(space, a, b, radius):
shape = pymunk.Segment(space.static_body, a, b, radius)
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
return shape
def add_tee(
space,
position,
angle,
scale=30,
color="LightSlateGray",
mask=pymunk.ShapeFilter.ALL_MASKS(),
):
mass = 1
length = 4
vertices1 = [
(-length * scale / 2, scale),
(length * scale / 2, scale),
(length * scale / 2, 0),
(-length * scale / 2, 0),
]
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
vertices2 = [
(-scale / 2, scale),
(-scale / 2, length * scale),
(scale / 2, length * scale),
(scale / 2, scale),
]
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
body = pymunk.Body(mass, inertia1 + inertia2)
shape1 = pymunk.Poly(body, vertices1)
shape2 = pymunk.Poly(body, vertices2)
shape1.color = pygame.Color(color)
shape2.color = pygame.Color(color)
shape1.filter = pymunk.ShapeFilter(mask=mask)
shape2.filter = pymunk.ShapeFilter(mask=mask)
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
body.position = position
body.angle = angle
body.friction = 1
space.add(body, shape1, shape2)
return body
class PushtExperienceReplay(TensorDictReplayBuffer):
available_datasets = [
"xarm_lift_medium",
]
# available_datasets = [
# "xarm_lift_medium",
# ]
def __init__(
self,
@ -49,8 +114,6 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
split_trajs: bool = False,
strict_length: bool = True,
):
# TODO
raise NotImplementedError()
self.download = download
if streaming:
raise NotImplementedError
@ -68,8 +131,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
if split_trajs:
raise NotImplementedError
if self.download == True:
raise NotImplementedError()
if root is None:
root = _get_root_dir("simxarm")
root = _get_root_dir("pusht")
os.makedirs(root, exist_ok=True)
self.root = Path(root)
if self.download == "force" or (self.download and not self._is_downloaded()):
@ -77,29 +143,29 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
if num_slices is not None or slice_len is not None:
if sampler is not None:
raise ValueError(
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
)
# if num_slices is not None or slice_len is not None:
# if sampler is not None:
# raise ValueError(
# "`num_slices` and `slice_len` are exclusive with the `sampler` argument."
# )
if replacement:
if not self.shuffle:
raise RuntimeError(
"shuffle=False can only be used when replacement=False."
)
sampler = SliceSampler(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
)
else:
sampler = SliceSamplerWithoutReplacement(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
shuffle=self.shuffle,
)
# if replacement:
# if not self.shuffle:
# raise RuntimeError(
# "shuffle=False can only be used when replacement=False."
# )
# sampler = SliceSampler(
# num_slices=num_slices,
# slice_len=slice_len,
# strict_length=strict_length,
# )
# else:
# sampler = SliceSamplerWithoutReplacement(
# num_slices=num_slices,
# slice_len=slice_len,
# strict_length=strict_length,
# shuffle=self.shuffle,
# )
if writer is None:
writer = ImmutableDatasetWriter()
@ -131,49 +197,82 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
# TODO(rcadene)
# load
dataset_dir = Path("data") / self.dataset_id
dataset_path = dataset_dir / f"buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
zarr_path = (
"/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
)
dataset_dict = ReplayBuffer.copy_from_path(
zarr_path
) # , keys=['img', 'state', 'action'])
total_frames = dataset_dict["actions"].shape[0]
episode_ids = dataset_dict.get_episode_idxs()
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
assert len(
set([dataset_dict[key].shape[0] for key in dataset_dict.keys()])
), "Some data type dont have the same number of total frames."
# TODO: verify that goal pose is expected to be fixed
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
goal_body = get_goal_pose_body(goal_pos_angle)
idx0 = 0
idx1 = 0
episode_id = 0
for i in tqdm.tqdm(range(total_frames)):
idx1 += 1
if not dataset_dict["dones"][i]:
continue
idxtd = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id]
num_frames = idx1 - idx0
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
next_image = torch.tensor(
dataset_dict["next_observations"]["rgb"][idx0:idx1]
)
next_state = torch.tensor(
dataset_dict["next_observations"]["state"][idx0:idx1]
)
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
assert (episode_ids[idx0:idx1] == episode_id).all()
image = torch.from_numpy(dataset_dict["img"][idx0:idx1])
image = einops.rearrange(image, "b h w c -> b c h w")
state = torch.from_numpy(dataset_dict["state"][idx0:idx1])
agent_pos = state[:, :2]
block_pos = state[:, 2:4]
block_angle = state[:, 4]
reward = torch.zeros(num_frames, 1)
done = torch.zeros(num_frames, 1, dtype=torch.bool)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0
space.damping = 0
# Add walls.
walls = [
add_segment(space, (5, 506), (5, 5), 2),
add_segment(space, (5, 5), (506, 5), 2),
add_segment(space, (506, 5), (506, 506), 2),
add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
block_body = add_tee(
space, block_pos[i].tolist(), block_angle[i].item()
)
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage = intersection_area / goal_area
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
done[i] = coverage > SUCCESS_THRESHOLD
episode = TensorDict(
{
("observation", "image"): image,
("observation", "state"): state,
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
("next", "observation", "image"): next_image,
("next", "observation", "state"): next_state,
("next", "observation", "reward"): next_reward,
("next", "observation", "done"): next_done,
("observation", "image"): image[:-1],
("observation", "state"): agent_pos[:-1],
"action": torch.from_numpy(dataset_dict["action"][idx0:idx1])[:-1],
"episode": torch.from_numpy(episode_ids[idx0:idx1])[:-1],
"frame_id": torch.arange(0, num_frames - 1, 1),
("next", "observation", "image"): image[1:],
("next", "observation", "state"): agent_pos[1:],
# TODO: verify that reward and done are aligned with image and agent_pos
("next", "reward"): reward[1:],
("next", "done"): done[1:],
},
batch_size=num_frames,
batch_size=num_frames - 1,
)
if episode_id == 0:
@ -184,9 +283,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
.memmap_like(self.root / self.dataset_id)
)
td_data[idx0:idx1] = episode
td_data[idxtd : idxtd + len(episode)] = episode
episode_id += 1
idx0 = idx1
idxtd = idxtd + len(episode)
return TensorStorage(td_data.lock_())

View File

@ -13,7 +13,7 @@ class TOLD(nn.Module):
def __init__(self, cfg):
super().__init__()
action_dim = 4
action_dim = cfg.action_dim
self.cfg = cfg
self._encoder = h.enc(cfg)
@ -82,7 +82,7 @@ class TDMPC(nn.Module):
def __init__(self, cfg):
super().__init__()
self.action_dim = 4
self.action_dim = cfg.action_dim
self.cfg = cfg
self.device = torch.device("cuda")

View File

@ -130,7 +130,7 @@ class Flatten(nn.Module):
def enc(cfg):
obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size),
"state": (4,),
"state": (cfg.state_dim,),
}
"""Returns a TOLD encoder."""
@ -209,7 +209,7 @@ def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
def q(cfg):
action_dim = 4
action_dim = cfg.action_dim
"""Returns a Q-function that uses Layer Normalization."""
return nn.Sequential(
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
@ -331,7 +331,7 @@ class Episode(object):
"""Storage object for a single episode."""
def __init__(self, cfg, init_obs):
action_dim = 4
action_dim = cfg.action_dim
self.cfg = cfg
self.device = torch.device(cfg.buffer_device)
@ -447,8 +447,8 @@ class ReplayBuffer:
"""
def __init__(self, cfg, dataset=None):
action_dim = 4
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (4,)}
action_dim = cfg.action_dim
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (cfg.state_dim,)}
self.cfg = cfg
self.device = torch.device(cfg.buffer_device)

View File

@ -1,5 +1,6 @@
seed: 1337
log_dir: logs/2024_01_26_train
video_dir: tmp/2024_01_26_xarm_lift_medium
exp_name: default
device: cuda
buffer_device: cuda
@ -16,6 +17,7 @@ task: lift
from_pixels: True
pixels_only: False
image_size: 84
fps: 15
reward_scale: 1.0
@ -30,7 +32,8 @@ train_steps: 50000
frame_stack: 1
num_channels: 32
img_size: ${image_size}
state_dim: 4
action_dim: 4
# TDMPC
@ -97,4 +100,3 @@ latent_dim: 50
use_wandb: false
wandb_project: FOWM
wandb_entity: rcadene # insert your own

View File

@ -5,8 +5,12 @@ hydra:
job:
name: pusht
video_dir: tmp/2024_02_21_pusht
# env
env: pusht
image_size: 96
frame_skip: 1
state_dim: 2
action_dim: 2
fps: 10

View File

@ -20,6 +20,7 @@ def eval_policy(
max_steps: int = 30,
save_video: bool = False,
video_dir: Path = None,
fps: int = 15,
):
rewards = []
successes = []
@ -55,7 +56,7 @@ def eval_policy(
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
imageio.mimsave(video_path, np.stack(ep_frames), fps=fps)
metrics = {
"avg_reward": np.nanmean(rewards),
@ -74,16 +75,13 @@ def eval(cfg: dict):
if cfg.pretrained_model_path:
policy = TDMPC(cfg)
ckpt_path = (
"/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
)
if "offline" in cfg.pretrained_model_path:
policy.step = 25000
elif "final" in cfg.pretrained_model_path:
policy.step = 100000
else:
raise NotImplementedError()
policy.load(ckpt_path)
policy.load(cfg.pretrained_model_path)
policy = TensorDictModule(
policy,
@ -99,7 +97,8 @@ def eval(cfg: dict):
policy=policy,
num_episodes=20,
save_video=True,
video_dir=Path("tmp/2023_02_19_pusht"),
video_dir=Path(cfg.video_dir),
fps=cfg.fps,
)
print(metrics)

View File

@ -1,6 +1,7 @@
import pickle
from pathlib import Path
import hydra
import imageio
import simxarm
import torch
@ -10,30 +11,25 @@ from torchrl.data.replay_buffers import (
SliceSamplerWithoutReplacement,
)
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from lerobot.common.datasets.factory import make_offline_buffer
def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"):
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def visualize_dataset(cfg: dict):
sampler = SliceSamplerWithoutReplacement(
num_slices=1,
strict_length=False,
shuffle=False,
)
dataset = SimxarmExperienceReplay(
dataset_id,
# download="force",
download=True,
streaming=False,
root="data",
sampler=sampler,
)
offline_buffer = make_offline_buffer(cfg, sampler)
NUM_EPISODES_TO_RENDER = 10
MAX_NUM_STEPS = 50
MAX_NUM_STEPS = 1000
FIRST_FRAME = 0
for _ in range(NUM_EPISODES_TO_RENDER):
episode = dataset.sample(MAX_NUM_STEPS)
episode = offline_buffer.sample(MAX_NUM_STEPS)
ep_idx = episode["episode"][FIRST_FRAME].item()
ep_frames = torch.cat(
@ -44,16 +40,23 @@ def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"):
dim=0,
)
video_dir = Path("tmp/2024_02_03_xarm_lift_medium")
video_dir = Path(cfg.video_dir)
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{ep_idx}.mp4"
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=15)
assert ep_frames.min().item() >= 0
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
assert ep_frames.max().item() <= 255
ep_frames = ep_frames.type(torch.uint8)
imageio.mimsave(
video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps
)
# ran out of episodes
if dataset._sampler._sample_list.numel() == 0:
if offline_buffer._sampler._sample_list.numel() == 0:
break
if __name__ == "__main__":
visualize_simxarm_dataset()
visualize_dataset()