split img emb / state

This commit is contained in:
jayLEE0301 2024-05-13 20:33:58 -04:00
parent 02d55c0b9a
commit 311f79a874
3 changed files with 204 additions and 54 deletions

View File

@ -71,7 +71,9 @@ class VQBeTConfig:
# Inputs / output structure.
n_obs_steps: int = 5
n_action_steps: int = 5
# n_action_steps: int = 7
n_action_pred_token: int = 3
n_action_pred_chunk: int = 5
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {

View File

@ -77,8 +77,10 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
"observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
}
if self.config.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
if self.config.n_action_pred_chunk is not None:
self._action_queue = deque([], maxlen=self.config.n_action_pred_chunk) # original one
# self._action_queue = deque([], maxlen=self.config.n_action_pred_token) # jay temp2
self._action_history_queue = deque([], maxlen=self.config.n_obs_steps - 1)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@ -96,6 +98,10 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_inputs(batch)
self._obs_queues = populate_queues(self._obs_queues, batch)
if len(self._action_history_queue) == 0:
while len(self._action_history_queue) != self._action_history_queue.maxlen:
self._action_history_queue.append(batch["observation.state"])
if not self.check_discretized():
self.vqbet._action_head._vqvae_model.discretized = True
# raise NotImplementedError(
@ -109,16 +115,35 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
# for act seq pred, we should provide averaged act over horizon.
if len(self._action_queue) == 0:
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
# original one
batch = {key: torch.stack(list(self._obs_queues[key]), dim=1) for key in batch}
actions = self.vqbet(batch)[:, : self.config.n_action_steps]
batch["action"] = torch.stack(list(self._action_history_queue), dim=1)
actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
# jay temp2
# batch = {key: torch.stack(list(self._obs_queues[key]), dim=1) for key in batch}
# batch["action"] = torch.stack(list(self._action_history_queue), dim=1)
# action_predicted = []
# for i in range(self.config.n_action_pred_token):
# actions = self.vqbet(batch, rollout=True, input_predicted = action_predicted)[:, i]
# action_predicted.append(actions)
# actions = self.unnormalize_outputs({"action": torch.stack(action_predicted)})["action"]
# self._action_queue.extend(actions)
#############
action = self._action_queue.popleft()
self._action_history_queue.append(self.normalize_targets({"action": action})["action"])
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
@ -127,7 +152,60 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
if not self.check_discretized():
loss = self.vqbet.discretize(self.config.discretize_step, batch['action'])
return {"loss": loss}
_, loss = self.vqbet(batch)
# original one
_, loss = self.vqbet(batch, rollout=False)
# jay temp2
# action_predicted = []
# pred_values = []
# for i in range(self.config.n_action_pred_token):
# pred = self.vqbet(batch, rollout=False, input_predicted = action_predicted, recurrent = True)
# actions = pred["predicted_action"][:, i]
# action_predicted.append(actions)
# pred_values.append(pred)
# # Stack 'a', 'b', and 'c' from each dictionary in the list
# # jay TODO
# predicted_action = torch.stack([d["predicted_action"][:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1)
# decoded_action = torch.stack([einops.rearrange(d["decoded_action"], "(B T) 1 N->B T N", T=5)[:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1).reshape(-1, 1, 2)
# sampled_centers = torch.stack([einops.rearrange(d["sampled_centers"], "(B T) N->B T N", T=5)[:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1).reshape(-1, 2)
# cbet_logits = torch.stack([einops.rearrange(d["cbet_logits"], "(B T) G N->B T G N", T=5)[:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1).reshape(-1, 2, 16)
# gpt_output, G, NT, N, T = pred_values[-1]["input"], pred_values[-1]["G"], pred_values[-1]["NT"], pred_values[-1]["N"], pred_values[-1]["T"]
# pred_action = {
# "input": gpt_output,
# "cbet_logits1": None,
# "cbet_logits2": None,
# "cbet_logits": cbet_logits if "cbet_logits" in locals() else None,
# "predicted_action": predicted_action,
# "decoded_action": decoded_action,
# "sampled_centers": sampled_centers,
# "G": G,
# "NT": NT,
# "N": N,
# "T": T,
# }
# action = batch["action"]
# n, total_w, act_dim = action.shape
# act_w = self.config.n_action_pred_chunk
# num_token = total_w + 1 - act_w
# output_shape = (n, num_token, act_w, act_dim)
# output = torch.empty(output_shape).to(action.device)
# for i in range(num_token):
# output[:, i, :, :] = action[:, i : i + act_w, :]
# action = output # batch size, 7, 5, 2
# loss = self.vqbet._action_head.loss_fn(
# pred_action,
# action[:, self.config.n_obs_steps-1:],
# reduction="mean",
# )[0]
#############
return {"loss": loss['actor_loss']}
@ -138,11 +216,22 @@ class VQBeTModel(nn.Module):
self.rgb_encoder = DiffusionRgbEncoder(config)
global_cond_dim = (config.output_shapes["action"][0] + self.rgb_encoder.feature_dim)
self.global_cond_dim = self.rgb_encoder.feature_dim
# action token and EOS token
self._action_token = nn.Parameter(torch.randn(1, 1, self.global_cond_dim)) # Batch, Timestep, Data type, GPT input dim
self._eos_token = nn.Parameter(torch.randn(1, 1, self.global_cond_dim))
self.state_projector = MLP(
config.output_shapes["action"][0], hidden_channels=[self.global_cond_dim]
)
self.action_projector = MLP(
config.output_shapes["action"][0], hidden_channels=[self.global_cond_dim]
)
self._policy = GPT(
GPTConfig(
block_size=self.config.block_size,
input_dim=global_cond_dim,
input_dim=self.global_cond_dim,
output_dim=self.config.output_dim,
n_layer=self.config.n_layer,
n_head=self.config.n_head,
@ -157,14 +246,14 @@ class VQBeTModel(nn.Module):
vqvae_groups=config.vqvae_groups,
vqvae_n_embed=config.vqvae_n_embed,
vqvae_embedding_dim=config.vqvae_embedding_dim,
n_action_steps=config.n_action_steps
n_action_pred_chunk=config.n_action_pred_chunk
)
def discretize(self, discretize_step, actions):
return self._action_head.discretize(discretize_step, actions)
# ========= inference ============
def forward(self, batch: dict[str, Tensor]) -> Tensor:
def forward(self, batch: dict[str, Tensor], rollout: bool, input_predicted:list = None, recurrent = False) -> Tensor: # jay temp2
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
@ -175,13 +264,8 @@ class VQBeTModel(nn.Module):
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1)
obs = global_cond
if 'action' in batch.keys():
action = batch["action"]
else:
action = None
if 'goal' in batch.keys():
goal = batch["goal"]
num_goal_token = goal.shape[1]
@ -190,21 +274,70 @@ class VQBeTModel(nn.Module):
num_goal_token = 0
if rollout:
gpt_input_action = torch.cat((batch["action"], torch.zeros(batch_size, 1, batch["action"].shape[-1]).to(get_device_from_parameters(self))), dim=1)
else:
gpt_input_action = batch["action"][..., :n_obs_steps, :]
for i, val in enumerate(batch["frame_index"]):
if val < n_obs_steps - 1:
# Update original_data based on the rules
gpt_input_action[i, :-(val+1)] = batch["observation.state"][i, :-(val+1)]
gpt_input_action[:, -1] = 0
global_cond = torch.cat([
torch.unsqueeze(img_features, dim=2),
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
# torch.unsqueeze(self.action_projector(gpt_input_action), dim=2) # jay temp
], dim=-2).view(batch_size, -1, self.global_cond_dim) # batch size, num tokens, GPT input dim
# global_cond = global_cond[:, :-1] # get rid of final action # jay temp
# 앞선 action을 주고 뒤를 예측하게 해도 되고, 그냥 싹 다 예측하게 해도 될듯 -> 일단 action 앞에거까지 주는걸로
# 토큰이 그냥 single action이어도 되고, multi step (chunk) 여도 될듯 -> 일단 chunk로
eos_token = self._eos_token.repeat(batch_size, 1, 1)
action_token = self._action_token.repeat(batch_size, self.config.n_action_pred_token, 1)
# jay temp2
# if (input_predicted is not None) and (len(input_predicted) > 0):
# len_input = len(input_predicted)
# action_token[:, :len_input] = self.action_projector(torch.stack(input_predicted, dim=1).to(get_device_from_parameters(self)))
prompt_length = global_cond.shape[1]+1
global_cond = torch.cat([global_cond, eos_token, action_token], dim=1)
# get action features
features = self._policy(obs)
features = features[:, num_goal_token:]
features = self._policy(global_cond)
features = features[:, prompt_length:]
# action head
pred_action = self._action_head(
features,
**{"action_seq": action},
# **{"action_seq": action},
)
if action is None:
return pred_action["predicted_action"][:, -1, :].reshape(batch_size, self.config.n_action_steps, -1)
if rollout:
return pred_action["predicted_action"][:, -1, :].reshape(batch_size, self.config.n_action_pred_chunk, -1) # original one
# return pred_action["predicted_action"] # jay temp2
# should put this part between "if" and "else"
# elif recurrent: # jay temp2 => training
# return pred_action
else:
action = batch["action"]
n, total_w, act_dim = action.shape
act_w = self.config.n_action_pred_chunk
num_token = total_w + 1 - act_w
output_shape = (n, num_token, act_w, act_dim)
output = torch.empty(output_shape).to(action.device)
for i in range(num_token):
output[:, i, :, :] = action[:, i : i + act_w, :]
action = output # batch size, 7, 5, 2
loss = self._action_head.loss_fn(
pred_action,
action,
action[:, n_obs_steps-1:],
reduction="mean",
)
return pred_action, loss[0] if isinstance(loss, tuple) else loss
@ -223,7 +356,7 @@ class VQBeTHead(nn.Module):
vqvae_groups=2, # G(number of groups)
vqvae_n_embed=16, # C(number of code integers)
vqvae_embedding_dim=512, # D(embedding dims)
n_action_steps=1, # action chunk size
n_action_pred_chunk=1, # action chunk size
):
super().__init__()
self.input_size = input_size
@ -236,7 +369,7 @@ class VQBeTHead(nn.Module):
self._G = vqvae_groups # G(number of groups)
self._C = vqvae_n_embed # C(number of code integers)
self._D = vqvae_embedding_dim # D(embedding dims)
self.n_action_steps = n_action_steps # action chunk size
self.n_action_pred_chunk = n_action_pred_chunk # action chunk size
if self.sequentially_select:
@ -259,12 +392,12 @@ class VQBeTHead(nn.Module):
hidden_channels=[
self.hidden_size,
self.hidden_size,
self._G * self._C * n_action_steps * self.output_size,
self._G * self._C * n_action_pred_chunk * self.output_size,
],
)
# init vqvae
vqvae_config = {
"action_chunk": self.n_action_steps,
"action_chunk": self.n_action_pred_chunk,
"action_dim": self.output_size,
"vqvae_n_latent_dims": self._D,
"vqvae_n_embed": self._C,
@ -404,14 +537,14 @@ class VQBeTHead(nn.Module):
predicted_action, "N T (W A) -> (N T) W A", W=self._vqvae_model.input_dim_h
)
n, total_w, act_dim = action_seq.shape
act_w = self._vqvae_model.input_dim_h
obs_w = total_w + 1 - act_w
output_shape = (n, obs_w, act_w, act_dim)
output = torch.empty(output_shape).to(action_seq.device)
for i in range(obs_w):
output[:, i, :, :] = action_seq[:, i : i + act_w, :]
action_seq = einops.rearrange(output, "N T W A -> (N T) W A")
# n, total_w, act_dim = action_seq.shape
# act_w = self._vqvae_model.input_dim_h
# obs_w = total_w + 1 - act_w
# output_shape = (n, obs_w, act_w, act_dim)
# output = torch.empty(output_shape).to(action_seq.device)
# for i in range(obs_w):
# output[:, i, :, :] = action_seq[:, i : i + act_w, :]
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
# Figure out the loss for the actions.
# First, we need to find the closest cluster center for each action.
state_vq, action_bins = self._vqvae_model.get_code(
@ -425,43 +558,43 @@ class VQBeTHead(nn.Module):
offset_loss = torch.nn.L1Loss()(action_seq, predicted_action)
action_diff = F.mse_loss(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[:, -1, 0, :],
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, -1, 0, :],
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
:, -1, 0, :
],
) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout
action_diff_tot = F.mse_loss(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[:, -1, :, :],
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, -1, :, :],
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
:, -1, :, :
],
) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout
action_diff_mean_res1 = (
abs(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[
:, -1, 0, :
]
- einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=obs_w)[
- einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T)[
:, -1, 0, :
]
)
).mean()
action_diff_mean_res2 = (
abs(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[
:, -1, 0, :
]
- einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[
- einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
:, -1, 0, :
]
)
).mean()
action_diff_max = (
abs(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[
:, -1, 0, :
]
- einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[
- einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
:, -1, 0, :
]
)
@ -692,6 +825,19 @@ class VQBeTOptimizer:
learning_rate=cfg.training.bet_learning_rate,
betas=cfg.training.bet_betas,
)
self.bet_optimizer1.add_param_group(
{"params": policy.vqbet._action_token}
)
self.bet_optimizer1.add_param_group(
{"params": policy.vqbet._eos_token}
)
self.bet_optimizer1.add_param_group(
{"params": policy.vqbet.state_projector.parameters()}
)
self.bet_optimizer1.add_param_group(
{"params": policy.vqbet.action_projector.parameters()}
)
# if policy.vqbet._action_head.sequentially_select:
# self.bet_optimizer1.add_param_group(
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}

