add rollout part

This commit is contained in:
jayLEE0301 2024-05-08 15:41:34 -04:00
parent f6a5f9643f
commit f2d9b70f46
2 changed files with 7 additions and 7 deletions

View File

@ -89,7 +89,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
queue is empty.
"""
# jay TODO
# seungjae TODO: implement averaging action over horizons
self.eval()
@ -110,8 +110,8 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
if len(self._action_queue) == 0:
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
actions = self.vqbet(batch)[0, : self.config.n_action_steps]
batch = {key: torch.stack(list(self._obs_queues[key]), dim=1) for key in batch}
actions = self.vqbet(batch)[:, : self.config.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
@ -165,7 +165,7 @@ class VQBeTModel(nn.Module):
# ========= inference ============
def forward(self, batch: dict[str, Tensor]) -> Tensor:
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.image", "action"})
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
@ -199,7 +199,7 @@ class VQBeTModel(nn.Module):
)
if action is None:
return pred_action
return pred_action["predicted_action"][:, -1, :].reshape(batch_size, self.config.n_action_steps, -1)
else:
loss = self._action_head.loss_fn(
pred_action,

View File

@ -10,7 +10,7 @@ dataset_repo_id: lerobot/pusht
training:
offline_steps: 200000
online_steps: 0
eval_freq: 100 # jay
eval_freq: 10000
save_freq: 5000
log_freq: 250
save_model: true
@ -27,7 +27,7 @@ training:
# VQ-BeT specific
vqvae_lr: 1.0e-3
discretize_step: 30 # jay
discretize_step: 3000
bet_weight_decay: 2e-4
bet_learning_rate: 5.5e-5
bet_betas: [0.9, 0.999]