wip: still needs batch logic for act and tdmp
This commit is contained in:
parent
8c56770318
commit
ba91976944
lerobot
common
configs
scripts
tests
|
@ -168,42 +168,31 @@ class AlohaEnv(AbstractEnv):
|
|||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
# step expects shape=(4,) so we pad if necessary
|
||||
assert action.ndim == 1
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
|
||||
if action.ndim == 1:
|
||||
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
||||
else:
|
||||
if self.frame_skip > 1:
|
||||
raise NotImplementedError()
|
||||
_, reward, _, raw_obs = self._env.step(action)
|
||||
|
||||
num_action_steps = action.shape[0]
|
||||
for i in range(num_action_steps):
|
||||
_, reward, discount, raw_obs = self._env.step(action[i])
|
||||
del discount # not used
|
||||
# TODO(rcadene): add an enum
|
||||
success = done = reward == 4
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
# TOOD(rcadene): add an enum
|
||||
success = done = reward == 4
|
||||
sum_reward += reward
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"]["top"])
|
||||
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"]["top"])
|
||||
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
self.call_rendering_hooks()
|
||||
self.call_rendering_hooks()
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||
"reward": torch.tensor([reward], dtype=torch.float32),
|
||||
# succes and done are true when coverage > self.success_threshold in env
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([success], dtype=torch.bool),
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
||||
|
||||
|
||||
def make_env(cfg, transform=None):
|
||||
def make_env(cfg, seed=None, transform=None):
|
||||
"""
|
||||
Provide seed to override the seed in the cfg (useful for batched environments).
|
||||
"""
|
||||
kwargs = {
|
||||
"frame_skip": cfg.env.action_repeat,
|
||||
"from_pixels": cfg.env.from_pixels,
|
||||
"pixels_only": cfg.env.pixels_only,
|
||||
"image_size": cfg.env.image_size,
|
||||
# TODO(rcadene): do we want a specific eval_env_seed?
|
||||
"seed": cfg.seed,
|
||||
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||
"seed": seed if seed is not None else cfg.seed,
|
||||
}
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
|
|
|
@ -2,7 +2,6 @@ import importlib
|
|||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
|
@ -120,40 +119,30 @@ class PushtEnv(AbstractEnv):
|
|||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
# step expects shape=(4,) so we pad if necessary
|
||||
assert action.ndim == 1
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
|
||||
if action.ndim == 1:
|
||||
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
||||
else:
|
||||
if self.frame_skip > 1:
|
||||
raise NotImplementedError()
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
|
||||
num_action_steps = action.shape[0]
|
||||
for i in range(num_action_steps):
|
||||
raw_obs, reward, done, info = self._env.step(action[i])
|
||||
sum_reward += reward
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"])
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"])
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
self.call_rendering_hooks()
|
||||
self.call_rendering_hooks()
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||
# succes and done are true when coverage > self.success_threshold in env
|
||||
"reward": torch.tensor([reward], dtype=torch.float32),
|
||||
# success and done are true when coverage > self.success_threshold in env
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([done], dtype=torch.bool),
|
||||
},
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class AbstractPolicy(nn.Module):
|
||||
@abstractmethod
|
||||
def update(self, replay_buffer, step):
|
||||
"""One step of the policy's learning algorithm."""
|
||||
pass
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
@abstractmethod
|
||||
def select_action(self, observation) -> Tensor:
|
||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||
|
||||
Should return a (batch_size, n_action_steps, *) tensor of actions.
|
||||
"""
|
||||
pass
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
||||
|
||||
WARNING: In general, this should not be overriden.
|
||||
|
||||
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
||||
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
||||
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
||||
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||
the subclass doesn't have to.
|
||||
|
||||
This method effectively wraps the `select_action` method of the subclass. The following assumptions are made:
|
||||
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||
the action trajectory horizon and * is the action dimensions.
|
||||
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||
"""
|
||||
n_action_steps_attr = "n_action_steps"
|
||||
if not hasattr(self, n_action_steps_attr):
|
||||
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute")
|
||||
if not hasattr(self, "_action_queue"):
|
||||
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr))
|
||||
if len(self._action_queue) == 0:
|
||||
# Each element in the queue has shape (B, *).
|
||||
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1))
|
||||
|
||||
return self._action_queue.popleft()
|
|
@ -2,10 +2,10 @@ import logging
|
|||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.act.detr_vae import build
|
||||
|
||||
|
||||
|
@ -40,7 +40,7 @@ def kl_divergence(mu, logvar):
|
|||
return total_kld, dimension_wise_kld, mean_kld
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
@ -147,7 +147,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
def select_action(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
|
|
|
@ -3,14 +3,14 @@ import time
|
|||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
|
||||
|
||||
class DiffusionPolicy(nn.Module):
|
||||
class DiffusionPolicy(AbstractPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
|
@ -44,6 +44,7 @@ class DiffusionPolicy(nn.Module):
|
|||
**cfg_obs_encoder,
|
||||
)
|
||||
|
||||
self.n_action_steps = n_action_steps # needed for the parent class
|
||||
self.diffusion = DiffusionUnetImagePolicy(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=noise_scheduler,
|
||||
|
@ -93,21 +94,16 @@ class DiffusionPolicy(nn.Module):
|
|||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
def select_action(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
observation["state"] = observation["state"].unsqueeze(0)
|
||||
|
||||
obs_dict = {
|
||||
"image": observation["image"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
|
||||
action = out["action"].squeeze(0)
|
||||
action = out["action"]
|
||||
return action
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
|
|
|
@ -9,6 +9,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
import lerobot.common.policies.tdmpc.helper as h
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
|
||||
FIRST_FRAME = 0
|
||||
|
||||
|
@ -85,7 +86,7 @@ class TOLD(nn.Module):
|
|||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||
|
||||
|
||||
class TDMPC(nn.Module):
|
||||
class TDMPC(AbstractPolicy):
|
||||
"""Implementation of TD-MPC learning + inference."""
|
||||
|
||||
def __init__(self, cfg, device):
|
||||
|
@ -124,7 +125,7 @@ class TDMPC(nn.Module):
|
|||
self.model_target.load_state_dict(d["model_target"])
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
def select_action(self, observation, step_count):
|
||||
t0 = step_count.item() == 0
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack...
|
||||
|
|
|
@ -10,6 +10,8 @@ hydra:
|
|||
name: default
|
||||
|
||||
seed: 1337
|
||||
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
|
||||
rollout_batch_size: 10
|
||||
device: cuda # cpu
|
||||
prefetch: 4
|
||||
eval_freq: ???
|
||||
|
|
|
@ -9,7 +9,8 @@ import numpy as np
|
|||
import torch
|
||||
import tqdm
|
||||
from tensordict.nn import TensorDictModule
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs import EnvBase, SerialEnv
|
||||
from torchrl.envs.batched_envs import BatchedEnvBase
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
|
@ -23,7 +24,7 @@ def write_video(video_path, stacked_frames, fps):
|
|||
|
||||
|
||||
def eval_policy(
|
||||
env: EnvBase,
|
||||
env: BatchedEnvBase,
|
||||
policy: TensorDictModule = None,
|
||||
num_episodes: int = 10,
|
||||
max_steps: int = 30,
|
||||
|
@ -36,45 +37,55 @@ def eval_policy(
|
|||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
threads = []
|
||||
for i in tqdm.tqdm(range(num_episodes)):
|
||||
threads = [] # for video saving threads
|
||||
episode_counter = 0 # for saving the correct number of videos
|
||||
|
||||
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
|
||||
# needed as I'm currently taking a ceil.
|
||||
for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))):
|
||||
ep_frames = []
|
||||
if save_video or (return_first_video and i == 0):
|
||||
|
||||
def render_frame(env):
|
||||
def maybe_render_frame(env: EnvBase, _):
|
||||
if save_video or (return_first_video and i == 0): # noqa: B023
|
||||
ep_frames.append(env.render()) # noqa: B023
|
||||
|
||||
env.register_rendering_hook(render_frame)
|
||||
|
||||
with torch.inference_mode():
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
auto_cast_to_device=True,
|
||||
callback=maybe_render_frame,
|
||||
)
|
||||
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
||||
ep_sum_reward = rollout["next", "reward"].sum()
|
||||
ep_max_reward = rollout["next", "reward"].max()
|
||||
ep_success = rollout["next", "success"].any()
|
||||
sum_rewards.append(ep_sum_reward.item())
|
||||
max_rewards.append(ep_max_reward.item())
|
||||
successes.append(ep_success.item())
|
||||
batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1)
|
||||
batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0]
|
||||
batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1)
|
||||
sum_rewards.extend(batch_sum_reward.tolist())
|
||||
max_rewards.extend(batch_max_reward.tolist())
|
||||
successes.extend(batch_success.tolist())
|
||||
|
||||
if save_video or (return_first_video and i == 0):
|
||||
stacked_frames = np.stack(ep_frames)
|
||||
batch_stacked_frames = np.stack(ep_frames) # (t, b, *)
|
||||
batch_stacked_frames = batch_stacked_frames.transpose(
|
||||
1, 0, *range(2, batch_stacked_frames.ndim)
|
||||
) # (b, t, *)
|
||||
|
||||
if save_video:
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{i}.mp4"
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(str(video_path), stacked_frames, fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
for stacked_frames in batch_stacked_frames:
|
||||
if episode_counter >= num_episodes:
|
||||
continue
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(str(video_path), stacked_frames, fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
episode_counter += 1
|
||||
|
||||
if return_first_video and i == 0:
|
||||
first_video = stacked_frames.transpose(0, 3, 1, 2)
|
||||
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
|
||||
|
||||
env.reset_rendering_hooks()
|
||||
|
||||
|
@ -82,9 +93,9 @@ def eval_policy(
|
|||
thread.join()
|
||||
|
||||
info = {
|
||||
"avg_sum_reward": np.nanmean(sum_rewards),
|
||||
"avg_max_reward": np.nanmean(max_rewards),
|
||||
"pc_success": np.nanmean(successes) * 100,
|
||||
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
|
||||
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
|
||||
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
|
||||
"eval_s": time.time() - start,
|
||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||
}
|
||||
|
@ -119,7 +130,14 @@ def eval(cfg: dict, out_dir=None):
|
|||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
env = SerialEnv(
|
||||
cfg.rollout_batch_size,
|
||||
create_env_fn=make_env,
|
||||
create_env_kwargs=[
|
||||
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
|
||||
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||
],
|
||||
)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
policy = make_policy(cfg)
|
||||
|
@ -138,7 +156,7 @@ def eval(cfg: dict, out_dir=None):
|
|||
save_video=True,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
fps=cfg.env.fps,
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
max_steps=cfg.env.episode_length,
|
||||
num_episodes=cfg.eval_episodes,
|
||||
)
|
||||
print(metrics)
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch
|
|||
from tensordict.nn import TensorDictModule
|
||||
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
from torchrl.envs import SerialEnv
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
|
@ -148,6 +149,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
env = SerialEnv(
|
||||
cfg.rollout_batch_size,
|
||||
create_env_fn=make_env,
|
||||
create_env_kwargs=[
|
||||
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
|
||||
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||
],
|
||||
)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg)
|
||||
|
@ -191,7 +200,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
env,
|
||||
td_policy,
|
||||
num_episodes=cfg.eval_episodes,
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
max_steps=cfg.env.episode_length,
|
||||
return_first_video=True,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
save_video=True,
|
||||
|
|
|
@ -1,7 +1,15 @@
|
|||
|
||||
import pytest
|
||||
from tensordict import TensorDict
|
||||
from tensordict.nn import TensorDictModule
|
||||
import torch
|
||||
from torchrl.data import UnboundedContinuousTensorSpec
|
||||
from torchrl.envs import EnvBase
|
||||
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
|
||||
from .utils import DEVICE, init_config
|
||||
|
||||
|
||||
|
@ -23,3 +31,75 @@ def test_factory(env_name, policy_name):
|
|||
]
|
||||
)
|
||||
policy = make_policy(cfg)
|
||||
|
||||
|
||||
def test_abstract_policy_forward():
|
||||
"""
|
||||
Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that:
|
||||
- The policy is invoked the expected number of times during a rollout.
|
||||
- The environment's termination condition is respected even when part way through an action trajectory.
|
||||
- The observations are returned correctly.
|
||||
"""
|
||||
|
||||
n_action_steps = 8 # our test policy will output 8 action step horizons
|
||||
terminate_at = 10 # some number that is more than n_action_steps but not a multiple
|
||||
rollout_max_steps = terminate_at + 1 # some number greater than terminate_at
|
||||
|
||||
# A minimal environment for testing.
|
||||
class StubEnv(EnvBase):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.action_spec = UnboundedContinuousTensorSpec(shape=(1,))
|
||||
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
|
||||
|
||||
def _step(self, tensordict: TensorDict) -> TensorDict:
|
||||
self.invocation_count += 1
|
||||
return TensorDict(
|
||||
{
|
||||
"observation": torch.tensor([self.invocation_count]),
|
||||
"reward": torch.tensor([self.invocation_count]),
|
||||
"terminated": torch.tensor(
|
||||
tensordict["action"].item() == terminate_at
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _reset(self, tensordict: TensorDict) -> TensorDict:
|
||||
self.invocation_count = 0
|
||||
return TensorDict(
|
||||
{
|
||||
"observation": torch.tensor([self.invocation_count]),
|
||||
"reward": torch.tensor([self.invocation_count]),
|
||||
}
|
||||
)
|
||||
|
||||
def _set_seed(self, seed: int | None):
|
||||
return
|
||||
|
||||
|
||||
class StubPolicy(AbstractPolicy):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.n_action_steps = n_action_steps
|
||||
self.n_policy_invocations = 0
|
||||
|
||||
def select_action(self):
|
||||
self.n_policy_invocations += 1
|
||||
return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0)
|
||||
|
||||
|
||||
env = StubEnv()
|
||||
policy = StubPolicy()
|
||||
policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=[],
|
||||
out_keys=["action"],
|
||||
)
|
||||
|
||||
# Keep track to make sure the policy is called the expected number of times
|
||||
rollout = env.rollout(rollout_max_steps, policy)
|
||||
|
||||
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
|
||||
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
|
||||
assert torch.equal(rollout['observation'].flatten(), torch.arange(terminate_at + 1))
|
||||
|
|
Loading…
Reference in New Issue