Stable version of rlpd + drq

This commit is contained in:
AdilZouitine 2025-01-22 09:00:16 +00:00
parent 5b92465e38
commit 83dc00683c
6 changed files with 460 additions and 174 deletions

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib import importlib
from collections import deque
import gymnasium as gym import gymnasium as gym
@ -67,3 +68,86 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
) )
return env return env
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
env = gym.make(
cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
num_envs=n_envs,
)
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
env = ManiSkillVectorEnv(env, ignore_terminations=True)
# state should have the size of 25
# env = ConvertToLeRobotEnv(env, n_envs)
# env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env.unwrapped.metadata["render_fps"] = 20
return env
class PixelWrapper(gym.Wrapper):
"""
Wrapper for pixel observations. Works with Maniskill vectorized environments
"""
def __init__(self, cfg, env, num_envs, num_frames=3):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
dtype=np.uint8,
)
self._frames = deque([], maxlen=num_frames)
self._render_size = cfg.env.render_size
def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame)
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
def reset(self, seed):
obs, info = self.env.reset() # (seed=seed)
for _ in range(self._frames.maxlen):
obs_frames = self._get_obs(obs)
return obs_frames, info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
class ConvertToLeRobotEnv(gym.Wrapper):
def __init__(self, env, num_envs):
super().__init__(env)
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options={})
return self._get_obs(obs), info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
def _get_obs(self, observation):
sensor_data = observation.pop("sensor_data")
del observation["sensor_param"]
images = []
for cam_data in sensor_data.values():
images.append(cam_data["rgb"])
images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(
observation, use_torch=True, device=self.base_env.device
)
ret = dict()
ret["state"] = observation
ret["pixels"] = images
return ret

View File

@ -33,6 +33,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
""" """
# map to expected inputs for the policy # map to expected inputs for the policy
return_observations = {} return_observations = {}
# TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation
# And you keep sensor data rgb
if "pixels" in observations: if "pixels" in observations:
if isinstance(observations["pixels"], dict): if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
@ -56,6 +59,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
img /= 255 img /= 255
return_observations[imgkey] = img return_observations[imgkey] = img
# obs state agent qpos and qvel
# image
if "environment_state" in observations: if "environment_state" in observations:
return_observations["observation.environment_state"] = torch.from_numpy( return_observations["observation.environment_state"] = torch.from_numpy(
@ -86,3 +91,38 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
policy_features[policy_key] = feature policy_features[policy_key] = feature
return policy_features return policy_features
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# map to expected inputs for the policy
return_observations = {}
# TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation
# And you keep sensor data rgb
q_pos = observations["agent"]["qpos"]
q_vel = observations["agent"]["qvel"]
tcp_pos = observations["extra"]["tcp_pose"]
img = observations["sensor_data"]["base_camera"]["rgb"]
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
return_observations["observation.image"] = img
return_observations["observation.state"] = state
return return_observations

View File

@ -19,34 +19,6 @@ from dataclasses import dataclass, field
from typing import Any from typing import Any
@dataclass
class SACConfig:
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [2],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
"observation.environment_state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
from dataclasses import dataclass, field
@dataclass @dataclass
class SACConfig: class SACConfig:
input_shapes: dict[str, list[int]] = field( input_shapes: dict[str, list[int]] = field(
@ -67,10 +39,13 @@ class SACConfig:
"observation.environment_state": "min_max", "observation.environment_state": "min_max",
} }
) )
output_normalization_modes: dict[str, str] = field( output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
default_factory=lambda: {"action": "min_max"} output_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"action": {"min": [-1, -1], "max": [1, 1]},
}
) )
camera_number: int = 1
# Add type annotations for these fields: # Add type annotations for these fields:
image_encoder_hidden_dim: int = 32 image_encoder_hidden_dim: int = 32
shared_encoder: bool = False shared_encoder: bool = False

View File

