Add common, refactor eval with eval_policy
This commit is contained in:
parent
1e52499490
commit
5a5b190f70
|
@ -0,0 +1,42 @@
|
|||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||
|
||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||
|
||||
|
||||
def make_env(cfg):
|
||||
assert cfg.env == "simxarm"
|
||||
env = SimxarmEnv(
|
||||
task=cfg.task,
|
||||
from_pixels=cfg.from_pixels,
|
||||
pixels_only=cfg.pixels_only,
|
||||
image_size=cfg.image_size,
|
||||
)
|
||||
|
||||
# limit rollout to max_steps
|
||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length))
|
||||
|
||||
return env
|
||||
|
||||
|
||||
# def make_env(env_name, frame_skip, device, is_test=False):
|
||||
# env = GymEnv(
|
||||
# env_name,
|
||||
# frame_skip=frame_skip,
|
||||
# from_pixels=True,
|
||||
# pixels_only=False,
|
||||
# device=device,
|
||||
# )
|
||||
# env = TransformedEnv(env)
|
||||
# env.append_transform(NoopResetEnv(noops=30, random=True))
|
||||
# if not is_test:
|
||||
# env.append_transform(EndOfLifeTransform())
|
||||
# env.append_transform(RewardClipping(-1, 1))
|
||||
# env.append_transform(ToTensorImage())
|
||||
# env.append_transform(GrayScale())
|
||||
# env.append_transform(Resize(84, 84))
|
||||
# env.append_transform(CatFrames(N=4, dim=-3))
|
||||
# env.append_transform(RewardSum())
|
||||
# env.append_transform(StepCounter(max_steps=4500))
|
||||
# env.append_transform(DoubleToFloat())
|
||||
# env.append_transform(VecNorm(in_keys=["pixels"]))
|
||||
# return env
|
|
@ -0,0 +1,183 @@
|
|||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
BoundedTensorSpec,
|
||||
CompositeSpec,
|
||||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
_has_gym = importlib.util.find_spec("gym") is not None
|
||||
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
|
||||
|
||||
|
||||
class SimxarmEnv(EnvBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.task = task
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
|
||||
if not _has_simxarm:
|
||||
raise ImportError("Cannot import simxarm.")
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
import gym
|
||||
from gym.wrappers import TimeLimit
|
||||
from simxarm import TASKS
|
||||
|
||||
if self.task not in TASKS:
|
||||
raise ValueError(
|
||||
f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}"
|
||||
)
|
||||
|
||||
self._env = TASKS[self.task]["env"]()
|
||||
self._env = TimeLimit(self._env, TASKS[self.task]["episode_length"])
|
||||
|
||||
MAX_NUM_ACTIONS = 4
|
||||
num_actions = len(TASKS[self.task]["action_space"])
|
||||
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||
self._action_padding = np.zeros(
|
||||
(MAX_NUM_ACTIONS - num_actions), dtype=np.float32
|
||||
)
|
||||
if "w" not in TASKS[self.task]["action_space"]:
|
||||
self._action_padding[-1] = 1.0
|
||||
|
||||
self._make_spec()
|
||||
self.set_seed(seed)
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
return self._env.render(mode, width=width, height=height)
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
camera = self.render(
|
||||
mode="rgb_array", width=self.image_size, height=self.image_size
|
||||
)
|
||||
camera = camera.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||
camera = torch.tensor(camera.copy(), dtype=torch.uint8)
|
||||
|
||||
obs = {"camera": camera}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["robot_state"] = torch.tensor(
|
||||
self._env.robot_state, dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
|
||||
|
||||
obs = TensorDict(obs, batch_size=[])
|
||||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
raw_obs = self._env.reset()
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
# step expects shape=(4,) so we pad if necessary
|
||||
action = np.concatenate([action, self._action_padding])
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
"reward": torch.tensor([reward], dtype=torch.float32),
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([info["success"]], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
return td
|
||||
|
||||
def _make_spec(self):
|
||||
obs = {}
|
||||
if self.from_pixels:
|
||||
obs["camera"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(3, self.image_size, self.image_size),
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
obs["robot_state"] = UnboundedContinuousTensorSpec(
|
||||
shape=(len(self._env.robot_state),),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
shape=self._env.observation_space["observation"].shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.observation_spec = CompositeSpec({"observation": obs})
|
||||
|
||||
self.action_spec = _gym_to_torchrl_spec_transform(
|
||||
self._action_space,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||
shape=(1,),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.done_spec = DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.success_spec = DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
set_seed(seed)
|
||||
self._env.seed(seed)
|
|
@ -0,0 +1,450 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import lerobot.common.tdmpc_helper as h
|
||||
|
||||
|
||||
class TOLD(nn.Module):
|
||||
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
action_dim = 4
|
||||
|
||||
self.cfg = cfg
|
||||
self._encoder = h.enc(cfg)
|
||||
self._dynamics = h.dynamics(
|
||||
cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim
|
||||
)
|
||||
self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1)
|
||||
self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim)
|
||||
self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)])
|
||||
self._V = h.v(cfg)
|
||||
self.apply(h.orthogonal_init)
|
||||
for m in [self._reward, *self._Qs]:
|
||||
m[-1].weight.data.fill_(0)
|
||||
m[-1].bias.data.fill_(0)
|
||||
|
||||
def track_q_grad(self, enable=True):
|
||||
"""Utility function. Enables/disables gradient tracking of Q-networks."""
|
||||
for m in self._Qs:
|
||||
h.set_requires_grad(m, enable)
|
||||
|
||||
def track_v_grad(self, enable=True):
|
||||
"""Utility function. Enables/disables gradient tracking of Q-networks."""
|
||||
if hasattr(self, "_V"):
|
||||
h.set_requires_grad(self._V, enable)
|
||||
|
||||
def encode(self, obs):
|
||||
"""Encodes an observation into its latent representation."""
|
||||
out = self._encoder(obs)
|
||||
if isinstance(obs, dict):
|
||||
# fusion
|
||||
out = torch.stack([v for k, v in out.items()]).mean(dim=0)
|
||||
return out
|
||||
|
||||
def next(self, z, a):
|
||||
"""Predicts next latent state (d) and single-step reward (R)."""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
return self._dynamics(x), self._reward(x)
|
||||
|
||||
def pi(self, z, std=0):
|
||||
"""Samples an action from the learned policy (pi)."""
|
||||
mu = torch.tanh(self._pi(z))
|
||||
if std > 0:
|
||||
std = torch.ones_like(mu) * std
|
||||
return h.TruncatedNormal(mu, std).sample(clip=0.3)
|
||||
return mu
|
||||
|
||||
def V(self, z):
|
||||
"""Predict state value (V)."""
|
||||
return self._V(z)
|
||||
|
||||
def Q(self, z, a, return_type):
|
||||
"""Predict state-action value (Q)."""
|
||||
assert return_type in {"min", "avg", "all"}
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
|
||||
if return_type == "all":
|
||||
return torch.stack(list(q(x) for q in self._Qs), dim=0)
|
||||
|
||||
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
|
||||
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
|
||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||
|
||||
|
||||
class TDMPC(nn.Module):
|
||||
"""Implementation of TD-MPC learning + inference."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.action_dim = 4
|
||||
|
||||
self.cfg = cfg
|
||||
self.device = torch.device("cuda")
|
||||
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
||||
self.model = TOLD(cfg).cuda()
|
||||
self.model_target = deepcopy(self.model)
|
||||
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
||||
self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.model.eval()
|
||||
self.model_target.eval()
|
||||
self.batch_size = cfg.batch_size
|
||||
|
||||
# TODO(rcadene): clean
|
||||
self.step = 100000
|
||||
|
||||
def state_dict(self):
|
||||
"""Retrieve state dict of TOLD model, including slow-moving target network."""
|
||||
return {
|
||||
"model": self.model.state_dict(),
|
||||
"model_target": self.model_target.state_dict(),
|
||||
}
|
||||
|
||||
def save(self, fp):
|
||||
"""Save state dict of TOLD model to filepath."""
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
"""Load a saved state dict from filepath into current agent."""
|
||||
d = torch.load(fp)
|
||||
self.model.load_state_dict(d["model"])
|
||||
self.model_target.load_state_dict(d["model_target"])
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
t0 = step_count.item() == 0
|
||||
obs = {
|
||||
"rgb": observation["camera"],
|
||||
"state": observation["robot_state"],
|
||||
}
|
||||
return self.act(obs, t0=t0, step=self.step)
|
||||
|
||||
@torch.no_grad()
|
||||
def act(self, obs, t0=False, step=None):
|
||||
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
||||
if isinstance(obs, dict):
|
||||
obs = {
|
||||
k: torch.tensor(o, dtype=torch.float32, device=self.device).unsqueeze(0)
|
||||
for k, o in obs.items()
|
||||
}
|
||||
else:
|
||||
obs = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(
|
||||
0
|
||||
)
|
||||
z = self.model.encode(obs)
|
||||
if self.cfg.mpc:
|
||||
a = self.plan(z, t0=t0, step=step)
|
||||
else:
|
||||
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
||||
return a.cpu()
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_value(self, z, actions, horizon):
|
||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||
G, discount = 0, 1
|
||||
for t in range(horizon):
|
||||
if self.cfg.uncertainty_cost > 0:
|
||||
G -= (
|
||||
discount
|
||||
* self.cfg.uncertainty_cost
|
||||
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
|
||||
)
|
||||
z, reward = self.model.next(z, actions[t])
|
||||
G += discount * reward
|
||||
discount *= self.cfg.discount
|
||||
pi = self.model.pi(z, self.cfg.min_std)
|
||||
G += discount * self.model.Q(z, pi, return_type="min")
|
||||
if self.cfg.uncertainty_cost > 0:
|
||||
G -= (
|
||||
discount
|
||||
* self.cfg.uncertainty_cost
|
||||
* self.model.Q(z, pi, return_type="all").std(dim=0)
|
||||
)
|
||||
return G
|
||||
|
||||
@torch.no_grad()
|
||||
def plan(self, z, step=None, t0=True):
|
||||
"""
|
||||
Plan next action using TD-MPC inference.
|
||||
z: latent state.
|
||||
step: current time step. determines e.g. planning horizon.
|
||||
t0: whether current step is the first step of an episode.
|
||||
"""
|
||||
# during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
|
||||
|
||||
assert step is not None
|
||||
# Seed steps
|
||||
if step < self.cfg.seed_steps and self.model.training:
|
||||
return torch.empty(
|
||||
self.action_dim, dtype=torch.float32, device=self.device
|
||||
).uniform_(-1, 1)
|
||||
|
||||
# Sample policy trajectories
|
||||
horizon = int(
|
||||
min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))
|
||||
)
|
||||
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
|
||||
if num_pi_trajs > 0:
|
||||
pi_actions = torch.empty(
|
||||
horizon, num_pi_trajs, self.action_dim, device=self.device
|
||||
)
|
||||
_z = z.repeat(num_pi_trajs, 1)
|
||||
for t in range(horizon):
|
||||
pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
|
||||
_z, _ = self.model.next(_z, pi_actions[t])
|
||||
|
||||
# Initialize state and parameters
|
||||
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
|
||||
mean = torch.zeros(horizon, self.action_dim, device=self.device)
|
||||
std = self.cfg.max_std * torch.ones(
|
||||
horizon, self.action_dim, device=self.device
|
||||
)
|
||||
if not t0 and hasattr(self, "_prev_mean"):
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
|
||||
# Iterate CEM
|
||||
for i in range(self.cfg.iterations):
|
||||
actions = torch.clamp(
|
||||
mean.unsqueeze(1)
|
||||
+ std.unsqueeze(1)
|
||||
* torch.randn(
|
||||
horizon, self.cfg.num_samples, self.action_dim, device=std.device
|
||||
),
|
||||
-1,
|
||||
1,
|
||||
)
|
||||
if num_pi_trajs > 0:
|
||||
actions = torch.cat([actions, pi_actions], dim=1)
|
||||
|
||||
# Compute elite actions
|
||||
value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
|
||||
elite_idxs = torch.topk(
|
||||
value.squeeze(1), self.cfg.num_elites, dim=0
|
||||
).indices
|
||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||
|
||||
# Update parameters
|
||||
max_value = elite_value.max(0)[0]
|
||||
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
|
||||
score /= score.sum(0)
|
||||
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (
|
||||
score.sum(0) + 1e-9
|
||||
)
|
||||
_std = torch.sqrt(
|
||||
torch.sum(
|
||||
score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2,
|
||||
dim=1,
|
||||
)
|
||||
/ (score.sum(0) + 1e-9)
|
||||
)
|
||||
_std = _std.clamp_(self.std, self.cfg.max_std)
|
||||
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
|
||||
|
||||
# Outputs
|
||||
score = score.squeeze(1).cpu().numpy()
|
||||
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
|
||||
self._prev_mean = mean
|
||||
mean, std = actions[0], _std[0]
|
||||
a = mean
|
||||
if self.model.training:
|
||||
a += std * torch.randn(self.action_dim, device=std.device)
|
||||
return torch.clamp(a, -1, 1)
|
||||
|
||||
def update_pi(self, zs, acts=None):
|
||||
"""Update policy using a sequence of latent states."""
|
||||
self.pi_optim.zero_grad(set_to_none=True)
|
||||
self.model.track_q_grad(False)
|
||||
self.model.track_v_grad(False)
|
||||
|
||||
info = {}
|
||||
# Advantage Weighted Regression
|
||||
assert acts is not None
|
||||
vs = self.model.V(zs)
|
||||
qs = self.model_target.Q(zs, acts, return_type="min")
|
||||
adv = qs - vs
|
||||
exp_a = torch.exp(adv * self.cfg.A_scaling)
|
||||
exp_a = torch.clamp(exp_a, max=100.0)
|
||||
log_probs = h.gaussian_logprob(self.model.pi(zs) - acts, 0)
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
||||
pi_loss = -((exp_a * log_probs).mean(dim=(1, 2)) * rho).mean()
|
||||
info["adv"] = adv[0]
|
||||
|
||||
pi_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model._pi.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
self.pi_optim.step()
|
||||
self.model.track_q_grad(True)
|
||||
self.model.track_v_grad(True)
|
||||
|
||||
info["pi_loss"] = pi_loss.item()
|
||||
return pi_loss.item(), info
|
||||
|
||||
@torch.no_grad()
|
||||
def _td_target(self, next_z, reward, mask):
|
||||
"""Compute the TD-target from a reward and the observation at the following time step."""
|
||||
next_v = self.model.V(next_z)
|
||||
td_target = reward + self.cfg.discount * mask * next_v
|
||||
return td_target
|
||||
|
||||
def update(self, replay_buffer, step, demo_buffer=None):
|
||||
"""Main update function. Corresponds to one iteration of the model learning."""
|
||||
|
||||
if demo_buffer is not None:
|
||||
# Update oversampling ratio
|
||||
self.demo_batch_size = int(
|
||||
h.linear_schedule(self.cfg.demo_schedule, step) * self.batch_size
|
||||
)
|
||||
replay_buffer.cfg.batch_size = self.batch_size - self.demo_batch_size
|
||||
demo_buffer.cfg.batch_size = self.demo_batch_size
|
||||
else:
|
||||
self.demo_batch_size = 0
|
||||
|
||||
# Sample from interaction dataset
|
||||
obs, next_obses, action, reward, mask, done, idxs, weights = (
|
||||
replay_buffer.sample()
|
||||
)
|
||||
|
||||
# Sample from demonstration dataset
|
||||
if self.demo_batch_size > 0:
|
||||
(
|
||||
demo_obs,
|
||||
demo_next_obses,
|
||||
demo_action,
|
||||
demo_reward,
|
||||
demo_mask,
|
||||
demo_done,
|
||||
demo_idxs,
|
||||
demo_weights,
|
||||
) = demo_buffer.sample()
|
||||
|
||||
if isinstance(obs, dict):
|
||||
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
|
||||
next_obses = {
|
||||
k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1)
|
||||
for k in next_obses
|
||||
}
|
||||
else:
|
||||
obs = torch.cat([obs, demo_obs])
|
||||
next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
|
||||
action = torch.cat([action, demo_action], dim=1)
|
||||
reward = torch.cat([reward, demo_reward], dim=1)
|
||||
mask = torch.cat([mask, demo_mask], dim=1)
|
||||
done = torch.cat([done, demo_done], dim=1)
|
||||
idxs = torch.cat([idxs, demo_idxs])
|
||||
weights = torch.cat([weights, demo_weights])
|
||||
|
||||
horizon = self.cfg.horizon
|
||||
loss_mask = torch.ones_like(mask, device=self.device)
|
||||
for t in range(1, horizon):
|
||||
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
|
||||
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
self.std = h.linear_schedule(self.cfg.std_schedule, step)
|
||||
self.model.train()
|
||||
|
||||
# Compute targets
|
||||
with torch.no_grad():
|
||||
next_z = self.model.encode(next_obses)
|
||||
z_targets = self.model_target.encode(next_obses)
|
||||
td_targets = self._td_target(next_z, reward, mask)
|
||||
|
||||
# Latent rollout
|
||||
zs = torch.empty(
|
||||
horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device
|
||||
)
|
||||
reward_preds = torch.empty_like(reward, device=self.device)
|
||||
assert reward.shape[0] == horizon
|
||||
z = self.model.encode(obs)
|
||||
zs[0] = z
|
||||
value_info = {"Q": 0.0, "V": 0.0}
|
||||
for t in range(horizon):
|
||||
z, reward_pred = self.model.next(z, action[t])
|
||||
zs[t + 1] = z
|
||||
reward_preds[t] = reward_pred
|
||||
|
||||
with torch.no_grad():
|
||||
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
|
||||
|
||||
# Predictions
|
||||
qs = self.model.Q(zs[:-1], action, return_type="all")
|
||||
value_info["Q"] = qs.mean().item()
|
||||
v = self.model.V(zs[:-1])
|
||||
value_info["V"] = v.mean().item()
|
||||
|
||||
# Losses
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(
|
||||
-1, 1, 1
|
||||
)
|
||||
consistency_loss = (
|
||||
rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask
|
||||
).sum(dim=0)
|
||||
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
|
||||
q_value_loss, priority_loss = 0, 0
|
||||
for q in range(self.cfg.num_q):
|
||||
q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||
|
||||
self.expectile = h.linear_schedule(self.cfg.expectile, step)
|
||||
v_value_loss = (
|
||||
rho * h.l2_expectile(v_target - v, expectile=self.expectile) * loss_mask
|
||||
).sum(dim=0)
|
||||
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss
|
||||
+ self.cfg.reward_coef * reward_loss
|
||||
+ self.cfg.value_coef * q_value_loss
|
||||
+ self.cfg.value_coef * v_value_loss
|
||||
)
|
||||
|
||||
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
||||
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
||||
weighted_loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||
)
|
||||
self.optim.step()
|
||||
|
||||
if self.cfg.per:
|
||||
# Update priorities
|
||||
priorities = priority_loss.clamp(max=1e4).detach()
|
||||
replay_buffer.update_priorities(
|
||||
idxs[: replay_buffer.cfg.batch_size],
|
||||
priorities[: replay_buffer.cfg.batch_size],
|
||||
)
|
||||
if self.demo_batch_size > 0:
|
||||
demo_buffer.update_priorities(
|
||||
demo_idxs, priorities[replay_buffer.cfg.batch_size :]
|
||||
)
|
||||
|
||||
# Update policy + target network
|
||||
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
|
||||
|
||||
if step % self.cfg.update_freq == 0:
|
||||
h.ema(self.model._encoder, self.model_target._encoder, self.cfg.tau)
|
||||
h.ema(self.model._Qs, self.model_target._Qs, self.cfg.tau)
|
||||
|
||||
self.model.eval()
|
||||
metrics = {
|
||||
"consistency_loss": float(consistency_loss.mean().item()),
|
||||
"reward_loss": float(reward_loss.mean().item()),
|
||||
"Q_value_loss": float(q_value_loss.mean().item()),
|
||||
"V_value_loss": float(v_value_loss.mean().item()),
|
||||
"total_loss": float(total_loss.mean().item()),
|
||||
"weighted_loss": float(weighted_loss.mean().item()),
|
||||
"grad_norm": float(grad_norm),
|
||||
}
|
||||
for key in ["demo_batch_size", "expectile"]:
|
||||
if hasattr(self, key):
|
||||
metrics[key] = getattr(self, key)
|
||||
metrics.update(value_info)
|
||||
metrics.update(pi_update_info)
|
||||
|
||||
return metrics
|
|
@ -0,0 +1,829 @@
|
|||
import os
|
||||
import pickle
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import distributions as pyd
|
||||
from torch.distributions.utils import _standard_normal
|
||||
|
||||
__REDUCE__ = lambda b: "mean" if b else "none"
|
||||
|
||||
|
||||
def l1(pred, target, reduce=False):
|
||||
"""Computes the L1-loss between predictions and targets."""
|
||||
return F.l1_loss(pred, target, reduction=__REDUCE__(reduce))
|
||||
|
||||
|
||||
def mse(pred, target, reduce=False):
|
||||
"""Computes the MSE loss between predictions and targets."""
|
||||
return F.mse_loss(pred, target, reduction=__REDUCE__(reduce))
|
||||
|
||||
|
||||
def l2_expectile(diff, expectile=0.7, reduce=False):
|
||||
weight = torch.where(diff > 0, expectile, (1 - expectile))
|
||||
loss = weight * (diff**2)
|
||||
reduction = __REDUCE__(reduce)
|
||||
if reduction == "mean":
|
||||
return torch.mean(loss)
|
||||
elif reduction == "sum":
|
||||
return torch.sum(loss)
|
||||
return loss
|
||||
|
||||
|
||||
def _get_out_shape(in_shape, layers):
|
||||
"""Utility function. Returns the output shape of a network for a given input shape."""
|
||||
x = torch.randn(*in_shape).unsqueeze(0)
|
||||
return (
|
||||
(nn.Sequential(*layers) if isinstance(layers, list) else layers)(x)
|
||||
.squeeze(0)
|
||||
.shape
|
||||
)
|
||||
|
||||
|
||||
def gaussian_logprob(eps, log_std):
|
||||
"""Compute Gaussian log probability."""
|
||||
residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True)
|
||||
return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1)
|
||||
|
||||
|
||||
def squash(mu, pi, log_pi):
|
||||
"""Apply squashing function."""
|
||||
mu = torch.tanh(mu)
|
||||
pi = torch.tanh(pi)
|
||||
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
|
||||
return mu, pi, log_pi
|
||||
|
||||
|
||||
def orthogonal_init(m):
|
||||
"""Orthogonal layer initialization."""
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
gain = nn.init.calculate_gain("relu")
|
||||
nn.init.orthogonal_(m.weight.data, gain)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
def ema(m, m_target, tau):
|
||||
"""Update slow-moving average of online network (target network) at rate tau."""
|
||||
with torch.no_grad():
|
||||
for p, p_target in zip(m.parameters(), m_target.parameters()):
|
||||
p_target.data.lerp_(p.data, tau)
|
||||
|
||||
|
||||
def set_requires_grad(net, value):
|
||||
"""Enable/disable gradients for a given (sub)network."""
|
||||
for param in net.parameters():
|
||||
param.requires_grad_(value)
|
||||
|
||||
|
||||
class TruncatedNormal(pyd.Normal):
|
||||
"""Utility class implementing the truncated normal distribution."""
|
||||
|
||||
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
|
||||
super().__init__(loc, scale, validate_args=False)
|
||||
self.low = low
|
||||
self.high = high
|
||||
self.eps = eps
|
||||
|
||||
def _clamp(self, x):
|
||||
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
|
||||
x = x - x.detach() + clamped_x.detach()
|
||||
return x
|
||||
|
||||
def sample(self, clip=None, sample_shape=torch.Size()):
|
||||
shape = self._extended_shape(sample_shape)
|
||||
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
|
||||
eps *= self.scale
|
||||
if clip is not None:
|
||||
eps = torch.clamp(eps, -clip, clip)
|
||||
x = self.loc + eps
|
||||
return self._clamp(x)
|
||||
|
||||
|
||||
class NormalizeImg(nn.Module):
|
||||
"""Normalizes pixel observations to [0,1) range."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.div(255.0)
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
"""Flattens its input to a (batched) vector."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
def enc(cfg):
|
||||
obs_shape = {
|
||||
"rgb": (3, cfg.img_size, cfg.img_size),
|
||||
"state": (4,),
|
||||
}
|
||||
|
||||
"""Returns a TOLD encoder."""
|
||||
pixels_enc_layers, state_enc_layers = None, None
|
||||
if cfg.modality in {"pixels", "all"}:
|
||||
C = int(3 * cfg.frame_stack)
|
||||
pixels_enc_layers = [
|
||||
NormalizeImg(),
|
||||
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
]
|
||||
out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers)
|
||||
pixels_enc_layers.extend(
|
||||
[
|
||||
Flatten(),
|
||||
nn.Linear(np.prod(out_shape), cfg.latent_dim),
|
||||
nn.LayerNorm(cfg.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
]
|
||||
)
|
||||
if cfg.modality == "pixels":
|
||||
return ConvExt(nn.Sequential(*pixels_enc_layers))
|
||||
if cfg.modality in {"state", "all"}:
|
||||
state_dim = obs_shape[0] if cfg.modality == "state" else obs_shape["state"][0]
|
||||
state_enc_layers = [
|
||||
nn.Linear(state_dim, cfg.enc_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(cfg.enc_dim, cfg.latent_dim),
|
||||
nn.LayerNorm(cfg.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
]
|
||||
if cfg.modality == "state":
|
||||
return nn.Sequential(*state_enc_layers)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
encoders = {}
|
||||
for k in obs_shape:
|
||||
if k == "state":
|
||||
encoders[k] = nn.Sequential(*state_enc_layers)
|
||||
elif k.endswith("rgb"):
|
||||
encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return Multiplexer(nn.ModuleDict(encoders))
|
||||
|
||||
|
||||
def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
|
||||
"""Returns an MLP."""
|
||||
if isinstance(mlp_dim, int):
|
||||
mlp_dim = [mlp_dim, mlp_dim]
|
||||
return nn.Sequential(
|
||||
nn.Linear(in_dim, mlp_dim[0]),
|
||||
nn.LayerNorm(mlp_dim[0]),
|
||||
act_fn,
|
||||
nn.Linear(mlp_dim[0], mlp_dim[1]),
|
||||
nn.LayerNorm(mlp_dim[1]),
|
||||
act_fn,
|
||||
nn.Linear(mlp_dim[1], out_dim),
|
||||
)
|
||||
|
||||
|
||||
def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
|
||||
"""Returns a dynamics network."""
|
||||
return nn.Sequential(
|
||||
mlp(in_dim, mlp_dim, out_dim, act_fn),
|
||||
nn.LayerNorm(out_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
|
||||
def q(cfg):
|
||||
action_dim = 4
|
||||
"""Returns a Q-function that uses Layer Normalization."""
|
||||
return nn.Sequential(
|
||||
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
|
||||
nn.LayerNorm(cfg.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(cfg.mlp_dim, 1),
|
||||
)
|
||||
|
||||
|
||||
def v(cfg):
|
||||
"""Returns a state value function that uses Layer Normalization."""
|
||||
return nn.Sequential(
|
||||
nn.Linear(cfg.latent_dim, cfg.mlp_dim),
|
||||
nn.LayerNorm(cfg.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(cfg.mlp_dim, 1),
|
||||
)
|
||||
|
||||
|
||||
def aug(cfg):
|
||||
obs_shape = {
|
||||
"rgb": (3, cfg.img_size, cfg.img_size),
|
||||
"state": (4,),
|
||||
}
|
||||
|
||||
"""Multiplex augmentation"""
|
||||
if cfg.modality == "state":
|
||||
return nn.Identity()
|
||||
elif cfg.modality == "pixels":
|
||||
return RandomShiftsAug(cfg)
|
||||
else:
|
||||
augs = {}
|
||||
for k in obs_shape:
|
||||
if k == "state":
|
||||
augs[k] = nn.Identity()
|
||||
elif k.endswith("rgb"):
|
||||
augs[k] = RandomShiftsAug(cfg)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return Multiplexer(nn.ModuleDict(augs))
|
||||
|
||||
|
||||
class ConvExt(nn.Module):
|
||||
"""Auxiliary conv net accommodating high-dim input"""
|
||||
|
||||
def __init__(self, conv):
|
||||
super().__init__()
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, x):
|
||||
if x.ndim > 4:
|
||||
batch_shape = x.shape[:-3]
|
||||
out = self.conv(x.view(-1, *x.shape[-3:]))
|
||||
out = out.view(*batch_shape, *out.shape[1:])
|
||||
else:
|
||||
out = self.conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class Multiplexer(nn.Module):
|
||||
"""Model multiplexer"""
|
||||
|
||||
def __init__(self, choices):
|
||||
super().__init__()
|
||||
self.choices = choices
|
||||
|
||||
def forward(self, x, key=None):
|
||||
if isinstance(x, dict):
|
||||
if key is not None:
|
||||
return self.choices[key](x)
|
||||
return {k: self.choices[k](_x) for k, _x in x.items()}
|
||||
return self.choices(x)
|
||||
|
||||
|
||||
class RandomShiftsAug(nn.Module):
|
||||
"""
|
||||
Random shift image augmentation.
|
||||
Adapted from https://github.com/facebookresearch/drqv2
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
assert cfg.modality in {"pixels", "all"}
|
||||
self.pad = int(cfg.img_size / 21)
|
||||
|
||||
def forward(self, x):
|
||||
n, c, h, w = x.size()
|
||||
assert h == w
|
||||
padding = tuple([self.pad] * 4)
|
||||
x = F.pad(x, padding, "replicate")
|
||||
eps = 1.0 / (h + 2 * self.pad)
|
||||
arange = torch.linspace(
|
||||
-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype
|
||||
)[:h]
|
||||
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
|
||||
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
|
||||
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
|
||||
shift = torch.randint(
|
||||
0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype
|
||||
)
|
||||
shift *= 2.0 / (h + 2 * self.pad)
|
||||
grid = base_grid + shift
|
||||
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
|
||||
|
||||
|
||||
class Episode(object):
|
||||
"""Storage object for a single episode."""
|
||||
|
||||
def __init__(self, cfg, init_obs):
|
||||
action_dim = 4
|
||||
|
||||
self.cfg = cfg
|
||||
self.device = torch.device(cfg.buffer_device)
|
||||
if cfg.modality in {"pixels", "state"}:
|
||||
dtype = torch.float32 if cfg.modality == "state" else torch.uint8
|
||||
self.obses = torch.empty(
|
||||
(cfg.episode_length + 1, *init_obs.shape),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
|
||||
elif cfg.modality == "all":
|
||||
self.obses = {}
|
||||
for k, v in init_obs.items():
|
||||
assert k in {"rgb", "state"}
|
||||
dtype = torch.float32 if k == "state" else torch.uint8
|
||||
self.obses[k] = torch.empty(
|
||||
(cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device
|
||||
)
|
||||
self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
|
||||
else:
|
||||
raise ValueError
|
||||
self.actions = torch.empty(
|
||||
(cfg.episode_length, action_dim), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.rewards = torch.empty(
|
||||
(cfg.episode_length,), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.dones = torch.empty(
|
||||
(cfg.episode_length,), dtype=torch.bool, device=self.device
|
||||
)
|
||||
self.masks = torch.empty(
|
||||
(cfg.episode_length,), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.cumulative_reward = 0
|
||||
self.done = False
|
||||
self.success = False
|
||||
self._idx = 0
|
||||
|
||||
def __len__(self):
|
||||
return self._idx
|
||||
|
||||
@classmethod
|
||||
def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
|
||||
"""Constructs an episode from a trajectory."""
|
||||
|
||||
if cfg.modality in {"pixels", "state"}:
|
||||
episode = cls(cfg, obses[0])
|
||||
episode.obses[1:] = torch.tensor(
|
||||
obses[1:], dtype=episode.obses.dtype, device=episode.device
|
||||
)
|
||||
elif cfg.modality == "all":
|
||||
episode = cls(cfg, {k: v[0] for k, v in obses.items()})
|
||||
for k, v in obses.items():
|
||||
episode.obses[k][1:] = torch.tensor(
|
||||
obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
episode.actions = torch.tensor(
|
||||
actions, dtype=episode.actions.dtype, device=episode.device
|
||||
)
|
||||
episode.rewards = torch.tensor(
|
||||
rewards, dtype=episode.rewards.dtype, device=episode.device
|
||||
)
|
||||
episode.dones = (
|
||||
torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
|
||||
if dones is not None
|
||||
else torch.zeros_like(episode.dones)
|
||||
)
|
||||
episode.masks = (
|
||||
torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
|
||||
if masks is not None
|
||||
else torch.ones_like(episode.masks)
|
||||
)
|
||||
episode.cumulative_reward = torch.sum(episode.rewards)
|
||||
episode.done = True
|
||||
episode._idx = cfg.episode_length
|
||||
return episode
|
||||
|
||||
@property
|
||||
def first(self):
|
||||
return len(self) == 0
|
||||
|
||||
def __add__(self, transition):
|
||||
self.add(*transition)
|
||||
return self
|
||||
|
||||
def add(self, obs, action, reward, done, mask=1.0, success=False):
|
||||
"""Add a transition into the episode."""
|
||||
if isinstance(obs, dict):
|
||||
for k, v in obs.items():
|
||||
self.obses[k][self._idx + 1] = torch.tensor(
|
||||
v, dtype=self.obses[k].dtype, device=self.obses[k].device
|
||||
)
|
||||
else:
|
||||
self.obses[self._idx + 1] = torch.tensor(
|
||||
obs, dtype=self.obses.dtype, device=self.obses.device
|
||||
)
|
||||
self.actions[self._idx] = action
|
||||
self.rewards[self._idx] = reward
|
||||
self.dones[self._idx] = done
|
||||
self.masks[self._idx] = mask
|
||||
self.cumulative_reward += reward
|
||||
self.done = done
|
||||
self.success = self.success or success
|
||||
self._idx += 1
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
"""
|
||||
Storage and sampling functionality.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, dataset=None):
|
||||
action_dim = 4
|
||||
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (4,)}
|
||||
|
||||
self.cfg = cfg
|
||||
self.device = torch.device(cfg.buffer_device)
|
||||
print("Replay buffer device: ", self.device)
|
||||
|
||||
if dataset is not None:
|
||||
self.capacity = max(dataset["rewards"].shape[0], cfg.max_buffer_size)
|
||||
else:
|
||||
self.capacity = min(cfg.train_steps, cfg.max_buffer_size)
|
||||
|
||||
if cfg.modality in {"pixels", "state"}:
|
||||
dtype = torch.float32 if cfg.modality == "state" else torch.uint8
|
||||
# Note self.obs_shape always has single frame, which is different from cfg.obs_shape
|
||||
self.obs_shape = (
|
||||
obs_shape if cfg.modality == "state" else (3, *obs_shape[-2:])
|
||||
)
|
||||
self._obs = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1, *self.obs_shape),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self._next_obs = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1, *self.obs_shape),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif cfg.modality == "all":
|
||||
self.obs_shape = {}
|
||||
self._obs, self._next_obs = {}, {}
|
||||
for k, v in obs_shape.items():
|
||||
assert k in {"rgb", "state"}
|
||||
dtype = torch.float32 if k == "state" else torch.uint8
|
||||
self.obs_shape[k] = v if k == "state" else (3, *v[-2:])
|
||||
self._obs[k] = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1, *self.obs_shape[k]),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self._next_obs[k] = self._obs[k].clone()
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self._action = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1, action_dim),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self._reward = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._mask = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._done = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1,), dtype=torch.bool, device=self.device
|
||||
)
|
||||
self._priorities = torch.ones(
|
||||
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._eps = 1e-6
|
||||
self._full = False
|
||||
self.idx = 0
|
||||
if dataset is not None:
|
||||
self.init_from_offline_dataset(dataset)
|
||||
|
||||
self._aug = aug(cfg)
|
||||
|
||||
def init_from_offline_dataset(self, dataset):
|
||||
"""Initialize the replay buffer from an offline dataset."""
|
||||
assert self.idx == 0 and not self._full
|
||||
n_transitions = int(len(dataset["rewards"]) * self.cfg.data_first_percent)
|
||||
|
||||
def copy_data(dst, src, n):
|
||||
assert isinstance(dst, dict) == isinstance(src, dict)
|
||||
if isinstance(dst, dict):
|
||||
for k in dst:
|
||||
copy_data(dst[k], src[k], n)
|
||||
else:
|
||||
dst[:n] = torch.from_numpy(src[:n])
|
||||
|
||||
copy_data(self._obs, dataset["observations"], n_transitions)
|
||||
copy_data(self._next_obs, dataset["next_observations"], n_transitions)
|
||||
copy_data(self._action, dataset["actions"], n_transitions)
|
||||
copy_data(self._reward, dataset["rewards"], n_transitions)
|
||||
copy_data(self._mask, dataset["masks"], n_transitions)
|
||||
copy_data(self._done, dataset["dones"], n_transitions)
|
||||
self.idx = (self.idx + n_transitions) % self.capacity
|
||||
self._full = n_transitions >= self.capacity
|
||||
|
||||
def __add__(self, episode: Episode):
|
||||
self.add(episode)
|
||||
return self
|
||||
|
||||
def add(self, episode: Episode):
|
||||
"""Add an episode to the replay buffer."""
|
||||
if self.idx + len(episode) > self.capacity:
|
||||
print("Warning: episode got truncated")
|
||||
ep_len = min(len(episode), self.capacity - self.idx)
|
||||
idxs = slice(self.idx, self.idx + ep_len)
|
||||
assert self.idx + ep_len <= self.capacity
|
||||
if self.cfg.modality in {"pixels", "state"}:
|
||||
self._obs[idxs] = (
|
||||
episode.obses[:ep_len]
|
||||
if self.cfg.modality == "state"
|
||||
else episode.obses[:ep_len, -3:]
|
||||
)
|
||||
self._next_obs[idxs] = (
|
||||
episode.obses[1 : ep_len + 1]
|
||||
if self.cfg.modality == "state"
|
||||
else episode.obses[1 : ep_len + 1, -3:]
|
||||
)
|
||||
elif self.cfg.modality == "all":
|
||||
for k, v in episode.obses.items():
|
||||
assert k in {"rgb", "state"}
|
||||
assert k in self._obs
|
||||
assert k in self._next_obs
|
||||
if k == "rgb":
|
||||
self._obs[k][idxs] = episode.obses[k][:ep_len, -3:]
|
||||
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1, -3:]
|
||||
else:
|
||||
self._obs[k][idxs] = episode.obses[k][:ep_len]
|
||||
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1]
|
||||
self._action[idxs] = episode.actions[:ep_len]
|
||||
self._reward[idxs] = episode.rewards[:ep_len]
|
||||
self._mask[idxs] = episode.masks[:ep_len]
|
||||
self._done[idxs] = episode.dones[:ep_len]
|
||||
self._done[self.idx + ep_len - 1] = True # in case truncated
|
||||
if self._full:
|
||||
max_priority = (
|
||||
self._priorities[: self.capacity].max().to(self.device).item()
|
||||
)
|
||||
else:
|
||||
max_priority = (
|
||||
1.0
|
||||
if self.idx == 0
|
||||
else self._priorities[: self.idx].max().to(self.device).item()
|
||||
)
|
||||
new_priorities = torch.full((ep_len,), max_priority, device=self.device)
|
||||
self._priorities[idxs] = new_priorities
|
||||
self.idx = (self.idx + ep_len) % self.capacity
|
||||
self._full = self._full or self.idx == 0
|
||||
|
||||
def update_priorities(self, idxs, priorities):
|
||||
"""Update priorities for Prioritized Experience Replay (PER)"""
|
||||
self._priorities[idxs] = priorities.squeeze(1).to(self.device) + self._eps
|
||||
|
||||
def _get_obs(self, arr, idxs):
|
||||
"""Retrieve observations by indices"""
|
||||
if isinstance(arr, dict):
|
||||
return {k: self._get_obs(v, idxs) for k, v in arr.items()}
|
||||
if arr.ndim <= 2: # if self.cfg.modality == 'state':
|
||||
return arr[idxs].cuda()
|
||||
obs = torch.empty(
|
||||
(self.cfg.batch_size, 3 * self.cfg.frame_stack, *arr.shape[-2:]),
|
||||
dtype=arr.dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
obs[:, -3:] = arr[idxs].cuda()
|
||||
_idxs = idxs.clone()
|
||||
mask = torch.ones_like(_idxs, dtype=torch.bool)
|
||||
for i in range(1, self.cfg.frame_stack):
|
||||
mask[_idxs % self.cfg.episode_length == 0] = False
|
||||
_idxs[mask] -= 1
|
||||
obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda()
|
||||
return obs.float()
|
||||
|
||||
def sample(self):
|
||||
"""Sample transitions from the replay buffer."""
|
||||
probs = (
|
||||
self._priorities[: self.capacity]
|
||||
if self._full
|
||||
else self._priorities[: self.idx]
|
||||
) ** self.cfg.per_alpha
|
||||
probs /= probs.sum()
|
||||
total = len(probs)
|
||||
idxs = torch.from_numpy(
|
||||
np.random.choice(
|
||||
total,
|
||||
self.cfg.batch_size,
|
||||
p=probs.cpu().numpy(),
|
||||
replace=not self._full,
|
||||
)
|
||||
).to(self.device)
|
||||
weights = (total * probs[idxs]) ** (-self.cfg.per_beta)
|
||||
weights /= weights.max()
|
||||
|
||||
idxs_in_horizon = torch.stack([idxs + t for t in range(self.cfg.horizon)])
|
||||
|
||||
obs = self._aug(self._get_obs(self._obs, idxs))
|
||||
next_obs = [
|
||||
self._aug(self._get_obs(self._next_obs, _idxs)) for _idxs in idxs_in_horizon
|
||||
]
|
||||
if isinstance(next_obs[0], dict):
|
||||
next_obs = {k: torch.stack([o[k] for o in next_obs]) for k in next_obs[0]}
|
||||
else:
|
||||
next_obs = torch.stack(next_obs)
|
||||
action = self._action[idxs_in_horizon]
|
||||
reward = self._reward[idxs_in_horizon]
|
||||
mask = self._mask[idxs_in_horizon]
|
||||
done = self._done[idxs_in_horizon]
|
||||
|
||||
if not action.is_cuda:
|
||||
action, reward, mask, done, idxs, weights = (
|
||||
action.cuda(),
|
||||
reward.cuda(),
|
||||
mask.cuda(),
|
||||
done.cuda(),
|
||||
idxs.cuda(),
|
||||
weights.cuda(),
|
||||
)
|
||||
|
||||
return (
|
||||
obs,
|
||||
next_obs,
|
||||
action,
|
||||
reward.unsqueeze(2),
|
||||
mask.unsqueeze(2),
|
||||
done.unsqueeze(2),
|
||||
idxs,
|
||||
weights,
|
||||
)
|
||||
|
||||
def save(self, path):
|
||||
"""Save the replay buffer to path"""
|
||||
print(f"saving replay buffer to '{path}'...")
|
||||
sz = self.capacity if self._full else self.idx
|
||||
dataset = {
|
||||
"observations": (
|
||||
{k: v[:sz].cpu().numpy() for k, v in self._obs.items()}
|
||||
if isinstance(self._obs, dict)
|
||||
else self._obs[:sz].cpu().numpy()
|
||||
),
|
||||
"next_observations": (
|
||||
{k: v[:sz].cpu().numpy() for k, v in self._next_obs.items()}
|
||||
if isinstance(self._next_obs, dict)
|
||||
else self._next_obs[:sz].cpu().numpy()
|
||||
),
|
||||
"actions": self._action[:sz].cpu().numpy(),
|
||||
"rewards": self._reward[:sz].cpu().numpy(),
|
||||
"dones": self._done[:sz].cpu().numpy(),
|
||||
"masks": self._mask[:sz].cpu().numpy(),
|
||||
}
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(dataset, f)
|
||||
return dataset
|
||||
|
||||
|
||||
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
|
||||
"""Construct a dataset for env"""
|
||||
required_keys = [
|
||||
"observations",
|
||||
"next_observations",
|
||||
"actions",
|
||||
"rewards",
|
||||
"dones",
|
||||
"masks",
|
||||
]
|
||||
|
||||
if cfg.task.startswith("xarm"):
|
||||
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl")
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
for k in required_keys:
|
||||
if k not in dataset_dict and k[:-1] in dataset_dict:
|
||||
dataset_dict[k] = dataset_dict.pop(k[:-1])
|
||||
elif cfg.task.startswith("legged"):
|
||||
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl")
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
dataset_dict["actions"] /= env.unwrapped.clip_actions
|
||||
print(f"clip_actions={env.unwrapped.clip_actions}")
|
||||
else:
|
||||
import d4rl
|
||||
|
||||
dataset_dict = d4rl.qlearning_dataset(env)
|
||||
dones = np.full_like(dataset_dict["rewards"], False, dtype=bool)
|
||||
|
||||
for i in range(len(dones) - 1):
|
||||
if (
|
||||
np.linalg.norm(
|
||||
dataset_dict["observations"][i + 1]
|
||||
- dataset_dict["next_observations"][i]
|
||||
)
|
||||
> 1e-6
|
||||
or dataset_dict["terminals"][i] == 1.0
|
||||
):
|
||||
dones[i] = True
|
||||
|
||||
dones[-1] = True
|
||||
|
||||
dataset_dict["masks"] = 1.0 - dataset_dict["terminals"]
|
||||
del dataset_dict["terminals"]
|
||||
|
||||
for k, v in dataset_dict.items():
|
||||
dataset_dict[k] = v.astype(np.float32)
|
||||
|
||||
dataset_dict["dones"] = dones
|
||||
|
||||
if cfg.is_data_clip:
|
||||
lim = 1 - cfg.data_clip_eps
|
||||
dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim)
|
||||
reward_normalizer = get_reward_normalizer(cfg, dataset_dict)
|
||||
dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"])
|
||||
|
||||
for key in required_keys:
|
||||
assert key in dataset_dict.keys(), f"Missing `{key}` in dataset."
|
||||
|
||||
if return_reward_normalizer:
|
||||
return dataset_dict, reward_normalizer
|
||||
return dataset_dict
|
||||
|
||||
|
||||
def get_trajectory_boundaries_and_returns(dataset):
|
||||
"""
|
||||
Split dataset into trajectories and compute returns
|
||||
"""
|
||||
episode_starts = [0]
|
||||
episode_ends = []
|
||||
|
||||
episode_return = 0
|
||||
episode_returns = []
|
||||
|
||||
n_transitions = len(dataset["rewards"])
|
||||
|
||||
for i in range(n_transitions):
|
||||
episode_return += dataset["rewards"][i]
|
||||
|
||||
if dataset["dones"][i]:
|
||||
episode_returns.append(episode_return)
|
||||
episode_ends.append(i + 1)
|
||||
if i + 1 < n_transitions:
|
||||
episode_starts.append(i + 1)
|
||||
episode_return = 0.0
|
||||
|
||||
return episode_starts, episode_ends, episode_returns
|
||||
|
||||
|
||||
def normalize_returns(dataset, scaling=1000):
|
||||
"""
|
||||
Normalize returns in the dataset
|
||||
"""
|
||||
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
|
||||
dataset["rewards"] /= np.max(episode_returns) - np.min(episode_returns)
|
||||
dataset["rewards"] *= scaling
|
||||
return dataset
|
||||
|
||||
|
||||
def get_reward_normalizer(cfg, dataset):
|
||||
"""
|
||||
Get a reward normalizer for the dataset
|
||||
"""
|
||||
if cfg.task.startswith("xarm"):
|
||||
return lambda x: x
|
||||
elif "maze" in cfg.task:
|
||||
return lambda x: x - 1.0
|
||||
elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]:
|
||||
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
|
||||
return (
|
||||
lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
|
||||
)
|
||||
elif hasattr(cfg, "reward_scale"):
|
||||
return lambda x: x * cfg.reward_scale
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def linear_schedule(schdl, step):
|
||||
"""
|
||||
Outputs values following a linear decay schedule.
|
||||
Adapted from https://github.com/facebookresearch/drqv2
|
||||
"""
|
||||
try:
|
||||
return float(schdl)
|
||||
except ValueError:
|
||||
match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl)
|
||||
if match:
|
||||
init, final, start, end = [float(g) for g in match.groups()]
|
||||
mix = np.clip((step - start) / (end - start), 0.0, 1.0)
|
||||
return (1.0 - mix) * init + mix * final
|
||||
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
|
||||
if match:
|
||||
init, final, duration = [float(g) for g in match.groups()]
|
||||
mix = np.clip(step / duration, 0.0, 1.0)
|
||||
return (1.0 - mix) * init + mix * final
|
||||
raise NotImplementedError(schdl)
|
|
@ -0,0 +1,12 @@
|
|||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
|
@ -5,70 +5,54 @@ import imageio
|
|||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from tensordict.nn import TensorDictModule
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.lib.envs.factory import make_env
|
||||
from lerobot.lib.tdmpc import TDMPC
|
||||
from lerobot.lib.utils import set_seed
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.tdmpc import TDMPC
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
|
||||
def eval_agent(
|
||||
env, agent, num_episodes: int, save_video: bool = False, video_path: Path = None
|
||||
def eval_policy(
|
||||
env, policy, num_episodes: int, save_video: bool = False, video_dir: Path = None
|
||||
):
|
||||
"""Evaluate a trained agent and optionally save a video."""
|
||||
if save_video:
|
||||
assert video_path is not None
|
||||
assert video_path.suffix == ".mp4"
|
||||
episode_rewards = []
|
||||
episode_successes = []
|
||||
episode_lengths = []
|
||||
rewards = []
|
||||
successes = []
|
||||
for i in range(num_episodes):
|
||||
td = env.reset()
|
||||
obs = {}
|
||||
obs["rgb"] = td["observation"]["camera"]
|
||||
obs["state"] = td["observation"]["robot_state"]
|
||||
ep_frames = []
|
||||
|
||||
done = False
|
||||
ep_reward = 0
|
||||
t = 0
|
||||
ep_success = False
|
||||
def rendering_callback(env, td=None):
|
||||
nonlocal ep_frames
|
||||
frame = env.render()
|
||||
ep_frames.append(frame)
|
||||
|
||||
tensordict = env.reset()
|
||||
# render first frame before rollout
|
||||
rendering_callback(env)
|
||||
|
||||
rollout = env.rollout(
|
||||
max_steps=30,
|
||||
policy=policy,
|
||||
callback=rendering_callback,
|
||||
auto_reset=False,
|
||||
tensordict=tensordict,
|
||||
)
|
||||
ep_reward = rollout["next", "reward"].sum()
|
||||
ep_success = rollout["next", "success"].any()
|
||||
rewards.append(ep_reward.item())
|
||||
successes.append(ep_success.item())
|
||||
|
||||
if save_video:
|
||||
frames = []
|
||||
while not done:
|
||||
action = agent.act(obs, t0=t == 0, eval_mode=True, step=100000)
|
||||
td = TensorDict({"action": action}, batch_size=[])
|
||||
|
||||
td = env.step(td)
|
||||
|
||||
reward = td["next", "reward"].item()
|
||||
success = td["next", "success"].item()
|
||||
done = td["next", "done"].item()
|
||||
|
||||
obs = {}
|
||||
obs["rgb"] = td["next", "observation"]["camera"]
|
||||
obs["state"] = td["next", "observation"]["robot_state"]
|
||||
|
||||
ep_reward += reward
|
||||
if success:
|
||||
ep_success = True
|
||||
if save_video:
|
||||
frame = env.render()
|
||||
frames.append(frame)
|
||||
t += 1
|
||||
episode_rewards.append(float(ep_reward))
|
||||
episode_successes.append(float(ep_success))
|
||||
episode_lengths.append(t)
|
||||
if save_video:
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
frames = np.stack(frames) # .transpose(0, 3, 1, 2)
|
||||
video_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
# TODO(rcadene): make fps configurable
|
||||
imageio.mimsave(video_path, frames, fps=15)
|
||||
return {
|
||||
"episode_reward": np.nanmean(episode_rewards),
|
||||
"episode_success": np.nanmean(episode_successes),
|
||||
"episode_length": np.nanmean(episode_lengths),
|
||||
video_path = video_dir / f"eval_episode_{i}.mp4"
|
||||
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
|
||||
|
||||
metrics = {
|
||||
"avg_reward": np.nanmean(rewards),
|
||||
"pc_success": np.nanmean(successes) * 100,
|
||||
}
|
||||
return metrics
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||
|
@ -78,20 +62,25 @@ def eval(cfg: dict):
|
|||
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir)
|
||||
|
||||
env = make_env(cfg)
|
||||
agent = TDMPC(cfg)
|
||||
policy = TDMPC(cfg)
|
||||
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||
agent.load(ckpt_path)
|
||||
policy.load(ckpt_path)
|
||||
|
||||
eval_metrics = eval_agent(
|
||||
env,
|
||||
agent,
|
||||
num_episodes=10,
|
||||
save_video=True,
|
||||
video_path=Path("tmp/2023_01_29_xarm_lift_final/eval.mp4"),
|
||||
policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=["observation", "step_count"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
|
||||
print(eval_metrics)
|
||||
metrics = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
num_episodes=10,
|
||||
save_video=True,
|
||||
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
||||
)
|
||||
print(metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -2,7 +2,10 @@ import hydra
|
|||
import torch
|
||||
from termcolor import colored
|
||||
|
||||
from ..lib.utils import set_seed
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.tdmpc import TDMPC
|
||||
|
||||
from ..common.utils import set_seed
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||
|
@ -11,6 +14,24 @@ def train(cfg: dict):
|
|||
set_seed(cfg.seed)
|
||||
print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir)
|
||||
|
||||
env = make_env(cfg)
|
||||
agent = TDMPC(cfg)
|
||||
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||
agent.load(ckpt_path)
|
||||
|
||||
# online training
|
||||
|
||||
eval_metrics = train_agent(
|
||||
env,
|
||||
agent,
|
||||
num_episodes=10,
|
||||
save_video=True,
|
||||
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
||||
)
|
||||
|
||||
print(eval_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
from tensordict import TensorDict
|
||||
from torchrl.envs.utils import check_env_specs, step_mdp
|
||||
|
||||
from lerobot.lib.envs import SimxarmEnv
|
||||
from lerobot.common.envs import SimxarmEnv
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
Loading…
Reference in New Issue