Fix unit tests, Refactor, Add pusht env, (TODO pusht replay buffer, image preprocessing)

This commit is contained in:
Cadene 2024-02-20 12:26:57 +00:00
parent fdfb2010fd
commit 3da6ffb2cb
10 changed files with 559 additions and 89 deletions

View File

@ -0,0 +1,47 @@
import torch
from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
def make_offline_buffer(cfg):
num_traj_per_batch = cfg.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
sampler = PrioritizedSliceSampler(
max_capacity=100_000,
alpha=cfg.per_alpha,
beta=cfg.per_beta,
num_slices=num_traj_per_batch,
strict_length=False,
)
if cfg.env == "simxarm":
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
offline_buffer = SimxarmExperienceReplay(
f"xarm_{cfg.task}_medium",
# download="force",
download=True,
streaming=False,
root="data",
sampler=sampler,
)
elif cfg.env == "pusht":
offline_buffer = PushtExperienceReplay(
f"xarm_{cfg.task}_medium",
# download="force",
download=True,
streaming=False,
root="data",
sampler=sampler,
)
else:
raise ValueError(cfg.env)
num_steps = len(offline_buffer)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
return offline_buffer

View File

@ -0,0 +1,192 @@
import os
import pickle
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import (
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import (
Sampler,
SliceSampler,
SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
class PushtExperienceReplay(TensorDictReplayBuffer):
available_datasets = [
"xarm_lift_medium",
]
def __init__(
self,
dataset_id,
batch_size: int = None,
*,
shuffle: bool = True,
num_slices: int = None,
slice_len: int = None,
pad: float = None,
replacement: bool = None,
streaming: bool = False,
root: Path = None,
download: bool = False,
sampler: Sampler = None,
writer: Writer = None,
collate_fn: Callable = None,
pin_memory: bool = False,
prefetch: int = None,
transform: "torchrl.envs.Transform" = None, # noqa-F821
split_trajs: bool = False,
strict_length: bool = True,
):
# TODO
raise NotImplementedError()
self.download = download
if streaming:
raise NotImplementedError
self.streaming = streaming
self.dataset_id = dataset_id
self.split_trajs = split_trajs
self.shuffle = shuffle
self.num_slices = num_slices
self.slice_len = slice_len
self.pad = pad
self.strict_length = strict_length
if (self.num_slices is not None) and (self.slice_len is not None):
raise ValueError("num_slices or slice_len can be not None, but not both.")
if split_trajs:
raise NotImplementedError
if root is None:
root = _get_root_dir("simxarm")
os.makedirs(root, exist_ok=True)
self.root = Path(root)
if self.download == "force" or (self.download and not self._is_downloaded()):
storage = self._download_and_preproc()
else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
if num_slices is not None or slice_len is not None:
if sampler is not None:
raise ValueError(
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
)
if replacement:
if not self.shuffle:
raise RuntimeError(
"shuffle=False can only be used when replacement=False."
)
sampler = SliceSampler(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
)
else:
sampler = SliceSamplerWithoutReplacement(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
shuffle=self.shuffle,
)
if writer is None:
writer = ImmutableDatasetWriter()
if collate_fn is None:
collate_fn = _collate_id
super().__init__(
storage=storage,
sampler=sampler,
writer=writer,
collate_fn=collate_fn,
pin_memory=pin_memory,
prefetch=prefetch,
batch_size=batch_size,
transform=transform,
)
@property
def data_path_root(self):
if self.streaming:
return None
return self.root / self.dataset_id
def _is_downloaded(self):
return os.path.exists(self.data_path_root)
def _download_and_preproc(self):
# download
# TODO(rcadene)
# load
dataset_dir = Path("data") / self.dataset_id
dataset_path = dataset_dir / f"buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
total_frames = dataset_dict["actions"].shape[0]
idx0 = 0
idx1 = 0
episode_id = 0
for i in tqdm.tqdm(range(total_frames)):
idx1 += 1
if not dataset_dict["dones"][i]:
continue
num_frames = idx1 - idx0
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
next_image = torch.tensor(
dataset_dict["next_observations"]["rgb"][idx0:idx1]
)
next_state = torch.tensor(
dataset_dict["next_observations"]["state"][idx0:idx1]
)
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
episode = TensorDict(
{
("observation", "image"): image,
("observation", "state"): state,
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
("next", "observation", "image"): next_image,
("next", "observation", "state"): next_state,
("next", "observation", "reward"): next_reward,
("next", "observation", "done"): next_done,
},
batch_size=num_frames,
)
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = (
episode[0]
.expand(total_frames)
.memmap_like(self.root / self.dataset_id)
)
td_data[idx0:idx1] = episode
episode_id += 1
idx0 = idx1
return TensorStorage(td_data.lock_())

View File

@ -1,17 +1,26 @@
from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.transforms import StepCounter, TransformedEnv
from lerobot.common.envs.pusht import PushtEnv
from lerobot.common.envs.simxarm import SimxarmEnv from lerobot.common.envs.simxarm import SimxarmEnv
def make_env(cfg): def make_env(cfg):
assert cfg.env == "simxarm" kwargs = {
env = SimxarmEnv( "frame_skip": cfg.action_repeat,
task=cfg.task, "from_pixels": cfg.from_pixels,
frame_skip=cfg.action_repeat, "pixels_only": cfg.pixels_only,
from_pixels=cfg.from_pixels, "image_size": cfg.image_size,
pixels_only=cfg.pixels_only, }
image_size=cfg.image_size,
) if cfg.env == "simxarm":
kwargs["task"] = cfg.task
clsfunc = SimxarmEnv
elif cfg.env == "pusht":
clsfunc = PushtEnv
else:
raise ValueError(cfg.env)
env = clsfunc(**kwargs)
# limit rollout to max_steps # limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length)) env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length))

View File

@ -0,0 +1,193 @@
import importlib
from typing import Optional
import numpy as np
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import EnvBase
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.utils import set_seed
_has_gym = importlib.util.find_spec("gym") is not None
_has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _has_gym
class PushtEnv(EnvBase):
def __init__(
self,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
max_episode_length=25, # TODO: verify
):
super().__init__(device=device, batch_size=[])
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
self.max_episode_length = max_episode_length
if pixels_only:
assert from_pixels
if from_pixels:
assert image_size
if not _has_diffpolicy:
raise ImportError("Cannot import diffusion_policy.")
if not _has_gym:
raise ImportError("Cannot import gym.")
import gym
from diffusion_policy.env.pusht.pusht_env import PushTEnv
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
from gym.wrappers import TimeLimit
self._env = PushTImageEnv(render_size=self.image_size)
self._env = TimeLimit(self._env, self.max_episode_length)
# MAX_NUM_ACTIONS = 4
# num_actions = len(TASKS[self.task]["action_space"])
# self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
# self._action_padding = np.zeros(
# (MAX_NUM_ACTIONS - num_actions), dtype=np.float32
# )
# if "w" not in TASKS[self.task]["action_space"]:
# self._action_padding[-1] = 1.0
self._make_spec()
self.set_seed(seed)
def render(self, mode="rgb_array", width=384, height=384):
if width != height:
raise NotImplementedError()
tmp = self._env.render_size
self._env.render_size = width
out = self._env.render(mode)
self._env.render_size = tmp
return out
def _format_raw_obs(self, raw_obs):
if self.from_pixels:
obs = {"image": torch.from_numpy(raw_obs["image"])}
if not self.pixels_only:
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(
torch.float32
)
else:
# TODO:
obs = {
"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)
}
obs = TensorDict(obs, batch_size=[])
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
raw_obs = self._env.reset()
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
return td
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
for t in range(self.frame_skip):
raw_obs, reward, done, info = self._env.step(action)
sum_reward += reward
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([done], dtype=torch.bool),
},
batch_size=[],
)
return td
def _make_spec(self):
obs = {}
if self.from_pixels:
obs["image"] = BoundedTensorSpec(
low=0,
high=1,
shape=(3, self.image_size, self.image_size),
dtype=torch.float32,
device=self.device,
)
if not self.pixels_only:
obs["state"] = BoundedTensorSpec(
low=0,
high=512,
shape=self._env.observation_space["agent_pos"].shape,
dtype=torch.float32,
device=self.device,
)
else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
obs["state"] = UnboundedContinuousTensorSpec(
# TODO:
shape=self._env.observation_space["observation"].shape,
dtype=torch.float32,
device=self.device,
)
self.observation_spec = CompositeSpec({"observation": obs})
self.action_spec = _gym_to_torchrl_spec_transform(
self._env.action_space,
device=self.device,
)
self.reward_spec = UnboundedContinuousTensorSpec(
shape=(1,),
dtype=torch.float32,
device=self.device,
)
self.done_spec = CompositeSpec(
{
"done": DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
),
"success": DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
),
}
)
def _set_seed(self, seed: Optional[int]):
set_seed(seed)
self._env.seed(seed)

