backup wip
This commit is contained in:
parent
9c28ac8aa4
commit
1e71196fe3
|
@ -158,7 +158,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
|||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
logging.info("Initialize and feed offline buffer")
|
||||
frame_idx = 0
|
||||
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
|
@ -190,8 +190,14 @@ class AlohaDataset(torch.utils.data.Dataset):
|
|||
ep_dict[f"observation.images.{cam}"] = image[:-1]
|
||||
# ep_dict[f"next.observation.images.{cam}"] = image[1:]
|
||||
|
||||
assert isinstance(ep_id, int)
|
||||
self.data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
|
||||
assert len(self.data_ids_per_episode[ep_id]) == num_frames
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
frame_idx += num_frames
|
||||
|
||||
self.data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
|
|
|
@ -59,96 +59,95 @@ def make_dataset(
|
|||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
stats = compute_or_load_stats(stats_dataset)
|
||||
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
|
||||
# TODO(now): These stats are needed to use their pretrained model for sim_transfer_cube_human.
|
||||
# (Pdb) stats['observation']['state']['mean']
|
||||
# tensor([-0.0071, -0.6293, 1.0351, -0.0517, -0.4642, -0.0754, 0.4751, -0.0373,
|
||||
# -0.3324, 0.9034, -0.2258, -0.3127, -0.2412, 0.6866])
|
||||
stats["observation", "state", "mean"] = torch.tensor(
|
||||
[
|
||||
-0.00740268,
|
||||
-0.63187766,
|
||||
1.0356655,
|
||||
-0.05027218,
|
||||
-0.46199223,
|
||||
-0.07467502,
|
||||
0.47467607,
|
||||
-0.03615446,
|
||||
-0.33203387,
|
||||
0.9038929,
|
||||
-0.22060776,
|
||||
-0.31011587,
|
||||
-0.23484458,
|
||||
0.6842416,
|
||||
]
|
||||
)
|
||||
# (Pdb) stats['observation']['state']['std']
|
||||
# tensor([0.0022, 0.0520, 0.0291, 0.0092, 0.0267, 0.0145, 0.0563, 0.0179, 0.0494,
|
||||
# 0.0326, 0.0476, 0.0535, 0.0956, 0.0513])
|
||||
stats["observation", "state", "std"] = torch.tensor(
|
||||
[
|
||||
0.01219023,
|
||||
0.2975381,
|
||||
0.16728032,
|
||||
0.04733803,
|
||||
0.1486037,
|
||||
0.08788499,
|
||||
0.31752336,
|
||||
0.1049916,
|
||||
0.27933604,
|
||||
0.18094037,
|
||||
0.26604933,
|
||||
0.30466506,
|
||||
0.5298686,
|
||||
0.25505227,
|
||||
]
|
||||
)
|
||||
# (Pdb) stats['action']['mean']
|
||||
# tensor([-0.0075, -0.6346, 1.0353, -0.0465, -0.4686, -0.0738, 0.3723, -0.0396,
|
||||
# -0.3184, 0.8991, -0.2065, -0.3182, -0.2338, 0.5593])
|
||||
stats["action"]["mean"] = torch.tensor(
|
||||
[
|
||||
-0.00756444,
|
||||
-0.6281845,
|
||||
1.0312834,
|
||||
-0.04664314,
|
||||
-0.47211358,
|
||||
-0.074527,
|
||||
0.37389806,
|
||||
-0.03718753,
|
||||
-0.3261143,
|
||||
0.8997205,
|
||||
-0.21371077,
|
||||
-0.31840396,
|
||||
-0.23360962,
|
||||
0.551947,
|
||||
]
|
||||
)
|
||||
# (Pdb) stats['action']['std']
|
||||
# tensor([0.0023, 0.0514, 0.0290, 0.0086, 0.0263, 0.0143, 0.0593, 0.0185, 0.0510,
|
||||
# 0.0328, 0.0478, 0.0531, 0.0945, 0.0794])
|
||||
stats["action"]["std"] = torch.tensor(
|
||||
[
|
||||
0.01252818,
|
||||
0.2957442,
|
||||
0.16701928,
|
||||
0.04584508,
|
||||
0.14833844,
|
||||
0.08763024,
|
||||
0.30665937,
|
||||
0.10600077,
|
||||
0.27572668,
|
||||
0.1805853,
|
||||
0.26304692,
|
||||
0.30708534,
|
||||
0.5305411,
|
||||
0.38381037,
|
||||
]
|
||||
)
|
||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) # noqa: F821
|
||||
# # TODO(now): These stats are needed to use their pretrained model for sim_transfer_cube_human.
|
||||
# # (Pdb) stats['observation']['state']['mean']
|
||||
# # tensor([-0.0071, -0.6293, 1.0351, -0.0517, -0.4642, -0.0754, 0.4751, -0.0373,
|
||||
# # -0.3324, 0.9034, -0.2258, -0.3127, -0.2412, 0.6866])
|
||||
# stats["observation", "state", "mean"] = torch.tensor(
|
||||
# [
|
||||
# -0.00740268,
|
||||
# -0.63187766,
|
||||
# 1.0356655,
|
||||
# -0.05027218,
|
||||
# -0.46199223,
|
||||
# -0.07467502,
|
||||
# 0.47467607,
|
||||
# -0.03615446,
|
||||
# -0.33203387,
|
||||
# 0.9038929,
|
||||
# -0.22060776,
|
||||
# -0.31011587,
|
||||
# -0.23484458,
|
||||
# 0.6842416,
|
||||
# ]
|
||||
# )
|
||||
# # (Pdb) stats['observation']['state']['std']
|
||||
# # tensor([0.0022, 0.0520, 0.0291, 0.0092, 0.0267, 0.0145, 0.0563, 0.0179, 0.0494,
|
||||
# # 0.0326, 0.0476, 0.0535, 0.0956, 0.0513])
|
||||
# stats["observation", "state", "std"] = torch.tensor(
|
||||
# [
|
||||
# 0.01219023,
|
||||
# 0.2975381,
|
||||
# 0.16728032,
|
||||
# 0.04733803,
|
||||
# 0.1486037,
|
||||
# 0.08788499,
|
||||
# 0.31752336,
|
||||
# 0.1049916,
|
||||
# 0.27933604,
|
||||
# 0.18094037,
|
||||
# 0.26604933,
|
||||
# 0.30466506,
|
||||
# 0.5298686,
|
||||
# 0.25505227,
|
||||
# ]
|
||||
# )
|
||||
# # (Pdb) stats['action']['mean']
|
||||
# # tensor([-0.0075, -0.6346, 1.0353, -0.0465, -0.4686, -0.0738, 0.3723, -0.0396,
|
||||
# # -0.3184, 0.8991, -0.2065, -0.3182, -0.2338, 0.5593])
|
||||
# stats["action"]["mean"] = torch.tensor(
|
||||
# [
|
||||
# -0.00756444,
|
||||
# -0.6281845,
|
||||
# 1.0312834,
|
||||
# -0.04664314,
|
||||
# -0.47211358,
|
||||
# -0.074527,
|
||||
# 0.37389806,
|
||||
# -0.03718753,
|
||||
# -0.3261143,
|
||||
# 0.8997205,
|
||||
# -0.21371077,
|
||||
# -0.31840396,
|
||||
# -0.23360962,
|
||||
# 0.551947,
|
||||
# ]
|
||||
# )
|
||||
# # (Pdb) stats['action']['std']
|
||||
# # tensor([0.0023, 0.0514, 0.0290, 0.0086, 0.0263, 0.0143, 0.0593, 0.0185, 0.0510,
|
||||
# # 0.0328, 0.0478, 0.0531, 0.0945, 0.0794])
|
||||
# stats["action"]["std"] = torch.tensor(
|
||||
# [
|
||||
# 0.01252818,
|
||||
# 0.2957442,
|
||||
# 0.16701928,
|
||||
# 0.04584508,
|
||||
# 0.14833844,
|
||||
# 0.08763024,
|
||||
# 0.30665937,
|
||||
# 0.10600077,
|
||||
# 0.27572668,
|
||||
# 0.1805853,
|
||||
# 0.26304692,
|
||||
# 0.30708534,
|
||||
# 0.5305411,
|
||||
# 0.38381037,
|
||||
# ]
|
||||
# )
|
||||
# transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) # noqa: F821
|
||||
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
|
@ -173,7 +172,11 @@ def make_dataset(
|
|||
"action": [-0.1] + [i / clsfunc.fps for i in range(15)],
|
||||
}
|
||||
else:
|
||||
delta_timestamps = None
|
||||
delta_timestamps = {
|
||||
"observation.images.top": [0],
|
||||
"observation.state": [0],
|
||||
"action": [i / clsfunc.fps for i in range(cfg.policy.horizon)],
|
||||
}
|
||||
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
|
|
|
@ -19,11 +19,10 @@ from torch import Tensor, nn
|
|||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
|
@ -61,205 +60,20 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
|||
"""
|
||||
|
||||
name = "act"
|
||||
_multiple_obs_steps_not_handled_msg = (
|
||||
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
||||
)
|
||||
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
"""
|
||||
TODO(alexander-soare): Add documentation for all parameters.
|
||||
"""
|
||||
super().__init__(n_action_steps)
|
||||
super().__init__()
|
||||
if getattr(cfg, "n_obs_steps", 1) != 1:
|
||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
||||
self.cfg = cfg
|
||||
self.n_action_steps = n_action_steps
|
||||
self.device = get_safe_torch_device(device)
|
||||
|
||||
self.model = _ActionChunkingTransformer(cfg)
|
||||
self._create_optimizer()
|
||||
self.to(self.device)
|
||||
|
||||
def _create_optimizer(self):
|
||||
optimizer_params_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.model.named_parameters()
|
||||
if not n.startswith("backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.model.named_parameters()
|
||||
if n.startswith("backbone") and p.requires_grad
|
||||
],
|
||||
"lr": self.cfg.lr_backbone,
|
||||
},
|
||||
]
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
|
||||
)
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
del step
|
||||
|
||||
self.train()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
def process_batch(batch, horizon, num_slices):
|
||||
# trajectory t = 64, horizon h = 16
|
||||
# (t h) ... -> t h ...
|
||||
batch = batch.reshape(num_slices, horizon)
|
||||
|
||||
image = batch["observation", "image", "top"]
|
||||
image = image[:, 0] # first observation t=0
|
||||
# batch, num_cam, channel, height, width
|
||||
image = image.unsqueeze(1)
|
||||
assert image.ndim == 5
|
||||
image = image.float()
|
||||
|
||||
state = batch["observation", "state"]
|
||||
state = state[:, 0] # first observation t=0
|
||||
# batch, qpos_dim
|
||||
assert state.ndim == 2
|
||||
|
||||
action = batch["action"]
|
||||
# batch, seq, action_dim
|
||||
assert action.ndim == 3
|
||||
assert action.shape[1] == horizon
|
||||
|
||||
if self.cfg.n_obs_steps > 1:
|
||||
raise NotImplementedError()
|
||||
# # keep first n observations of the slice corresponding to t=[-1,0]
|
||||
# image = image[:, : self.cfg.n_obs_steps]
|
||||
# state = state[:, : self.cfg.n_obs_steps]
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": image.to(self.device, non_blocking=True),
|
||||
"agent_pos": state.to(self.device, non_blocking=True),
|
||||
},
|
||||
"action": action.to(self.device, non_blocking=True),
|
||||
}
|
||||
return out
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
||||
loss = self.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.cfg.lr,
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
loss_dict = self._forward(
|
||||
qpos=batch["obs"]["agent_pos"],
|
||||
image=batch["obs"]["image"],
|
||||
actions=batch["action"],
|
||||
)
|
||||
loss = loss_dict["loss"]
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def select_actions(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
self.eval()
|
||||
|
||||
# TODO(rcadene): remove hack
|
||||
# add 1 camera dimension
|
||||
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
|
||||
|
||||
obs_dict = {
|
||||
"image": observation["image", "top"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
action = self._forward(qpos=obs_dict["agent_pos"] * 0.182, image=obs_dict["image"])
|
||||
|
||||
if self.cfg.temporal_agg:
|
||||
# TODO(rcadene): implement temporal aggregation
|
||||
raise NotImplementedError()
|
||||
# all_time_actions[[t], t:t+num_queries] = action
|
||||
# actions_for_curr_step = all_time_actions[:, t]
|
||||
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
|
||||
# actions_for_curr_step = actions_for_curr_step[actions_populated]
|
||||
# k = 0.01
|
||||
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
||||
# exp_weights = exp_weights / exp_weights.sum()
|
||||
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
||||
|
||||
# take first predicted action or n first actions
|
||||
action = action[: self.n_action_steps]
|
||||
return action
|
||||
|
||||
def _forward(self, qpos, image, actions=None):
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
|
||||
is_training = actions is not None
|
||||
if is_training: # training time
|
||||
actions = actions[:, : self.model.horizon]
|
||||
|
||||
a_hat, (mu, log_sigma_x2) = self.model(qpos, image, actions)
|
||||
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = all_l1.mean()
|
||||
|
||||
loss_dict = {}
|
||||
loss_dict["l1"] = l1
|
||||
if self.cfg.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean()
|
||||
loss_dict["kl"] = mean_kld
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = loss_dict["l1"]
|
||||
return loss_dict
|
||||
else:
|
||||
action, _ = self.model(qpos, image) # no action, sample from prior
|
||||
return action
|
||||
|
||||
|
||||
# TODO(alexander-soare) move all this code into the policy when we have the policy API established.
|
||||
class _ActionChunkingTransformer(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
self.camera_names = cfg.camera_names
|
||||
self.use_vae = cfg.use_vae
|
||||
self.horizon = cfg.horizon
|
||||
|
@ -326,26 +140,179 @@ class _ActionChunkingTransformer(nn.Module):
|
|||
|
||||
self._reset_parameters()
|
||||
|
||||
self._create_optimizer()
|
||||
self.to(self.device)
|
||||
|
||||
def _create_optimizer(self):
|
||||
optimizer_params_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad
|
||||
],
|
||||
"lr": self.cfg.lr_backbone,
|
||||
},
|
||||
]
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
|
||||
)
|
||||
|
||||
def _reset_parameters(self):
|
||||
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
||||
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_actions(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
self.eval()
|
||||
|
||||
# TODO(rcadene): remove hack
|
||||
# add 1 camera dimension
|
||||
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
|
||||
|
||||
obs_dict = {
|
||||
"image": observation["image", "top"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
action = self._forward(qpos=obs_dict["agent_pos"] * 0.182, image=obs_dict["image"])
|
||||
|
||||
if self.cfg.temporal_agg:
|
||||
# TODO(rcadene): implement temporal aggregation
|
||||
raise NotImplementedError()
|
||||
# all_time_actions[[t], t:t+num_queries] = action
|
||||
# actions_for_curr_step = all_time_actions[:, t]
|
||||
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
|
||||
# actions_for_curr_step = actions_for_curr_step[actions_populated]
|
||||
# k = 0.01
|
||||
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
||||
# exp_weights = exp_weights / exp_weights.sum()
|
||||
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
||||
|
||||
# take first predicted action or n first actions
|
||||
action = action[: self.n_action_steps]
|
||||
return action
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# TODO(now): Temporary bridge.
|
||||
return self.update(*args, **kwargs)
|
||||
|
||||
def _preprocess_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Expects batch to have (at least):
|
||||
{
|
||||
"observation.state": (B, 1, J) tensor of robot states (joint configuration)
|
||||
|
||||
"observation.images.top": (B, 1, C, H, W) tensor of images.
|
||||
"action": (B, H, J) tensor of actions (positional target for robot joint configuration)
|
||||
"action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds.
|
||||
}
|
||||
"""
|
||||
if batch["observation.state"].shape[1] != 1:
|
||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
||||
batch["observation.state"] = batch["observation.state"].squeeze(1)
|
||||
# TODO(alexander-soare): generalize this to multiple images. Note: no squeeze is required for
|
||||
# "observation.images.top" because then we'd have to unsqueeze to get get the image index dimension.
|
||||
|
||||
def update(self, batch, *_):
|
||||
start_time = time.time()
|
||||
self._preprocess_batch(batch)
|
||||
|
||||
self.train()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
loss = self.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.cfg.lr,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
def compute_loss(self, batch):
|
||||
loss_dict = self.forward(
|
||||
robot_state=batch["observation.state"],
|
||||
image=batch["observation.images.top"],
|
||||
actions=batch["action"],
|
||||
)
|
||||
loss = loss_dict["loss"]
|
||||
return loss
|
||||
|
||||
def forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
|
||||
# TODO(now): Maybe this shouldn't be here?
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
|
||||
is_training = actions is not None
|
||||
if is_training: # training time
|
||||
actions = actions[:, : self.horizon]
|
||||
|
||||
a_hat, (mu, log_sigma_x2) = self._forward(robot_state, image, actions)
|
||||
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = all_l1.mean()
|
||||
|
||||
loss_dict = {}
|
||||
loss_dict["l1"] = l1
|
||||
if self.cfg.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean()
|
||||
loss_dict["kl"] = mean_kld
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = loss_dict["l1"]
|
||||
return loss_dict
|
||||
else:
|
||||
action, _ = self._forward(robot_state, image) # no action, sample from prior
|
||||
return action
|
||||
|
||||
def _forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
|
||||
"""
|
||||
Args:
|
||||
robot_state: (B, J) batch of robot joint configurations.
|
||||
image: (B, N, C, H, W) batch of N camera frames.
|
||||
actions: (B, S, A) batch of actions from the target dataset which must be provided if the
|
||||
VAE is enabled and the model is in training mode.
|
||||
Returns:
|
||||
(B, S, A) batch of action sequences
|
||||
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
|
||||
latent dimension.
|
||||
"""
|
||||
if self.use_vae and self.training:
|
||||
assert (
|
||||
actions is not None
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
|
||||
batch_size, _ = robot_state.shape
|
||||
batch_size = robot_state.shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.use_vae and actions is not None:
|
||||
|
@ -428,6 +395,13 @@ class _ActionChunkingTransformer(nn.Module):
|
|||
|
||||
return actions, [mu, log_sigma_x2]
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
|
||||
class _TransformerEncoder(nn.Module):
|
||||
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
||||
|
|
|
@ -152,7 +152,6 @@ class DiffusionPolicy(nn.Module):
|
|||
self.diffusion.train()
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
|
|
|
@ -41,7 +41,6 @@ def log_train_info(logger, info, step, cfg, dataset, is_offline):
|
|||
loss = info["loss"]
|
||||
grad_norm = info["grad_norm"]
|
||||
lr = info["lr"]
|
||||
data_s = info["data_s"]
|
||||
update_s = info["update_s"]
|
||||
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
|
@ -62,7 +61,6 @@ def log_train_info(logger, info, step, cfg, dataset, is_offline):
|
|||
f"grdn:{grad_norm:.3f}",
|
||||
f"lr:{lr:0.1e}",
|
||||
# in seconds
|
||||
f"data_s:{data_s:.3f}",
|
||||
f"updt_s:{update_s:.3f}",
|
||||
]
|
||||
logging.info(" ".join(log_items))
|
||||
|
@ -200,7 +198,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
is_offline = True
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
num_workers=0,
|
||||
batch_size=cfg.policy.batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
|
|
|
@ -880,6 +880,29 @@ files = [
|
|||
[package.extras]
|
||||
protobuf = ["grpcio-tools (>=1.62.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "gym-pusht"
|
||||
version = "0.1.0"
|
||||
description = "PushT environment for LeRobot"
|
||||
optional = true
|
||||
python-versions = "^3.10"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
gymnasium = "^0.29.1"
|
||||
opencv-python = "^4.9.0.80"
|
||||
pygame = "^2.5.2"
|
||||
pymunk = "^6.6.0"
|
||||
scikit-image = "^0.22.0"
|
||||
shapely = "^2.0.3"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "git@github.com:huggingface/gym-pusht.git"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "0fe4449cca5a2b08f529f7a07fbf5b9df24962ec"
|
||||
|
||||
[[package]]
|
||||
name = "gymnasium"
|
||||
version = "0.29.1"
|
||||
|
@ -1261,17 +1284,21 @@ setuptools = "!=50.0.0"
|
|||
|
||||
[[package]]
|
||||
name = "lazy-loader"
|
||||
version = "0.3"
|
||||
description = "lazy_loader"
|
||||
version = "0.4"
|
||||
description = "Makes it easy to load subpackages and functions on demand."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "lazy_loader-0.3-py3-none-any.whl", hash = "sha256:1e9e76ee8631e264c62ce10006718e80b2cfc74340d17d1031e0f84af7478554"},
|
||||
{file = "lazy_loader-0.3.tar.gz", hash = "sha256:3b68898e34f5b2a29daaaac172c6555512d0f32074f147e2254e4a6d9d838f37"},
|
||||
{file = "lazy_loader-0.4-py3-none-any.whl", hash = "sha256:342aa8e14d543a154047afb4ba8ef17f5563baad3fc610d7b15b213b0f119efc"},
|
||||
{file = "lazy_loader-0.4.tar.gz", hash = "sha256:47c75182589b91a4e1a85a136c074285a5ad4d9f39c63e0d7fb76391c4574cd1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
packaging = "*"
|
||||
|
||||
[package.extras]
|
||||
lint = ["pre-commit (>=3.3)"]
|
||||
dev = ["changelist (==0.5)"]
|
||||
lint = ["pre-commit (==3.7.0)"]
|
||||
test = ["pytest (>=7.4)", "pytest-cov (>=4.1)"]
|
||||
|
||||
[[package]]
|
||||
|
@ -3274,7 +3301,7 @@ protobuf = ">=3.20"
|
|||
|
||||
[[package]]
|
||||
name = "tensordict"
|
||||
version = "0.4.0+b4c91e8"
|
||||
version = "0.4.0+f622b2f"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
@ -3518,13 +3545,13 @@ tutorials = ["matplotlib", "pandas", "tabulate", "torch"]
|
|||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.10.0"
|
||||
version = "4.11.0"
|
||||
description = "Backported and Experimental Type Hints for Python 3.8+"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"},
|
||||
{file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"},
|
||||
{file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
|
||||
{file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -3667,9 +3694,9 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link
|
|||
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
|
||||
|
||||
[extras]
|
||||
pusht = []
|
||||
pusht = ["gym_pusht"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "04b17fa57f189ad63181611d2e724d7fbdfb3485bc1a587b259d0a3751db918d"
|
||||
content-hash = "3eee17e4bf2b7a570f41ef9c400ec5a24a3113f62a13162229cf43504ca0d005"
|
||||
|
|
|
@ -52,6 +52,7 @@ robomimic = "0.2.0"
|
|||
gymnasium-robotics = "^1.2.4"
|
||||
gymnasium = "^0.29.1"
|
||||
cmake = "^3.29.0.1"
|
||||
gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
|
||||
|
||||
[tool.poetry.extras]
|
||||
pusht = ["gym_pusht"]
|
||||
|
|
Loading…
Reference in New Issue