split img emb / state
This commit is contained in:
parent
02d55c0b9a
commit
311f79a874
|
@ -71,7 +71,9 @@ class VQBeTConfig:
|
|||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 5
|
||||
n_action_steps: int = 5
|
||||
# n_action_steps: int = 7
|
||||
n_action_pred_token: int = 3
|
||||
n_action_pred_chunk: int = 5
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
|
|
|
@ -77,8 +77,10 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
}
|
||||
if self.config.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
if self.config.n_action_pred_chunk is not None:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_pred_chunk) # original one
|
||||
# self._action_queue = deque([], maxlen=self.config.n_action_pred_token) # jay temp2
|
||||
self._action_history_queue = deque([], maxlen=self.config.n_obs_steps - 1)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
|
@ -96,6 +98,10 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
batch = self.normalize_inputs(batch)
|
||||
self._obs_queues = populate_queues(self._obs_queues, batch)
|
||||
|
||||
if len(self._action_history_queue) == 0:
|
||||
while len(self._action_history_queue) != self._action_history_queue.maxlen:
|
||||
self._action_history_queue.append(batch["observation.state"])
|
||||
|
||||
if not self.check_discretized():
|
||||
self.vqbet._action_head._vqvae_model.discretized = True
|
||||
# raise NotImplementedError(
|
||||
|
@ -109,16 +115,35 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
# for act seq pred, we should provide averaged act over horizon.
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
|
||||
# original one
|
||||
|
||||
batch = {key: torch.stack(list(self._obs_queues[key]), dim=1) for key in batch}
|
||||
actions = self.vqbet(batch)[:, : self.config.n_action_steps]
|
||||
batch["action"] = torch.stack(list(self._action_history_queue), dim=1)
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
|
||||
# jay temp2
|
||||
# batch = {key: torch.stack(list(self._obs_queues[key]), dim=1) for key in batch}
|
||||
# batch["action"] = torch.stack(list(self._action_history_queue), dim=1)
|
||||
# action_predicted = []
|
||||
# for i in range(self.config.n_action_pred_token):
|
||||
|
||||
# actions = self.vqbet(batch, rollout=True, input_predicted = action_predicted)[:, i]
|
||||
|
||||
# action_predicted.append(actions)
|
||||
# actions = self.unnormalize_outputs({"action": torch.stack(action_predicted)})["action"]
|
||||
# self._action_queue.extend(actions)
|
||||
#############
|
||||
|
||||
action = self._action_queue.popleft()
|
||||
self._action_history_queue.append(self.normalize_targets({"action": action})["action"])
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
|
@ -127,7 +152,60 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
if not self.check_discretized():
|
||||
loss = self.vqbet.discretize(self.config.discretize_step, batch['action'])
|
||||
return {"loss": loss}
|
||||
_, loss = self.vqbet(batch)
|
||||
|
||||
# original one
|
||||
_, loss = self.vqbet(batch, rollout=False)
|
||||
|
||||
# jay temp2
|
||||
# action_predicted = []
|
||||
# pred_values = []
|
||||
# for i in range(self.config.n_action_pred_token):
|
||||
# pred = self.vqbet(batch, rollout=False, input_predicted = action_predicted, recurrent = True)
|
||||
# actions = pred["predicted_action"][:, i]
|
||||
# action_predicted.append(actions)
|
||||
# pred_values.append(pred)
|
||||
|
||||
# # Stack 'a', 'b', and 'c' from each dictionary in the list
|
||||
# # jay TODO
|
||||
# predicted_action = torch.stack([d["predicted_action"][:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1)
|
||||
# decoded_action = torch.stack([einops.rearrange(d["decoded_action"], "(B T) 1 N->B T N", T=5)[:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1).reshape(-1, 1, 2)
|
||||
# sampled_centers = torch.stack([einops.rearrange(d["sampled_centers"], "(B T) N->B T N", T=5)[:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1).reshape(-1, 2)
|
||||
# cbet_logits = torch.stack([einops.rearrange(d["cbet_logits"], "(B T) G N->B T G N", T=5)[:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1).reshape(-1, 2, 16)
|
||||
|
||||
# gpt_output, G, NT, N, T = pred_values[-1]["input"], pred_values[-1]["G"], pred_values[-1]["NT"], pred_values[-1]["N"], pred_values[-1]["T"]
|
||||
|
||||
|
||||
# pred_action = {
|
||||
# "input": gpt_output,
|
||||
# "cbet_logits1": None,
|
||||
# "cbet_logits2": None,
|
||||
# "cbet_logits": cbet_logits if "cbet_logits" in locals() else None,
|
||||
# "predicted_action": predicted_action,
|
||||
# "decoded_action": decoded_action,
|
||||
# "sampled_centers": sampled_centers,
|
||||
# "G": G,
|
||||
# "NT": NT,
|
||||
# "N": N,
|
||||
# "T": T,
|
||||
# }
|
||||
|
||||
# action = batch["action"]
|
||||
# n, total_w, act_dim = action.shape
|
||||
# act_w = self.config.n_action_pred_chunk
|
||||
# num_token = total_w + 1 - act_w
|
||||
# output_shape = (n, num_token, act_w, act_dim)
|
||||
# output = torch.empty(output_shape).to(action.device)
|
||||
# for i in range(num_token):
|
||||
# output[:, i, :, :] = action[:, i : i + act_w, :]
|
||||
# action = output # batch size, 7, 5, 2
|
||||
|
||||
# loss = self.vqbet._action_head.loss_fn(
|
||||
# pred_action,
|
||||
# action[:, self.config.n_obs_steps-1:],
|
||||
# reduction="mean",
|
||||
# )[0]
|
||||
#############
|
||||
|
||||
return {"loss": loss['actor_loss']}
|
||||
|
||||
|
||||
|
@ -138,11 +216,22 @@ class VQBeTModel(nn.Module):
|
|||
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
|
||||
global_cond_dim = (config.output_shapes["action"][0] + self.rgb_encoder.feature_dim)
|
||||
self.global_cond_dim = self.rgb_encoder.feature_dim
|
||||
|
||||
# action token and EOS token
|
||||
self._action_token = nn.Parameter(torch.randn(1, 1, self.global_cond_dim)) # Batch, Timestep, Data type, GPT input dim
|
||||
self._eos_token = nn.Parameter(torch.randn(1, 1, self.global_cond_dim))
|
||||
|
||||
self.state_projector = MLP(
|
||||
config.output_shapes["action"][0], hidden_channels=[self.global_cond_dim]
|
||||
)
|
||||
self.action_projector = MLP(
|
||||
config.output_shapes["action"][0], hidden_channels=[self.global_cond_dim]
|
||||
)
|
||||
self._policy = GPT(
|
||||
GPTConfig(
|
||||
block_size=self.config.block_size,
|
||||
input_dim=global_cond_dim,
|
||||
input_dim=self.global_cond_dim,
|
||||
output_dim=self.config.output_dim,
|
||||
n_layer=self.config.n_layer,
|
||||
n_head=self.config.n_head,
|
||||
|
@ -157,14 +246,14 @@ class VQBeTModel(nn.Module):
|
|||
vqvae_groups=config.vqvae_groups,
|
||||
vqvae_n_embed=config.vqvae_n_embed,
|
||||
vqvae_embedding_dim=config.vqvae_embedding_dim,
|
||||
n_action_steps=config.n_action_steps
|
||||
n_action_pred_chunk=config.n_action_pred_chunk
|
||||
)
|
||||
|
||||
def discretize(self, discretize_step, actions):
|
||||
return self._action_head.discretize(discretize_step, actions)
|
||||
|
||||
# ========= inference ============
|
||||
def forward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool, input_predicted:list = None, recurrent = False) -> Tensor: # jay temp2
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
|
@ -175,13 +264,8 @@ class VQBeTModel(nn.Module):
|
|||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1)
|
||||
|
||||
obs = global_cond
|
||||
if 'action' in batch.keys():
|
||||
action = batch["action"]
|
||||
else:
|
||||
action = None
|
||||
|
||||
|
||||
if 'goal' in batch.keys():
|
||||
goal = batch["goal"]
|
||||
num_goal_token = goal.shape[1]
|
||||
|
@ -190,21 +274,70 @@ class VQBeTModel(nn.Module):
|
|||
num_goal_token = 0
|
||||
|
||||
|
||||
if rollout:
|
||||
gpt_input_action = torch.cat((batch["action"], torch.zeros(batch_size, 1, batch["action"].shape[-1]).to(get_device_from_parameters(self))), dim=1)
|
||||
else:
|
||||
gpt_input_action = batch["action"][..., :n_obs_steps, :]
|
||||
for i, val in enumerate(batch["frame_index"]):
|
||||
if val < n_obs_steps - 1:
|
||||
# Update original_data based on the rules
|
||||
gpt_input_action[i, :-(val+1)] = batch["observation.state"][i, :-(val+1)]
|
||||
|
||||
|
||||
gpt_input_action[:, -1] = 0
|
||||
|
||||
global_cond = torch.cat([
|
||||
torch.unsqueeze(img_features, dim=2),
|
||||
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
|
||||
# torch.unsqueeze(self.action_projector(gpt_input_action), dim=2) # jay temp
|
||||
], dim=-2).view(batch_size, -1, self.global_cond_dim) # batch size, num tokens, GPT input dim
|
||||
# global_cond = global_cond[:, :-1] # get rid of final action # jay temp
|
||||
|
||||
|
||||
# 앞선 action을 주고 뒤를 예측하게 해도 되고, 그냥 싹 다 예측하게 해도 될듯 -> 일단 action 앞에거까지 주는걸로
|
||||
# 토큰이 그냥 single action이어도 되고, multi step (chunk) 여도 될듯 -> 일단 chunk로
|
||||
eos_token = self._eos_token.repeat(batch_size, 1, 1)
|
||||
action_token = self._action_token.repeat(batch_size, self.config.n_action_pred_token, 1)
|
||||
|
||||
# jay temp2
|
||||
# if (input_predicted is not None) and (len(input_predicted) > 0):
|
||||
# len_input = len(input_predicted)
|
||||
# action_token[:, :len_input] = self.action_projector(torch.stack(input_predicted, dim=1).to(get_device_from_parameters(self)))
|
||||
|
||||
prompt_length = global_cond.shape[1]+1
|
||||
global_cond = torch.cat([global_cond, eos_token, action_token], dim=1)
|
||||
|
||||
|
||||
# get action features
|
||||
features = self._policy(obs)
|
||||
features = features[:, num_goal_token:]
|
||||
features = self._policy(global_cond)
|
||||
features = features[:, prompt_length:]
|
||||
# action head
|
||||
pred_action = self._action_head(
|
||||
features,
|
||||
**{"action_seq": action},
|
||||
# **{"action_seq": action},
|
||||
)
|
||||
|
||||
if action is None:
|
||||
return pred_action["predicted_action"][:, -1, :].reshape(batch_size, self.config.n_action_steps, -1)
|
||||
if rollout:
|
||||
return pred_action["predicted_action"][:, -1, :].reshape(batch_size, self.config.n_action_pred_chunk, -1) # original one
|
||||
# return pred_action["predicted_action"] # jay temp2
|
||||
|
||||
# should put this part between "if" and "else"
|
||||
# elif recurrent: # jay temp2 => training
|
||||
# return pred_action
|
||||
else:
|
||||
action = batch["action"]
|
||||
n, total_w, act_dim = action.shape
|
||||
act_w = self.config.n_action_pred_chunk
|
||||
num_token = total_w + 1 - act_w
|
||||
output_shape = (n, num_token, act_w, act_dim)
|
||||
output = torch.empty(output_shape).to(action.device)
|
||||
for i in range(num_token):
|
||||
output[:, i, :, :] = action[:, i : i + act_w, :]
|
||||
action = output # batch size, 7, 5, 2
|
||||
|
||||
loss = self._action_head.loss_fn(
|
||||
pred_action,
|
||||
action,
|
||||
action[:, n_obs_steps-1:],
|
||||
reduction="mean",
|
||||
)
|
||||
return pred_action, loss[0] if isinstance(loss, tuple) else loss
|
||||
|
@ -223,7 +356,7 @@ class VQBeTHead(nn.Module):
|
|||
vqvae_groups=2, # G(number of groups)
|
||||
vqvae_n_embed=16, # C(number of code integers)
|
||||
vqvae_embedding_dim=512, # D(embedding dims)
|
||||
n_action_steps=1, # action chunk size
|
||||
n_action_pred_chunk=1, # action chunk size
|
||||
):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
|
@ -236,7 +369,7 @@ class VQBeTHead(nn.Module):
|
|||
self._G = vqvae_groups # G(number of groups)
|
||||
self._C = vqvae_n_embed # C(number of code integers)
|
||||
self._D = vqvae_embedding_dim # D(embedding dims)
|
||||
self.n_action_steps = n_action_steps # action chunk size
|
||||
self.n_action_pred_chunk = n_action_pred_chunk # action chunk size
|
||||
|
||||
|
||||
if self.sequentially_select:
|
||||
|
@ -259,12 +392,12 @@ class VQBeTHead(nn.Module):
|
|||
hidden_channels=[
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
self._G * self._C * n_action_steps * self.output_size,
|
||||
self._G * self._C * n_action_pred_chunk * self.output_size,
|
||||
],
|
||||
)
|
||||
# init vqvae
|
||||
vqvae_config = {
|
||||
"action_chunk": self.n_action_steps,
|
||||
"action_chunk": self.n_action_pred_chunk,
|
||||
"action_dim": self.output_size,
|
||||
"vqvae_n_latent_dims": self._D,
|
||||
"vqvae_n_embed": self._C,
|
||||
|
@ -404,14 +537,14 @@ class VQBeTHead(nn.Module):
|
|||
predicted_action, "N T (W A) -> (N T) W A", W=self._vqvae_model.input_dim_h
|
||||
)
|
||||
|
||||
n, total_w, act_dim = action_seq.shape
|
||||
act_w = self._vqvae_model.input_dim_h
|
||||
obs_w = total_w + 1 - act_w
|
||||
output_shape = (n, obs_w, act_w, act_dim)
|
||||
output = torch.empty(output_shape).to(action_seq.device)
|
||||
for i in range(obs_w):
|
||||
output[:, i, :, :] = action_seq[:, i : i + act_w, :]
|
||||
action_seq = einops.rearrange(output, "N T W A -> (N T) W A")
|
||||
# n, total_w, act_dim = action_seq.shape
|
||||
# act_w = self._vqvae_model.input_dim_h
|
||||
# obs_w = total_w + 1 - act_w
|
||||
# output_shape = (n, obs_w, act_w, act_dim)
|
||||
# output = torch.empty(output_shape).to(action_seq.device)
|
||||
# for i in range(obs_w):
|
||||
# output[:, i, :, :] = action_seq[:, i : i + act_w, :]
|
||||
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
||||
# Figure out the loss for the actions.
|
||||
# First, we need to find the closest cluster center for each action.
|
||||
state_vq, action_bins = self._vqvae_model.get_code(
|
||||
|
@ -425,43 +558,43 @@ class VQBeTHead(nn.Module):
|
|||
offset_loss = torch.nn.L1Loss()(action_seq, predicted_action)
|
||||
|
||||
action_diff = F.mse_loss(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[:, -1, 0, :],
|
||||
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, -1, 0, :],
|
||||
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
|
||||
:, -1, 0, :
|
||||
],
|
||||
) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout
|
||||
action_diff_tot = F.mse_loss(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[:, -1, :, :],
|
||||
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, -1, :, :],
|
||||
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
|
||||
:, -1, :, :
|
||||
],
|
||||
) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout
|
||||
action_diff_mean_res1 = (
|
||||
abs(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[
|
||||
:, -1, 0, :
|
||||
]
|
||||
- einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=obs_w)[
|
||||
- einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T)[
|
||||
:, -1, 0, :
|
||||
]
|
||||
)
|
||||
).mean()
|
||||
action_diff_mean_res2 = (
|
||||
abs(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[
|
||||
:, -1, 0, :
|
||||
]
|
||||
- einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[
|
||||
- einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
|
||||
:, -1, 0, :
|
||||
]
|
||||
)
|
||||
).mean()
|
||||
action_diff_max = (
|
||||
abs(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[
|
||||
:, -1, 0, :
|
||||
]
|
||||
- einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[
|
||||
- einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
|
||||
:, -1, 0, :
|
||||
]
|
||||
)
|
||||
|
@ -692,6 +825,19 @@ class VQBeTOptimizer:
|
|||
learning_rate=cfg.training.bet_learning_rate,
|
||||
betas=cfg.training.bet_betas,
|
||||
)
|
||||
|
||||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet._action_token}
|
||||
)
|
||||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet._eos_token}
|
||||
)
|
||||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet.state_projector.parameters()}
|
||||
)
|
||||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet.action_projector.parameters()}
|
||||
)
|
||||
# if policy.vqbet._action_head.sequentially_select:
|
||||
# self.bet_optimizer1.add_param_group(
|
||||
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
|
||||
|
|
|
@ -22,10 +22,10 @@ override_dataset_stats:
|
|||
max: [511.0, 511.0]
|
||||
|
||||
training:
|
||||
offline_steps: 200000
|
||||
offline_steps: 800000
|
||||
online_steps: 0
|
||||
eval_freq: 10000
|
||||
save_freq: 5000
|
||||
eval_freq: 20000 # jay
|
||||
save_freq: 20000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
|
||||
|
@ -41,7 +41,7 @@ training:
|
|||
|
||||
# VQ-BeT specific
|
||||
vqvae_lr: 1.0e-3
|
||||
discretize_step: 3000
|
||||
discretize_step: 20000 # jay
|
||||
bet_weight_decay: 2e-4
|
||||
bet_learning_rate: 5.5e-5
|
||||
bet_betas: [0.9, 0.999]
|
||||
|
@ -49,7 +49,7 @@ training:
|
|||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_steps})]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.n_action_pred_chunk} - 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
@ -60,7 +60,9 @@ policy:
|
|||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 5
|
||||
n_action_steps: 5
|
||||
# n_action_steps: 7 # n_action_pred_token + n_action_pred_window - 1
|
||||
n_action_pred_token: 3 # 3 5
|
||||
n_action_pred_chunk: 5 # 5 1 jay temp2
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
|
@ -90,7 +92,7 @@ policy:
|
|||
vqvae_n_embed: 16
|
||||
vqvae_embedding_dim: 256
|
||||
# VQ-BeT
|
||||
block_size: 50
|
||||
block_size: 500
|
||||
output_dim: 256 # 512
|
||||
n_layer: 6 # 8
|
||||
n_head: 6 # 4
|
||||
|
|
Loading…
Reference in New Issue