Save action, display next_action
This commit is contained in:
parent
a66a792029
commit
5f32d75b58
|
@ -159,9 +159,10 @@ class ACTPolicy(PreTrainedPolicy):
|
|||
l1_loss *= ~batch["action_is_pad"].unsqueeze(-1)
|
||||
|
||||
bsize, seqlen, num_motors = l1_loss.shape
|
||||
loss_dict = {
|
||||
output_dict = {
|
||||
"l1_loss": l1_loss.mean().item(),
|
||||
"l1_loss_per_item": l1_loss.view(bsize, seqlen * num_motors).mean(dim=1),
|
||||
"action": self.unnormalize_outputs({"action": actions_hat})["action"],
|
||||
}
|
||||
if self.config.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
|
@ -169,19 +170,19 @@ class ACTPolicy(PreTrainedPolicy):
|
|||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
||||
loss_dict["kld_loss_per_item"] = mean_kld
|
||||
output_dict["kld_loss_per_item"] = mean_kld
|
||||
|
||||
mean_kld = mean_kld.mean()
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
output_dict["kld_loss"] = mean_kld.item()
|
||||
|
||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||
loss_dict["loss_per_item"] = (
|
||||
loss_dict["l1_loss_per_item"] + loss_dict["kld_loss_per_item"] * self.config.kl_weight
|
||||
output_dict["loss_per_item"] = (
|
||||
output_dict["l1_loss_per_item"] + output_dict["kld_loss_per_item"] * self.config.kl_weight
|
||||
)
|
||||
else:
|
||||
loss = l1_loss
|
||||
|
||||
return loss, loss_dict
|
||||
return loss, output_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
|
|
|
@ -89,11 +89,8 @@ def save_inference(cfg: SaveInferenceConfig):
|
|||
# Create a temporary directory that will be automatically cleaned up
|
||||
output_dir = tempfile.mkdtemp(prefix="lerobot_save_inference_")
|
||||
|
||||
elif Path(output_dir).exists():
|
||||
if cfg.force_override:
|
||||
shutil.rmtree(cfg.output_dir)
|
||||
else:
|
||||
raise NotImplementedError(f"Output directory already exists: {cfg.output_dir}")
|
||||
elif Path(output_dir).exists() and cfg.force_override:
|
||||
shutil.rmtree(cfg.output_dir)
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
|
@ -119,27 +116,33 @@ def save_inference(cfg: SaveInferenceConfig):
|
|||
with torch.no_grad(), torch.autocast(device_type=cfg.device) if cfg.use_amp else nullcontext():
|
||||
_, output_dict = policy.forward(batch)
|
||||
|
||||
bsize = batch["episode_index"].shape[0]
|
||||
batch_size = batch["episode_index"].shape[0]
|
||||
episode_indices.append(batch["episode_index"])
|
||||
frame_indices.append(batch["frame_index"])
|
||||
|
||||
for key in output_dict:
|
||||
if "loss_per_item" not in key:
|
||||
for key, value in output_dict.items():
|
||||
if not isinstance(value, torch.Tensor) or value.shape[0] != batch_size:
|
||||
print(f"Skipping {key}")
|
||||
continue
|
||||
|
||||
if key not in feats:
|
||||
feats[key] = []
|
||||
|
||||
if not (output_dict[key].ndim == 1 and output_dict[key].shape[0] == bsize):
|
||||
raise ValueError(output_dict[key].shape)
|
||||
feats[key].append(value)
|
||||
|
||||
feats[key].append(output_dict[key])
|
||||
episode_indices = torch.cat(episode_indices).cpu()
|
||||
frame_indices = torch.cat(frame_indices).cpu()
|
||||
|
||||
episode_indices = torch.cat(episode_indices)
|
||||
frame_indices = torch.cat(frame_indices)
|
||||
|
||||
for key in feats:
|
||||
feats[key] = torch.cat(feats[key])
|
||||
# TODO(rcadene): use collate?
|
||||
for key, value in feats.items():
|
||||
if isinstance(value[0], (float, int)):
|
||||
feats[key] = torch.tensor(value)
|
||||
elif isinstance(value[0], torch.Tensor):
|
||||
feats[key] = torch.cat(value, dim=0).cpu()
|
||||
elif isinstance(value[0], str):
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(f"{key}: {value}")
|
||||
|
||||
# Find unique episode indices
|
||||
unique_episodes = torch.unique(episode_indices)
|
||||
|
@ -147,7 +150,7 @@ def save_inference(cfg: SaveInferenceConfig):
|
|||
for episode in unique_episodes:
|
||||
ep_feats = {}
|
||||
for key in feats:
|
||||
ep_feats[key] = feats[key][episode_indices == episode].data.cpu()
|
||||
ep_feats[key] = feats[key][episode_indices == episode]
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(ep_feats, output_dir / f"output_features_episode_{episode}.pth")
|
||||
|
|
|
@ -295,8 +295,26 @@ def get_episode_data(
|
|||
if inference_dir is not None:
|
||||
feats = torch.load(inference_dir / f"output_features_episode_{episode_index}.pth")
|
||||
for key in feats:
|
||||
header.append(key.replace("loss_per_item", "loss"))
|
||||
rows = np.concatenate([rows, feats[key][:, None]], axis=1)
|
||||
if "loss_per_item" in key:
|
||||
if feats[key].ndim != 1:
|
||||
raise ValueError()
|
||||
|
||||
header.append(key.replace("loss_per_item", "loss"))
|
||||
rows = np.concatenate([rows, feats[key][:, None]], axis=1)
|
||||
|
||||
elif key == "action":
|
||||
if feats[key].ndim != 3:
|
||||
raise ValueError()
|
||||
|
||||
next_action = feats[key][:, 0, :]
|
||||
num_motors = next_action.shape[1]
|
||||
|
||||
for i in range(num_motors):
|
||||
header.append(f"action_{i}")
|
||||
|
||||
rows = np.concatenate([rows, next_action], axis=1)
|
||||
else:
|
||||
raise NotImplementedError(key)
|
||||
|
||||
rows = rows.tolist()
|
||||
|
||||
|
|
Loading…
Reference in New Issue