Update time to monotonic

This commit is contained in:
Cadene 2024-03-18 16:26:07 +00:00
parent 2bef00c317
commit a346469a5a
5 changed files with 18 additions and 18 deletions

View File

@ -54,7 +54,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
def update(self, replay_buffer, step): def update(self, replay_buffer, step):
del step del step
start_time = time.time() start_time = time.monotonic()
self.train() self.train()
@ -104,7 +104,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
batch = replay_buffer.sample(batch_size) batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices) batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time data_s = time.monotonic() - start_time
loss = self.compute_loss(batch) loss = self.compute_loss(batch)
loss.backward() loss.backward()
@ -125,7 +125,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# "lr": self.lr_scheduler.get_last_lr()[0], # "lr": self.lr_scheduler.get_last_lr()[0],
"lr": self.cfg.lr, "lr": self.cfg.lr,
"data_s": data_s, "data_s": data_s,
"update_s": time.time() - start_time, "update_s": time.monotonic() - start_time,
} }
return info return info

View File

@ -188,8 +188,8 @@ class MetricLogger:
def log_every(self, iterable, print_freq, header=None): def log_every(self, iterable, print_freq, header=None):
if not header: if not header:
header = "" header = ""
start_time = time.time() start_time = time.monotonic()
end = time.time() end = time.monotonic()
iter_time = SmoothedValue(fmt="{avg:.4f}") iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}") data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d" space_fmt = ":" + str(len(str(len(iterable)))) + "d"
@ -218,9 +218,9 @@ class MetricLogger:
) )
mega_b = 1024.0 * 1024.0 mega_b = 1024.0 * 1024.0
for i, obj in enumerate(iterable): for i, obj in enumerate(iterable):
data_time.update(time.time() - end) data_time.update(time.monotonic() - end)
yield obj yield obj
iter_time.update(time.time() - end) iter_time.update(time.monotonic() - end)
if i % print_freq == 0 or i == len(iterable) - 1: if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
@ -247,8 +247,8 @@ class MetricLogger:
data=str(data_time), data=str(data_time),
) )
) )
end = time.time() end = time.monotonic()
total_time = time.time() - start_time total_time = time.monotonic() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable))) print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))

View File

@ -111,7 +111,7 @@ class DiffusionPolicy(nn.Module):
return action return action
def update(self, replay_buffer, step): def update(self, replay_buffer, step):
start_time = time.time() start_time = time.monotonic()
self.diffusion.train() self.diffusion.train()
@ -158,7 +158,7 @@ class DiffusionPolicy(nn.Module):
batch = replay_buffer.sample(batch_size) batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices) batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time data_s = time.monotonic() - start_time
loss = self.diffusion.compute_loss(batch) loss = self.diffusion.compute_loss(batch)
loss.backward() loss.backward()
@ -181,7 +181,7 @@ class DiffusionPolicy(nn.Module):
"grad_norm": float(grad_norm), "grad_norm": float(grad_norm),
"lr": self.lr_scheduler.get_last_lr()[0], "lr": self.lr_scheduler.get_last_lr()[0],
"data_s": data_s, "data_s": data_s,
"update_s": time.time() - start_time, "update_s": time.monotonic() - start_time,
} }
# TODO(rcadene): remove hardcoding # TODO(rcadene): remove hardcoding

View File

@ -291,7 +291,7 @@ class TDMPC(nn.Module):
def update(self, replay_buffer, step, demo_buffer=None): def update(self, replay_buffer, step, demo_buffer=None):
"""Main update function. Corresponds to one iteration of the model learning.""" """Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time() start_time = time.monotonic()
num_slices = self.cfg.batch_size num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices batch_size = self.cfg.horizon * num_slices
@ -405,7 +405,7 @@ class TDMPC(nn.Module):
self.std = h.linear_schedule(self.cfg.std_schedule, step) self.std = h.linear_schedule(self.cfg.std_schedule, step)
self.model.train() self.model.train()
data_s = time.time() - start_time data_s = time.monotonic() - start_time
# Compute targets # Compute targets
with torch.no_grad(): with torch.no_grad():
@ -501,7 +501,7 @@ class TDMPC(nn.Module):
"grad_norm": float(grad_norm), "grad_norm": float(grad_norm),
"lr": self.cfg.lr, "lr": self.cfg.lr,
"data_s": data_s, "data_s": data_s,
"update_s": time.time() - start_time, "update_s": time.monotonic() - start_time,
} }
info["demo_batch_size"] = demo_batch_size info["demo_batch_size"] = demo_batch_size
info["expectile"] = expectile info["expectile"] = expectile

View File

@ -32,7 +32,7 @@ def eval_policy(
fps: int = 15, fps: int = 15,
return_first_video: bool = False, return_first_video: bool = False,
): ):
start = time.time() start = time.monotonic()
sum_rewards = [] sum_rewards = []
max_rewards = [] max_rewards = []
successes = [] successes = []
@ -85,8 +85,8 @@ def eval_policy(
"avg_sum_reward": np.nanmean(sum_rewards), "avg_sum_reward": np.nanmean(sum_rewards),
"avg_max_reward": np.nanmean(max_rewards), "avg_max_reward": np.nanmean(max_rewards),
"pc_success": np.nanmean(successes) * 100, "pc_success": np.nanmean(successes) * 100,
"eval_s": time.time() - start, "eval_s": time.monotonic() - start,
"eval_ep_s": (time.time() - start) / num_episodes, "eval_ep_s": (time.monotonic() - start) / num_episodes,
} }
if return_first_video: if return_first_video:
return info, first_video return info, first_video