Stable version of rlpd + drq
This commit is contained in:
parent
5b92465e38
commit
83dc00683c
|
@ -14,6 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
@ -67,3 +68,86 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
||||||
)
|
)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||||
|
"""Make ManiSkill3 gym environment"""
|
||||||
|
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||||
|
|
||||||
|
env = gym.make(
|
||||||
|
cfg.env.task,
|
||||||
|
obs_mode=cfg.env.obs,
|
||||||
|
control_mode=cfg.env.control_mode,
|
||||||
|
render_mode=cfg.env.render_mode,
|
||||||
|
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
|
||||||
|
num_envs=n_envs,
|
||||||
|
)
|
||||||
|
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
|
||||||
|
env = ManiSkillVectorEnv(env, ignore_terminations=True)
|
||||||
|
# state should have the size of 25
|
||||||
|
# env = ConvertToLeRobotEnv(env, n_envs)
|
||||||
|
# env = PixelWrapper(cfg, env, n_envs)
|
||||||
|
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||||
|
env.unwrapped.metadata["render_fps"] = 20
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
class PixelWrapper(gym.Wrapper):
|
||||||
|
"""
|
||||||
|
Wrapper for pixel observations. Works with Maniskill vectorized environments
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg, env, num_envs, num_frames=3):
|
||||||
|
super().__init__(env)
|
||||||
|
self.cfg = cfg
|
||||||
|
self.env = env
|
||||||
|
self.observation_space = gym.spaces.Box(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
self._frames = deque([], maxlen=num_frames)
|
||||||
|
self._render_size = cfg.env.render_size
|
||||||
|
|
||||||
|
def _get_obs(self, obs):
|
||||||
|
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
||||||
|
self._frames.append(frame)
|
||||||
|
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
|
||||||
|
|
||||||
|
def reset(self, seed):
|
||||||
|
obs, info = self.env.reset() # (seed=seed)
|
||||||
|
for _ in range(self._frames.maxlen):
|
||||||
|
obs_frames = self._get_obs(obs)
|
||||||
|
return obs_frames, info
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
return self._get_obs(obs), reward, terminated, truncated, info
|
||||||
|
|
||||||
|
class ConvertToLeRobotEnv(gym.Wrapper):
|
||||||
|
def __init__(self, env, num_envs):
|
||||||
|
super().__init__(env)
|
||||||
|
def reset(self, seed=None, options=None):
|
||||||
|
obs, info = self.env.reset(seed=seed, options={})
|
||||||
|
return self._get_obs(obs), info
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
return self._get_obs(obs), reward, terminated, truncated, info
|
||||||
|
def _get_obs(self, observation):
|
||||||
|
sensor_data = observation.pop("sensor_data")
|
||||||
|
del observation["sensor_param"]
|
||||||
|
images = []
|
||||||
|
for cam_data in sensor_data.values():
|
||||||
|
images.append(cam_data["rgb"])
|
||||||
|
|
||||||
|
images = torch.concat(images, axis=-1)
|
||||||
|
# flatten the rest of the data which should just be state data
|
||||||
|
observation = common.flatten_state_dict(
|
||||||
|
observation, use_torch=True, device=self.base_env.device
|
||||||
|
)
|
||||||
|
ret = dict()
|
||||||
|
ret["state"] = observation
|
||||||
|
ret["pixels"] = images
|
||||||
|
return ret
|
|
@ -33,6 +33,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
"""
|
"""
|
||||||
# map to expected inputs for the policy
|
# map to expected inputs for the policy
|
||||||
return_observations = {}
|
return_observations = {}
|
||||||
|
# TODO: You have to merge all tensors from agent key and extra key
|
||||||
|
# You don't keep sensor param key in the observation
|
||||||
|
# And you keep sensor data rgb
|
||||||
if "pixels" in observations:
|
if "pixels" in observations:
|
||||||
if isinstance(observations["pixels"], dict):
|
if isinstance(observations["pixels"], dict):
|
||||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||||
|
@ -56,6 +59,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
img /= 255
|
img /= 255
|
||||||
|
|
||||||
return_observations[imgkey] = img
|
return_observations[imgkey] = img
|
||||||
|
# obs state agent qpos and qvel
|
||||||
|
# image
|
||||||
|
|
||||||
if "environment_state" in observations:
|
if "environment_state" in observations:
|
||||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||||
|
@ -86,3 +91,38 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||||
policy_features[policy_key] = feature
|
policy_features[policy_key] = feature
|
||||||
|
|
||||||
return policy_features
|
return policy_features
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||||
|
"""Convert environment observation to LeRobot format observation.
|
||||||
|
Args:
|
||||||
|
observation: Dictionary of observation batches from a Gym vector environment.
|
||||||
|
Returns:
|
||||||
|
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||||
|
"""
|
||||||
|
# map to expected inputs for the policy
|
||||||
|
return_observations = {}
|
||||||
|
# TODO: You have to merge all tensors from agent key and extra key
|
||||||
|
# You don't keep sensor param key in the observation
|
||||||
|
# And you keep sensor data rgb
|
||||||
|
q_pos = observations["agent"]["qpos"]
|
||||||
|
q_vel = observations["agent"]["qvel"]
|
||||||
|
tcp_pos = observations["extra"]["tcp_pose"]
|
||||||
|
img = observations["sensor_data"]["base_camera"]["rgb"]
|
||||||
|
|
||||||
|
_, h, w, c = img.shape
|
||||||
|
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||||
|
|
||||||
|
# sanity check that images are uint8
|
||||||
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
|
|
||||||
|
# convert to channel first of type float32 in range [0,1]
|
||||||
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||||
|
img = img.type(torch.float32)
|
||||||
|
img /= 255
|
||||||
|
|
||||||
|
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||||
|
|
||||||
|
return_observations["observation.image"] = img
|
||||||
|
return_observations["observation.state"] = state
|
||||||
|
return return_observations
|
||||||
|
|
|
@ -19,34 +19,6 @@ from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SACConfig:
|
|
||||||
input_shapes: dict[str, list[int]] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"observation.image": [3, 84, 84],
|
|
||||||
"observation.state": [4],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
output_shapes: dict[str, list[int]] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"action": [2],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
|
||||||
input_normalization_modes: dict[str, str] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"observation.image": "mean_std",
|
|
||||||
"observation.state": "min_max",
|
|
||||||
"observation.environment_state": "min_max",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
output_normalization_modes: dict[str, str] = field(
|
|
||||||
default_factory=lambda: {"action": "min_max"},
|
|
||||||
)
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SACConfig:
|
class SACConfig:
|
||||||
input_shapes: dict[str, list[int]] = field(
|
input_shapes: dict[str, list[int]] = field(
|
||||||
|
@ -67,10 +39,13 @@ class SACConfig:
|
||||||
"observation.environment_state": "min_max",
|
"observation.environment_state": "min_max",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
output_normalization_modes: dict[str, str] = field(
|
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||||
default_factory=lambda: {"action": "min_max"}
|
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": {"min": [-1, -1], "max": [1, 1]},
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
camera_number: int = 1
|
||||||
# Add type annotations for these fields:
|
# Add type annotations for these fields:
|
||||||
image_encoder_hidden_dim: int = 32
|
image_encoder_hidden_dim: int = 32
|
||||||
shared_encoder: bool = False
|
shared_encoder: bool = False
|
||||||
|
|
|
@ -42,37 +42,31 @@ class SACPolicy(
|
||||||
name = "sac"
|
name = "sac"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
self,
|
||||||
|
config: SACConfig | None = None,
|
||||||
|
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
device: str = "cpu",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = SACConfig()
|
config = SACConfig()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
if config.input_normalization_modes is not None:
|
if config.input_normalization_modes is not None:
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = Normalize(
|
||||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.normalize_inputs = nn.Identity()
|
self.normalize_inputs = nn.Identity()
|
||||||
# HACK: we need to pass the dataset_stats to the normalization functions
|
|
||||||
|
|
||||||
# NOTE: This is for biwalker environment
|
output_normalization_params = {}
|
||||||
dataset_stats = dataset_stats or {
|
for outer_key, inner_dict in config.output_normalization_params.items():
|
||||||
"action": {
|
output_normalization_params[outer_key] = {}
|
||||||
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
|
for key, value in inner_dict.items():
|
||||||
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
|
output_normalization_params[outer_key][key] = torch.tensor(value)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# NOTE: This is for pusht environment
|
# HACK: This is hacky and should be removed
|
||||||
# dataset_stats = dataset_stats or {
|
dataset_stats = dataset_stats or output_normalization_params
|
||||||
# "action": {
|
|
||||||
# "min": torch.tensor([0, 0]),
|
|
||||||
# "max": torch.tensor([512, 512]),
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
self.normalize_targets = Normalize(
|
self.normalize_targets = Normalize(
|
||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
|
@ -82,7 +76,7 @@ class SACPolicy(
|
||||||
|
|
||||||
if config.shared_encoder:
|
if config.shared_encoder:
|
||||||
encoder_critic = SACObservationEncoder(config)
|
encoder_critic = SACObservationEncoder(config)
|
||||||
encoder_actor = encoder_critic
|
encoder_actor: SACObservationEncoder = encoder_critic
|
||||||
else:
|
else:
|
||||||
encoder_critic = SACObservationEncoder(config)
|
encoder_critic = SACObservationEncoder(config)
|
||||||
encoder_actor = SACObservationEncoder(config)
|
encoder_actor = SACObservationEncoder(config)
|
||||||
|
@ -95,6 +89,7 @@ class SACPolicy(
|
||||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||||
**config.critic_network_kwargs,
|
**config.critic_network_kwargs,
|
||||||
),
|
),
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
critic_nets.append(critic_net)
|
critic_nets.append(critic_net)
|
||||||
|
|
||||||
|
@ -106,40 +101,35 @@ class SACPolicy(
|
||||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||||
**config.critic_network_kwargs,
|
**config.critic_network_kwargs,
|
||||||
),
|
),
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
target_critic_nets.append(target_critic_net)
|
target_critic_nets.append(target_critic_net)
|
||||||
|
|
||||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
self.critic_ensemble = create_critic_ensemble(
|
||||||
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
|
critics=critic_nets, num_critics=config.num_critics, device=device
|
||||||
|
)
|
||||||
|
self.critic_target = create_critic_ensemble(
|
||||||
|
critics=target_critic_nets, num_critics=config.num_critics, device=device
|
||||||
|
)
|
||||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||||
|
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=encoder_actor,
|
encoder=encoder_actor,
|
||||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||||
action_dim=config.output_shapes["action"][0],
|
action_dim=config.output_shapes["action"][0],
|
||||||
|
device=device,
|
||||||
|
encoder_is_shared=config.shared_encoder,
|
||||||
**config.policy_kwargs,
|
**config.policy_kwargs,
|
||||||
)
|
)
|
||||||
if config.target_entropy is None:
|
if config.target_entropy is None:
|
||||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||||
# TODO: fix later device
|
|
||||||
# TODO: Handle the case where the temparameter is a fixed
|
# TODO: Handle the case where the temparameter is a fixed
|
||||||
self.log_alpha = torch.zeros(1, requires_grad=True, device="cpu")
|
self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||||
self.temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""Reset the policy"""
|
||||||
Clear observation and action queues. Should be called on `env.reset()`
|
pass
|
||||||
queues are populated during rollout of the policy, they contain the n latest observations and actions
|
|
||||||
"""
|
|
||||||
|
|
||||||
self._queues = {
|
|
||||||
"observation.state": deque(maxlen=1),
|
|
||||||
"action": deque(maxlen=1),
|
|
||||||
}
|
|
||||||
if "observation.image" in self.config.input_shapes:
|
|
||||||
self._queues["observation.image"] = deque(maxlen=1)
|
|
||||||
if "observation.environment_state" in self.config.input_shapes:
|
|
||||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
@ -334,6 +324,7 @@ class Policy(nn.Module):
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
use_tanh_squash: bool = False,
|
use_tanh_squash: bool = False,
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
|
encoder_is_shared: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
|
@ -344,7 +335,12 @@ class Policy(nn.Module):
|
||||||
self.log_std_max = log_std_max
|
self.log_std_max = log_std_max
|
||||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||||
self.use_tanh_squash = use_tanh_squash
|
self.use_tanh_squash = use_tanh_squash
|
||||||
|
self.parameters_to_optimize = []
|
||||||
|
|
||||||
|
self.parameters_to_optimize += list(self.network.parameters())
|
||||||
|
|
||||||
|
if self.encoder is not None and not encoder_is_shared:
|
||||||
|
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||||
# Find the last Linear layer's output dimension
|
# Find the last Linear layer's output dimension
|
||||||
for layer in reversed(network.net):
|
for layer in reversed(network.net):
|
||||||
if isinstance(layer, nn.Linear):
|
if isinstance(layer, nn.Linear):
|
||||||
|
@ -358,6 +354,7 @@ class Policy(nn.Module):
|
||||||
else:
|
else:
|
||||||
orthogonal_init()(self.mean_layer.weight)
|
orthogonal_init()(self.mean_layer.weight)
|
||||||
|
|
||||||
|
self.parameters_to_optimize += list(self.mean_layer.parameters())
|
||||||
# Standard deviation layer or parameter
|
# Standard deviation layer or parameter
|
||||||
if fixed_std is None:
|
if fixed_std is None:
|
||||||
self.std_layer = nn.Linear(out_features, action_dim)
|
self.std_layer = nn.Linear(out_features, action_dim)
|
||||||
|
@ -366,6 +363,7 @@ class Policy(nn.Module):
|
||||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||||
else:
|
else:
|
||||||
orthogonal_init()(self.std_layer.weight)
|
orthogonal_init()(self.std_layer.weight)
|
||||||
|
self.parameters_to_optimize += list(self.std_layer.parameters())
|
||||||
|
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
|
@ -428,44 +426,78 @@ class SACObservationEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
if "observation.image" in config.input_shapes:
|
if "observation.image" in config.input_shapes:
|
||||||
self.image_enc_layers = nn.Sequential(
|
self.image_enc_layers = nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
|
in_channels=config.input_shapes["observation.image"][0],
|
||||||
|
out_channels=config.image_encoder_hidden_dim,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=2,
|
||||||
),
|
),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
nn.Conv2d(
|
||||||
|
in_channels=config.image_encoder_hidden_dim,
|
||||||
|
out_channels=config.image_encoder_hidden_dim,
|
||||||
|
kernel_size=5,
|
||||||
|
stride=2,
|
||||||
|
),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
nn.Conv2d(
|
||||||
|
in_channels=config.image_encoder_hidden_dim,
|
||||||
|
out_channels=config.image_encoder_hidden_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
nn.Conv2d(
|
||||||
|
in_channels=config.image_encoder_hidden_dim,
|
||||||
|
out_channels=config.image_encoder_hidden_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
|
self.camera_number = config.camera_number
|
||||||
|
self.aggregation_size: int = 0
|
||||||
|
|
||||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||||
self.image_enc_layers.extend(
|
self.image_enc_layers.extend(
|
||||||
nn.Sequential(
|
sequential=nn.Sequential(
|
||||||
nn.Flatten(),
|
nn.Flatten(),
|
||||||
nn.Linear(np.prod(out_shape), config.latent_dim),
|
nn.Linear(
|
||||||
nn.LayerNorm(config.latent_dim),
|
in_features=np.prod(out_shape) * self.camera_number, out_features=config.latent_dim
|
||||||
|
),
|
||||||
|
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.aggregation_size += config.latent_dim * self.camera_number
|
||||||
if "observation.state" in config.input_shapes:
|
if "observation.state" in config.input_shapes:
|
||||||
self.state_enc_layers = nn.Sequential(
|
self.state_enc_layers = nn.Sequential(
|
||||||
nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim),
|
nn.Linear(
|
||||||
nn.LayerNorm(config.latent_dim),
|
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
|
||||||
|
),
|
||||||
|
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
|
self.aggregation_size += config.latent_dim
|
||||||
|
|
||||||
if "observation.environment_state" in config.input_shapes:
|
if "observation.environment_state" in config.input_shapes:
|
||||||
self.env_state_enc_layers = nn.Sequential(
|
self.env_state_enc_layers = nn.Sequential(
|
||||||
nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim),
|
nn.Linear(
|
||||||
nn.LayerNorm(config.latent_dim),
|
in_features=config.input_shapes["observation.environment_state"][0],
|
||||||
|
out_features=config.latent_dim,
|
||||||
|
),
|
||||||
|
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.aggregation_size += config.latent_dim
|
||||||
|
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||||
|
|
||||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||||
"""Encode the image and/or state vector.
|
"""Encode the image and/or state vector.
|
||||||
|
|
||||||
|
@ -482,7 +514,11 @@ class SACObservationEncoder(nn.Module):
|
||||||
if "observation.state" in self.config.input_shapes:
|
if "observation.state" in self.config.input_shapes:
|
||||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||||
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
|
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
|
||||||
return torch.stack(feat, dim=0).mean(0)
|
# return torch.stack(feat, dim=0).mean(0)
|
||||||
|
features = torch.cat(tensors=feat, dim=-1)
|
||||||
|
features = self.aggregation_layer(features)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_dim(self) -> int:
|
def output_dim(self) -> int:
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Train with:
|
||||||
|
#
|
||||||
|
# python lerobot/scripts/train.py \
|
||||||
|
# +dataset=lerobot/pusht_keypoints
|
||||||
|
# env=pusht \
|
||||||
|
# env.gym.obs_type=environment_state_agent_pos \
|
||||||
|
|
||||||
|
seed: 1
|
||||||
|
dataset_repo_id: null
|
||||||
|
|
||||||
|
|
||||||
|
training:
|
||||||
|
# Offline training dataloader
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
# batch_size: 256
|
||||||
|
batch_size: 512
|
||||||
|
grad_clip_norm: 10.0
|
||||||
|
lr: 3e-4
|
||||||
|
|
||||||
|
eval_freq: 2500
|
||||||
|
log_freq: 500
|
||||||
|
save_freq: 50000
|
||||||
|
|
||||||
|
online_steps: 1000000
|
||||||
|
online_rollout_n_episodes: 10
|
||||||
|
online_rollout_batch_size: 10
|
||||||
|
online_steps_between_rollouts: 1000
|
||||||
|
online_sampling_ratio: 1.0
|
||||||
|
online_env_seed: 10000
|
||||||
|
online_buffer_capacity: 1000000
|
||||||
|
online_buffer_seed_size: 0
|
||||||
|
online_step_before_learning: 5000
|
||||||
|
do_online_rollout_async: false
|
||||||
|
policy_update_freq: 1
|
||||||
|
|
||||||
|
# delta_timestamps:
|
||||||
|
# observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||||
|
# observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||||
|
# action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||||
|
# next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||||
|
|
||||||
|
policy:
|
||||||
|
name: sac
|
||||||
|
|
||||||
|
pretrained_model_path:
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_action_repeats: 1
|
||||||
|
horizon: 1
|
||||||
|
n_action_steps: 1
|
||||||
|
|
||||||
|
shared_encoder: true
|
||||||
|
input_shapes:
|
||||||
|
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
observation.image: [3, 64, 64]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes: null
|
||||||
|
output_normalization_modes:
|
||||||
|
action: min_max
|
||||||
|
output_normalization_params:
|
||||||
|
action:
|
||||||
|
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
||||||
|
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||||
|
|
||||||
|
# Architecture / modeling.
|
||||||
|
# Neural networks.
|
||||||
|
image_encoder_hidden_dim: 32
|
||||||
|
# discount: 0.99
|
||||||
|
discount: 0.80
|
||||||
|
temperature_init: 1.0
|
||||||
|
num_critics: 2
|
||||||
|
num_subsample_critics: null
|
||||||
|
critic_lr: 3e-4
|
||||||
|
actor_lr: 3e-4
|
||||||
|
temperature_lr: 3e-4
|
||||||
|
# critic_target_update_weight: 0.005
|
||||||
|
critic_target_update_weight: 0.01
|
||||||
|
utd_ratio: 1
|
||||||
|
|
||||||
|
|
||||||
|
# # Loss coefficients.
|
||||||
|
# reward_coeff: 0.5
|
||||||
|
# expectile_weight: 0.9
|
||||||
|
# value_coeff: 0.1
|
||||||
|
# consistency_coeff: 20.0
|
||||||
|
# advantage_scaling: 3.0
|
||||||
|
# pi_coeff: 0.5
|
||||||
|
# temporal_decay_coeff: 0.5
|
||||||
|
# # Target model.
|
||||||
|
# target_model_momentum: 0.995
|
|
@ -14,34 +14,27 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import time
|
import functools
|
||||||
from contextlib import nullcontext
|
|
||||||
from copy import deepcopy
|
|
||||||
from pathlib import Path
|
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Sequence, TypedDict
|
from typing import Optional, Sequence, TypedDict, Callable
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from deepdiff import DeepDiff
|
import torch.nn.functional as F
|
||||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
||||||
from termcolor import colored
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from deepdiff import DeepDiff
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset, LeRobotDataset
|
|
||||||
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
# TODO: Remove the import of maniskill
|
||||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.envs.factory import make_env, make_maniskill_env
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
|
@ -56,7 +49,8 @@ from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
def make_optimizers_and_scheduler(cfg, policy):
|
def make_optimizers_and_scheduler(cfg, policy):
|
||||||
optimizer_actor = torch.optim.Adam(
|
optimizer_actor = torch.optim.Adam(
|
||||||
params=policy.actor.parameters(),
|
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||||
|
params=policy.actor.parameters_to_optimize,
|
||||||
lr=policy.config.actor_lr,
|
lr=policy.config.actor_lr,
|
||||||
)
|
)
|
||||||
optimizer_critic = torch.optim.Adam(
|
optimizer_critic = torch.optim.Adam(
|
||||||
|
@ -73,11 +67,6 @@ def make_optimizers_and_scheduler(cfg, policy):
|
||||||
return optimizers, lr_scheduler
|
return optimizers, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
# def update_policy(policy, batch, optimizers, grad_clip_norm):
|
|
||||||
|
|
||||||
# NOTE: This is temporary, online buffer or query lerobot dataset is not performant enough yet
|
|
||||||
|
|
||||||
|
|
||||||
class Transition(TypedDict):
|
class Transition(TypedDict):
|
||||||
state: dict[str, torch.Tensor]
|
state: dict[str, torch.Tensor]
|
||||||
action: torch.Tensor
|
action: torch.Tensor
|
||||||
|
@ -95,13 +84,62 @@ class BatchTransition(TypedDict):
|
||||||
done: torch.Tensor
|
done: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||||
|
(Same as shown previously.)
|
||||||
|
"""
|
||||||
|
B, C, H, W = images.shape
|
||||||
|
crop_h, crop_w = output_size
|
||||||
|
|
||||||
|
if crop_h > H or crop_w > W:
|
||||||
|
raise ValueError(
|
||||||
|
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
|
||||||
|
)
|
||||||
|
|
||||||
|
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
|
||||||
|
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
|
||||||
|
|
||||||
|
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
|
||||||
|
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
|
||||||
|
|
||||||
|
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
|
||||||
|
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
|
||||||
|
|
||||||
|
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||||
|
|
||||||
|
# Gather pixels
|
||||||
|
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||||
|
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||||
|
|
||||||
|
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||||
|
return cropped
|
||||||
|
|
||||||
|
|
||||||
|
def random_shift(images: torch.Tensor, pad: int = 4):
|
||||||
|
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
|
||||||
|
_, _, h, w = images.shape
|
||||||
|
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
|
||||||
|
return random_crop_vectorized(images=images, output_size=(h, w))
|
||||||
|
|
||||||
|
|
||||||
class ReplayBuffer:
|
class ReplayBuffer:
|
||||||
def __init__(self, capacity: int, device: str = "cuda:0", state_keys: Optional[Sequence[str]] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
capacity: int,
|
||||||
|
device: str = "cuda:0",
|
||||||
|
state_keys: Optional[Sequence[str]] = None,
|
||||||
|
image_augmentation_function: Optional[Callable] = None,
|
||||||
|
use_drq: bool = True,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
capacity (int): Maximum number of transitions to store in the buffer.
|
capacity (int): Maximum number of transitions to store in the buffer.
|
||||||
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
|
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
|
||||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||||
|
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||||
|
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||||
|
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||||
"""
|
"""
|
||||||
self.capacity = capacity
|
self.capacity = capacity
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -111,6 +149,9 @@ class ReplayBuffer:
|
||||||
# If no state_keys provided, default to an empty list
|
# If no state_keys provided, default to an empty list
|
||||||
# (you can handle this differently if needed)
|
# (you can handle this differently if needed)
|
||||||
self.state_keys = state_keys if state_keys is not None else []
|
self.state_keys = state_keys if state_keys is not None else []
|
||||||
|
if image_augmentation_function is None:
|
||||||
|
self.image_augmentation_function = functools.partial(random_shift, pad=4)
|
||||||
|
self.use_drq = use_drq
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
|
@ -134,7 +175,7 @@ class ReplayBuffer:
|
||||||
done=done,
|
done=done,
|
||||||
complementary_info=complementary_info,
|
complementary_info=complementary_info,
|
||||||
)
|
)
|
||||||
self.position = (self.position + 1) % self.capacity
|
self.position: int = (self.position + 1) % self.capacity
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_lerobot_dataset(
|
def from_lerobot_dataset(
|
||||||
|
@ -143,6 +184,18 @@ class ReplayBuffer:
|
||||||
device: str = "cuda:0",
|
device: str = "cuda:0",
|
||||||
state_keys: Optional[Sequence[str]] = None,
|
state_keys: Optional[Sequence[str]] = None,
|
||||||
) -> "ReplayBuffer":
|
) -> "ReplayBuffer":
|
||||||
|
"""
|
||||||
|
Convert a LeRobotDataset into a ReplayBuffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lerobot_dataset (LeRobotDataset): The dataset to convert.
|
||||||
|
device (str): The device . Defaults to "cuda:0".
|
||||||
|
state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReplayBuffer: The replay buffer with offline dataset transitions.
|
||||||
|
"""
|
||||||
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
|
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
|
||||||
# a replay buffer than from a lerobot dataset.
|
# a replay buffer than from a lerobot dataset.
|
||||||
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
|
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
|
||||||
|
@ -248,6 +301,8 @@ class ReplayBuffer:
|
||||||
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
|
if key.startswith("observation.image") and self.use_drq:
|
||||||
|
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
||||||
|
|
||||||
# -- Build batched actions --
|
# -- Build batched actions --
|
||||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
||||||
|
@ -263,6 +318,8 @@ class ReplayBuffer:
|
||||||
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
|
if key.startswith("observation.image") and self.use_drq:
|
||||||
|
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
|
||||||
|
|
||||||
# -- Build batched dones --
|
# -- Build batched dones --
|
||||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||||
|
@ -285,7 +342,7 @@ class ReplayBuffer:
|
||||||
def concatenate_batch_transitions(
|
def concatenate_batch_transitions(
|
||||||
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||||
) -> BatchTransition:
|
) -> BatchTransition:
|
||||||
"""Be careful it change the left_batch_transitions in place"""
|
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||||
left_batch_transitions["state"] = {
|
left_batch_transitions["state"] = {
|
||||||
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
|
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
|
||||||
for key in left_batch_transitions["state"]
|
for key in left_batch_transitions["state"]
|
||||||
|
@ -321,11 +378,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
|
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
|
||||||
# NOTE: Off policy algorithm are efficient enought to use a single environment
|
# NOTE: Off policy algorithm are efficient enought to use a single environment
|
||||||
logging.info("make_env online")
|
logging.info("make_env online")
|
||||||
online_env = make_env(cfg, n_envs=1)
|
# online_env = make_env(cfg, n_envs=1)
|
||||||
|
# TODO: Remove the import of maniskill and unifiy with make env
|
||||||
|
online_env = make_maniskill_env(cfg, n_envs=1)
|
||||||
if cfg.training.eval_freq > 0:
|
if cfg.training.eval_freq > 0:
|
||||||
logging.info("make_env eval")
|
logging.info("make_env eval")
|
||||||
eval_env = make_env(cfg, n_envs=1)
|
# eval_env = make_env(cfg, n_envs=1)
|
||||||
|
# TODO: Remove the import of maniskill and unifiy with make env
|
||||||
|
eval_env = make_maniskill_env(cfg, n_envs=1)
|
||||||
|
|
||||||
# TODO: Add a way to resume training
|
# TODO: Add a way to resume training
|
||||||
|
|
||||||
|
@ -348,6 +408,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
# Hack: But if we do online traning, we do not need dataset_stats
|
# Hack: But if we do online traning, we do not need dataset_stats
|
||||||
dataset_stats=None,
|
dataset_stats=None,
|
||||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
|
@ -360,17 +421,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
log_output_dir(out_dir)
|
log_output_dir(out_dir)
|
||||||
logging.info(f"{cfg.env.task=}")
|
logging.info(f"{cfg.env.task=}")
|
||||||
# TODO: Handle offline steps
|
|
||||||
# logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
|
|
||||||
logging.info(f"{cfg.training.online_steps=}")
|
logging.info(f"{cfg.training.online_steps=}")
|
||||||
# logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
|
|
||||||
# logging.info(f"{offline_dataset.num_episodes=}")
|
|
||||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
obs, info = online_env.reset()
|
obs, info = online_env.reset()
|
||||||
|
|
||||||
obs = preprocess_observation(obs)
|
# HACK for maniskill
|
||||||
|
# obs = preprocess_observation(obs)
|
||||||
|
obs = preprocess_maniskill_observation(obs)
|
||||||
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
|
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
|
||||||
|
|
||||||
replay_buffer = ReplayBuffer(
|
replay_buffer = ReplayBuffer(
|
||||||
|
@ -378,8 +437,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size = cfg.training.batch_size
|
batch_size = cfg.training.batch_size
|
||||||
# if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
|
|
||||||
# raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
|
||||||
if cfg.dataset_repo_id is not None:
|
if cfg.dataset_repo_id is not None:
|
||||||
logging.info("make_dataset offline buffer")
|
logging.info("make_dataset offline buffer")
|
||||||
offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
|
@ -404,7 +462,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
# HACK
|
# HACK
|
||||||
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
|
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
|
||||||
|
|
||||||
next_obs = preprocess_observation(next_obs)
|
# HACK: For maniskill
|
||||||
|
# next_obs = preprocess_observation(next_obs)
|
||||||
|
next_obs = preprocess_maniskill_observation(next_obs)
|
||||||
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
|
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
|
||||||
sum_reward_episode += float(reward[0])
|
sum_reward_episode += float(reward[0])
|
||||||
# Because we are using a single environment
|
# Because we are using a single environment
|
||||||
|
@ -413,16 +473,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||||
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
|
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
|
||||||
sum_reward_episode = 0
|
sum_reward_episode = 0
|
||||||
if "final_info" in info:
|
# HACK: This is for maniskill
|
||||||
if "is_success" in info["final_info"][0]:
|
logging.info(
|
||||||
logging.info(
|
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
|
||||||
f"Global step {interaction_step}: Episode success: {info['final_info'][0]['is_success']}"
|
)
|
||||||
)
|
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
|
||||||
if "coverage" in info["final_info"][0]:
|
|
||||||
logging.info(
|
|
||||||
f"Global step {interaction_step}: Episode final coverage: {info['final_info'][0]['coverage']} \n"
|
|
||||||
)
|
|
||||||
logger.log_dict({"Final coverage": info["final_info"][0]["coverage"]}, interaction_step)
|
|
||||||
|
|
||||||
replay_buffer.add(
|
replay_buffer.add(
|
||||||
state=obs,
|
state=obs,
|
||||||
|
@ -433,38 +488,13 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
)
|
)
|
||||||
obs = next_obs
|
obs = next_obs
|
||||||
|
|
||||||
if interaction_step >= cfg.training.online_step_before_learning:
|
if interaction_step < cfg.training.online_step_before_learning:
|
||||||
for _ in range(cfg.policy.utd_ratio - 1):
|
continue
|
||||||
batch = replay_buffer.sample(batch_size)
|
for _ in range(cfg.policy.utd_ratio - 1):
|
||||||
if cfg.dataset_repo_id is not None:
|
|
||||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
|
||||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
|
||||||
|
|
||||||
actions = batch["action"]
|
|
||||||
rewards = batch["reward"]
|
|
||||||
observations = batch["state"]
|
|
||||||
next_observations = batch["next_state"]
|
|
||||||
done = batch["done"]
|
|
||||||
|
|
||||||
loss_critic = policy.compute_loss_critic(
|
|
||||||
observations=observations,
|
|
||||||
actions=actions,
|
|
||||||
rewards=rewards,
|
|
||||||
next_observations=next_observations,
|
|
||||||
done=done,
|
|
||||||
)
|
|
||||||
optimizers["critic"].zero_grad()
|
|
||||||
loss_critic.backward()
|
|
||||||
optimizers["critic"].step()
|
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size)
|
batch = replay_buffer.sample(batch_size)
|
||||||
if cfg.dataset_repo_id is not None:
|
if cfg.dataset_repo_id is not None:
|
||||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||||
batch = concatenate_batch_transitions(
|
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
|
||||||
)
|
|
||||||
# NOTE: We have to handle the normalization for the batch
|
|
||||||
# batch = policy.normalize_inputs(batch)
|
|
||||||
|
|
||||||
actions = batch["action"]
|
actions = batch["action"]
|
||||||
rewards = batch["reward"]
|
rewards = batch["reward"]
|
||||||
|
@ -483,31 +513,55 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
optimizers["critic"].step()
|
optimizers["critic"].step()
|
||||||
|
|
||||||
training_infos = {}
|
batch = replay_buffer.sample(batch_size)
|
||||||
training_infos["loss_critic"] = loss_critic.item()
|
if cfg.dataset_repo_id is not None:
|
||||||
|
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||||
|
batch = concatenate_batch_transitions(
|
||||||
|
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||||
|
)
|
||||||
|
|
||||||
if interaction_step % cfg.training.policy_update_freq == 0:
|
actions = batch["action"]
|
||||||
# TD3 Trick
|
rewards = batch["reward"]
|
||||||
for _ in range(cfg.training.policy_update_freq):
|
observations = batch["state"]
|
||||||
loss_actor = policy.compute_loss_actor(observations=observations)
|
next_observations = batch["next_state"]
|
||||||
|
done = batch["done"]
|
||||||
|
|
||||||
optimizers["actor"].zero_grad()
|
loss_critic = policy.compute_loss_critic(
|
||||||
loss_actor.backward()
|
observations=observations,
|
||||||
optimizers["actor"].step()
|
actions=actions,
|
||||||
|
rewards=rewards,
|
||||||
|
next_observations=next_observations,
|
||||||
|
done=done,
|
||||||
|
)
|
||||||
|
optimizers["critic"].zero_grad()
|
||||||
|
loss_critic.backward()
|
||||||
|
optimizers["critic"].step()
|
||||||
|
|
||||||
training_infos["loss_actor"] = loss_actor.item()
|
training_infos = {}
|
||||||
|
training_infos["loss_critic"] = loss_critic.item()
|
||||||
|
|
||||||
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
if interaction_step % cfg.training.policy_update_freq == 0:
|
||||||
optimizers["temperature"].zero_grad()
|
# TD3 Trick
|
||||||
loss_temperature.backward()
|
for _ in range(cfg.training.policy_update_freq):
|
||||||
optimizers["temperature"].step()
|
loss_actor = policy.compute_loss_actor(observations=observations)
|
||||||
|
|
||||||
training_infos["loss_temperature"] = loss_temperature.item()
|
optimizers["actor"].zero_grad()
|
||||||
|
loss_actor.backward()
|
||||||
|
optimizers["actor"].step()
|
||||||
|
|
||||||
if interaction_step % cfg.training.log_freq == 0:
|
training_infos["loss_actor"] = loss_actor.item()
|
||||||
logger.log_dict(training_infos, interaction_step, mode="train")
|
|
||||||
|
|
||||||
policy.update_target_networks()
|
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
||||||
|
optimizers["temperature"].zero_grad()
|
||||||
|
loss_temperature.backward()
|
||||||
|
optimizers["temperature"].step()
|
||||||
|
|
||||||
|
training_infos["loss_temperature"] = loss_temperature.item()
|
||||||
|
|
||||||
|
if interaction_step % cfg.training.log_freq == 0:
|
||||||
|
logger.log_dict(training_infos, interaction_step, mode="train")
|
||||||
|
|
||||||
|
policy.update_target_networks()
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||||
|
|
Loading…
Reference in New Issue