Online finetuning runs (sometimes crash because of nans)
This commit is contained in:
parent
228c045674
commit
c202c2b3c2
12
README.md
12
README.md
|
@ -15,12 +15,20 @@ conda activate lerobot
|
||||||
python setup.py develop
|
python setup.py develop
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
|
||||||
|
- [ ] priority update doesnt match FOWM or original paper
|
||||||
|
- [ ] self.step=100000 should be updated at every step to adjust to horizon of planner
|
||||||
|
- [ ] prefetch replay buffer to speedup training
|
||||||
|
- [ ] parallelize env to speedup eval
|
||||||
|
|
||||||
## Contribute
|
## Contribute
|
||||||
|
|
||||||
**style**
|
**style**
|
||||||
```
|
```
|
||||||
isort .
|
isort lerobot
|
||||||
black .
|
black lerobot
|
||||||
|
isort test
|
||||||
|
black test
|
||||||
pylint lerobot
|
pylint lerobot
|
||||||
```
|
```
|
||||||
|
|
|
@ -77,18 +77,16 @@ class SimxarmEnv(EnvBase):
|
||||||
|
|
||||||
def _format_raw_obs(self, raw_obs):
|
def _format_raw_obs(self, raw_obs):
|
||||||
if self.from_pixels:
|
if self.from_pixels:
|
||||||
camera = self.render(
|
image = self.render(
|
||||||
mode="rgb_array", width=self.image_size, height=self.image_size
|
mode="rgb_array", width=self.image_size, height=self.image_size
|
||||||
)
|
)
|
||||||
camera = camera.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||||
camera = torch.tensor(camera.copy(), dtype=torch.uint8)
|
image = torch.tensor(image.copy(), dtype=torch.uint8)
|
||||||
|
|
||||||
obs = {"camera": camera}
|
obs = {"image": image}
|
||||||
|
|
||||||
if not self.pixels_only:
|
if not self.pixels_only:
|
||||||
obs["robot_state"] = torch.tensor(
|
obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32)
|
||||||
self._env.robot_state, dtype=torch.float32
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
|
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
|
||||||
|
|
||||||
|
@ -136,7 +134,7 @@ class SimxarmEnv(EnvBase):
|
||||||
def _make_spec(self):
|
def _make_spec(self):
|
||||||
obs = {}
|
obs = {}
|
||||||
if self.from_pixels:
|
if self.from_pixels:
|
||||||
obs["camera"] = BoundedTensorSpec(
|
obs["image"] = BoundedTensorSpec(
|
||||||
low=0,
|
low=0,
|
||||||
high=255,
|
high=255,
|
||||||
shape=(3, self.image_size, self.image_size),
|
shape=(3, self.image_size, self.image_size),
|
||||||
|
@ -144,7 +142,7 @@ class SimxarmEnv(EnvBase):
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
if not self.pixels_only:
|
if not self.pixels_only:
|
||||||
obs["robot_state"] = UnboundedContinuousTensorSpec(
|
obs["state"] = UnboundedContinuousTensorSpec(
|
||||||
shape=(len(self._env.robot_state),),
|
shape=(len(self._env.robot_state),),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
|
|
@ -96,8 +96,7 @@ class TDMPC(nn.Module):
|
||||||
self.model_target.eval()
|
self.model_target.eval()
|
||||||
self.batch_size = cfg.batch_size
|
self.batch_size = cfg.batch_size
|
||||||
|
|
||||||
# TODO(rcadene): clean
|
self.step = 0
|
||||||
self.step = 100000
|
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
"""Retrieve state dict of TOLD model, including slow-moving target network."""
|
"""Retrieve state dict of TOLD model, including slow-moving target network."""
|
||||||
|
@ -120,8 +119,8 @@ class TDMPC(nn.Module):
|
||||||
def forward(self, observation, step_count):
|
def forward(self, observation, step_count):
|
||||||
t0 = step_count.item() == 0
|
t0 = step_count.item() == 0
|
||||||
obs = {
|
obs = {
|
||||||
"rgb": observation["camera"],
|
"rgb": observation["image"],
|
||||||
"state": observation["robot_state"],
|
"state": observation["state"],
|
||||||
}
|
}
|
||||||
return self.act(obs, t0=t0, step=self.step)
|
return self.act(obs, t0=t0, step=self.step)
|
||||||
|
|
||||||
|
@ -298,65 +297,81 @@ class TDMPC(nn.Module):
|
||||||
def update(self, replay_buffer, step, demo_buffer=None):
|
def update(self, replay_buffer, step, demo_buffer=None):
|
||||||
"""Main update function. Corresponds to one iteration of the model learning."""
|
"""Main update function. Corresponds to one iteration of the model learning."""
|
||||||
|
|
||||||
if demo_buffer is not None:
|
num_slices = self.cfg.batch_size
|
||||||
# Update oversampling ratio
|
batch_size = self.cfg.horizon * num_slices
|
||||||
self.demo_batch_size = int(
|
|
||||||
h.linear_schedule(self.cfg.demo_schedule, step) * self.batch_size
|
if demo_buffer is None:
|
||||||
)
|
demo_batch_size = 0
|
||||||
replay_buffer.cfg.batch_size = self.batch_size - self.demo_batch_size
|
|
||||||
demo_buffer.cfg.batch_size = self.demo_batch_size
|
|
||||||
else:
|
else:
|
||||||
self.demo_batch_size = 0
|
# Update oversampling ratio
|
||||||
|
demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
|
||||||
|
demo_num_slices = int(demo_pc_batch * self.batch_size)
|
||||||
|
demo_batch_size = self.cfg.horizon * demo_num_slices
|
||||||
|
batch_size -= demo_batch_size
|
||||||
|
num_slices -= demo_num_slices
|
||||||
|
replay_buffer._sampler.num_slices = num_slices
|
||||||
|
demo_buffer._sampler.num_slices = demo_num_slices
|
||||||
|
|
||||||
|
assert demo_batch_size % self.cfg.horizon == 0
|
||||||
|
assert demo_batch_size % demo_num_slices == 0
|
||||||
|
|
||||||
|
assert batch_size % self.cfg.horizon == 0
|
||||||
|
assert batch_size % num_slices == 0
|
||||||
|
|
||||||
# Sample from interaction dataset
|
# Sample from interaction dataset
|
||||||
|
|
||||||
# to not have to mask
|
def process_batch(batch, horizon, num_slices):
|
||||||
# batch_size = (self.cfg.batch_size // self.cfg.horizon) * self.cfg.horizon
|
# trajectory t = 256, horizon h = 5
|
||||||
batch_size = self.cfg.horizon * self.cfg.batch_size
|
# (t h) ... -> h t ...
|
||||||
|
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
||||||
|
batch = batch.to("cuda")
|
||||||
|
|
||||||
|
FIRST_FRAME = 0
|
||||||
|
obs = {
|
||||||
|
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
|
||||||
|
"state": batch["observation", "state"][FIRST_FRAME],
|
||||||
|
}
|
||||||
|
action = batch["action"]
|
||||||
|
next_obses = {
|
||||||
|
"rgb": batch["next", "observation", "image"].float(),
|
||||||
|
"state": batch["next", "observation", "state"],
|
||||||
|
}
|
||||||
|
reward = batch["next", "reward"]
|
||||||
|
|
||||||
|
# TODO(rcadene): rearrange directly in offline dataset
|
||||||
|
if reward.ndim == 2:
|
||||||
|
reward = einops.rearrange(reward, "h t -> h t 1")
|
||||||
|
|
||||||
|
assert reward.ndim == 3
|
||||||
|
assert reward.shape == (horizon, num_slices, 1)
|
||||||
|
# We dont use `batch["next", "done"]` since it only indicates the end of an
|
||||||
|
# episode, but not the end of the trajectory of an episode.
|
||||||
|
# Neither does `batch["next", "terminated"]`
|
||||||
|
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
|
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
|
|
||||||
|
idxs = batch["index"][FIRST_FRAME]
|
||||||
|
weights = batch["_weight"][FIRST_FRAME, :, None]
|
||||||
|
return obs, action, next_obses, reward, mask, done, idxs, weights
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size)
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
|
||||||
# trajectory t = 256, horizon h = 5
|
batch, self.cfg.horizon, num_slices
|
||||||
# (t h) ... -> h t ...
|
|
||||||
batch = (
|
|
||||||
batch.reshape(self.cfg.batch_size, self.cfg.horizon)
|
|
||||||
.transpose(1, 0)
|
|
||||||
.contiguous()
|
|
||||||
)
|
)
|
||||||
batch = batch.to("cuda")
|
|
||||||
|
|
||||||
FIRST_FRAME = 0
|
|
||||||
obs = {
|
|
||||||
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
|
|
||||||
"state": batch["observation", "state"][FIRST_FRAME],
|
|
||||||
}
|
|
||||||
action = batch["action"]
|
|
||||||
next_obses = {
|
|
||||||
"rgb": batch["next", "observation", "image"].float(),
|
|
||||||
"state": batch["next", "observation", "state"],
|
|
||||||
}
|
|
||||||
reward = batch["next", "reward"]
|
|
||||||
reward = einops.rearrange(reward, "h t -> h t 1")
|
|
||||||
# We dont use `batch["next", "done"]` since it only indicates the end of an
|
|
||||||
# episode, but not the end of the trajectory of an episode.
|
|
||||||
# Neither does `batch["next", "terminated"]`
|
|
||||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
|
||||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
|
||||||
|
|
||||||
idxs = batch["frame_id"][FIRST_FRAME]
|
|
||||||
weights = batch["_weight"][FIRST_FRAME, :, None]
|
|
||||||
|
|
||||||
# Sample from demonstration dataset
|
# Sample from demonstration dataset
|
||||||
if self.demo_batch_size > 0:
|
if demo_batch_size > 0:
|
||||||
|
demo_batch = demo_buffer.sample(demo_batch_size)
|
||||||
(
|
(
|
||||||
demo_obs,
|
demo_obs,
|
||||||
demo_next_obses,
|
|
||||||
demo_action,
|
demo_action,
|
||||||
|
demo_next_obses,
|
||||||
demo_reward,
|
demo_reward,
|
||||||
demo_mask,
|
demo_mask,
|
||||||
demo_done,
|
demo_done,
|
||||||
demo_idxs,
|
demo_idxs,
|
||||||
demo_weights,
|
demo_weights,
|
||||||
) = demo_buffer.sample()
|
) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
|
||||||
|
|
||||||
if isinstance(obs, dict):
|
if isinstance(obs, dict):
|
||||||
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
|
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
|
||||||
|
@ -440,9 +455,9 @@ class TDMPC(nn.Module):
|
||||||
q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0)
|
q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||||
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||||
|
|
||||||
self.expectile = h.linear_schedule(self.cfg.expectile, step)
|
expectile = h.linear_schedule(self.cfg.expectile, step)
|
||||||
v_value_loss = (
|
v_value_loss = (
|
||||||
rho * h.l2_expectile(v_target - v, expectile=self.expectile) * loss_mask
|
rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask
|
||||||
).sum(dim=0)
|
).sum(dim=0)
|
||||||
|
|
||||||
total_loss = (
|
total_loss = (
|
||||||
|
@ -464,17 +479,12 @@ class TDMPC(nn.Module):
|
||||||
if self.cfg.per:
|
if self.cfg.per:
|
||||||
# Update priorities
|
# Update priorities
|
||||||
priorities = priority_loss.clamp(max=1e4).detach()
|
priorities = priority_loss.clamp(max=1e4).detach()
|
||||||
# normalize between [0,1] to fit torchrl specification
|
|
||||||
priorities /= 1e4
|
|
||||||
priorities = priorities.clamp(max=1.0)
|
|
||||||
replay_buffer.update_priority(
|
replay_buffer.update_priority(
|
||||||
idxs[: self.cfg.batch_size],
|
idxs[:num_slices],
|
||||||
priorities[: self.cfg.batch_size],
|
priorities[:num_slices],
|
||||||
)
|
)
|
||||||
if self.demo_batch_size > 0:
|
if demo_batch_size > 0:
|
||||||
demo_buffer.update_priority(
|
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
|
||||||
demo_idxs, priorities[self.cfg.batch_size :]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update policy + target network
|
# Update policy + target network
|
||||||
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
|
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
|
||||||
|
@ -493,10 +503,12 @@ class TDMPC(nn.Module):
|
||||||
"weighted_loss": float(weighted_loss.mean().item()),
|
"weighted_loss": float(weighted_loss.mean().item()),
|
||||||
"grad_norm": float(grad_norm),
|
"grad_norm": float(grad_norm),
|
||||||
}
|
}
|
||||||
for key in ["demo_batch_size", "expectile"]:
|
# for key in ["demo_batch_size", "expectile"]:
|
||||||
if hasattr(self, key):
|
# if hasattr(self, key):
|
||||||
metrics[key] = getattr(self, key)
|
metrics["demo_batch_size"] = demo_batch_size
|
||||||
|
metrics["expectile"] = expectile
|
||||||
metrics.update(value_info)
|
metrics.update(value_info)
|
||||||
metrics.update(pi_update_info)
|
metrics.update(pi_update_info)
|
||||||
|
|
||||||
|
self.step = step
|
||||||
return metrics
|
return metrics
|
||||||
|
|
|
@ -80,7 +80,7 @@ expectile: 0.9
|
||||||
A_scaling: 3.0
|
A_scaling: 3.0
|
||||||
|
|
||||||
# offline->online
|
# offline->online
|
||||||
offline_steps: ${train_steps}/2
|
offline_steps: 25000 # ${train_steps}/2
|
||||||
pretrained_model_path: ""
|
pretrained_model_path: ""
|
||||||
balanced_sampling: true
|
balanced_sampling: true
|
||||||
demo_schedule: 0.5
|
demo_schedule: 0.5
|
||||||
|
|
|
@ -19,6 +19,7 @@ from lerobot.common.logger import Logger
|
||||||
from lerobot.common.tdmpc import TDMPC
|
from lerobot.common.tdmpc import TDMPC
|
||||||
from lerobot.common.utils import set_seed
|
from lerobot.common.utils import set_seed
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
from rl.torchrl.collectors.collectors import SyncDataCollector
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||||
|
@ -29,8 +30,10 @@ def train(cfg: dict):
|
||||||
|
|
||||||
env = make_env(cfg)
|
env = make_env(cfg)
|
||||||
policy = TDMPC(cfg)
|
policy = TDMPC(cfg)
|
||||||
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||||
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
policy.step = 25000
|
||||||
|
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||||
|
# policy.step = 100000
|
||||||
policy.load(ckpt_path)
|
policy.load(ckpt_path)
|
||||||
|
|
||||||
td_policy = TensorDictModule(
|
td_policy = TensorDictModule(
|
||||||
|
@ -54,7 +57,7 @@ def train(cfg: dict):
|
||||||
strict_length=False,
|
strict_length=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(rcadene): use PrioritizedReplayBuffer
|
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
||||||
offline_buffer = SimxarmExperienceReplay(
|
offline_buffer = SimxarmExperienceReplay(
|
||||||
dataset_id,
|
dataset_id,
|
||||||
# download="force",
|
# download="force",
|
||||||
|
@ -68,9 +71,22 @@ def train(cfg: dict):
|
||||||
index = torch.arange(0, num_steps, 1)
|
index = torch.arange(0, num_steps, 1)
|
||||||
sampler.extend(index)
|
sampler.extend(index)
|
||||||
|
|
||||||
# offline_buffer._storage.device = torch.device("cuda")
|
if cfg.balanced_sampling:
|
||||||
# offline_buffer._storage._storage.to(torch.device("cuda"))
|
online_sampler = PrioritizedSliceSampler(
|
||||||
# TODO(rcadene): add online_buffer
|
max_capacity=100_000,
|
||||||
|
alpha=0.7,
|
||||||
|
beta=0.9,
|
||||||
|
num_slices=num_traj_per_batch,
|
||||||
|
strict_length=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
online_buffer = TensorDictReplayBuffer(
|
||||||
|
storage=LazyMemmapStorage(100_000),
|
||||||
|
sampler=online_sampler,
|
||||||
|
# batch_size=3,
|
||||||
|
# pin_memory=False,
|
||||||
|
# prefetch=3,
|
||||||
|
)
|
||||||
|
|
||||||
# Observation encoder
|
# Observation encoder
|
||||||
# Dynamics predictor
|
# Dynamics predictor
|
||||||
|
@ -81,59 +97,80 @@ def train(cfg: dict):
|
||||||
|
|
||||||
L = Logger(cfg.log_dir, cfg)
|
L = Logger(cfg.log_dir, cfg)
|
||||||
|
|
||||||
episode_idx = 0
|
online_episode_idx = 0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
step = 0
|
step = 0
|
||||||
last_log_step = 0
|
last_log_step = 0
|
||||||
last_save_step = 0
|
last_save_step = 0
|
||||||
|
|
||||||
|
# TODO(rcadene): remove
|
||||||
|
step = 25000
|
||||||
|
|
||||||
while step < cfg.train_steps:
|
while step < cfg.train_steps:
|
||||||
is_offline = True
|
is_offline = True
|
||||||
num_updates = cfg.episode_length
|
num_updates = cfg.episode_length
|
||||||
_step = step + num_updates
|
_step = step + num_updates
|
||||||
rollout_metrics = {}
|
rollout_metrics = {}
|
||||||
|
|
||||||
# if step >= cfg.offline_steps:
|
if step >= cfg.offline_steps:
|
||||||
# is_offline = False
|
is_offline = False
|
||||||
|
|
||||||
# # Collect trajectory
|
# TODO: use SyncDataCollector for that?
|
||||||
# obs = env.reset()
|
rollout = env.rollout(
|
||||||
# episode = Episode(cfg, obs)
|
max_steps=cfg.episode_length,
|
||||||
# success = False
|
policy=td_policy,
|
||||||
# while not episode.done:
|
)
|
||||||
# action = policy.act(obs, step=step, t0=episode.first)
|
assert len(rollout) <= cfg.episode_length
|
||||||
# obs, reward, done, info = env.step(action.cpu().numpy())
|
rollout["episode"] = torch.tensor(
|
||||||
# reward = reward_normalizer(reward)
|
[online_episode_idx] * len(rollout), dtype=torch.int
|
||||||
# mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0
|
)
|
||||||
# success = info.get('success', False)
|
online_buffer.extend(rollout)
|
||||||
# episode += (obs, action, reward, done, mask, success)
|
|
||||||
# assert len(episode) <= cfg.episode_length
|
# Collect trajectory
|
||||||
# buffer += episode
|
# obs = env.reset()
|
||||||
# episode_idx += 1
|
# episode = Episode(cfg, obs)
|
||||||
# rollout_metrics = {
|
# success = False
|
||||||
# 'episode_reward': episode.cumulative_reward,
|
# while not episode.done:
|
||||||
# 'episode_success': float(success),
|
# action = policy.act(obs, step=step, t0=episode.first)
|
||||||
# 'episode_length': len(episode)
|
# obs, reward, done, info = env.step(action.cpu().numpy())
|
||||||
# }
|
# reward = reward_normalizer(reward)
|
||||||
# num_updates = len(episode) * cfg.utd
|
# mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0
|
||||||
# _step = min(step + len(episode), cfg.train_steps)
|
# success = info.get('success', False)
|
||||||
|
# episode += (obs, action, reward, done, mask, success)
|
||||||
|
|
||||||
|
ep_reward = rollout["next", "reward"].sum()
|
||||||
|
ep_success = rollout["next", "success"].any()
|
||||||
|
|
||||||
|
online_episode_idx += 1
|
||||||
|
rollout_metrics = {
|
||||||
|
# 'episode_reward': episode.cumulative_reward,
|
||||||
|
# 'episode_success': float(success),
|
||||||
|
# 'episode_length': len(episode)
|
||||||
|
"avg_reward": np.nanmean(ep_reward),
|
||||||
|
"pc_success": np.nanmean(ep_success) * 100,
|
||||||
|
}
|
||||||
|
num_updates = len(rollout) * cfg.utd
|
||||||
|
_step = min(step + len(rollout), cfg.train_steps)
|
||||||
|
|
||||||
# Update model
|
# Update model
|
||||||
train_metrics = {}
|
train_metrics = {}
|
||||||
if is_offline:
|
if is_offline:
|
||||||
for i in range(num_updates):
|
for i in range(num_updates):
|
||||||
train_metrics.update(policy.update(offline_buffer, step + i))
|
train_metrics.update(policy.update(offline_buffer, step + i))
|
||||||
# else:
|
else:
|
||||||
# for i in range(num_updates):
|
for i in range(num_updates):
|
||||||
# train_metrics.update(
|
train_metrics.update(
|
||||||
# policy.update(buffer, step + i // cfg.utd,
|
policy.update(
|
||||||
# demo_buffer=offline_buffer if cfg.balanced_sampling else None)
|
online_buffer,
|
||||||
# )
|
step + i // cfg.utd,
|
||||||
|
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Log training metrics
|
# Log training metrics
|
||||||
env_step = int(_step * cfg.action_repeat)
|
env_step = int(_step * cfg.action_repeat)
|
||||||
common_metrics = {
|
common_metrics = {
|
||||||
"episode": episode_idx,
|
"episode": online_episode_idx,
|
||||||
"step": _step,
|
"step": _step,
|
||||||
"env_step": env_step,
|
"env_step": env_step,
|
||||||
"total_time": time.time() - start_time,
|
"total_time": time.time() - start_time,
|
||||||
|
|
Loading…
Reference in New Issue