backup wip

This commit is contained in:
Alexander Soare 2024-06-07 19:11:43 +01:00
parent 1eb4bfe2e4
commit c5d50f42b9
4 changed files with 31 additions and 23 deletions

View File

@ -119,18 +119,19 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
return action
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
actions = self.model(batch)[0][:, : self.config.n_action_steps]
# # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# # querying the policy.
# if len(self._action_queue) == 0:
actions = self.model(batch)[0][:, : self.config.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
# self._action_queue.extend(actions.transpose(0, 1))
# return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""

View File

@ -102,7 +102,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
}
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def select_action(self, batch: dict[str, Tensor], update_queue: bool = False) -> Tensor:
"""Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the
@ -128,18 +128,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
# self._queues["action"].extend(actions.transpose(0, 1))
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
return action
# action = self._queues["action"].popleft()
# return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
@ -243,8 +243,9 @@ class DiffusionModel(nn.Module):
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.config.n_action_steps
actions = actions[:, start:end]
# end = start + self.config.n_action_steps
# actions = actions[:, start:end]
actions[:, start:]
return actions

View File

@ -8,6 +8,9 @@ env:
state_dim: 14
action_dim: 14
fps: ${fps}
# This environment runs actions and observations on the same clock cycle. Therefore the minimum latency
# that a policy would have to account for is one clock cycle.
min_observation_action_latency: 1 / ${fps}
episode_length: 400
gym:
fps: ${fps}

View File

@ -24,12 +24,15 @@ training:
grad_clip_norm: 10
online_steps_between_rollouts: 1
# Learn to predict actions starting from the "current" timestep
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes: 50
batch_size: 50
# When rolling out in a real-time environment, set this to the minimum inference latency you expect to have.
min_observation_action_latency: 0
# See `configuration_act.py` for more details.
policy: