wip: still needs batch logic for act and tdmp
This commit is contained in:
parent
8c56770318
commit
ba91976944
|
@ -168,42 +168,31 @@ class AlohaEnv(AbstractEnv):
|
||||||
def _step(self, tensordict: TensorDict):
|
def _step(self, tensordict: TensorDict):
|
||||||
td = tensordict
|
td = tensordict
|
||||||
action = td["action"].numpy()
|
action = td["action"].numpy()
|
||||||
# step expects shape=(4,) so we pad if necessary
|
assert action.ndim == 1
|
||||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||||
sum_reward = 0
|
|
||||||
|
|
||||||
if action.ndim == 1:
|
_, reward, _, raw_obs = self._env.step(action)
|
||||||
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
|
||||||
else:
|
|
||||||
if self.frame_skip > 1:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
num_action_steps = action.shape[0]
|
# TODO(rcadene): add an enum
|
||||||
for i in range(num_action_steps):
|
success = done = reward == 4
|
||||||
_, reward, discount, raw_obs = self._env.step(action[i])
|
obs = self._format_raw_obs(raw_obs)
|
||||||
del discount # not used
|
|
||||||
|
|
||||||
# TOOD(rcadene): add an enum
|
if self.num_prev_obs > 0:
|
||||||
success = done = reward == 4
|
stacked_obs = {}
|
||||||
sum_reward += reward
|
if "image" in obs:
|
||||||
obs = self._format_raw_obs(raw_obs)
|
self._prev_obs_image_queue.append(obs["image"]["top"])
|
||||||
|
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||||
|
if "state" in obs:
|
||||||
|
self._prev_obs_state_queue.append(obs["state"])
|
||||||
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
|
obs = stacked_obs
|
||||||
|
|
||||||
if self.num_prev_obs > 0:
|
self.call_rendering_hooks()
|
||||||
stacked_obs = {}
|
|
||||||
if "image" in obs:
|
|
||||||
self._prev_obs_image_queue.append(obs["image"]["top"])
|
|
||||||
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
|
||||||
if "state" in obs:
|
|
||||||
self._prev_obs_state_queue.append(obs["state"])
|
|
||||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
|
||||||
obs = stacked_obs
|
|
||||||
|
|
||||||
self.call_rendering_hooks()
|
|
||||||
|
|
||||||
td = TensorDict(
|
td = TensorDict(
|
||||||
{
|
{
|
||||||
"observation": TensorDict(obs, batch_size=[]),
|
"observation": TensorDict(obs, batch_size=[]),
|
||||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
"reward": torch.tensor([reward], dtype=torch.float32),
|
||||||
# succes and done are true when coverage > self.success_threshold in env
|
# succes and done are true when coverage > self.success_threshold in env
|
||||||
"done": torch.tensor([done], dtype=torch.bool),
|
"done": torch.tensor([done], dtype=torch.bool),
|
||||||
"success": torch.tensor([success], dtype=torch.bool),
|
"success": torch.tensor([success], dtype=torch.bool),
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg, transform=None):
|
def make_env(cfg, seed=None, transform=None):
|
||||||
|
"""
|
||||||
|
Provide seed to override the seed in the cfg (useful for batched environments).
|
||||||
|
"""
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"frame_skip": cfg.env.action_repeat,
|
"frame_skip": cfg.env.action_repeat,
|
||||||
"from_pixels": cfg.env.from_pixels,
|
"from_pixels": cfg.env.from_pixels,
|
||||||
"pixels_only": cfg.env.pixels_only,
|
"pixels_only": cfg.env.pixels_only,
|
||||||
"image_size": cfg.env.image_size,
|
"image_size": cfg.env.image_size,
|
||||||
# TODO(rcadene): do we want a specific eval_env_seed?
|
|
||||||
"seed": cfg.seed,
|
|
||||||
"num_prev_obs": cfg.n_obs_steps - 1,
|
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||||
|
"seed": seed if seed is not None else cfg.seed,
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.env.name == "simxarm":
|
if cfg.env.name == "simxarm":
|
||||||
|
|
|
@ -2,7 +2,6 @@ import importlib
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import einops
|
|
||||||
import torch
|
import torch
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.data.tensor_specs import (
|
from torchrl.data.tensor_specs import (
|
||||||
|
@ -120,40 +119,30 @@ class PushtEnv(AbstractEnv):
|
||||||
def _step(self, tensordict: TensorDict):
|
def _step(self, tensordict: TensorDict):
|
||||||
td = tensordict
|
td = tensordict
|
||||||
action = td["action"].numpy()
|
action = td["action"].numpy()
|
||||||
# step expects shape=(4,) so we pad if necessary
|
assert action.ndim == 1
|
||||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||||
sum_reward = 0
|
|
||||||
|
|
||||||
if action.ndim == 1:
|
raw_obs, reward, done, info = self._env.step(action)
|
||||||
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
|
||||||
else:
|
|
||||||
if self.frame_skip > 1:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
num_action_steps = action.shape[0]
|
obs = self._format_raw_obs(raw_obs)
|
||||||
for i in range(num_action_steps):
|
|
||||||
raw_obs, reward, done, info = self._env.step(action[i])
|
|
||||||
sum_reward += reward
|
|
||||||
|
|
||||||
obs = self._format_raw_obs(raw_obs)
|
if self.num_prev_obs > 0:
|
||||||
|
stacked_obs = {}
|
||||||
|
if "image" in obs:
|
||||||
|
self._prev_obs_image_queue.append(obs["image"])
|
||||||
|
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||||
|
if "state" in obs:
|
||||||
|
self._prev_obs_state_queue.append(obs["state"])
|
||||||
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
|
obs = stacked_obs
|
||||||
|
|
||||||
if self.num_prev_obs > 0:
|
self.call_rendering_hooks()
|
||||||
stacked_obs = {}
|
|
||||||
if "image" in obs:
|
|
||||||
self._prev_obs_image_queue.append(obs["image"])
|
|
||||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
|
||||||
if "state" in obs:
|
|
||||||
self._prev_obs_state_queue.append(obs["state"])
|
|
||||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
|
||||||
obs = stacked_obs
|
|
||||||
|
|
||||||
self.call_rendering_hooks()
|
|
||||||
|
|
||||||
td = TensorDict(
|
td = TensorDict(
|
||||||
{
|
{
|
||||||
"observation": TensorDict(obs, batch_size=[]),
|
"observation": TensorDict(obs, batch_size=[]),
|
||||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
"reward": torch.tensor([reward], dtype=torch.float32),
|
||||||
# succes and done are true when coverage > self.success_threshold in env
|
# success and done are true when coverage > self.success_threshold in env
|
||||||
"done": torch.tensor([done], dtype=torch.bool),
|
"done": torch.tensor([done], dtype=torch.bool),
|
||||||
"success": torch.tensor([done], dtype=torch.bool),
|
"success": torch.tensor([done], dtype=torch.bool),
|
||||||
},
|
},
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
from abc import abstractmethod
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractPolicy(nn.Module):
|
||||||
|
@abstractmethod
|
||||||
|
def update(self, replay_buffer, step):
|
||||||
|
"""One step of the policy's learning algorithm."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save(self, fp):
|
||||||
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
|
def load(self, fp):
|
||||||
|
d = torch.load(fp)
|
||||||
|
self.load_state_dict(d)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def select_action(self, observation) -> Tensor:
|
||||||
|
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||||
|
|
||||||
|
Should return a (batch_size, n_action_steps, *) tensor of actions.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
||||||
|
|
||||||
|
WARNING: In general, this should not be overriden.
|
||||||
|
|
||||||
|
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
||||||
|
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
||||||
|
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
||||||
|
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||||
|
the subclass doesn't have to.
|
||||||
|
|
||||||
|
This method effectively wraps the `select_action` method of the subclass. The following assumptions are made:
|
||||||
|
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||||
|
the action trajectory horizon and * is the action dimensions.
|
||||||
|
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||||
|
"""
|
||||||
|
n_action_steps_attr = "n_action_steps"
|
||||||
|
if not hasattr(self, n_action_steps_attr):
|
||||||
|
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute")
|
||||||
|
if not hasattr(self, "_action_queue"):
|
||||||
|
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr))
|
||||||
|
if len(self._action_queue) == 0:
|
||||||
|
# Each element in the queue has shape (B, *).
|
||||||
|
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1))
|
||||||
|
|
||||||
|
return self._action_queue.popleft()
|
|
@ -2,10 +2,10 @@ import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
from lerobot.common.policies.act.detr_vae import build
|
from lerobot.common.policies.act.detr_vae import build
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ def kl_divergence(mu, logvar):
|
||||||
return total_kld, dimension_wise_kld, mean_kld
|
return total_kld, dimension_wise_kld, mean_kld
|
||||||
|
|
||||||
|
|
||||||
class ActionChunkingTransformerPolicy(nn.Module):
|
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||||
def __init__(self, cfg, device, n_action_steps=1):
|
def __init__(self, cfg, device, n_action_steps=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
@ -147,7 +147,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, observation, step_count):
|
def select_action(self, observation, step_count):
|
||||||
# TODO(rcadene): remove unused step_count
|
# TODO(rcadene): remove unused step_count
|
||||||
del step_count
|
del step_count
|
||||||
|
|
||||||
|
|
|
@ -3,14 +3,14 @@ import time
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPolicy(nn.Module):
|
class DiffusionPolicy(AbstractPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cfg,
|
cfg,
|
||||||
|
@ -44,6 +44,7 @@ class DiffusionPolicy(nn.Module):
|
||||||
**cfg_obs_encoder,
|
**cfg_obs_encoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.n_action_steps = n_action_steps # needed for the parent class
|
||||||
self.diffusion = DiffusionUnetImagePolicy(
|
self.diffusion = DiffusionUnetImagePolicy(
|
||||||
shape_meta=shape_meta,
|
shape_meta=shape_meta,
|
||||||
noise_scheduler=noise_scheduler,
|
noise_scheduler=noise_scheduler,
|
||||||
|
@ -93,21 +94,16 @@ class DiffusionPolicy(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, observation, step_count):
|
def select_action(self, observation, step_count):
|
||||||
# TODO(rcadene): remove unused step_count
|
# TODO(rcadene): remove unused step_count
|
||||||
del step_count
|
del step_count
|
||||||
|
|
||||||
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
|
||||||
observation["image"] = observation["image"].unsqueeze(0)
|
|
||||||
observation["state"] = observation["state"].unsqueeze(0)
|
|
||||||
|
|
||||||
obs_dict = {
|
obs_dict = {
|
||||||
"image": observation["image"],
|
"image": observation["image"],
|
||||||
"agent_pos": observation["state"],
|
"agent_pos": observation["state"],
|
||||||
}
|
}
|
||||||
out = self.diffusion.predict_action(obs_dict)
|
out = self.diffusion.predict_action(obs_dict)
|
||||||
|
action = out["action"]
|
||||||
action = out["action"].squeeze(0)
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def update(self, replay_buffer, step):
|
def update(self, replay_buffer, step):
|
||||||
|
|
|
@ -9,6 +9,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import lerobot.common.policies.tdmpc.helper as h
|
import lerobot.common.policies.tdmpc.helper as h
|
||||||
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
|
|
||||||
FIRST_FRAME = 0
|
FIRST_FRAME = 0
|
||||||
|
|
||||||
|
@ -85,7 +86,7 @@ class TOLD(nn.Module):
|
||||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||||
|
|
||||||
|
|
||||||
class TDMPC(nn.Module):
|
class TDMPC(AbstractPolicy):
|
||||||
"""Implementation of TD-MPC learning + inference."""
|
"""Implementation of TD-MPC learning + inference."""
|
||||||
|
|
||||||
def __init__(self, cfg, device):
|
def __init__(self, cfg, device):
|
||||||
|
@ -124,7 +125,7 @@ class TDMPC(nn.Module):
|
||||||
self.model_target.load_state_dict(d["model_target"])
|
self.model_target.load_state_dict(d["model_target"])
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, observation, step_count):
|
def select_action(self, observation, step_count):
|
||||||
t0 = step_count.item() == 0
|
t0 = step_count.item() == 0
|
||||||
|
|
||||||
# TODO(rcadene): remove unsqueeze hack...
|
# TODO(rcadene): remove unsqueeze hack...
|
||||||
|
|
|
@ -10,6 +10,8 @@ hydra:
|
||||||
name: default
|
name: default
|
||||||
|
|
||||||
seed: 1337
|
seed: 1337
|
||||||
|
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
|
||||||
|
rollout_batch_size: 10
|
||||||
device: cuda # cpu
|
device: cuda # cpu
|
||||||
prefetch: 4
|
prefetch: 4
|
||||||
eval_freq: ???
|
eval_freq: ???
|
||||||
|
|
|
@ -9,7 +9,8 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
from torchrl.envs import EnvBase
|
from torchrl.envs import EnvBase, SerialEnv
|
||||||
|
from torchrl.envs.batched_envs import BatchedEnvBase
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
|
@ -23,7 +24,7 @@ def write_video(video_path, stacked_frames, fps):
|
||||||
|
|
||||||
|
|
||||||
def eval_policy(
|
def eval_policy(
|
||||||
env: EnvBase,
|
env: BatchedEnvBase,
|
||||||
policy: TensorDictModule = None,
|
policy: TensorDictModule = None,
|
||||||
num_episodes: int = 10,
|
num_episodes: int = 10,
|
||||||
max_steps: int = 30,
|
max_steps: int = 30,
|
||||||
|
@ -36,45 +37,55 @@ def eval_policy(
|
||||||
sum_rewards = []
|
sum_rewards = []
|
||||||
max_rewards = []
|
max_rewards = []
|
||||||
successes = []
|
successes = []
|
||||||
threads = []
|
threads = [] # for video saving threads
|
||||||
for i in tqdm.tqdm(range(num_episodes)):
|
episode_counter = 0 # for saving the correct number of videos
|
||||||
|
|
||||||
|
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
|
||||||
|
# needed as I'm currently taking a ceil.
|
||||||
|
for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))):
|
||||||
ep_frames = []
|
ep_frames = []
|
||||||
if save_video or (return_first_video and i == 0):
|
|
||||||
|
|
||||||
def render_frame(env):
|
def maybe_render_frame(env: EnvBase, _):
|
||||||
|
if save_video or (return_first_video and i == 0): # noqa: B023
|
||||||
ep_frames.append(env.render()) # noqa: B023
|
ep_frames.append(env.render()) # noqa: B023
|
||||||
|
|
||||||
env.register_rendering_hook(render_frame)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
rollout = env.rollout(
|
rollout = env.rollout(
|
||||||
max_steps=max_steps,
|
max_steps=max_steps,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
auto_cast_to_device=True,
|
auto_cast_to_device=True,
|
||||||
|
callback=maybe_render_frame,
|
||||||
)
|
)
|
||||||
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
||||||
ep_sum_reward = rollout["next", "reward"].sum()
|
batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1)
|
||||||
ep_max_reward = rollout["next", "reward"].max()
|
batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0]
|
||||||
ep_success = rollout["next", "success"].any()
|
batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1)
|
||||||
sum_rewards.append(ep_sum_reward.item())
|
sum_rewards.extend(batch_sum_reward.tolist())
|
||||||
max_rewards.append(ep_max_reward.item())
|
max_rewards.extend(batch_max_reward.tolist())
|
||||||
successes.append(ep_success.item())
|
successes.extend(batch_success.tolist())
|
||||||
|
|
||||||
if save_video or (return_first_video and i == 0):
|
if save_video or (return_first_video and i == 0):
|
||||||
stacked_frames = np.stack(ep_frames)
|
batch_stacked_frames = np.stack(ep_frames) # (t, b, *)
|
||||||
|
batch_stacked_frames = batch_stacked_frames.transpose(
|
||||||
|
1, 0, *range(2, batch_stacked_frames.ndim)
|
||||||
|
) # (b, t, *)
|
||||||
|
|
||||||
if save_video:
|
if save_video:
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
for stacked_frames in batch_stacked_frames:
|
||||||
video_path = video_dir / f"eval_episode_{i}.mp4"
|
if episode_counter >= num_episodes:
|
||||||
thread = threading.Thread(
|
continue
|
||||||
target=write_video,
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
args=(str(video_path), stacked_frames, fps),
|
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||||
)
|
thread = threading.Thread(
|
||||||
thread.start()
|
target=write_video,
|
||||||
threads.append(thread)
|
args=(str(video_path), stacked_frames, fps),
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
threads.append(thread)
|
||||||
|
episode_counter += 1
|
||||||
|
|
||||||
if return_first_video and i == 0:
|
if return_first_video and i == 0:
|
||||||
first_video = stacked_frames.transpose(0, 3, 1, 2)
|
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
|
||||||
|
|
||||||
env.reset_rendering_hooks()
|
env.reset_rendering_hooks()
|
||||||
|
|
||||||
|
@ -82,9 +93,9 @@ def eval_policy(
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"avg_sum_reward": np.nanmean(sum_rewards),
|
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
|
||||||
"avg_max_reward": np.nanmean(max_rewards),
|
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
|
||||||
"pc_success": np.nanmean(successes) * 100,
|
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
|
||||||
"eval_s": time.time() - start,
|
"eval_s": time.time() - start,
|
||||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||||
}
|
}
|
||||||
|
@ -119,7 +130,14 @@ def eval(cfg: dict, out_dir=None):
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
env = make_env(cfg, transform=offline_buffer.transform)
|
env = SerialEnv(
|
||||||
|
cfg.rollout_batch_size,
|
||||||
|
create_env_fn=make_env,
|
||||||
|
create_env_kwargs=[
|
||||||
|
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
|
||||||
|
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.policy.pretrained_model_path:
|
if cfg.policy.pretrained_model_path:
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
@ -138,7 +156,7 @@ def eval(cfg: dict, out_dir=None):
|
||||||
save_video=True,
|
save_video=True,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
fps=cfg.env.fps,
|
fps=cfg.env.fps,
|
||||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
max_steps=cfg.env.episode_length,
|
||||||
num_episodes=cfg.eval_episodes,
|
num_episodes=cfg.eval_episodes,
|
||||||
)
|
)
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
|
@ -7,6 +7,7 @@ import torch
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
||||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||||
|
from torchrl.envs import SerialEnv
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
|
@ -148,6 +149,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
env = make_env(cfg, transform=offline_buffer.transform)
|
env = make_env(cfg, transform=offline_buffer.transform)
|
||||||
|
env = SerialEnv(
|
||||||
|
cfg.rollout_batch_size,
|
||||||
|
create_env_fn=make_env,
|
||||||
|
create_env_kwargs=[
|
||||||
|
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
|
||||||
|
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
@ -191,7 +200,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
env,
|
env,
|
||||||
td_policy,
|
td_policy,
|
||||||
num_episodes=cfg.eval_episodes,
|
num_episodes=cfg.eval_episodes,
|
||||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
max_steps=cfg.env.episode_length,
|
||||||
return_first_video=True,
|
return_first_video=True,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
save_video=True,
|
save_video=True,
|
||||||
|
|
|
@ -1,7 +1,15 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from tensordict.nn import TensorDictModule
|
||||||
|
import torch
|
||||||
|
from torchrl.data import UnboundedContinuousTensorSpec
|
||||||
|
from torchrl.envs import EnvBase
|
||||||
|
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
|
||||||
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
|
|
||||||
from .utils import DEVICE, init_config
|
from .utils import DEVICE, init_config
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,3 +31,75 @@ def test_factory(env_name, policy_name):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def test_abstract_policy_forward():
|
||||||
|
"""
|
||||||
|
Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that:
|
||||||
|
- The policy is invoked the expected number of times during a rollout.
|
||||||
|
- The environment's termination condition is respected even when part way through an action trajectory.
|
||||||
|
- The observations are returned correctly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
n_action_steps = 8 # our test policy will output 8 action step horizons
|
||||||
|
terminate_at = 10 # some number that is more than n_action_steps but not a multiple
|
||||||
|
rollout_max_steps = terminate_at + 1 # some number greater than terminate_at
|
||||||
|
|
||||||
|
# A minimal environment for testing.
|
||||||
|
class StubEnv(EnvBase):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.action_spec = UnboundedContinuousTensorSpec(shape=(1,))
|
||||||
|
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
|
||||||
|
|
||||||
|
def _step(self, tensordict: TensorDict) -> TensorDict:
|
||||||
|
self.invocation_count += 1
|
||||||
|
return TensorDict(
|
||||||
|
{
|
||||||
|
"observation": torch.tensor([self.invocation_count]),
|
||||||
|
"reward": torch.tensor([self.invocation_count]),
|
||||||
|
"terminated": torch.tensor(
|
||||||
|
tensordict["action"].item() == terminate_at
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _reset(self, tensordict: TensorDict) -> TensorDict:
|
||||||
|
self.invocation_count = 0
|
||||||
|
return TensorDict(
|
||||||
|
{
|
||||||
|
"observation": torch.tensor([self.invocation_count]),
|
||||||
|
"reward": torch.tensor([self.invocation_count]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_seed(self, seed: int | None):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class StubPolicy(AbstractPolicy):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.n_action_steps = n_action_steps
|
||||||
|
self.n_policy_invocations = 0
|
||||||
|
|
||||||
|
def select_action(self):
|
||||||
|
self.n_policy_invocations += 1
|
||||||
|
return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
env = StubEnv()
|
||||||
|
policy = StubPolicy()
|
||||||
|
policy = TensorDictModule(
|
||||||
|
policy,
|
||||||
|
in_keys=[],
|
||||||
|
out_keys=["action"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep track to make sure the policy is called the expected number of times
|
||||||
|
rollout = env.rollout(rollout_max_steps, policy)
|
||||||
|
|
||||||
|
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
|
||||||
|
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
|
||||||
|
assert torch.equal(rollout['observation'].flatten(), torch.arange(terminate_at + 1))
|
||||||
|
|
Loading…
Reference in New Issue