206 lines
7.5 KiB
Python
206 lines
7.5 KiB
Python
import io
|
|
import logging
|
|
import zipfile
|
|
from copy import deepcopy
|
|
from math import ceil
|
|
from pathlib import Path
|
|
|
|
import einops
|
|
import requests
|
|
import torch
|
|
import tqdm
|
|
|
|
|
|
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
|
print(f"downloading from {url}")
|
|
response = requests.get(url, stream=True)
|
|
if response.status_code == 200:
|
|
total_size = int(response.headers.get("content-length", 0))
|
|
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
|
|
|
zip_file = io.BytesIO()
|
|
for chunk in response.iter_content(chunk_size=1024):
|
|
if chunk:
|
|
zip_file.write(chunk)
|
|
progress_bar.update(len(chunk))
|
|
|
|
progress_bar.close()
|
|
|
|
zip_file.seek(0)
|
|
|
|
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
|
zip_ref.extractall(destination_folder)
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def euclidean_distance_matrix(mat0, mat1):
|
|
# Compute the square of the distance matrix
|
|
sq0 = torch.sum(mat0**2, dim=1, keepdim=True)
|
|
sq1 = torch.sum(mat1**2, dim=1, keepdim=True)
|
|
distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1)
|
|
|
|
# Taking the square root to get the euclidean distance
|
|
distance = torch.sqrt(torch.clamp(distance_sq, min=0))
|
|
return distance
|
|
|
|
|
|
def is_contiguously_true_or_false(bool_vector):
|
|
assert bool_vector.ndim == 1
|
|
assert bool_vector.dtype == torch.bool
|
|
|
|
# Compare each element with its neighbor to find changes
|
|
changes = bool_vector[1:] != bool_vector[:-1]
|
|
|
|
# Count the number of changes
|
|
num_changes = changes.sum().item()
|
|
|
|
# If there's more than one change, the list is not contiguous
|
|
return num_changes <= 1
|
|
|
|
# examples = [
|
|
# ([True, False, True, False, False, False], False),
|
|
# ([True, True, True, False, False, False], True),
|
|
# ([False, False, False, False, False, False], True)
|
|
# ]
|
|
# for bool_list, expected in examples:
|
|
# result = is_contiguously_true_or_false(bool_list)
|
|
|
|
|
|
def load_data_with_delta_timestamps(
|
|
data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode
|
|
):
|
|
# get indices of the frames associated to the episode, and their timestamps
|
|
ep_data_ids = data_ids_per_episode[episode]
|
|
ep_timestamps = data_dict["timestamp"][ep_data_ids]
|
|
|
|
# get timestamps used as query to retrieve data of previous/future frames
|
|
delta_ts = delta_timestamps[key]
|
|
query_ts = current_ts + torch.tensor(delta_ts)
|
|
|
|
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
|
dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None])
|
|
min_, argmin_ = dist.min(1)
|
|
|
|
# get the indices of the data that are closest to the query timestamps
|
|
data_ids = ep_data_ids[argmin_]
|
|
# closest_ts = ep_timestamps[argmin_]
|
|
|
|
# get the data
|
|
data = data_dict[key][data_ids].clone()
|
|
|
|
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
|
|
|
tol = 0.02
|
|
is_pad = min_ > tol
|
|
|
|
assert is_contiguously_true_or_false(is_pad), (
|
|
"One or several timestamps unexpectedly violate the tolerance."
|
|
"This might be due to synchronization issues with timestamps during data collection."
|
|
)
|
|
|
|
return data, is_pad
|
|
|
|
|
|
def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
|
stats_path = dataset.data_dir / "stats.pth"
|
|
if stats_path.exists():
|
|
return torch.load(stats_path)
|
|
|
|
logging.info(f"compute_stats and save to {stats_path}")
|
|
|
|
if max_num_samples is None:
|
|
max_num_samples = len(dataset)
|
|
else:
|
|
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
num_workers=4,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
# pin_memory=cfg.device != "cpu",
|
|
drop_last=False,
|
|
)
|
|
|
|
# these einops patterns will be used to aggregate batches and compute statistics
|
|
stats_patterns = {
|
|
"action": "b c -> c",
|
|
"observation.state": "b c -> c",
|
|
}
|
|
for key in dataset.image_keys:
|
|
stats_patterns[key] = "b c h w -> c 1 1"
|
|
|
|
# mean and std will be computed incrementally while max and min will track the running value.
|
|
mean, std, max, min = {}, {}, {}, {}
|
|
for key in stats_patterns:
|
|
mean[key] = torch.tensor(0.0).float()
|
|
std[key] = torch.tensor(0.0).float()
|
|
max[key] = torch.tensor(-float("inf")).float()
|
|
min[key] = torch.tensor(float("inf")).float()
|
|
|
|
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
|
# surprises when rerunning the sampler.
|
|
first_batch = None
|
|
running_item_count = 0 # for online mean computation
|
|
for i, batch in enumerate(
|
|
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
|
):
|
|
this_batch_size = len(batch["index"])
|
|
running_item_count += this_batch_size
|
|
if first_batch is None:
|
|
first_batch = deepcopy(batch)
|
|
for key, pattern in stats_patterns.items():
|
|
batch[key] = batch[key].float()
|
|
# Numerically stable update step for mean computation.
|
|
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
|
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
|
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
|
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
|
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
|
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
|
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
|
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
|
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
|
|
|
if i == ceil(max_num_samples / batch_size) - 1:
|
|
break
|
|
|
|
first_batch_ = None
|
|
running_item_count = 0 # for online std computation
|
|
for i, batch in enumerate(
|
|
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
|
):
|
|
this_batch_size = len(batch["index"])
|
|
running_item_count += this_batch_size
|
|
# Sanity check to make sure the batches are still in the same order as before.
|
|
if first_batch_ is None:
|
|
first_batch_ = deepcopy(batch)
|
|
for key in stats_patterns:
|
|
assert torch.equal(first_batch_[key], first_batch[key])
|
|
for key, pattern in stats_patterns.items():
|
|
batch[key] = batch[key].float()
|
|
# Numerically stable update step for mean computation (where the mean is over squared
|
|
# residuals).See notes in the mean computation loop above.
|
|
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
|
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
|
|
|
if i == ceil(max_num_samples / batch_size) - 1:
|
|
break
|
|
|
|
for key in stats_patterns:
|
|
std[key] = torch.sqrt(std[key])
|
|
|
|
stats = {}
|
|
for key in stats_patterns:
|
|
stats[key] = {
|
|
"mean": mean[key],
|
|
"std": std[key],
|
|
"max": max[key],
|
|
"min": min[key],
|
|
}
|
|
|
|
torch.save(stats, stats_path)
|
|
return stats
|