backup wip

This commit is contained in:
Alexander Soare 2024-03-20 15:01:27 +00:00
parent 32e3f71dd1
commit d323993569
7 changed files with 71 additions and 81 deletions

View File

@ -1,4 +1,3 @@
import cv2
import numpy as np import numpy as np
from gym import spaces from gym import spaces
@ -34,14 +33,14 @@ class PushTImageEnv(PushTEnv):
coord = (action / 512 * 96).astype(np.int32) coord = (action / 512 * 96).astype(np.int32)
marker_size = int(8 / 96 * self.render_size) marker_size = int(8 / 96 * self.render_size)
thickness = int(1 / 96 * self.render_size) thickness = int(1 / 96 * self.render_size)
cv2.drawMarker( # cv2.drawMarker(
img, # img,
coord, # coord,
color=(255, 0, 0), # color=(255, 0, 0),
markerType=cv2.MARKER_CROSS, # markerType=cv2.MARKER_CROSS,
markerSize=marker_size, # markerSize=marker_size,
thickness=thickness, # thickness=thickness,
) # )
self.render_cache = img self.render_cache = img
return obs return obs

View File

@ -15,11 +15,12 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class RgbEncoder(nn.Module): class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0.""" """Following `VisualCore` from Robomimic 0.2.0."""
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32): def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32):
""" """
input_shape: channel-first input shape (C, H, W) input_shape: channel-first input shape (C, H, W)
resnet_name: a timm model name. resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights. pretrained: whether to use timm pretrained weights.
rele: whether to use relu as a final step.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
""" """
super().__init__() super().__init__()
@ -30,9 +31,11 @@ class RgbEncoder(nn.Module):
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2) self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
self.relu = nn.ReLU() if relu else nn.Identity()
def forward(self, x): def forward(self, x):
return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)) # TODO(now): make nonlinearity optional
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
class MultiImageObsEncoder(ModuleAttrMixin): class MultiImageObsEncoder(ModuleAttrMixin):
@ -182,7 +185,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
feature = torch.moveaxis(feature, 0, 1) feature = torch.moveaxis(feature, 0, 1)
# (B,N*D) # (B,N*D)
feature = feature.reshape(batch_size, -1) feature = feature.reshape(batch_size, -1)
# feature = torch.nn.functional.relu(feature) # TODO: make optional
features.append(feature) features.append(feature)
else: else:
# run each rgb obs to independent models # run each rgb obs to independent models
@ -195,7 +197,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
assert img.shape[1:] == self.key_shape_map[key] assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img) img = self.key_transform_map[key](img)
feature = self.key_model_map[key](img) feature = self.key_model_map[key](img)
# feature = torch.nn.functional.relu(feature) # TODO: make optional
features.append(feature) features.append(feature)
# concatenate all features # concatenate all features

View File

@ -1,9 +1,11 @@
import copy import copy
import logging
import time import time
import hydra import hydra
import torch import torch
from lerobot.common.ema import update_ema_parameters
from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
@ -19,7 +21,6 @@ class DiffusionPolicy(AbstractPolicy):
cfg_rgb_model, cfg_rgb_model,
cfg_obs_encoder, cfg_obs_encoder,
cfg_optimizer, cfg_optimizer,
cfg_ema,
shape_meta: dict, shape_meta: dict,
horizon, horizon,
n_action_steps, n_action_steps,
@ -42,7 +43,6 @@ class DiffusionPolicy(AbstractPolicy):
if cfg_obs_encoder.crop_shape is not None: if cfg_obs_encoder.crop_shape is not None:
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model) rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
obs_encoder = MultiImageObsEncoder( obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model, rgb_model=rgb_model,
**cfg_obs_encoder, **cfg_obs_encoder,
@ -70,12 +70,9 @@ class DiffusionPolicy(AbstractPolicy):
if torch.cuda.is_available() and cfg_device == "cuda": if torch.cuda.is_available() and cfg_device == "cuda":
self.diffusion.cuda() self.diffusion.cuda()
self.ema = None self.ema_diffusion = None
if self.cfg.use_ema: if self.cfg.ema.enable:
self.ema = hydra.utils.instantiate( self.ema_diffusion = copy.deepcopy(self.diffusion)
cfg_ema,
model=copy.deepcopy(self.diffusion),
)
self.optimizer = hydra.utils.instantiate( self.optimizer = hydra.utils.instantiate(
cfg_optimizer, cfg_optimizer,
@ -98,6 +95,9 @@ class DiffusionPolicy(AbstractPolicy):
@torch.no_grad() @torch.no_grad()
def select_actions(self, observation, step_count): def select_actions(self, observation, step_count):
"""
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
"""
# TODO(rcadene): remove unused step_count # TODO(rcadene): remove unused step_count
del step_count del step_count
@ -105,7 +105,10 @@ class DiffusionPolicy(AbstractPolicy):
"image": observation["image"], "image": observation["image"],
"agent_pos": observation["state"], "agent_pos": observation["state"],
} }
out = self.diffusion.predict_action(obs_dict) if self.training:
out = self.diffusion.predict_action(obs_dict)
else:
out = self.ema_diffusion.predict_action(obs_dict)
action = out["action"] action = out["action"]
return action return action
@ -172,8 +175,8 @@ class DiffusionPolicy(AbstractPolicy):
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.lr_scheduler.step() self.lr_scheduler.step()
if self.ema is not None: if self.cfg.ema.enable:
self.ema.step(self.diffusion) update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate)
info = { info = {
"loss": loss.item(), "loss": loss.item(),
@ -195,4 +198,10 @@ class DiffusionPolicy(AbstractPolicy):
def load(self, fp): def load(self, fp):
d = torch.load(fp) d = torch.load(fp)
self.load_state_dict(d) missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
logging.warning(
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
)
assert len(unexpected_keys) == 0

View File

@ -16,7 +16,6 @@ def make_policy(cfg):
cfg_rgb_model=cfg.rgb_model, cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder, cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer, cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy, **cfg.policy,
) )
@ -41,7 +40,7 @@ def make_policy(cfg):
policy.load(cfg.policy.pretrained_model_path) policy.load(cfg.policy.pretrained_model_path)
# import torch # import torch
# loaded = torch.load('/home/alexander/Downloads/dp_ema.pth') # loaded = torch.load('/home/alexander/Downloads/dp.pth')
# aligned = {} # aligned = {}
# their_prefix = "obs_encoder.obs_nets.image.backbone" # their_prefix = "obs_encoder.obs_nets.image.backbone"

View File

@ -12,14 +12,14 @@ hydra:
seed: 1337 seed: 1337
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index # batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
# NOTE: only diffusion policy supports rollout_batch_size > 1 # NOTE: only diffusion policy supports rollout_batch_size > 1
rollout_batch_size: 1 rollout_batch_size: 10
device: cuda # cpu device: cuda # cpu
prefetch: 4 prefetch: 4
eval_freq: ??? eval_freq: ???
save_freq: ??? save_freq: ???
eval_episodes: ??? eval_episodes: ???
save_video: false save_video: false
save_model: false save_model: true
save_buffer: false save_buffer: false
train_steps: ??? train_steps: ???
fps: ??? fps: ???
@ -34,6 +34,6 @@ policy: ???
wandb: wandb:
enable: true enable: true
# Set to true to disable saving an artifact despite save_model == True # Set to true to disable saving an artifact despite save_model == True
disable_artifact: false disable_artifact: true
project: lerobot project: lerobot
notes: "" notes: ""

View File

@ -21,12 +21,12 @@ past_action_visible: False
keypoint_visible_rate: 1.0 keypoint_visible_rate: 1.0
obs_as_global_cond: True obs_as_global_cond: True
eval_episodes: 1 eval_episodes: 50
eval_freq: 10000 eval_freq: 5000
save_freq: 100000 save_freq: 5000
log_freq: 250 log_freq: 250
offline_steps: 1344000 offline_steps: 50000
online_steps: 0 online_steps: 0
offline_prioritized_sampler: true offline_prioritized_sampler: true
@ -58,7 +58,9 @@ policy:
balanced_sampling: false balanced_sampling: false
utd: 1 utd: 1
offline_steps: ${offline_steps} offline_steps: ${offline_steps}
use_ema: true ema:
enable: true
rate: 0.999
lr_scheduler: cosine lr_scheduler: cosine
lr_warmup_steps: 500 lr_warmup_steps: 500
grad_clip_norm: 10 grad_clip_norm: 10
@ -87,14 +89,7 @@ rgb_model:
model_name: resnet18 model_name: resnet18
pretrained: false pretrained: false
num_keypoints: 32 num_keypoints: 32
relu: true
ema:
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
update_after_step: 0
inv_gamma: 1.0
power: 0.75
min_value: 0.0
max_value: 0.9999
optimizer: optimizer:
_target_: torch.optim.AdamW _target_: torch.optim.AdamW

View File

@ -155,11 +155,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
td_policy = TensorDictModule( td_policy = TensorDictModule(policy, in_keys=["observation", "step_count"], out_keys=["action"])
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
# log metrics to terminal and wandb # log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg) logger = Logger(out_dir, job_name, cfg)
@ -174,19 +170,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
step = 0 # number of policy update (forward + backward + optim) # Note: this helper will be used in offline and online training loops.
def _maybe_eval_and_maybe_save(step):
is_offline = True if step % cfg.eval_freq == 0:
for offline_step in range(cfg.offline_steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
policy.train()
train_info = policy.update(offline_buffer, step)
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
if step > 0 and step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy( eval_info, first_video = eval_policy(
env, env,
@ -202,11 +188,27 @@ def train(cfg: dict, out_dir=None, job_name=None):
logger.log_video(first_video, step, mode="eval") logger.log_video(first_video, step, mode="eval")
logging.info("Resume training") logging.info("Resume training")
if step > 0 and cfg.save_model and step % cfg.save_freq == 0: if cfg.save_model and step % cfg.save_freq == 0:
logging.info(f"Checkpoint policy at step {step}") logging.info(f"Checkpoint policy after step {step}")
logger.save_model(policy, identifier=step) logger.save_model(policy, identifier=step)
logging.info("Resume training") logging.info("Resume training")
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
for offline_step in range(cfg.offline_steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
policy.train()
train_info = policy.update(offline_buffer, step)
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
# step + 1.
_maybe_eval_and_maybe_save(step + 1)
step += 1 step += 1
demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None
@ -248,24 +250,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
train_info.update(rollout_info) train_info.update(rollout_info)
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
if step > 0 and step % cfg.eval_freq == 0: # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
logging.info(f"Eval policy at step {step}") # in step + 1.
eval_info, first_video = eval_policy( _maybe_eval_and_maybe_save(step + 1)
env,
td_policy,
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
return_first_video=True,
)
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval")
logging.info("Resume training")
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
logging.info(f"Checkpoint policy at step {step}")
logger.save_model(policy, identifier=step)
logging.info("Resume training")
step += 1 step += 1
online_step += 1 online_step += 1