Refactor configs to have env in seperate yaml + Fix training

This commit is contained in:
Cadene 2024-02-25 17:42:47 +00:00
parent eec134d72b
commit b16c334825
13 changed files with 146 additions and 54 deletions

View File

@ -10,28 +10,29 @@ def make_offline_buffer(cfg, sampler=None):
overwrite_sampler = sampler is not None overwrite_sampler = sampler is not None
if not overwrite_sampler: if not overwrite_sampler:
num_traj_per_batch = cfg.batch_size # // cfg.horizon # TODO(rcadene): move batch_size outside
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
sampler = PrioritizedSliceSampler( sampler = PrioritizedSliceSampler(
max_capacity=100_000, max_capacity=100_000,
alpha=cfg.per_alpha, alpha=cfg.policy.per_alpha,
beta=cfg.per_beta, beta=cfg.policy.per_beta,
num_slices=num_traj_per_batch, num_slices=num_traj_per_batch,
strict_length=False, strict_length=False,
) )
if cfg.env == "simxarm": if cfg.env.name == "simxarm":
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
offline_buffer = SimxarmExperienceReplay( offline_buffer = SimxarmExperienceReplay(
f"xarm_{cfg.task}_medium", f"xarm_{cfg.env.task}_medium",
# download="force", # download="force",
download=True, download=True,
streaming=False, streaming=False,
root="data", root="data",
sampler=sampler, sampler=sampler,
) )
elif cfg.env == "pusht": elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay( offline_buffer = PushtExperienceReplay(
"pusht", "pusht",
# download="force", # download="force",
@ -41,7 +42,7 @@ def make_offline_buffer(cfg, sampler=None):
sampler=sampler, sampler=sampler,
) )
else: else:
raise ValueError(cfg.env) raise ValueError(cfg.env.name)
if not overwrite_sampler: if not overwrite_sampler:
num_steps = len(offline_buffer) num_steps = len(offline_buffer)

View File

@ -1,7 +1,5 @@
from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.transforms import StepCounter, TransformedEnv
from lerobot.common.envs.pusht import PushtEnv
from lerobot.common.envs.simxarm import SimxarmEnv
from lerobot.common.envs.transforms import Prod from lerobot.common.envs.transforms import Prod
@ -14,9 +12,13 @@ def make_env(cfg):
} }
if cfg.env.name == "simxarm": if cfg.env.name == "simxarm":
from lerobot.common.envs.simxarm import SimxarmEnv
kwargs["task"] = cfg.env.task kwargs["task"] = cfg.env.task
clsfunc = SimxarmEnv clsfunc = SimxarmEnv
elif cfg.env.name == "pusht": elif cfg.env.name == "pusht":
from lerobot.common.envs.pusht import PushtEnv
clsfunc = PushtEnv clsfunc = PushtEnv
else: else:
raise ValueError(cfg.env.name) raise ValueError(cfg.env.name)

View File

