add rollout part
This commit is contained in:
parent
f6a5f9643f
commit
f2d9b70f46
|
@ -89,7 +89,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
queue is empty.
|
||||
"""
|
||||
|
||||
# jay TODO
|
||||
# seungjae TODO: implement averaging action over horizons
|
||||
|
||||
self.eval()
|
||||
|
||||
|
@ -110,8 +110,8 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
if len(self._action_queue) == 0:
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
actions = self.vqbet(batch)[0, : self.config.n_action_steps]
|
||||
batch = {key: torch.stack(list(self._obs_queues[key]), dim=1) for key in batch}
|
||||
actions = self.vqbet(batch)[:, : self.config.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
@ -165,7 +165,7 @@ class VQBeTModel(nn.Module):
|
|||
# ========= inference ============
|
||||
def forward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.image", "action"})
|
||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
|
@ -199,7 +199,7 @@ class VQBeTModel(nn.Module):
|
|||
)
|
||||
|
||||
if action is None:
|
||||
return pred_action
|
||||
return pred_action["predicted_action"][:, -1, :].reshape(batch_size, self.config.n_action_steps, -1)
|
||||
else:
|
||||
loss = self._action_head.loss_fn(
|
||||
pred_action,
|
||||
|
|
|
@ -10,7 +10,7 @@ dataset_repo_id: lerobot/pusht
|
|||
training:
|
||||
offline_steps: 200000
|
||||
online_steps: 0
|
||||
eval_freq: 100 # jay
|
||||
eval_freq: 10000
|
||||
save_freq: 5000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
|
@ -27,7 +27,7 @@ training:
|
|||
|
||||
# VQ-BeT specific
|
||||
vqvae_lr: 1.0e-3
|
||||
discretize_step: 30 # jay
|
||||
discretize_step: 3000
|
||||
bet_weight_decay: 2e-4
|
||||
bet_learning_rate: 5.5e-5
|
||||
bet_betas: [0.9, 0.999]
|
||||
|
|
Loading…
Reference in New Issue