Add common, refactor eval with eval_policy
This commit is contained in:
parent
1e52499490
commit
5a5b190f70
|
@ -0,0 +1,42 @@
|
||||||
|
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||||
|
|
||||||
|
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||||
|
|
||||||
|
|
||||||
|
def make_env(cfg):
|
||||||
|
assert cfg.env == "simxarm"
|
||||||
|
env = SimxarmEnv(
|
||||||
|
task=cfg.task,
|
||||||
|
from_pixels=cfg.from_pixels,
|
||||||
|
pixels_only=cfg.pixels_only,
|
||||||
|
image_size=cfg.image_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# limit rollout to max_steps
|
||||||
|
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length))
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
# def make_env(env_name, frame_skip, device, is_test=False):
|
||||||
|
# env = GymEnv(
|
||||||
|
# env_name,
|
||||||
|
# frame_skip=frame_skip,
|
||||||
|
# from_pixels=True,
|
||||||
|
# pixels_only=False,
|
||||||
|
# device=device,
|
||||||
|
# )
|
||||||
|
# env = TransformedEnv(env)
|
||||||
|
# env.append_transform(NoopResetEnv(noops=30, random=True))
|
||||||
|
# if not is_test:
|
||||||
|
# env.append_transform(EndOfLifeTransform())
|
||||||
|
# env.append_transform(RewardClipping(-1, 1))
|
||||||
|
# env.append_transform(ToTensorImage())
|
||||||
|
# env.append_transform(GrayScale())
|
||||||
|
# env.append_transform(Resize(84, 84))
|
||||||
|
# env.append_transform(CatFrames(N=4, dim=-3))
|
||||||
|
# env.append_transform(RewardSum())
|
||||||
|
# env.append_transform(StepCounter(max_steps=4500))
|
||||||
|
# env.append_transform(DoubleToFloat())
|
||||||
|
# env.append_transform(VecNorm(in_keys=["pixels"]))
|
||||||
|
# return env
|
|
@ -0,0 +1,183 @@
|
||||||
|
import importlib
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.tensor_specs import (
|
||||||
|
BoundedTensorSpec,
|
||||||
|
CompositeSpec,
|
||||||
|
DiscreteTensorSpec,
|
||||||
|
UnboundedContinuousTensorSpec,
|
||||||
|
)
|
||||||
|
from torchrl.envs import EnvBase
|
||||||
|
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||||
|
|
||||||
|
from lerobot.common.utils import set_seed
|
||||||
|
|
||||||
|
_has_gym = importlib.util.find_spec("gym") is not None
|
||||||
|
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
|
||||||
|
|
||||||
|
|
||||||
|
class SimxarmEnv(EnvBase):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task,
|
||||||
|
from_pixels: bool = False,
|
||||||
|
pixels_only: bool = False,
|
||||||
|
image_size=None,
|
||||||
|
seed=1337,
|
||||||
|
device="cpu",
|
||||||
|
):
|
||||||
|
super().__init__(device=device, batch_size=[])
|
||||||
|
self.task = task
|
||||||
|
self.from_pixels = from_pixels
|
||||||
|
self.pixels_only = pixels_only
|
||||||
|
self.image_size = image_size
|
||||||
|
|
||||||
|
if pixels_only:
|
||||||
|
assert from_pixels
|
||||||
|
if from_pixels:
|
||||||
|
assert image_size
|
||||||
|
|
||||||
|
if not _has_simxarm:
|
||||||
|
raise ImportError("Cannot import simxarm.")
|
||||||
|
if not _has_gym:
|
||||||
|
raise ImportError("Cannot import gym.")
|
||||||
|
|
||||||
|
import gym
|
||||||
|
from gym.wrappers import TimeLimit
|
||||||
|
from simxarm import TASKS
|
||||||
|
|
||||||
|
if self.task not in TASKS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._env = TASKS[self.task]["env"]()
|
||||||
|
self._env = TimeLimit(self._env, TASKS[self.task]["episode_length"])
|
||||||
|
|
||||||
|
MAX_NUM_ACTIONS = 4
|
||||||
|
num_actions = len(TASKS[self.task]["action_space"])
|
||||||
|
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||||
|
self._action_padding = np.zeros(
|
||||||
|
(MAX_NUM_ACTIONS - num_actions), dtype=np.float32
|
||||||
|
)
|
||||||
|
if "w" not in TASKS[self.task]["action_space"]:
|
||||||
|
self._action_padding[-1] = 1.0
|
||||||
|
|
||||||
|
self._make_spec()
|
||||||
|
self.set_seed(seed)
|
||||||
|
|
||||||
|
def render(self, mode="rgb_array", width=384, height=384):
|
||||||
|
return self._env.render(mode, width=width, height=height)
|
||||||
|
|
||||||
|
def _format_raw_obs(self, raw_obs):
|
||||||
|
if self.from_pixels:
|
||||||
|
camera = self.render(
|
||||||
|
mode="rgb_array", width=self.image_size, height=self.image_size
|
||||||
|
)
|
||||||
|
camera = camera.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||||
|
camera = torch.tensor(camera.copy(), dtype=torch.uint8)
|
||||||
|
|
||||||
|
obs = {"camera": camera}
|
||||||
|
|
||||||
|
if not self.pixels_only:
|
||||||
|
obs["robot_state"] = torch.tensor(
|
||||||
|
self._env.robot_state, dtype=torch.float32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
|
||||||
|
|
||||||
|
obs = TensorDict(obs, batch_size=[])
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
|
td = tensordict
|
||||||
|
if td is None or td.is_empty():
|
||||||
|
raw_obs = self._env.reset()
|
||||||
|
|
||||||
|
td = TensorDict(
|
||||||
|
{
|
||||||
|
"observation": self._format_raw_obs(raw_obs),
|
||||||
|
"done": torch.tensor([False], dtype=torch.bool),
|
||||||
|
},
|
||||||
|
batch_size=[],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
return td
|
||||||
|
|
||||||
|
def _step(self, tensordict: TensorDict):
|
||||||
|
td = tensordict
|
||||||
|
action = td["action"].numpy()
|
||||||
|
# step expects shape=(4,) so we pad if necessary
|
||||||
|
action = np.concatenate([action, self._action_padding])
|
||||||
|
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||||
|
raw_obs, reward, done, info = self._env.step(action)
|
||||||
|
|
||||||
|
td = TensorDict(
|
||||||
|
{
|
||||||
|
"observation": self._format_raw_obs(raw_obs),
|
||||||
|
"reward": torch.tensor([reward], dtype=torch.float32),
|
||||||
|
"done": torch.tensor([done], dtype=torch.bool),
|
||||||
|
"success": torch.tensor([info["success"]], dtype=torch.bool),
|
||||||
|
},
|
||||||
|
batch_size=[],
|
||||||
|
)
|
||||||
|
return td
|
||||||
|
|
||||||
|
def _make_spec(self):
|
||||||
|
obs = {}
|
||||||
|
if self.from_pixels:
|
||||||
|
obs["camera"] = BoundedTensorSpec(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=(3, self.image_size, self.image_size),
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
if not self.pixels_only:
|
||||||
|
obs["robot_state"] = UnboundedContinuousTensorSpec(
|
||||||
|
shape=(len(self._env.robot_state),),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||||
|
obs["state"] = UnboundedContinuousTensorSpec(
|
||||||
|
shape=self._env.observation_space["observation"].shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.observation_spec = CompositeSpec({"observation": obs})
|
||||||
|
|
||||||
|
self.action_spec = _gym_to_torchrl_spec_transform(
|
||||||
|
self._action_space,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.done_spec = DiscreteTensorSpec(
|
||||||
|
2,
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.success_spec = DiscreteTensorSpec(
|
||||||
|
2,
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_seed(self, seed: Optional[int]):
|
||||||
|
set_seed(seed)
|
||||||
|
self._env.seed(seed)
|
|
@ -0,0 +1,450 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import lerobot.common.tdmpc_helper as h
|
||||||
|
|
||||||
|
|
||||||
|
class TOLD(nn.Module):
|
||||||
|
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
action_dim = 4
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
self._encoder = h.enc(cfg)
|
||||||
|
self._dynamics = h.dynamics(
|
||||||
|
cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim
|
||||||
|
)
|
||||||
|
self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1)
|
||||||
|
self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim)
|
||||||
|
self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)])
|
||||||
|
self._V = h.v(cfg)
|
||||||
|
self.apply(h.orthogonal_init)
|
||||||
|
for m in [self._reward, *self._Qs]:
|
||||||
|
m[-1].weight.data.fill_(0)
|
||||||
|
m[-1].bias.data.fill_(0)
|
||||||
|
|
||||||
|
def track_q_grad(self, enable=True):
|
||||||
|
"""Utility function. Enables/disables gradient tracking of Q-networks."""
|
||||||
|
for m in self._Qs:
|
||||||
|
h.set_requires_grad(m, enable)
|
||||||
|
|
||||||
|
def track_v_grad(self, enable=True):
|
||||||
|
"""Utility function. Enables/disables gradient tracking of Q-networks."""
|
||||||
|
if hasattr(self, "_V"):
|
||||||
|
h.set_requires_grad(self._V, enable)
|
||||||
|
|
||||||
|
def encode(self, obs):
|
||||||
|
"""Encodes an observation into its latent representation."""
|
||||||
|
out = self._encoder(obs)
|
||||||
|
if isinstance(obs, dict):
|
||||||
|
# fusion
|
||||||
|
out = torch.stack([v for k, v in out.items()]).mean(dim=0)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def next(self, z, a):
|
||||||
|
"""Predicts next latent state (d) and single-step reward (R)."""
|
||||||
|
x = torch.cat([z, a], dim=-1)
|
||||||
|
return self._dynamics(x), self._reward(x)
|
||||||
|
|
||||||
|
def pi(self, z, std=0):
|
||||||
|
"""Samples an action from the learned policy (pi)."""
|
||||||
|
mu = torch.tanh(self._pi(z))
|
||||||
|
if std > 0:
|
||||||
|
std = torch.ones_like(mu) * std
|
||||||
|
return h.TruncatedNormal(mu, std).sample(clip=0.3)
|
||||||
|
return mu
|
||||||
|
|
||||||
|
def V(self, z):
|
||||||
|
"""Predict state value (V)."""
|
||||||
|
return self._V(z)
|
||||||
|
|
||||||
|
def Q(self, z, a, return_type):
|
||||||
|
"""Predict state-action value (Q)."""
|
||||||
|
assert return_type in {"min", "avg", "all"}
|
||||||
|
x = torch.cat([z, a], dim=-1)
|
||||||
|
|
||||||
|
if return_type == "all":
|
||||||
|
return torch.stack(list(q(x) for q in self._Qs), dim=0)
|
||||||
|
|
||||||
|
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
|
||||||
|
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
|
||||||
|
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||||
|
|
||||||
|
|
||||||
|
class TDMPC(nn.Module):
|
||||||
|
"""Implementation of TD-MPC learning + inference."""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
self.action_dim = 4
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
self.device = torch.device("cuda")
|
||||||
|
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
||||||
|
self.model = TOLD(cfg).cuda()
|
||||||
|
self.model_target = deepcopy(self.model)
|
||||||
|
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||||
|
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
||||||
|
self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||||
|
self.model.eval()
|
||||||
|
self.model_target.eval()
|
||||||
|
self.batch_size = cfg.batch_size
|
||||||
|
|
||||||
|
# TODO(rcadene): clean
|
||||||
|
self.step = 100000
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""Retrieve state dict of TOLD model, including slow-moving target network."""
|
||||||
|
return {
|
||||||
|
"model": self.model.state_dict(),
|
||||||
|
"model_target": self.model_target.state_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def save(self, fp):
|
||||||
|
"""Save state dict of TOLD model to filepath."""
|
||||||
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
|
def load(self, fp):
|
||||||
|
"""Load a saved state dict from filepath into current agent."""
|
||||||
|
d = torch.load(fp)
|
||||||
|
self.model.load_state_dict(d["model"])
|
||||||
|
self.model_target.load_state_dict(d["model_target"])
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, observation, step_count):
|
||||||
|
t0 = step_count.item() == 0
|
||||||
|
obs = {
|
||||||
|
"rgb": observation["camera"],
|
||||||
|
"state": observation["robot_state"],
|
||||||
|
}
|
||||||
|
return self.act(obs, t0=t0, step=self.step)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def act(self, obs, t0=False, step=None):
|
||||||
|
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
||||||
|
if isinstance(obs, dict):
|
||||||
|
obs = {
|
||||||
|
k: torch.tensor(o, dtype=torch.float32, device=self.device).unsqueeze(0)
|
||||||
|
for k, o in obs.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
obs = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(
|
||||||
|
0
|
||||||
|
)
|
||||||
|
z = self.model.encode(obs)
|
||||||
|
if self.cfg.mpc:
|
||||||
|
a = self.plan(z, t0=t0, step=step)
|
||||||
|
else:
|
||||||
|
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
||||||
|
return a.cpu()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def estimate_value(self, z, actions, horizon):
|
||||||
|
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||||
|
G, discount = 0, 1
|
||||||
|
for t in range(horizon):
|
||||||
|
if self.cfg.uncertainty_cost > 0:
|
||||||
|
G -= (
|
||||||
|
discount
|
||||||
|
* self.cfg.uncertainty_cost
|
||||||
|
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
|
||||||
|
)
|
||||||
|
z, reward = self.model.next(z, actions[t])
|
||||||
|
G += discount * reward
|
||||||
|
discount *= self.cfg.discount
|
||||||
|
pi = self.model.pi(z, self.cfg.min_std)
|
||||||
|
G += discount * self.model.Q(z, pi, return_type="min")
|
||||||
|
if self.cfg.uncertainty_cost > 0:
|
||||||
|
G -= (
|
||||||
|
discount
|
||||||
|
* self.cfg.uncertainty_cost
|
||||||
|
* self.model.Q(z, pi, return_type="all").std(dim=0)
|
||||||
|
)
|
||||||
|
return G
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plan(self, z, step=None, t0=True):
|
||||||
|
"""
|
||||||
|
Plan next action using TD-MPC inference.
|
||||||
|
z: latent state.
|
||||||
|
step: current time step. determines e.g. planning horizon.
|
||||||
|
t0: whether current step is the first step of an episode.
|
||||||
|
"""
|
||||||
|
# during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
|
||||||
|
|
||||||
|
assert step is not None
|
||||||
|
# Seed steps
|
||||||
|
if step < self.cfg.seed_steps and self.model.training:
|
||||||
|
return torch.empty(
|
||||||
|
self.action_dim, dtype=torch.float32, device=self.device
|
||||||
|
).uniform_(-1, 1)
|
||||||
|
|
||||||
|
# Sample policy trajectories
|
||||||
|
horizon = int(
|
||||||
|
min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))
|
||||||
|
)
|
||||||
|
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
|
||||||
|
if num_pi_trajs > 0:
|
||||||
|
pi_actions = torch.empty(
|
||||||
|
horizon, num_pi_trajs, self.action_dim, device=self.device
|
||||||
|
)
|
||||||
|
_z = z.repeat(num_pi_trajs, 1)
|
||||||
|
for t in range(horizon):
|
||||||
|
pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
|
||||||
|
_z, _ = self.model.next(_z, pi_actions[t])
|
||||||
|
|
||||||
|
# Initialize state and parameters
|
||||||
|
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
|
||||||
|
mean = torch.zeros(horizon, self.action_dim, device=self.device)
|
||||||
|
std = self.cfg.max_std * torch.ones(
|
||||||
|
horizon, self.action_dim, device=self.device
|
||||||
|
)
|
||||||
|
if not t0 and hasattr(self, "_prev_mean"):
|
||||||
|
mean[:-1] = self._prev_mean[1:]
|
||||||
|
|
||||||
|
# Iterate CEM
|
||||||
|
for i in range(self.cfg.iterations):
|
||||||
|
actions = torch.clamp(
|
||||||
|
mean.unsqueeze(1)
|
||||||
|
+ std.unsqueeze(1)
|
||||||
|
* torch.randn(
|
||||||
|
horizon, self.cfg.num_samples, self.action_dim, device=std.device
|
||||||
|
),
|
||||||
|
-1,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
if num_pi_trajs > 0:
|
||||||
|
actions = torch.cat([actions, pi_actions], dim=1)
|
||||||
|
|
||||||
|
# Compute elite actions
|
||||||
|
value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
|
||||||
|
elite_idxs = torch.topk(
|
||||||
|
value.squeeze(1), self.cfg.num_elites, dim=0
|
||||||
|
).indices
|
||||||
|
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||||
|
|
||||||
|
# Update parameters
|
||||||
|
max_value = elite_value.max(0)[0]
|
||||||
|
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
|
||||||
|
score /= score.sum(0)
|
||||||
|
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (
|
||||||
|
score.sum(0) + 1e-9
|
||||||
|
)
|
||||||
|
_std = torch.sqrt(
|
||||||
|
torch.sum(
|
||||||
|
score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2,
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
/ (score.sum(0) + 1e-9)
|
||||||
|
)
|
||||||
|
_std = _std.clamp_(self.std, self.cfg.max_std)
|
||||||
|
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
score = score.squeeze(1).cpu().numpy()
|
||||||
|
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
|
||||||
|
self._prev_mean = mean
|
||||||
|
mean, std = actions[0], _std[0]
|
||||||
|
a = mean
|
||||||
|
if self.model.training:
|
||||||
|
a += std * torch.randn(self.action_dim, device=std.device)
|
||||||
|
return torch.clamp(a, -1, 1)
|
||||||
|
|
||||||
|
def update_pi(self, zs, acts=None):
|
||||||
|
"""Update policy using a sequence of latent states."""
|
||||||
|
self.pi_optim.zero_grad(set_to_none=True)
|
||||||
|
self.model.track_q_grad(False)
|
||||||
|
self.model.track_v_grad(False)
|
||||||
|
|
||||||
|
info = {}
|
||||||
|
# Advantage Weighted Regression
|
||||||
|
assert acts is not None
|
||||||
|
vs = self.model.V(zs)
|
||||||
|
qs = self.model_target.Q(zs, acts, return_type="min")
|
||||||
|
adv = qs - vs
|
||||||
|
exp_a = torch.exp(adv * self.cfg.A_scaling)
|
||||||
|
exp_a = torch.clamp(exp_a, max=100.0)
|
||||||
|
log_probs = h.gaussian_logprob(self.model.pi(zs) - acts, 0)
|
||||||
|
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
||||||
|
pi_loss = -((exp_a * log_probs).mean(dim=(1, 2)) * rho).mean()
|
||||||
|
info["adv"] = adv[0]
|
||||||
|
|
||||||
|
pi_loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model._pi.parameters(),
|
||||||
|
self.cfg.grad_clip_norm,
|
||||||
|
error_if_nonfinite=False,
|
||||||
|
)
|
||||||
|
self.pi_optim.step()
|
||||||
|
self.model.track_q_grad(True)
|
||||||
|
self.model.track_v_grad(True)
|
||||||
|
|
||||||
|
info["pi_loss"] = pi_loss.item()
|
||||||
|
return pi_loss.item(), info
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _td_target(self, next_z, reward, mask):
|
||||||
|
"""Compute the TD-target from a reward and the observation at the following time step."""
|
||||||
|
next_v = self.model.V(next_z)
|
||||||
|
td_target = reward + self.cfg.discount * mask * next_v
|
||||||
|
return td_target
|
||||||
|
|
||||||
|
def update(self, replay_buffer, step, demo_buffer=None):
|
||||||
|
"""Main update function. Corresponds to one iteration of the model learning."""
|
||||||
|
|
||||||
|
if demo_buffer is not None:
|
||||||
|
# Update oversampling ratio
|
||||||
|
self.demo_batch_size = int(
|
||||||
|
h.linear_schedule(self.cfg.demo_schedule, step) * self.batch_size
|
||||||
|
)
|
||||||
|
replay_buffer.cfg.batch_size = self.batch_size - self.demo_batch_size
|
||||||
|
demo_buffer.cfg.batch_size = self.demo_batch_size
|
||||||
|
else:
|
||||||
|
self.demo_batch_size = 0
|
||||||
|
|
||||||
|
# Sample from interaction dataset
|
||||||
|
obs, next_obses, action, reward, mask, done, idxs, weights = (
|
||||||
|
replay_buffer.sample()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sample from demonstration dataset
|
||||||
|
if self.demo_batch_size > 0:
|
||||||
|
(
|
||||||
|
demo_obs,
|
||||||
|
demo_next_obses,
|
||||||
|
demo_action,
|
||||||
|
demo_reward,
|
||||||
|
demo_mask,
|
||||||
|
demo_done,
|
||||||
|
demo_idxs,
|
||||||
|
demo_weights,
|
||||||
|
) = demo_buffer.sample()
|
||||||
|
|
||||||
|
if isinstance(obs, dict):
|
||||||
|
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
|
||||||
|
next_obses = {
|
||||||
|
k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1)
|
||||||
|
for k in next_obses
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
obs = torch.cat([obs, demo_obs])
|
||||||
|
next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
|
||||||
|
action = torch.cat([action, demo_action], dim=1)
|
||||||
|
reward = torch.cat([reward, demo_reward], dim=1)
|
||||||
|
mask = torch.cat([mask, demo_mask], dim=1)
|
||||||
|
done = torch.cat([done, demo_done], dim=1)
|
||||||
|
idxs = torch.cat([idxs, demo_idxs])
|
||||||
|
weights = torch.cat([weights, demo_weights])
|
||||||
|
|
||||||
|
horizon = self.cfg.horizon
|
||||||
|
loss_mask = torch.ones_like(mask, device=self.device)
|
||||||
|
for t in range(1, horizon):
|
||||||
|
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
|
||||||
|
|
||||||
|
self.optim.zero_grad(set_to_none=True)
|
||||||
|
self.std = h.linear_schedule(self.cfg.std_schedule, step)
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
# Compute targets
|
||||||
|
with torch.no_grad():
|
||||||
|
next_z = self.model.encode(next_obses)
|
||||||
|
z_targets = self.model_target.encode(next_obses)
|
||||||
|
td_targets = self._td_target(next_z, reward, mask)
|
||||||
|
|
||||||
|
# Latent rollout
|
||||||
|
zs = torch.empty(
|
||||||
|
horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device
|
||||||
|
)
|
||||||
|
reward_preds = torch.empty_like(reward, device=self.device)
|
||||||
|
assert reward.shape[0] == horizon
|
||||||
|
z = self.model.encode(obs)
|
||||||
|
zs[0] = z
|
||||||
|
value_info = {"Q": 0.0, "V": 0.0}
|
||||||
|
for t in range(horizon):
|
||||||
|
z, reward_pred = self.model.next(z, action[t])
|
||||||
|
zs[t + 1] = z
|
||||||
|
reward_preds[t] = reward_pred
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
|
||||||
|
|
||||||
|
# Predictions
|
||||||
|
qs = self.model.Q(zs[:-1], action, return_type="all")
|
||||||
|
value_info["Q"] = qs.mean().item()
|
||||||
|
v = self.model.V(zs[:-1])
|
||||||
|
value_info["V"] = v.mean().item()
|
||||||
|
|
||||||
|
# Losses
|
||||||
|
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(
|
||||||
|
-1, 1, 1
|
||||||
|
)
|
||||||
|
consistency_loss = (
|
||||||
|
rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask
|
||||||
|
).sum(dim=0)
|
||||||
|
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
|
||||||
|
q_value_loss, priority_loss = 0, 0
|
||||||
|
for q in range(self.cfg.num_q):
|
||||||
|
q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||||
|
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||||
|
|
||||||
|
self.expectile = h.linear_schedule(self.cfg.expectile, step)
|
||||||
|
v_value_loss = (
|
||||||
|
rho * h.l2_expectile(v_target - v, expectile=self.expectile) * loss_mask
|
||||||
|
).sum(dim=0)
|
||||||
|
|
||||||
|
total_loss = (
|
||||||
|
self.cfg.consistency_coef * consistency_loss
|
||||||
|
+ self.cfg.reward_coef * reward_loss
|
||||||
|
+ self.cfg.value_coef * q_value_loss
|
||||||
|
+ self.cfg.value_coef * v_value_loss
|
||||||
|
)
|
||||||
|
|
||||||
|
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
||||||
|
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
||||||
|
weighted_loss.backward()
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||||
|
)
|
||||||
|
self.optim.step()
|
||||||
|
|
||||||
|
if self.cfg.per:
|
||||||
|
# Update priorities
|
||||||
|
priorities = priority_loss.clamp(max=1e4).detach()
|
||||||
|
replay_buffer.update_priorities(
|
||||||
|
idxs[: replay_buffer.cfg.batch_size],
|
||||||
|
priorities[: replay_buffer.cfg.batch_size],
|
||||||
|
)
|
||||||
|
if self.demo_batch_size > 0:
|
||||||
|
demo_buffer.update_priorities(
|
||||||
|
demo_idxs, priorities[replay_buffer.cfg.batch_size :]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update policy + target network
|
||||||
|
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
|
||||||
|
|
||||||
|
if step % self.cfg.update_freq == 0:
|
||||||
|
h.ema(self.model._encoder, self.model_target._encoder, self.cfg.tau)
|
||||||
|
h.ema(self.model._Qs, self.model_target._Qs, self.cfg.tau)
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
metrics = {
|
||||||
|
"consistency_loss": float(consistency_loss.mean().item()),
|
||||||
|
"reward_loss": float(reward_loss.mean().item()),
|
||||||
|
"Q_value_loss": float(q_value_loss.mean().item()),
|
||||||
|
"V_value_loss": float(v_value_loss.mean().item()),
|
||||||
|
"total_loss": float(total_loss.mean().item()),
|
||||||
|
"weighted_loss": float(weighted_loss.mean().item()),
|
||||||
|
"grad_norm": float(grad_norm),
|
||||||
|
}
|
||||||
|
for key in ["demo_batch_size", "expectile"]:
|
||||||
|
if hasattr(self, key):
|
||||||
|
metrics[key] = getattr(self, key)
|
||||||
|
metrics.update(value_info)
|
||||||
|
metrics.update(pi_update_info)
|
||||||
|
|
||||||
|
return metrics
|
|
@ -0,0 +1,829 @@
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import distributions as pyd
|
||||||
|
from torch.distributions.utils import _standard_normal
|
||||||
|
|
||||||
|
__REDUCE__ = lambda b: "mean" if b else "none"
|
||||||
|
|
||||||
|
|
||||||
|
def l1(pred, target, reduce=False):
|
||||||
|
"""Computes the L1-loss between predictions and targets."""
|
||||||
|
return F.l1_loss(pred, target, reduction=__REDUCE__(reduce))
|
||||||
|
|
||||||
|
|
||||||
|
def mse(pred, target, reduce=False):
|
||||||
|
"""Computes the MSE loss between predictions and targets."""
|
||||||
|
return F.mse_loss(pred, target, reduction=__REDUCE__(reduce))
|
||||||
|
|
||||||
|
|
||||||
|
def l2_expectile(diff, expectile=0.7, reduce=False):
|
||||||
|
weight = torch.where(diff > 0, expectile, (1 - expectile))
|
||||||
|
loss = weight * (diff**2)
|
||||||
|
reduction = __REDUCE__(reduce)
|
||||||
|
if reduction == "mean":
|
||||||
|
return torch.mean(loss)
|
||||||
|
elif reduction == "sum":
|
||||||
|
return torch.sum(loss)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _get_out_shape(in_shape, layers):
|
||||||
|
"""Utility function. Returns the output shape of a network for a given input shape."""
|
||||||
|
x = torch.randn(*in_shape).unsqueeze(0)
|
||||||
|
return (
|
||||||
|
(nn.Sequential(*layers) if isinstance(layers, list) else layers)(x)
|
||||||
|
.squeeze(0)
|
||||||
|
.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian_logprob(eps, log_std):
|
||||||
|
"""Compute Gaussian log probability."""
|
||||||
|
residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True)
|
||||||
|
return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def squash(mu, pi, log_pi):
|
||||||
|
"""Apply squashing function."""
|
||||||
|
mu = torch.tanh(mu)
|
||||||
|
pi = torch.tanh(pi)
|
||||||
|
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
|
||||||
|
return mu, pi, log_pi
|
||||||
|
|
||||||
|
|
||||||
|
def orthogonal_init(m):
|
||||||
|
"""Orthogonal layer initialization."""
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.orthogonal_(m.weight.data)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.Conv2d):
|
||||||
|
gain = nn.init.calculate_gain("relu")
|
||||||
|
nn.init.orthogonal_(m.weight.data, gain)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def ema(m, m_target, tau):
|
||||||
|
"""Update slow-moving average of online network (target network) at rate tau."""
|
||||||
|
with torch.no_grad():
|
||||||
|
for p, p_target in zip(m.parameters(), m_target.parameters()):
|
||||||
|
p_target.data.lerp_(p.data, tau)
|
||||||
|
|
||||||
|
|
||||||
|
def set_requires_grad(net, value):
|
||||||
|
"""Enable/disable gradients for a given (sub)network."""
|
||||||
|
for param in net.parameters():
|
||||||
|
param.requires_grad_(value)
|
||||||
|
|
||||||
|
|
||||||
|
class TruncatedNormal(pyd.Normal):
|
||||||
|
"""Utility class implementing the truncated normal distribution."""
|
||||||
|
|
||||||
|
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
|
||||||
|
super().__init__(loc, scale, validate_args=False)
|
||||||
|
self.low = low
|
||||||
|
self.high = high
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def _clamp(self, x):
|
||||||
|
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
|
||||||
|
x = x - x.detach() + clamped_x.detach()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def sample(self, clip=None, sample_shape=torch.Size()):
|
||||||
|
shape = self._extended_shape(sample_shape)
|
||||||
|
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
|
||||||
|
eps *= self.scale
|
||||||
|
if clip is not None:
|
||||||
|
eps = torch.clamp(eps, -clip, clip)
|
||||||
|
x = self.loc + eps
|
||||||
|
return self._clamp(x)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeImg(nn.Module):
|
||||||
|
"""Normalizes pixel observations to [0,1) range."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.div(255.0)
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(nn.Module):
|
||||||
|
"""Flattens its input to a (batched) vector."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.view(x.size(0), -1)
|
||||||
|
|
||||||
|
|
||||||
|
def enc(cfg):
|
||||||
|
obs_shape = {
|
||||||
|
"rgb": (3, cfg.img_size, cfg.img_size),
|
||||||
|
"state": (4,),
|
||||||
|
}
|
||||||
|
|
||||||
|
"""Returns a TOLD encoder."""
|
||||||
|
pixels_enc_layers, state_enc_layers = None, None
|
||||||
|
if cfg.modality in {"pixels", "all"}:
|
||||||
|
C = int(3 * cfg.frame_stack)
|
||||||
|
pixels_enc_layers = [
|
||||||
|
NormalizeImg(),
|
||||||
|
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
|
||||||
|
nn.ReLU(),
|
||||||
|
]
|
||||||
|
out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers)
|
||||||
|
pixels_enc_layers.extend(
|
||||||
|
[
|
||||||
|
Flatten(),
|
||||||
|
nn.Linear(np.prod(out_shape), cfg.latent_dim),
|
||||||
|
nn.LayerNorm(cfg.latent_dim),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if cfg.modality == "pixels":
|
||||||
|
return ConvExt(nn.Sequential(*pixels_enc_layers))
|
||||||
|
if cfg.modality in {"state", "all"}:
|
||||||
|
state_dim = obs_shape[0] if cfg.modality == "state" else obs_shape["state"][0]
|
||||||
|
state_enc_layers = [
|
||||||
|
nn.Linear(state_dim, cfg.enc_dim),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Linear(cfg.enc_dim, cfg.latent_dim),
|
||||||
|
nn.LayerNorm(cfg.latent_dim),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
]
|
||||||
|
if cfg.modality == "state":
|
||||||
|
return nn.Sequential(*state_enc_layers)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
encoders = {}
|
||||||
|
for k in obs_shape:
|
||||||
|
if k == "state":
|
||||||
|
encoders[k] = nn.Sequential(*state_enc_layers)
|
||||||
|
elif k.endswith("rgb"):
|
||||||
|
encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return Multiplexer(nn.ModuleDict(encoders))
|
||||||
|
|
||||||
|
|
||||||
|
def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
|
||||||
|
"""Returns an MLP."""
|
||||||
|
if isinstance(mlp_dim, int):
|
||||||
|
mlp_dim = [mlp_dim, mlp_dim]
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Linear(in_dim, mlp_dim[0]),
|
||||||
|
nn.LayerNorm(mlp_dim[0]),
|
||||||
|
act_fn,
|
||||||
|
nn.Linear(mlp_dim[0], mlp_dim[1]),
|
||||||
|
nn.LayerNorm(mlp_dim[1]),
|
||||||
|
act_fn,
|
||||||
|
nn.Linear(mlp_dim[1], out_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
|
||||||
|
"""Returns a dynamics network."""
|
||||||
|
return nn.Sequential(
|
||||||
|
mlp(in_dim, mlp_dim, out_dim, act_fn),
|
||||||
|
nn.LayerNorm(out_dim),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def q(cfg):
|
||||||
|
action_dim = 4
|
||||||
|
"""Returns a Q-function that uses Layer Normalization."""
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
|
||||||
|
nn.LayerNorm(cfg.mlp_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Linear(cfg.mlp_dim, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def v(cfg):
|
||||||
|
"""Returns a state value function that uses Layer Normalization."""
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Linear(cfg.latent_dim, cfg.mlp_dim),
|
||||||
|
nn.LayerNorm(cfg.mlp_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Linear(cfg.mlp_dim, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def aug(cfg):
|
||||||
|
obs_shape = {
|
||||||
|
"rgb": (3, cfg.img_size, cfg.img_size),
|
||||||
|
"state": (4,),
|
||||||
|
}
|
||||||
|
|
||||||
|
"""Multiplex augmentation"""
|
||||||
|
if cfg.modality == "state":
|
||||||
|
return nn.Identity()
|
||||||
|
elif cfg.modality == "pixels":
|
||||||
|
return RandomShiftsAug(cfg)
|
||||||
|
else:
|
||||||
|
augs = {}
|
||||||
|
for k in obs_shape:
|
||||||
|
if k == "state":
|
||||||
|
augs[k] = nn.Identity()
|
||||||
|
elif k.endswith("rgb"):
|
||||||
|
augs[k] = RandomShiftsAug(cfg)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return Multiplexer(nn.ModuleDict(augs))
|
||||||
|
|
||||||
|
|
||||||
|
class ConvExt(nn.Module):
|
||||||
|
"""Auxiliary conv net accommodating high-dim input"""
|
||||||
|
|
||||||
|
def __init__(self, conv):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = conv
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if x.ndim > 4:
|
||||||
|
batch_shape = x.shape[:-3]
|
||||||
|
out = self.conv(x.view(-1, *x.shape[-3:]))
|
||||||
|
out = out.view(*batch_shape, *out.shape[1:])
|
||||||
|
else:
|
||||||
|
out = self.conv(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Multiplexer(nn.Module):
|
||||||
|
"""Model multiplexer"""
|
||||||
|
|
||||||
|
def __init__(self, choices):
|
||||||
|
super().__init__()
|
||||||
|
self.choices = choices
|
||||||
|
|
||||||
|
def forward(self, x, key=None):
|
||||||
|
if isinstance(x, dict):
|
||||||
|
if key is not None:
|
||||||
|
return self.choices[key](x)
|
||||||
|
return {k: self.choices[k](_x) for k, _x in x.items()}
|
||||||
|
return self.choices(x)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomShiftsAug(nn.Module):
|
||||||
|
"""
|
||||||
|
Random shift image augmentation.
|
||||||
|
Adapted from https://github.com/facebookresearch/drqv2
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
assert cfg.modality in {"pixels", "all"}
|
||||||
|
self.pad = int(cfg.img_size / 21)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
n, c, h, w = x.size()
|
||||||
|
assert h == w
|
||||||
|
padding = tuple([self.pad] * 4)
|
||||||
|
x = F.pad(x, padding, "replicate")
|
||||||
|
eps = 1.0 / (h + 2 * self.pad)
|
||||||
|
arange = torch.linspace(
|
||||||
|
-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype
|
||||||
|
)[:h]
|
||||||
|
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
|
||||||
|
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
|
||||||
|
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
|
||||||
|
shift = torch.randint(
|
||||||
|
0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
shift *= 2.0 / (h + 2 * self.pad)
|
||||||
|
grid = base_grid + shift
|
||||||
|
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Episode(object):
|
||||||
|
"""Storage object for a single episode."""
|
||||||
|
|
||||||
|
def __init__(self, cfg, init_obs):
|
||||||
|
action_dim = 4
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
self.device = torch.device(cfg.buffer_device)
|
||||||
|
if cfg.modality in {"pixels", "state"}:
|
||||||
|
dtype = torch.float32 if cfg.modality == "state" else torch.uint8
|
||||||
|
self.obses = torch.empty(
|
||||||
|
(cfg.episode_length + 1, *init_obs.shape),
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
|
||||||
|
elif cfg.modality == "all":
|
||||||
|
self.obses = {}
|
||||||
|
for k, v in init_obs.items():
|
||||||
|
assert k in {"rgb", "state"}
|
||||||
|
dtype = torch.float32 if k == "state" else torch.uint8
|
||||||
|
self.obses[k] = torch.empty(
|
||||||
|
(cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device
|
||||||
|
)
|
||||||
|
self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
self.actions = torch.empty(
|
||||||
|
(cfg.episode_length, action_dim), dtype=torch.float32, device=self.device
|
||||||
|
)
|
||||||
|
self.rewards = torch.empty(
|
||||||
|
(cfg.episode_length,), dtype=torch.float32, device=self.device
|
||||||
|
)
|
||||||
|
self.dones = torch.empty(
|
||||||
|
(cfg.episode_length,), dtype=torch.bool, device=self.device
|
||||||
|
)
|
||||||
|
self.masks = torch.empty(
|
||||||
|
(cfg.episode_length,), dtype=torch.float32, device=self.device
|
||||||
|
)
|
||||||
|
self.cumulative_reward = 0
|
||||||
|
self.done = False
|
||||||
|
self.success = False
|
||||||
|
self._idx = 0
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._idx
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
|
||||||
|
"""Constructs an episode from a trajectory."""
|
||||||
|
|
||||||
|
if cfg.modality in {"pixels", "state"}:
|
||||||
|
episode = cls(cfg, obses[0])
|
||||||
|
episode.obses[1:] = torch.tensor(
|
||||||
|
obses[1:], dtype=episode.obses.dtype, device=episode.device
|
||||||
|
)
|
||||||
|
elif cfg.modality == "all":
|
||||||
|
episode = cls(cfg, {k: v[0] for k, v in obses.items()})
|
||||||
|
for k, v in obses.items():
|
||||||
|
episode.obses[k][1:] = torch.tensor(
|
||||||
|
obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
episode.actions = torch.tensor(
|
||||||
|
actions, dtype=episode.actions.dtype, device=episode.device
|
||||||
|
)
|
||||||
|
episode.rewards = torch.tensor(
|
||||||
|
rewards, dtype=episode.rewards.dtype, device=episode.device
|
||||||
|
)
|
||||||
|
episode.dones = (
|
||||||
|
torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
|
||||||
|
if dones is not None
|
||||||
|
else torch.zeros_like(episode.dones)
|
||||||
|
)
|
||||||
|
episode.masks = (
|
||||||
|
torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
|
||||||
|
if masks is not None
|
||||||
|
else torch.ones_like(episode.masks)
|
||||||
|
)
|
||||||
|
episode.cumulative_reward = torch.sum(episode.rewards)
|
||||||
|
episode.done = True
|
||||||
|
episode._idx = cfg.episode_length
|
||||||
|
return episode
|
||||||
|
|
||||||
|
@property
|
||||||
|
def first(self):
|
||||||
|
return len(self) == 0
|
||||||
|
|
||||||
|
def __add__(self, transition):
|
||||||
|
self.add(*transition)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add(self, obs, action, reward, done, mask=1.0, success=False):
|
||||||
|
"""Add a transition into the episode."""
|
||||||
|
if isinstance(obs, dict):
|
||||||
|
for k, v in obs.items():
|
||||||
|
self.obses[k][self._idx + 1] = torch.tensor(
|
||||||
|
v, dtype=self.obses[k].dtype, device=self.obses[k].device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.obses[self._idx + 1] = torch.tensor(
|
||||||
|
obs, dtype=self.obses.dtype, device=self.obses.device
|
||||||
|
)
|
||||||
|
self.actions[self._idx] = action
|
||||||
|
self.rewards[self._idx] = reward
|
||||||
|
self.dones[self._idx] = done
|
||||||
|
self.masks[self._idx] = mask
|
||||||
|
self.cumulative_reward += reward
|
||||||
|
self.done = done
|
||||||
|
self.success = self.success or success
|
||||||
|
self._idx += 1
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayBuffer:
|
||||||
|
"""
|
||||||
|
Storage and sampling functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg, dataset=None):
|
||||||
|
action_dim = 4
|
||||||
|
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (4,)}
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
self.device = torch.device(cfg.buffer_device)
|
||||||
|
print("Replay buffer device: ", self.device)
|
||||||
|
|
||||||
|
if dataset is not None:
|
||||||
|
self.capacity = max(dataset["rewards"].shape[0], cfg.max_buffer_size)
|
||||||
|
else:
|
||||||
|
self.capacity = min(cfg.train_steps, cfg.max_buffer_size)
|
||||||
|
|
||||||
|
if cfg.modality in {"pixels", "state"}:
|
||||||
|
dtype = torch.float32 if cfg.modality == "state" else torch.uint8
|
||||||
|
# Note self.obs_shape always has single frame, which is different from cfg.obs_shape
|
||||||
|
self.obs_shape = (
|
||||||
|
obs_shape if cfg.modality == "state" else (3, *obs_shape[-2:])
|
||||||
|
)
|
||||||
|
self._obs = torch.zeros(
|
||||||
|
(self.capacity + cfg.horizon - 1, *self.obs_shape),
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self._next_obs = torch.zeros(
|
||||||
|
(self.capacity + cfg.horizon - 1, *self.obs_shape),
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
elif cfg.modality == "all":
|
||||||
|
self.obs_shape = {}
|
||||||
|
self._obs, self._next_obs = {}, {}
|
||||||
|
for k, v in obs_shape.items():
|
||||||
|
assert k in {"rgb", "state"}
|
||||||
|
dtype = torch.float32 if k == "state" else torch.uint8
|
||||||
|
self.obs_shape[k] = v if k == "state" else (3, *v[-2:])
|
||||||
|
self._obs[k] = torch.zeros(
|
||||||
|
(self.capacity + cfg.horizon - 1, *self.obs_shape[k]),
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self._next_obs[k] = self._obs[k].clone()
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
self._action = torch.zeros(
|
||||||
|
(self.capacity + cfg.horizon - 1, action_dim),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self._reward = torch.zeros(
|
||||||
|
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
|
||||||
|
)
|
||||||
|
self._mask = torch.zeros(
|
||||||
|
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
|
||||||
|
)
|
||||||
|
self._done = torch.zeros(
|
||||||
|
(self.capacity + cfg.horizon - 1,), dtype=torch.bool, device=self.device
|
||||||
|
)
|
||||||
|
self._priorities = torch.ones(
|
||||||
|
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
|
||||||
|
)
|
||||||
|
self._eps = 1e-6
|
||||||
|
self._full = False
|
||||||
|
self.idx = 0
|
||||||
|
if dataset is not None:
|
||||||
|
self.init_from_offline_dataset(dataset)
|
||||||
|
|
||||||
|
self._aug = aug(cfg)
|
||||||
|
|
||||||
|
def init_from_offline_dataset(self, dataset):
|
||||||
|
"""Initialize the replay buffer from an offline dataset."""
|
||||||
|
assert self.idx == 0 and not self._full
|
||||||
|
n_transitions = int(len(dataset["rewards"]) * self.cfg.data_first_percent)
|
||||||
|
|
||||||
|
def copy_data(dst, src, n):
|
||||||
|
assert isinstance(dst, dict) == isinstance(src, dict)
|
||||||
|
if isinstance(dst, dict):
|
||||||
|
for k in dst:
|
||||||
|
copy_data(dst[k], src[k], n)
|
||||||
|
else:
|
||||||
|
dst[:n] = torch.from_numpy(src[:n])
|
||||||
|
|
||||||
|
copy_data(self._obs, dataset["observations"], n_transitions)
|
||||||
|
copy_data(self._next_obs, dataset["next_observations"], n_transitions)
|
||||||
|
copy_data(self._action, dataset["actions"], n_transitions)
|
||||||
|
copy_data(self._reward, dataset["rewards"], n_transitions)
|
||||||
|
copy_data(self._mask, dataset["masks"], n_transitions)
|
||||||
|
copy_data(self._done, dataset["dones"], n_transitions)
|
||||||
|
self.idx = (self.idx + n_transitions) % self.capacity
|
||||||
|
self._full = n_transitions >= self.capacity
|
||||||
|
|
||||||
|
def __add__(self, episode: Episode):
|
||||||
|
self.add(episode)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add(self, episode: Episode):
|
||||||
|
"""Add an episode to the replay buffer."""
|
||||||
|
if self.idx + len(episode) > self.capacity:
|
||||||
|
print("Warning: episode got truncated")
|
||||||
|
ep_len = min(len(episode), self.capacity - self.idx)
|
||||||
|
idxs = slice(self.idx, self.idx + ep_len)
|
||||||
|
assert self.idx + ep_len <= self.capacity
|
||||||
|
if self.cfg.modality in {"pixels", "state"}:
|
||||||
|
self._obs[idxs] = (
|
||||||
|
episode.obses[:ep_len]
|
||||||
|
if self.cfg.modality == "state"
|
||||||
|
else episode.obses[:ep_len, -3:]
|
||||||
|
)
|
||||||
|
self._next_obs[idxs] = (
|
||||||
|
episode.obses[1 : ep_len + 1]
|
||||||
|
if self.cfg.modality == "state"
|
||||||
|
else episode.obses[1 : ep_len + 1, -3:]
|
||||||
|
)
|
||||||
|
elif self.cfg.modality == "all":
|
||||||
|
for k, v in episode.obses.items():
|
||||||
|
assert k in {"rgb", "state"}
|
||||||
|
assert k in self._obs
|
||||||
|
assert k in self._next_obs
|
||||||
|
if k == "rgb":
|
||||||
|
self._obs[k][idxs] = episode.obses[k][:ep_len, -3:]
|
||||||
|
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1, -3:]
|
||||||
|
else:
|
||||||
|
self._obs[k][idxs] = episode.obses[k][:ep_len]
|
||||||
|
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1]
|
||||||
|
self._action[idxs] = episode.actions[:ep_len]
|
||||||
|
self._reward[idxs] = episode.rewards[:ep_len]
|
||||||
|
self._mask[idxs] = episode.masks[:ep_len]
|
||||||
|
self._done[idxs] = episode.dones[:ep_len]
|
||||||
|
self._done[self.idx + ep_len - 1] = True # in case truncated
|
||||||
|
if self._full:
|
||||||
|
max_priority = (
|
||||||
|
self._priorities[: self.capacity].max().to(self.device).item()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
max_priority = (
|
||||||
|
1.0
|
||||||
|
if self.idx == 0
|
||||||
|
else self._priorities[: self.idx].max().to(self.device).item()
|
||||||
|
)
|
||||||
|
new_priorities = torch.full((ep_len,), max_priority, device=self.device)
|
||||||
|
self._priorities[idxs] = new_priorities
|
||||||
|
self.idx = (self.idx + ep_len) % self.capacity
|
||||||
|
self._full = self._full or self.idx == 0
|
||||||
|
|
||||||
|
def update_priorities(self, idxs, priorities):
|
||||||
|
"""Update priorities for Prioritized Experience Replay (PER)"""
|
||||||
|
self._priorities[idxs] = priorities.squeeze(1).to(self.device) + self._eps
|
||||||
|
|
||||||
|
def _get_obs(self, arr, idxs):
|
||||||
|
"""Retrieve observations by indices"""
|
||||||
|
if isinstance(arr, dict):
|
||||||
|
return {k: self._get_obs(v, idxs) for k, v in arr.items()}
|
||||||
|
if arr.ndim <= 2: # if self.cfg.modality == 'state':
|
||||||
|
return arr[idxs].cuda()
|
||||||
|
obs = torch.empty(
|
||||||
|
(self.cfg.batch_size, 3 * self.cfg.frame_stack, *arr.shape[-2:]),
|
||||||
|
dtype=arr.dtype,
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
)
|
||||||
|
obs[:, -3:] = arr[idxs].cuda()
|
||||||
|
_idxs = idxs.clone()
|
||||||
|
mask = torch.ones_like(_idxs, dtype=torch.bool)
|
||||||
|
for i in range(1, self.cfg.frame_stack):
|
||||||
|
mask[_idxs % self.cfg.episode_length == 0] = False
|
||||||
|
_idxs[mask] -= 1
|
||||||
|
obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda()
|
||||||
|
return obs.float()
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
"""Sample transitions from the replay buffer."""
|
||||||
|
probs = (
|
||||||
|
self._priorities[: self.capacity]
|
||||||
|
if self._full
|
||||||
|
else self._priorities[: self.idx]
|
||||||
|
) ** self.cfg.per_alpha
|
||||||
|
probs /= probs.sum()
|
||||||
|
total = len(probs)
|
||||||
|
idxs = torch.from_numpy(
|
||||||
|
np.random.choice(
|
||||||
|
total,
|
||||||
|
self.cfg.batch_size,
|
||||||
|
p=probs.cpu().numpy(),
|
||||||
|
replace=not self._full,
|
||||||
|
)
|
||||||
|
).to(self.device)
|
||||||
|
weights = (total * probs[idxs]) ** (-self.cfg.per_beta)
|
||||||
|
weights /= weights.max()
|
||||||
|
|
||||||
|
idxs_in_horizon = torch.stack([idxs + t for t in range(self.cfg.horizon)])
|
||||||
|
|
||||||
|
obs = self._aug(self._get_obs(self._obs, idxs))
|
||||||
|
next_obs = [
|
||||||
|
self._aug(self._get_obs(self._next_obs, _idxs)) for _idxs in idxs_in_horizon
|
||||||
|
]
|
||||||
|
if isinstance(next_obs[0], dict):
|
||||||
|
next_obs = {k: torch.stack([o[k] for o in next_obs]) for k in next_obs[0]}
|
||||||
|
else:
|
||||||
|
next_obs = torch.stack(next_obs)
|
||||||
|
action = self._action[idxs_in_horizon]
|
||||||
|
reward = self._reward[idxs_in_horizon]
|
||||||
|
mask = self._mask[idxs_in_horizon]
|
||||||
|
done = self._done[idxs_in_horizon]
|
||||||
|
|
||||||
|
if not action.is_cuda:
|
||||||
|
action, reward, mask, done, idxs, weights = (
|
||||||
|
action.cuda(),
|
||||||
|
reward.cuda(),
|
||||||
|
mask.cuda(),
|
||||||
|
done.cuda(),
|
||||||
|
idxs.cuda(),
|
||||||
|
weights.cuda(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
obs,
|
||||||
|
next_obs,
|
||||||
|
action,
|
||||||
|
reward.unsqueeze(2),
|
||||||
|
mask.unsqueeze(2),
|
||||||
|
done.unsqueeze(2),
|
||||||
|
idxs,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
"""Save the replay buffer to path"""
|
||||||
|
print(f"saving replay buffer to '{path}'...")
|
||||||
|
sz = self.capacity if self._full else self.idx
|
||||||
|
dataset = {
|
||||||
|
"observations": (
|
||||||
|
{k: v[:sz].cpu().numpy() for k, v in self._obs.items()}
|
||||||
|
if isinstance(self._obs, dict)
|
||||||
|
else self._obs[:sz].cpu().numpy()
|
||||||
|
),
|
||||||
|
"next_observations": (
|
||||||
|
{k: v[:sz].cpu().numpy() for k, v in self._next_obs.items()}
|
||||||
|
if isinstance(self._next_obs, dict)
|
||||||
|
else self._next_obs[:sz].cpu().numpy()
|
||||||
|
),
|
||||||
|
"actions": self._action[:sz].cpu().numpy(),
|
||||||
|
"rewards": self._reward[:sz].cpu().numpy(),
|
||||||
|
"dones": self._done[:sz].cpu().numpy(),
|
||||||
|
"masks": self._mask[:sz].cpu().numpy(),
|
||||||
|
}
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
pickle.dump(dataset, f)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
|
||||||
|
"""Construct a dataset for env"""
|
||||||
|
required_keys = [
|
||||||
|
"observations",
|
||||||
|
"next_observations",
|
||||||
|
"actions",
|
||||||
|
"rewards",
|
||||||
|
"dones",
|
||||||
|
"masks",
|
||||||
|
]
|
||||||
|
|
||||||
|
if cfg.task.startswith("xarm"):
|
||||||
|
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl")
|
||||||
|
print(f"Using offline dataset '{dataset_path}'")
|
||||||
|
with open(dataset_path, "rb") as f:
|
||||||
|
dataset_dict = pickle.load(f)
|
||||||
|
for k in required_keys:
|
||||||
|
if k not in dataset_dict and k[:-1] in dataset_dict:
|
||||||
|
dataset_dict[k] = dataset_dict.pop(k[:-1])
|
||||||
|
elif cfg.task.startswith("legged"):
|
||||||
|
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl")
|
||||||
|
print(f"Using offline dataset '{dataset_path}'")
|
||||||
|
with open(dataset_path, "rb") as f:
|
||||||
|
dataset_dict = pickle.load(f)
|
||||||
|
dataset_dict["actions"] /= env.unwrapped.clip_actions
|
||||||
|
print(f"clip_actions={env.unwrapped.clip_actions}")
|
||||||
|
else:
|
||||||
|
import d4rl
|
||||||
|
|
||||||
|
dataset_dict = d4rl.qlearning_dataset(env)
|
||||||
|
dones = np.full_like(dataset_dict["rewards"], False, dtype=bool)
|
||||||
|
|
||||||
|
for i in range(len(dones) - 1):
|
||||||
|
if (
|
||||||
|
np.linalg.norm(
|
||||||
|
dataset_dict["observations"][i + 1]
|
||||||
|
- dataset_dict["next_observations"][i]
|
||||||
|
)
|
||||||
|
> 1e-6
|
||||||
|
or dataset_dict["terminals"][i] == 1.0
|
||||||
|
):
|
||||||
|
dones[i] = True
|
||||||
|
|
||||||
|
dones[-1] = True
|
||||||
|
|
||||||
|
dataset_dict["masks"] = 1.0 - dataset_dict["terminals"]
|
||||||
|
del dataset_dict["terminals"]
|
||||||
|
|
||||||
|
for k, v in dataset_dict.items():
|
||||||
|
dataset_dict[k] = v.astype(np.float32)
|
||||||
|
|
||||||
|
dataset_dict["dones"] = dones
|
||||||
|
|
||||||
|
if cfg.is_data_clip:
|
||||||
|
lim = 1 - cfg.data_clip_eps
|
||||||
|
dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim)
|
||||||
|
reward_normalizer = get_reward_normalizer(cfg, dataset_dict)
|
||||||
|
dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"])
|
||||||
|
|
||||||
|
for key in required_keys:
|
||||||
|
assert key in dataset_dict.keys(), f"Missing `{key}` in dataset."
|
||||||
|
|
||||||
|
if return_reward_normalizer:
|
||||||
|
return dataset_dict, reward_normalizer
|
||||||
|
return dataset_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_trajectory_boundaries_and_returns(dataset):
|
||||||
|
"""
|
||||||
|
Split dataset into trajectories and compute returns
|
||||||
|
"""
|
||||||
|
episode_starts = [0]
|
||||||
|
episode_ends = []
|
||||||
|
|
||||||
|
episode_return = 0
|
||||||
|
episode_returns = []
|
||||||
|
|
||||||
|
n_transitions = len(dataset["rewards"])
|
||||||
|
|
||||||
|
for i in range(n_transitions):
|
||||||
|
episode_return += dataset["rewards"][i]
|
||||||
|
|
||||||
|
if dataset["dones"][i]:
|
||||||
|
episode_returns.append(episode_return)
|
||||||
|
episode_ends.append(i + 1)
|
||||||
|
if i + 1 < n_transitions:
|
||||||
|
episode_starts.append(i + 1)
|
||||||
|
episode_return = 0.0
|
||||||
|
|
||||||
|
return episode_starts, episode_ends, episode_returns
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_returns(dataset, scaling=1000):
|
||||||
|
"""
|
||||||
|
Normalize returns in the dataset
|
||||||
|
"""
|
||||||
|
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
|
||||||
|
dataset["rewards"] /= np.max(episode_returns) - np.min(episode_returns)
|
||||||
|
dataset["rewards"] *= scaling
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_reward_normalizer(cfg, dataset):
|
||||||
|
"""
|
||||||
|
Get a reward normalizer for the dataset
|
||||||
|
"""
|
||||||
|
if cfg.task.startswith("xarm"):
|
||||||
|
return lambda x: x
|
||||||
|
elif "maze" in cfg.task:
|
||||||
|
return lambda x: x - 1.0
|
||||||
|
elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]:
|
||||||
|
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
|
||||||
|
return (
|
||||||
|
lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
|
||||||
|
)
|
||||||
|
elif hasattr(cfg, "reward_scale"):
|
||||||
|
return lambda x: x * cfg.reward_scale
|
||||||
|
return lambda x: x
|
||||||
|
|
||||||
|
|
||||||
|
def linear_schedule(schdl, step):
|
||||||
|
"""
|
||||||
|
Outputs values following a linear decay schedule.
|
||||||
|
Adapted from https://github.com/facebookresearch/drqv2
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return float(schdl)
|
||||||
|
except ValueError:
|
||||||
|
match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl)
|
||||||
|
if match:
|
||||||
|
init, final, start, end = [float(g) for g in match.groups()]
|
||||||
|
mix = np.clip((step - start) / (end - start), 0.0, 1.0)
|
||||||
|
return (1.0 - mix) * init + mix * final
|
||||||
|
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
|
||||||
|
if match:
|
||||||
|
init, final, duration = [float(g) for g in match.groups()]
|
||||||
|
mix = np.clip(step / duration, 0.0, 1.0)
|
||||||
|
return (1.0 - mix) * init + mix * final
|
||||||
|
raise NotImplementedError(schdl)
|
|
@ -0,0 +1,12 @@
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
"""Set seed for reproducibility."""
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
|
@ -5,70 +5,54 @@ import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
from tensordict.nn import TensorDictModule
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from lerobot.lib.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.lib.tdmpc import TDMPC
|
from lerobot.common.tdmpc import TDMPC
|
||||||
from lerobot.lib.utils import set_seed
|
from lerobot.common.utils import set_seed
|
||||||
|
|
||||||
|
|
||||||
def eval_agent(
|
def eval_policy(
|
||||||
env, agent, num_episodes: int, save_video: bool = False, video_path: Path = None
|
env, policy, num_episodes: int, save_video: bool = False, video_dir: Path = None
|
||||||
):
|
):
|
||||||
"""Evaluate a trained agent and optionally save a video."""
|
rewards = []
|
||||||
if save_video:
|
successes = []
|
||||||
assert video_path is not None
|
|
||||||
assert video_path.suffix == ".mp4"
|
|
||||||
episode_rewards = []
|
|
||||||
episode_successes = []
|
|
||||||
episode_lengths = []
|
|
||||||
for i in range(num_episodes):
|
for i in range(num_episodes):
|
||||||
td = env.reset()
|
ep_frames = []
|
||||||
obs = {}
|
|
||||||
obs["rgb"] = td["observation"]["camera"]
|
|
||||||
obs["state"] = td["observation"]["robot_state"]
|
|
||||||
|
|
||||||
done = False
|
def rendering_callback(env, td=None):
|
||||||
ep_reward = 0
|
nonlocal ep_frames
|
||||||
t = 0
|
frame = env.render()
|
||||||
ep_success = False
|
ep_frames.append(frame)
|
||||||
|
|
||||||
|
tensordict = env.reset()
|
||||||
|
# render first frame before rollout
|
||||||
|
rendering_callback(env)
|
||||||
|
|
||||||
|
rollout = env.rollout(
|
||||||
|
max_steps=30,
|
||||||
|
policy=policy,
|
||||||
|
callback=rendering_callback,
|
||||||
|
auto_reset=False,
|
||||||
|
tensordict=tensordict,
|
||||||
|
)
|
||||||
|
ep_reward = rollout["next", "reward"].sum()
|
||||||
|
ep_success = rollout["next", "success"].any()
|
||||||
|
rewards.append(ep_reward.item())
|
||||||
|
successes.append(ep_success.item())
|
||||||
|
|
||||||
if save_video:
|
if save_video:
|
||||||
frames = []
|
video_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||||
while not done:
|
|
||||||
action = agent.act(obs, t0=t == 0, eval_mode=True, step=100000)
|
|
||||||
td = TensorDict({"action": action}, batch_size=[])
|
|
||||||
|
|
||||||
td = env.step(td)
|
|
||||||
|
|
||||||
reward = td["next", "reward"].item()
|
|
||||||
success = td["next", "success"].item()
|
|
||||||
done = td["next", "done"].item()
|
|
||||||
|
|
||||||
obs = {}
|
|
||||||
obs["rgb"] = td["next", "observation"]["camera"]
|
|
||||||
obs["state"] = td["next", "observation"]["robot_state"]
|
|
||||||
|
|
||||||
ep_reward += reward
|
|
||||||
if success:
|
|
||||||
ep_success = True
|
|
||||||
if save_video:
|
|
||||||
frame = env.render()
|
|
||||||
frames.append(frame)
|
|
||||||
t += 1
|
|
||||||
episode_rewards.append(float(ep_reward))
|
|
||||||
episode_successes.append(float(ep_success))
|
|
||||||
episode_lengths.append(t)
|
|
||||||
if save_video:
|
|
||||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
frames = np.stack(frames) # .transpose(0, 3, 1, 2)
|
|
||||||
# TODO(rcadene): make fps configurable
|
# TODO(rcadene): make fps configurable
|
||||||
imageio.mimsave(video_path, frames, fps=15)
|
video_path = video_dir / f"eval_episode_{i}.mp4"
|
||||||
return {
|
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
|
||||||
"episode_reward": np.nanmean(episode_rewards),
|
|
||||||
"episode_success": np.nanmean(episode_successes),
|
metrics = {
|
||||||
"episode_length": np.nanmean(episode_lengths),
|
"avg_reward": np.nanmean(rewards),
|
||||||
|
"pc_success": np.nanmean(successes) * 100,
|
||||||
}
|
}
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||||
|
@ -78,20 +62,25 @@ def eval(cfg: dict):
|
||||||
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir)
|
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir)
|
||||||
|
|
||||||
env = make_env(cfg)
|
env = make_env(cfg)
|
||||||
agent = TDMPC(cfg)
|
policy = TDMPC(cfg)
|
||||||
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||||
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||||
agent.load(ckpt_path)
|
policy.load(ckpt_path)
|
||||||
|
|
||||||
eval_metrics = eval_agent(
|
policy = TensorDictModule(
|
||||||
env,
|
policy,
|
||||||
agent,
|
in_keys=["observation", "step_count"],
|
||||||
num_episodes=10,
|
out_keys=["action"],
|
||||||
save_video=True,
|
|
||||||
video_path=Path("tmp/2023_01_29_xarm_lift_final/eval.mp4"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(eval_metrics)
|
metrics = eval_policy(
|
||||||
|
env,
|
||||||
|
policy,
|
||||||
|
num_episodes=10,
|
||||||
|
save_video=True,
|
||||||
|
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
||||||
|
)
|
||||||
|
print(metrics)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -2,7 +2,10 @@ import hydra
|
||||||
import torch
|
import torch
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from ..lib.utils import set_seed
|
from lerobot.common.envs.factory import make_env
|
||||||
|
from lerobot.common.tdmpc import TDMPC
|
||||||
|
|
||||||
|
from ..common.utils import set_seed
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||||
|
@ -11,6 +14,24 @@ def train(cfg: dict):
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir)
|
print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir)
|
||||||
|
|
||||||
|
env = make_env(cfg)
|
||||||
|
agent = TDMPC(cfg)
|
||||||
|
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||||
|
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||||
|
agent.load(ckpt_path)
|
||||||
|
|
||||||
|
# online training
|
||||||
|
|
||||||
|
eval_metrics = train_agent(
|
||||||
|
env,
|
||||||
|
agent,
|
||||||
|
num_episodes=10,
|
||||||
|
save_video=True,
|
||||||
|
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(eval_metrics)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train()
|
train()
|
||||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.envs.utils import check_env_specs, step_mdp
|
from torchrl.envs.utils import check_env_specs, step_mdp
|
||||||
|
|
||||||
from lerobot.lib.envs import SimxarmEnv
|
from lerobot.common.envs import SimxarmEnv
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
Loading…
Reference in New Issue