From d3239935691307c8f8778fe26ebd9c908899f5df Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 20 Mar 2024 15:01:27 +0000 Subject: [PATCH] backup wip --- lerobot/common/envs/pusht/pusht_image_env.py | 17 +++-- .../model/multi_image_obs_encoder.py | 9 +-- lerobot/common/policies/diffusion/policy.py | 33 ++++++---- lerobot/common/policies/factory.py | 3 +- lerobot/configs/default.yaml | 6 +- lerobot/configs/policy/diffusion.yaml | 21 +++---- lerobot/scripts/train.py | 63 ++++++++----------- 7 files changed, 71 insertions(+), 81 deletions(-) diff --git a/lerobot/common/envs/pusht/pusht_image_env.py b/lerobot/common/envs/pusht/pusht_image_env.py index 0807e849..b30ad874 100644 --- a/lerobot/common/envs/pusht/pusht_image_env.py +++ b/lerobot/common/envs/pusht/pusht_image_env.py @@ -1,4 +1,3 @@ -import cv2 import numpy as np from gym import spaces @@ -34,14 +33,14 @@ class PushTImageEnv(PushTEnv): coord = (action / 512 * 96).astype(np.int32) marker_size = int(8 / 96 * self.render_size) thickness = int(1 / 96 * self.render_size) - cv2.drawMarker( - img, - coord, - color=(255, 0, 0), - markerType=cv2.MARKER_CROSS, - markerSize=marker_size, - thickness=thickness, - ) + # cv2.drawMarker( + # img, + # coord, + # color=(255, 0, 0), + # markerType=cv2.MARKER_CROSS, + # markerSize=marker_size, + # thickness=thickness, + # ) self.render_cache = img return obs diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index 17252c1c..c7b9807d 100644 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -15,11 +15,12 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules class RgbEncoder(nn.Module): """Following `VisualCore` from Robomimic 0.2.0.""" - def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32): + def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32): """ input_shape: channel-first input shape (C, H, W) resnet_name: a timm model name. pretrained: whether to use timm pretrained weights. + rele: whether to use relu as a final step. num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). """ super().__init__() @@ -30,9 +31,11 @@ class RgbEncoder(nn.Module): feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2) + self.relu = nn.ReLU() if relu else nn.Identity() def forward(self, x): - return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)) + # TODO(now): make nonlinearity optional + return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))) class MultiImageObsEncoder(ModuleAttrMixin): @@ -182,7 +185,6 @@ class MultiImageObsEncoder(ModuleAttrMixin): feature = torch.moveaxis(feature, 0, 1) # (B,N*D) feature = feature.reshape(batch_size, -1) - # feature = torch.nn.functional.relu(feature) # TODO: make optional features.append(feature) else: # run each rgb obs to independent models @@ -195,7 +197,6 @@ class MultiImageObsEncoder(ModuleAttrMixin): assert img.shape[1:] == self.key_shape_map[key] img = self.key_transform_map[key](img) feature = self.key_model_map[key](img) - # feature = torch.nn.functional.relu(feature) # TODO: make optional features.append(feature) # concatenate all features diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index f68ffb8e..a4185afc 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -1,9 +1,11 @@ import copy +import logging import time import hydra import torch +from lerobot.common.ema import update_ema_parameters from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler @@ -19,7 +21,6 @@ class DiffusionPolicy(AbstractPolicy): cfg_rgb_model, cfg_obs_encoder, cfg_optimizer, - cfg_ema, shape_meta: dict, horizon, n_action_steps, @@ -42,7 +43,6 @@ class DiffusionPolicy(AbstractPolicy): if cfg_obs_encoder.crop_shape is not None: rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model) - rgb_model = hydra.utils.instantiate(cfg_rgb_model) obs_encoder = MultiImageObsEncoder( rgb_model=rgb_model, **cfg_obs_encoder, @@ -70,12 +70,9 @@ class DiffusionPolicy(AbstractPolicy): if torch.cuda.is_available() and cfg_device == "cuda": self.diffusion.cuda() - self.ema = None - if self.cfg.use_ema: - self.ema = hydra.utils.instantiate( - cfg_ema, - model=copy.deepcopy(self.diffusion), - ) + self.ema_diffusion = None + if self.cfg.ema.enable: + self.ema_diffusion = copy.deepcopy(self.diffusion) self.optimizer = hydra.utils.instantiate( cfg_optimizer, @@ -98,6 +95,9 @@ class DiffusionPolicy(AbstractPolicy): @torch.no_grad() def select_actions(self, observation, step_count): + """ + Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights. + """ # TODO(rcadene): remove unused step_count del step_count @@ -105,7 +105,10 @@ class DiffusionPolicy(AbstractPolicy): "image": observation["image"], "agent_pos": observation["state"], } - out = self.diffusion.predict_action(obs_dict) + if self.training: + out = self.diffusion.predict_action(obs_dict) + else: + out = self.ema_diffusion.predict_action(obs_dict) action = out["action"] return action @@ -172,8 +175,8 @@ class DiffusionPolicy(AbstractPolicy): self.optimizer.zero_grad() self.lr_scheduler.step() - if self.ema is not None: - self.ema.step(self.diffusion) + if self.cfg.ema.enable: + update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate) info = { "loss": loss.item(), @@ -195,4 +198,10 @@ class DiffusionPolicy(AbstractPolicy): def load(self, fp): d = torch.load(fp) - self.load_state_dict(d) + missing_keys, unexpected_keys = self.load_state_dict(d, strict=False) + if len(missing_keys) > 0: + assert all(k.startswith("ema_diffusion.") for k in missing_keys) + logging.warning( + "DiffusionPolicy.load expected ema parameters in loaded state dict but none were found." + ) + assert len(unexpected_keys) == 0 diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 7961beed..32a366b3 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -16,7 +16,6 @@ def make_policy(cfg): cfg_rgb_model=cfg.rgb_model, cfg_obs_encoder=cfg.obs_encoder, cfg_optimizer=cfg.optimizer, - cfg_ema=cfg.ema, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, **cfg.policy, ) @@ -41,7 +40,7 @@ def make_policy(cfg): policy.load(cfg.policy.pretrained_model_path) # import torch - # loaded = torch.load('/home/alexander/Downloads/dp_ema.pth') + # loaded = torch.load('/home/alexander/Downloads/dp.pth') # aligned = {} # their_prefix = "obs_encoder.obs_nets.image.backbone" diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 52fd1d60..90d4c06b 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -12,14 +12,14 @@ hydra: seed: 1337 # batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index # NOTE: only diffusion policy supports rollout_batch_size > 1 -rollout_batch_size: 1 +rollout_batch_size: 10 device: cuda # cpu prefetch: 4 eval_freq: ??? save_freq: ??? eval_episodes: ??? save_video: false -save_model: false +save_model: true save_buffer: false train_steps: ??? fps: ??? @@ -34,6 +34,6 @@ policy: ??? wandb: enable: true # Set to true to disable saving an artifact despite save_model == True - disable_artifact: false + disable_artifact: true project: lerobot notes: "" diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 2b63f7e1..a81952e0 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -21,12 +21,12 @@ past_action_visible: False keypoint_visible_rate: 1.0 obs_as_global_cond: True -eval_episodes: 1 -eval_freq: 10000 -save_freq: 100000 +eval_episodes: 50 +eval_freq: 5000 +save_freq: 5000 log_freq: 250 -offline_steps: 1344000 +offline_steps: 50000 online_steps: 0 offline_prioritized_sampler: true @@ -58,7 +58,9 @@ policy: balanced_sampling: false utd: 1 offline_steps: ${offline_steps} - use_ema: true + ema: + enable: true + rate: 0.999 lr_scheduler: cosine lr_warmup_steps: 500 grad_clip_norm: 10 @@ -87,14 +89,7 @@ rgb_model: model_name: resnet18 pretrained: false num_keypoints: 32 - -ema: - _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel - update_after_step: 0 - inv_gamma: 1.0 - power: 0.75 - min_value: 0.0 - max_value: 0.9999 + relu: true optimizer: _target_: torch.optim.AdamW diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5ecd616d..a2039006 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -155,11 +155,7 @@ def train(cfg: dict, out_dir=None, job_name=None): num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - td_policy = TensorDictModule( - policy, - in_keys=["observation", "step_count"], - out_keys=["action"], - ) + td_policy = TensorDictModule(policy, in_keys=["observation", "step_count"], out_keys=["action"]) # log metrics to terminal and wandb logger = Logger(out_dir, job_name, cfg) @@ -174,19 +170,9 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") - step = 0 # number of policy update (forward + backward + optim) - - is_offline = True - for offline_step in range(cfg.offline_steps): - if offline_step == 0: - logging.info("Start offline training on a fixed dataset") - # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? - policy.train() - train_info = policy.update(offline_buffer, step) - if step % cfg.log_freq == 0: - log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) - - if step > 0 and step % cfg.eval_freq == 0: + # Note: this helper will be used in offline and online training loops. + def _maybe_eval_and_maybe_save(step): + if step % cfg.eval_freq == 0: logging.info(f"Eval policy at step {step}") eval_info, first_video = eval_policy( env, @@ -202,11 +188,27 @@ def train(cfg: dict, out_dir=None, job_name=None): logger.log_video(first_video, step, mode="eval") logging.info("Resume training") - if step > 0 and cfg.save_model and step % cfg.save_freq == 0: - logging.info(f"Checkpoint policy at step {step}") + if cfg.save_model and step % cfg.save_freq == 0: + logging.info(f"Checkpoint policy after step {step}") logger.save_model(policy, identifier=step) logging.info("Resume training") + step = 0 # number of policy update (forward + backward + optim) + + is_offline = True + for offline_step in range(cfg.offline_steps): + if offline_step == 0: + logging.info("Start offline training on a fixed dataset") + # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? + policy.train() + train_info = policy.update(offline_buffer, step) + if step % cfg.log_freq == 0: + log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) + + # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in + # step + 1. + _maybe_eval_and_maybe_save(step + 1) + step += 1 demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None @@ -248,24 +250,9 @@ def train(cfg: dict, out_dir=None, job_name=None): train_info.update(rollout_info) log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) - if step > 0 and step % cfg.eval_freq == 0: - logging.info(f"Eval policy at step {step}") - eval_info, first_video = eval_policy( - env, - td_policy, - num_episodes=cfg.eval_episodes, - max_steps=cfg.env.episode_length // cfg.n_action_steps, - return_first_video=True, - ) - log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) - if cfg.wandb.enable: - logger.log_video(first_video, step, mode="eval") - logging.info("Resume training") - - if step > 0 and cfg.save_model and step % cfg.save_freq == 0: - logging.info(f"Checkpoint policy at step {step}") - logger.save_model(policy, identifier=step) - logging.info("Resume training") + # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass + # in step + 1. + _maybe_eval_and_maybe_save(step + 1) step += 1 online_step += 1