wip: still needs batch logic for act and tdmp

This commit is contained in:
Alexander Soare 2024-03-14 15:22:55 +00:00
parent 8c56770318
commit ba91976944
11 changed files with 240 additions and 100 deletions
lerobot
common
configs
scripts
tests

View File

@ -168,42 +168,31 @@ class AlohaEnv(AbstractEnv):
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
assert action.ndim == 1
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
_, reward, _, raw_obs = self._env.step(action)
num_action_steps = action.shape[0]
for i in range(num_action_steps):
_, reward, discount, raw_obs = self._env.step(action[i])
del discount # not used
# TODO(rcadene): add an enum
success = done = reward == 4
obs = self._format_raw_obs(raw_obs)
# TOOD(rcadene): add an enum
success = done = reward == 4
sum_reward += reward
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
self.call_rendering_hooks()
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
"reward": torch.tensor([reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([success], dtype=torch.bool),

View File

@ -1,15 +1,17 @@
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
def make_env(cfg, transform=None):
def make_env(cfg, seed=None, transform=None):
"""
Provide seed to override the seed in the cfg (useful for batched environments).
"""
kwargs = {
"frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed,
"num_prev_obs": cfg.n_obs_steps - 1,
"seed": seed if seed is not None else cfg.seed,
}
if cfg.env.name == "simxarm":

View File

@ -2,7 +2,6 @@ import importlib
from collections import deque
from typing import Optional
import einops
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
@ -120,40 +119,30 @@ class PushtEnv(AbstractEnv):
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
assert action.ndim == 1
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
raw_obs, reward, done, info = self._env.step(action)
num_action_steps = action.shape[0]
for i in range(num_action_steps):
raw_obs, reward, done, info = self._env.step(action[i])
sum_reward += reward
obs = self._format_raw_obs(raw_obs)
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
self.call_rendering_hooks()
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"reward": torch.tensor([reward], dtype=torch.float32),
# success and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([done], dtype=torch.bool),
},

View File

@ -0,0 +1,54 @@
from abc import abstractmethod
from collections import deque
import torch
from torch import Tensor, nn
class AbstractPolicy(nn.Module):
@abstractmethod
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
pass
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
@abstractmethod
def select_action(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout.
Should return a (batch_size, n_action_steps, *) tensor of actions.
"""
pass
def forward(self, *args, **kwargs):
"""Inference step that makes multi-step policies compatible with their single-step environments.
WARNING: In general, this should not be overriden.
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
the subclass doesn't have to.
This method effectively wraps the `select_action` method of the subclass. The following assumptions are made:
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
the action trajectory horizon and * is the action dimensions.
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined.
"""
n_action_steps_attr = "n_action_steps"
if not hasattr(self, n_action_steps_attr):
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute")
if not hasattr(self, "_action_queue"):
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr))
if len(self._action_queue) == 0:
# Each element in the queue has shape (B, *).
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1))
return self._action_queue.popleft()

View File

@ -2,10 +2,10 @@ import logging
import time
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.act.detr_vae import build
@ -40,7 +40,7 @@ def kl_divergence(mu, logvar):
return total_kld, dimension_wise_kld, mean_kld
class ActionChunkingTransformerPolicy(nn.Module):
class ActionChunkingTransformerPolicy(AbstractPolicy):
def __init__(self, cfg, device, n_action_steps=1):
super().__init__()
self.cfg = cfg
@ -147,7 +147,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
return loss
@torch.no_grad()
def forward(self, observation, step_count):
def select_action(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count

View File

@ -3,14 +3,14 @@ import time
import hydra
import torch
import torch.nn as nn
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
class DiffusionPolicy(nn.Module):
class DiffusionPolicy(AbstractPolicy):
def __init__(
self,
cfg,
@ -44,6 +44,7 @@ class DiffusionPolicy(nn.Module):
**cfg_obs_encoder,
)
self.n_action_steps = n_action_steps # needed for the parent class
self.diffusion = DiffusionUnetImagePolicy(
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
@ -93,21 +94,16 @@ class DiffusionPolicy(nn.Module):
)
@torch.no_grad()
def forward(self, observation, step_count):
def select_action(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count
# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
obs_dict = {
"image": observation["image"],
"agent_pos": observation["state"],
}
out = self.diffusion.predict_action(obs_dict)
action = out["action"].squeeze(0)
action = out["action"]
return action
def update(self, replay_buffer, step):

View File

@ -9,6 +9,7 @@ import torch
import torch.nn as nn
import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.abstract import AbstractPolicy
FIRST_FRAME = 0
@ -85,7 +86,7 @@ class TOLD(nn.Module):
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
class TDMPC(nn.Module):
class TDMPC(AbstractPolicy):
"""Implementation of TD-MPC learning + inference."""
def __init__(self, cfg, device):
@ -124,7 +125,7 @@ class TDMPC(nn.Module):
self.model_target.load_state_dict(d["model_target"])
@torch.no_grad()
def forward(self, observation, step_count):
def select_action(self, observation, step_count):
t0 = step_count.item() == 0
# TODO(rcadene): remove unsqueeze hack...

View File

@ -10,6 +10,8 @@ hydra:
name: default
seed: 1337
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
rollout_batch_size: 10
device: cuda # cpu
prefetch: 4
eval_freq: ???

View File

@ -9,7 +9,8 @@ import numpy as np
import torch
import tqdm
from tensordict.nn import TensorDictModule
from torchrl.envs import EnvBase
from torchrl.envs import EnvBase, SerialEnv
from torchrl.envs.batched_envs import BatchedEnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
@ -23,7 +24,7 @@ def write_video(video_path, stacked_frames, fps):
def eval_policy(
env: EnvBase,
env: BatchedEnvBase,
policy: TensorDictModule = None,
num_episodes: int = 10,
max_steps: int = 30,
@ -36,45 +37,55 @@ def eval_policy(
sum_rewards = []
max_rewards = []
successes = []
threads = []
for i in tqdm.tqdm(range(num_episodes)):
threads = [] # for video saving threads
episode_counter = 0 # for saving the correct number of videos
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
# needed as I'm currently taking a ceil.
for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))):
ep_frames = []
if save_video or (return_first_video and i == 0):
def render_frame(env):
def maybe_render_frame(env: EnvBase, _):
if save_video or (return_first_video and i == 0): # noqa: B023
ep_frames.append(env.render()) # noqa: B023
env.register_rendering_hook(render_frame)
with torch.inference_mode():
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
auto_cast_to_device=True,
callback=maybe_render_frame,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
ep_success = rollout["next", "success"].any()
sum_rewards.append(ep_sum_reward.item())
max_rewards.append(ep_max_reward.item())
successes.append(ep_success.item())
batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1)
batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0]
batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1)
sum_rewards.extend(batch_sum_reward.tolist())
max_rewards.extend(batch_max_reward.tolist())
successes.extend(batch_success.tolist())
if save_video or (return_first_video and i == 0):
stacked_frames = np.stack(ep_frames)
batch_stacked_frames = np.stack(ep_frames) # (t, b, *)
batch_stacked_frames = batch_stacked_frames.transpose(
1, 0, *range(2, batch_stacked_frames.ndim)
) # (b, t, *)
if save_video:
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{i}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames, fps),
)
thread.start()
threads.append(thread)
for stacked_frames in batch_stacked_frames:
if episode_counter >= num_episodes:
continue
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames, fps),
)
thread.start()
threads.append(thread)
episode_counter += 1
if return_first_video and i == 0:
first_video = stacked_frames.transpose(0, 3, 1, 2)
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
env.reset_rendering_hooks()
@ -82,9 +93,9 @@ def eval_policy(
thread.join()
info = {
"avg_sum_reward": np.nanmean(sum_rewards),
"avg_max_reward": np.nanmean(max_rewards),
"pc_success": np.nanmean(successes) * 100,
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
}
@ -119,7 +130,14 @@ def eval(cfg: dict, out_dir=None):
offline_buffer = make_offline_buffer(cfg)
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform)
env = SerialEnv(
cfg.rollout_batch_size,
create_env_fn=make_env,
create_env_kwargs=[
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)
if cfg.policy.pretrained_model_path:
policy = make_policy(cfg)
@ -138,7 +156,7 @@ def eval(cfg: dict, out_dir=None):
save_video=True,
video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
max_steps=cfg.env.episode_length,
num_episodes=cfg.eval_episodes,
)
print(metrics)

View File

@ -7,6 +7,7 @@ import torch
from tensordict.nn import TensorDictModule
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from torchrl.envs import SerialEnv
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
@ -148,6 +149,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform)
env = SerialEnv(
cfg.rollout_batch_size,
create_env_fn=make_env,
create_env_kwargs=[
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)
logging.info("make_policy")
policy = make_policy(cfg)
@ -191,7 +200,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
env,
td_policy,
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
max_steps=cfg.env.episode_length,
return_first_video=True,
video_dir=Path(out_dir) / "eval",
save_video=True,

View File

@ -1,7 +1,15 @@
import pytest
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
import torch
from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.abstract import AbstractPolicy
from .utils import DEVICE, init_config
@ -23,3 +31,75 @@ def test_factory(env_name, policy_name):
]
)
policy = make_policy(cfg)
def test_abstract_policy_forward():
"""
Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that:
- The policy is invoked the expected number of times during a rollout.
- The environment's termination condition is respected even when part way through an action trajectory.
- The observations are returned correctly.
"""
n_action_steps = 8 # our test policy will output 8 action step horizons
terminate_at = 10 # some number that is more than n_action_steps but not a multiple
rollout_max_steps = terminate_at + 1 # some number greater than terminate_at
# A minimal environment for testing.
class StubEnv(EnvBase):
def __init__(self):
super().__init__()
self.action_spec = UnboundedContinuousTensorSpec(shape=(1,))
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
def _step(self, tensordict: TensorDict) -> TensorDict:
self.invocation_count += 1
return TensorDict(
{
"observation": torch.tensor([self.invocation_count]),
"reward": torch.tensor([self.invocation_count]),
"terminated": torch.tensor(
tensordict["action"].item() == terminate_at
),
}
)
def _reset(self, tensordict: TensorDict) -> TensorDict:
self.invocation_count = 0
return TensorDict(
{
"observation": torch.tensor([self.invocation_count]),
"reward": torch.tensor([self.invocation_count]),
}
)
def _set_seed(self, seed: int | None):
return
class StubPolicy(AbstractPolicy):
def __init__(self):
super().__init__()
self.n_action_steps = n_action_steps
self.n_policy_invocations = 0
def select_action(self):
self.n_policy_invocations += 1
return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0)
env = StubEnv()
policy = StubPolicy()
policy = TensorDictModule(
policy,
in_keys=[],
out_keys=["action"],
)
# Keep track to make sure the policy is called the expected number of times
rollout = env.rollout(rollout_max_steps, policy)
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
assert torch.equal(rollout['observation'].flatten(), torch.arange(terminate_at + 1))