Refactor configs to have env in seperate yaml + Fix training
This commit is contained in:
parent
eec134d72b
commit
b16c334825
|
@ -10,28 +10,29 @@ def make_offline_buffer(cfg, sampler=None):
|
||||||
overwrite_sampler = sampler is not None
|
overwrite_sampler = sampler is not None
|
||||||
|
|
||||||
if not overwrite_sampler:
|
if not overwrite_sampler:
|
||||||
num_traj_per_batch = cfg.batch_size # // cfg.horizon
|
# TODO(rcadene): move batch_size outside
|
||||||
|
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
||||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||||
sampler = PrioritizedSliceSampler(
|
sampler = PrioritizedSliceSampler(
|
||||||
max_capacity=100_000,
|
max_capacity=100_000,
|
||||||
alpha=cfg.per_alpha,
|
alpha=cfg.policy.per_alpha,
|
||||||
beta=cfg.per_beta,
|
beta=cfg.policy.per_beta,
|
||||||
num_slices=num_traj_per_batch,
|
num_slices=num_traj_per_batch,
|
||||||
strict_length=False,
|
strict_length=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.env == "simxarm":
|
if cfg.env.name == "simxarm":
|
||||||
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
||||||
offline_buffer = SimxarmExperienceReplay(
|
offline_buffer = SimxarmExperienceReplay(
|
||||||
f"xarm_{cfg.task}_medium",
|
f"xarm_{cfg.env.task}_medium",
|
||||||
# download="force",
|
# download="force",
|
||||||
download=True,
|
download=True,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
root="data",
|
root="data",
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
)
|
)
|
||||||
elif cfg.env == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
offline_buffer = PushtExperienceReplay(
|
offline_buffer = PushtExperienceReplay(
|
||||||
"pusht",
|
"pusht",
|
||||||
# download="force",
|
# download="force",
|
||||||
|
@ -41,7 +42,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env)
|
raise ValueError(cfg.env.name)
|
||||||
|
|
||||||
if not overwrite_sampler:
|
if not overwrite_sampler:
|
||||||
num_steps = len(offline_buffer)
|
num_steps = len(offline_buffer)
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||||
|
|
||||||
from lerobot.common.envs.pusht import PushtEnv
|
|
||||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
|
||||||
from lerobot.common.envs.transforms import Prod
|
from lerobot.common.envs.transforms import Prod
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,9 +12,13 @@ def make_env(cfg):
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.env.name == "simxarm":
|
if cfg.env.name == "simxarm":
|
||||||
|
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||||
|
|
||||||
kwargs["task"] = cfg.env.task
|
kwargs["task"] = cfg.env.task
|
||||||
clsfunc = SimxarmEnv
|
clsfunc = SimxarmEnv
|
||||||
elif cfg.env.name == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
|
from lerobot.common.envs.pusht import PushtEnv
|
||||||
|
|
||||||
clsfunc = PushtEnv
|
clsfunc = PushtEnv
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env.name)
|
raise ValueError(cfg.env.name)
|
||||||
|
|
|
@ -50,7 +50,7 @@ def print_run(cfg, reward=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
kvs = [
|
kvs = [
|
||||||
("task", cfg.task),
|
("task", cfg.env.task),
|
||||||
("train steps", f"{int(cfg.train_steps * cfg.env.action_repeat):,}"),
|
("train steps", f"{int(cfg.train_steps * cfg.env.action_repeat):,}"),
|
||||||
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
|
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
|
||||||
# ('actions', cfg.action_dim),
|
# ('actions', cfg.action_dim),
|
||||||
|
@ -72,7 +72,7 @@ def cfg_to_group(cfg, return_list=False):
|
||||||
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
||||||
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
||||||
lst = [
|
lst = [
|
||||||
f"env:{cfg.env}",
|
f"env:{cfg.env.name}",
|
||||||
f"seed:{cfg.seed}",
|
f"seed:{cfg.seed}",
|
||||||
]
|
]
|
||||||
return lst if return_list else "-".join(lst)
|
return lst if return_list else "-".join(lst)
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionPolicy(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
shape_meta: dict,
|
||||||
|
noise_scheduler: DDPMScheduler,
|
||||||
|
obs_encoder: MultiImageObsEncoder,
|
||||||
|
horizon,
|
||||||
|
n_action_steps,
|
||||||
|
n_obs_steps,
|
||||||
|
num_inference_steps=None,
|
||||||
|
obs_as_global_cond=True,
|
||||||
|
diffusion_step_embed_dim=256,
|
||||||
|
down_dims=(256, 512, 1024),
|
||||||
|
kernel_size=5,
|
||||||
|
n_groups=8,
|
||||||
|
cond_predict_scale=True,
|
||||||
|
# parameters passed to step
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.diffusion = DiffusionUnetImagePolicy(
|
||||||
|
shape_meta=shape_meta,
|
||||||
|
noise_scheduler=noise_scheduler,
|
||||||
|
obs_encoder=obs_encoder,
|
||||||
|
horizon=horizon,
|
||||||
|
n_action_steps=n_action_steps,
|
||||||
|
n_obs_steps=n_obs_steps,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
obs_as_global_cond=obs_as_global_cond,
|
||||||
|
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
||||||
|
down_dims=down_dims,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale,
|
||||||
|
# parameters passed to step
|
||||||
|
**kwargs,
|
||||||
|
)
|
|
@ -1,9 +1,12 @@
|
||||||
from lerobot.common.policies.tdmpc import TDMPC
|
|
||||||
|
|
||||||
|
|
||||||
def make_policy(cfg):
|
def make_policy(cfg):
|
||||||
if cfg.policy.name == "tdmpc":
|
if cfg.policy.name == "tdmpc":
|
||||||
|
from lerobot.common.policies.tdmpc import TDMPC
|
||||||
|
|
||||||
policy = TDMPC(cfg.policy)
|
policy = TDMPC(cfg.policy)
|
||||||
|
elif cfg.policy.name == "diffusion":
|
||||||
|
from lerobot.common.policies.diffusion import DiffusionPolicy
|
||||||
|
|
||||||
|
policy = DiffusionPolicy(cfg.policy)
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.policy.name)
|
raise ValueError(cfg.policy.name)
|
||||||
|
|
||||||
|
|
|
@ -1,31 +1,24 @@
|
||||||
|
defaults:
|
||||||
|
- _self_
|
||||||
|
- env: simxarm
|
||||||
|
|
||||||
hydra:
|
hydra:
|
||||||
run:
|
run:
|
||||||
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}_${hydra.job.name}
|
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}_${hydra.job.name}
|
||||||
job:
|
|
||||||
name: default
|
|
||||||
|
|
||||||
seed: 1337
|
seed: 1337
|
||||||
device: cuda
|
device: cuda
|
||||||
buffer_device: cuda
|
buffer_device: cuda
|
||||||
eval_freq: 1000
|
eval_freq: ???
|
||||||
save_freq: 10000
|
save_freq: ???
|
||||||
eval_episodes: 20
|
eval_episodes: ???
|
||||||
save_video: false
|
save_video: false
|
||||||
save_model: false
|
save_model: false
|
||||||
save_buffer: false
|
save_buffer: false
|
||||||
train_steps: 50000
|
train_steps: ???
|
||||||
fps: 15
|
fps: ???
|
||||||
|
|
||||||
env:
|
|
||||||
name: simxarm
|
|
||||||
task: lift
|
|
||||||
from_pixels: True
|
|
||||||
pixels_only: False
|
|
||||||
image_size: 84
|
|
||||||
action_repeat: 2
|
|
||||||
episode_length: 25
|
|
||||||
fps: ${fps}
|
|
||||||
|
|
||||||
|
env: ???
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
name: tdmpc
|
name: tdmpc
|
||||||
|
@ -42,8 +35,8 @@ policy:
|
||||||
frame_stack: 1
|
frame_stack: 1
|
||||||
num_channels: 32
|
num_channels: 32
|
||||||
img_size: ${env.image_size}
|
img_size: ${env.image_size}
|
||||||
state_dim: 4
|
state_dim: ???
|
||||||
action_dim: 4
|
action_dim: ???
|
||||||
|
|
||||||
# planning
|
# planning
|
||||||
mpc: true
|
mpc: true
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
defaults:
|
# @package _global_
|
||||||
- default
|
|
||||||
- _self_
|
|
||||||
|
|
||||||
hydra:
|
hydra:
|
||||||
job:
|
job:
|
||||||
|
@ -9,11 +7,15 @@ hydra:
|
||||||
eval_episodes: 50
|
eval_episodes: 50
|
||||||
eval_freq: 7500
|
eval_freq: 7500
|
||||||
save_freq: 75000
|
save_freq: 75000
|
||||||
|
train_steps: 50000 # TODO: same as simxarm, need to adjust
|
||||||
|
|
||||||
fps: 10
|
fps: 10
|
||||||
|
|
||||||
env:
|
env:
|
||||||
name: pusht
|
name: pusht
|
||||||
task: pusht
|
task: pusht
|
||||||
|
from_pixels: True
|
||||||
|
pixels_only: False
|
||||||
image_size: 96
|
image_size: 96
|
||||||
action_repeat: 1
|
action_repeat: 1
|
||||||
episode_length: 300
|
episode_length: 300
|
|
@ -0,0 +1,26 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
hydra:
|
||||||
|
job:
|
||||||
|
name: simxarm
|
||||||
|
|
||||||
|
eval_episodes: 20
|
||||||
|
eval_freq: 1000
|
||||||
|
save_freq: 10000
|
||||||
|
train_steps: 50000
|
||||||
|
|
||||||
|
fps: 15
|
||||||
|
|
||||||
|
env:
|
||||||
|
name: simxarm
|
||||||
|
task: lift
|
||||||
|
from_pixels: True
|
||||||
|
pixels_only: False
|
||||||
|
image_size: 84
|
||||||
|
action_repeat: 2
|
||||||
|
episode_length: 25
|
||||||
|
fps: ${fps}
|
||||||
|
|
||||||
|
policy:
|
||||||
|
state_dim: 4
|
||||||
|
action_dim: 4
|
|
@ -62,13 +62,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
|
||||||
if cfg.balanced_sampling:
|
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
|
||||||
num_traj_per_batch = cfg.batch_size
|
if cfg.policy.balanced_sampling:
|
||||||
|
num_traj_per_batch = cfg.policy.batch_size
|
||||||
|
|
||||||
online_sampler = PrioritizedSliceSampler(
|
online_sampler = PrioritizedSliceSampler(
|
||||||
max_capacity=100_000,
|
max_capacity=100_000,
|
||||||
alpha=cfg.per_alpha,
|
alpha=cfg.policy.per_alpha,
|
||||||
beta=cfg.per_beta,
|
beta=cfg.policy.per_beta,
|
||||||
num_slices=num_traj_per_batch,
|
num_slices=num_traj_per_batch,
|
||||||
strict_length=True,
|
strict_length=True,
|
||||||
)
|
)
|
||||||
|
@ -92,7 +93,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
_step = step + num_updates
|
_step = step + num_updates
|
||||||
rollout_metrics = {}
|
rollout_metrics = {}
|
||||||
|
|
||||||
if step >= cfg.offline_steps:
|
# TODO(rcadene): move offline_steps outside policy
|
||||||
|
if step >= cfg.policy.offline_steps:
|
||||||
is_offline = False
|
is_offline = False
|
||||||
|
|
||||||
# TODO: use SyncDataCollector for that?
|
# TODO: use SyncDataCollector for that?
|
||||||
|
@ -118,7 +120,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
"avg_max_reward": np.nanmean(ep_max_reward),
|
"avg_max_reward": np.nanmean(ep_max_reward),
|
||||||
"pc_success": np.nanmean(ep_success) * 100,
|
"pc_success": np.nanmean(ep_success) * 100,
|
||||||
}
|
}
|
||||||
num_updates = len(rollout) * cfg.utd
|
num_updates = len(rollout) * cfg.policy.utd
|
||||||
_step = min(step + len(rollout), cfg.train_steps)
|
_step = min(step + len(rollout), cfg.train_steps)
|
||||||
|
|
||||||
# Update model
|
# Update model
|
||||||
|
@ -128,8 +130,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
else:
|
else:
|
||||||
train_metrics = policy.update(
|
train_metrics = policy.update(
|
||||||
online_buffer,
|
online_buffer,
|
||||||
step + i // cfg.utd,
|
step + i // cfg.policy.utd,
|
||||||
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
|
demo_buffer=(
|
||||||
|
offline_buffer if cfg.policy.balanced_sampling else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log training metrics
|
# Log training metrics
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
|
|
||||||
|
from .utils import init_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env_name",
|
||||||
|
[
|
||||||
|
"simxarm",
|
||||||
|
"pusht",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_factory(env_name):
|
||||||
|
cfg = init_config(overrides=[f"env={env_name}"])
|
||||||
|
offline_buffer = make_offline_buffer(cfg)
|
|
@ -78,13 +78,13 @@ def test_pusht(from_pixels, pixels_only):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"config_name",
|
"env_name",
|
||||||
[
|
[
|
||||||
"default",
|
"simxarm",
|
||||||
"pusht",
|
"pusht",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_factory(config_name):
|
def test_factory(env_name):
|
||||||
cfg = init_config(config_name)
|
cfg = init_config(overrides=[f"env={env_name}"])
|
||||||
env = make_env(cfg)
|
env = make_env(cfg)
|
||||||
check_env_specs(env)
|
check_env_specs(env)
|
||||||
|
|
|
@ -6,12 +6,12 @@ from .utils import init_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"config_name",
|
"env_name",
|
||||||
[
|
[
|
||||||
"default",
|
"simxarm",
|
||||||
"pusht",
|
"pusht",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_factory(config_name):
|
def test_factory(env_name):
|
||||||
cfg = init_config(config_name)
|
cfg = init_config(overrides=[f"env={env_name}"])
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
|
|
@ -4,8 +4,8 @@ from hydra import compose, initialize
|
||||||
CONFIG_PATH = "../lerobot/configs"
|
CONFIG_PATH = "../lerobot/configs"
|
||||||
|
|
||||||
|
|
||||||
def init_config(config_name):
|
def init_config(config_name="default", overrides=None):
|
||||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||||
initialize(config_path=CONFIG_PATH)
|
initialize(config_path=CONFIG_PATH)
|
||||||
cfg = compose(config_name=config_name)
|
cfg = compose(config_name=config_name, overrides=overrides)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
Loading…
Reference in New Issue