@ -42,37 +42,31 @@ class SACPolicy(
name = "sac" name = "sac"
def __init__( def __init__(
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None self,
config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
device: str = "cpu",
): ):
super().__init__() super().__init__()
if config is None: if config is None:
config = SACConfig() config = SACConfig()
self.config = config self.config = config
if config.input_normalization_modes is not None: if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats config.input_shapes, config.input_normalization_modes, dataset_stats
) )
else: else:
self.normalize_inputs = nn.Identity() self.normalize_inputs = nn.Identity()
# HACK: we need to pass the dataset_stats to the normalization functions
# NOTE: This is for biwalker environment output_normalization_params = {}
dataset_stats = dataset_stats or { for outer_key, inner_dict in config.output_normalization_params.items():
"action": { output_normalization_params[outer_key] = {}
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]), for key, value in inner_dict.items():
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]), output_normalization_params[outer_key][key] = torch.tensor(value)
}
}
# NOTE: This is for pusht environment # HACK: This is hacky and should be removed
# dataset_stats = dataset_stats or { dataset_stats = dataset_stats or output_normalization_params
# "action": {
# "min": torch.tensor([0, 0]),
# "max": torch.tensor([512, 512]),
# }
# }
self.normalize_targets = Normalize( self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
@ -82,7 +76,7 @@ class SACPolicy(
if config.shared_encoder: if config.shared_encoder:
encoder_critic = SACObservationEncoder(config) encoder_critic = SACObservationEncoder(config)
encoder_actor = encoder_critic encoder_actor: SACObservationEncoder = encoder_critic
else: else:
encoder_critic = SACObservationEncoder(config) encoder_critic = SACObservationEncoder(config)
encoder_actor = SACObservationEncoder(config) encoder_actor = SACObservationEncoder(config)
@ -95,6 +89,7 @@ class SACPolicy(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs, **config.critic_network_kwargs,
), ),
device=device,
) )
critic_nets.append(critic_net) critic_nets.append(critic_net)
@ -106,40 +101,35 @@ class SACPolicy(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs, **config.critic_network_kwargs,
), ),
device=device,
) )
target_critic_nets.append(target_critic_net) target_critic_nets.append(target_critic_net)
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) self.critic_ensemble = create_critic_ensemble(
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics) critics=critic_nets, num_critics=config.num_critics, device=device
)
self.critic_target = create_critic_ensemble(
critics=target_critic_nets, num_critics=config.num_critics, device=device
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.actor = Policy( self.actor = Policy(
encoder=encoder_actor, encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0], action_dim=config.output_shapes["action"][0],
device=device,
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs, **config.policy_kwargs,
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2) config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
# TODO: fix later device
# TODO: Handle the case where the temparameter is a fixed # TODO: Handle the case where the temparameter is a fixed
self.log_alpha = torch.zeros(1, requires_grad=True, device="cpu") self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
def reset(self): def reset(self):
""" """Reset the policy"""
Clear observation and action queues. Should be called on `env.reset()` pass
queues are populated during rollout of the policy, they contain the n latest observations and actions
"""
self._queues = {
"observation.state": deque(maxlen=1),
"action": deque(maxlen=1),
}
if "observation.image" in self.config.input_shapes:
self._queues["observation.image"] = deque(maxlen=1)
if "observation.environment_state" in self.config.input_shapes:
self._queues["observation.environment_state"] = deque(maxlen=1)
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@ -334,6 +324,7 @@ class Policy(nn.Module):
init_final: Optional[float] = None, init_final: Optional[float] = None,
use_tanh_squash: bool = False, use_tanh_squash: bool = False,
device: str = "cpu", device: str = "cpu",
encoder_is_shared: bool = False,
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@ -344,7 +335,12 @@ class Policy(nn.Module):
self.log_std_max = log_std_max self.log_std_max = log_std_max
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.use_tanh_squash = use_tanh_squash self.use_tanh_squash = use_tanh_squash
self.parameters_to_optimize = []
self.parameters_to_optimize += list(self.network.parameters())
if self.encoder is not None and not encoder_is_shared:
self.parameters_to_optimize += list(self.encoder.parameters())
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network.net): for layer in reversed(network.net):
if isinstance(layer, nn.Linear): if isinstance(layer, nn.Linear):
@ -358,6 +354,7 @@ class Policy(nn.Module):
else: else:
orthogonal_init()(self.mean_layer.weight) orthogonal_init()(self.mean_layer.weight)
self.parameters_to_optimize += list(self.mean_layer.parameters())
# Standard deviation layer or parameter # Standard deviation layer or parameter
if fixed_std is None: if fixed_std is None:
self.std_layer = nn.Linear(out_features, action_dim) self.std_layer = nn.Linear(out_features, action_dim)
@ -366,6 +363,7 @@ class Policy(nn.Module):
nn.init.uniform_(self.std_layer.bias, -init_final, init_final) nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else: else:
orthogonal_init()(self.std_layer.weight) orthogonal_init()(self.std_layer.weight)
self.parameters_to_optimize += list(self.std_layer.parameters())
self.to(self.device) self.to(self.device)
@ -428,44 +426,78 @@ class SACObservationEncoder(nn.Module):
""" """
super().__init__() super().__init__()
self.config = config self.config = config
if "observation.image" in config.input_shapes: if "observation.image" in config.input_shapes:
self.image_enc_layers = nn.Sequential( self.image_enc_layers = nn.Sequential(
nn.Conv2d( nn.Conv2d(
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2 in_channels=config.input_shapes["observation.image"][0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
), ),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=5,
stride=2,
),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(), nn.ReLU(),
) )
self.camera_number = config.camera_number
self.aggregation_size: int = 0
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode(): with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:] out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend( self.image_enc_layers.extend(
nn.Sequential( sequential=nn.Sequential(
nn.Flatten(), nn.Flatten(),
nn.Linear(np.prod(out_shape), config.latent_dim), nn.Linear(
nn.LayerNorm(config.latent_dim), in_features=np.prod(out_shape) * self.camera_number, out_features=config.latent_dim
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
) )
self.aggregation_size += config.latent_dim * self.camera_number
if "observation.state" in config.input_shapes: if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential( self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim), nn.Linear(
nn.LayerNorm(config.latent_dim), in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
self.aggregation_size += config.latent_dim
if "observation.environment_state" in config.input_shapes: if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential( self.env_state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim), nn.Linear(
nn.LayerNorm(config.latent_dim), in_features=config.input_shapes["observation.environment_state"][0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
self.aggregation_size += config.latent_dim
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector. """Encode the image and/or state vector.
@ -482,7 +514,11 @@ class SACObservationEncoder(nn.Module):
if "observation.state" in self.config.input_shapes: if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"])) feat.append(self.state_enc_layers(obs_dict["observation.state"]))
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way # TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
return torch.stack(feat, dim=0).mean(0) # return torch.stack(feat, dim=0).mean(0)
features = torch.cat(tensors=feat, dim=-1)
features = self.aggregation_layer(features)
return features
@property @property
def output_dim(self) -> int: def output_dim(self) -> int:

View File

@ -0,0 +1,97 @@
# @package _global_
# Train with:
#
# python lerobot/scripts/train.py \
# +dataset=lerobot/pusht_keypoints
# env=pusht \
# env.gym.obs_type=environment_state_agent_pos \
seed: 1
dataset_repo_id: null
training:
# Offline training dataloader
num_workers: 4
# batch_size: 256
batch_size: 512
grad_clip_norm: 10.0
lr: 3e-4
eval_freq: 2500
log_freq: 500
save_freq: 50000
online_steps: 1000000
online_rollout_n_episodes: 10
online_rollout_batch_size: 10
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 1000000
online_buffer_seed_size: 0
online_step_before_learning: 5000
do_online_rollout_async: false
policy_update_freq: 1
# delta_timestamps:
# observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
# observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
# action: "[i / ${fps} for i in range(${policy.horizon})]"
# next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
policy:
name: sac
pretrained_model_path:
# Input / output structure.
n_action_repeats: 1
horizon: 1
n_action_steps: 1
shared_encoder: true
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]
observation.image: [3, 64, 64]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes: null
output_normalization_modes:
action: min_max
output_normalization_params:
action:
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
# Architecture / modeling.
# Neural networks.
image_encoder_hidden_dim: 32
# discount: 0.99
discount: 0.80
temperature_init: 1.0
num_critics: 2
num_subsample_critics: null
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
# critic_target_update_weight: 0.005
critic_target_update_weight: 0.01
utd_ratio: 1
# # Loss coefficients.
# reward_coeff: 0.5
# expectile_weight: 0.9
# value_coeff: 0.1
# consistency_coeff: 20.0
# advantage_scaling: 3.0
# pi_coeff: 0.5
# temporal_decay_coeff: 0.5
# # Target model.
# target_model_momentum: 0.995

View File

@ -14,34 +14,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import time import functools
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from pprint import pformat from pprint import pformat
import random import random
from typing import Optional, Sequence, TypedDict from typing import Optional, Sequence, TypedDict, Callable
import hydra import hydra
import numpy as np
import torch import torch
from deepdiff import DeepDiff import torch.nn.functional as F
from omegaconf import DictConfig, ListConfig, OmegaConf
from termcolor import colored
from torch import nn from torch import nn
from torch.cuda.amp import GradScaler
from tqdm import tqdm from tqdm import tqdm
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset, LeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights # TODO: Remove the import of maniskill
from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env, make_maniskill_env
from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
@ -56,7 +49,8 @@ from lerobot.scripts.eval import eval_policy
def make_optimizers_and_scheduler(cfg, policy): def make_optimizers_and_scheduler(cfg, policy):
optimizer_actor = torch.optim.Adam( optimizer_actor = torch.optim.Adam(
params=policy.actor.parameters(), # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize,
lr=policy.config.actor_lr, lr=policy.config.actor_lr,
) )
optimizer_critic = torch.optim.Adam( optimizer_critic = torch.optim.Adam(
@ -73,11 +67,6 @@ def make_optimizers_and_scheduler(cfg, policy):
return optimizers, lr_scheduler return optimizers, lr_scheduler
# def update_policy(policy, batch, optimizers, grad_clip_norm):
# NOTE: This is temporary, online buffer or query lerobot dataset is not performant enough yet
class Transition(TypedDict): class Transition(TypedDict):
state: dict[str, torch.Tensor] state: dict[str, torch.Tensor]
action: torch.Tensor action: torch.Tensor
@ -95,13 +84,62 @@ class BatchTransition(TypedDict):
done: torch.Tensor done: torch.Tensor
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
"""
Perform a per-image random crop over a batch of images in a vectorized way.
(Same as shown previously.)
"""
B, C, H, W = images.shape
crop_h, crop_w = output_size
if crop_h > H or crop_w > W:
raise ValueError(
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
)
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
return cropped
def random_shift(images: torch.Tensor, pad: int = 4):
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
_, _, h, w = images.shape
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
return random_crop_vectorized(images=images, output_size=(h, w))
class ReplayBuffer: class ReplayBuffer:
def __init__(self, capacity: int, device: str = "cuda:0", state_keys: Optional[Sequence[str]] = None): def __init__(
self,
capacity: int,
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
):
""" """
Args: Args:
capacity (int): Maximum number of transitions to store in the buffer. capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved ("cuda:0" or "cpu"). device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
state_keys (List[str]): The list of keys that appear in `state` and `next_state`. state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
and returns a batch of augmented images. If None, a default augmentation function is used.
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
""" """
self.capacity = capacity self.capacity = capacity
self.device = device self.device = device
@ -111,6 +149,9 @@ class ReplayBuffer:
# If no state_keys provided, default to an empty list # If no state_keys provided, default to an empty list
# (you can handle this differently if needed) # (you can handle this differently if needed)
self.state_keys = state_keys if state_keys is not None else [] self.state_keys = state_keys if state_keys is not None else []
if image_augmentation_function is None:
self.image_augmentation_function = functools.partial(random_shift, pad=4)
self.use_drq = use_drq
def add( def add(
self, self,
@ -134,7 +175,7 @@ class ReplayBuffer:
done=done, done=done,
complementary_info=complementary_info, complementary_info=complementary_info,
) )
self.position = (self.position + 1) % self.capacity self.position: int = (self.position + 1) % self.capacity
@classmethod @classmethod
def from_lerobot_dataset( def from_lerobot_dataset(
@ -143,6 +184,18 @@ class ReplayBuffer:
device: str = "cuda:0", device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None, state_keys: Optional[Sequence[str]] = None,
) -> "ReplayBuffer": ) -> "ReplayBuffer":
"""
Convert a LeRobotDataset into a ReplayBuffer.
Args:
lerobot_dataset (LeRobotDataset): The dataset to convert.
device (str): The device . Defaults to "cuda:0".
state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`.
Defaults to None.
Returns:
ReplayBuffer: The replay buffer with offline dataset transitions.
"""
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from # We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
# a replay buffer than from a lerobot dataset. # a replay buffer than from a lerobot dataset.
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys) replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
@ -248,6 +301,8 @@ class ReplayBuffer:
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to( batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
self.device self.device
) )
if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions -- # -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device) batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
@ -263,6 +318,8 @@ class ReplayBuffer:
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to( batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
self.device self.device
) )
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
# -- Build batched dones -- # -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
@ -285,7 +342,7 @@ class ReplayBuffer:
def concatenate_batch_transitions( def concatenate_batch_transitions(
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
) -> BatchTransition: ) -> BatchTransition:
"""Be careful it change the left_batch_transitions in place""" """NOTE: Be careful it change the left_batch_transitions in place"""
left_batch_transitions["state"] = { left_batch_transitions["state"] = {
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0) key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
for key in left_batch_transitions["state"] for key in left_batch_transitions["state"]
@ -321,11 +378,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size) # online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
# NOTE: Off policy algorithm are efficient enought to use a single environment # NOTE: Off policy algorithm are efficient enought to use a single environment
logging.info("make_env online") logging.info("make_env online")
online_env = make_env(cfg, n_envs=1) # online_env = make_env(cfg, n_envs=1)
# TODO: Remove the import of maniskill and unifiy with make env
online_env = make_maniskill_env(cfg, n_envs=1)
if cfg.training.eval_freq > 0: if cfg.training.eval_freq > 0:
logging.info("make_env eval") logging.info("make_env eval")
eval_env = make_env(cfg, n_envs=1) # eval_env = make_env(cfg, n_envs=1)
# TODO: Remove the import of maniskill and unifiy with make env
eval_env = make_maniskill_env(cfg, n_envs=1)
# TODO: Add a way to resume training # TODO: Add a way to resume training
@ -348,6 +408,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Hack: But if we do online traning, we do not need dataset_stats # Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None, dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
device=device,
) )
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
@ -360,17 +421,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
log_output_dir(out_dir) log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")
# TODO: Handle offline steps
# logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
logging.info(f"{cfg.training.online_steps=}") logging.info(f"{cfg.training.online_steps=}")
# logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
# logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
obs, info = online_env.reset() obs, info = online_env.reset()
obs = preprocess_observation(obs) # HACK for maniskill
# obs = preprocess_observation(obs)
obs = preprocess_maniskill_observation(obs)
obs = {key: obs[key].to(device, non_blocking=True) for key in obs} obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
replay_buffer = ReplayBuffer( replay_buffer = ReplayBuffer(
@ -378,8 +437,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
) )
batch_size = cfg.training.batch_size batch_size = cfg.training.batch_size
# if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
# raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
if cfg.dataset_repo_id is not None: if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer") logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg) offline_dataset = make_dataset(cfg)
@ -404,7 +462,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# HACK # HACK
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True) action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
next_obs = preprocess_observation(next_obs) # HACK: For maniskill
# next_obs = preprocess_observation(next_obs)
next_obs = preprocess_maniskill_observation(next_obs)
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs} next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
sum_reward_episode += float(reward[0]) sum_reward_episode += float(reward[0])
# Because we are using a single environment # Because we are using a single environment
@ -413,16 +473,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}") logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step) logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
sum_reward_episode = 0 sum_reward_episode = 0
if "final_info" in info: # HACK: This is for maniskill
if "is_success" in info["final_info"][0]: logging.info(
logging.info( f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
f"Global step {interaction_step}: Episode success: {info['final_info'][0]['is_success']}" )
) logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
if "coverage" in info["final_info"][0]:
logging.info(
f"Global step {interaction_step}: Episode final coverage: {info['final_info'][0]['coverage']} \n"
)
logger.log_dict({"Final coverage": info["final_info"][0]["coverage"]}, interaction_step)
replay_buffer.add( replay_buffer.add(
state=obs, state=obs,
@ -433,38 +488,13 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
) )
obs = next_obs obs = next_obs
if interaction_step >= cfg.training.online_step_before_learning: if interaction_step < cfg.training.online_step_before_learning:
for _ in range(cfg.policy.utd_ratio - 1): continue
batch = replay_buffer.sample(batch_size) for _ in range(cfg.policy.utd_ratio - 1):
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
batch = replay_buffer.sample(batch_size) batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None: if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size) batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions( batch = concatenate_batch_transitions(batch, batch_offline)
left_batch_transitions=batch, right_batch_transition=batch_offline
)
# NOTE: We have to handle the normalization for the batch
# batch = policy.normalize_inputs(batch)
actions = batch["action"] actions = batch["action"]
rewards = batch["reward"] rewards = batch["reward"]
@ -483,31 +513,55 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
loss_critic.backward() loss_critic.backward()
optimizers["critic"].step() optimizers["critic"].step()
training_infos = {} batch = replay_buffer.sample(batch_size)
training_infos["loss_critic"] = loss_critic.item() if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
if interaction_step % cfg.training.policy_update_freq == 0: actions = batch["action"]
# TD3 Trick rewards = batch["reward"]
for _ in range(cfg.training.policy_update_freq): observations = batch["state"]
loss_actor = policy.compute_loss_actor(observations=observations) next_observations = batch["next_state"]
done = batch["done"]
optimizers["actor"].zero_grad() loss_critic = policy.compute_loss_critic(
loss_actor.backward() observations=observations,
optimizers["actor"].step() actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
training_infos["loss_actor"] = loss_actor.item() training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
loss_temperature = policy.compute_loss_temperature(observations=observations) if interaction_step % cfg.training.policy_update_freq == 0:
optimizers["temperature"].zero_grad() # TD3 Trick
loss_temperature.backward() for _ in range(cfg.training.policy_update_freq):
optimizers["temperature"].step() loss_actor = policy.compute_loss_actor(observations=observations)
training_infos["loss_temperature"] = loss_temperature.item() optimizers["actor"].zero_grad()
loss_actor.backward()
optimizers["actor"].step()
if interaction_step % cfg.training.log_freq == 0: training_infos["loss_actor"] = loss_actor.item()
logger.log_dict(training_infos, interaction_step, mode="train")
policy.update_target_networks() loss_temperature = policy.compute_loss_temperature(observations=observations)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
training_infos["loss_temperature"] = loss_temperature.item()
if interaction_step % cfg.training.log_freq == 0:
logger.log_dict(training_infos, interaction_step, mode="train")
policy.update_target_networks()
@hydra.main(version_base="1.2", config_name="default", config_path="../configs") @hydra.main(version_base="1.2", config_name="default", config_path="../configs")