diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index f41b8f7a..d680b987 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.version = version self.root = root self.split = split - self.transform = image_transforms + self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps # load data from hub or locally when root is provided # TODO(rcadene, aliberts): implement faster transfer @@ -151,9 +151,9 @@ class LeRobotDataset(torch.utils.data.Dataset): self.tolerance_s, ) - if self.transform is not None: + if self.image_transforms is not None: for cam in self.camera_keys: - item[cam] = self.transform(item[cam]) + item[cam] = self.image_transforms(item[cam]) return item @@ -169,7 +169,7 @@ class LeRobotDataset(torch.utils.data.Dataset): f" Recorded Frames per Second: {self.fps},\n" f" Camera Keys: {self.camera_keys},\n" f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" - f" Transformations: {self.transform},\n" + f" Transformations: {self.image_transforms},\n" f")" ) @@ -203,7 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.version = version obj.root = root obj.split = split - obj.transform = transform + obj.image_transforms = transform obj.delta_timestamps = delta_timestamps obj.hf_dataset = hf_dataset obj.episode_data_index = episode_data_index @@ -275,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): self.version = version self.root = root self.split = split - self.transform = image_transforms + self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.stats = aggregate_stats(self._datasets) @@ -396,6 +396,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): f" Recorded Frames per Second: {self.fps},\n" f" Camera Keys: {self.camera_keys},\n" f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" - f" Transformations: {self.transform},\n" + f" Transformations: {self.image_transforms},\n" f")" )