Rename num_samples -> num_frames for consistency
This commit is contained in:
parent
2650872b76
commit
79d114cc1f
|
@ -37,7 +37,7 @@ print(dataset)
|
|||
print(dataset.hf_dataset)
|
||||
|
||||
# And provides additional utilities for robotics and compatibility with Pytorch
|
||||
print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
||||
print(f"\naverage number of frames per episode: {dataset.num_frames / dataset.num_episodes:.3f}")
|
||||
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||
print(f"keys to access images from cameras: {dataset.camera_keys=}\n")
|
||||
|
||||
|
|
|
@ -180,13 +180,13 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
|||
"n ... -> ...",
|
||||
stat_key,
|
||||
)
|
||||
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
|
||||
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.stats)
|
||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||
# dataset, then divide by total_samples to get the overall "mean".
|
||||
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["mean"] = sum(
|
||||
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
|
||||
d.stats[data_key]["mean"] * (d.num_frames / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.stats
|
||||
)
|
||||
|
@ -195,12 +195,12 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
|||
# Given two sets of data where the statistics are known:
|
||||
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
|
||||
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
|
||||
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["std"] = torch.sqrt(
|
||||
sum(
|
||||
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
|
||||
* (d.num_samples / total_samples)
|
||||
* (d.num_frames / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.stats
|
||||
)
|
||||
|
|
|
@ -357,8 +357,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return self.info["names"]
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
"""Number of samples/frames in selected episodes."""
|
||||
def num_frames(self) -> int:
|
||||
"""Number of frames in selected episodes."""
|
||||
return len(self.hf_dataset) if self.hf_dataset is not None else self.total_frames
|
||||
|
||||
@property
|
||||
|
@ -510,7 +510,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return item
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
return self.num_frames
|
||||
|
||||
def __getitem__(self, idx) -> dict:
|
||||
item = self.hf_dataset[idx]
|
||||
|
@ -544,7 +544,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Selected episodes: {self.episodes},\n"
|
||||
f" Number of selected episodes: {self.num_episodes},\n"
|
||||
f" Number of selected samples: {self.num_samples},\n"
|
||||
f" Number of selected samples: {self.num_frames},\n"
|
||||
f"\n{json.dumps(self.info, indent=4)}\n"
|
||||
)
|
||||
|
||||
|
@ -981,9 +981,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
return video_frame_keys
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
def num_frames(self) -> int:
|
||||
"""Number of samples/frames."""
|
||||
return sum(d.num_samples for d in self._datasets)
|
||||
return sum(d.num_frames for d in self._datasets)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
|
@ -1000,7 +1000,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
return 1 / self.fps - 1e-4
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
return self.num_frames
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self):
|
||||
|
@ -1009,8 +1009,8 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
start_idx = 0
|
||||
dataset_idx = 0
|
||||
for dataset in self._datasets:
|
||||
if idx >= start_idx + dataset.num_samples:
|
||||
start_idx += dataset.num_samples
|
||||
if idx >= start_idx + dataset.num_frames:
|
||||
start_idx += dataset.num_frames
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
|
@ -1028,7 +1028,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository IDs: '{self.repo_ids}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Samples: {self.num_frames},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
|
|
|
@ -187,7 +187,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
|||
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
|
||||
|
||||
# Shift the incoming indices if necessary.
|
||||
if self.num_samples > 0:
|
||||
if self.num_frames > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||
|
@ -227,11 +227,11 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
|||
)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
def num_frames(self) -> int:
|
||||
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
return self.num_frames
|
||||
|
||||
def _item_to_tensors(self, item: dict) -> dict:
|
||||
item_ = {}
|
||||
|
|
|
@ -343,7 +343,7 @@ def replay(
|
|||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_samples):
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"]
|
||||
|
|
|
@ -171,9 +171,9 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
|
|||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
num_samples = (step + 1) * cfg.training.batch_size
|
||||
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
|
||||
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
|
||||
num_episodes = num_samples / avg_samples_per_ep
|
||||
num_epochs = num_samples / dataset.num_samples
|
||||
num_epochs = num_samples / dataset.num_frames
|
||||
log_items = [
|
||||
f"step:{format_big_number(step)}",
|
||||
# number of samples seen during training
|
||||
|
@ -208,9 +208,9 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online):
|
|||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
num_samples = (step + 1) * cfg.training.batch_size
|
||||
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
|
||||
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
|
||||
num_episodes = num_samples / avg_samples_per_ep
|
||||
num_epochs = num_samples / dataset.num_samples
|
||||
num_epochs = num_samples / dataset.num_frames
|
||||
log_items = [
|
||||
f"step:{format_big_number(step)}",
|
||||
# number of samples seen during training
|
||||
|
@ -349,7 +349,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
|
||||
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
|
||||
logging.info(f"{offline_dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
@ -573,7 +573,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||
)
|
||||
sampler.num_samples = len(concat_dataset)
|
||||
sampler.num_frames = len(concat_dataset)
|
||||
|
||||
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ def run_server(
|
|||
def show_episode(dataset_namespace, dataset_name, episode_id):
|
||||
dataset_info = {
|
||||
"repo_id": dataset.repo_id,
|
||||
"num_samples": dataset.num_samples,
|
||||
"num_samples": dataset.num_frames,
|
||||
"num_episodes": dataset.num_episodes,
|
||||
"fps": dataset.fps,
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@
|
|||
|
||||
<ul>
|
||||
<li>
|
||||
Number of samples/frames: {{ dataset_info.num_samples }}
|
||||
Number of samples/frames: {{ dataset_info.num_frames }}
|
||||
</li>
|
||||
<li>
|
||||
Number of episodes: {{ dataset_info.num_episodes }}
|
||||
|
|
Loading…
Reference in New Issue