transform -> image_transforms

This commit is contained in:
Simon Alibert 2024-06-06 16:53:37 +00:00
parent c45dd8f848
commit faacb36271
1 changed files with 7 additions and 7 deletions

View File

@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.version = version
self.root = root
self.split = split
self.transform = image_transforms
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
@ -151,9 +151,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s,
)
if self.transform is not None:
if self.image_transforms is not None:
for cam in self.camera_keys:
item[cam] = self.transform(item[cam])
item[cam] = self.image_transforms(item[cam])
return item
@ -169,7 +169,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)
@ -203,7 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.version = version
obj.root = root
obj.split = split
obj.transform = transform
obj.image_transforms = transform
obj.delta_timestamps = delta_timestamps
obj.hf_dataset = hf_dataset
obj.episode_data_index = episode_data_index
@ -275,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.version = version
self.root = root
self.split = split
self.transform = image_transforms
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets)
@ -396,6 +396,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)