Refactor env queue, Training diffusion works (Still not converging)

This commit is contained in:
Remi Cadene 2024-03-04 10:59:43 +00:00
parent fddd9f0311
commit cfc304e870
11 changed files with 96 additions and 111 deletions

View File

@ -69,7 +69,7 @@ def make_offline_buffer(cfg, sampler=None):
sampler=sampler, sampler=sampler,
batch_size=batch_size, batch_size=batch_size,
pin_memory=pin_memory, pin_memory=pin_memory,
prefetch=prefetch, prefetch=prefetch if isinstance(prefetch, int) else None,
) )
elif cfg.env.name == "pusht": elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay( offline_buffer = PushtExperienceReplay(
@ -79,7 +79,7 @@ def make_offline_buffer(cfg, sampler=None):
sampler=sampler, sampler=sampler,
batch_size=batch_size, batch_size=batch_size,
pin_memory=pin_memory, pin_memory=pin_memory,
prefetch=prefetch, prefetch=prefetch if isinstance(prefetch, int) else None,
) )
else: else:
raise ValueError(cfg.env.name) raise ValueError(cfg.env.name)

View File

@ -143,13 +143,24 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
in_keys=[ in_keys=[
# ("observation", "image"), # ("observation", "image"),
("observation", "state"), ("observation", "state"),
# TODO(rcadene): for tdmpc, we might want image and state
# ("next", "observation", "image"), # ("next", "observation", "image"),
("next", "observation", "state"), # ("next", "observation", "state"),
("action"), ("action"),
], ],
mode="min_max", mode="min_max",
) )
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, min_max_spec
transform.stats["observation", "state", "min"] = torch.tensor(
[13.456424, 32.938293], dtype=torch.float32
)
transform.stats["observation", "state", "max"] = torch.tensor(
[496.14618, 510.9579], dtype=torch.float32
)
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
if writer is None: if writer is None:
writer = ImmutableDatasetWriter() writer = ImmutableDatasetWriter()
if collate_fn is None: if collate_fn is None:

View File

@ -7,6 +7,8 @@ def make_env(cfg, transform=None):
"from_pixels": cfg.env.from_pixels, "from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only, "pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size, "image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed,
} }
if cfg.env.name == "simxarm": if cfg.env.name == "simxarm":
@ -17,6 +19,8 @@ def make_env(cfg, transform=None):
elif cfg.env.name == "pusht": elif cfg.env.name == "pusht":
from lerobot.common.envs.pusht import PushtEnv from lerobot.common.envs.pusht import PushtEnv
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
clsfunc = PushtEnv clsfunc = PushtEnv
else: else:
raise ValueError(cfg.env.name) raise ValueError(cfg.env.name)

View File

