Save action, display next_action

This commit is contained in:
Remi Cadene 2025-02-18 22:12:59 +01:00
parent a66a792029
commit 5f32d75b58
3 changed files with 47 additions and 25 deletions

View File

@ -159,9 +159,10 @@ class ACTPolicy(PreTrainedPolicy):
l1_loss *= ~batch["action_is_pad"].unsqueeze(-1) l1_loss *= ~batch["action_is_pad"].unsqueeze(-1)
bsize, seqlen, num_motors = l1_loss.shape bsize, seqlen, num_motors = l1_loss.shape
loss_dict = { output_dict = {
"l1_loss": l1_loss.mean().item(), "l1_loss": l1_loss.mean().item(),
"l1_loss_per_item": l1_loss.view(bsize, seqlen * num_motors).mean(dim=1), "l1_loss_per_item": l1_loss.view(bsize, seqlen * num_motors).mean(dim=1),
"action": self.unnormalize_outputs({"action": actions_hat})["action"],
} }
if self.config.use_vae: if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
@ -169,19 +170,19 @@ class ACTPolicy(PreTrainedPolicy):
# KL-divergence per batch element, then take the mean over the batch. # KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details). # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1) mean_kld = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
loss_dict["kld_loss_per_item"] = mean_kld output_dict["kld_loss_per_item"] = mean_kld
mean_kld = mean_kld.mean() mean_kld = mean_kld.mean()
loss_dict["kld_loss"] = mean_kld.item() output_dict["kld_loss"] = mean_kld.item()
loss = l1_loss + mean_kld * self.config.kl_weight loss = l1_loss + mean_kld * self.config.kl_weight
loss_dict["loss_per_item"] = ( output_dict["loss_per_item"] = (
loss_dict["l1_loss_per_item"] + loss_dict["kld_loss_per_item"] * self.config.kl_weight output_dict["l1_loss_per_item"] + output_dict["kld_loss_per_item"] * self.config.kl_weight
) )
else: else:
loss = l1_loss loss = l1_loss
return loss, loss_dict return loss, output_dict
class ACTTemporalEnsembler: class ACTTemporalEnsembler:

View File

@ -89,11 +89,8 @@ def save_inference(cfg: SaveInferenceConfig):
# Create a temporary directory that will be automatically cleaned up # Create a temporary directory that will be automatically cleaned up
output_dir = tempfile.mkdtemp(prefix="lerobot_save_inference_") output_dir = tempfile.mkdtemp(prefix="lerobot_save_inference_")
elif Path(output_dir).exists(): elif Path(output_dir).exists() and cfg.force_override:
if cfg.force_override:
shutil.rmtree(cfg.output_dir) shutil.rmtree(cfg.output_dir)
else:
raise NotImplementedError(f"Output directory already exists: {cfg.output_dir}")
output_dir = Path(output_dir) output_dir = Path(output_dir)
@ -119,27 +116,33 @@ def save_inference(cfg: SaveInferenceConfig):
with torch.no_grad(), torch.autocast(device_type=cfg.device) if cfg.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=cfg.device) if cfg.use_amp else nullcontext():
_, output_dict = policy.forward(batch) _, output_dict = policy.forward(batch)
bsize = batch["episode_index"].shape[0] batch_size = batch["episode_index"].shape[0]
episode_indices.append(batch["episode_index"]) episode_indices.append(batch["episode_index"])
frame_indices.append(batch["frame_index"]) frame_indices.append(batch["frame_index"])
for key in output_dict: for key, value in output_dict.items():
if "loss_per_item" not in key: if not isinstance(value, torch.Tensor) or value.shape[0] != batch_size:
print(f"Skipping {key}")
continue continue
if key not in feats: if key not in feats:
feats[key] = [] feats[key] = []
if not (output_dict[key].ndim == 1 and output_dict[key].shape[0] == bsize): feats[key].append(value)
raise ValueError(output_dict[key].shape)
feats[key].append(output_dict[key]) episode_indices = torch.cat(episode_indices).cpu()
frame_indices = torch.cat(frame_indices).cpu()
episode_indices = torch.cat(episode_indices) # TODO(rcadene): use collate?
frame_indices = torch.cat(frame_indices) for key, value in feats.items():
if isinstance(value[0], (float, int)):
for key in feats: feats[key] = torch.tensor(value)
feats[key] = torch.cat(feats[key]) elif isinstance(value[0], torch.Tensor):
feats[key] = torch.cat(value, dim=0).cpu()
elif isinstance(value[0], str):
pass
else:
raise NotImplementedError(f"{key}: {value}")
# Find unique episode indices # Find unique episode indices
unique_episodes = torch.unique(episode_indices) unique_episodes = torch.unique(episode_indices)
@ -147,7 +150,7 @@ def save_inference(cfg: SaveInferenceConfig):
for episode in unique_episodes: for episode in unique_episodes:
ep_feats = {} ep_feats = {}
for key in feats: for key in feats:
ep_feats[key] = feats[key][episode_indices == episode].data.cpu() ep_feats[key] = feats[key][episode_indices == episode]
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
torch.save(ep_feats, output_dir / f"output_features_episode_{episode}.pth") torch.save(ep_feats, output_dir / f"output_features_episode_{episode}.pth")

View File

@ -295,9 +295,27 @@ def get_episode_data(
if inference_dir is not None: if inference_dir is not None:
feats = torch.load(inference_dir / f"output_features_episode_{episode_index}.pth") feats = torch.load(inference_dir / f"output_features_episode_{episode_index}.pth")
for key in feats: for key in feats:
if "loss_per_item" in key:
if feats[key].ndim != 1:
raise ValueError()
header.append(key.replace("loss_per_item", "loss")) header.append(key.replace("loss_per_item", "loss"))
rows = np.concatenate([rows, feats[key][:, None]], axis=1) rows = np.concatenate([rows, feats[key][:, None]], axis=1)
elif key == "action":
if feats[key].ndim != 3:
raise ValueError()
next_action = feats[key][:, 0, :]
num_motors = next_action.shape[1]
for i in range(num_motors):
header.append(f"action_{i}")
rows = np.concatenate([rows, next_action], axis=1)
else:
raise NotImplementedError(key)
rows = rows.tolist() rows = rows.tolist()
# Convert data to CSV string # Convert data to CSV string