View File

@ -167,18 +167,21 @@ class SimxarmEnv(EnvBase):
device=self.device, device=self.device,
) )
self.done_spec = DiscreteTensorSpec( self.done_spec = CompositeSpec(
{
"done": DiscreteTensorSpec(
2, 2,
shape=(1,), shape=(1,),
dtype=torch.bool, dtype=torch.bool,
device=self.device, device=self.device,
) ),
"success": DiscreteTensorSpec(
self.success_spec = DiscreteTensorSpec(
2, 2,
shape=(1,), shape=(1,),
dtype=torch.bool, dtype=torch.bool,
device=self.device, device=self.device,
),
}
) )
def _set_seed(self, seed: Optional[int]): def _set_seed(self, seed: Optional[int]):

View File

@ -29,7 +29,7 @@ train_steps: 50000
# pixels # pixels
frame_stack: 1 frame_stack: 1
num_channels: 32 num_channels: 32
img_size: 84 img_size: ${image_size}
# TDMPC # TDMPC
@ -82,6 +82,8 @@ A_scaling: 3.0
# offline->online # offline->online
offline_steps: 25000 # ${train_steps}/2 offline_steps: 25000 # ${train_steps}/2
pretrained_model_path: "" pretrained_model_path: ""
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
balanced_sampling: true balanced_sampling: true
demo_schedule: 0.5 demo_schedule: 0.5

