Fix unit tests, Refactor, Add pusht env, (TODO pusht replay buffer, image preprocessing)
This commit is contained in:
parent
fdfb2010fd
commit
3da6ffb2cb
|
@ -0,0 +1,47 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||||
|
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||||
|
from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
|
||||||
|
|
||||||
|
|
||||||
|
def make_offline_buffer(cfg):
|
||||||
|
|
||||||
|
num_traj_per_batch = cfg.batch_size # // cfg.horizon
|
||||||
|
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||||
|
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||||
|
sampler = PrioritizedSliceSampler(
|
||||||
|
max_capacity=100_000,
|
||||||
|
alpha=cfg.per_alpha,
|
||||||
|
beta=cfg.per_beta,
|
||||||
|
num_slices=num_traj_per_batch,
|
||||||
|
strict_length=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.env == "simxarm":
|
||||||
|
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
||||||
|
offline_buffer = SimxarmExperienceReplay(
|
||||||
|
f"xarm_{cfg.task}_medium",
|
||||||
|
# download="force",
|
||||||
|
download=True,
|
||||||
|
streaming=False,
|
||||||
|
root="data",
|
||||||
|
sampler=sampler,
|
||||||
|
)
|
||||||
|
elif cfg.env == "pusht":
|
||||||
|
offline_buffer = PushtExperienceReplay(
|
||||||
|
f"xarm_{cfg.task}_medium",
|
||||||
|
# download="force",
|
||||||
|
download=True,
|
||||||
|
streaming=False,
|
||||||
|
root="data",
|
||||||
|
sampler=sampler,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(cfg.env)
|
||||||
|
|
||||||
|
num_steps = len(offline_buffer)
|
||||||
|
index = torch.arange(0, num_steps, 1)
|
||||||
|
sampler.extend(index)
|
||||||
|
|
||||||
|
return offline_buffer
|
|
@ -0,0 +1,192 @@
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchrl
|
||||||
|
import tqdm
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.datasets.utils import _get_root_dir
|
||||||
|
from torchrl.data.replay_buffers.replay_buffers import (
|
||||||
|
TensorDictPrioritizedReplayBuffer,
|
||||||
|
TensorDictReplayBuffer,
|
||||||
|
)
|
||||||
|
from torchrl.data.replay_buffers.samplers import (
|
||||||
|
Sampler,
|
||||||
|
SliceSampler,
|
||||||
|
SliceSamplerWithoutReplacement,
|
||||||
|
)
|
||||||
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||||
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||||
|
|
||||||
|
|
||||||
|
class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
|
|
||||||
|
available_datasets = [
|
||||||
|
"xarm_lift_medium",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_id,
|
||||||
|
batch_size: int = None,
|
||||||
|
*,
|
||||||
|
shuffle: bool = True,
|
||||||
|
num_slices: int = None,
|
||||||
|
slice_len: int = None,
|
||||||
|
pad: float = None,
|
||||||
|
replacement: bool = None,
|
||||||
|
streaming: bool = False,
|
||||||
|
root: Path = None,
|
||||||
|
download: bool = False,
|
||||||
|
sampler: Sampler = None,
|
||||||
|
writer: Writer = None,
|
||||||
|
collate_fn: Callable = None,
|
||||||
|
pin_memory: bool = False,
|
||||||
|
prefetch: int = None,
|
||||||
|
transform: "torchrl.envs.Transform" = None, # noqa-F821
|
||||||
|
split_trajs: bool = False,
|
||||||
|
strict_length: bool = True,
|
||||||
|
):
|
||||||
|
# TODO
|
||||||
|
raise NotImplementedError()
|
||||||
|
self.download = download
|
||||||
|
if streaming:
|
||||||
|
raise NotImplementedError
|
||||||
|
self.streaming = streaming
|
||||||
|
self.dataset_id = dataset_id
|
||||||
|
self.split_trajs = split_trajs
|
||||||
|
self.shuffle = shuffle
|
||||||
|
self.num_slices = num_slices
|
||||||
|
self.slice_len = slice_len
|
||||||
|
self.pad = pad
|
||||||
|
|
||||||
|
self.strict_length = strict_length
|
||||||
|
if (self.num_slices is not None) and (self.slice_len is not None):
|
||||||
|
raise ValueError("num_slices or slice_len can be not None, but not both.")
|
||||||
|
if split_trajs:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if root is None:
|
||||||
|
root = _get_root_dir("simxarm")
|
||||||
|
os.makedirs(root, exist_ok=True)
|
||||||
|
self.root = Path(root)
|
||||||
|
if self.download == "force" or (self.download and not self._is_downloaded()):
|
||||||
|
storage = self._download_and_preproc()
|
||||||
|
else:
|
||||||
|
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||||
|
|
||||||
|
if num_slices is not None or slice_len is not None:
|
||||||
|
if sampler is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
|
||||||
|
)
|
||||||
|
|
||||||
|
if replacement:
|
||||||
|
if not self.shuffle:
|
||||||
|
raise RuntimeError(
|
||||||
|
"shuffle=False can only be used when replacement=False."
|
||||||
|
)
|
||||||
|
sampler = SliceSampler(
|
||||||
|
num_slices=num_slices,
|
||||||
|
slice_len=slice_len,
|
||||||
|
strict_length=strict_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampler = SliceSamplerWithoutReplacement(
|
||||||
|
num_slices=num_slices,
|
||||||
|
slice_len=slice_len,
|
||||||
|
strict_length=strict_length,
|
||||||
|
shuffle=self.shuffle,
|
||||||
|
)
|
||||||
|
|
||||||
|
if writer is None:
|
||||||
|
writer = ImmutableDatasetWriter()
|
||||||
|
if collate_fn is None:
|
||||||
|
collate_fn = _collate_id
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
storage=storage,
|
||||||
|
sampler=sampler,
|
||||||
|
writer=writer,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
prefetch=prefetch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_path_root(self):
|
||||||
|
if self.streaming:
|
||||||
|
return None
|
||||||
|
return self.root / self.dataset_id
|
||||||
|
|
||||||
|
def _is_downloaded(self):
|
||||||
|
return os.path.exists(self.data_path_root)
|
||||||
|
|
||||||
|
def _download_and_preproc(self):
|
||||||
|
# download
|
||||||
|
# TODO(rcadene)
|
||||||
|
|
||||||
|
# load
|
||||||
|
dataset_dir = Path("data") / self.dataset_id
|
||||||
|
dataset_path = dataset_dir / f"buffer.pkl"
|
||||||
|
print(f"Using offline dataset '{dataset_path}'")
|
||||||
|
with open(dataset_path, "rb") as f:
|
||||||
|
dataset_dict = pickle.load(f)
|
||||||
|
|
||||||
|
total_frames = dataset_dict["actions"].shape[0]
|
||||||
|
|
||||||
|
idx0 = 0
|
||||||
|
idx1 = 0
|
||||||
|
episode_id = 0
|
||||||
|
for i in tqdm.tqdm(range(total_frames)):
|
||||||
|
idx1 += 1
|
||||||
|
|
||||||
|
if not dataset_dict["dones"][i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_frames = idx1 - idx0
|
||||||
|
|
||||||
|
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||||
|
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||||
|
next_image = torch.tensor(
|
||||||
|
dataset_dict["next_observations"]["rgb"][idx0:idx1]
|
||||||
|
)
|
||||||
|
next_state = torch.tensor(
|
||||||
|
dataset_dict["next_observations"]["state"][idx0:idx1]
|
||||||
|
)
|
||||||
|
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||||
|
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||||
|
|
||||||
|
episode = TensorDict(
|
||||||
|
{
|
||||||
|
("observation", "image"): image,
|
||||||
|
("observation", "state"): state,
|
||||||
|
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
|
||||||
|
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||||
|
"frame_id": torch.arange(0, num_frames, 1),
|
||||||
|
("next", "observation", "image"): next_image,
|
||||||
|
("next", "observation", "state"): next_state,
|
||||||
|
("next", "observation", "reward"): next_reward,
|
||||||
|
("next", "observation", "done"): next_done,
|
||||||
|
},
|
||||||
|
batch_size=num_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
if episode_id == 0:
|
||||||
|
# hack to initialize tensordict data structure to store episodes
|
||||||
|
td_data = (
|
||||||
|
episode[0]
|
||||||
|
.expand(total_frames)
|
||||||
|
.memmap_like(self.root / self.dataset_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
td_data[idx0:idx1] = episode
|
||||||
|
|
||||||
|
episode_id += 1
|
||||||
|
idx0 = idx1
|
||||||
|
|
||||||
|
return TensorStorage(td_data.lock_())
|
|
@ -1,17 +1,26 @@
|
||||||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||||
|
|
||||||
|
from lerobot.common.envs.pusht import PushtEnv
|
||||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg):
|
def make_env(cfg):
|
||||||
assert cfg.env == "simxarm"
|
kwargs = {
|
||||||
env = SimxarmEnv(
|
"frame_skip": cfg.action_repeat,
|
||||||
task=cfg.task,
|
"from_pixels": cfg.from_pixels,
|
||||||
frame_skip=cfg.action_repeat,
|
"pixels_only": cfg.pixels_only,
|
||||||
from_pixels=cfg.from_pixels,
|
"image_size": cfg.image_size,
|
||||||
pixels_only=cfg.pixels_only,
|
}
|
||||||
image_size=cfg.image_size,
|
|
||||||
)
|
if cfg.env == "simxarm":
|
||||||
|
kwargs["task"] = cfg.task
|
||||||
|
clsfunc = SimxarmEnv
|
||||||
|
elif cfg.env == "pusht":
|
||||||
|
clsfunc = PushtEnv
|
||||||
|
else:
|
||||||
|
raise ValueError(cfg.env)
|
||||||
|
|
||||||
|
env = clsfunc(**kwargs)
|
||||||
|
|
||||||
# limit rollout to max_steps
|
# limit rollout to max_steps
|
||||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length))
|
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length))
|
||||||
|
|
|
@ -0,0 +1,193 @@
|
||||||
|
import importlib
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.tensor_specs import (
|
||||||
|
BoundedTensorSpec,
|
||||||
|
CompositeSpec,
|
||||||
|
DiscreteTensorSpec,
|
||||||
|
UnboundedContinuousTensorSpec,
|
||||||
|
)
|
||||||
|
from torchrl.envs import EnvBase
|
||||||
|
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||||
|
|
||||||
|
from lerobot.common.utils import set_seed
|
||||||
|
|
||||||
|
_has_gym = importlib.util.find_spec("gym") is not None
|
||||||
|
_has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _has_gym
|
||||||
|
|
||||||
|
|
||||||
|
class PushtEnv(EnvBase):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
frame_skip: int = 1,
|
||||||
|
from_pixels: bool = False,
|
||||||
|
pixels_only: bool = False,
|
||||||
|
image_size=None,
|
||||||
|
seed=1337,
|
||||||
|
device="cpu",
|
||||||
|
max_episode_length=25, # TODO: verify
|
||||||
|
):
|
||||||
|
super().__init__(device=device, batch_size=[])
|
||||||
|
self.frame_skip = frame_skip
|
||||||
|
self.from_pixels = from_pixels
|
||||||
|
self.pixels_only = pixels_only
|
||||||
|
self.image_size = image_size
|
||||||
|
self.max_episode_length = max_episode_length
|
||||||
|
|
||||||
|
if pixels_only:
|
||||||
|
assert from_pixels
|
||||||
|
if from_pixels:
|
||||||
|
assert image_size
|
||||||
|
|
||||||
|
if not _has_diffpolicy:
|
||||||
|
raise ImportError("Cannot import diffusion_policy.")
|
||||||
|
if not _has_gym:
|
||||||
|
raise ImportError("Cannot import gym.")
|
||||||
|
|
||||||
|
import gym
|
||||||
|
from diffusion_policy.env.pusht.pusht_env import PushTEnv
|
||||||
|
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||||
|
from gym.wrappers import TimeLimit
|
||||||
|
|
||||||
|
self._env = PushTImageEnv(render_size=self.image_size)
|
||||||
|
self._env = TimeLimit(self._env, self.max_episode_length)
|
||||||
|
|
||||||
|
# MAX_NUM_ACTIONS = 4
|
||||||
|
# num_actions = len(TASKS[self.task]["action_space"])
|
||||||
|
# self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||||
|
# self._action_padding = np.zeros(
|
||||||
|
# (MAX_NUM_ACTIONS - num_actions), dtype=np.float32
|
||||||
|
# )
|
||||||
|
# if "w" not in TASKS[self.task]["action_space"]:
|
||||||
|
# self._action_padding[-1] = 1.0
|
||||||
|
|
||||||
|
self._make_spec()
|
||||||
|
self.set_seed(seed)
|
||||||
|
|
||||||
|
def render(self, mode="rgb_array", width=384, height=384):
|
||||||
|
if width != height:
|
||||||
|
raise NotImplementedError()
|
||||||
|
tmp = self._env.render_size
|
||||||
|
self._env.render_size = width
|
||||||
|
out = self._env.render(mode)
|
||||||
|
self._env.render_size = tmp
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _format_raw_obs(self, raw_obs):
|
||||||
|
if self.from_pixels:
|
||||||
|
obs = {"image": torch.from_numpy(raw_obs["image"])}
|
||||||
|
|
||||||
|
if not self.pixels_only:
|
||||||
|
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(
|
||||||
|
torch.float32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO:
|
||||||
|
obs = {
|
||||||
|
"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)
|
||||||
|
}
|
||||||
|
|
||||||
|
obs = TensorDict(obs, batch_size=[])
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
|
td = tensordict
|
||||||
|
if td is None or td.is_empty():
|
||||||
|
raw_obs = self._env.reset()
|
||||||
|
|
||||||
|
td = TensorDict(
|
||||||
|
{
|
||||||
|
"observation": self._format_raw_obs(raw_obs),
|
||||||
|
"done": torch.tensor([False], dtype=torch.bool),
|
||||||
|
},
|
||||||
|
batch_size=[],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
return td
|
||||||
|
|
||||||
|
def _step(self, tensordict: TensorDict):
|
||||||
|
td = tensordict
|
||||||
|
action = td["action"].numpy()
|
||||||
|
# step expects shape=(4,) so we pad if necessary
|
||||||
|
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||||
|
sum_reward = 0
|
||||||
|
for t in range(self.frame_skip):
|
||||||
|
raw_obs, reward, done, info = self._env.step(action)
|
||||||
|
sum_reward += reward
|
||||||
|
|
||||||
|
td = TensorDict(
|
||||||
|
{
|
||||||
|
"observation": self._format_raw_obs(raw_obs),
|
||||||
|
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||||
|
# succes and done are true when coverage > self.success_threshold in env
|
||||||
|
"done": torch.tensor([done], dtype=torch.bool),
|
||||||
|
"success": torch.tensor([done], dtype=torch.bool),
|
||||||
|
},
|
||||||
|
batch_size=[],
|
||||||
|
)
|
||||||
|
return td
|
||||||
|
|
||||||
|
def _make_spec(self):
|
||||||
|
obs = {}
|
||||||
|
if self.from_pixels:
|
||||||
|
obs["image"] = BoundedTensorSpec(
|
||||||
|
low=0,
|
||||||
|
high=1,
|
||||||
|
shape=(3, self.image_size, self.image_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
if not self.pixels_only:
|
||||||
|
obs["state"] = BoundedTensorSpec(
|
||||||
|
low=0,
|
||||||
|
high=512,
|
||||||
|
shape=self._env.observation_space["agent_pos"].shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||||
|
obs["state"] = UnboundedContinuousTensorSpec(
|
||||||
|
# TODO:
|
||||||
|
shape=self._env.observation_space["observation"].shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.observation_spec = CompositeSpec({"observation": obs})
|
||||||
|
|
||||||
|
self.action_spec = _gym_to_torchrl_spec_transform(
|
||||||
|
self._env.action_space,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.done_spec = CompositeSpec(
|
||||||
|
{
|
||||||
|
"done": DiscreteTensorSpec(
|
||||||
|
2,
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"success": DiscreteTensorSpec(
|
||||||
|
2,
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_seed(self, seed: Optional[int]):
|
||||||
|
set_seed(seed)
|
||||||
|
self._env.seed(seed)
|
|
@ -167,18 +167,21 @@ class SimxarmEnv(EnvBase):
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.done_spec = DiscreteTensorSpec(
|
self.done_spec = CompositeSpec(
|
||||||
|
{
|
||||||
|
"done": DiscreteTensorSpec(
|
||||||
2,
|
2,
|
||||||
shape=(1,),
|
shape=(1,),
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
),
|
||||||
|
"success": DiscreteTensorSpec(
|
||||||
self.success_spec = DiscreteTensorSpec(
|
|
||||||
2,
|
2,
|
||||||
shape=(1,),
|
shape=(1,),
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_seed(self, seed: Optional[int]):
|
def _set_seed(self, seed: Optional[int]):
|
||||||
|
|
|
@ -29,7 +29,7 @@ train_steps: 50000
|
||||||
# pixels
|
# pixels
|
||||||
frame_stack: 1
|
frame_stack: 1
|
||||||
num_channels: 32
|
num_channels: 32
|
||||||
img_size: 84
|
img_size: ${image_size}
|
||||||
|
|
||||||
|
|
||||||
# TDMPC
|
# TDMPC
|
||||||
|
@ -82,6 +82,8 @@ A_scaling: 3.0
|
||||||
# offline->online
|
# offline->online
|
||||||
offline_steps: 25000 # ${train_steps}/2
|
offline_steps: 25000 # ${train_steps}/2
|
||||||
pretrained_model_path: ""
|
pretrained_model_path: ""
|
||||||
|
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||||
|
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||||
balanced_sampling: true
|
balanced_sampling: true
|
||||||
demo_schedule: 0.5
|
demo_schedule: 0.5
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
defaults:
|
||||||
|
- default
|
||||||
|
|
||||||
|
hydra:
|
||||||
|
job:
|
||||||
|
name: pusht
|
||||||
|
|
||||||
|
# env
|
||||||
|
env: pusht
|
||||||
|
image_size: 96
|
||||||
|
frame_skip: 1
|
||||||
|
|
|
@ -71,9 +71,18 @@ def eval(cfg: dict):
|
||||||
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir)
|
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir)
|
||||||
|
|
||||||
env = make_env(cfg)
|
env = make_env(cfg)
|
||||||
|
|
||||||
|
if cfg.pretrained_model_path:
|
||||||
policy = TDMPC(cfg)
|
policy = TDMPC(cfg)
|
||||||
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
ckpt_path = (
|
||||||
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
"/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||||
|
)
|
||||||
|
if "offline" in cfg.pretrained_model_path:
|
||||||
|
policy.step = 25000
|
||||||
|
elif "final" in cfg.pretrained_model_path:
|
||||||
|
policy.step = 100000
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
policy.load(ckpt_path)
|
policy.load(ckpt_path)
|
||||||
|
|
||||||
policy = TensorDictModule(
|
policy = TensorDictModule(
|
||||||
|
@ -81,14 +90,16 @@ def eval(cfg: dict):
|
||||||
in_keys=["observation", "step_count"],
|
in_keys=["observation", "step_count"],
|
||||||
out_keys=["action"],
|
out_keys=["action"],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# when policy is None, rollout a random policy
|
||||||
|
policy = None
|
||||||
|
|
||||||
# policy can be None to rollout a random policy
|
|
||||||
metrics = eval_policy(
|
metrics = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
num_episodes=20,
|
num_episodes=20,
|
||||||
save_video=False,
|
save_video=True,
|
||||||
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
video_dir=Path("tmp/2023_02_19_pusht"),
|
||||||
)
|
)
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from torchrl.data.datasets.d4rl import D4RLExperienceReplay
|
||||||
from torchrl.data.datasets.openx import OpenXExperienceReplay
|
from torchrl.data.datasets.openx import OpenXExperienceReplay
|
||||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||||
|
|
||||||
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger
|
from lerobot.common.logger import Logger
|
||||||
|
@ -26,11 +27,17 @@ def train(cfg: dict):
|
||||||
|
|
||||||
env = make_env(cfg)
|
env = make_env(cfg)
|
||||||
policy = TDMPC(cfg)
|
policy = TDMPC(cfg)
|
||||||
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
if cfg.pretrained_model_path:
|
||||||
# policy.step = 25000
|
ckpt_path = (
|
||||||
# # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
"/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||||
# # policy.step = 100000
|
)
|
||||||
# policy.load(ckpt_path)
|
if "offline" in cfg.pretrained_model_path:
|
||||||
|
policy.step = 25000
|
||||||
|
elif "final" in cfg.pretrained_model_path:
|
||||||
|
policy.step = 100000
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
policy.load(ckpt_path)
|
||||||
|
|
||||||
td_policy = TensorDictModule(
|
td_policy = TensorDictModule(
|
||||||
policy,
|
policy,
|
||||||
|
@ -40,32 +47,7 @@ def train(cfg: dict):
|
||||||
|
|
||||||
# initialize offline dataset
|
# initialize offline dataset
|
||||||
|
|
||||||
dataset_id = f"xarm_{cfg.task}_medium"
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
|
||||||
num_traj_per_batch = cfg.batch_size # // cfg.horizon
|
|
||||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
|
||||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
|
||||||
sampler = PrioritizedSliceSampler(
|
|
||||||
max_capacity=100_000,
|
|
||||||
alpha=cfg.per_alpha,
|
|
||||||
beta=cfg.per_beta,
|
|
||||||
num_slices=num_traj_per_batch,
|
|
||||||
strict_length=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
|
||||||
offline_buffer = SimxarmExperienceReplay(
|
|
||||||
dataset_id,
|
|
||||||
# download="force",
|
|
||||||
download=True,
|
|
||||||
streaming=False,
|
|
||||||
root="data",
|
|
||||||
sampler=sampler,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_steps = len(offline_buffer)
|
|
||||||
index = torch.arange(0, num_steps, 1)
|
|
||||||
sampler.extend(index)
|
|
||||||
|
|
||||||
if cfg.balanced_sampling:
|
if cfg.balanced_sampling:
|
||||||
online_sampler = PrioritizedSliceSampler(
|
online_sampler = PrioritizedSliceSampler(
|
||||||
|
|
|
@ -2,7 +2,35 @@ import pytest
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.envs.utils import check_env_specs, step_mdp
|
from torchrl.envs.utils import check_env_specs, step_mdp
|
||||||
|
|
||||||
from lerobot.common.envs import SimxarmEnv
|
from lerobot.common.envs.pusht import PushtEnv
|
||||||
|
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||||
|
|
||||||
|
|
||||||
|
def print_spec_rollout(env):
|
||||||
|
print("observation_spec:", env.observation_spec)
|
||||||
|
print("action_spec:", env.action_spec)
|
||||||
|
print("reward_spec:", env.reward_spec)
|
||||||
|
print("done_spec:", env.done_spec)
|
||||||
|
|
||||||
|
td = env.reset()
|
||||||
|
print("reset tensordict", td)
|
||||||
|
|
||||||
|
td = env.rand_step(td)
|
||||||
|
print("random step tensordict", td)
|
||||||
|
|
||||||
|
def simple_rollout(steps=100):
|
||||||
|
# preallocate:
|
||||||
|
data = TensorDict({}, [steps])
|
||||||
|
# reset
|
||||||
|
_data = env.reset()
|
||||||
|
for i in range(steps):
|
||||||
|
_data["action"] = env.action_spec.rand()
|
||||||
|
_data = env.step(_data)
|
||||||
|
data[i] = _data
|
||||||
|
_data = step_mdp(_data, keep_other=True)
|
||||||
|
return data
|
||||||
|
|
||||||
|
print("data from rollout:", simple_rollout(100))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -26,30 +54,21 @@ def test_simxarm(task, from_pixels, pixels_only):
|
||||||
pixels_only=pixels_only,
|
pixels_only=pixels_only,
|
||||||
image_size=84 if from_pixels else None,
|
image_size=84 if from_pixels else None,
|
||||||
)
|
)
|
||||||
|
print_spec_rollout(env)
|
||||||
check_env_specs(env)
|
check_env_specs(env)
|
||||||
|
|
||||||
print("observation_spec:", env.observation_spec)
|
|
||||||
print("action_spec:", env.action_spec)
|
|
||||||
print("reward_spec:", env.reward_spec)
|
|
||||||
print("done_spec:", env.done_spec)
|
|
||||||
print("success_spec:", env.success_spec)
|
|
||||||
|
|
||||||
td = env.reset()
|
@pytest.mark.parametrize(
|
||||||
print("reset tensordict", td)
|
"from_pixels,pixels_only",
|
||||||
|
[
|
||||||
td = env.rand_step(td)
|
(True, False),
|
||||||
print("random step tensordict", td)
|
],
|
||||||
|
)
|
||||||
def simple_rollout(steps=100):
|
def test_pusht(from_pixels, pixels_only):
|
||||||
# preallocate:
|
env = PushtEnv(
|
||||||
data = TensorDict({}, [steps])
|
from_pixels=from_pixels,
|
||||||
# reset
|
pixels_only=pixels_only,
|
||||||
_data = env.reset()
|
image_size=96 if from_pixels else None,
|
||||||
for i in range(steps):
|
)
|
||||||
_data["action"] = env.action_spec.rand()
|
print_spec_rollout(env)
|
||||||
_data = env.step(_data)
|
check_env_specs(env)
|
||||||
data[i] = _data
|
|
||||||
_data = step_mdp(_data, keep_other=True)
|
|
||||||
return data
|
|
||||||
|
|
||||||
print("data from rollout:", simple_rollout(100))
|
|
||||||
|
|
Loading…
Reference in New Issue