backup wip
This commit is contained in:
parent
32e3f71dd1
commit
d323993569
|
@ -1,4 +1,3 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
|
@ -34,14 +33,14 @@ class PushTImageEnv(PushTEnv):
|
|||
coord = (action / 512 * 96).astype(np.int32)
|
||||
marker_size = int(8 / 96 * self.render_size)
|
||||
thickness = int(1 / 96 * self.render_size)
|
||||
cv2.drawMarker(
|
||||
img,
|
||||
coord,
|
||||
color=(255, 0, 0),
|
||||
markerType=cv2.MARKER_CROSS,
|
||||
markerSize=marker_size,
|
||||
thickness=thickness,
|
||||
)
|
||||
# cv2.drawMarker(
|
||||
# img,
|
||||
# coord,
|
||||
# color=(255, 0, 0),
|
||||
# markerType=cv2.MARKER_CROSS,
|
||||
# markerSize=marker_size,
|
||||
# thickness=thickness,
|
||||
# )
|
||||
self.render_cache = img
|
||||
|
||||
return obs
|
||||
|
|
|
@ -15,11 +15,12 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
|||
class RgbEncoder(nn.Module):
|
||||
"""Following `VisualCore` from Robomimic 0.2.0."""
|
||||
|
||||
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
|
||||
def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32):
|
||||
"""
|
||||
input_shape: channel-first input shape (C, H, W)
|
||||
resnet_name: a timm model name.
|
||||
pretrained: whether to use timm pretrained weights.
|
||||
rele: whether to use relu as a final step.
|
||||
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
||||
"""
|
||||
super().__init__()
|
||||
|
@ -30,9 +31,11 @@ class RgbEncoder(nn.Module):
|
|||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
|
||||
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
|
||||
self.relu = nn.ReLU() if relu else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))
|
||||
# TODO(now): make nonlinearity optional
|
||||
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||
|
||||
|
||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
|
@ -182,7 +185,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
|||
feature = torch.moveaxis(feature, 0, 1)
|
||||
# (B,N*D)
|
||||
feature = feature.reshape(batch_size, -1)
|
||||
# feature = torch.nn.functional.relu(feature) # TODO: make optional
|
||||
features.append(feature)
|
||||
else:
|
||||
# run each rgb obs to independent models
|
||||
|
@ -195,7 +197,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
|||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
img = self.key_transform_map[key](img)
|
||||
feature = self.key_model_map[key](img)
|
||||
# feature = torch.nn.functional.relu(feature) # TODO: make optional
|
||||
features.append(feature)
|
||||
|
||||
# concatenate all features
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import copy
|
||||
import logging
|
||||
import time
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
|
||||
from lerobot.common.ema import update_ema_parameters
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||
|
@ -19,7 +21,6 @@ class DiffusionPolicy(AbstractPolicy):
|
|||
cfg_rgb_model,
|
||||
cfg_obs_encoder,
|
||||
cfg_optimizer,
|
||||
cfg_ema,
|
||||
shape_meta: dict,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
|
@ -42,7 +43,6 @@ class DiffusionPolicy(AbstractPolicy):
|
|||
if cfg_obs_encoder.crop_shape is not None:
|
||||
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
|
||||
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
|
||||
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
|
||||
obs_encoder = MultiImageObsEncoder(
|
||||
rgb_model=rgb_model,
|
||||
**cfg_obs_encoder,
|
||||
|
@ -70,12 +70,9 @@ class DiffusionPolicy(AbstractPolicy):
|
|||
if torch.cuda.is_available() and cfg_device == "cuda":
|
||||
self.diffusion.cuda()
|
||||
|
||||
self.ema = None
|
||||
if self.cfg.use_ema:
|
||||
self.ema = hydra.utils.instantiate(
|
||||
cfg_ema,
|
||||
model=copy.deepcopy(self.diffusion),
|
||||
)
|
||||
self.ema_diffusion = None
|
||||
if self.cfg.ema.enable:
|
||||
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
||||
|
||||
self.optimizer = hydra.utils.instantiate(
|
||||
cfg_optimizer,
|
||||
|
@ -98,6 +95,9 @@ class DiffusionPolicy(AbstractPolicy):
|
|||
|
||||
@torch.no_grad()
|
||||
def select_actions(self, observation, step_count):
|
||||
"""
|
||||
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
||||
"""
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
|
@ -105,7 +105,10 @@ class DiffusionPolicy(AbstractPolicy):
|
|||
"image": observation["image"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
if self.training:
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
else:
|
||||
out = self.ema_diffusion.predict_action(obs_dict)
|
||||
action = out["action"]
|
||||
return action
|
||||
|
||||
|
@ -172,8 +175,8 @@ class DiffusionPolicy(AbstractPolicy):
|
|||
self.optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.ema is not None:
|
||||
self.ema.step(self.diffusion)
|
||||
if self.cfg.ema.enable:
|
||||
update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate)
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
|
@ -195,4 +198,10 @@ class DiffusionPolicy(AbstractPolicy):
|
|||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
||||
if len(missing_keys) > 0:
|
||||
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
||||
logging.warning(
|
||||
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
|
||||
)
|
||||
assert len(unexpected_keys) == 0
|
||||
|
|
|
@ -16,7 +16,6 @@ def make_policy(cfg):
|
|||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
|
@ -41,7 +40,7 @@ def make_policy(cfg):
|
|||
policy.load(cfg.policy.pretrained_model_path)
|
||||
|
||||
# import torch
|
||||
# loaded = torch.load('/home/alexander/Downloads/dp_ema.pth')
|
||||
# loaded = torch.load('/home/alexander/Downloads/dp.pth')
|
||||
# aligned = {}
|
||||
|
||||
# their_prefix = "obs_encoder.obs_nets.image.backbone"
|
||||
|
|
|
@ -12,14 +12,14 @@ hydra:
|
|||
seed: 1337
|
||||
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
|
||||
# NOTE: only diffusion policy supports rollout_batch_size > 1
|
||||
rollout_batch_size: 1
|
||||
rollout_batch_size: 10
|
||||
device: cuda # cpu
|
||||
prefetch: 4
|
||||
eval_freq: ???
|
||||
save_freq: ???
|
||||
eval_episodes: ???
|
||||
save_video: false
|
||||
save_model: false
|
||||
save_model: true
|
||||
save_buffer: false
|
||||
train_steps: ???
|
||||
fps: ???
|
||||
|
@ -34,6 +34,6 @@ policy: ???
|
|||
wandb:
|
||||
enable: true
|
||||
# Set to true to disable saving an artifact despite save_model == True
|
||||
disable_artifact: false
|
||||
disable_artifact: true
|
||||
project: lerobot
|
||||
notes: ""
|
||||
|
|
|
@ -21,12 +21,12 @@ past_action_visible: False
|
|||
keypoint_visible_rate: 1.0
|
||||
obs_as_global_cond: True
|
||||
|
||||
eval_episodes: 1
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
eval_episodes: 50
|
||||
eval_freq: 5000
|
||||
save_freq: 5000
|
||||
log_freq: 250
|
||||
|
||||
offline_steps: 1344000
|
||||
offline_steps: 50000
|
||||
online_steps: 0
|
||||
|
||||
offline_prioritized_sampler: true
|
||||
|
@ -58,7 +58,9 @@ policy:
|
|||
balanced_sampling: false
|
||||
utd: 1
|
||||
offline_steps: ${offline_steps}
|
||||
use_ema: true
|
||||
ema:
|
||||
enable: true
|
||||
rate: 0.999
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
grad_clip_norm: 10
|
||||
|
@ -87,14 +89,7 @@ rgb_model:
|
|||
model_name: resnet18
|
||||
pretrained: false
|
||||
num_keypoints: 32
|
||||
|
||||
ema:
|
||||
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
|
||||
update_after_step: 0
|
||||
inv_gamma: 1.0
|
||||
power: 0.75
|
||||
min_value: 0.0
|
||||
max_value: 0.9999
|
||||
relu: true
|
||||
|
||||
optimizer:
|
||||
_target_: torch.optim.AdamW
|
||||
|
|
|
@ -155,11 +155,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
td_policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=["observation", "step_count"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
td_policy = TensorDictModule(policy, in_keys=["observation", "step_count"], out_keys=["action"])
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(out_dir, job_name, cfg)
|
||||
|
@ -174,19 +170,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
step = 0 # number of policy update (forward + backward + optim)
|
||||
|
||||
is_offline = True
|
||||
for offline_step in range(cfg.offline_steps):
|
||||
if offline_step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
policy.train()
|
||||
train_info = policy.update(offline_buffer, step)
|
||||
if step % cfg.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
||||
|
||||
if step > 0 and step % cfg.eval_freq == 0:
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def _maybe_eval_and_maybe_save(step):
|
||||
if step % cfg.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
eval_info, first_video = eval_policy(
|
||||
env,
|
||||
|
@ -202,11 +188,27 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
logger.log_video(first_video, step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
||||
logging.info(f"Checkpoint policy at step {step}")
|
||||
if cfg.save_model and step % cfg.save_freq == 0:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
logger.save_model(policy, identifier=step)
|
||||
logging.info("Resume training")
|
||||
|
||||
step = 0 # number of policy update (forward + backward + optim)
|
||||
|
||||
is_offline = True
|
||||
for offline_step in range(cfg.offline_steps):
|
||||
if offline_step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
policy.train()
|
||||
train_info = policy.update(offline_buffer, step)
|
||||
if step % cfg.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
||||
|
||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
|
||||
# step + 1.
|
||||
_maybe_eval_and_maybe_save(step + 1)
|
||||
|
||||
step += 1
|
||||
|
||||
demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None
|
||||
|
@ -248,24 +250,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
train_info.update(rollout_info)
|
||||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
||||
|
||||
if step > 0 and step % cfg.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
eval_info, first_video = eval_policy(
|
||||
env,
|
||||
td_policy,
|
||||
num_episodes=cfg.eval_episodes,
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
return_first_video=True,
|
||||
)
|
||||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(first_video, step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
||||
logging.info(f"Checkpoint policy at step {step}")
|
||||
logger.save_model(policy, identifier=step)
|
||||
logging.info("Resume training")
|
||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
|
||||
# in step + 1.
|
||||
_maybe_eval_and_maybe_save(step + 1)
|
||||
|
||||
step += 1
|
||||
online_step += 1
|
||||
|
|
Loading…
Reference in New Issue