View File

@ -0,0 +1,12 @@
defaults:
- default
hydra:
job:
name: pusht
# env
env: pusht
image_size: 96
frame_skip: 1

View File

@ -71,9 +71,18 @@ def eval(cfg: dict):
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir) print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir)
env = make_env(cfg) env = make_env(cfg)
if cfg.pretrained_model_path:
policy = TDMPC(cfg) policy = TDMPC(cfg)
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" ckpt_path = (
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
)
if "offline" in cfg.pretrained_model_path:
policy.step = 25000
elif "final" in cfg.pretrained_model_path:
policy.step = 100000
else:
raise NotImplementedError()
policy.load(ckpt_path) policy.load(ckpt_path)
policy = TensorDictModule( policy = TensorDictModule(
@ -81,14 +90,16 @@ def eval(cfg: dict):
in_keys=["observation", "step_count"], in_keys=["observation", "step_count"],
out_keys=["action"], out_keys=["action"],
) )
else:
# when policy is None, rollout a random policy
policy = None
# policy can be None to rollout a random policy
metrics = eval_policy( metrics = eval_policy(
env, env,
policy=policy, policy=policy,
num_episodes=20, num_episodes=20,
save_video=False, save_video=True,
video_dir=Path("tmp/2023_01_29_xarm_lift_final"), video_dir=Path("tmp/2023_02_19_pusht"),
) )
print(metrics) print(metrics)

View File