@ -50,7 +50,7 @@ def print_run(cfg, reward=None):
) )
kvs = [ kvs = [
("task", cfg.task), ("task", cfg.env.task),
("train steps", f"{int(cfg.train_steps * cfg.env.action_repeat):,}"), ("train steps", f"{int(cfg.train_steps * cfg.env.action_repeat):,}"),
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])), # ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
# ('actions', cfg.action_dim), # ('actions', cfg.action_dim),
@ -72,7 +72,7 @@ def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list.""" """Return a wandb-safe group name for logging. Optionally returns group name as list."""
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] # lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
lst = [ lst = [
f"env:{cfg.env}", f"env:{cfg.env.name}",
f"seed:{cfg.seed}", f"seed:{cfg.seed}",
] ]
return lst if return_list else "-".join(lst) return lst if return_list else "-".join(lst)

View File

@ -0,0 +1,44 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
class DiffusionPolicy(nn.Module):
def __init__(
self,
shape_meta: dict,
noise_scheduler: DDPMScheduler,
obs_encoder: MultiImageObsEncoder,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
obs_as_global_cond=True,
diffusion_step_embed_dim=256,
down_dims=(256, 512, 1024),
kernel_size=5,
n_groups=8,
cond_predict_scale=True,
# parameters passed to step
**kwargs,
):
super().__init__()
self.diffusion = DiffusionUnetImagePolicy(
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
obs_encoder=obs_encoder,
horizon=horizon,
n_action_steps=n_action_steps,
n_obs_steps=n_obs_steps,
num_inference_steps=num_inference_steps,
obs_as_global_cond=obs_as_global_cond,
diffusion_step_embed_dim=diffusion_step_embed_dim,
down_dims=down_dims,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
# parameters passed to step
**kwargs,
)

View File

@ -1,9 +1,12 @@
from lerobot.common.policies.tdmpc import TDMPC
def make_policy(cfg): def make_policy(cfg):
if cfg.policy.name == "tdmpc": if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc import TDMPC
policy = TDMPC(cfg.policy) policy = TDMPC(cfg.policy)
elif cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion import DiffusionPolicy
policy = DiffusionPolicy(cfg.policy)
else: else:
raise ValueError(cfg.policy.name) raise ValueError(cfg.policy.name)

View File

@ -1,31 +1,24 @@
defaults:
- _self_
- env: simxarm
hydra: hydra:
run: run:
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}_${hydra.job.name} dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}_${hydra.job.name}
job:
name: default
seed: 1337 seed: 1337
device: cuda device: cuda
buffer_device: cuda buffer_device: cuda
eval_freq: 1000 eval_freq: ???
save_freq: 10000 save_freq: ???
eval_episodes: 20 eval_episodes: ???
save_video: false save_video: false
save_model: false save_model: false
save_buffer: false save_buffer: false
train_steps: 50000 train_steps: ???
fps: 15 fps: ???
env:
name: simxarm
task: lift
from_pixels: True
pixels_only: False
image_size: 84
action_repeat: 2
episode_length: 25
fps: ${fps}
env: ???
policy: policy:
name: tdmpc name: tdmpc
@ -42,8 +35,8 @@ policy:
frame_stack: 1 frame_stack: 1
num_channels: 32 num_channels: 32
img_size: ${env.image_size} img_size: ${env.image_size}
state_dim: 4 state_dim: ???
action_dim: 4 action_dim: ???
# planning # planning
mpc: true mpc: true

View File

@ -1,6 +1,4 @@
defaults: # @package _global_
- default
- _self_
hydra: hydra:
job: job:
@ -9,11 +7,15 @@ hydra:
eval_episodes: 50 eval_episodes: 50
eval_freq: 7500 eval_freq: 7500
save_freq: 75000 save_freq: 75000
train_steps: 50000 # TODO: same as simxarm, need to adjust
fps: 10 fps: 10
env: env:
name: pusht name: pusht
task: pusht task: pusht
from_pixels: True
pixels_only: False
image_size: 96 image_size: 96
action_repeat: 1 action_repeat: 1
episode_length: 300 episode_length: 300

26
lerobot/configs/env/simxarm.yaml vendored Normal file
View File

@ -0,0 +1,26 @@
# @package _global_
hydra:
job:
name: simxarm
eval_episodes: 20
eval_freq: 1000
save_freq: 10000
train_steps: 50000
fps: 15
env:
name: simxarm
task: lift
from_pixels: True
pixels_only: False
image_size: 84
action_repeat: 2
episode_length: 25
fps: ${fps}
policy:
state_dim: 4
action_dim: 4

View File

@ -62,13 +62,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)
if cfg.balanced_sampling: # TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
num_traj_per_batch = cfg.batch_size if cfg.policy.balanced_sampling:
num_traj_per_batch = cfg.policy.batch_size
online_sampler = PrioritizedSliceSampler( online_sampler = PrioritizedSliceSampler(
max_capacity=100_000, max_capacity=100_000,
alpha=cfg.per_alpha, alpha=cfg.policy.per_alpha,
beta=cfg.per_beta, beta=cfg.policy.per_beta,
num_slices=num_traj_per_batch, num_slices=num_traj_per_batch,
strict_length=True, strict_length=True,
) )
@ -92,7 +93,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
_step = step + num_updates _step = step + num_updates
rollout_metrics = {} rollout_metrics = {}
if step >= cfg.offline_steps: # TODO(rcadene): move offline_steps outside policy
if step >= cfg.policy.offline_steps:
is_offline = False is_offline = False
# TODO: use SyncDataCollector for that? # TODO: use SyncDataCollector for that?
@ -118,7 +120,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
"avg_max_reward": np.nanmean(ep_max_reward), "avg_max_reward": np.nanmean(ep_max_reward),
"pc_success": np.nanmean(ep_success) * 100, "pc_success": np.nanmean(ep_success) * 100,
} }
num_updates = len(rollout) * cfg.utd num_updates = len(rollout) * cfg.policy.utd
_step = min(step + len(rollout), cfg.train_steps) _step = min(step + len(rollout), cfg.train_steps)
# Update model # Update model
@ -128,8 +130,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
else: else:
train_metrics = policy.update( train_metrics = policy.update(
online_buffer, online_buffer,
step + i // cfg.utd, step + i // cfg.policy.utd,
demo_buffer=offline_buffer if cfg.balanced_sampling else None, demo_buffer=(
offline_buffer if cfg.policy.balanced_sampling else None
),
) )
# Log training metrics # Log training metrics

17
tests/test_datasets.py Normal file
View File

@ -0,0 +1,17 @@
import pytest
from lerobot.common.datasets.factory import make_offline_buffer
from .utils import init_config
@pytest.mark.parametrize(
"env_name",
[
"simxarm",
"pusht",
],
)
def test_factory(env_name):
cfg = init_config(overrides=[f"env={env_name}"])
offline_buffer = make_offline_buffer(cfg)

View File

@ -78,13 +78,13 @@ def test_pusht(from_pixels, pixels_only):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"config_name", "env_name",
[ [
"default", "simxarm",
"pusht", "pusht",
], ],
) )
def test_factory(config_name): def test_factory(env_name):
cfg = init_config(config_name) cfg = init_config(overrides=[f"env={env_name}"])
env = make_env(cfg) env = make_env(cfg)
check_env_specs(env) check_env_specs(env)

View File

@ -6,12 +6,12 @@ from .utils import init_config
@pytest.mark.parametrize( @pytest.mark.parametrize(
"config_name", "env_name",
[ [
"default", "simxarm",
"pusht", "pusht",
], ],
) )
def test_factory(config_name): def test_factory(env_name):
cfg = init_config(config_name) cfg = init_config(overrides=[f"env={env_name}"])
policy = make_policy(cfg) policy = make_policy(cfg)

View File

@ -4,8 +4,8 @@ from hydra import compose, initialize
CONFIG_PATH = "../lerobot/configs" CONFIG_PATH = "../lerobot/configs"
def init_config(config_name): def init_config(config_name="default", overrides=None):
hydra.core.global_hydra.GlobalHydra.instance().clear() hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=CONFIG_PATH) initialize(config_path=CONFIG_PATH)
cfg = compose(config_name=config_name) cfg = compose(config_name=config_name, overrides=overrides)
return cfg return cfg