From 83dc00683c5d595f54bec74c97e54de7c9f8502c Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Wed, 22 Jan 2025 09:00:16 +0000 Subject: [PATCH] Stable version of rlpd + drq --- lerobot/common/envs/factory.py | 84 ++++++ lerobot/common/envs/utils.py | 40 +++ .../common/policies/sac/configuration_sac.py | 37 +-- lerobot/common/policies/sac/modeling_sac.py | 132 ++++++---- lerobot/configs/policy/sac_manyskill.yaml | 97 +++++++ lerobot/scripts/train_sac.py | 244 +++++++++++------- 6 files changed, 460 insertions(+), 174 deletions(-) create mode 100644 lerobot/configs/policy/sac_manyskill.yaml diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 8450f84b..8aec915c 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +from collections import deque import gymnasium as gym @@ -67,3 +68,86 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g ) return env + + +def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None: + """Make ManiSkill3 gym environment""" + from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv + + env = gym.make( + cfg.env.task, + obs_mode=cfg.env.obs, + control_mode=cfg.env.control_mode, + render_mode=cfg.env.render_mode, + sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size), + num_envs=n_envs, + ) + # cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode + env = ManiSkillVectorEnv(env, ignore_terminations=True) + # state should have the size of 25 + # env = ConvertToLeRobotEnv(env, n_envs) + # env = PixelWrapper(cfg, env, n_envs) + env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env) + env.unwrapped.metadata["render_fps"] = 20 + + return env + + +class PixelWrapper(gym.Wrapper): + """ + Wrapper for pixel observations. Works with Maniskill vectorized environments + """ + + def __init__(self, cfg, env, num_envs, num_frames=3): + super().__init__(env) + self.cfg = cfg + self.env = env + self.observation_space = gym.spaces.Box( + low=0, + high=255, + shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size), + dtype=np.uint8, + ) + self._frames = deque([], maxlen=num_frames) + self._render_size = cfg.env.render_size + + def _get_obs(self, obs): + frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2) + self._frames.append(frame) + return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)} + + def reset(self, seed): + obs, info = self.env.reset() # (seed=seed) + for _ in range(self._frames.maxlen): + obs_frames = self._get_obs(obs) + return obs_frames, info + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + return self._get_obs(obs), reward, terminated, truncated, info + +class ConvertToLeRobotEnv(gym.Wrapper): + def __init__(self, env, num_envs): + super().__init__(env) + def reset(self, seed=None, options=None): + obs, info = self.env.reset(seed=seed, options={}) + return self._get_obs(obs), info + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + return self._get_obs(obs), reward, terminated, truncated, info + def _get_obs(self, observation): + sensor_data = observation.pop("sensor_data") + del observation["sensor_param"] + images = [] + for cam_data in sensor_data.values(): + images.append(cam_data["rgb"]) + + images = torch.concat(images, axis=-1) + # flatten the rest of the data which should just be state data + observation = common.flatten_state_dict( + observation, use_torch=True, device=self.base_env.device + ) + ret = dict() + ret["state"] = observation + ret["pixels"] = images + return ret \ No newline at end of file diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 30bbaf39..ead6bf45 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -33,6 +33,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten """ # map to expected inputs for the policy return_observations = {} + # TODO: You have to merge all tensors from agent key and extra key + # You don't keep sensor param key in the observation + # And you keep sensor data rgb if "pixels" in observations: if isinstance(observations["pixels"], dict): imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} @@ -56,6 +59,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten img /= 255 return_observations[imgkey] = img + # obs state agent qpos and qvel + # image if "environment_state" in observations: return_observations["observation.environment_state"] = torch.from_numpy( @@ -86,3 +91,38 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: policy_features[policy_key] = feature return policy_features + + +def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]: + """Convert environment observation to LeRobot format observation. + Args: + observation: Dictionary of observation batches from a Gym vector environment. + Returns: + Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. + """ + # map to expected inputs for the policy + return_observations = {} + # TODO: You have to merge all tensors from agent key and extra key + # You don't keep sensor param key in the observation + # And you keep sensor data rgb + q_pos = observations["agent"]["qpos"] + q_vel = observations["agent"]["qvel"] + tcp_pos = observations["extra"]["tcp_pose"] + img = observations["sensor_data"]["base_camera"]["rgb"] + + _, h, w, c = img.shape + assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" + + # sanity check that images are uint8 + assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" + + # convert to channel first of type float32 in range [0,1] + img = einops.rearrange(img, "b h w c -> b c h w").contiguous() + img = img.type(torch.float32) + img /= 255 + + state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1) + + return_observations["observation.image"] = img + return_observations["observation.state"] = state + return return_observations diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 97ba04b1..62f35ed5 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -19,34 +19,6 @@ from dataclasses import dataclass, field from typing import Any -@dataclass -class SACConfig: - input_shapes: dict[str, list[int]] = field( - default_factory=lambda: { - "observation.image": [3, 84, 84], - "observation.state": [4], - } - ) - - output_shapes: dict[str, list[int]] = field( - default_factory=lambda: { - "action": [2], - } - ) - - # Normalization / Unnormalization - input_normalization_modes: dict[str, str] = field( - default_factory=lambda: { - "observation.image": "mean_std", - "observation.state": "min_max", - "observation.environment_state": "min_max", - } - ) - output_normalization_modes: dict[str, str] = field( - default_factory=lambda: {"action": "min_max"}, - ) -from dataclasses import dataclass, field - @dataclass class SACConfig: input_shapes: dict[str, list[int]] = field( @@ -67,10 +39,13 @@ class SACConfig: "observation.environment_state": "min_max", } ) - output_normalization_modes: dict[str, str] = field( - default_factory=lambda: {"action": "min_max"} + output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"}) + output_normalization_params: dict[str, dict[str, list[float]]] = field( + default_factory=lambda: { + "action": {"min": [-1, -1], "max": [1, 1]}, + } ) - + camera_number: int = 1 # Add type annotations for these fields: image_encoder_hidden_dim: int = 32 shared_encoder: bool = False diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 35b1bd5a..8fb46199 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -42,37 +42,31 @@ class SACPolicy( name = "sac" def __init__( - self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None + self, + config: SACConfig | None = None, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + device: str = "cpu", ): super().__init__() if config is None: config = SACConfig() self.config = config - if config.input_normalization_modes is not None: self.normalize_inputs = Normalize( config.input_shapes, config.input_normalization_modes, dataset_stats ) else: self.normalize_inputs = nn.Identity() - # HACK: we need to pass the dataset_stats to the normalization functions - # NOTE: This is for biwalker environment - dataset_stats = dataset_stats or { - "action": { - "min": torch.tensor([-1.0, -1.0, -1.0, -1.0]), - "max": torch.tensor([1.0, 1.0, 1.0, 1.0]), - } - } + output_normalization_params = {} + for outer_key, inner_dict in config.output_normalization_params.items(): + output_normalization_params[outer_key] = {} + for key, value in inner_dict.items(): + output_normalization_params[outer_key][key] = torch.tensor(value) - # NOTE: This is for pusht environment - # dataset_stats = dataset_stats or { - # "action": { - # "min": torch.tensor([0, 0]), - # "max": torch.tensor([512, 512]), - # } - # } + # HACK: This is hacky and should be removed + dataset_stats = dataset_stats or output_normalization_params self.normalize_targets = Normalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) @@ -82,7 +76,7 @@ class SACPolicy( if config.shared_encoder: encoder_critic = SACObservationEncoder(config) - encoder_actor = encoder_critic + encoder_actor: SACObservationEncoder = encoder_critic else: encoder_critic = SACObservationEncoder(config) encoder_actor = SACObservationEncoder(config) @@ -95,6 +89,7 @@ class SACPolicy( input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], **config.critic_network_kwargs, ), + device=device, ) critic_nets.append(critic_net) @@ -106,40 +101,35 @@ class SACPolicy( input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], **config.critic_network_kwargs, ), + device=device, ) target_critic_nets.append(target_critic_net) - self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) - self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics) + self.critic_ensemble = create_critic_ensemble( + critics=critic_nets, num_critics=config.num_critics, device=device + ) + self.critic_target = create_critic_ensemble( + critics=target_critic_nets, num_critics=config.num_critics, device=device + ) self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), action_dim=config.output_shapes["action"][0], + device=device, + encoder_is_shared=config.shared_encoder, **config.policy_kwargs, ) if config.target_entropy is None: config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2) - # TODO: fix later device # TODO: Handle the case where the temparameter is a fixed - self.log_alpha = torch.zeros(1, requires_grad=True, device="cpu") + self.log_alpha = torch.zeros(1, requires_grad=True, device=device) self.temperature = self.log_alpha.exp().item() def reset(self): - """ - Clear observation and action queues. Should be called on `env.reset()` - queues are populated during rollout of the policy, they contain the n latest observations and actions - """ - - self._queues = { - "observation.state": deque(maxlen=1), - "action": deque(maxlen=1), - } - if "observation.image" in self.config.input_shapes: - self._queues["observation.image"] = deque(maxlen=1) - if "observation.environment_state" in self.config.input_shapes: - self._queues["observation.environment_state"] = deque(maxlen=1) + """Reset the policy""" + pass @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: @@ -334,6 +324,7 @@ class Policy(nn.Module): init_final: Optional[float] = None, use_tanh_squash: bool = False, device: str = "cpu", + encoder_is_shared: bool = False, ): super().__init__() self.device = torch.device(device) @@ -344,7 +335,12 @@ class Policy(nn.Module): self.log_std_max = log_std_max self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.use_tanh_squash = use_tanh_squash + self.parameters_to_optimize = [] + self.parameters_to_optimize += list(self.network.parameters()) + + if self.encoder is not None and not encoder_is_shared: + self.parameters_to_optimize += list(self.encoder.parameters()) # Find the last Linear layer's output dimension for layer in reversed(network.net): if isinstance(layer, nn.Linear): @@ -358,6 +354,7 @@ class Policy(nn.Module): else: orthogonal_init()(self.mean_layer.weight) + self.parameters_to_optimize += list(self.mean_layer.parameters()) # Standard deviation layer or parameter if fixed_std is None: self.std_layer = nn.Linear(out_features, action_dim) @@ -366,6 +363,7 @@ class Policy(nn.Module): nn.init.uniform_(self.std_layer.bias, -init_final, init_final) else: orthogonal_init()(self.std_layer.weight) + self.parameters_to_optimize += list(self.std_layer.parameters()) self.to(self.device) @@ -428,44 +426,78 @@ class SACObservationEncoder(nn.Module): """ super().__init__() self.config = config - if "observation.image" in config.input_shapes: self.image_enc_layers = nn.Sequential( nn.Conv2d( - config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2 + in_channels=config.input_shapes["observation.image"][0], + out_channels=config.image_encoder_hidden_dim, + kernel_size=7, + stride=2, ), nn.ReLU(), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=5, + stride=2, + ), nn.ReLU(), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), nn.ReLU(), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), nn.ReLU(), ) + self.camera_number = config.camera_number + self.aggregation_size: int = 0 + dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) with torch.inference_mode(): out_shape = self.image_enc_layers(dummy_batch).shape[1:] self.image_enc_layers.extend( - nn.Sequential( + sequential=nn.Sequential( nn.Flatten(), - nn.Linear(np.prod(out_shape), config.latent_dim), - nn.LayerNorm(config.latent_dim), + nn.Linear( + in_features=np.prod(out_shape) * self.camera_number, out_features=config.latent_dim + ), + nn.LayerNorm(normalized_shape=config.latent_dim), nn.Tanh(), ) ) + + self.aggregation_size += config.latent_dim * self.camera_number if "observation.state" in config.input_shapes: self.state_enc_layers = nn.Sequential( - nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim), - nn.LayerNorm(config.latent_dim), + nn.Linear( + in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim + ), + nn.LayerNorm(normalized_shape=config.latent_dim), nn.Tanh(), ) + self.aggregation_size += config.latent_dim + if "observation.environment_state" in config.input_shapes: self.env_state_enc_layers = nn.Sequential( - nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim), - nn.LayerNorm(config.latent_dim), + nn.Linear( + in_features=config.input_shapes["observation.environment_state"][0], + out_features=config.latent_dim, + ), + nn.LayerNorm(normalized_shape=config.latent_dim), nn.Tanh(), ) + self.aggregation_size += config.latent_dim + self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) + def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: """Encode the image and/or state vector. @@ -482,7 +514,11 @@ class SACObservationEncoder(nn.Module): if "observation.state" in self.config.input_shapes: feat.append(self.state_enc_layers(obs_dict["observation.state"])) # TODO(ke-wang): currently average over all features, concatenate all features maybe a better way - return torch.stack(feat, dim=0).mean(0) + # return torch.stack(feat, dim=0).mean(0) + features = torch.cat(tensors=feat, dim=-1) + features = self.aggregation_layer(features) + + return features @property def output_dim(self) -> int: diff --git a/lerobot/configs/policy/sac_manyskill.yaml b/lerobot/configs/policy/sac_manyskill.yaml new file mode 100644 index 00000000..e4c3f17d --- /dev/null +++ b/lerobot/configs/policy/sac_manyskill.yaml @@ -0,0 +1,97 @@ +# @package _global_ + +# Train with: +# +# python lerobot/scripts/train.py \ +# +dataset=lerobot/pusht_keypoints +# env=pusht \ +# env.gym.obs_type=environment_state_agent_pos \ + +seed: 1 +dataset_repo_id: null + + +training: + # Offline training dataloader + num_workers: 4 + + # batch_size: 256 + batch_size: 512 + grad_clip_norm: 10.0 + lr: 3e-4 + + eval_freq: 2500 + log_freq: 500 + save_freq: 50000 + + online_steps: 1000000 + online_rollout_n_episodes: 10 + online_rollout_batch_size: 10 + online_steps_between_rollouts: 1000 + online_sampling_ratio: 1.0 + online_env_seed: 10000 + online_buffer_capacity: 1000000 + online_buffer_seed_size: 0 + online_step_before_learning: 5000 + do_online_rollout_async: false + policy_update_freq: 1 + + # delta_timestamps: + # observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + # observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + # action: "[i / ${fps} for i in range(${policy.horizon})]" + # next.reward: "[i / ${fps} for i in range(${policy.horizon})]" + +policy: + name: sac + + pretrained_model_path: + + # Input / output structure. + n_action_repeats: 1 + horizon: 1 + n_action_steps: 1 + + shared_encoder: true + input_shapes: + # # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.state: ["${env.state_dim}"] + observation.image: [3, 64, 64] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: null + output_normalization_modes: + action: min_max + output_normalization_params: + action: + min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0] + max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + + # Architecture / modeling. + # Neural networks. + image_encoder_hidden_dim: 32 + # discount: 0.99 + discount: 0.80 + temperature_init: 1.0 + num_critics: 2 + num_subsample_critics: null + critic_lr: 3e-4 + actor_lr: 3e-4 + temperature_lr: 3e-4 + # critic_target_update_weight: 0.005 + critic_target_update_weight: 0.01 + utd_ratio: 1 + + + # # Loss coefficients. + # reward_coeff: 0.5 + # expectile_weight: 0.9 + # value_coeff: 0.1 + # consistency_coeff: 20.0 + # advantage_scaling: 3.0 + # pi_coeff: 0.5 + # temporal_decay_coeff: 0.5 + # # Target model. + # target_model_momentum: 0.995 diff --git a/lerobot/scripts/train_sac.py b/lerobot/scripts/train_sac.py index bb9b51d5..866415d0 100644 --- a/lerobot/scripts/train_sac.py +++ b/lerobot/scripts/train_sac.py @@ -14,34 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import time -from contextlib import nullcontext -from copy import deepcopy -from pathlib import Path +import functools from pprint import pformat import random -from typing import Optional, Sequence, TypedDict +from typing import Optional, Sequence, TypedDict, Callable import hydra -import numpy as np import torch -from deepdiff import DeepDiff -from omegaconf import DictConfig, ListConfig, OmegaConf -from termcolor import colored +import torch.nn.functional as F from torch import nn -from torch.cuda.amp import GradScaler from tqdm import tqdm +from deepdiff import DeepDiff +from omegaconf import DictConfig, OmegaConf -from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps -from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset, LeRobotDataset -from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights -from lerobot.common.datasets.sampler import EpisodeAwareSampler -from lerobot.common.datasets.utils import cycle -from lerobot.common.envs.factory import make_env -from lerobot.common.envs.utils import preprocess_observation +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + +# TODO: Remove the import of maniskill +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.envs.factory import make_env, make_maniskill_env +from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy -from lerobot.common.policies.policy_protocol import PolicyWithUpdate from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.utils import ( @@ -56,7 +49,8 @@ from lerobot.scripts.eval import eval_policy def make_optimizers_and_scheduler(cfg, policy): optimizer_actor = torch.optim.Adam( - params=policy.actor.parameters(), + # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor + params=policy.actor.parameters_to_optimize, lr=policy.config.actor_lr, ) optimizer_critic = torch.optim.Adam( @@ -73,11 +67,6 @@ def make_optimizers_and_scheduler(cfg, policy): return optimizers, lr_scheduler -# def update_policy(policy, batch, optimizers, grad_clip_norm): - -# NOTE: This is temporary, online buffer or query lerobot dataset is not performant enough yet - - class Transition(TypedDict): state: dict[str, torch.Tensor] action: torch.Tensor @@ -95,13 +84,62 @@ class BatchTransition(TypedDict): done: torch.Tensor +def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: + """ + Perform a per-image random crop over a batch of images in a vectorized way. + (Same as shown previously.) + """ + B, C, H, W = images.shape + crop_h, crop_w = output_size + + if crop_h > H or crop_w > W: + raise ValueError( + f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})." + ) + + tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device) + lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device) + + rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1) + cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1) + + rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w) + cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w) + + images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) + + # Gather pixels + cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :] + # cropped_hwcn => (B, crop_h, crop_w, C) + + cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) + return cropped + + +def random_shift(images: torch.Tensor, pad: int = 4): + """Vectorized random shift, imgs: (B,C,H,W), pad: #pixels""" + _, _, h, w = images.shape + images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate") + return random_crop_vectorized(images=images, output_size=(h, w)) + + class ReplayBuffer: - def __init__(self, capacity: int, device: str = "cuda:0", state_keys: Optional[Sequence[str]] = None): + def __init__( + self, + capacity: int, + device: str = "cuda:0", + state_keys: Optional[Sequence[str]] = None, + image_augmentation_function: Optional[Callable] = None, + use_drq: bool = True, + ): """ Args: capacity (int): Maximum number of transitions to store in the buffer. device (str): The device where the tensors will be moved ("cuda:0" or "cpu"). state_keys (List[str]): The list of keys that appear in `state` and `next_state`. + image_augmentation_function (Optional[Callable]): A function that takes a batch of images + and returns a batch of augmented images. If None, a default augmentation function is used. + use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. """ self.capacity = capacity self.device = device @@ -111,6 +149,9 @@ class ReplayBuffer: # If no state_keys provided, default to an empty list # (you can handle this differently if needed) self.state_keys = state_keys if state_keys is not None else [] + if image_augmentation_function is None: + self.image_augmentation_function = functools.partial(random_shift, pad=4) + self.use_drq = use_drq def add( self, @@ -134,7 +175,7 @@ class ReplayBuffer: done=done, complementary_info=complementary_info, ) - self.position = (self.position + 1) % self.capacity + self.position: int = (self.position + 1) % self.capacity @classmethod def from_lerobot_dataset( @@ -143,6 +184,18 @@ class ReplayBuffer: device: str = "cuda:0", state_keys: Optional[Sequence[str]] = None, ) -> "ReplayBuffer": + """ + Convert a LeRobotDataset into a ReplayBuffer. + + Args: + lerobot_dataset (LeRobotDataset): The dataset to convert. + device (str): The device . Defaults to "cuda:0". + state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`. + Defaults to None. + + Returns: + ReplayBuffer: The replay buffer with offline dataset transitions. + """ # We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from # a replay buffer than from a lerobot dataset. replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys) @@ -248,6 +301,8 @@ class ReplayBuffer: batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to( self.device ) + if key.startswith("observation.image") and self.use_drq: + batch_state[key] = self.image_augmentation_function(batch_state[key]) # -- Build batched actions -- batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device) @@ -263,6 +318,8 @@ class ReplayBuffer: batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to( self.device ) + if key.startswith("observation.image") and self.use_drq: + batch_next_state[key] = self.image_augmentation_function(batch_next_state[key]) # -- Build batched dones -- batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( @@ -285,7 +342,7 @@ class ReplayBuffer: def concatenate_batch_transitions( left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition ) -> BatchTransition: - """Be careful it change the left_batch_transitions in place""" + """NOTE: Be careful it change the left_batch_transitions in place""" left_batch_transitions["state"] = { key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0) for key in left_batch_transitions["state"] @@ -321,11 +378,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size) # NOTE: Off policy algorithm are efficient enought to use a single environment logging.info("make_env online") - online_env = make_env(cfg, n_envs=1) - + # online_env = make_env(cfg, n_envs=1) + # TODO: Remove the import of maniskill and unifiy with make env + online_env = make_maniskill_env(cfg, n_envs=1) if cfg.training.eval_freq > 0: logging.info("make_env eval") - eval_env = make_env(cfg, n_envs=1) + # eval_env = make_env(cfg, n_envs=1) + # TODO: Remove the import of maniskill and unifiy with make env + eval_env = make_maniskill_env(cfg, n_envs=1) # TODO: Add a way to resume training @@ -348,6 +408,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Hack: But if we do online traning, we do not need dataset_stats dataset_stats=None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, + device=device, ) assert isinstance(policy, nn.Module) @@ -360,17 +421,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No log_output_dir(out_dir) logging.info(f"{cfg.env.task=}") - # TODO: Handle offline steps - # logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})") logging.info(f"{cfg.training.online_steps=}") - # logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})") - # logging.info(f"{offline_dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") obs, info = online_env.reset() - obs = preprocess_observation(obs) + # HACK for maniskill + # obs = preprocess_observation(obs) + obs = preprocess_maniskill_observation(obs) obs = {key: obs[key].to(device, non_blocking=True) for key in obs} replay_buffer = ReplayBuffer( @@ -378,8 +437,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No ) batch_size = cfg.training.batch_size - # if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig): - # raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.") + if cfg.dataset_repo_id is not None: logging.info("make_dataset offline buffer") offline_dataset = make_dataset(cfg) @@ -404,7 +462,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # HACK action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True) - next_obs = preprocess_observation(next_obs) + # HACK: For maniskill + # next_obs = preprocess_observation(next_obs) + next_obs = preprocess_maniskill_observation(next_obs) next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs} sum_reward_episode += float(reward[0]) # Because we are using a single environment @@ -413,16 +473,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}") logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step) sum_reward_episode = 0 - if "final_info" in info: - if "is_success" in info["final_info"][0]: - logging.info( - f"Global step {interaction_step}: Episode success: {info['final_info'][0]['is_success']}" - ) - if "coverage" in info["final_info"][0]: - logging.info( - f"Global step {interaction_step}: Episode final coverage: {info['final_info'][0]['coverage']} \n" - ) - logger.log_dict({"Final coverage": info["final_info"][0]["coverage"]}, interaction_step) + # HACK: This is for maniskill + logging.info( + f"global step {interaction_step}: episode success: {info['success'].float().item()} \n" + ) + logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step) replay_buffer.add( state=obs, @@ -433,38 +488,13 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No ) obs = next_obs - if interaction_step >= cfg.training.online_step_before_learning: - for _ in range(cfg.policy.utd_ratio - 1): - batch = replay_buffer.sample(batch_size) - if cfg.dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size) - batch = concatenate_batch_transitions(batch, batch_offline) - - actions = batch["action"] - rewards = batch["reward"] - observations = batch["state"] - next_observations = batch["next_state"] - done = batch["done"] - - loss_critic = policy.compute_loss_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - ) - optimizers["critic"].zero_grad() - loss_critic.backward() - optimizers["critic"].step() - + if interaction_step < cfg.training.online_step_before_learning: + continue + for _ in range(cfg.policy.utd_ratio - 1): batch = replay_buffer.sample(batch_size) if cfg.dataset_repo_id is not None: batch_offline = offline_replay_buffer.sample(batch_size) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) - # NOTE: We have to handle the normalization for the batch - # batch = policy.normalize_inputs(batch) + batch = concatenate_batch_transitions(batch, batch_offline) actions = batch["action"] rewards = batch["reward"] @@ -483,31 +513,55 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No loss_critic.backward() optimizers["critic"].step() - training_infos = {} - training_infos["loss_critic"] = loss_critic.item() + batch = replay_buffer.sample(batch_size) + if cfg.dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) - if interaction_step % cfg.training.policy_update_freq == 0: - # TD3 Trick - for _ in range(cfg.training.policy_update_freq): - loss_actor = policy.compute_loss_actor(observations=observations) + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] - optimizers["actor"].zero_grad() - loss_actor.backward() - optimizers["actor"].step() + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + optimizers["critic"].step() - training_infos["loss_actor"] = loss_actor.item() + training_infos = {} + training_infos["loss_critic"] = loss_critic.item() - loss_temperature = policy.compute_loss_temperature(observations=observations) - optimizers["temperature"].zero_grad() - loss_temperature.backward() - optimizers["temperature"].step() + if interaction_step % cfg.training.policy_update_freq == 0: + # TD3 Trick + for _ in range(cfg.training.policy_update_freq): + loss_actor = policy.compute_loss_actor(observations=observations) - training_infos["loss_temperature"] = loss_temperature.item() + optimizers["actor"].zero_grad() + loss_actor.backward() + optimizers["actor"].step() - if interaction_step % cfg.training.log_freq == 0: - logger.log_dict(training_infos, interaction_step, mode="train") + training_infos["loss_actor"] = loss_actor.item() - policy.update_target_networks() + loss_temperature = policy.compute_loss_temperature(observations=observations) + optimizers["temperature"].zero_grad() + loss_temperature.backward() + optimizers["temperature"].step() + + training_infos["loss_temperature"] = loss_temperature.item() + + if interaction_step % cfg.training.log_freq == 0: + logger.log_dict(training_infos, interaction_step, mode="train") + + policy.update_target_networks() @hydra.main(version_base="1.2", config_name="default", config_path="../configs")