View File

@ -22,10 +22,10 @@ override_dataset_stats:
max: [511.0, 511.0]
training:
offline_steps: 200000
offline_steps: 800000
online_steps: 0
eval_freq: 10000
save_freq: 5000
eval_freq: 20000 # jay
save_freq: 20000
log_freq: 250
save_model: true
@ -41,7 +41,7 @@ training:
# VQ-BeT specific
vqvae_lr: 1.0e-3
discretize_step: 3000
discretize_step: 20000 # jay
bet_weight_decay: 2e-4
bet_learning_rate: 5.5e-5
bet_betas: [0.9, 0.999]
@ -49,7 +49,7 @@ training:
delta_timestamps:
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_steps})]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.n_action_pred_chunk} - 1)]"
eval:
n_episodes: 50
@ -60,7 +60,9 @@ policy:
# Input / output structure.
n_obs_steps: 5
n_action_steps: 5
# n_action_steps: 7 # n_action_pred_token + n_action_pred_window - 1
n_action_pred_token: 3 # 3 5
n_action_pred_chunk: 5 # 5 1 jay temp2
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
@ -90,7 +92,7 @@ policy:
vqvae_n_embed: 16
vqvae_embedding_dim: 256
# VQ-BeT
block_size: 50
block_size: 500
output_dim: 256 # 512
n_layer: 6 # 8
n_head: 6 # 4