backup wip

This commit is contained in:
Alexander Soare 2024-03-19 16:02:09 +00:00
parent 88347965c2
commit ea17f4ce50
11 changed files with 71 additions and 46 deletions

View File

@ -49,9 +49,9 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
@property @property
def stats_patterns(self) -> dict: def stats_patterns(self) -> dict:
return { return {
("observation", "state"): "b c -> 1 c", ("observation", "state"): "b c -> c",
("observation", "image"): "b c h w -> 1 c 1 1", ("observation", "image"): "b c h w -> c",
("action",): "b c -> 1 c", ("action",): "b c -> c",
} }
@property @property

View File

@ -113,11 +113,11 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
@property @property
def stats_patterns(self) -> dict: def stats_patterns(self) -> dict:
d = { d = {
("observation", "state"): "b c -> 1 c", ("observation", "state"): "b c -> c",
("action",): "b c -> 1 c", ("action",): "b c -> c",
} }
for cam in CAMERAS[self.dataset_id]: for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1" d[("observation", "image", cam)] = "b c h w -> c"
return d return d
@property @property

View File

@ -1,17 +1,31 @@
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
def make_env(cfg, seed=None, transform=None): def make_env(cfg, transform=None):
""" """
Provide seed to override the seed in the cfg (useful for batched environments). Provide seed to override the seed in the cfg (useful for batched environments).
""" """
# assert cfg.rollout_batch_size == 1, \
# """
# For the time being, rollout batch sizes of > 1 are not supported. This is because the SerialEnv rollout does not
# correctly handle terminated environments. If you really want to use a larger batch size, read on...
# When calling `EnvBase.rollout` with `break_when_any_done == True` all environments stop rolling out as soon as the
# first is terminated or truncated. This almost certainly results in incorrect success metrics, as all but the first
# environment get an opportunity to reach the goal. A possible work around is to comment out `if any_done: break`
# inf `EnvBase._rollout_stop_early`. One potential downside is that the environments `step` function will continue
# to be called and the outputs will continue to be added to the rollout.
# When calling `EnvBase.rollout` with `break_when_any_done == False` environments are reset when done.
# """
kwargs = { kwargs = {
"frame_skip": cfg.env.action_repeat, "frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels, "from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only, "pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size, "image_size": cfg.env.image_size,
"num_prev_obs": cfg.n_obs_steps - 1, "num_prev_obs": cfg.n_obs_steps - 1,
"seed": seed if seed is not None else cfg.seed, "seed": cfg.seed,
} }
if cfg.env.name == "simxarm": if cfg.env.name == "simxarm":
@ -33,22 +47,33 @@ def make_env(cfg, seed=None, transform=None):
else: else:
raise ValueError(cfg.env.name) raise ValueError(cfg.env.name)
env = clsfunc(**kwargs) def _make_env(seed):
nonlocal kwargs
kwargs["seed"] = seed
env = clsfunc(**kwargs)
# limit rollout to max_steps # limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
if transform is not None: if transform is not None:
# useful to add normalization # useful to add normalization
if isinstance(transform, Compose): if isinstance(transform, Compose):
for tf in transform: for tf in transform:
env.append_transform(tf.clone()) env.append_transform(tf.clone())
elif isinstance(transform, Transform): elif isinstance(transform, Transform):
env.append_transform(transform.clone()) env.append_transform(transform.clone())
else: else:
raise NotImplementedError() raise NotImplementedError()
return env return env
# return SerialEnv(
# cfg.rollout_batch_size,
# create_env_fn=_make_env,
# create_env_kwargs={
# "seed": env_seed for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
# },
# )
# def make_env(env_name, frame_skip, device, is_test=False): # def make_env(env_name, frame_skip, device, is_test=False):

View File

@ -30,7 +30,7 @@ class AbstractPolicy(nn.Module, ABC):
Should return a (batch_size, n_action_steps, *) tensor of actions. Should return a (batch_size, n_action_steps, *) tensor of actions.
""" """
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs) -> Tensor:
"""Inference step that makes multi-step policies compatible with their single-step environments. """Inference step that makes multi-step policies compatible with their single-step environments.
WARNING: In general, this should not be overriden. WARNING: In general, this should not be overriden.

View File

@ -11,14 +11,16 @@ hydra:
seed: 1337 seed: 1337
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index # batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
rollout_batch_size: 10 # NOTE: batch size of 1 is not yet supported! This is just a placeholder for future support. See
# `lerobot.common.envs.factory.make_env` for more information.
rollout_batch_size: 1
device: cuda # cpu device: cuda # cpu
prefetch: 4 prefetch: 4
eval_freq: ??? eval_freq: ???
save_freq: ??? save_freq: ???
eval_episodes: ??? eval_episodes: ???
save_video: false save_video: false
save_model: false save_model: true
save_buffer: false save_buffer: false
train_steps: ??? train_steps: ???
fps: ??? fps: ???
@ -31,7 +33,7 @@ env: ???
policy: ??? policy: ???
wandb: wandb:
enable: true enable: false
# Set to true to disable saving an artifact despite save_model == True # Set to true to disable saving an artifact despite save_model == True
disable_artifact: false disable_artifact: false
project: lerobot project: lerobot

View File

@ -22,8 +22,8 @@ keypoint_visible_rate: 1.0
obs_as_global_cond: True obs_as_global_cond: True
eval_episodes: 1 eval_episodes: 1
eval_freq: 10000 eval_freq: 5000
save_freq: 100000 save_freq: 5000
log_freq: 250 log_freq: 250
offline_steps: 1344000 offline_steps: 1344000

View File

@ -9,7 +9,7 @@ import numpy as np
import torch import torch
import tqdm import tqdm
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from torchrl.envs import EnvBase, SerialEnv from torchrl.envs import EnvBase
from torchrl.envs.batched_envs import BatchedEnvBase from torchrl.envs.batched_envs import BatchedEnvBase
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
@ -131,14 +131,7 @@ def eval(cfg: dict, out_dir=None):
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)
logging.info("make_env") logging.info("make_env")
env = SerialEnv( env = make_env(cfg, transform=offline_buffer.transform)
cfg.rollout_batch_size,
create_env_fn=make_env,
create_env_kwargs=[
{"cfg": cfg, "seed": env_seed, "transform": offline_buffer.transform}
for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)
if cfg.policy.pretrained_model_path: if cfg.policy.pretrained_model_path:
policy = make_policy(cfg) policy = make_policy(cfg)

View File

@ -7,7 +7,6 @@ import torch
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers import PrioritizedSliceSampler from torchrl.data.replay_buffers import PrioritizedSliceSampler
from torchrl.envs import SerialEnv
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
@ -149,14 +148,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info("make_env") logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform) env = make_env(cfg, transform=offline_buffer.transform)
env = SerialEnv(
cfg.rollout_batch_size,
create_env_fn=make_env,
create_env_kwargs=[
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)
logging.info("make_policy") logging.info("make_policy")
policy = make_policy(cfg) policy = make_policy(cfg)

Binary file not shown.

View File

@ -1,4 +1,5 @@
from omegaconf import open_dict
import pytest import pytest
from tensordict import TensorDict from tensordict import TensorDict
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
@ -7,7 +8,8 @@ from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase from torchrl.envs import EnvBase
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.abstract import AbstractPolicy
from .utils import DEVICE, init_config from .utils import DEVICE, init_config
@ -30,7 +32,19 @@ def test_factory(env_name, policy_name):
f"device={DEVICE}", f"device={DEVICE}",
] ]
) )
# Check that we can make the policy object.
policy = make_policy(cfg) policy = make_policy(cfg)
# Check that we run select_action and get the appropriate output.
if env_name == "simxarm":
# TODO(rcadene): Not implemented
return
if policy_name == "tdmpc":
# TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this.
with open_dict(cfg):
cfg['n_obs_steps'] = 1
offline_buffer = make_offline_buffer(cfg)
env = make_env(cfg, transform=offline_buffer.transform)
policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE))
def test_abstract_policy_forward(): def test_abstract_policy_forward():