Refactor env queue, Training diffusion works (Still not converging)
This commit is contained in:
parent
fddd9f0311
commit
cfc304e870
|
@ -69,7 +69,7 @@ def make_offline_buffer(cfg, sampler=None):
|
|||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
elif cfg.env.name == "pusht":
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
|
@ -79,7 +79,7 @@ def make_offline_buffer(cfg, sampler=None):
|
|||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
|
|
@ -143,13 +143,24 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
in_keys=[
|
||||
# ("observation", "image"),
|
||||
("observation", "state"),
|
||||
# TODO(rcadene): for tdmpc, we might want image and state
|
||||
# ("next", "observation", "image"),
|
||||
("next", "observation", "state"),
|
||||
# ("next", "observation", "state"),
|
||||
("action"),
|
||||
],
|
||||
mode="min_max",
|
||||
)
|
||||
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, min_max_spec
|
||||
transform.stats["observation", "state", "min"] = torch.tensor(
|
||||
[13.456424, 32.938293], dtype=torch.float32
|
||||
)
|
||||
transform.stats["observation", "state", "max"] = torch.tensor(
|
||||
[496.14618, 510.9579], dtype=torch.float32
|
||||
)
|
||||
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
if collate_fn is None:
|
||||
|
|
|
@ -7,6 +7,8 @@ def make_env(cfg, transform=None):
|
|||
"from_pixels": cfg.env.from_pixels,
|
||||
"pixels_only": cfg.env.pixels_only,
|
||||
"image_size": cfg.env.image_size,
|
||||
# TODO(rcadene): do we want a specific eval_env_seed?
|
||||
"seed": cfg.seed,
|
||||
}
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
|
@ -17,6 +19,8 @@ def make_env(cfg, transform=None):
|
|||
elif cfg.env.name == "pusht":
|
||||
from lerobot.common.envs.pusht import PushtEnv
|
||||
|
||||
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||
|
||||
clsfunc = PushtEnv
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
|
|
@ -101,14 +101,18 @@ class PushtEnv(EnvBase):
|
|||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
# remove all previous observations
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.clear()
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.clear()
|
||||
|
||||
# copy the current observation n times
|
||||
obs = self._stack_prev_obs(obs)
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
|
@ -121,40 +125,6 @@ class PushtEnv(EnvBase):
|
|||
raise NotImplementedError()
|
||||
return td
|
||||
|
||||
def _stack_prev_obs(self, obs):
|
||||
"""When the queue is empty, copy the current observation n times."""
|
||||
assert self.num_prev_obs > 0
|
||||
|
||||
def stack_update_queue(prev_obs_queue, obs, num_prev_obs):
|
||||
# get n most recent observations
|
||||
prev_obs = list(prev_obs_queue)[-num_prev_obs:]
|
||||
|
||||
# if not enough observations, copy the oldest observation until we obtain n observations
|
||||
if len(prev_obs) == 0:
|
||||
prev_obs = [obs] * num_prev_obs # queue is empty when env reset
|
||||
elif len(prev_obs) < num_prev_obs:
|
||||
prev_obs = [prev_obs[0] for _ in range(num_prev_obs - len(prev_obs))] + prev_obs
|
||||
|
||||
# stack n most recent observations with the current observation
|
||||
stacked_obs = torch.stack(prev_obs + [obs], dim=0)
|
||||
|
||||
# add current observation to the queue
|
||||
# automatically remove oldest observation when queue is full
|
||||
prev_obs_queue.appendleft(obs)
|
||||
|
||||
return stacked_obs
|
||||
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
stacked_obs["image"] = stack_update_queue(
|
||||
self._prev_obs_image_queue, obs["image"], self.num_prev_obs
|
||||
)
|
||||
if "state" in obs:
|
||||
stacked_obs["state"] = stack_update_queue(
|
||||
self._prev_obs_state_queue, obs["state"], self.num_prev_obs
|
||||
)
|
||||
return stacked_obs
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
|
@ -176,7 +146,14 @@ class PushtEnv(EnvBase):
|
|||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
obs = self._stack_prev_obs(obs)
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"])
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
|
|
|
@ -1,51 +1,11 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
def make_dir(dir_path):
|
||||
"""Create directory if it does not already exist."""
|
||||
with contextlib.suppress(OSError):
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
return dir_path
|
||||
|
||||
|
||||
def print_run(cfg, reward=None):
|
||||
"""Pretty-printing of run information. Call at start of training."""
|
||||
prefix, color, attrs = " ", "green", ["bold"]
|
||||
|
||||
def limstr(s, maxlen=32):
|
||||
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
|
||||
|
||||
def pprint(k, v):
|
||||
print(
|
||||
prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs),
|
||||
limstr(v),
|
||||
)
|
||||
|
||||
kvs = [
|
||||
("task", cfg.env.task),
|
||||
("offline_steps", f"{cfg.offline_steps}"),
|
||||
("online_steps", f"{cfg.online_steps}"),
|
||||
("action_repeat", f"{cfg.env.action_repeat}"),
|
||||
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
|
||||
# ('actions', cfg.action_dim),
|
||||
# ('experiment', cfg.exp_name),
|
||||
]
|
||||
if reward is not None:
|
||||
kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"])))
|
||||
w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
|
||||
div = "-" * w
|
||||
print(div)
|
||||
for k, v in kvs:
|
||||
pprint(k, v)
|
||||
print(div)
|
||||
|
||||
|
||||
def cfg_to_group(cfg, return_list=False):
|
||||
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
||||
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
||||
|
@ -71,13 +31,12 @@ class Logger:
|
|||
self._seed = cfg.seed
|
||||
self._cfg = cfg
|
||||
self._eval = []
|
||||
print_run(cfg)
|
||||
project = cfg.get("wandb", {}).get("project")
|
||||
entity = cfg.get("wandb", {}).get("entity")
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
run_offline = not enable_wandb or not project or not entity
|
||||
if run_offline:
|
||||
print(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
|
@ -134,7 +93,6 @@ class Logger:
|
|||
self.save_buffer(buffer, identifier="buffer")
|
||||
if self._wandb:
|
||||
self._wandb.finish()
|
||||
print_run(self._cfg, self._eval[-1][-1])
|
||||
|
||||
def log_dict(self, d, step, mode="train"):
|
||||
assert mode in {"train", "eval"}
|
||||
|
|
|
@ -4,10 +4,8 @@ import time
|
|||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||
|
||||
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from .multi_image_obs_encoder import MultiImageObsEncoder
|
||||
|
@ -39,8 +37,8 @@ class DiffusionPolicy(nn.Module):
|
|||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
noise_scheduler = DDPMScheduler(**cfg_noise_scheduler)
|
||||
rgb_model = get_resnet(**cfg_rgb_model)
|
||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
|
||||
obs_encoder = MultiImageObsEncoder(
|
||||
rgb_model=rgb_model,
|
||||
**cfg_obs_encoder,
|
||||
|
@ -127,16 +125,36 @@ class DiffusionPolicy(nn.Module):
|
|||
# (t h) ... -> t h ...
|
||||
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
|
||||
|
||||
# |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
|
||||
# |o|o| observations: 2
|
||||
# | |a|a|a|a|a|a|a|a| actions executed: 8
|
||||
# |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
|
||||
# note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
|
||||
|
||||
image = batch["observation", "image"]
|
||||
state = batch["observation", "state"]
|
||||
action = batch["action"]
|
||||
assert image.shape[1] == horizon
|
||||
assert state.shape[1] == horizon
|
||||
assert action.shape[1] == horizon
|
||||
|
||||
if not (horizon == 16 and self.cfg.n_obs_steps == 2):
|
||||
raise NotImplementedError()
|
||||
|
||||
# keep first 2 observations of the slice corresponding to t=[-1,0]
|
||||
image = image[:, : self.cfg.n_obs_steps]
|
||||
state = state[:, : self.cfg.n_obs_steps]
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": batch["observation", "image"].to(self.device, non_blocking=True),
|
||||
"agent_pos": batch["observation", "state"].to(self.device, non_blocking=True),
|
||||
"image": image.to(self.device, non_blocking=True),
|
||||
"agent_pos": state.to(self.device, non_blocking=True),
|
||||
},
|
||||
"action": batch["action"].to(self.device, non_blocking=True),
|
||||
"action": action.to(self.device, non_blocking=True),
|
||||
}
|
||||
return out
|
||||
|
||||
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
defaults:
|
||||
- _self_
|
||||
- env: simxarm
|
||||
- policy: tdmpc
|
||||
- env: pusht
|
||||
- policy: diffusion
|
||||
|
||||
hydra:
|
||||
run:
|
||||
|
@ -22,6 +22,7 @@ save_buffer: false
|
|||
train_steps: ???
|
||||
fps: ???
|
||||
|
||||
n_action_steps: ???
|
||||
env: ???
|
||||
|
||||
policy: ???
|
||||
|
|
|
@ -21,7 +21,7 @@ past_action_visible: False
|
|||
keypoint_visible_rate: 1.0
|
||||
obs_as_global_cond: True
|
||||
|
||||
eval_episodes: 50
|
||||
eval_episodes: 1
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
|
@ -40,8 +40,8 @@ policy:
|
|||
num_inference_steps: 100
|
||||
obs_as_global_cond: ${obs_as_global_cond}
|
||||
# crop_shape: null
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims: [512, 1024, 2048]
|
||||
diffusion_step_embed_dim: 256 # before 128
|
||||
down_dims: [256, 512, 1024] # before [512, 1024, 2048]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
cond_predict_scale: True
|
||||
|
@ -62,7 +62,7 @@ policy:
|
|||
grad_clip_norm: 0
|
||||
|
||||
noise_scheduler:
|
||||
# _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
num_train_timesteps: 100
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
|
@ -74,16 +74,16 @@ noise_scheduler:
|
|||
obs_encoder:
|
||||
# _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
||||
shape_meta: ${shape_meta}
|
||||
resize_shape: null
|
||||
crop_shape: [76, 76]
|
||||
# resize_shape: null
|
||||
# crop_shape: [76, 76]
|
||||
# constant center crop
|
||||
random_crop: True
|
||||
# random_crop: True
|
||||
use_group_norm: True
|
||||
share_rgb_model: False
|
||||
imagenet_norm: False # TODO(rcadene): was set to True
|
||||
imagenet_norm: True
|
||||
|
||||
rgb_model:
|
||||
#_target_: diffusion_policy.model.vision.model_getter.get_resnet
|
||||
_target_: diffusion_policy.model.vision.model_getter.get_resnet
|
||||
name: resnet18
|
||||
weights: null
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# @package _global_
|
||||
|
||||
n_action_steps: 1
|
||||
|
||||
policy:
|
||||
name: tdmpc
|
||||
|
||||
|
|
|
@ -137,7 +137,7 @@ def eval(cfg: dict, out_dir=None):
|
|||
save_video=True,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
fps=cfg.env.fps,
|
||||
max_steps=cfg.env.episode_length,
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
num_episodes=cfg.eval_episodes,
|
||||
)
|
||||
print(metrics)
|
||||
|
|
|
@ -119,7 +119,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_seed(cfg.seed)
|
||||
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
logging.info("make_offline_buffer")
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
@ -149,6 +148,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
logging.info("make_policy")
|
||||
policy = make_policy(cfg)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
td_policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=["observation", "step_count"],
|
||||
|
@ -158,6 +160,16 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
# log metrics to terminal and wandb
|
||||
logger = Logger(out_dir, job_name, cfg)
|
||||
|
||||
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
||||
logging.info(f"{cfg.online_steps=}")
|
||||
logging.info(f"{cfg.env.action_repeat=}")
|
||||
logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})")
|
||||
logging.info(f"{offline_buffer.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
step = 0 # number of policy update
|
||||
|
||||
is_offline = True
|
||||
|
@ -175,6 +187,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
env,
|
||||
td_policy,
|
||||
num_episodes=cfg.eval_episodes,
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
return_first_video=True,
|
||||
)
|
||||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
||||
|
@ -199,11 +212,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
# TODO: add configurable number of rollout? (default=1)
|
||||
with torch.no_grad():
|
||||
rollout = env.rollout(
|
||||
max_steps=cfg.env.episode_length,
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
policy=td_policy,
|
||||
auto_cast_to_device=True,
|
||||
)
|
||||
assert len(rollout) <= cfg.env.episode_length
|
||||
assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps
|
||||
# set same episode index for all time steps contained in this rollout
|
||||
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
|
||||
online_buffer.extend(rollout)
|
||||
|
@ -235,6 +248,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
env,
|
||||
td_policy,
|
||||
num_episodes=cfg.eval_episodes,
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
return_first_video=True,
|
||||
)
|
||||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
||||
|
|
Loading…
Reference in New Issue