@ -10,6 +10,7 @@ from torchrl.data.datasets.d4rl import D4RLExperienceReplay
from torchrl.data.datasets.openx import OpenXExperienceReplay from torchrl.data.datasets.openx import OpenXExperienceReplay
from torchrl.data.replay_buffers import PrioritizedSliceSampler from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger from lerobot.common.logger import Logger
@ -26,11 +27,17 @@ def train(cfg: dict):
env = make_env(cfg) env = make_env(cfg)
policy = TDMPC(cfg) policy = TDMPC(cfg)
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" if cfg.pretrained_model_path:
# policy.step = 25000 ckpt_path = (
# # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
# # policy.step = 100000 )
# policy.load(ckpt_path) if "offline" in cfg.pretrained_model_path:
policy.step = 25000
elif "final" in cfg.pretrained_model_path:
policy.step = 100000
else:
raise NotImplementedError()
policy.load(ckpt_path)
td_policy = TensorDictModule( td_policy = TensorDictModule(
policy, policy,
@ -40,32 +47,7 @@ def train(cfg: dict):
# initialize offline dataset # initialize offline dataset
dataset_id = f"xarm_{cfg.task}_medium" offline_buffer = make_offline_buffer(cfg)
num_traj_per_batch = cfg.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
sampler = PrioritizedSliceSampler(
max_capacity=100_000,
alpha=cfg.per_alpha,
beta=cfg.per_beta,
num_slices=num_traj_per_batch,
strict_length=False,
)
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
offline_buffer = SimxarmExperienceReplay(
dataset_id,
# download="force",
download=True,
streaming=False,
root="data",
sampler=sampler,
)
num_steps = len(offline_buffer)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
if cfg.balanced_sampling: if cfg.balanced_sampling:
online_sampler = PrioritizedSliceSampler( online_sampler = PrioritizedSliceSampler(

View File

@ -2,7 +2,35 @@ import pytest
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.envs.utils import check_env_specs, step_mdp from torchrl.envs.utils import check_env_specs, step_mdp
from lerobot.common.envs import SimxarmEnv from lerobot.common.envs.pusht import PushtEnv
from lerobot.common.envs.simxarm import SimxarmEnv
def print_spec_rollout(env):
print("observation_spec:", env.observation_spec)
print("action_spec:", env.action_spec)
print("reward_spec:", env.reward_spec)
print("done_spec:", env.done_spec)
td = env.reset()
print("reset tensordict", td)
td = env.rand_step(td)
print("random step tensordict", td)
def simple_rollout(steps=100):
# preallocate:
data = TensorDict({}, [steps])
# reset
_data = env.reset()
for i in range(steps):
_data["action"] = env.action_spec.rand()
_data = env.step(_data)
data[i] = _data
_data = step_mdp(_data, keep_other=True)
return data
print("data from rollout:", simple_rollout(100))
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -26,30 +54,21 @@ def test_simxarm(task, from_pixels, pixels_only):
pixels_only=pixels_only, pixels_only=pixels_only,
image_size=84 if from_pixels else None, image_size=84 if from_pixels else None,
) )
print_spec_rollout(env)
check_env_specs(env) check_env_specs(env)
print("observation_spec:", env.observation_spec)
print("action_spec:", env.action_spec)
print("reward_spec:", env.reward_spec)
print("done_spec:", env.done_spec)
print("success_spec:", env.success_spec)
td = env.reset() @pytest.mark.parametrize(
print("reset tensordict", td) "from_pixels,pixels_only",
[
td = env.rand_step(td) (True, False),
print("random step tensordict", td) ],
)
def simple_rollout(steps=100): def test_pusht(from_pixels, pixels_only):
# preallocate: env = PushtEnv(
data = TensorDict({}, [steps]) from_pixels=from_pixels,
# reset pixels_only=pixels_only,
_data = env.reset() image_size=96 if from_pixels else None,
for i in range(steps): )
_data["action"] = env.action_spec.rand() print_spec_rollout(env)
_data = env.step(_data) check_env_specs(env)
data[i] = _data
_data = step_mdp(_data, keep_other=True)
return data
print("data from rollout:", simple_rollout(100))