backup wip

This commit is contained in:
Alexander Soare 2024-04-02 19:13:49 +01:00
parent 2b928eedd4
commit 65ef8c30d0
3 changed files with 80 additions and 19 deletions

View File

@ -125,32 +125,92 @@ def make_offline_buffer(
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
# TODO(now): These stats are needed to use their pretrained model for sim_transfer_cube_human. # TODO(now): These stats are needed to use their pretrained model for sim_transfer_cube_human.
# (Pdb) stats['observation']['state']['mean'] # (Pdb) stats['observation']['state']['mean']
# tensor([-0.0071, -0.6293, 1.0351, -0.0517, -0.4642, -0.0754, 0.4751, -0.0373, # tensor([-0.0071, -0.6293, 1.0351, -0.0517, -0.4642, -0.0754, 0.4751, -0.0373,
# -0.3324, 0.9034, -0.2258, -0.3127, -0.2412, 0.6866]) # -0.3324, 0.9034, -0.2258, -0.3127, -0.2412, 0.6866])
stats['observation', 'state', 'mean'] = torch.tensor([-0.00740268, -0.63187766, 1.0356655 , -0.05027218, -0.46199223, stats["observation", "state", "mean"] = torch.tensor(
-0.07467502, 0.47467607, -0.03615446, -0.33203387, 0.9038929 , [
-0.22060776, -0.31011587, -0.23484458, 0.6842416 ]) -0.00740268,
-0.63187766,
1.0356655,
-0.05027218,
-0.46199223,
-0.07467502,
0.47467607,
-0.03615446,
-0.33203387,
0.9038929,
-0.22060776,
-0.31011587,
-0.23484458,
0.6842416,
]
)
# (Pdb) stats['observation']['state']['std'] # (Pdb) stats['observation']['state']['std']
# tensor([0.0022, 0.0520, 0.0291, 0.0092, 0.0267, 0.0145, 0.0563, 0.0179, 0.0494, # tensor([0.0022, 0.0520, 0.0291, 0.0092, 0.0267, 0.0145, 0.0563, 0.0179, 0.0494,
# 0.0326, 0.0476, 0.0535, 0.0956, 0.0513]) # 0.0326, 0.0476, 0.0535, 0.0956, 0.0513])
stats['observation', 'state', 'std'] = torch.tensor([0.01219023, 0.2975381 , 0.16728032, 0.04733803, 0.1486037 , stats["observation", "state", "std"] = torch.tensor(
0.08788499, 0.31752336, 0.1049916 , 0.27933604, 0.18094037, [
0.26604933, 0.30466506, 0.5298686 , 0.25505227]) 0.01219023,
0.2975381,
0.16728032,
0.04733803,
0.1486037,
0.08788499,
0.31752336,
0.1049916,
0.27933604,
0.18094037,
0.26604933,
0.30466506,
0.5298686,
0.25505227,
]
)
# (Pdb) stats['action']['mean'] # (Pdb) stats['action']['mean']
# tensor([-0.0075, -0.6346, 1.0353, -0.0465, -0.4686, -0.0738, 0.3723, -0.0396, # tensor([-0.0075, -0.6346, 1.0353, -0.0465, -0.4686, -0.0738, 0.3723, -0.0396,
# -0.3184, 0.8991, -0.2065, -0.3182, -0.2338, 0.5593]) # -0.3184, 0.8991, -0.2065, -0.3182, -0.2338, 0.5593])
stats['action']['mean'] = torch.tensor([-0.00756444, -0.6281845 , 1.0312834 , -0.04664314, -0.47211358, stats["action"]["mean"] = torch.tensor(
-0.074527 , 0.37389806, -0.03718753, -0.3261143 , 0.8997205 , [
-0.21371077, -0.31840396, -0.23360962, 0.551947]) -0.00756444,
-0.6281845,
1.0312834,
-0.04664314,
-0.47211358,
-0.074527,
0.37389806,
-0.03718753,
-0.3261143,
0.8997205,
-0.21371077,
-0.31840396,
-0.23360962,
0.551947,
]
)
# (Pdb) stats['action']['std'] # (Pdb) stats['action']['std']
# tensor([0.0023, 0.0514, 0.0290, 0.0086, 0.0263, 0.0143, 0.0593, 0.0185, 0.0510, # tensor([0.0023, 0.0514, 0.0290, 0.0086, 0.0263, 0.0143, 0.0593, 0.0185, 0.0510,
# 0.0328, 0.0478, 0.0531, 0.0945, 0.0794]) # 0.0328, 0.0478, 0.0531, 0.0945, 0.0794])
stats['action']['std'] = torch.tensor([0.01252818, 0.2957442 , 0.16701928, 0.04584508, 0.14833844, stats["action"]["std"] = torch.tensor(
0.08763024, 0.30665937, 0.10600077, 0.27572668, 0.1805853 , [
0.26304692, 0.30708534, 0.5305411 , 0.38381037]) 0.01252818,
0.2957442,
0.16701928,
0.04584508,
0.14833844,
0.08763024,
0.30665937,
0.10600077,
0.27572668,
0.1805853,
0.26304692,
0.30708534,
0.5305411,
0.38381037,
]
)
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
offline_buffer.set_transform(transforms) offline_buffer.set_transform(transforms)

View File

@ -2,7 +2,6 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.autograd import Variable from torch.autograd import Variable
from transformers import DetrForObjectDetection
from .backbone import build_backbone from .backbone import build_backbone
from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer
@ -74,7 +73,7 @@ class ActionChunkingTransformer(nn.Module):
hidden_dim = transformer.d_model hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, action_dim) self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1) self.is_pad_head = nn.Linear(hidden_dim, 1)
# Positional embedding to be used as input to the latent vae_encoder (if applicable) and for the # Positional embedding to be used as input to the latent vae_encoder (if applicable) and for the
self.pos_embed = nn.Embedding(horizon, hidden_dim) self.pos_embed = nn.Embedding(horizon, hidden_dim)
if backbones is not None: if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
@ -134,7 +133,9 @@ class ActionChunkingTransformer(nn.Module):
pos_embed = self.pos_table.clone().detach() pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model # query model
vae_encoder_output = self.vae_encoder(vae_encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad) vae_encoder_output = self.vae_encoder(
vae_encoder_input, pos=pos_embed
) # , src_key_padding_mask=is_pad)
vae_encoder_output = vae_encoder_output[0] # take cls output only vae_encoder_output = vae_encoder_output[0] # take cls output only
latent_info = self.latent_proj(vae_encoder_output) latent_info = self.latent_proj(vae_encoder_output)
mu = latent_info[:, : self.latent_dim] mu = latent_info[:, : self.latent_dim]
@ -219,7 +220,7 @@ def build(args):
backbones.append(backbone) backbones.append(backbone)
transformer = build_transformer(args) transformer = build_transformer(args)
vae_encoder = build_vae_encoder(args) vae_encoder = build_vae_encoder(args)
model = ActionChunkingTransformer( model = ActionChunkingTransformer(

View File

@ -54,7 +54,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
Args: Args:
vae: Whether to use the variational objective. TODO(now): Give more details. vae: Whether to use the variational objective. TODO(now): Give more details.
temporal_agg: Whether to do temporal aggregation. For each timestep during rollout, the action temporal_agg: Whether to do temporal aggregation. For each timestep during rollout, the action
returned as an exponential moving average of previously generated actions for that timestep. returned as an exponential moving average of previously generated actions for that timestep.
n_obs_steps: Number of time steps worth of observation to use as input. n_obs_steps: Number of time steps worth of observation to use as input.
horizon: The number of actions to generate in one forward pass. horizon: The number of actions to generate in one forward pass.
kl_weight: Weight for KL divergence. Defaults to None. Only applicable when using the variational kl_weight: Weight for KL divergence. Defaults to None. Only applicable when using the variational
@ -120,7 +120,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
"action": action.to(self.device, non_blocking=True), "action": action.to(self.device, non_blocking=True),
} }
return out return out
start_time = time.time() start_time = time.time()
batch = replay_buffer.sample(batch_size) batch = replay_buffer.sample(batch_size)