@ -101,14 +101,18 @@ class PushtEnv(EnvBase):
obs = self._format_raw_obs(raw_obs) obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0: if self.num_prev_obs > 0:
# remove all previous observations stacked_obs = {}
if "image" in obs: if "image" in obs:
self._prev_obs_image_queue.clear() self._prev_obs_image_queue = deque(
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs: if "state" in obs:
self._prev_obs_state_queue.clear() self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
# copy the current observation n times )
obs = self._stack_prev_obs(obs) stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict( td = TensorDict(
{ {
@ -121,40 +125,6 @@ class PushtEnv(EnvBase):
raise NotImplementedError() raise NotImplementedError()
return td return td
def _stack_prev_obs(self, obs):
"""When the queue is empty, copy the current observation n times."""
assert self.num_prev_obs > 0
def stack_update_queue(prev_obs_queue, obs, num_prev_obs):
# get n most recent observations
prev_obs = list(prev_obs_queue)[-num_prev_obs:]
# if not enough observations, copy the oldest observation until we obtain n observations
if len(prev_obs) == 0:
prev_obs = [obs] * num_prev_obs # queue is empty when env reset
elif len(prev_obs) < num_prev_obs:
prev_obs = [prev_obs[0] for _ in range(num_prev_obs - len(prev_obs))] + prev_obs
# stack n most recent observations with the current observation
stacked_obs = torch.stack(prev_obs + [obs], dim=0)
# add current observation to the queue
# automatically remove oldest observation when queue is full
prev_obs_queue.appendleft(obs)
return stacked_obs
stacked_obs = {}
if "image" in obs:
stacked_obs["image"] = stack_update_queue(
self._prev_obs_image_queue, obs["image"], self.num_prev_obs
)
if "state" in obs:
stacked_obs["state"] = stack_update_queue(
self._prev_obs_state_queue, obs["state"], self.num_prev_obs
)
return stacked_obs
def _step(self, tensordict: TensorDict): def _step(self, tensordict: TensorDict):
td = tensordict td = tensordict
action = td["action"].numpy() action = td["action"].numpy()
@ -176,7 +146,14 @@ class PushtEnv(EnvBase):
obs = self._format_raw_obs(raw_obs) obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0: if self.num_prev_obs > 0:
obs = self._stack_prev_obs(obs) stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict( td = TensorDict(
{ {

View File

@ -1,51 +1,11 @@
import contextlib import logging
import os import os
from pathlib import Path from pathlib import Path
import numpy as np
from omegaconf import OmegaConf from omegaconf import OmegaConf
from termcolor import colored from termcolor import colored
def make_dir(dir_path):
"""Create directory if it does not already exist."""
with contextlib.suppress(OSError):
dir_path.mkdir(parents=True, exist_ok=True)
return dir_path
def print_run(cfg, reward=None):
"""Pretty-printing of run information. Call at start of training."""
prefix, color, attrs = " ", "green", ["bold"]
def limstr(s, maxlen=32):
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
def pprint(k, v):
print(
prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs),
limstr(v),
)
kvs = [
("task", cfg.env.task),
("offline_steps", f"{cfg.offline_steps}"),
("online_steps", f"{cfg.online_steps}"),
("action_repeat", f"{cfg.env.action_repeat}"),
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
# ('actions', cfg.action_dim),
# ('experiment', cfg.exp_name),
]
if reward is not None:
kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"])))
w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
div = "-" * w
print(div)
for k, v in kvs:
pprint(k, v)
print(div)
def cfg_to_group(cfg, return_list=False): def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list.""" """Return a wandb-safe group name for logging. Optionally returns group name as list."""
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] # lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
@ -71,13 +31,12 @@ class Logger:
self._seed = cfg.seed self._seed = cfg.seed
self._cfg = cfg self._cfg = cfg
self._eval = [] self._eval = []
print_run(cfg)
project = cfg.get("wandb", {}).get("project") project = cfg.get("wandb", {}).get("project")
entity = cfg.get("wandb", {}).get("entity") entity = cfg.get("wandb", {}).get("entity")
enable_wandb = cfg.get("wandb", {}).get("enable", False) enable_wandb = cfg.get("wandb", {}).get("enable", False)
run_offline = not enable_wandb or not project or not entity run_offline = not enable_wandb or not project or not entity
if run_offline: if run_offline:
print(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None self._wandb = None
else: else:
os.environ["WANDB_SILENT"] = "true" os.environ["WANDB_SILENT"] = "true"
@ -134,7 +93,6 @@ class Logger:
self.save_buffer(buffer, identifier="buffer") self.save_buffer(buffer, identifier="buffer")
if self._wandb: if self._wandb:
self._wandb.finish() self._wandb.finish()
print_run(self._cfg, self._eval[-1][-1])
def log_dict(self, d, step, mode="train"): def log_dict(self, d, step, mode="train"):
assert mode in {"train", "eval"} assert mode in {"train", "eval"}

View File

@ -4,10 +4,8 @@ import time
import hydra import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
from .multi_image_obs_encoder import MultiImageObsEncoder from .multi_image_obs_encoder import MultiImageObsEncoder
@ -39,8 +37,8 @@ class DiffusionPolicy(nn.Module):
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
noise_scheduler = DDPMScheduler(**cfg_noise_scheduler) noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model = get_resnet(**cfg_rgb_model) rgb_model = hydra.utils.instantiate(cfg_rgb_model)
obs_encoder = MultiImageObsEncoder( obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model, rgb_model=rgb_model,
**cfg_obs_encoder, **cfg_obs_encoder,
@ -127,16 +125,36 @@ class DiffusionPolicy(nn.Module):
# (t h) ... -> t h ... # (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous() batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
# |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
# |o|o| observations: 2
# | |a|a|a|a|a|a|a|a| actions executed: 8
# |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
# note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
image = batch["observation", "image"]
state = batch["observation", "state"]
action = batch["action"]
assert image.shape[1] == horizon
assert state.shape[1] == horizon
assert action.shape[1] == horizon
if not (horizon == 16 and self.cfg.n_obs_steps == 2):
raise NotImplementedError()
# keep first 2 observations of the slice corresponding to t=[-1,0]
image = image[:, : self.cfg.n_obs_steps]
state = state[:, : self.cfg.n_obs_steps]
out = { out = {
"obs": { "obs": {
"image": batch["observation", "image"].to(self.device, non_blocking=True), "image": image.to(self.device, non_blocking=True),
"agent_pos": batch["observation", "state"].to(self.device, non_blocking=True), "agent_pos": state.to(self.device, non_blocking=True),
}, },
"action": batch["action"].to(self.device, non_blocking=True), "action": action.to(self.device, non_blocking=True),
} }
return out return out
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices) batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time data_s = time.time() - start_time

View File

@ -1,7 +1,7 @@
defaults: defaults:
- _self_ - _self_
- env: simxarm - env: pusht
- policy: tdmpc - policy: diffusion
hydra: hydra:
run: run:
@ -22,6 +22,7 @@ save_buffer: false
train_steps: ??? train_steps: ???
fps: ??? fps: ???
n_action_steps: ???
env: ??? env: ???
policy: ??? policy: ???

View File

@ -21,7 +21,7 @@ past_action_visible: False
keypoint_visible_rate: 1.0 keypoint_visible_rate: 1.0
obs_as_global_cond: True obs_as_global_cond: True
eval_episodes: 50 eval_episodes: 1
eval_freq: 10000 eval_freq: 10000
save_freq: 100000 save_freq: 100000
log_freq: 250 log_freq: 250
@ -40,8 +40,8 @@ policy:
num_inference_steps: 100 num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond} obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null # crop_shape: null
diffusion_step_embed_dim: 128 diffusion_step_embed_dim: 256 # before 128
down_dims: [512, 1024, 2048] down_dims: [256, 512, 1024] # before [512, 1024, 2048]
kernel_size: 5 kernel_size: 5
n_groups: 8 n_groups: 8
cond_predict_scale: True cond_predict_scale: True
@ -62,7 +62,7 @@ policy:
grad_clip_norm: 0 grad_clip_norm: 0
noise_scheduler: noise_scheduler:
# _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
num_train_timesteps: 100 num_train_timesteps: 100
beta_start: 0.0001 beta_start: 0.0001
beta_end: 0.02 beta_end: 0.02
@ -74,16 +74,16 @@ noise_scheduler:
obs_encoder: obs_encoder:
# _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder # _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder
shape_meta: ${shape_meta} shape_meta: ${shape_meta}
resize_shape: null # resize_shape: null
crop_shape: [76, 76] # crop_shape: [76, 76]
# constant center crop # constant center crop
random_crop: True # random_crop: True
use_group_norm: True use_group_norm: True
share_rgb_model: False share_rgb_model: False
imagenet_norm: False # TODO(rcadene): was set to True imagenet_norm: True
rgb_model: rgb_model:
#_target_: diffusion_policy.model.vision.model_getter.get_resnet _target_: diffusion_policy.model.vision.model_getter.get_resnet
name: resnet18 name: resnet18
weights: null weights: null

View File

@ -1,5 +1,7 @@
# @package _global_ # @package _global_
n_action_steps: 1
policy: policy:
name: tdmpc name: tdmpc

View File

@ -137,7 +137,7 @@ def eval(cfg: dict, out_dir=None):
save_video=True, save_video=True,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps, fps=cfg.env.fps,
max_steps=cfg.env.episode_length, max_steps=cfg.env.episode_length // cfg.n_action_steps,
num_episodes=cfg.eval_episodes, num_episodes=cfg.eval_episodes,
) )
print(metrics) print(metrics)

View File

@ -119,7 +119,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed) set_seed(cfg.seed)
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
logging.info("make_offline_buffer") logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)
@ -149,6 +148,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info("make_policy") logging.info("make_policy")
policy = make_policy(cfg) policy = make_policy(cfg)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
td_policy = TensorDictModule( td_policy = TensorDictModule(
policy, policy,
in_keys=["observation", "step_count"], in_keys=["observation", "step_count"],
@ -158,6 +160,16 @@ def train(cfg: dict, out_dir=None, job_name=None):
# log metrics to terminal and wandb # log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg) logger = Logger(out_dir, job_name, cfg)
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}")
logging.info(f"{cfg.env.action_repeat=}")
logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})")
logging.info(f"{offline_buffer.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
step = 0 # number of policy update step = 0 # number of policy update
is_offline = True is_offline = True
@ -175,6 +187,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
env, env,
td_policy, td_policy,
num_episodes=cfg.eval_episodes, num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
return_first_video=True, return_first_video=True,
) )
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
@ -199,11 +212,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO: add configurable number of rollout? (default=1) # TODO: add configurable number of rollout? (default=1)
with torch.no_grad(): with torch.no_grad():
rollout = env.rollout( rollout = env.rollout(
max_steps=cfg.env.episode_length, max_steps=cfg.env.episode_length // cfg.n_action_steps,
policy=td_policy, policy=td_policy,
auto_cast_to_device=True, auto_cast_to_device=True,
) )
assert len(rollout) <= cfg.env.episode_length assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps
# set same episode index for all time steps contained in this rollout # set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout) online_buffer.extend(rollout)
@ -235,6 +248,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
env, env,
td_policy, td_policy,
num_episodes=cfg.eval_episodes, num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
return_first_video=True, return_first_video=True,
) )
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)