update traj generation + save tensors

This commit is contained in:
Akshay Kashyap 2024-06-04 13:14:52 -07:00
parent 0d6fde3ba1
commit c98d9dbd08
5 changed files with 44 additions and 25 deletions

View File

@ -264,24 +264,32 @@ class OctoModel(nn.Module):
# Sample prior.
sample = torch.randn(
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
size=(
batch_size,
self.config.n_obs_steps,
self.config.horizon,
self.config.output_shapes["action"][0],
),
dtype=dtype,
device=device,
generator=generator,
)
sample = rearrange(sample, "b t d -> b (t d)")
sample = rearrange(sample, "b o t d -> b o (t d)")
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
# Predict model output.
t_ = t.repeat((batch_size, 1)).to(device)
model_output = self.action_head(readout_embeds, t_, sample)
model_output = self.action_head(
readout_embeds,
torch.full((sample.shape[0], 1), t, dtype=torch.long, device=sample.device),
sample,
)
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
sample = rearrange(sample, "b (t d) -> b t d", t=self.config.horizon)
sample = rearrange(sample, "b o (t d) -> b o t d", t=self.config.horizon)
return sample
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
@ -305,12 +313,8 @@ class OctoModel(nn.Module):
# run sampling
sample = self.conditional_sample(batch_size, readout_embeds)
# `horizon` steps worth of actions (from the first observation).
actions = sample[..., : self.config.output_shapes["action"][0]]
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.config.n_action_steps
actions = actions[:, start:end]
# `horizon` steps worth of actions (from the last observation).
actions = sample[:, -1, : self.config.output_shapes["action"][0]]
return actions
@ -327,7 +331,7 @@ class OctoModel(nn.Module):
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
horizon = batch["action"].shape[1]
horizon = batch["action"].shape[1] - self.config.n_obs_steps + 1
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
@ -339,30 +343,36 @@ class OctoModel(nn.Module):
readout_embeds = self.transformer(batch["observation.state"], img_features)
trajectory = batch["action"]
trajectory_per_obs_step = torch.zeros(
(batch_size, n_obs_steps, horizon, self.config.output_shapes["action"][0]),
device=trajectory.device,
)
for i in range(n_obs_steps):
trajectory_per_obs_step[:, i, :horizon] = trajectory[:, i : i + horizon]
trajectory_per_obs_step = rearrange(trajectory_per_obs_step, "b o t d -> b o (t d)")
# Forward diffusion.
# Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device)
eps = torch.randn(trajectory_per_obs_step.shape, device=trajectory_per_obs_step.device)
# Sample a random noising timestep for each item in the batch.
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(trajectory.shape[0], 1),
device=trajectory.device,
size=(trajectory_per_obs_step.shape[0], 1),
device=trajectory_per_obs_step.device,
).long()
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
noisy_trajectories = self.noise_scheduler.add_noise(trajectory_per_obs_step, eps, timesteps)
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
pred = self.action_head(readout_embeds, timesteps, rearrange(noisy_trajectory, "b t d -> b (t d)"))
pred = rearrange(pred, "b (t d) -> b t d", t=horizon)
pred = self.action_head(readout_embeds, timesteps, noisy_trajectories)
# Compute the loss.
# The target is either the original trajectory, or the noise.
if self.config.prediction_type == "epsilon":
target = eps
elif self.config.prediction_type == "sample":
target = batch["action"]
target = trajectory_per_obs_step
else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
@ -371,7 +381,15 @@ class OctoModel(nn.Module):
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if self.config.do_mask_loss_for_padding and "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
in_episode_bound_per_step = torch.zeros(
(batch_size, n_obs_steps, horizon),
device=trajectory.device,
)
for i in range(n_obs_steps):
in_episode_bound_per_step[:, i, :horizon] = in_episode_bound[:, i : i + horizon]
loss = rearrange(loss, "b o (t d) -> b o t d", t=horizon)
loss = loss * in_episode_bound_per_step.unsqueeze(-1)
return loss.mean()
@ -743,14 +761,15 @@ class OctoDiffusionActionHead(nn.Module):
Args:
readout_embeds: torch.Tensor, shape [batch_size, n_obs_steps, n_readouts_per_step, embed_dim]
time: torch.Tensor, shape [batch_size, 1]
actions: torch.Tensor, shape [batch_size, pred_horizon * action_dim]
actions: torch.Tensor, shape [batch_size, n_obs_steps, pred_horizon * action_dim]
Returns:
eps_pred: torch.Tensor, shape [batch_size, pred_horizon * action_dim]
eps_pred: torch.Tensor, shape [batch_size, n_obs_steps, pred_horizon * action_dim]
"""
# we use the mean of all readout tokens for now but there is room for experimentation.
mean_readouts_embed = readout_embeds.mean(dim=(1, 2))
mean_readouts_embed = readout_embeds.mean(dim=-2)
time_feature = self.fourier_feature_embedder(time)
time_cond = self.time_feature_encoder(time_feature)
time_cond = repeat(time_cond, "b f -> b t f", t=readout_embeds.size(1))
x = torch.cat([time_cond, mean_readouts_embed, actions], dim=-1)
eps_pred = self.net(x)
return eps_pred

View File

@ -22,7 +22,7 @@ training:
delta_timestamps:
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.horizon})]"
eval:
n_episodes: 50
@ -58,7 +58,7 @@ policy:
use_group_norm: True
# OctoTransformer.
embed_dim: 384
n_readouts: 1
n_readouts_per_step: 1
n_layers: 12
n_heads: 6
dim_feedforward: 1536