Save action, display next_action

This commit is contained in:
Remi Cadene 2025-02-18 22:12:59 +01:00
parent a66a792029
commit 5f32d75b58
3 changed files with 47 additions and 25 deletions

View File

@ -159,9 +159,10 @@ class ACTPolicy(PreTrainedPolicy):
l1_loss *= ~batch["action_is_pad"].unsqueeze(-1)
bsize, seqlen, num_motors = l1_loss.shape
loss_dict = {
output_dict = {
"l1_loss": l1_loss.mean().item(),
"l1_loss_per_item": l1_loss.view(bsize, seqlen * num_motors).mean(dim=1),
"action": self.unnormalize_outputs({"action": actions_hat})["action"],
}
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
@ -169,19 +170,19 @@ class ACTPolicy(PreTrainedPolicy):
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
loss_dict["kld_loss_per_item"] = mean_kld
output_dict["kld_loss_per_item"] = mean_kld
mean_kld = mean_kld.mean()
loss_dict["kld_loss"] = mean_kld.item()
output_dict["kld_loss"] = mean_kld.item()
loss = l1_loss + mean_kld * self.config.kl_weight
loss_dict["loss_per_item"] = (
loss_dict["l1_loss_per_item"] + loss_dict["kld_loss_per_item"] * self.config.kl_weight
output_dict["loss_per_item"] = (
output_dict["l1_loss_per_item"] + output_dict["kld_loss_per_item"] * self.config.kl_weight
)
else:
loss = l1_loss
return loss, loss_dict
return loss, output_dict
class ACTTemporalEnsembler:

View File

@ -89,11 +89,8 @@ def save_inference(cfg: SaveInferenceConfig):
# Create a temporary directory that will be automatically cleaned up
output_dir = tempfile.mkdtemp(prefix="lerobot_save_inference_")
elif Path(output_dir).exists():
if cfg.force_override:
shutil.rmtree(cfg.output_dir)
else:
raise NotImplementedError(f"Output directory already exists: {cfg.output_dir}")
elif Path(output_dir).exists() and cfg.force_override:
shutil.rmtree(cfg.output_dir)
output_dir = Path(output_dir)
@ -119,27 +116,33 @@ def save_inference(cfg: SaveInferenceConfig):
with torch.no_grad(), torch.autocast(device_type=cfg.device) if cfg.use_amp else nullcontext():
_, output_dict = policy.forward(batch)
bsize = batch["episode_index"].shape[0]
batch_size = batch["episode_index"].shape[0]
episode_indices.append(batch["episode_index"])
frame_indices.append(batch["frame_index"])
for key in output_dict:
if "loss_per_item" not in key:
for key, value in output_dict.items():
if not isinstance(value, torch.Tensor) or value.shape[0] != batch_size:
print(f"Skipping {key}")
continue
if key not in feats:
feats[key] = []
if not (output_dict[key].ndim == 1 and output_dict[key].shape[0] == bsize):
raise ValueError(output_dict[key].shape)
feats[key].append(value)
feats[key].append(output_dict[key])
episode_indices = torch.cat(episode_indices).cpu()
frame_indices = torch.cat(frame_indices).cpu()
episode_indices = torch.cat(episode_indices)
frame_indices = torch.cat(frame_indices)
for key in feats:
feats[key] = torch.cat(feats[key])
# TODO(rcadene): use collate?
for key, value in feats.items():
if isinstance(value[0], (float, int)):
feats[key] = torch.tensor(value)
elif isinstance(value[0], torch.Tensor):
feats[key] = torch.cat(value, dim=0).cpu()
elif isinstance(value[0], str):
pass
else:
raise NotImplementedError(f"{key}: {value}")
# Find unique episode indices
unique_episodes = torch.unique(episode_indices)
@ -147,7 +150,7 @@ def save_inference(cfg: SaveInferenceConfig):
for episode in unique_episodes:
ep_feats = {}
for key in feats:
ep_feats[key] = feats[key][episode_indices == episode].data.cpu()
ep_feats[key] = feats[key][episode_indices == episode]
output_dir.mkdir(parents=True, exist_ok=True)
torch.save(ep_feats, output_dir / f"output_features_episode_{episode}.pth")

View File

@ -295,8 +295,26 @@ def get_episode_data(
if inference_dir is not None:
feats = torch.load(inference_dir / f"output_features_episode_{episode_index}.pth")
for key in feats:
header.append(key.replace("loss_per_item", "loss"))
rows = np.concatenate([rows, feats[key][:, None]], axis=1)
if "loss_per_item" in key:
if feats[key].ndim != 1:
raise ValueError()
header.append(key.replace("loss_per_item", "loss"))
rows = np.concatenate([rows, feats[key][:, None]], axis=1)
elif key == "action":
if feats[key].ndim != 3:
raise ValueError()
next_action = feats[key][:, 0, :]
num_motors = next_action.shape[1]
for i in range(num_motors):
header.append(f"action_{i}")
rows = np.concatenate([rows, next_action], axis=1)
else:
raise NotImplementedError(key)
rows = rows.tolist()