fix hf_dataset.set_transform(hf_transform_to_torch)
This commit is contained in:
parent
7c005c2aa1
commit
71715c3914
|
@ -736,7 +736,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
for key in self.meta.video_keys:
|
for key in self.meta.video_keys:
|
||||||
if query_indices is not None and key in query_indices:
|
if query_indices is not None and key in query_indices:
|
||||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||||
query_timestamps[key] = timestamps.tolist()
|
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||||
else:
|
else:
|
||||||
query_timestamps[key] = [current_ts]
|
query_timestamps[key] = [current_ts]
|
||||||
|
|
||||||
|
@ -744,7 +744,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||||
return {
|
return {
|
||||||
key: self.hf_dataset.select(q_idx)[key]
|
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||||
for key, q_idx in query_indices.items()
|
for key, q_idx in query_indices.items()
|
||||||
if key not in self.meta.video_keys
|
if key not in self.meta.video_keys
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,8 +56,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
||||||
def synced_timestamps_factory(hf_dataset_factory):
|
def synced_timestamps_factory(hf_dataset_factory):
|
||||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
hf_dataset = hf_dataset_factory(fps=fps)
|
hf_dataset = hf_dataset_factory(fps=fps)
|
||||||
timestamps = hf_dataset["timestamp"].numpy()
|
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
||||||
episode_indices = hf_dataset["episode_index"].numpy()
|
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
return timestamps, episode_indices, episode_data_index
|
return timestamps, episode_indices, episode_data_index
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue