Use HWC for images
This commit is contained in:
parent
1f13bda25b
commit
6203641710
|
@ -13,7 +13,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
@ -138,6 +137,11 @@ class LeRobotDatasetMetadata:
|
||||||
"""Formattable string for the video files."""
|
"""Formattable string for the video files."""
|
||||||
return self.info["video_path"]
|
return self.info["video_path"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def robot_type(self) -> str | None:
|
||||||
|
"""Robot type used in recording this dataset."""
|
||||||
|
return self.info["robot_type"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self) -> int:
|
def fps(self) -> int:
|
||||||
"""Frames per second used during data collection."""
|
"""Frames per second used during data collection."""
|
||||||
|
@ -258,10 +262,14 @@ class LeRobotDatasetMetadata:
|
||||||
write_json(self.info, self.root / INFO_PATH)
|
write_json(self.info, self.root / INFO_PATH)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
feature_keys = list(self.features)
|
||||||
return (
|
return (
|
||||||
f"{self.__class__.__name__}\n"
|
f"{self.__class__.__name__}({{\n"
|
||||||
f"Repository ID: '{self.repo_id}',\n"
|
f" Repository ID: '{self.repo_id}',\n"
|
||||||
f"{json.dumps(self.meta.info, indent=4)}\n"
|
f" Total episodes: '{self.total_episodes}',\n"
|
||||||
|
f" Total frames: '{self.total_frames}',\n"
|
||||||
|
f" Features: '{feature_keys}',\n"
|
||||||
|
"})',\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -657,13 +665,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
feature_keys = list(self.features)
|
||||||
return (
|
return (
|
||||||
f"{self.__class__.__name__}\n"
|
f"{self.__class__.__name__}({{\n"
|
||||||
f" Repository ID: '{self.repo_id}',\n"
|
f" Repository ID: '{self.repo_id}',\n"
|
||||||
f" Selected episodes: {self.episodes},\n"
|
f" Number of selected episodes: '{self.num_episodes}',\n"
|
||||||
f" Number of selected episodes: {self.num_episodes},\n"
|
f" Number of selected samples: '{self.num_frames}',\n"
|
||||||
f" Number of selected samples: {self.num_frames},\n"
|
f" Features: '{feature_keys}',\n"
|
||||||
f"\n{json.dumps(self.meta.info, indent=4)}\n"
|
"})',\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||||
|
|
|
@ -468,6 +468,7 @@ def create_lerobot_dataset_card(
|
||||||
text: str | None = None,
|
text: str | None = None,
|
||||||
info: dict | None = None,
|
info: dict | None = None,
|
||||||
license: str | None = None,
|
license: str | None = None,
|
||||||
|
url: str | None = None,
|
||||||
citation: str | None = None,
|
citation: str | None = None,
|
||||||
arxiv: str | None = None,
|
arxiv: str | None = None,
|
||||||
) -> DatasetCard:
|
) -> DatasetCard:
|
||||||
|
@ -488,6 +489,8 @@ def create_lerobot_dataset_card(
|
||||||
card.data.license = license
|
card.data.license = license
|
||||||
if tags:
|
if tags:
|
||||||
card.data.tags += tags
|
card.data.tags += tags
|
||||||
|
if url:
|
||||||
|
card.text += f"## Homepage:\n{url}\n"
|
||||||
if text:
|
if text:
|
||||||
card.text += f"{text}\n"
|
card.text += f"{text}\n"
|
||||||
if info:
|
if info:
|
||||||
|
|
|
@ -222,12 +222,12 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N
|
||||||
dtype = "image"
|
dtype = "image"
|
||||||
image = dataset[0][key] # Assuming first row
|
image = dataset[0][key] # Assuming first row
|
||||||
channels = get_image_pixel_channels(image)
|
channels = get_image_pixel_channels(image)
|
||||||
shape = (image.width, image.height, channels)
|
shape = (image.height, image.width, channels)
|
||||||
names = ["width", "height", "channel"]
|
names = ["height", "width", "channel"]
|
||||||
elif ft._type == "VideoFrame":
|
elif ft._type == "VideoFrame":
|
||||||
dtype = "video"
|
dtype = "video"
|
||||||
shape = None # Add shape later
|
shape = None # Add shape later
|
||||||
names = ["width", "height", "channel"]
|
names = ["height", "width", "channel"]
|
||||||
|
|
||||||
features[key] = {
|
features[key] = {
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
|
@ -437,8 +437,9 @@ def convert_dataset(
|
||||||
tasks_col: Path | None = None,
|
tasks_col: Path | None = None,
|
||||||
robot_config: dict | None = None,
|
robot_config: dict | None = None,
|
||||||
license: str | None = None,
|
license: str | None = None,
|
||||||
citation: str | None = None,
|
url: str | None = None,
|
||||||
arxiv: str | None = None,
|
arxiv: str | None = None,
|
||||||
|
citation: str | None = None,
|
||||||
test_branch: str | None = None,
|
test_branch: str | None = None,
|
||||||
):
|
):
|
||||||
v1 = get_hub_safe_version(repo_id, V16)
|
v1 = get_hub_safe_version(repo_id, V16)
|
||||||
|
@ -518,8 +519,8 @@ def convert_dataset(
|
||||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||||
for key in video_keys:
|
for key in video_keys:
|
||||||
features[key]["shape"] = (
|
features[key]["shape"] = (
|
||||||
videos_info[key].pop("video.width"),
|
|
||||||
videos_info[key].pop("video.height"),
|
videos_info[key].pop("video.height"),
|
||||||
|
videos_info[key].pop("video.width"),
|
||||||
videos_info[key].pop("video.channels"),
|
videos_info[key].pop("video.channels"),
|
||||||
)
|
)
|
||||||
features[key]["video_info"] = videos_info[key]
|
features[key]["video_info"] = videos_info[key]
|
||||||
|
@ -566,7 +567,7 @@ def convert_dataset(
|
||||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||||
convert_stats_to_json(v1x_dir, v20_dir)
|
convert_stats_to_json(v1x_dir, v20_dir)
|
||||||
card = create_lerobot_dataset_card(
|
card = create_lerobot_dataset_card(
|
||||||
tags=repo_tags, info=metadata_v2_0, license=license, citation=citation, arxiv=arxiv
|
tags=repo_tags, info=metadata_v2_0, license=license, url=url, citation=citation, arxiv=arxiv
|
||||||
)
|
)
|
||||||
|
|
||||||
with contextlib.suppress(EntryNotFoundError):
|
with contextlib.suppress(EntryNotFoundError):
|
||||||
|
|
|
@ -279,8 +279,8 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||||
|
|
||||||
video_info = {
|
video_info = {
|
||||||
"video.fps": fps,
|
"video.fps": fps,
|
||||||
"video.width": video_stream_info["width"],
|
|
||||||
"video.height": video_stream_info["height"],
|
"video.height": video_stream_info["height"],
|
||||||
|
"video.width": video_stream_info["width"],
|
||||||
"video.channels": pixel_channels,
|
"video.channels": pixel_channels,
|
||||||
"video.codec": video_stream_info["codec_name"],
|
"video.codec": video_stream_info["codec_name"],
|
||||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||||
|
|
|
@ -235,8 +235,8 @@ class ManipulatorRobot:
|
||||||
for cam_key, cam in self.cameras.items():
|
for cam_key, cam in self.cameras.items():
|
||||||
key = f"observation.images.{cam_key}"
|
key = f"observation.images.{cam_key}"
|
||||||
cam_ft[key] = {
|
cam_ft[key] = {
|
||||||
"shape": (cam.width, cam.height, cam.channels),
|
"shape": (cam.height, cam.width, cam.channels),
|
||||||
"names": ["width", "height", "channels"],
|
"names": ["height", "width", "channels"],
|
||||||
"info": None,
|
"info": None,
|
||||||
}
|
}
|
||||||
return cam_ft
|
return cam_ft
|
||||||
|
|
|
@ -27,15 +27,6 @@ from tests.fixtures.defaults import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | None = None) -> dict:
|
|
||||||
shapes = {}
|
|
||||||
if keys:
|
|
||||||
shapes.update({key: 10 for key in keys})
|
|
||||||
if camera_keys:
|
|
||||||
shapes.update({key: {"width": 100, "height": 70, "channels": 3} for key in camera_keys})
|
|
||||||
return shapes
|
|
||||||
|
|
||||||
|
|
||||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||||
tasks = {d["task_index"]: d["task"] for d in task_dicts}
|
tasks = {d["task_index"]: d["task"] for d in task_dicts}
|
||||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||||
|
@ -44,7 +35,7 @@ def get_task_index(task_dicts: dict, task: str) -> int:
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def img_tensor_factory():
|
def img_tensor_factory():
|
||||||
def _create_img_tensor(width=100, height=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
||||||
return torch.rand((channels, height, width), dtype=dtype)
|
return torch.rand((channels, height, width), dtype=dtype)
|
||||||
|
|
||||||
return _create_img_tensor
|
return _create_img_tensor
|
||||||
|
@ -52,7 +43,7 @@ def img_tensor_factory():
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def img_array_factory():
|
def img_array_factory():
|
||||||
def _create_img_array(width=100, height=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||||
if np.issubdtype(dtype, np.unsignedinteger):
|
if np.issubdtype(dtype, np.unsignedinteger):
|
||||||
# Int array in [0, 255] range
|
# Int array in [0, 255] range
|
||||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
||||||
|
@ -68,8 +59,8 @@ def img_array_factory():
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def img_factory(img_array_factory):
|
def img_factory(img_array_factory):
|
||||||
def _create_img(width=100, height=100) -> PIL.Image.Image:
|
def _create_img(height=100, width=100) -> PIL.Image.Image:
|
||||||
img_array = img_array_factory(width=width, height=height)
|
img_array = img_array_factory(height=height, width=width)
|
||||||
return PIL.Image.fromarray(img_array)
|
return PIL.Image.fromarray(img_array)
|
||||||
|
|
||||||
return _create_img
|
return _create_img
|
||||||
|
@ -259,7 +250,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||||
for key, ft in features.items():
|
for key, ft in features.items():
|
||||||
if ft["dtype"] == "image":
|
if ft["dtype"] == "image":
|
||||||
robot_cols[key] = [
|
robot_cols[key] = [
|
||||||
img_array_factory(width=ft["shapes"][0], height=ft["shapes"][1])
|
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
|
||||||
for _ in range(len(index_col))
|
for _ in range(len(index_col))
|
||||||
]
|
]
|
||||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||||
|
|
|
@ -16,8 +16,8 @@ DUMMY_MOTOR_FEATURES = {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
DUMMY_CAMERA_FEATURES = {
|
DUMMY_CAMERA_FEATURES = {
|
||||||
"laptop": {"shape": (640, 480, 3), "names": ["width", "height", "channels"], "info": None},
|
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||||
"phone": {"shape": (640, 480, 3), "names": ["width", "height", "channels"], "info": None},
|
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||||
}
|
}
|
||||||
DEFAULT_FPS = 30
|
DEFAULT_FPS = 30
|
||||||
DUMMY_VIDEO_INFO = {
|
DUMMY_VIDEO_INFO = {
|
||||||
|
|
|
@ -265,7 +265,7 @@ def test_wait_until_done(tmp_path, img_array_factory):
|
||||||
writer = AsyncImageWriter(num_processes=0, num_threads=4)
|
writer = AsyncImageWriter(num_processes=0, num_threads=4)
|
||||||
try:
|
try:
|
||||||
num_images = 100
|
num_images = 100
|
||||||
image_arrays = [img_array_factory(width=500, height=500) for _ in range(num_images)]
|
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
|
||||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
Loading…
Reference in New Issue