update traj generation + save tensors
This commit is contained in:
parent
0d6fde3ba1
commit
c98d9dbd08
|
@ -264,24 +264,32 @@ class OctoModel(nn.Module):
|
|||
|
||||
# Sample prior.
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
|
||||
size=(
|
||||
batch_size,
|
||||
self.config.n_obs_steps,
|
||||
self.config.horizon,
|
||||
self.config.output_shapes["action"][0],
|
||||
),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
sample = rearrange(sample, "b t d -> b (t d)")
|
||||
sample = rearrange(sample, "b o t d -> b o (t d)")
|
||||
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
for t in self.noise_scheduler.timesteps:
|
||||
# Predict model output.
|
||||
t_ = t.repeat((batch_size, 1)).to(device)
|
||||
model_output = self.action_head(readout_embeds, t_, sample)
|
||||
model_output = self.action_head(
|
||||
readout_embeds,
|
||||
torch.full((sample.shape[0], 1), t, dtype=torch.long, device=sample.device),
|
||||
sample,
|
||||
)
|
||||
|
||||
# Compute previous image: x_t -> x_t-1
|
||||
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||
|
||||
sample = rearrange(sample, "b (t d) -> b t d", t=self.config.horizon)
|
||||
sample = rearrange(sample, "b o (t d) -> b o t d", t=self.config.horizon)
|
||||
return sample
|
||||
|
||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
|
@ -305,12 +313,8 @@ class OctoModel(nn.Module):
|
|||
# run sampling
|
||||
sample = self.conditional_sample(batch_size, readout_embeds)
|
||||
|
||||
# `horizon` steps worth of actions (from the first observation).
|
||||
actions = sample[..., : self.config.output_shapes["action"][0]]
|
||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.config.n_action_steps
|
||||
actions = actions[:, start:end]
|
||||
# `horizon` steps worth of actions (from the last observation).
|
||||
actions = sample[:, -1, : self.config.output_shapes["action"][0]]
|
||||
|
||||
return actions
|
||||
|
||||
|
@ -327,7 +331,7 @@ class OctoModel(nn.Module):
|
|||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
horizon = batch["action"].shape[1]
|
||||
horizon = batch["action"].shape[1] - self.config.n_obs_steps + 1
|
||||
assert horizon == self.config.horizon
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
|
@ -339,30 +343,36 @@ class OctoModel(nn.Module):
|
|||
readout_embeds = self.transformer(batch["observation.state"], img_features)
|
||||
|
||||
trajectory = batch["action"]
|
||||
trajectory_per_obs_step = torch.zeros(
|
||||
(batch_size, n_obs_steps, horizon, self.config.output_shapes["action"][0]),
|
||||
device=trajectory.device,
|
||||
)
|
||||
for i in range(n_obs_steps):
|
||||
trajectory_per_obs_step[:, i, :horizon] = trajectory[:, i : i + horizon]
|
||||
trajectory_per_obs_step = rearrange(trajectory_per_obs_step, "b o t d -> b o (t d)")
|
||||
|
||||
# Forward diffusion.
|
||||
# Sample noise to add to the trajectory.
|
||||
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
||||
eps = torch.randn(trajectory_per_obs_step.shape, device=trajectory_per_obs_step.device)
|
||||
# Sample a random noising timestep for each item in the batch.
|
||||
timesteps = torch.randint(
|
||||
low=0,
|
||||
high=self.noise_scheduler.config.num_train_timesteps,
|
||||
size=(trajectory.shape[0], 1),
|
||||
device=trajectory.device,
|
||||
size=(trajectory_per_obs_step.shape[0], 1),
|
||||
device=trajectory_per_obs_step.device,
|
||||
).long()
|
||||
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
|
||||
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
|
||||
noisy_trajectories = self.noise_scheduler.add_noise(trajectory_per_obs_step, eps, timesteps)
|
||||
|
||||
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
|
||||
pred = self.action_head(readout_embeds, timesteps, rearrange(noisy_trajectory, "b t d -> b (t d)"))
|
||||
pred = rearrange(pred, "b (t d) -> b t d", t=horizon)
|
||||
pred = self.action_head(readout_embeds, timesteps, noisy_trajectories)
|
||||
|
||||
# Compute the loss.
|
||||
# The target is either the original trajectory, or the noise.
|
||||
if self.config.prediction_type == "epsilon":
|
||||
target = eps
|
||||
elif self.config.prediction_type == "sample":
|
||||
target = batch["action"]
|
||||
target = trajectory_per_obs_step
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
|
||||
|
||||
|
@ -371,7 +381,15 @@ class OctoModel(nn.Module):
|
|||
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
||||
if self.config.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||
in_episode_bound_per_step = torch.zeros(
|
||||
(batch_size, n_obs_steps, horizon),
|
||||
device=trajectory.device,
|
||||
)
|
||||
for i in range(n_obs_steps):
|
||||
in_episode_bound_per_step[:, i, :horizon] = in_episode_bound[:, i : i + horizon]
|
||||
|
||||
loss = rearrange(loss, "b o (t d) -> b o t d", t=horizon)
|
||||
loss = loss * in_episode_bound_per_step.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
@ -743,14 +761,15 @@ class OctoDiffusionActionHead(nn.Module):
|
|||
Args:
|
||||
readout_embeds: torch.Tensor, shape [batch_size, n_obs_steps, n_readouts_per_step, embed_dim]
|
||||
time: torch.Tensor, shape [batch_size, 1]
|
||||
actions: torch.Tensor, shape [batch_size, pred_horizon * action_dim]
|
||||
actions: torch.Tensor, shape [batch_size, n_obs_steps, pred_horizon * action_dim]
|
||||
Returns:
|
||||
eps_pred: torch.Tensor, shape [batch_size, pred_horizon * action_dim]
|
||||
eps_pred: torch.Tensor, shape [batch_size, n_obs_steps, pred_horizon * action_dim]
|
||||
"""
|
||||
# we use the mean of all readout tokens for now but there is room for experimentation.
|
||||
mean_readouts_embed = readout_embeds.mean(dim=(1, 2))
|
||||
mean_readouts_embed = readout_embeds.mean(dim=-2)
|
||||
time_feature = self.fourier_feature_embedder(time)
|
||||
time_cond = self.time_feature_encoder(time_feature)
|
||||
time_cond = repeat(time_cond, "b f -> b t f", t=readout_embeds.size(1))
|
||||
x = torch.cat([time_cond, mean_readouts_embed, actions], dim=-1)
|
||||
eps_pred = self.net(x)
|
||||
return eps_pred
|
||||
|
|
|
@ -22,7 +22,7 @@ training:
|
|||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.horizon})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
@ -58,7 +58,7 @@ policy:
|
|||
use_group_norm: True
|
||||
# OctoTransformer.
|
||||
embed_dim: 384
|
||||
n_readouts: 1
|
||||
n_readouts_per_step: 1
|
||||
n_layers: 12
|
||||
n_heads: 6
|
||||
dim_feedforward: 1536
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue