backup wip
This commit is contained in:
parent
32e3f71dd1
commit
d323993569
|
@ -1,4 +1,3 @@
|
||||||
import cv2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
|
|
||||||
|
@ -34,14 +33,14 @@ class PushTImageEnv(PushTEnv):
|
||||||
coord = (action / 512 * 96).astype(np.int32)
|
coord = (action / 512 * 96).astype(np.int32)
|
||||||
marker_size = int(8 / 96 * self.render_size)
|
marker_size = int(8 / 96 * self.render_size)
|
||||||
thickness = int(1 / 96 * self.render_size)
|
thickness = int(1 / 96 * self.render_size)
|
||||||
cv2.drawMarker(
|
# cv2.drawMarker(
|
||||||
img,
|
# img,
|
||||||
coord,
|
# coord,
|
||||||
color=(255, 0, 0),
|
# color=(255, 0, 0),
|
||||||
markerType=cv2.MARKER_CROSS,
|
# markerType=cv2.MARKER_CROSS,
|
||||||
markerSize=marker_size,
|
# markerSize=marker_size,
|
||||||
thickness=thickness,
|
# thickness=thickness,
|
||||||
)
|
# )
|
||||||
self.render_cache = img
|
self.render_cache = img
|
||||||
|
|
||||||
return obs
|
return obs
|
||||||
|
|
|
@ -15,11 +15,12 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
||||||
class RgbEncoder(nn.Module):
|
class RgbEncoder(nn.Module):
|
||||||
"""Following `VisualCore` from Robomimic 0.2.0."""
|
"""Following `VisualCore` from Robomimic 0.2.0."""
|
||||||
|
|
||||||
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
|
def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32):
|
||||||
"""
|
"""
|
||||||
input_shape: channel-first input shape (C, H, W)
|
input_shape: channel-first input shape (C, H, W)
|
||||||
resnet_name: a timm model name.
|
resnet_name: a timm model name.
|
||||||
pretrained: whether to use timm pretrained weights.
|
pretrained: whether to use timm pretrained weights.
|
||||||
|
rele: whether to use relu as a final step.
|
||||||
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -30,9 +31,11 @@ class RgbEncoder(nn.Module):
|
||||||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
||||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
|
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
|
||||||
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
|
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
|
||||||
|
self.relu = nn.ReLU() if relu else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))
|
# TODO(now): make nonlinearity optional
|
||||||
|
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||||
|
|
||||||
|
|
||||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
|
@ -182,7 +185,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
feature = torch.moveaxis(feature, 0, 1)
|
feature = torch.moveaxis(feature, 0, 1)
|
||||||
# (B,N*D)
|
# (B,N*D)
|
||||||
feature = feature.reshape(batch_size, -1)
|
feature = feature.reshape(batch_size, -1)
|
||||||
# feature = torch.nn.functional.relu(feature) # TODO: make optional
|
|
||||||
features.append(feature)
|
features.append(feature)
|
||||||
else:
|
else:
|
||||||
# run each rgb obs to independent models
|
# run each rgb obs to independent models
|
||||||
|
@ -195,7 +197,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
assert img.shape[1:] == self.key_shape_map[key]
|
assert img.shape[1:] == self.key_shape_map[key]
|
||||||
img = self.key_transform_map[key](img)
|
img = self.key_transform_map[key](img)
|
||||||
feature = self.key_model_map[key](img)
|
feature = self.key_model_map[key](img)
|
||||||
# feature = torch.nn.functional.relu(feature) # TODO: make optional
|
|
||||||
features.append(feature)
|
features.append(feature)
|
||||||
|
|
||||||
# concatenate all features
|
# concatenate all features
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.ema import update_ema_parameters
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||||
|
@ -19,7 +21,6 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
cfg_rgb_model,
|
cfg_rgb_model,
|
||||||
cfg_obs_encoder,
|
cfg_obs_encoder,
|
||||||
cfg_optimizer,
|
cfg_optimizer,
|
||||||
cfg_ema,
|
|
||||||
shape_meta: dict,
|
shape_meta: dict,
|
||||||
horizon,
|
horizon,
|
||||||
n_action_steps,
|
n_action_steps,
|
||||||
|
@ -42,7 +43,6 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
if cfg_obs_encoder.crop_shape is not None:
|
if cfg_obs_encoder.crop_shape is not None:
|
||||||
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
|
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
|
||||||
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
|
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
|
||||||
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
|
|
||||||
obs_encoder = MultiImageObsEncoder(
|
obs_encoder = MultiImageObsEncoder(
|
||||||
rgb_model=rgb_model,
|
rgb_model=rgb_model,
|
||||||
**cfg_obs_encoder,
|
**cfg_obs_encoder,
|
||||||
|
@ -70,12 +70,9 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
if torch.cuda.is_available() and cfg_device == "cuda":
|
if torch.cuda.is_available() and cfg_device == "cuda":
|
||||||
self.diffusion.cuda()
|
self.diffusion.cuda()
|
||||||
|
|
||||||
self.ema = None
|
self.ema_diffusion = None
|
||||||
if self.cfg.use_ema:
|
if self.cfg.ema.enable:
|
||||||
self.ema = hydra.utils.instantiate(
|
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
||||||
cfg_ema,
|
|
||||||
model=copy.deepcopy(self.diffusion),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer = hydra.utils.instantiate(
|
self.optimizer = hydra.utils.instantiate(
|
||||||
cfg_optimizer,
|
cfg_optimizer,
|
||||||
|
@ -98,6 +95,9 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_actions(self, observation, step_count):
|
def select_actions(self, observation, step_count):
|
||||||
|
"""
|
||||||
|
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
||||||
|
"""
|
||||||
# TODO(rcadene): remove unused step_count
|
# TODO(rcadene): remove unused step_count
|
||||||
del step_count
|
del step_count
|
||||||
|
|
||||||
|
@ -105,7 +105,10 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
"image": observation["image"],
|
"image": observation["image"],
|
||||||
"agent_pos": observation["state"],
|
"agent_pos": observation["state"],
|
||||||
}
|
}
|
||||||
out = self.diffusion.predict_action(obs_dict)
|
if self.training:
|
||||||
|
out = self.diffusion.predict_action(obs_dict)
|
||||||
|
else:
|
||||||
|
out = self.ema_diffusion.predict_action(obs_dict)
|
||||||
action = out["action"]
|
action = out["action"]
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
@ -172,8 +175,8 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.lr_scheduler.step()
|
self.lr_scheduler.step()
|
||||||
|
|
||||||
if self.ema is not None:
|
if self.cfg.ema.enable:
|
||||||
self.ema.step(self.diffusion)
|
update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate)
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
|
@ -195,4 +198,10 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
|
|
||||||
def load(self, fp):
|
def load(self, fp):
|
||||||
d = torch.load(fp)
|
d = torch.load(fp)
|
||||||
self.load_state_dict(d)
|
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
||||||
|
logging.warning(
|
||||||
|
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
|
||||||
|
)
|
||||||
|
assert len(unexpected_keys) == 0
|
||||||
|
|
|
@ -16,7 +16,6 @@ def make_policy(cfg):
|
||||||
cfg_rgb_model=cfg.rgb_model,
|
cfg_rgb_model=cfg.rgb_model,
|
||||||
cfg_obs_encoder=cfg.obs_encoder,
|
cfg_obs_encoder=cfg.obs_encoder,
|
||||||
cfg_optimizer=cfg.optimizer,
|
cfg_optimizer=cfg.optimizer,
|
||||||
cfg_ema=cfg.ema,
|
|
||||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||||
**cfg.policy,
|
**cfg.policy,
|
||||||
)
|
)
|
||||||
|
@ -41,7 +40,7 @@ def make_policy(cfg):
|
||||||
policy.load(cfg.policy.pretrained_model_path)
|
policy.load(cfg.policy.pretrained_model_path)
|
||||||
|
|
||||||
# import torch
|
# import torch
|
||||||
# loaded = torch.load('/home/alexander/Downloads/dp_ema.pth')
|
# loaded = torch.load('/home/alexander/Downloads/dp.pth')
|
||||||
# aligned = {}
|
# aligned = {}
|
||||||
|
|
||||||
# their_prefix = "obs_encoder.obs_nets.image.backbone"
|
# their_prefix = "obs_encoder.obs_nets.image.backbone"
|
||||||
|
|
|
@ -12,14 +12,14 @@ hydra:
|
||||||
seed: 1337
|
seed: 1337
|
||||||
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
|
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
|
||||||
# NOTE: only diffusion policy supports rollout_batch_size > 1
|
# NOTE: only diffusion policy supports rollout_batch_size > 1
|
||||||
rollout_batch_size: 1
|
rollout_batch_size: 10
|
||||||
device: cuda # cpu
|
device: cuda # cpu
|
||||||
prefetch: 4
|
prefetch: 4
|
||||||
eval_freq: ???
|
eval_freq: ???
|
||||||
save_freq: ???
|
save_freq: ???
|
||||||
eval_episodes: ???
|
eval_episodes: ???
|
||||||
save_video: false
|
save_video: false
|
||||||
save_model: false
|
save_model: true
|
||||||
save_buffer: false
|
save_buffer: false
|
||||||
train_steps: ???
|
train_steps: ???
|
||||||
fps: ???
|
fps: ???
|
||||||
|
@ -34,6 +34,6 @@ policy: ???
|
||||||
wandb:
|
wandb:
|
||||||
enable: true
|
enable: true
|
||||||
# Set to true to disable saving an artifact despite save_model == True
|
# Set to true to disable saving an artifact despite save_model == True
|
||||||
disable_artifact: false
|
disable_artifact: true
|
||||||
project: lerobot
|
project: lerobot
|
||||||
notes: ""
|
notes: ""
|
||||||
|
|
|
@ -21,12 +21,12 @@ past_action_visible: False
|
||||||
keypoint_visible_rate: 1.0
|
keypoint_visible_rate: 1.0
|
||||||
obs_as_global_cond: True
|
obs_as_global_cond: True
|
||||||
|
|
||||||
eval_episodes: 1
|
eval_episodes: 50
|
||||||
eval_freq: 10000
|
eval_freq: 5000
|
||||||
save_freq: 100000
|
save_freq: 5000
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
|
|
||||||
offline_steps: 1344000
|
offline_steps: 50000
|
||||||
online_steps: 0
|
online_steps: 0
|
||||||
|
|
||||||
offline_prioritized_sampler: true
|
offline_prioritized_sampler: true
|
||||||
|
@ -58,7 +58,9 @@ policy:
|
||||||
balanced_sampling: false
|
balanced_sampling: false
|
||||||
utd: 1
|
utd: 1
|
||||||
offline_steps: ${offline_steps}
|
offline_steps: ${offline_steps}
|
||||||
use_ema: true
|
ema:
|
||||||
|
enable: true
|
||||||
|
rate: 0.999
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
lr_warmup_steps: 500
|
lr_warmup_steps: 500
|
||||||
grad_clip_norm: 10
|
grad_clip_norm: 10
|
||||||
|
@ -87,14 +89,7 @@ rgb_model:
|
||||||
model_name: resnet18
|
model_name: resnet18
|
||||||
pretrained: false
|
pretrained: false
|
||||||
num_keypoints: 32
|
num_keypoints: 32
|
||||||
|
relu: true
|
||||||
ema:
|
|
||||||
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
|
|
||||||
update_after_step: 0
|
|
||||||
inv_gamma: 1.0
|
|
||||||
power: 0.75
|
|
||||||
min_value: 0.0
|
|
||||||
max_value: 0.9999
|
|
||||||
|
|
||||||
optimizer:
|
optimizer:
|
||||||
_target_: torch.optim.AdamW
|
_target_: torch.optim.AdamW
|
||||||
|
|
|
@ -155,11 +155,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
td_policy = TensorDictModule(
|
td_policy = TensorDictModule(policy, in_keys=["observation", "step_count"], out_keys=["action"])
|
||||||
policy,
|
|
||||||
in_keys=["observation", "step_count"],
|
|
||||||
out_keys=["action"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# log metrics to terminal and wandb
|
# log metrics to terminal and wandb
|
||||||
logger = Logger(out_dir, job_name, cfg)
|
logger = Logger(out_dir, job_name, cfg)
|
||||||
|
@ -174,19 +170,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
step = 0 # number of policy update (forward + backward + optim)
|
# Note: this helper will be used in offline and online training loops.
|
||||||
|
def _maybe_eval_and_maybe_save(step):
|
||||||
is_offline = True
|
if step % cfg.eval_freq == 0:
|
||||||
for offline_step in range(cfg.offline_steps):
|
|
||||||
if offline_step == 0:
|
|
||||||
logging.info("Start offline training on a fixed dataset")
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
|
||||||
policy.train()
|
|
||||||
train_info = policy.update(offline_buffer, step)
|
|
||||||
if step % cfg.log_freq == 0:
|
|
||||||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
|
||||||
|
|
||||||
if step > 0 and step % cfg.eval_freq == 0:
|
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
eval_info, first_video = eval_policy(
|
eval_info, first_video = eval_policy(
|
||||||
env,
|
env,
|
||||||
|
@ -202,11 +188,27 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
logger.log_video(first_video, step, mode="eval")
|
logger.log_video(first_video, step, mode="eval")
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
if cfg.save_model and step % cfg.save_freq == 0:
|
||||||
logging.info(f"Checkpoint policy at step {step}")
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
logger.save_model(policy, identifier=step)
|
logger.save_model(policy, identifier=step)
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
|
step = 0 # number of policy update (forward + backward + optim)
|
||||||
|
|
||||||
|
is_offline = True
|
||||||
|
for offline_step in range(cfg.offline_steps):
|
||||||
|
if offline_step == 0:
|
||||||
|
logging.info("Start offline training on a fixed dataset")
|
||||||
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
|
policy.train()
|
||||||
|
train_info = policy.update(offline_buffer, step)
|
||||||
|
if step % cfg.log_freq == 0:
|
||||||
|
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
||||||
|
|
||||||
|
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
|
||||||
|
# step + 1.
|
||||||
|
_maybe_eval_and_maybe_save(step + 1)
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None
|
demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None
|
||||||
|
@ -248,24 +250,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
train_info.update(rollout_info)
|
train_info.update(rollout_info)
|
||||||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
||||||
|
|
||||||
if step > 0 and step % cfg.eval_freq == 0:
|
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
|
||||||
logging.info(f"Eval policy at step {step}")
|
# in step + 1.
|
||||||
eval_info, first_video = eval_policy(
|
_maybe_eval_and_maybe_save(step + 1)
|
||||||
env,
|
|
||||||
td_policy,
|
|
||||||
num_episodes=cfg.eval_episodes,
|
|
||||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
|
||||||
return_first_video=True,
|
|
||||||
)
|
|
||||||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
|
||||||
if cfg.wandb.enable:
|
|
||||||
logger.log_video(first_video, step, mode="eval")
|
|
||||||
logging.info("Resume training")
|
|
||||||
|
|
||||||
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
|
||||||
logging.info(f"Checkpoint policy at step {step}")
|
|
||||||
logger.save_model(policy, identifier=step)
|
|
||||||
logging.info("Resume training")
|
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
online_step += 1
|
online_step += 1
|
||||||
|
|
Loading…
Reference in New Issue