Save action, display next_action
This commit is contained in:
parent
a66a792029
commit
5f32d75b58
|
@ -159,9 +159,10 @@ class ACTPolicy(PreTrainedPolicy):
|
||||||
l1_loss *= ~batch["action_is_pad"].unsqueeze(-1)
|
l1_loss *= ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
|
|
||||||
bsize, seqlen, num_motors = l1_loss.shape
|
bsize, seqlen, num_motors = l1_loss.shape
|
||||||
loss_dict = {
|
output_dict = {
|
||||||
"l1_loss": l1_loss.mean().item(),
|
"l1_loss": l1_loss.mean().item(),
|
||||||
"l1_loss_per_item": l1_loss.view(bsize, seqlen * num_motors).mean(dim=1),
|
"l1_loss_per_item": l1_loss.view(bsize, seqlen * num_motors).mean(dim=1),
|
||||||
|
"action": self.unnormalize_outputs({"action": actions_hat})["action"],
|
||||||
}
|
}
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
|
@ -169,19 +170,19 @@ class ACTPolicy(PreTrainedPolicy):
|
||||||
# KL-divergence per batch element, then take the mean over the batch.
|
# KL-divergence per batch element, then take the mean over the batch.
|
||||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||||
mean_kld = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
mean_kld = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
||||||
loss_dict["kld_loss_per_item"] = mean_kld
|
output_dict["kld_loss_per_item"] = mean_kld
|
||||||
|
|
||||||
mean_kld = mean_kld.mean()
|
mean_kld = mean_kld.mean()
|
||||||
loss_dict["kld_loss"] = mean_kld.item()
|
output_dict["kld_loss"] = mean_kld.item()
|
||||||
|
|
||||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||||
loss_dict["loss_per_item"] = (
|
output_dict["loss_per_item"] = (
|
||||||
loss_dict["l1_loss_per_item"] + loss_dict["kld_loss_per_item"] * self.config.kl_weight
|
output_dict["l1_loss_per_item"] + output_dict["kld_loss_per_item"] * self.config.kl_weight
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
loss = l1_loss
|
loss = l1_loss
|
||||||
|
|
||||||
return loss, loss_dict
|
return loss, output_dict
|
||||||
|
|
||||||
|
|
||||||
class ACTTemporalEnsembler:
|
class ACTTemporalEnsembler:
|
||||||
|
|
|
@ -89,11 +89,8 @@ def save_inference(cfg: SaveInferenceConfig):
|
||||||
# Create a temporary directory that will be automatically cleaned up
|
# Create a temporary directory that will be automatically cleaned up
|
||||||
output_dir = tempfile.mkdtemp(prefix="lerobot_save_inference_")
|
output_dir = tempfile.mkdtemp(prefix="lerobot_save_inference_")
|
||||||
|
|
||||||
elif Path(output_dir).exists():
|
elif Path(output_dir).exists() and cfg.force_override:
|
||||||
if cfg.force_override:
|
|
||||||
shutil.rmtree(cfg.output_dir)
|
shutil.rmtree(cfg.output_dir)
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Output directory already exists: {cfg.output_dir}")
|
|
||||||
|
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
|
@ -119,27 +116,33 @@ def save_inference(cfg: SaveInferenceConfig):
|
||||||
with torch.no_grad(), torch.autocast(device_type=cfg.device) if cfg.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=cfg.device) if cfg.use_amp else nullcontext():
|
||||||
_, output_dict = policy.forward(batch)
|
_, output_dict = policy.forward(batch)
|
||||||
|
|
||||||
bsize = batch["episode_index"].shape[0]
|
batch_size = batch["episode_index"].shape[0]
|
||||||
episode_indices.append(batch["episode_index"])
|
episode_indices.append(batch["episode_index"])
|
||||||
frame_indices.append(batch["frame_index"])
|
frame_indices.append(batch["frame_index"])
|
||||||
|
|
||||||
for key in output_dict:
|
for key, value in output_dict.items():
|
||||||
if "loss_per_item" not in key:
|
if not isinstance(value, torch.Tensor) or value.shape[0] != batch_size:
|
||||||
|
print(f"Skipping {key}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if key not in feats:
|
if key not in feats:
|
||||||
feats[key] = []
|
feats[key] = []
|
||||||
|
|
||||||
if not (output_dict[key].ndim == 1 and output_dict[key].shape[0] == bsize):
|
feats[key].append(value)
|
||||||
raise ValueError(output_dict[key].shape)
|
|
||||||
|
|
||||||
feats[key].append(output_dict[key])
|
episode_indices = torch.cat(episode_indices).cpu()
|
||||||
|
frame_indices = torch.cat(frame_indices).cpu()
|
||||||
|
|
||||||
episode_indices = torch.cat(episode_indices)
|
# TODO(rcadene): use collate?
|
||||||
frame_indices = torch.cat(frame_indices)
|
for key, value in feats.items():
|
||||||
|
if isinstance(value[0], (float, int)):
|
||||||
for key in feats:
|
feats[key] = torch.tensor(value)
|
||||||
feats[key] = torch.cat(feats[key])
|
elif isinstance(value[0], torch.Tensor):
|
||||||
|
feats[key] = torch.cat(value, dim=0).cpu()
|
||||||
|
elif isinstance(value[0], str):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{key}: {value}")
|
||||||
|
|
||||||
# Find unique episode indices
|
# Find unique episode indices
|
||||||
unique_episodes = torch.unique(episode_indices)
|
unique_episodes = torch.unique(episode_indices)
|
||||||
|
@ -147,7 +150,7 @@ def save_inference(cfg: SaveInferenceConfig):
|
||||||
for episode in unique_episodes:
|
for episode in unique_episodes:
|
||||||
ep_feats = {}
|
ep_feats = {}
|
||||||
for key in feats:
|
for key in feats:
|
||||||
ep_feats[key] = feats[key][episode_indices == episode].data.cpu()
|
ep_feats[key] = feats[key][episode_indices == episode]
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
torch.save(ep_feats, output_dir / f"output_features_episode_{episode}.pth")
|
torch.save(ep_feats, output_dir / f"output_features_episode_{episode}.pth")
|
||||||
|
|
|
@ -295,9 +295,27 @@ def get_episode_data(
|
||||||
if inference_dir is not None:
|
if inference_dir is not None:
|
||||||
feats = torch.load(inference_dir / f"output_features_episode_{episode_index}.pth")
|
feats = torch.load(inference_dir / f"output_features_episode_{episode_index}.pth")
|
||||||
for key in feats:
|
for key in feats:
|
||||||
|
if "loss_per_item" in key:
|
||||||
|
if feats[key].ndim != 1:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
header.append(key.replace("loss_per_item", "loss"))
|
header.append(key.replace("loss_per_item", "loss"))
|
||||||
rows = np.concatenate([rows, feats[key][:, None]], axis=1)
|
rows = np.concatenate([rows, feats[key][:, None]], axis=1)
|
||||||
|
|
||||||
|
elif key == "action":
|
||||||
|
if feats[key].ndim != 3:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
next_action = feats[key][:, 0, :]
|
||||||
|
num_motors = next_action.shape[1]
|
||||||
|
|
||||||
|
for i in range(num_motors):
|
||||||
|
header.append(f"action_{i}")
|
||||||
|
|
||||||
|
rows = np.concatenate([rows, next_action], axis=1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(key)
|
||||||
|
|
||||||
rows = rows.tolist()
|
rows = rows.tolist()
|
||||||
|
|
||||||
# Convert data to CSV string
|
# Convert data to CSV string
|
||||||
|
|
Loading…
Reference in New Issue