[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-24 13:41:27 +00:00 committed by Michel Aractingi
parent 2abbd60a0d
commit 0ea27704f6
123 changed files with 1161 additions and 3425 deletions

View File

@ -85,9 +85,7 @@ def get_directory_size(directory: Path) -> int:
return total_size return total_size
def load_original_frames( def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
imgs_dir: Path, timestamps: list[float], fps: int
) -> torch.Tensor:
frames = [] frames = []
for ts in timestamps: for ts in timestamps:
idx = int(ts * fps) idx = int(ts * fps)
@ -129,9 +127,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
hf_dataset = dataset.hf_dataset.with_format(None) hf_dataset = dataset.hf_dataset.with_format(None)
# We only save images from the first camera # We only save images from the first camera
img_keys = [ img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
key for key in hf_dataset.features if key.startswith("observation.image")
]
imgs_dataset = hf_dataset.select_columns(img_keys[0]) imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate( for i, item in enumerate(
@ -148,9 +144,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
break break
def sample_timestamps( def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
timestamps_mode: str, ep_num_images: int, fps: int
) -> list[float]:
# Start at 5 to allow for 2_frames_4_space and 6_frames # Start at 5 to allow for 2_frames_4_space and 6_frames
idx = random.randint(5, ep_num_images - 1) idx = random.randint(5, ep_num_images - 1)
match timestamps_mode: match timestamps_mode:
@ -175,9 +169,7 @@ def decode_video_frames(
backend: str, backend: str,
) -> torch.Tensor: ) -> torch.Tensor:
if backend in ["pyav", "video_reader"]: if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision( return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
video_path, timestamps, tolerance_s, backend
)
else: else:
raise NotImplementedError(backend) raise NotImplementedError(backend)
@ -204,9 +196,7 @@ def benchmark_decoding(
} }
with time_benchmark: with time_benchmark:
frames = decode_video_frames( frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend
)
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
with time_benchmark: with time_benchmark:
@ -215,18 +205,12 @@ def benchmark_decoding(
frames_np, original_frames_np = frames.numpy(), original_frames.numpy() frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
for i in range(num_frames): for i in range(num_frames):
result["mse_values"].append( result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
mean_squared_error(original_frames_np[i], frames_np[i])
)
result["psnr_values"].append( result["psnr_values"].append(
peak_signal_noise_ratio( peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
original_frames_np[i], frames_np[i], data_range=1.0
)
) )
result["ssim_values"].append( result["ssim_values"].append(
structural_similarity( structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0
)
) )
if save_frames and sample == 0: if save_frames and sample == 0:
@ -246,9 +230,7 @@ def benchmark_decoding(
# As these samples are independent, we run them in parallel threads to speed up the benchmark. # As these samples are independent, we run them in parallel threads to speed up the benchmark.
with ThreadPoolExecutor(max_workers=num_workers) as executor: with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, i) for i in range(num_samples)] futures = [executor.submit(process_sample, i) for i in range(num_samples)]
for future in tqdm( for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
as_completed(futures), total=num_samples, desc="samples", leave=False
):
result = future.result() result = future.result()
load_times_video_ms.append(result["load_time_video_ms"]) load_times_video_ms.append(result["load_time_video_ms"])
load_times_images_ms.append(result["load_time_images_ms"]) load_times_images_ms.append(result["load_time_images_ms"])
@ -312,9 +294,7 @@ def benchmark_encoding_decoding(
desc="decodings (timestamps_modes)", desc="decodings (timestamps_modes)",
leave=False, leave=False,
): ):
for backend in tqdm( for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
decoding_cfg["backends"], desc="decodings (backends)", leave=False
):
benchmark_row = benchmark_decoding( benchmark_row = benchmark_decoding(
imgs_dir, imgs_dir,
video_path, video_path,
@ -392,23 +372,14 @@ def main(
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_") imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
# We only use the first episode # We only use the first episode
save_first_episode(imgs_dir, dataset) save_first_episode(imgs_dir, dataset)
for key, values in tqdm( for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False
):
for value in tqdm(values, desc=f"encodings ({key})", leave=False): for value in tqdm(values, desc=f"encodings ({key})", leave=False):
encoding_cfg = BASE_ENCODING.copy() encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format encoding_cfg["pix_fmt"] = pixel_format
encoding_cfg[key] = value encoding_cfg[key] = value
args_path = Path( args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
"_".join(str(value) for value in encoding_cfg.values()) video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
)
video_path = (
output_dir
/ "videos"
/ args_path
/ f"{repo_id.replace('/', '_')}.mp4"
)
benchmark_table += benchmark_encoding_decoding( benchmark_table += benchmark_encoding_decoding(
dataset, dataset,
video_path, video_path,
@ -434,9 +405,7 @@ def main(
# Concatenate all results # Concatenate all results
df_list = [pd.read_csv(csv_path) for csv_path in file_paths] df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
concatenated_df = pd.concat(df_list, ignore_index=True) concatenated_df = pd.concat(df_list, ignore_index=True)
concatenated_path = ( concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
)
concatenated_df.to_csv(concatenated_path, header=True, index=False) concatenated_df.to_csv(concatenated_path, header=True, index=False)

View File

@ -43,10 +43,7 @@ pprint(lerobot.available_datasets)
# You can also browse through the datasets created/ported by the community on the hub using the hub api: # You can also browse through the datasets created/ported by the community on the hub using the hub api:
hub_api = HfApi() hub_api = HfApi()
repo_ids = [ repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
info.id
for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])
]
pprint(repo_ids) pprint(repo_ids)
# Or simply explore them in your web browser directly at: # Or simply explore them in your web browser directly at:
@ -61,9 +58,7 @@ ds_meta = LeRobotDatasetMetadata(repo_id)
# structure of the dataset without downloading the actual data yet (only metadata files — which are # structure of the dataset without downloading the actual data yet (only metadata files — which are
# lightweight). # lightweight).
print(f"Total number of episodes: {ds_meta.total_episodes}") print(f"Total number of episodes: {ds_meta.total_episodes}")
print( print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}"
)
print(f"Frames per second used during data collection: {ds_meta.fps}") print(f"Frames per second used during data collection: {ds_meta.fps}")
print(f"Robot type: {ds_meta.robot_type}") print(f"Robot type: {ds_meta.robot_type}")
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n") print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")

View File

@ -51,18 +51,12 @@ def main():
# - dataset stats: for normalization and denormalization of input/outputs # - dataset stats: for normalization and denormalization of input/outputs
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht") dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
features = dataset_to_policy_features(dataset_metadata.features) features = dataset_to_policy_features(dataset_metadata.features)
output_features = { output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION input_features = {key: ft for key, ft in features.items() if key not in output_features}
}
input_features = {
key: ft for key, ft in features.items() if key not in output_features
}
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example, # Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
# we'll just use the defaults and so no arguments other than input/output features need to be passed. # we'll just use the defaults and so no arguments other than input/output features need to be passed.
cfg = DiffusionConfig( cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
input_features=input_features, output_features=output_features
)
# We can now instantiate our policy with this config and the dataset stats. # We can now instantiate our policy with this config and the dataset stats.
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats) policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
@ -72,12 +66,8 @@ def main():
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames # Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
# which can differ for inputs, outputs and rewards (if there are some). # which can differ for inputs, outputs and rewards (if there are some).
delta_timestamps = { delta_timestamps = {
"observation.image": [ "observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
i / dataset_metadata.fps for i in cfg.observation_delta_indices "observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
],
"observation.state": [
i / dataset_metadata.fps for i in cfg.observation_delta_indices
],
"action": [i / dataset_metadata.fps for i in cfg.action_delta_indices], "action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
} }
@ -129,10 +119,7 @@ def main():
done = False done = False
while not done: while not done:
for batch in dataloader: for batch in dataloader:
batch = { batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
k: (v.to(device) if isinstance(v, torch.Tensor) else v)
for k, v in batch.items()
}
loss, _ = policy.forward(batch) loss, _ = policy.forward(batch)
loss.backward() loss.backward()
optimizer.step() optimizer.step()

View File

@ -48,14 +48,10 @@ transforms = v2.Compose(
) )
# Create another LeRobotDataset with the defined transformations # Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset( transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms)
dataset_repo_id, episodes=[0], image_transforms=transforms
)
# Get a frame from the transformed dataset # Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][ transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]
transformed_dataset.meta.camera_keys[0]
]
# Create a directory to store output images # Create a directory to store output images
output_dir = Path("outputs/image_transforms") output_dir = Path("outputs/image_transforms")

View File

@ -90,9 +90,7 @@ def main():
train_dataset = LeRobotDataset( train_dataset = LeRobotDataset(
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps "lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
) )
val_dataset = LeRobotDataset( val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
"lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps
)
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}") print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}") print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")

View File

@ -164,11 +164,7 @@ available_real_world_datasets = [
] ]
available_datasets = sorted( available_datasets = sorted(
set( set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
itertools.chain(
*available_datasets_per_env.values(), available_real_world_datasets
)
)
) )
# lists all available policies from `lerobot/common/policies` # lists all available policies from `lerobot/common/policies`
@ -209,13 +205,9 @@ available_policies_per_env = {
"aloha_real": ["act_aloha_real"], "aloha_real": ["act_aloha_real"],
} }
env_task_pairs = [ env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks
]
env_dataset_pairs = [ env_dataset_pairs = [
(env, dataset) (env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
for env, datasets in available_datasets_per_env.items()
for dataset in datasets
] ]
env_dataset_policy_triplets = [ env_dataset_policy_triplets = [
(env, dataset, policy) (env, dataset, policy)

View File

@ -46,18 +46,14 @@ def sample_indices(data_len: int) -> list[int]:
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist() return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
def auto_downsample_height_width( def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300
):
_, height, width = img.shape _, height, width = img.shape
if max(width, height) < max_size_threshold: if max(width, height) < max_size_threshold:
# no downsampling needed # no downsampling needed
return img return img
downsample_factor = ( downsample_factor = int(width / target_size) if width > height else int(height / target_size)
int(width / target_size) if width > height else int(height / target_size)
)
return img[:, ::downsample_factor, ::downsample_factor] return img[:, ::downsample_factor, ::downsample_factor]
@ -79,9 +75,7 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images return images
def get_feature_stats( def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
array: np.ndarray, axis: tuple, keepdims: bool
) -> dict[str, np.ndarray]:
return { return {
"min": np.min(array, axis=axis, keepdims=keepdims), "min": np.min(array, axis=axis, keepdims=keepdims),
"max": np.max(array, axis=axis, keepdims=keepdims), "max": np.max(array, axis=axis, keepdims=keepdims),
@ -91,9 +85,7 @@ def get_feature_stats(
} }
def compute_episode_stats( def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
episode_data: dict[str, list[str] | np.ndarray], features: dict
) -> dict:
ep_stats = {} ep_stats = {}
for key, data in episode_data.items(): for key, data in episode_data.items():
if features[key]["dtype"] == "string": if features[key]["dtype"] == "string":
@ -107,15 +99,12 @@ def compute_episode_stats(
axes_to_reduce = 0 # compute stats over the first axis axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array keepdims = data.ndim == 1 # keep as np.array
ep_stats[key] = get_feature_stats( ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims
)
# finally, we normalize and remove batch dim for images # finally, we normalize and remove batch dim for images
if features[key]["dtype"] in ["image", "video"]: if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = { ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
for k, v in ep_stats[key].items()
} }
return ep_stats return ep_stats
@ -130,17 +119,11 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead." f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
) )
if v.ndim == 0: if v.ndim == 0:
raise ValueError( raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
"Number of dimensions must be at least 1, and is 0 instead."
)
if k == "count" and v.shape != (1,): if k == "count" and v.shape != (1,):
raise ValueError( raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
f"Shape of 'count' must be (1), but is {v.shape} instead."
)
if "image" in fkey and k != "count" and v.shape != (3, 1, 1): if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
raise ValueError( raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead."
)
def aggregate_feature_stats( def aggregate_feature_stats(

View File

@ -58,9 +58,7 @@ def resolve_delta_timestamps(
if key == "action" and cfg.action_delta_indices is not None: if key == "action" and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith("observation.") and cfg.observation_delta_indices is not None: if key.startswith("observation.") and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [ delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
i / ds_meta.fps for i in cfg.observation_delta_indices
]
if len(delta_timestamps) == 0: if len(delta_timestamps) == 0:
delta_timestamps = None delta_timestamps = None
@ -81,9 +79,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
LeRobotDataset | MultiLeRobotDataset LeRobotDataset | MultiLeRobotDataset
""" """
image_transforms = ( image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
if cfg.dataset.image_transforms.enable
else None
) )
if isinstance(cfg.dataset.repo_id, str): if isinstance(cfg.dataset.repo_id, str):
@ -117,8 +113,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
if cfg.dataset.use_imagenet_stats: if cfg.dataset.use_imagenet_stats:
for key in dataset.meta.camera_keys: for key in dataset.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items(): for stats_type, stats in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor( dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
stats, dtype=torch.float32
)
return dataset return dataset

View File

@ -38,14 +38,10 @@ def safe_stop_image_writer(func):
return wrapper return wrapper
def image_array_to_pil_image( def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
image_array: np.ndarray, range_check: bool = True
) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images # TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim != 3: if image_array.ndim != 3:
raise ValueError( raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
f"The array has {image_array.ndim} dimensions, but 3 is expected for an image."
)
if image_array.shape[0] == 3: if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C) # Transpose from pytorch convention (C, H, W) to (H, W, C)
@ -131,9 +127,7 @@ class AsyncImageWriter:
self._stopped = False self._stopped = False
if num_threads <= 0 and num_processes <= 0: if num_threads <= 0 and num_processes <= 0:
raise ValueError( raise ValueError("Number of threads and processes must be greater than zero.")
"Number of threads and processes must be greater than zero."
)
if self.num_processes == 0: if self.num_processes == 0:
# Use threading # Use threading
@ -147,16 +141,12 @@ class AsyncImageWriter:
# Use multiprocessing # Use multiprocessing
self.queue = multiprocessing.JoinableQueue() self.queue = multiprocessing.JoinableQueue()
for _ in range(self.num_processes): for _ in range(self.num_processes):
p = multiprocessing.Process( p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
target=worker_process, args=(self.queue, self.num_threads)
)
p.daemon = True p.daemon = True
p.start() p.start()
self.processes.append(p) self.processes.append(p)
def save_image( def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
):
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time # Convert tensor to numpy array to minimize main process time
image = image.cpu().numpy() image = image.cpu().numpy()

View File

@ -108,9 +108,7 @@ class LeRobotDatasetMetadata:
self.episodes = load_episodes(self.root) self.episodes = load_episodes(self.root)
if self._version < packaging.version.parse("v2.1"): if self._version < packaging.version.parse("v2.1"):
self.stats = load_stats(self.root) self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats( self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
self.stats, self.episodes
)
else: else:
self.episodes_stats = load_episodes_stats(self.root) self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values())) self.stats = aggregate_stats(list(self.episodes_stats.values()))
@ -141,9 +139,7 @@ class LeRobotDatasetMetadata:
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index) ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.video_path.format( fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index
)
return Path(fpath) return Path(fpath)
def get_episode_chunk(self, ep_index: int) -> int: def get_episode_chunk(self, ep_index: int) -> int:
@ -187,11 +183,7 @@ class LeRobotDatasetMetadata:
@property @property
def camera_keys(self) -> list[str]: def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method).""" """Keys to access visual modalities (regardless of their storage method)."""
return [ return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
key
for key, ft in self.features.items()
if ft["dtype"] in ["video", "image"]
]
@property @property
def names(self) -> dict[str, list | dict]: def names(self) -> dict[str, list | dict]:
@ -240,9 +232,7 @@ class LeRobotDatasetMetadata:
Given a task in natural language, add it to the dictionary of tasks. Given a task in natural language, add it to the dictionary of tasks.
""" """
if task in self.task_to_task_index: if task in self.task_to_task_index:
raise ValueError( raise ValueError(f"The task '{task}' already exists and can't be added twice.")
f"The task '{task}' already exists and can't be added twice."
)
task_index = self.info["total_tasks"] task_index = self.info["total_tasks"]
self.task_to_task_index[task] = task_index self.task_to_task_index[task] = task_index
@ -285,11 +275,7 @@ class LeRobotDatasetMetadata:
write_episode(episode_dict, self.root) write_episode(episode_dict, self.root)
self.episodes_stats[episode_index] = episode_stats self.episodes_stats[episode_index] = episode_stats
self.stats = ( self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
aggregate_stats([self.stats, episode_stats])
if self.stats
else episode_stats
)
write_episode_stats(episode_index, episode_stats, self.root) write_episode_stats(episode_index, episode_stats, self.root)
def update_video_info(self) -> None: def update_video_info(self) -> None:
@ -299,9 +285,7 @@ class LeRobotDatasetMetadata:
""" """
for key in self.video_keys: for key in self.video_keys:
if not self.features[key].get("info", None): if not self.features[key].get("info", None):
video_path = self.root / self.get_video_file_path( video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
ep_index=0, vid_key=key
)
self.info["features"][key]["info"] = get_video_info(video_path) self.info["features"][key]["info"] = get_video_info(video_path)
def __repr__(self): def __repr__(self):
@ -353,17 +337,13 @@ class LeRobotDatasetMetadata:
# as this would break the dict flattening in the stats computation, which uses '/' as separator # as this would break the dict flattening in the stats computation, which uses '/' as separator
for key in features: for key in features:
if "/" in key: if "/" in key:
raise ValueError( raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
f"Feature names should not contain '/'. Found '/' in feature '{key}'."
)
features = {**features, **DEFAULT_FEATURES} features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.task_to_task_index = {}, {} obj.tasks, obj.task_to_task_index = {}, {}
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {} obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
obj.info = create_empty_dataset_info( obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
CODEBASE_VERSION, fps, robot_type, features, use_videos
)
if len(obj.video_keys) > 0 and not use_videos: if len(obj.video_keys) > 0 and not use_videos:
raise ValueError() raise ValueError()
write_json(obj.info, obj.root / INFO_PATH) write_json(obj.info, obj.root / INFO_PATH)
@ -494,9 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes self.episodes = episodes
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = ( self.video_backend = video_backend if video_backend else get_safe_default_codec()
video_backend if video_backend else get_safe_default_codec()
)
self.delta_indices = None self.delta_indices = None
# Unused attributes # Unused attributes
@ -509,39 +487,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta = LeRobotDatasetMetadata( self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
) )
if self.episodes is not None and self.meta._version >= packaging.version.parse( if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
"v2.1" episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
):
episodes_stats = [
self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes
]
self.stats = aggregate_stats(episodes_stats) self.stats = aggregate_stats(episodes_stats)
# Load actual data # Load actual data
try: try:
if force_cache_sync: if force_cache_sync:
raise FileNotFoundError raise FileNotFoundError
assert all( assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
(self.root / fpath).is_file()
for fpath in self.get_episodes_file_paths()
)
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError): except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision) self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos) self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index( self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
self.meta.episodes, self.episodes
)
# Check timestamps # Check timestamps
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy() timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy() episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()} ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync( check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s
)
# Setup delta_indices # Setup delta_indices
if self.delta_timestamps is not None: if self.delta_timestamps is not None:
@ -593,9 +560,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
else: else:
hub_api.upload_folder(**upload_kwargs) hub_api.upload_folder(**upload_kwargs)
if not hub_api.file_exists( if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch
):
card = create_lerobot_dataset_card( card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
) )
@ -603,12 +568,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
if tag_version: if tag_version:
with contextlib.suppress(RevisionNotFoundError): with contextlib.suppress(RevisionNotFoundError):
hub_api.delete_tag( hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset" hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
)
hub_api.create_tag(
self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
)
def pull_from_repo( def pull_from_repo(
self, self,
@ -640,11 +601,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def get_episodes_file_paths(self) -> list[Path]: def get_episodes_file_paths(self) -> list[Path]:
episodes = ( episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
self.episodes
if self.episodes is not None
else list(range(self.meta.total_episodes))
)
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes] fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
if len(self.meta.video_keys) > 0: if len(self.meta.video_keys) > 0:
video_files = [ video_files = [
@ -662,10 +619,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
path = str(self.root / "data") path = str(self.root / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split="train") hf_dataset = load_dataset("parquet", data_dir=path, split="train")
else: else:
files = [ files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
str(self.root / self.meta.get_data_file_path(ep_idx))
for ep_idx in self.episodes
]
hf_dataset = load_dataset("parquet", data_files=files, split="train") hf_dataset = load_dataset("parquet", data_files=files, split="train")
# TODO(aliberts): hf_dataset.set_format("torch") # TODO(aliberts): hf_dataset.set_format("torch")
@ -675,9 +629,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def create_hf_dataset(self) -> datasets.Dataset: def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features) features = get_hf_features_from_features(self.features)
ft_dict = {col: [] for col in features} ft_dict = {col: [] for col in features}
hf_dataset = datasets.Dataset.from_dict( hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
ft_dict, features=features, split="train"
)
# TODO(aliberts): hf_dataset.set_format("torch") # TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)
@ -691,20 +643,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property @property
def num_frames(self) -> int: def num_frames(self) -> int:
"""Number of frames in selected episodes.""" """Number of frames in selected episodes."""
return ( return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
len(self.hf_dataset)
if self.hf_dataset is not None
else self.meta.total_frames
)
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
"""Number of episodes selected.""" """Number of episodes selected."""
return ( return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
len(self.episodes)
if self.episodes is not None
else self.meta.total_episodes
)
@property @property
def features(self) -> dict[str, dict]: def features(self) -> dict[str, dict]:
@ -718,24 +662,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
else: else:
return get_hf_features_from_features(self.features) return get_hf_features_from_features(self.features)
def _get_query_indices( def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
self, idx: int, ep_idx: int
) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.episode_data_index["from"][ep_idx] ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx] ep_end = self.episode_data_index["to"][ep_idx]
query_indices = { query_indices = {
key: [ key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
max(ep_start.item(), min(ep_end.item() - 1, idx + delta))
for delta in delta_idx
]
for key, delta_idx in self.delta_indices.items() for key, delta_idx in self.delta_indices.items()
} }
padding = { # Pad values outside of current episode range padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor( f"{key}_is_pad": torch.BoolTensor(
[ [(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item())
for delta in delta_idx
]
) )
for key, delta_idx in self.delta_indices.items() for key, delta_idx in self.delta_indices.items()
} }
@ -763,9 +699,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if key not in self.meta.video_keys if key not in self.meta.video_keys
} }
def _query_videos( def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
self, query_timestamps: dict[str, list[float]], ep_idx: int
) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in Segmentation Fault. This probably happens because a memory reference to the video loader is created in
@ -774,9 +708,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {} item = {}
for vid_key, query_ts in query_timestamps.items(): for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames( frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
video_path, query_ts, self.tolerance_s, self.video_backend
)
item[vid_key] = frames.squeeze(0) item[vid_key] = frames.squeeze(0)
return item return item
@ -830,9 +762,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
def create_episode_buffer(self, episode_index: int | None = None) -> dict: def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = ( current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
self.meta.total_episodes if episode_index is None else episode_index
)
ep_buffer = {} ep_buffer = {}
# size and task are special cases that are not in self.features # size and task are special cases that are not in self.features
ep_buffer["size"] = 0 ep_buffer["size"] = 0
@ -841,17 +771,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_buffer[key] = current_ep_idx if key == "episode_index" else [] ep_buffer[key] = current_ep_idx if key == "episode_index" else []
return ep_buffer return ep_buffer
def _get_image_file_path( def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
self, episode_index: int, image_key: str, frame_index: int
) -> Path:
fpath = DEFAULT_IMAGE_PATH.format( fpath = DEFAULT_IMAGE_PATH.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index image_key=image_key, episode_index=episode_index, frame_index=frame_index
) )
return self.root / fpath return self.root / fpath
def _save_image( def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
) -> None:
if self.image_writer is None: if self.image_writer is None:
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
image = image.cpu().numpy() image = image.cpu().numpy()
@ -877,9 +803,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Automatically add frame_index and timestamp to episode buffer # Automatically add frame_index and timestamp to episode buffer
frame_index = self.episode_buffer["size"] frame_index = self.episode_buffer["size"]
timestamp = ( timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
)
self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp) self.episode_buffer["timestamp"].append(timestamp)
@ -930,9 +854,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_tasks = list(set(tasks)) episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"] episode_index = episode_buffer["episode_index"]
episode_buffer["index"] = np.arange( episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
self.meta.total_frames, self.meta.total_frames + episode_length
)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index) episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
# Add new tasks to the tasks dictionary # Add new tasks to the tasks dictionary
@ -942,9 +864,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta.add_task(task) self.meta.add_task(task)
# Given tasks in natural language, find their corresponding task indices # Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array( episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
[self.meta.get_task_index(task) for task in tasks]
)
for key, ft in self.features.items(): for key, ft in self.features.items():
# index, episode_index, task_index are already processed above, and image and video # index, episode_index, task_index are already processed above, and image and video
@ -994,9 +914,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None: def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features} episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict( ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
episode_dict, features=self.hf_features, split="train"
)
ep_dataset = embed_images(ep_dataset) ep_dataset = embed_images(ep_dataset)
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset]) self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
self.hf_dataset.set_transform(hf_transform_to_torch) self.hf_dataset.set_transform(hf_transform_to_torch)
@ -1115,9 +1033,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None obj.delta_timestamps = None
obj.delta_indices = None obj.delta_indices = None
obj.episode_data_index = None obj.episode_data_index = None
obj.video_backend = ( obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
video_backend if video_backend is not None else get_safe_default_codec()
)
return obj return obj
@ -1142,9 +1058,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
super().__init__() super().__init__()
self.repo_ids = repo_ids self.repo_ids = repo_ids
self.root = Path(root) if root else HF_LEROBOT_HOME self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = ( self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class. # are handled by this class.
self._datasets = [ self._datasets = [
@ -1223,13 +1137,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def features(self) -> datasets.Features: def features(self) -> datasets.Features:
features = {} features = {}
for dataset in self._datasets: for dataset in self._datasets:
features.update( features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
{
k: v
for k, v in dataset.hf_features.items()
if k not in self.disabled_features
}
)
return features return features
@property @property
@ -1290,9 +1198,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
continue continue
break break
else: else:
raise AssertionError( raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
"We expect the loop to break out as long as the index is within bounds."
)
item = self._datasets[dataset_idx][idx - start_idx] item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx) item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_features: for data_key in self.disabled_features:

View File

@ -131,9 +131,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
else: else:
self._delta_timestamps = None self._delta_timestamps = None
def _make_data_spec( def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
self, data_spec: dict[str, Any], buffer_capacity: int
) -> dict[str, dict[str, Any]]:
"""Makes the data spec for np.memmap.""" """Makes the data spec for np.memmap."""
if any(k.startswith("_") for k in data_spec): if any(k.startswith("_") for k in data_spec):
raise ValueError( raise ValueError(
@ -208,9 +206,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
# Shift the incoming indices if necessary. # Shift the incoming indices if necessary.
if self.num_frames > 0: if self.num_frames > 0:
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][ last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
next_index - 1
]
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1] last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1 data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1 data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
@ -245,11 +241,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
return len( return len(
np.unique( np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
self._data[OnlineBuffer.EPISODE_INDEX_KEY][
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]
]
)
) )
@property @property
@ -287,9 +279,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY], self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
) )
)[0] )[0]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][ episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
episode_data_indices
]
for data_key in self.delta_timestamps: for data_key in self.delta_timestamps:
# Note: The logic in this loop is copied from `load_previous_and_future_frames`. # Note: The logic in this loop is copied from `load_previous_and_future_frames`.
@ -306,8 +296,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
# Check violated query timestamps are all outside the episode range. # Check violated query timestamps are all outside the episode range.
assert ( assert (
(query_ts[is_pad] < episode_timestamps[0]) (query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
| (episode_timestamps[-1] < query_ts[is_pad])
).all(), ( ).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}" f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
") inside the episode range." ") inside the episode range."
@ -322,9 +311,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
def get_data_by_key(self, key: str) -> torch.Tensor: def get_data_by_key(self, key: str) -> torch.Tensor:
"""Returns all data for a given data key as a Tensor.""" """Returns all data for a given data key as a Tensor."""
return torch.from_numpy( return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]
)
def compute_sampler_weights( def compute_sampler_weights(
@ -355,19 +342,13 @@ def compute_sampler_weights(
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not - Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
included here to avoid adding complexity. included here to avoid adding complexity.
""" """
if len(offline_dataset) == 0 and ( if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
online_dataset is None or len(online_dataset) == 0 raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
):
raise ValueError(
"At least one of `offline_dataset` or `online_dataset` should be contain data."
)
if (online_dataset is None) ^ (online_sampling_ratio is None): if (online_dataset is None) ^ (online_sampling_ratio is None):
raise ValueError( raise ValueError(
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all." "`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
) )
offline_sampling_ratio = ( offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
0 if online_sampling_ratio is None else 1 - online_sampling_ratio
)
weights = [] weights = []

View File

@ -45,9 +45,7 @@ def concatenate_episodes(ep_dicts):
return data_dict return data_dict
def save_images_concurrently( def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
):
out_dir = Path(out_dir) out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
@ -57,10 +55,7 @@ def save_images_concurrently(
num_images = len(imgs_array) num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
[ [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
executor.submit(save_image, imgs_array[i], i, out_dir)
for i in range(num_images)
]
def get_default_encoding() -> dict: def get_default_encoding() -> dict:
@ -69,8 +64,7 @@ def get_default_encoding() -> dict:
return { return {
k: v.default k: v.default
for k, v in signature.parameters.items() for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
and k in ["vcodec", "pix_fmt", "g", "crf"]
} }

View File

@ -58,9 +58,7 @@ class RandomSubsetApply(Transform):
elif not isinstance(n_subset, int): elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None") raise TypeError("n_subset should be an int or None")
elif not (1 <= n_subset <= len(transforms)): elif not (1 <= n_subset <= len(transforms)):
raise ValueError( raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
f"n_subset should be in the interval [1, {len(transforms)}]"
)
self.transforms = transforms self.transforms = transforms
total = sum(p) total = sum(p)
@ -121,36 +119,26 @@ class SharpnessJitter(Transform):
def _check_input(self, sharpness): def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)): if isinstance(sharpness, (int, float)):
if sharpness < 0: if sharpness < 0:
raise ValueError( raise ValueError("If sharpness is a single number, it must be non negative.")
"If sharpness is a single number, it must be non negative."
)
sharpness = [1.0 - sharpness, 1.0 + sharpness] sharpness = [1.0 - sharpness, 1.0 + sharpness]
sharpness[0] = max(sharpness[0], 0.0) sharpness[0] = max(sharpness[0], 0.0)
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2: elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
sharpness = [float(v) for v in sharpness] sharpness = [float(v) for v in sharpness]
else: else:
raise TypeError( raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
f"{sharpness=} should be a single number or a sequence with length 2."
)
if not 0.0 <= sharpness[0] <= sharpness[1]: if not 0.0 <= sharpness[0] <= sharpness[1]:
raise ValueError( raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
f"sharpnesss values should be between (0., inf), but got {sharpness}."
)
return float(sharpness[0]), float(sharpness[1]) return float(sharpness[0]), float(sharpness[1])
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
sharpness_factor = ( sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
)
return {"sharpness_factor": sharpness_factor} return {"sharpness_factor": sharpness_factor}
def transform(self, inpt: Any, params: dict[str, Any]) -> Any: def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
sharpness_factor = params["sharpness_factor"] sharpness_factor = params["sharpness_factor"]
return self._call_kernel( return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
)
@dataclass @dataclass

View File

@ -52,15 +52,9 @@ STATS_PATH = "meta/stats.json"
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
TASKS_PATH = "meta/tasks.jsonl" TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = ( DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
) DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
DEFAULT_PARQUET_PATH = (
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
)
DEFAULT_IMAGE_PATH = (
"images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
)
DATASET_CARD_TEMPLATE = """ DATASET_CARD_TEMPLATE = """
--- ---
@ -135,9 +129,7 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
elif isinstance(value, (int, float)): elif isinstance(value, (int, float)):
serialized_dict[key] = value serialized_dict[key] = value
else: else:
raise NotImplementedError( raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
f"The value '{value}' of type '{type(value)}' is not supported."
)
return unflatten_dict(serialized_dict) return unflatten_dict(serialized_dict)
@ -216,10 +208,7 @@ def write_task(task_index: int, task: dict, local_dir: Path):
def load_tasks(local_dir: Path) -> tuple[dict, dict]: def load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / TASKS_PATH) tasks = load_jsonlines(local_dir / TASKS_PATH)
tasks = { tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
item["task_index"]: item["task"]
for item in sorted(tasks, key=lambda x: x["task_index"])
}
task_to_task_index = {task: task_index for task_index, task in tasks.items()} task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index return tasks, task_to_task_index
@ -230,10 +219,7 @@ def write_episode(episode: dict, local_dir: Path):
def load_episodes(local_dir: Path) -> dict: def load_episodes(local_dir: Path) -> dict:
episodes = load_jsonlines(local_dir / EPISODES_PATH) episodes = load_jsonlines(local_dir / EPISODES_PATH)
return { return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
item["episode_index"]: item
for item in sorted(episodes, key=lambda x: x["episode_index"])
}
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path): def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
@ -286,9 +272,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
elif first_item is None: elif first_item is None:
pass pass
else: else:
items_dict[key] = [ items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]
]
return items_dict return items_dict
@ -341,9 +325,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
Otherwise, will throw a `CompatibilityError`. Otherwise, will throw a `CompatibilityError`.
""" """
target_version = ( target_version = (
packaging.version.parse(version) packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
if not isinstance(version, packaging.version.Version)
else version
) )
hub_versions = get_repo_versions(repo_id) hub_versions = get_repo_versions(repo_id)
@ -364,16 +346,12 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
return f"v{target_version}" return f"v{target_version}"
compatibles = [ compatibles = [
v v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
for v in hub_versions
if v.major == target_version.major and v.minor <= target_version.minor
] ]
if compatibles: if compatibles:
return_version = max(compatibles) return_version = max(compatibles)
if return_version < target_version: if return_version < target_version:
logging.warning( logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
f"Revision {version} for {repo_id} not found, using version v{return_version}"
)
return f"v{return_version}" return f"v{return_version}"
lower_major = [v for v in hub_versions if v.major < target_version.major] lower_major = [v for v in hub_versions if v.major < target_version.major]
@ -480,9 +458,7 @@ def create_empty_dataset_info(
def get_episode_data_index( def get_episode_data_index(
episode_dicts: dict[dict], episodes: list[int] | None = None episode_dicts: dict[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
episode_lengths = { episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()
}
if episodes is not None: if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
@ -532,9 +508,7 @@ def check_timestamps_sync(
# Mask to ignore differences at the boundaries between episodes # Mask to ignore differences at the boundaries between episodes
mask = np.ones(len(diffs), dtype=bool) mask = np.ones(len(diffs), dtype=bool)
ignored_diffs = ( ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
episode_data_index["to"][:-1] - 1
) # indices at the end of each episode
mask[ignored_diffs] = False mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask] filtered_within_tolerance = within_tolerance[mask]
@ -580,14 +554,10 @@ def check_delta_timestamps(
""" """
outside_tolerance = {} outside_tolerance = {}
for key, delta_ts in delta_timestamps.items(): for key, delta_ts in delta_timestamps.items():
within_tolerance = [ within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts
]
if not all(within_tolerance): if not all(within_tolerance):
outside_tolerance[key] = [ outside_tolerance[key] = [
ts ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
for ts, is_within in zip(delta_ts, within_tolerance, strict=True)
if not is_within
] ]
if len(outside_tolerance) > 0: if len(outside_tolerance) > 0:
@ -605,9 +575,7 @@ def check_delta_timestamps(
return True return True
def get_delta_indices( def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
delta_timestamps: dict[str, list[float]], fps: int
) -> dict[str, list[int]]:
delta_indices = {} delta_indices = {}
for key, delta_ts in delta_timestamps.items(): for key, delta_ts in delta_timestamps.items():
delta_indices[key] = [round(d * fps) for d in delta_ts] delta_indices[key] = [round(d * fps) for d in delta_ts]
@ -672,9 +640,7 @@ def create_lerobot_dataset_card(
], ],
) )
card_template = ( card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
importlib.resources.files("lerobot.common.datasets") / "card_template.md"
).read_text()
return DatasetCard.from_template( return DatasetCard.from_template(
card_data=card_data, card_data=card_data,
@ -743,18 +709,14 @@ def validate_frame(frame: dict, features: dict):
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"} expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
actual_features = set(frame.keys()) actual_features = set(frame.keys())
error_message = validate_features_presence( error_message = validate_features_presence(actual_features, expected_features, optional_features)
actual_features, expected_features, optional_features
)
if "task" in frame: if "task" in frame:
error_message += validate_feature_string("task", frame["task"]) error_message += validate_feature_string("task", frame["task"])
common_features = actual_features & (expected_features | optional_features) common_features = actual_features & (expected_features | optional_features)
for name in common_features - {"task"}: for name in common_features - {"task"}:
error_message += validate_feature_dtype_and_shape( error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
name, features[name], frame[name]
)
if error_message: if error_message:
raise ValueError(error_message) raise ValueError(error_message)
@ -777,9 +739,7 @@ def validate_features_presence(
return error_message return error_message
def validate_feature_dtype_and_shape( def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
):
expected_dtype = feature["dtype"] expected_dtype = feature["dtype"]
expected_shape = feature["shape"] expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype): if is_valid_numpy_dtype_string(expected_dtype):
@ -789,9 +749,7 @@ def validate_feature_dtype_and_shape(
elif expected_dtype == "string": elif expected_dtype == "string":
return validate_feature_string(name, value) return validate_feature_string(name, value)
else: else:
raise NotImplementedError( raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
f"The feature dtype '{expected_dtype}' is not implemented yet."
)
def validate_feature_numpy_array( def validate_feature_numpy_array(
@ -813,17 +771,13 @@ def validate_feature_numpy_array(
return error_message return error_message
def validate_feature_image_or_video( def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
):
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = "" error_message = ""
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
actual_shape = value.shape actual_shape = value.shape
c, h, w = expected_shape c, h, w = expected_shape
if len(actual_shape) != 3 or ( if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
actual_shape != (c, h, w) and actual_shape != (h, w, c)
):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
elif isinstance(value, PILImage.Image): elif isinstance(value, PILImage.Image):
pass pass
@ -854,9 +808,7 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
) )
if episode_buffer["size"] == 0: if episode_buffer["size"] == 0:
raise ValueError( raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
"You must add one or several frames with `add_frame` before calling `add_episode`."
)
buffer_keys = set(episode_buffer.keys()) - {"task", "size"} buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
if not buffer_keys == set(features): if not buffer_keys == set(features):

View File

@ -218,9 +218,7 @@ def get_features_from_hf_dataset(
dtype = ft.feature.dtype dtype = ft.feature.dtype
shape = (ft.length,) shape = (ft.length,)
motor_names = ( motor_names = (
robot_config["names"][key] robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
if robot_config
else [f"motor_{i}" for i in range(ft.length)]
) )
assert len(motor_names) == shape[0] assert len(motor_names) == shape[0]
names = {"motors": motor_names} names = {"motors": motor_names}
@ -244,15 +242,11 @@ def get_features_from_hf_dataset(
return features return features
def add_task_index_by_episodes( def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
dataset: Dataset, tasks_by_episodes: dict
) -> tuple[Dataset, list[str]]:
df = dataset.to_pandas() df = dataset.to_pandas()
tasks = list(set(tasks_by_episodes.values())) tasks = list(set(tasks_by_episodes.values()))
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)} tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
episodes_to_task_index = { episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()
}
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int) df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
features = dataset.features features = dataset.features
@ -269,19 +263,10 @@ def add_task_index_from_tasks_col(
# HACK: This is to clean some of the instructions in our version of Open X datasets # HACK: This is to clean some of the instructions in our version of Open X datasets
prefix_to_clean = "tf.Tensor(b'" prefix_to_clean = "tf.Tensor(b'"
suffix_to_clean = "', shape=(), dtype=string)" suffix_to_clean = "', shape=(), dtype=string)"
df[tasks_col] = ( df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
df[tasks_col]
.str.removeprefix(prefix_to_clean)
.str.removesuffix(suffix_to_clean)
)
# Create task_index col # Create task_index col
tasks_by_episode = ( tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
df.groupby("episode_index")[tasks_col]
.unique()
.apply(lambda x: x.tolist())
.to_dict()
)
tasks = df[tasks_col].unique().tolist() tasks = df[tasks_col].unique().tolist()
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)} tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int) df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
@ -306,9 +291,7 @@ def split_parquet_by_episodes(
for ep_chunk in range(total_chunks): for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format( chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
episode_chunk=ep_chunk
)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True) (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end): for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
@ -340,9 +323,7 @@ def move_videos(
videos_moved = False videos_moved = False
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")] video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
if len(video_files) == 0: if len(video_files) == 0:
video_files = [ video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")
]
videos_moved = True # Videos have already been moved videos_moved = True # Videos have already been moved
assert len(video_files) == total_episodes * len(video_keys) assert len(video_files) == total_episodes * len(video_keys)
@ -373,9 +354,7 @@ def move_videos(
target_path = DEFAULT_VIDEO_PATH.format( target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
) )
video_file = V1_VIDEO_FILE.format( video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
video_key=vid_key, episode_index=ep_idx
)
if len(video_dirs) == 1: if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file video_path = video_dirs[0] / video_file
else: else:
@ -392,9 +371,7 @@ def move_videos(
subprocess.run(["git", "push"], cwd=work_dir, check=True) subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_lfs_video_files_tracking( def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
work_dir: Path, lfs_untracked_videos: list[str]
) -> None:
""" """
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case, HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
there's no other option than to download the actual files and reupload them with lfs tracking. there's no other option than to download the actual files and reupload them with lfs tracking.
@ -418,14 +395,10 @@ def fix_lfs_video_files_tracking(
subprocess.run(["git", "push"], cwd=work_dir, check=True) subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_gitattributes( def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path
) -> None:
shutil.copyfile(clean_gittatributes, current_gittatributes) shutil.copyfile(clean_gittatributes, current_gittatributes)
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True) subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
subprocess.run( subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True
)
subprocess.run(["git", "push"], cwd=work_dir, check=True) subprocess.run(["git", "push"], cwd=work_dir, check=True)
@ -462,9 +435,7 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st
return [f for f in video_files if f not in lfs_tracked_files] return [f for f in video_files if f not in lfs_tracked_files]
def get_videos_info( def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
repo_id: str, local_dir: Path, video_keys: list[str], branch: str
) -> dict:
# Assumes first episode # Assumes first episode
video_files = [ video_files = [
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
@ -539,31 +510,19 @@ def convert_dataset(
if single_task: if single_task:
tasks_by_episodes = dict.fromkeys(episode_indices, single_task) tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = { tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
elif tasks_path: elif tasks_path:
tasks_by_episodes = load_json(tasks_path) tasks_by_episodes = load_json(tasks_path)
tasks_by_episodes = { tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()
}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = { tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
elif tasks_col: elif tasks_col:
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col( dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
dataset, tasks_col
)
else: else:
raise ValueError raise ValueError
assert set(tasks) == { assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
}
tasks = [
{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)
]
write_jsonlines(tasks, v20_dir / TASKS_PATH) write_jsonlines(tasks, v20_dir / TASKS_PATH)
features["task_index"] = { features["task_index"] = {
"dtype": "int64", "dtype": "int64",
@ -593,9 +552,7 @@ def convert_dataset(
clean_gitattr, clean_gitattr,
branch, branch,
) )
videos_info = get_videos_info( videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
repo_id, v1x_dir, video_keys=video_keys, branch=branch
)
for key in video_keys: for key in video_keys:
features[key]["shape"] = ( features[key]["shape"] = (
videos_info[key].pop("video.height"), videos_info[key].pop("video.height"),
@ -603,22 +560,15 @@ def convert_dataset(
videos_info[key].pop("video.channels"), videos_info[key].pop("video.channels"),
) )
features[key]["video_info"] = videos_info[key] features[key]["video_info"] = videos_info[key]
assert math.isclose( assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3
)
if "encoding" in metadata_v1: if "encoding" in metadata_v1:
assert ( assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
videos_info[key]["video.pix_fmt"]
== metadata_v1["encoding"]["pix_fmt"]
)
else: else:
assert metadata_v1.get("video", 0) == 0 assert metadata_v1.get("video", 0) == 0
videos_info = None videos_info = None
# Split data into 1 parquet file by episode # Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes( episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
dataset, total_episodes, total_chunks, v20_dir
)
if robot_config is not None: if robot_config is not None:
robot_type = robot_config.type robot_type = robot_config.type
@ -656,14 +606,10 @@ def convert_dataset(
} }
write_json(metadata_v2_0, v20_dir / INFO_PATH) write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir) convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card( card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder( hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder( hub_api.delete_folder(
@ -674,9 +620,7 @@ def convert_dataset(
) )
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder( hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch
)
hub_api.upload_folder( hub_api.upload_folder(
repo_id=repo_id, repo_id=repo_id,

View File

@ -35,30 +35,22 @@ def fix_dataset(repo_id: str) -> str:
dataset_info = get_dataset_config_info(repo_id, "default") dataset_info = get_dataset_config_info(repo_id, "default")
with SuppressWarnings(): with SuppressWarnings():
lerobot_metadata = LeRobotDatasetMetadata( lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
repo_id, revision=V20, force_cache_sync=True
)
meta_features = { meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"
}
parquet_features = set(dataset_info.features) parquet_features = set(dataset_info.features)
diff_parquet_meta = parquet_features - meta_features diff_parquet_meta = parquet_features - meta_features
diff_meta_parquet = meta_features - parquet_features diff_meta_parquet = meta_features - parquet_features
if diff_parquet_meta: if diff_parquet_meta:
raise ValueError( raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
f"In parquet not in info.json: {parquet_features - meta_features}"
)
if not diff_meta_parquet: if not diff_meta_parquet:
return f"{repo_id}: skipped (no diff)" return f"{repo_id}: skipped (no diff)"
if diff_meta_parquet: if diff_meta_parquet:
logging.warning( logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
f"In info.json not in parquet: {meta_features - parquet_features}"
)
assert diff_meta_parquet == {"language_instruction"} assert diff_meta_parquet == {"language_instruction"}
lerobot_metadata.features.pop("language_instruction") lerobot_metadata.features.pop("language_instruction")
write_info(lerobot_metadata.info, lerobot_metadata.root) write_info(lerobot_metadata.info, lerobot_metadata.root)

View File

@ -99,9 +99,7 @@ def convert_dataset(
repo_type="dataset", repo_type="dataset",
) )
hub_api.create_tag( hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -26,9 +26,7 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import write_episode_stats from lerobot.common.datasets.utils import write_episode_stats
def sample_episode_video_frames( def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
dataset: LeRobotDataset, episode_index: int, ft_key: str
) -> np.ndarray:
ep_len = dataset.meta.episodes[episode_index]["length"] ep_len = dataset.meta.episodes[episode_index]["length"]
sampled_indices = sample_indices(ep_len) sampled_indices = sample_indices(ep_len)
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices}) query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
@ -51,14 +49,11 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0 axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1 keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
ep_stats[key] = get_feature_stats( ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
ep_ft_data, axis=axes_to_reduce, keepdims=keepdims
)
if ft["dtype"] in ["image", "video"]: # remove batch dim if ft["dtype"] in ["image", "video"]: # remove batch dim
ep_stats[key] = { ep_stats[key] = {
k: v if k == "count" else np.squeeze(v, axis=0) k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
for k, v in ep_stats[key].items()
} }
dataset.meta.episodes_stats[ep_idx] = ep_stats dataset.meta.episodes_stats[ep_idx] = ep_stats

View File

@ -65,9 +65,7 @@ def decode_video_frames(
if backend == "torchcodec": if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]: elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision( return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
video_path, timestamps, tolerance_s, backend
)
else: else:
raise ValueError(f"Unsupported video backend: {backend}") raise ValueError(f"Unsupported video backend: {backend}")
@ -346,9 +344,7 @@ def get_audio_info(video_path: Path | str) -> dict:
"json", "json",
str(video_path), str(video_path),
] ]
result = subprocess.run( result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode != 0: if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}") raise RuntimeError(f"Error running ffprobe: {result.stderr}")
@ -362,9 +358,7 @@ def get_audio_info(video_path: Path | str) -> dict:
"has_audio": True, "has_audio": True,
"audio.channels": audio_stream_info.get("channels", None), "audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None), "audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"]) "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
if audio_stream_info.get("bit_rate")
else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"]) "audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate") if audio_stream_info.get("sample_rate")
else None, else None,
@ -386,9 +380,7 @@ def get_video_info(video_path: Path | str) -> dict:
"json", "json",
str(video_path), str(video_path),
] ]
result = subprocess.run( result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode != 0: if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}") raise RuntimeError(f"Error running ffprobe: {result.stderr}")

View File

@ -61,16 +61,10 @@ class AlohaEnv(EnvConfig):
def __post_init__(self): def __post_init__(self):
if self.obs_type == "pixels": if self.obs_type == "pixels":
self.features["top"] = PolicyFeature( self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
type=FeatureType.VISUAL, shape=(480, 640, 3)
)
elif self.obs_type == "pixels_agent_pos": elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature( self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
type=FeatureType.STATE, shape=(14,) self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
)
self.features["pixels/top"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(480, 640, 3)
)
@property @property
def gym_kwargs(self) -> dict: def gym_kwargs(self) -> dict:
@ -108,13 +102,9 @@ class PushtEnv(EnvConfig):
def __post_init__(self): def __post_init__(self):
if self.obs_type == "pixels_agent_pos": if self.obs_type == "pixels_agent_pos":
self.features["pixels"] = PolicyFeature( self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
type=FeatureType.VISUAL, shape=(384, 384, 3)
)
elif self.obs_type == "environment_state_agent_pos": elif self.obs_type == "environment_state_agent_pos":
self.features["environment_state"] = PolicyFeature( self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
type=FeatureType.ENV, shape=(16,)
)
@property @property
def gym_kwargs(self) -> dict: def gym_kwargs(self) -> dict:
@ -153,9 +143,7 @@ class XarmEnv(EnvConfig):
def __post_init__(self): def __post_init__(self):
if self.obs_type == "pixels_agent_pos": if self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature( self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
type=FeatureType.STATE, shape=(4,)
)
@property @property
def gym_kwargs(self) -> dict: def gym_kwargs(self) -> dict:

View File

@ -32,9 +32,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
raise ValueError(f"Policy type '{env_type}' is not available.") raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env( def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
) -> gym.vector.VectorEnv | None:
"""Makes a gym vector environment according to the config. """Makes a gym vector environment according to the config.
Args: Args:
@ -58,9 +56,7 @@ def make_env(
try: try:
importlib.import_module(package_name) importlib.import_module(package_name)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
print( print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`"
)
raise e raise e
gym_handle = f"{package_name}/{cfg.task}" gym_handle = f"{package_name}/{cfg.task}"
@ -68,18 +64,13 @@ def make_env(
# batched version of the env that returns an observation of shape (b, c) # batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
env = env_cls( env = env_cls(
[ [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs)
for _ in range(n_envs)
]
) )
return env return env
def make_maniskill_env( def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
cfg: DictConfig, n_envs: int | None = None
) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment""" """Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
@ -96,9 +87,7 @@ def make_maniskill_env(
# state should have the size of 25 # state should have the size of 25
# env = ConvertToLeRobotEnv(env, n_envs) # env = ConvertToLeRobotEnv(env, n_envs)
# env = PixelWrapper(cfg, env, n_envs) # env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = ( env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
50 # gym_utils.find_max_episode_steps_value(env)
)
env.unwrapped.metadata["render_fps"] = 20 env.unwrapped.metadata["render_fps"] = 20
return env return env
@ -125,11 +114,7 @@ class PixelWrapper(gym.Wrapper):
def _get_obs(self, obs): def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2) frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame) self._frames.append(frame)
return { return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
self.env.device
)
}
def reset(self, seed): def reset(self, seed):
obs, info = self.env.reset() # (seed=seed) obs, info = self.env.reset() # (seed=seed)
@ -164,9 +149,7 @@ class ConvertToLeRobotEnv(gym.Wrapper):
images = torch.concat(images, axis=-1) images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data # flatten the rest of the data which should just be state data
observation = common.flatten_state_dict( observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
observation, use_torch=True, device=self.base_env.device
)
ret = dict() ret = dict()
ret["state"] = observation ret["state"] = observation
ret["pixels"] = images ret["pixels"] = images

View File

@ -50,9 +50,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# sanity check that images are channel last # sanity check that images are channel last
_, h, w, c = img.shape _, h, w, c = img.shape
assert c < h and c < w, ( assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
f"expect channel last images, but instead got {img.shape=}"
)
# sanity check that images are uint8 # sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
@ -85,9 +83,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
for key, ft in env_cfg.features.items(): for key, ft in env_cfg.features.items():
if ft.type is FeatureType.VISUAL: if ft.type is FeatureType.VISUAL:
if len(ft.shape) != 3: if len(ft.shape) != 3:
raise ValueError( raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
f"Number of dimensions of {key} != 3 (shape={ft.shape})"
)
shape = get_channel_first_image_shape(ft.shape) shape = get_channel_first_image_shape(ft.shape)
feature = PolicyFeature(type=ft.type, shape=shape) feature = PolicyFeature(type=ft.type, shape=shape)

View File

@ -34,13 +34,7 @@ def make_optimizer_and_scheduler(
Returns: Returns:
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`. tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
""" """
params = ( params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
policy.get_optim_params()
if cfg.use_policy_training_preset
else policy.parameters()
)
optimizer = cfg.optimizer.build(params) optimizer = cfg.optimizer.build(params)
lr_scheduler = ( lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
)
return optimizer, lr_scheduler return optimizer, lr_scheduler

View File

@ -102,9 +102,7 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS) write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
def load_optimizer_state( def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
optimizer: torch.optim.Optimizer, save_dir: Path
) -> torch.optim.Optimizer:
current_state_dict = optimizer.state_dict() current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE) flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state) state = unflatten_dict(flat_state)

View File

@ -36,9 +36,7 @@ class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
return self.get_choice_name(self.__class__) return self.get_choice_name(self.__class__)
@abc.abstractmethod @abc.abstractmethod
def build( def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
self, optimizer: Optimizer, num_training_steps: int
) -> LRScheduler | None:
raise NotImplementedError raise NotImplementedError
@ -79,11 +77,7 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
) )
return max( return max(
0.0, 0.0,
0.5 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)),
* (
1.0
+ math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)
),
) )
return LambdaLR(optimizer, lr_lambda, -1) return LambdaLR(optimizer, lr_lambda, -1)
@ -111,9 +105,7 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
def cosine_decay_schedule(current_step): def cosine_decay_schedule(current_step):
step = min(current_step, self.num_decay_steps) step = min(current_step, self.num_decay_steps)
cosine_decay = 0.5 * ( cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
1 + math.cos(math.pi * step / self.num_decay_steps)
)
alpha = self.decay_lr / self.peak_lr alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha decayed = (1 - alpha) * cosine_decay + alpha
return decayed return decayed
@ -132,8 +124,6 @@ def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler: def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
state_dict = deserialize_json_into_object( state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
save_dir / SCHEDULER_STATE, scheduler.state_dict()
)
scheduler.load_state_dict(state_dict) scheduler.load_state_dict(state_dict)
return scheduler return scheduler

View File

@ -171,9 +171,7 @@ class ACTConfig(PreTrainedConfig):
def validate_features(self) -> None: def validate_features(self) -> None:
if not self.image_features and not self.env_state_feature: if not self.image_features and not self.env_state_feature:
raise ValueError( raise ValueError("You must provide at least one image or the environment state among the inputs.")
"You must provide at least one image or the environment state among the inputs."
)
@property @property
def observation_delta_indices(self) -> None: def observation_delta_indices(self) -> None:

View File

@ -63,9 +63,7 @@ class ACTPolicy(PreTrainedPolicy):
config.validate_features() config.validate_features()
self.config = config self.config = config
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize( self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
@ -76,9 +74,7 @@ class ACTPolicy(PreTrainedPolicy):
self.model = ACT(config) self.model = ACT(config)
if config.temporal_ensemble_coeff is not None: if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler( self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
config.temporal_ensemble_coeff, config.chunk_size
)
self.reset() self.reset()
@ -122,12 +118,8 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self.config.image_features: if self.config.image_features:
batch = dict( batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch batch["observation.images"] = [batch[key] for key in self.config.image_features]
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [
batch[key] for key in self.config.image_features
]
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions # If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over. # we are ensembling over.
@ -154,19 +146,14 @@ class ACTPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self.config.image_features: if self.config.image_features:
batch = dict( batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch batch["observation.images"] = [batch[key] for key in self.config.image_features]
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [
batch[key] for key in self.config.image_features
]
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = ( l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
* ~batch["action_is_pad"].unsqueeze(-1)
).mean() ).mean()
loss_dict = {"l1_loss": l1_loss.item()} loss_dict = {"l1_loss": l1_loss.item()}
@ -176,12 +163,7 @@ class ACTPolicy(PreTrainedPolicy):
# KL-divergence per batch element, then take the mean over the batch. # KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details). # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = ( mean_kld = (
( (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
-0.5
* (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())
)
.sum(-1)
.mean()
) )
loss_dict["kld_loss"] = mean_kld.item() loss_dict["kld_loss"] = mean_kld.item()
loss = l1_loss + mean_kld * self.config.kl_weight loss = l1_loss + mean_kld * self.config.kl_weight
@ -235,9 +217,7 @@ class ACTTemporalEnsembler:
``` ```
""" """
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.ensemble_weights = torch.exp( self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
-temporal_ensemble_coeff * torch.arange(chunk_size)
)
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset() self.reset()
@ -253,9 +233,7 @@ class ACTTemporalEnsembler:
time steps, and pop/return the next batch of actions in the sequence. time steps, and pop/return the next batch of actions in the sequence.
""" """
self.ensemble_weights = self.ensemble_weights.to(device=actions.device) self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to( self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
device=actions.device
)
if self.ensembled_actions is None: if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first # Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode. # time step of the episode.
@ -270,22 +248,12 @@ class ACTTemporalEnsembler:
else: else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries. # the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[ self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
self.ensembled_actions_count - 1 self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
] self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
self.ensembled_actions += ( self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
)
self.ensembled_actions /= self.ensemble_weights_cumsum[
self.ensembled_actions_count
]
self.ensembled_actions_count = torch.clamp(
self.ensembled_actions_count + 1, max=self.chunk_size
)
# The last action, which has no prior online average, needs to get concatenated onto the end. # The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat( self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
[self.ensembled_actions, actions[:, -1:]], dim=1
)
self.ensembled_actions_count = torch.cat( self.ensembled_actions_count = torch.cat(
[ [
self.ensembled_actions_count, self.ensembled_actions_count,
@ -356,9 +324,7 @@ class ACT(nn.Module):
config.dim_model, config.dim_model,
) )
# Projection layer from the VAE encoder's output to the latent distribution's parameter space. # Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear( self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
config.dim_model, config.latent_dim * 2
)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension. # dimension.
num_input_token_encoder = 1 + config.chunk_size num_input_token_encoder = 1 + config.chunk_size
@ -366,9 +332,7 @@ class ACT(nn.Module):
num_input_token_encoder += 1 num_input_token_encoder += 1
self.register_buffer( self.register_buffer(
"vae_encoder_pos_enc", "vae_encoder_pos_enc",
create_sinusoidal_pos_embedding( create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
num_input_token_encoder, config.dim_model
).unsqueeze(0),
) )
# Backbone for image feature extraction. # Backbone for image feature extraction.
@ -385,9 +349,7 @@ class ACT(nn.Module):
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# feature map). # feature map).
# Note: The forward method of this returns a dict: {"feature_map": output}. # Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter( self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
backbone_model, return_layers={"layer4": "feature_map"}
)
# Transformer (acts as VAE decoder when training with the variational objective). # Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config) self.encoder = ACTEncoder(config)
@ -416,18 +378,14 @@ class ACT(nn.Module):
n_1d_tokens += 1 n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.config.image_features: if self.config.image_features:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d( self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
config.dim_model // 2
)
# Transformer decoder. # Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder. # Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear( self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
config.dim_model, self.config.action_feature.shape[0]
)
self._reset_parameters() self._reset_parameters()
@ -437,9 +395,7 @@ class ACT(nn.Module):
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
def forward( def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
self, batch: dict[str, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder). """A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure: `batch` should have the following structure:
@ -475,13 +431,9 @@ class ACT(nn.Module):
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D) ) # (B, 1, D)
if self.config.robot_state_feature: if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj( robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
batch["observation.state"]
)
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj( action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
batch["action"]
) # (B, S, D)
if self.config.robot_state_feature: if self.config.robot_state_feature:
vae_encoder_input = [ vae_encoder_input = [
@ -526,26 +478,20 @@ class ACT(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros. # When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros( latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
[batch_size, self.config.latent_dim], dtype=torch.float32 batch["observation.state"].device
).to(batch["observation.state"].device) )
# Prepare transformer encoder inputs. # Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)] encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list( encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)
)
# Robot state token. # Robot state token.
if self.config.robot_state_feature: if self.config.robot_state_feature:
encoder_in_tokens.append( encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
self.encoder_robot_state_input_proj(batch["observation.state"])
)
# Environment state token. # Environment state token.
if self.config.env_state_feature: if self.config.env_state_feature:
encoder_in_tokens.append( encoder_in_tokens.append(
self.encoder_env_state_input_proj( self.encoder_env_state_input_proj(batch["observation.environment_state"])
batch["observation.environment_state"]
)
) )
# Camera observation features and positional embeddings. # Camera observation features and positional embeddings.
@ -556,9 +502,7 @@ class ACT(nn.Module):
# For a list of images, the H and W may vary but H*W is constant. # For a list of images, the H and W may vary but H*W is constant.
for img in batch["observation.images"]: for img in batch["observation.images"]:
cam_features = self.backbone(img)["feature_map"] cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to( cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
dtype=cam_features.dtype
)
cam_features = self.encoder_img_feat_input_proj(cam_features) cam_features = self.encoder_img_feat_input_proj(cam_features)
# Rearrange features to (sequence, batch, dim). # Rearrange features to (sequence, batch, dim).
@ -604,14 +548,8 @@ class ACTEncoder(nn.Module):
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False): def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
super().__init__() super().__init__()
self.is_vae_encoder = is_vae_encoder self.is_vae_encoder = is_vae_encoder
num_layers = ( num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
config.n_vae_encoder_layers self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
if self.is_vae_encoder
else config.n_encoder_layers
)
self.layers = nn.ModuleList(
[ACTEncoderLayer(config) for _ in range(num_layers)]
)
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward( def forward(
@ -629,9 +567,7 @@ class ACTEncoder(nn.Module):
class ACTEncoderLayer(nn.Module): class ACTEncoderLayer(nn.Module):
def __init__(self, config: ACTConfig): def __init__(self, config: ACTConfig):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention( self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
config.dim_model, config.n_heads, dropout=config.dropout
)
# Feed forward layers. # Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@ -646,9 +582,7 @@ class ACTEncoderLayer(nn.Module):
self.activation = get_activation_fn(config.feedforward_activation) self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm self.pre_norm = config.pre_norm
def forward( def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
skip = x skip = x
if self.pre_norm: if self.pre_norm:
x = self.norm1(x) x = self.norm1(x)
@ -673,9 +607,7 @@ class ACTDecoder(nn.Module):
def __init__(self, config: ACTConfig): def __init__(self, config: ACTConfig):
"""Convenience module for running multiple decoder layers followed by normalization.""" """Convenience module for running multiple decoder layers followed by normalization."""
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
[ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
)
self.norm = nn.LayerNorm(config.dim_model) self.norm = nn.LayerNorm(config.dim_model)
def forward( def forward(
@ -700,12 +632,8 @@ class ACTDecoder(nn.Module):
class ACTDecoderLayer(nn.Module): class ACTDecoderLayer(nn.Module):
def __init__(self, config: ACTConfig): def __init__(self, config: ACTConfig):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention( self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
config.dim_model, config.n_heads, dropout=config.dropout self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
)
self.multihead_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
# Feed forward layers. # Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@ -746,9 +674,7 @@ class ACTDecoderLayer(nn.Module):
if self.pre_norm: if self.pre_norm:
x = self.norm1(x) x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
x = self.self_attn(q, k, value=x)[ x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
0
] # select just the output, not the attention weights
x = skip + self.dropout1(x) x = skip + self.dropout1(x)
if self.pre_norm: if self.pre_norm:
skip = x skip = x
@ -785,14 +711,9 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso
""" """
def get_position_angle_vec(position): def get_position_angle_vec(position):
return [ return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
position / np.power(10000, 2 * (hid_j // 2) / dimension)
for hid_j in range(dimension)
]
sinusoid_table = np.array( sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
[get_position_angle_vec(pos_i) for pos_i in range(num_positions)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.from_numpy(sinusoid_table).float() return torch.from_numpy(sinusoid_table).float()
@ -837,9 +758,7 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** ( inverse_frequency = self._temperature ** (
2 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
* (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2)
/ self.dimension
) )
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
@ -847,15 +766,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
# Note: this stack then flatten operation results in interleaved sine and cosine terms. # Note: this stack then flatten operation results in interleaved sine and cosine terms.
# pos_embed_x and pos_embed_y are (1, H, W, C // 2). # pos_embed_x and pos_embed_y are (1, H, W, C // 2).
pos_embed_x = torch.stack( pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
(x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1 pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
).flatten(3) pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
pos_embed_y = torch.stack(
(y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(
0, 3, 1, 2
) # (1, C, H, W)
return pos_embed return pos_embed

View File

@ -205,16 +205,11 @@ class DiffusionConfig(PreTrainedConfig):
def validate_features(self) -> None: def validate_features(self) -> None:
if len(self.image_features) == 0 and self.env_state_feature is None: if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError( raise ValueError("You must provide at least one image or the environment state among the inputs.")
"You must provide at least one image or the environment state among the inputs."
)
if self.crop_shape is not None: if self.crop_shape is not None:
for key, image_ft in self.image_features.items(): for key, image_ft in self.image_features.items():
if ( if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
self.crop_shape[0] > image_ft.shape[1]
or self.crop_shape[1] > image_ft.shape[2]
):
raise ValueError( raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for " f"for `crop_shape` and {image_ft.shape} for "

View File

@ -70,9 +70,7 @@ class DiffusionPolicy(PreTrainedPolicy):
config.validate_features() config.validate_features()
self.config = config self.config = config
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize( self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
@ -99,9 +97,7 @@ class DiffusionPolicy(PreTrainedPolicy):
if self.config.image_features: if self.config.image_features:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.config.env_state_feature: if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque( self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
maxlen=self.config.n_obs_steps
)
@torch.no_grad @torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@ -127,9 +123,7 @@ class DiffusionPolicy(PreTrainedPolicy):
""" """
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self.config.image_features: if self.config.image_features:
batch = dict( batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack( batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4 [batch[key] for key in self.config.image_features], dim=-4
) )
@ -138,11 +132,7 @@ class DiffusionPolicy(PreTrainedPolicy):
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
# stack n latest observations from the queue # stack n latest observations from the queue
batch = { batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
actions = self.diffusion.generate_actions(batch) actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary? # TODO(rcadene): make above methods return output dictionary?
@ -157,9 +147,7 @@ class DiffusionPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self.config.image_features: if self.config.image_features:
batch = dict( batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack( batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4 [batch[key] for key in self.config.image_features], dim=-4
) )
@ -201,9 +189,7 @@ class DiffusionModel(nn.Module):
if self.config.env_state_feature: if self.config.env_state_feature:
global_cond_dim += self.config.env_state_feature.shape[0] global_cond_dim += self.config.env_state_feature.shape[0]
self.unet = DiffusionConditionalUnet1d( self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
config, global_cond_dim=global_cond_dim * config.n_obs_steps
)
self.noise_scheduler = _make_noise_scheduler( self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type, config.noise_scheduler_type,
@ -249,9 +235,7 @@ class DiffusionModel(nn.Module):
global_cond=global_cond, global_cond=global_cond,
) )
# Compute previous image: x_t -> x_t-1 # Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step( sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
model_output, t, sample, generator=generator
).prev_sample
return sample return sample
@ -263,15 +247,11 @@ class DiffusionModel(nn.Module):
if self.config.image_features: if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera: if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first. # Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange( images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
batch["observation.images"], "b s n ... -> n (b s) ..."
)
img_features_list = torch.cat( img_features_list = torch.cat(
[ [
encoder(images) encoder(images)
for encoder, images in zip( for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
self.rgb_encoder, images_per_camera, strict=True
)
] ]
) )
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the # Separate batch and sequence dims back out. The camera index dim gets absorbed into the
@ -285,9 +265,7 @@ class DiffusionModel(nn.Module):
else: else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder. # Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder( img_features = self.rgb_encoder(
einops.rearrange( einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
batch["observation.images"], "b s n ... -> (b s n) ..."
)
) )
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features). # feature dim (effectively concatenating the camera features).
@ -381,9 +359,7 @@ class DiffusionModel(nn.Module):
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
target = batch["action"] target = batch["action"]
else: else:
raise ValueError( raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
f"Unsupported prediction type {self.config.prediction_type}"
)
loss = F.mse_loss(pred, target, reduction="none") loss = F.mse_loss(pred, target, reduction="none")
@ -443,9 +419,7 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy # we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models. # and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid( pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device. # register as buffer so it's moved to the correct device.
@ -487,9 +461,7 @@ class DiffusionRgbEncoder(nn.Module):
# Always use center crop for eval # Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random: if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop( self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
config.crop_shape
)
else: else:
self.maybe_random_crop = self.center_crop self.maybe_random_crop = self.center_crop
else: else:
@ -510,9 +482,7 @@ class DiffusionRgbEncoder(nn.Module):
self.backbone = _replace_submodules( self.backbone = _replace_submodules(
root_module=self.backbone, root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d), predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm( func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
num_groups=x.num_features // 16, num_channels=x.num_features
),
) )
# Set up pooling and final layers. # Set up pooling and final layers.
@ -523,15 +493,11 @@ class DiffusionRgbEncoder(nn.Module):
# Note: we have a check in the config class to make sure all images have the same shape. # Note: we have a check in the config class to make sure all images have the same shape.
images_shape = next(iter(config.image_features.values())).shape images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = ( dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
config.crop_shape if config.crop_shape is not None else images_shape[1:]
)
dummy_shape = (1, images_shape[0], *dummy_shape_h_w) dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax( self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
)
self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU() self.relu = nn.ReLU()
@ -573,11 +539,7 @@ def _replace_submodules(
if predicate(root_module): if predicate(root_module):
return func(root_module) return func(root_module)
replace_list = [ replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parents, k in replace_list: for *parents, k in replace_list:
parent_module = root_module parent_module = root_module
if len(parents) > 0: if len(parents) > 0:
@ -592,9 +554,7 @@ def _replace_submodules(
else: else:
setattr(parent_module, k, tgt_module) setattr(parent_module, k, tgt_module)
# verify that all BN are replaced # verify that all BN are replaced
assert not any( assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
return root_module return root_module
@ -622,9 +582,7 @@ class DiffusionConv1dBlock(nn.Module):
super().__init__() super().__init__()
self.block = nn.Sequential( self.block = nn.Sequential(
nn.Conv1d( nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
),
nn.GroupNorm(n_groups, out_channels), nn.GroupNorm(n_groups, out_channels),
nn.Mish(), nn.Mish(),
) )
@ -647,13 +605,9 @@ class DiffusionConditionalUnet1d(nn.Module):
# Encoder for the diffusion timestep. # Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential( self.diffusion_step_encoder = nn.Sequential(
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim), DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear( nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4
),
nn.Mish(), nn.Mish(),
nn.Linear( nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
),
) )
# The FiLM conditioning dimension. # The FiLM conditioning dimension.
@ -678,16 +632,10 @@ class DiffusionConditionalUnet1d(nn.Module):
self.down_modules.append( self.down_modules.append(
nn.ModuleList( nn.ModuleList(
[ [
DiffusionConditionalResidualBlock1d( DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
dim_in, dim_out, **common_res_block_kwargs DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
# Downsample as long as it is not the last block. # Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
if not is_last
else nn.Identity(),
] ]
) )
) )
@ -716,24 +664,16 @@ class DiffusionConditionalUnet1d(nn.Module):
nn.ModuleList( nn.ModuleList(
[ [
# dim_in * 2, because it takes the encoder's skip connection as well # dim_in * 2, because it takes the encoder's skip connection as well
DiffusionConditionalResidualBlock1d( DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
dim_in * 2, dim_out, **common_res_block_kwargs DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
# Upsample as long as it is not the last block. # Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
if not is_last
else nn.Identity(),
] ]
) )
) )
self.final_conv = nn.Sequential( self.final_conv = nn.Sequential(
DiffusionConv1dBlock( DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
),
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1), nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
) )
@ -801,23 +741,17 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
self.use_film_scale_modulation = use_film_scale_modulation self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels self.out_channels = out_channels
self.conv1 = DiffusionConv1dBlock( self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
in_channels, out_channels, kernel_size, n_groups=n_groups
)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = DiffusionConv1dBlock( self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
out_channels, out_channels, kernel_size, n_groups=n_groups
)
# A final convolution for dimension matching the residual (if needed). # A final convolution for dimension matching the residual (if needed).
self.residual_conv = ( self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
if in_channels != out_channels
else nn.Identity()
) )
def forward(self, x: Tensor, cond: Tensor) -> Tensor: def forward(self, x: Tensor, cond: Tensor) -> Tensor:

View File

@ -111,9 +111,7 @@ def make_policy(
PreTrainedPolicy: _description_ PreTrainedPolicy: _description_
""" """
if bool(ds_meta) == bool(env_cfg): if bool(ds_meta) == bool(env_cfg):
raise ValueError( raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
"Either one of a dataset metadata or a sim env must be provided."
)
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error. # NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies? # TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
@ -143,12 +141,8 @@ def make_policy(
) )
features = env_to_policy_features(env_cfg) features = env_to_policy_features(env_cfg)
cfg.output_features = { cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
}
cfg.input_features = {
key: ft for key, ft in features.items() if key not in cfg.output_features
}
kwargs["config"] = cfg kwargs["config"] = cfg
if cfg.pretrained_path: if cfg.pretrained_path:

View File

@ -7,9 +7,7 @@ from torch import Tensor, nn
from .configuration_classifier import ClassifierConfig from .configuration_classifier import ClassifierConfig
logging.basicConfig( logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -53,9 +51,7 @@ class Classifier(
super().__init__() super().__init__()
self.config = config self.config = config
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) # self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained( encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
self.config.model_name, trust_remote_code=True
)
# Extract vision model if we're given a multimodal model # Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"): if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only") logging.info("Multimodal model detected - using vision encoder only")
@ -81,9 +77,7 @@ class Classifier(
self.feature_dim = self.encoder.fc.in_features self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"): elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[ self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
-1
] # Last channel dimension
else: else:
raise ValueError("Unsupported CNN architecture") raise ValueError("Unsupported CNN architecture")
@ -103,9 +97,7 @@ class Classifier(
if hasattr(self.encoder.config, "hidden_size"): if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size input_dim = self.encoder.config.hidden_size
else: else:
raise ValueError( raise ValueError("Unsupported transformer architecture since hidden_size is not found")
"Unsupported transformer architecture since hidden_size is not found"
)
self.classifier_head = nn.Sequential( self.classifier_head = nn.Sequential(
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim), nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
@ -141,10 +133,7 @@ class Classifier(
return features return features
else: # Transformer models else: # Transformer models
outputs = self.encoder(processed) outputs = self.encoder(processed)
if ( if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
hasattr(outputs, "pooler_output")
and outputs.pooler_output is not None
):
return outputs.pooler_output return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :] return outputs.last_hidden_state[:, 0, :]
@ -160,9 +149,7 @@ class Classifier(
else: else:
probabilities = torch.softmax(logits, dim=-1) probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput( return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
logits=logits, probabilities=probabilities, hidden_states=encoder_outputs
)
def predict_reward(self, x, threshold=0.6): def predict_reward(self, x, threshold=0.6):
if self.config.num_classes == 2: if self.config.num_classes == 2:

View File

@ -82,43 +82,25 @@ def create_stats_buffers(
if stats: if stats:
if isinstance(stats[key]["mean"], np.ndarray): if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD: if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to( buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
dtype=torch.float32 buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(
dtype=torch.float32
)
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to( buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
dtype=torch.float32 buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(
dtype=torch.float32
)
elif isinstance(stats[key]["mean"], torch.Tensor): elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and # tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here # unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD: if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = ( buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
stats[key]["mean"].clone().to(dtype=torch.float32) buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
)
buffer["std"].data = (
stats[key]["std"].clone().to(dtype=torch.float32)
)
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = ( buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
stats[key]["min"].clone().to(dtype=torch.float32) buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
)
buffer["max"].data = (
stats[key]["max"].clone().to(dtype=torch.float32)
)
else: else:
type_ = type(stats[key]["mean"]) type_ = type(stats[key]["mean"])
raise ValueError( raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead."
)
stats_buffers[key] = buffer stats_buffers[key] = buffer
return stats_buffers return stats_buffers

View File

@ -44,9 +44,7 @@ def main():
else: else:
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human" dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
ckpt_torch_dir = ( ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
)
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}" ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
save_dir = Path(f"../openpi/data/{model_name}/save") save_dir = Path(f"../openpi/data/{model_name}/save")
@ -72,9 +70,7 @@ def main():
# Create LeRobot batch from Jax # Create LeRobot batch from Jax
batch = {} batch = {}
for cam_key, uint_chw_array in example["images"].items(): for cam_key, uint_chw_array in example["images"].items():
batch[f"observation.images.{cam_key}"] = ( batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
torch.from_numpy(uint_chw_array) / 255.0
)
batch["observation.state"] = torch.from_numpy(example["state"]) batch["observation.state"] = torch.from_numpy(example["state"])
batch["action"] = torch.from_numpy(outputs["actions"]) batch["action"] = torch.from_numpy(outputs["actions"])
batch["task"] = example["prompt"] batch["task"] = example["prompt"]

View File

@ -54,9 +54,7 @@ def get_paligemma_config(precision: str):
"projector_hidden_act": "gelu_fast", "projector_hidden_act": "gelu_fast",
"vision_use_head": False, "vision_use_head": False,
} }
final_config = PaliGemmaConfig( final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
text_config=text_config, vision_config=vision_config, **config
)
return final_config return final_config

View File

@ -322,9 +322,7 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
return {f"{prefix}{key}": value for key, value in d.items()} return {f"{prefix}{key}": value for key, value in d.items()}
def convert_pi0_checkpoint( def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str
):
# Break down orbax ckpts - they are in OCDBT # Break down orbax ckpts - they are in OCDBT
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir) initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
# process projection params # process projection params
@ -384,9 +382,7 @@ def convert_pi0_checkpoint(
# gemma_config=gemma_config, paligemma_config=paligemma_config) # gemma_config=gemma_config, paligemma_config=paligemma_config)
pi0_model = PI0Policy(pi0_config) pi0_model = PI0Policy(pi0_config)
paligemma_params = update_keys_with_prefix( paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
paligemma_params, "model.paligemma_with_expert."
)
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.") gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
projection_params = update_keys_with_prefix(projection_params, "model.") projection_params = update_keys_with_prefix(projection_params, "model.")

View File

@ -193,9 +193,7 @@ def aloha_gripper_to_angular(value):
# This is the inverse of the angular to linear transformation inside the Interbotix code. # This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius): def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / ( value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
2 * horn_radius * linear_position
)
return safe_arcsin(value) return safe_arcsin(value)
# The constants are taken from the Interbotix code. # The constants are taken from the Interbotix code.
@ -246,9 +244,7 @@ class PI0Policy(PreTrainedPolicy):
super().__init__(config) super().__init__(config)
config.validate_features() config.validate_features()
self.config = config self.config = config
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize( self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
@ -256,9 +252,7 @@ class PI0Policy(PreTrainedPolicy):
config.output_features, config.normalization_mapping, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
self.language_tokenizer = AutoTokenizer.from_pretrained( self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
"google/paligemma-3b-pt-224"
)
self.model = PI0FlowMatching(config) self.model = PI0FlowMatching(config)
self.reset() self.reset()
@ -271,9 +265,7 @@ class PI0Policy(PreTrainedPolicy):
return self.parameters() return self.parameters()
@torch.no_grad @torch.no_grad
def select_action( def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
self, batch: dict[str, Tensor], noise: Tensor | None = None
) -> Tensor:
"""Select a single action given environment observations. """Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the This method wraps `select_actions` in order to return one action at a time for execution in the
@ -312,9 +304,7 @@ class PI0Policy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1)) self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()
def forward( def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
self, batch: dict[str, Tensor], noise=None, time=None
) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss""" """Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha: if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
@ -330,9 +320,7 @@ class PI0Policy(PreTrainedPolicy):
actions_is_pad = batch.get("action_is_pad") actions_is_pad = batch.get("action_is_pad")
loss_dict = {} loss_dict = {}
losses = self.model.forward( losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
)
loss_dict["losses_after_forward"] = losses.clone() loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None: if actions_is_pad is not None:
@ -359,9 +347,7 @@ class PI0Policy(PreTrainedPolicy):
img_masks = [] img_masks = []
present_img_keys = [key for key in self.config.image_features if key in batch] present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [ missing_img_keys = [key for key in self.config.image_features if key not in batch]
key for key in self.config.image_features if key not in batch
]
if len(present_img_keys) == 0: if len(present_img_keys) == 0:
raise ValueError( raise ValueError(
@ -373,9 +359,7 @@ class PI0Policy(PreTrainedPolicy):
img = batch[key] img = batch[key]
if self.config.resize_imgs_with_padding is not None: if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad( img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
img, *self.config.resize_imgs_with_padding, pad_value=0
)
# Normalize from range [0,1] to [-1,1] as expacted by siglip # Normalize from range [0,1] to [-1,1] as expacted by siglip
img = img * 2.0 - 1.0 img = img * 2.0 - 1.0
@ -414,9 +398,7 @@ class PI0Policy(PreTrainedPolicy):
return_tensors="pt", return_tensors="pt",
) )
lang_tokens = tokenized_prompt["input_ids"].to(device=device) lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to( lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
device=device, dtype=torch.bool
)
return lang_tokens, lang_masks return lang_tokens, lang_masks
@ -435,9 +417,7 @@ class PI0Policy(PreTrainedPolicy):
actions[:, :, motor_idx] *= -1 actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime. # Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]: for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular( actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
actions[:, :, motor_idx]
)
return actions return actions
def _pi_aloha_encode_actions_inv(self, actions): def _pi_aloha_encode_actions_inv(self, actions):
@ -446,9 +426,7 @@ class PI0Policy(PreTrainedPolicy):
actions[:, :, motor_idx] *= -1 actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime. # Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]: for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv( actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
actions[:, :, motor_idx]
)
return actions return actions
def prepare_state(self, batch): def prepare_state(self, batch):
@ -498,25 +476,15 @@ class PI0FlowMatching(nn.Module):
train_expert_only=self.config.train_expert_only, train_expert_only=self.config.train_expert_only,
attention_implementation=self.config.attention_implementation, attention_implementation=self.config.attention_implementation,
) )
self.paligemma_with_expert = PaliGemmaWithExpertModel( self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
paligemma_with_export_config
)
# Projections are float32 # Projections are float32
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width) self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
self.action_in_proj = nn.Linear( self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
self.config.max_action_dim, self.config.proj_width self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
)
self.action_out_proj = nn.Linear(
self.config.proj_width, self.config.max_action_dim
)
self.action_time_mlp_in = nn.Linear( self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
self.config.proj_width * 2, self.config.proj_width self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
)
self.action_time_mlp_out = nn.Linear(
self.config.proj_width, self.config.proj_width
)
self.set_requires_grad() self.set_requires_grad()
@ -560,9 +528,7 @@ class PI0FlowMatching(nn.Module):
# Normalize image embeddings # Normalize image embeddings
img_emb_dim = img_emb.shape[-1] img_emb_dim = img_emb.shape[-1]
img_emb = img_emb * torch.tensor( img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
)
bsize, num_img_embs = img_emb.shape[:2] bsize, num_img_embs = img_emb.shape[:2]
img_mask = img_mask[:, None].expand(bsize, num_img_embs) img_mask = img_mask[:, None].expand(bsize, num_img_embs)
@ -637,9 +603,7 @@ class PI0FlowMatching(nn.Module):
embs.append(action_time_emb) embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2] bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones( action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
bsize, action_time_dim, dtype=torch.bool, device=device
)
pad_masks.append(action_time_mask) pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens # Set attention masks so that image, language and state inputs do not attend to action tokens
@ -677,9 +641,7 @@ class PI0FlowMatching(nn.Module):
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks images, img_masks, lang_tokens, lang_masks
) )
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix( suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
state, x_t, time
)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
@ -703,9 +665,7 @@ class PI0FlowMatching(nn.Module):
losses = F.mse_loss(u_t, v_t, reduction="none") losses = F.mse_loss(u_t, v_t, reduction="none")
return losses return losses
def sample_actions( def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
self, images, img_masks, lang_tokens, lang_masks, state, noise=None
) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0] bsize = state.shape[0]
device = state.device device = state.device
@ -763,16 +723,12 @@ class PI0FlowMatching(nn.Module):
timestep, timestep,
): ):
"""Apply one denoising step of the noise `x_t` at a given timestep.""" """Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix( suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
state, x_t, timestep
)
suffix_len = suffix_pad_masks.shape[1] suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0] batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1] prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand( prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
batch_size, suffix_len, prefix_len
)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)

View File

@ -39,13 +39,9 @@ def apply_rope(x, positions, max_wavelength=10_000):
dtype = x.dtype dtype = x.dtype
x = x.to(torch.float32) x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange( freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
d_half, dtype=torch.float32, device=device
)
timescale = max_wavelength**freq_exponents timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to( radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
torch.float32
)
radians = radians[..., None, :] radians = radians[..., None, :]
@ -178,9 +174,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
def __init__(self, config: PaliGemmaWithExpertConfig): def __init__(self, config: PaliGemmaWithExpertConfig):
super().__init__(config=config) super().__init__(config=config)
self.config = config self.config = config
self.paligemma = PaliGemmaForConditionalGeneration( self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
config=config.paligemma_config
)
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
# Remove unused embed_tokens # Remove unused embed_tokens
self.gemma_expert.model.embed_tokens = None self.gemma_expert.model.embed_tokens = None
@ -297,9 +291,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists # the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap) # in `transformers`. (molbap)
key_states = torch.cat( key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
[past_key_values[layer_idx]["key_states"], key_states], dim=1
)
value_states = torch.cat( value_states = torch.cat(
[past_key_values[layer_idx]["value_states"], value_states], [past_key_values[layer_idx]["value_states"], value_states],
dim=1, dim=1,
@ -392,9 +384,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
value_states, value_states,
): ):
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
num_key_value_heads = ( num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
self.config.paligemma_config.text_config.num_key_value_heads
)
num_key_value_groups = num_att_heads // num_key_value_heads num_key_value_groups = num_att_heads // num_key_value_heads
# query_states: batch_size, sequence_length, num_att_head, head_dim # query_states: batch_size, sequence_length, num_att_head, head_dim
@ -442,9 +432,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
att_weights *= head_dim**-0.5 att_weights *= head_dim**-0.5
big_neg = -2.3819763e38 # See gemma/modules.py big_neg = -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where( masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
attention_mask[:, None, :, :], att_weights, big_neg
)
probs = nn.functional.softmax(masked_att_weights, dim=-1) probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype) probs = probs.to(dtype=value_states.dtype)
@ -456,8 +444,6 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
att_output = att_output.permute(0, 2, 1, 3) att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change # we use -1 because sequence length can change
att_output = att_output.reshape( att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
)
return att_output return att_output

View File

@ -71,9 +71,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
def _save_pretrained(self, save_directory: Path) -> None: def _save_pretrained(self, save_directory: Path) -> None:
self.config._save_pretrained(save_directory) self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor( save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)
)
@classmethod @classmethod
def from_pretrained( def from_pretrained(
@ -112,9 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
if os.path.isdir(model_id): if os.path.isdir(model_id):
print("Loading weights from local directory") print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
policy = cls._load_as_safetensor( policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
instance, model_file, config.device, strict
)
else: else:
try: try:
model_file = hf_hub_download( model_file = hf_hub_download(
@ -128,9 +124,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
token=token, token=token,
local_files_only=local_files_only, local_files_only=local_files_only,
) )
policy = cls._load_as_safetensor( policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
instance, model_file, config.device, strict
)
except HfHubHTTPError as e: except HfHubHTTPError as e:
raise FileNotFoundError( raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}" f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
@ -141,12 +135,8 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
return policy return policy
@classmethod @classmethod
def _load_as_safetensor( def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
cls, model: T, model_file: str, map_location: str, strict: bool if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
) -> T:
if packaging.version.parse(safetensors.__version__) < packaging.version.parse(
"0.4.3"
):
load_model_as_safetensor(model, model_file, strict=strict) load_model_as_safetensor(model, model_file, strict=strict)
if map_location != "cpu": if map_location != "cpu":
logging.warning( logging.warning(
@ -157,9 +147,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
) )
model.to(map_location) model.to(map_location)
else: else:
safetensors.torch.load_model( safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
model, model_file, strict=strict, device=map_location
)
return model return model
# def generate_model_card(self, *args, **kwargs) -> ModelCard: # def generate_model_card(self, *args, **kwargs) -> ModelCard:

View File

@ -48,9 +48,7 @@ class SACConfig:
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]}, "observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
} }
) )
output_normalization_modes: dict[str, str] = field( output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
default_factory=lambda: {"action": "min_max"}
)
output_normalization_params: dict[str, dict[str, list[float]]] = field( output_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: { default_factory=lambda: {
"action": {"min": [-1, -1], "max": [1, 1]}, "action": {"min": [-1, -1], "max": [1, 1]},

View File

@ -18,8 +18,8 @@
# TODO: (1) better device management # TODO: (1) better device management
import math import math
from typing import Callable, Optional, Tuple, Union, Dict, List
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
import einops import einops
import numpy as np import numpy as np
@ -124,17 +124,13 @@ class SACPolicy(
self.actor = Policy( self.actor = Policy(
encoder=encoder_actor, encoder=encoder_actor,
network=MLP( network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
),
action_dim=config.output_shapes["action"][0], action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder, encoder_is_shared=config.shared_encoder,
**config.policy_kwargs, **config.policy_kwargs,
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = ( config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
-np.prod(config.output_shapes["action"][0]) / 2
) # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@ -146,10 +142,11 @@ class SACPolicy(
def _save_pretrained(self, save_directory): def _save_pretrained(self, save_directory):
"""Custom save method to handle TensorDict properly""" """Custom save method to handle TensorDict properly"""
import os
import json import json
import os
from dataclasses import asdict from dataclasses import asdict
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
from safetensors.torch import save_model from safetensors.torch import save_model
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)) save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
@ -177,12 +174,14 @@ class SACPolicy(
**model_kwargs, **model_kwargs,
) -> "SACPolicy": ) -> "SACPolicy":
"""Custom load method to handle loading SAC policy from saved files""" """Custom load method to handle loading SAC policy from saved files"""
import os
import json import json
import os
from pathlib import Path from pathlib import Path
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_model from safetensors.torch import load_model
from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.sac.configuration_sac import SACConfig
# Check if model_id is a local path or a hub model ID # Check if model_id is a local path or a hub model ID
@ -302,14 +301,10 @@ class SACPolicy(
) -> Tensor: ) -> Tensor:
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
with torch.no_grad(): with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor( next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
next_observations, next_observation_features
)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way # TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[ next_action_preds = self.unnormalize_outputs({"action": next_action_preds})["action"]
"action"
]
# 2- compute q targets # 2- compute q targets
q_targets = self.critic_forward( q_targets = self.critic_forward(
@ -353,21 +348,15 @@ class SACPolicy(
).sum() ).sum()
return critics_loss return critics_loss
def compute_loss_temperature( def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
self, observations, observation_features: Tensor | None = None
) -> Tensor:
"""Compute the temperature loss""" """Compute the temperature loss"""
# calculate temperature loss # calculate temperature loss
with torch.no_grad(): with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features) _, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = ( temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
).mean()
return temperature_loss return temperature_loss
def compute_loss_actor( def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
self, observations, observation_features: Tensor | None = None
) -> Tensor:
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features) actions_pi, log_probs, _ = self.actor(observations, observation_features)
@ -408,11 +397,7 @@ class MLP(nn.Module):
if dropout_rate is not None and dropout_rate > 0: if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0])) layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append( layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
# Rest of the layers # Rest of the layers
for i in range(1, len(hidden_dims)): for i in range(1, len(hidden_dims)):
@ -424,11 +409,7 @@ class MLP(nn.Module):
layers.append(nn.LayerNorm(hidden_dims[i])) layers.append(nn.LayerNorm(hidden_dims[i]))
# If we're at the final layer and a final activation is specified, use it # If we're at the final layer and a final activation is specified, use it
if ( if i + 1 == len(hidden_dims) and activate_final and final_activation is not None:
i + 1 == len(hidden_dims)
and activate_final
and final_activation is not None
):
layers.append( layers.append(
final_activation final_activation
if isinstance(final_activation, nn.Module) if isinstance(final_activation, nn.Module)
@ -436,9 +417,7 @@ class MLP(nn.Module):
) )
else: else:
layers.append( layers.append(
activations activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
) )
self.net = nn.Sequential(*layers) self.net = nn.Sequential(*layers)
@ -639,15 +618,11 @@ class Policy(nn.Module):
# Compute standard deviations # Compute standard deviations
if self.fixed_std is None: if self.fixed_std is None:
log_std = self.std_layer(outputs) log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), ( assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
"[ERROR] log_std became NaN after std_layer!"
)
if self.use_tanh_squash: if self.use_tanh_squash:
log_std = torch.tanh(log_std) log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * ( log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
self.log_std_max - self.log_std_min
) * (log_std + 1.0)
else: else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else: else:
@ -660,9 +635,7 @@ class Policy(nn.Module):
if self.use_tanh_squash: if self.use_tanh_squash:
actions = torch.tanh(x_t) actions = torch.tanh(x_t)
log_probs -= torch.log( log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
(1 - actions.pow(2)) + 1e-6
) # Adjust log-probs for Tanh
else: else:
actions = x_t # No Tanh; raw Gaussian sample actions = x_t # No Tanh; raw Gaussian sample
@ -709,9 +682,7 @@ class SACObservationEncoder(nn.Module):
freeze_image_encoder(self.image_enc_layers) freeze_image_encoder(self.image_enc_layers)
else: else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters()) self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [ self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
k for k in config.input_shapes if k.startswith("observation.image")
]
if "observation.state" in config.input_shapes: if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential( self.state_enc_layers = nn.Sequential(
@ -738,9 +709,7 @@ class SACObservationEncoder(nn.Module):
self.aggregation_size += config.latent_dim self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear( self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
in_features=self.aggregation_size, out_features=config.latent_dim
)
self.parameters_to_optimize += list(self.aggregation_layer.parameters()) self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
@ -753,19 +722,13 @@ class SACObservationEncoder(nn.Module):
obs_dict = self.input_normalization(obs_dict) obs_dict = self.input_normalization(obs_dict)
# Batch all images along the batch dimension, then encode them. # Batch all images along the batch dimension, then encode them.
if len(self.all_image_keys) > 0: if len(self.all_image_keys) > 0:
images_batched = torch.cat( images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
[obs_dict[key] for key in self.all_image_keys], dim=0
)
images_batched = self.image_enc_layers(images_batched) images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk( embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
images_batched, dim=0, chunks=len(self.all_image_keys)
)
feat.extend(embeddings_chunks) feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_shapes: if "observation.environment_state" in self.config.input_shapes:
feat.append( feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
self.env_state_enc_layers(obs_dict["observation.environment_state"])
)
if "observation.state" in self.config.input_shapes: if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"])) feat.append(self.state_enc_layers(obs_dict["observation.state"]))
@ -833,9 +796,7 @@ class PretrainedImageEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.image_enc_layers, self.image_enc_out_shape = ( self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
self._load_pretrained_vision_encoder(config)
)
self.image_enc_proj = nn.Sequential( self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim), nn.LayerNorm(config.latent_dim),
@ -846,21 +807,15 @@ class PretrainedImageEncoder(nn.Module):
"""Set up CNN encoder""" """Set up CNN encoder"""
from transformers import AutoModel from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained( self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
config.vision_encoder_name, trust_remote_code=True
)
# self.image_enc_layers.pooler = Identity() # self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"): if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[ self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
-1
] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"): elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else: else:
raise ValueError( raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
"Unsupported vision encoder architecture, make sure you are using a CNN"
)
return self.image_enc_layers, self.image_enc_out_shape return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x): def forward(self, x):
@ -896,9 +851,7 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
for key, value in inner_dict.items(): for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value) converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key: if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][ converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
key
].view(3, 1, 1)
return converted_params return converted_params

View File

@ -183,13 +183,9 @@ class TDMPCConfig(PreTrainedConfig):
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1." "If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
) )
if not self.use_mpc: if not self.use_mpc:
raise ValueError( raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
)
if self.n_action_steps > self.horizon: if self.n_action_steps > self.horizon:
raise ValueError( raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
"`n_action_steps` must be less than or equal to `horizon`."
)
def get_optimizer_preset(self) -> AdamConfig: def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(lr=self.optimizer_lr) return AdamConfig(lr=self.optimizer_lr)
@ -209,9 +205,7 @@ class TDMPCConfig(PreTrainedConfig):
if image_ft.shape[-2] != image_ft.shape[-1]: if image_ft.shape[-2] != image_ft.shape[-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift # TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed. # augmentation. It should be able to be removed.
raise ValueError( raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
f"Only square images are handled now. Got image shape {image_ft.shape}."
)
@property @property
def observation_delta_indices(self) -> list: def observation_delta_indices(self) -> list:

View File

@ -83,9 +83,7 @@ class TDMPCPolicy(PreTrainedPolicy):
config.validate_features() config.validate_features()
self.config = config self.config = config
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize( self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
@ -110,9 +108,7 @@ class TDMPCPolicy(PreTrainedPolicy):
""" """
self._queues = { self._queues = {
"observation.state": deque(maxlen=1), "observation.state": deque(maxlen=1),
"action": deque( "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)
),
} }
if self.config.image_features: if self.config.image_features:
self._queues["observation.image"] = deque(maxlen=1) self._queues["observation.image"] = deque(maxlen=1)
@ -127,9 +123,7 @@ class TDMPCPolicy(PreTrainedPolicy):
"""Select a single action given environment observations.""" """Select a single action given environment observations."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self.config.image_features: if self.config.image_features:
batch = dict( batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[next(iter(self.config.image_features))] batch["observation.image"] = batch[next(iter(self.config.image_features))]
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
@ -232,47 +226,35 @@ class TDMPCPolicy(PreTrainedPolicy):
self.config.action_feature.shape[0], self.config.action_feature.shape[0],
device=std.device, device=std.device,
) )
gaussian_actions = torch.clamp( gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1
)
# Compute elite actions. # Compute elite actions.
actions = torch.cat([gaussian_actions, pi_actions], dim=1) actions = torch.cat([gaussian_actions, pi_actions], dim=1)
value = self.estimate_value(z, actions).nan_to_num_(0) value = self.estimate_value(z, actions).nan_to_num_(0)
elite_idxs = torch.topk( elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
value, self.config.n_elites, dim=0
).indices # (n_elites, batch)
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch) elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
# (horizon, n_elites, batch, action_dim) # (horizon, n_elites, batch, action_dim)
elite_actions = actions.take_along_dim( elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1
)
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites. # Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch) max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage # The weighting is a softmax over trajectory values. Note that this is not the same as the usage
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This # of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²). # makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
score = torch.exp( score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
self.config.elite_weighting_temperature * (elite_value - max_value)
)
score /= score.sum(axis=0, keepdim=True) score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim) # (horizon, batch, action_dim)
_mean = torch.sum( _mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1
)
_std = torch.sqrt( _std = torch.sqrt(
torch.sum( torch.sum(
einops.rearrange(score, "n b -> n b 1") einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) * (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
** 2,
dim=1, dim=1,
) )
) )
# Update mean with an exponential moving average, and std with a direct replacement. # Update mean with an exponential moving average, and std with a direct replacement.
mean = ( mean = (
self.config.gaussian_mean_momentum * mean self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
+ (1 - self.config.gaussian_mean_momentum) * _mean
) )
std = _std.clamp_(self.config.min_std, self.config.max_std) std = _std.clamp_(self.config.min_std, self.config.max_std)
@ -281,9 +263,7 @@ class TDMPCPolicy(PreTrainedPolicy):
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax # Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
# scores from the last iteration. # scores from the last iteration.
actions = elite_actions[ actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)
]
return actions return actions
@ -306,8 +286,7 @@ class TDMPCPolicy(PreTrainedPolicy):
# of the FOWM paper. # of the FOWM paper.
if self.config.uncertainty_regularizer_coeff > 0: if self.config.uncertainty_regularizer_coeff > 0:
regularization = -( regularization = -(
self.config.uncertainty_regularizer_coeff self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
* self.model.Qs(z, actions[t]).std(0)
) )
else: else:
regularization = 0 regularization = 0
@ -328,9 +307,7 @@ class TDMPCPolicy(PreTrainedPolicy):
G += ( G += (
running_discount running_discount
* torch.min( * torch.min(
terminal_values[ terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))],
torch.randint(0, self.config.q_ensemble_size, size=(2,))
],
dim=0, dim=0,
)[0] )[0]
) )
@ -338,11 +315,7 @@ class TDMPCPolicy(PreTrainedPolicy):
G += running_discount * torch.min(terminal_values, dim=0)[0] G += running_discount * torch.min(terminal_values, dim=0)[0]
# Finally, also regularize the terminal value. # Finally, also regularize the terminal value.
if self.config.uncertainty_regularizer_coeff > 0: if self.config.uncertainty_regularizer_coeff > 0:
G -= ( G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
running_discount
* self.config.uncertainty_regularizer_coeff
* terminal_values.std(0)
)
return G return G
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
@ -354,9 +327,7 @@ class TDMPCPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self.config.image_features: if self.config.image_features:
batch = dict( batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[next(iter(self.config.image_features))] batch["observation.image"] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
@ -388,29 +359,21 @@ class TDMPCPolicy(PreTrainedPolicy):
current_observation[k] = observations[k][0] current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:] next_observations[k] = observations[k][1:]
horizon, batch_size = next_observations[ horizon, batch_size = next_observations[
"observation.image" "observation.image" if self.config.image_features else "observation.environment_state"
if self.config.image_features
else "observation.environment_state"
].shape[:2] ].shape[:2]
# Run latent rollout using the latent dynamics model and policy model. # Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`. # gives us a next `z`.
batch_size = batch["index"].shape[0] batch_size = batch["index"].shape[0]
z_preds = torch.empty( z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
horizon + 1, batch_size, self.config.latent_dim, device=device
)
z_preds[0] = self.model.encode(current_observation) z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device) reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon): for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward( z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
z_preds[t], action[t]
)
# Compute Q and V value predictions based on the latent rollout. # Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs( q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
z_preds[:-1], action
) # (ensemble, horizon, batch)
v_preds = self.model.V(z_preds[:-1]) v_preds = self.model.V(z_preds[:-1])
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()}) info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
@ -424,14 +387,10 @@ class TDMPCPolicy(PreTrainedPolicy):
# actions (not actions estimated by π). # actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code # Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper. # and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V( q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
self.model.encode(next_observations)
)
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we # From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V. # are using them to compute loss for V.
v_targets = self.model_target.Qs( v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
z_preds[:-1].detach(), action, return_min=True
)
# Compute losses. # Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the # Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
@ -474,9 +433,7 @@ class TDMPCPolicy(PreTrainedPolicy):
temporal_loss_coeffs temporal_loss_coeffs
* F.mse_loss( * F.mse_loss(
q_preds_ensemble, q_preds_ensemble,
einops.repeat( einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]
),
reduction="none", reduction="none",
).sum(0) # sum over ensemble ).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions. # `q_preds_ensemble` depends on the first observation and the actions.
@ -514,14 +471,12 @@ class TDMPCPolicy(PreTrainedPolicy):
z_preds = z_preds.detach() z_preds = z_preds.detach()
# Use stopgrad for the advantage calculation. # Use stopgrad for the advantage calculation.
with torch.no_grad(): with torch.no_grad():
advantage = self.model_target.Qs( advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
z_preds[:-1], action, return_min=True z_preds[:-1]
) - self.model.V(z_preds[:-1]) )
info["advantage"] = advantage[0] info["advantage"] = advantage[0]
# (t, b) # (t, b)
exp_advantage = torch.clamp( exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
torch.exp(advantage * self.config.advantage_scaling), max=100.0
)
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a) action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions. # Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation # Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
@ -575,9 +530,7 @@ class TDMPCPolicy(PreTrainedPolicy):
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA # Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code # update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995) # we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters( update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
self.model_target, self.model, self.config.target_model_momentum
)
class TDMPCTOLD(nn.Module): class TDMPCTOLD(nn.Module):
@ -588,9 +541,7 @@ class TDMPCTOLD(nn.Module):
self.config = config self.config = config
self._encoder = TDMPCObservationEncoder(config) self._encoder = TDMPCObservationEncoder(config)
self._dynamics = nn.Sequential( self._dynamics = nn.Sequential(
nn.Linear( nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
),
nn.LayerNorm(config.mlp_dim), nn.LayerNorm(config.mlp_dim),
nn.Mish(), nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim), nn.Linear(config.mlp_dim, config.mlp_dim),
@ -601,9 +552,7 @@ class TDMPCTOLD(nn.Module):
nn.Sigmoid(), nn.Sigmoid(),
) )
self._reward = nn.Sequential( self._reward = nn.Sequential(
nn.Linear( nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
),
nn.LayerNorm(config.mlp_dim), nn.LayerNorm(config.mlp_dim),
nn.Mish(), nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim), nn.Linear(config.mlp_dim, config.mlp_dim),
@ -671,9 +620,7 @@ class TDMPCTOLD(nn.Module):
"Sanity check. The last linear layer needs 0 initialization on weights." "Sanity check. The last linear layer needs 0 initialization on weights."
) )
nn.init.zeros_(m[-1].weight) nn.init.zeros_(m[-1].weight)
nn.init.zeros_( nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
m[-1].bias
) # this has already been done, but keep this line here for good measure
def encode(self, obs: dict[str, Tensor]) -> Tensor: def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation.""" """Encodes an observation into its latent representation."""
@ -812,9 +759,7 @@ class TDMPCObservationEncoder(nn.Module):
if config.robot_state_feature: if config.robot_state_feature:
self.state_enc_layers = nn.Sequential( self.state_enc_layers = nn.Sequential(
nn.Linear( nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
config.robot_state_feature.shape[0], config.state_encoder_hidden_dim
),
nn.ELU(), nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim), nn.LayerNorm(config.latent_dim),
@ -823,9 +768,7 @@ class TDMPCObservationEncoder(nn.Module):
if config.env_state_feature: if config.env_state_feature:
self.env_state_enc_layers = nn.Sequential( self.env_state_enc_layers = nn.Sequential(
nn.Linear( nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
config.env_state_feature.shape[0], config.state_encoder_hidden_dim
),
nn.ELU(), nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim), nn.LayerNorm(config.latent_dim),
@ -898,10 +841,7 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update" assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict): if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported") raise RuntimeError("Dict parameter not supported")
if ( if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
isinstance(module, nn.modules.batchnorm._BatchNorm)
or not p.requires_grad
):
# Copy BatchNorm parameters, and non-trainable parameters directly. # Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data) p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad(): with torch.no_grad():
@ -909,9 +849,7 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha) p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
def flatten_forward_unflatten( def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor. """Helper to temporarily flatten extra dims at the start of the image tensor.
Args: Args:

View File

@ -172,10 +172,7 @@ class VQBeTConfig(PreTrainedConfig):
if self.crop_shape is not None: if self.crop_shape is not None:
for key, image_ft in self.image_features.items(): for key, image_ft in self.image_features.items():
if ( if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
self.crop_shape[0] > image_ft.shape[1]
or self.crop_shape[1] > image_ft.shape[2]
):
raise ValueError( raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for " f"for `crop_shape` and {image_ft.shape} for "

View File

@ -64,9 +64,7 @@ class VQBeTPolicy(PreTrainedPolicy):
config.validate_features() config.validate_features()
self.config = config self.config = config
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize( self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
@ -97,17 +95,11 @@ class VQBeTPolicy(PreTrainedPolicy):
if self.config.sequentially_select: if self.config.sequentially_select:
decay_params = ( decay_params = (
decay_params decay_params
+ list( + list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters() + list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
)
+ list(
self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()
)
) )
else: else:
decay_params = decay_params + list( decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
self.vqbet.action_head.map_to_cbet_preds_bin.parameters()
)
return [ return [
{ {
@ -145,12 +137,8 @@ class VQBeTPolicy(PreTrainedPolicy):
""" """
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = dict( batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
# Note: It's important that this happens after stacking the images into a single key. # Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
@ -161,14 +149,8 @@ class VQBeTPolicy(PreTrainedPolicy):
) )
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
batch = { batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
k: torch.stack(list(self._queues[k]), dim=1) actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
for k in batch
if k in self._queues
}
actions = self.vqbet(batch, rollout=True)[
:, : self.config.action_chunk_size
]
# the dimension of returned action is (batch_size, action_chunk_size, action_dim) # the dimension of returned action is (batch_size, action_chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
@ -181,12 +163,8 @@ class VQBeTPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = dict( batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181) # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item(): if not self.vqbet.action_head.vqvae_model.discretized.item():
@ -194,9 +172,7 @@ class VQBeTPolicy(PreTrainedPolicy):
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`. # n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree). # n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
loss, n_different_codes, n_different_combinations, recon_l1_error = ( loss, n_different_codes, n_different_combinations, recon_l1_error = (
self.vqbet.action_head.discretize( self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
self.config.n_vqvae_training_steps, batch["action"]
)
) )
return loss, { return loss, {
"n_different_codes": n_different_codes, "n_different_codes": n_different_codes,
@ -253,9 +229,7 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy # we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models. # and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid( pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device. # register as buffer so it's moved to the correct device.
@ -370,12 +344,7 @@ class VQBeTModel(nn.Module):
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1 num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
self.register_buffer( self.register_buffer(
"select_target_actions_indices", "select_target_actions_indices",
torch.row_stack( torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
[
torch.arange(i, i + self.config.action_chunk_size)
for i in range(num_tokens)
]
),
) )
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]: def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
@ -406,19 +375,13 @@ class VQBeTModel(nn.Module):
input_tokens.append( input_tokens.append(
self.state_projector(batch["observation.state"]) self.state_projector(batch["observation.state"])
) # (batch, obs_step, projection dims) ) # (batch, obs_step, projection dims)
input_tokens.append( input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
einops.repeat(
self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
)
)
# Interleave tokens by stacking and rearranging. # Interleave tokens by stacking and rearranging.
input_tokens = torch.stack(input_tokens, dim=2) input_tokens = torch.stack(input_tokens, dim=2)
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d") input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
len_additional_action_token = self.config.n_action_pred_token - 1 len_additional_action_token = self.config.n_action_pred_token - 1
future_action_tokens = self.action_token.repeat( future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
batch_size, len_additional_action_token, 1
)
# add additional action query tokens for predicting future action chunks # add additional action query tokens for predicting future action chunks
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1) input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
@ -427,9 +390,9 @@ class VQBeTModel(nn.Module):
features = self.policy(input_tokens) features = self.policy(input_tokens)
# len(self.config.input_features) is the number of different observation modes. # len(self.config.input_features) is the number of different observation modes.
# this line gets the index of action prompt tokens. # this line gets the index of action prompt tokens.
historical_act_pred_index = np.arange(0, n_obs_steps) * ( historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
len(self.config.input_features) + 1 self.config.input_features
) + len(self.config.input_features) )
# only extract the output tokens at the position of action query: # only extract the output tokens at the position of action query:
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, # Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
@ -449,15 +412,13 @@ class VQBeTModel(nn.Module):
action_head_output = self.action_head(features) action_head_output = self.action_head(features)
# if rollout, VQ-BeT don't calculate loss # if rollout, VQ-BeT don't calculate loss
if rollout: if rollout:
return action_head_output["predicted_action"][ return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
:, n_obs_steps - 1, : batch_size, self.config.action_chunk_size, -1
].reshape(batch_size, self.config.action_chunk_size, -1) )
# else, it calculate overall loss (bin prediction loss, and offset loss) # else, it calculate overall loss (bin prediction loss, and offset loss)
else: else:
output = batch["action"][:, self.select_target_actions_indices] output = batch["action"][:, self.select_target_actions_indices]
loss = self.action_head.loss_fn( loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
action_head_output, output, reduction="mean"
)
return action_head_output, loss return action_head_output, loss
@ -492,9 +453,7 @@ class VQBeTHead(nn.Module):
else: else:
self.map_to_cbet_preds_bin = MLP( self.map_to_cbet_preds_bin = MLP(
in_channels=config.gpt_output_dim, in_channels=config.gpt_output_dim,
hidden_channels=[ hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed
],
) )
self.map_to_cbet_preds_offset = MLP( self.map_to_cbet_preds_offset = MLP(
in_channels=config.gpt_output_dim, in_channels=config.gpt_output_dim,
@ -521,10 +480,7 @@ class VQBeTHead(nn.Module):
loss, metric = self.vqvae_model.vqvae_forward(actions) loss, metric = self.vqvae_model.vqvae_forward(actions)
n_different_codes = sum( n_different_codes = sum(
[ [len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
len(torch.unique(metric[2][:, i]))
for i in range(self.vqvae_model.vqvae_num_layers)
]
) )
n_different_combinations = len(torch.unique(metric[2], dim=0)) n_different_combinations = len(torch.unique(metric[2], dim=0))
recon_l1_error = metric[0].detach().cpu().item() recon_l1_error = metric[0].detach().cpu().item()
@ -585,18 +541,12 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1 cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
) )
sampled_secondary_centers = einops.rearrange( sampled_secondary_centers = einops.rearrange(
torch.multinomial( torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
cbet_secondary_probs.view(-1, choices), num_samples=1
),
"(NT) 1 -> NT", "(NT) 1 -> NT",
NT=NT, NT=NT,
) )
sampled_centers = torch.stack( sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
(sampled_primary_centers, sampled_secondary_centers), axis=1 cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
)
cbet_logits = torch.stack(
[cbet_primary_logits, cbet_secondary_logits], dim=1
)
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once. # if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
else: else:
cbet_logits = self.map_to_cbet_preds_bin(x) cbet_logits = self.map_to_cbet_preds_bin(x)
@ -605,9 +555,7 @@ class VQBeTHead(nn.Module):
"(NT) (G C) -> (NT) G C", "(NT) (G C) -> (NT) G C",
G=self.vqvae_model.vqvae_num_layers, G=self.vqvae_model.vqvae_num_layers,
) )
cbet_probs = torch.softmax( cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
cbet_logits / self.config.bet_softmax_temperature, dim=-1
)
NT, G, choices = cbet_probs.shape NT, G, choices = cbet_probs.shape
sampled_centers = einops.rearrange( sampled_centers = einops.rearrange(
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1), torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
@ -627,17 +575,9 @@ class VQBeTHead(nn.Module):
sampled_offsets = sampled_offsets.sum(dim=1) sampled_offsets = sampled_offsets.sum(dim=1)
with torch.no_grad(): with torch.no_grad():
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder # Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
return_decoder_input = ( return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
self.vqvae_model.get_embeddings_from_code(sampled_centers)
.clone()
.detach()
)
# pass the centroids through decoder to get actions. # pass the centroids through decoder to get actions.
decoded_action = ( decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
self.vqvae_model.get_action_from_latent(return_decoder_input)
.clone()
.detach()
)
# reshaped extracted offset to match with decoded centroids # reshaped extracted offset to match with decoded centroids
sampled_offsets = einops.rearrange( sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
@ -686,9 +626,7 @@ class VQBeTHead(nn.Module):
# Figure out the loss for the actions. # Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action. # First, we need to find the closest cluster center for each ground truth action.
with torch.no_grad(): with torch.no_grad():
state_vq, action_bins = self.vqvae_model.get_code( state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
action_seq
) # action_bins: NT, G
# Now we can compute the loss. # Now we can compute the loss.
@ -711,12 +649,8 @@ class VQBeTHead(nn.Module):
+ cbet_loss2 * self.config.secondary_code_loss_weight + cbet_loss2 * self.config.secondary_code_loss_weight
) )
equal_primary_code_rate = torch.sum( equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
(action_bins[:, 0] == sampled_centers[:, 0]).int() equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
) / (NT)
equal_secondary_code_rate = torch.sum(
(action_bins[:, 1] == sampled_centers[:, 1]).int()
) / (NT)
action_mse_error = torch.mean((action_seq - predicted_action) ** 2) action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action)) vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
@ -730,9 +664,7 @@ class VQBeTHead(nn.Module):
"classification_loss": cbet_loss.detach().cpu().item(), "classification_loss": cbet_loss.detach().cpu().item(),
"offset_loss": offset_loss.detach().cpu().item(), "offset_loss": offset_loss.detach().cpu().item(),
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(), "equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach() "equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
.cpu()
.item(),
"vq_action_error": vq_action_error.detach().cpu().item(), "vq_action_error": vq_action_error.detach().cpu().item(),
"offset_action_error": offset_action_error.detach().cpu().item(), "offset_action_error": offset_action_error.detach().cpu().item(),
"action_error_max": action_error_max.detach().cpu().item(), "action_error_max": action_error_max.detach().cpu().item(),
@ -757,9 +689,7 @@ class VQBeTRgbEncoder(nn.Module):
# Always use center crop for eval # Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random: if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop( self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
config.crop_shape
)
else: else:
self.maybe_random_crop = self.center_crop self.maybe_random_crop = self.center_crop
else: else:
@ -780,9 +710,7 @@ class VQBeTRgbEncoder(nn.Module):
self.backbone = _replace_submodules( self.backbone = _replace_submodules(
root_module=self.backbone, root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d), predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm( func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
num_groups=x.num_features // 16, num_channels=x.num_features
),
) )
# Set up pooling and final layers. # Set up pooling and final layers.
@ -792,15 +720,11 @@ class VQBeTRgbEncoder(nn.Module):
# height and width from `config.image_features`. # height and width from `config.image_features`.
images_shape = next(iter(config.image_features.values())).shape images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = ( dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
config.crop_shape if config.crop_shape is not None else images_shape[1:]
)
dummy_shape = (1, images_shape[0], *dummy_shape_h_w) dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax( self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
)
self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU() self.relu = nn.ReLU()
@ -842,11 +766,7 @@ def _replace_submodules(
if predicate(root_module): if predicate(root_module):
return func(root_module) return func(root_module)
replace_list = [ replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parents, k in replace_list: for *parents, k in replace_list:
parent_module = root_module parent_module = root_module
if len(parents) > 0: if len(parents) > 0:
@ -861,9 +781,7 @@ def _replace_submodules(
else: else:
setattr(parent_module, k, tgt_module) setattr(parent_module, k, tgt_module)
# verify that all BN are replaced # verify that all BN are replaced
assert not any( assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
return root_module return root_module
@ -896,8 +814,7 @@ class VqVae(nn.Module):
) )
self.encoder = MLP( self.encoder = MLP(
in_channels=self.config.action_feature.shape[0] in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
* self.config.action_chunk_size,
hidden_channels=[ hidden_channels=[
config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim,
@ -925,13 +842,9 @@ class VqVae(nn.Module):
# given latent vector, this function outputs the decoded action. # given latent vector, this function outputs the decoded action.
output = self.decoder(latent) output = self.decoder(latent)
if self.config.action_chunk_size == 1: if self.config.action_chunk_size == 1:
return einops.rearrange( return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
)
else: else:
return einops.rearrange( return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
)
def get_code(self, state): def get_code(self, state):
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181) # in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)

View File

@ -123,15 +123,9 @@ class CausalSelfAttention(nn.Module):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim # calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2) q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose( k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
1, 2 q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
) # (B, nh, T, hs) v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
@ -139,9 +133,7 @@ class CausalSelfAttention(nn.Module):
att = F.softmax(att, dim=-1) att = F.softmax(att, dim=-1)
att = self.attn_dropout(att) att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = ( y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
y.transpose(1, 2).contiguous().view(B, T, C)
) # re-assemble all head outputs side by side
# output projection # output projection
y = self.resid_dropout(self.c_proj(y)) y = self.resid_dropout(self.c_proj(y))
@ -197,16 +189,12 @@ class GPT(nn.Module):
"ln_f": nn.LayerNorm(config.gpt_hidden_dim), "ln_f": nn.LayerNorm(config.gpt_hidden_dim),
} }
) )
self.lm_head = nn.Linear( self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
config.gpt_hidden_dim, config.gpt_output_dim, bias=False
)
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
self.apply(self._init_weights) self.apply(self._init_weights)
for pn, p in self.named_parameters(): for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"): if pn.endswith("c_proj.weight"):
torch.nn.init.normal_( torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer))
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
)
# report number of parameters # report number of parameters
n_params = sum(p.numel() for p in self.parameters()) n_params = sum(p.numel() for p in self.parameters())
@ -220,17 +208,11 @@ class GPT(nn.Module):
) )
# positional encodings that are added to the input embeddings # positional encodings that are added to the input embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
0
) # shape (1, t)
# forward the GPT model itself # forward the GPT model itself
tok_emb = self.transformer.wte( tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
input pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(
pos
) # position embeddings of shape (1, t, gpt_hidden_dim)
x = self.transformer.drop(tok_emb + pos_emb) x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h: for block in self.transformer.h:
x = block(x) x = block(x)
@ -255,9 +237,7 @@ class GPT(nn.Module):
# but want to use a smaller block size for some smaller, simpler model # but want to use a smaller block size for some smaller, simpler model
assert gpt_block_size <= self.config.gpt_block_size assert gpt_block_size <= self.config.gpt_block_size
self.config.gpt_block_size = gpt_block_size self.config.gpt_block_size = gpt_block_size
self.transformer.wpe.weight = nn.Parameter( self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
self.transformer.wpe.weight[:gpt_block_size]
)
for block in self.transformer.h: for block in self.transformer.h:
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size] block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
@ -290,11 +270,9 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters()) param_dict = dict(self.named_parameters())
inter_params = decay & no_decay inter_params = decay & no_decay
union_params = decay | no_decay union_params = decay | no_decay
assert len(inter_params) == 0, ( assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
"parameters {} made it into both decay/no_decay sets!".format(
str(inter_params) str(inter_params)
) )
)
assert len(param_dict.keys() - union_params) == 0, ( assert len(param_dict.keys() - union_params) == 0, (
"parameters {} were not separated into either decay/no_decay set!".format( "parameters {} were not separated into either decay/no_decay set!".format(
str(param_dict.keys() - union_params), str(param_dict.keys() - union_params),
@ -390,12 +368,8 @@ class ResidualVQ(nn.Module):
codebook_input_dim = codebook_dim * heads codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim requires_projection = codebook_input_dim != dim
self.project_in = ( self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.num_quantizers = num_quantizers self.num_quantizers = num_quantizers
@ -477,9 +451,7 @@ class ResidualVQ(nn.Module):
return all_codes return all_codes
def forward( def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None):
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
):
""" """
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss. For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize. First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
@ -508,17 +480,13 @@ class ResidualVQ(nn.Module):
) )
ce_losses = [] ce_losses = []
should_quantize_dropout = ( should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
self.training and self.quantize_dropout and not return_loss
)
# sample a layer index at which to dropout further residual quantization # sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss # also prepare null indices and loss
if should_quantize_dropout: if should_quantize_dropout:
rand_quantize_dropout_index = randrange( rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
self.quantize_dropout_cutoff_index, num_quant
)
if quant_dropout_multiple_of != 1: if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = ( rand_quantize_dropout_index = (
@ -527,23 +495,14 @@ class ResidualVQ(nn.Module):
- 1 - 1
) )
null_indices_shape = ( null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
(x.shape[0], *x.shape[-2:]) null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long)
if self.accept_image_fmap
else tuple(x.shape[:2])
)
null_indices = torch.full(
null_indices_shape, -1.0, device=device, dtype=torch.long
)
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype) null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
# go through the layers # go through the layers
for quantizer_index, layer in enumerate(self.layers): for quantizer_index, layer in enumerate(self.layers):
if ( if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
should_quantize_dropout
and quantizer_index > rand_quantize_dropout_index
):
all_indices.append(null_indices) all_indices.append(null_indices)
all_losses.append(null_loss) all_losses.append(null_loss)
continue continue
@ -583,9 +542,7 @@ class ResidualVQ(nn.Module):
# stack all losses and indices # stack all losses and indices
all_losses, all_indices = map( all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
partial(torch.stack, dim=-1), (all_losses, all_indices)
)
ret = (quantized_out, all_indices, all_losses) ret = (quantized_out, all_indices, all_losses)
@ -645,12 +602,8 @@ class VectorQuantize(nn.Module):
codebook_input_dim = codebook_dim * heads codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim requires_projection = codebook_input_dim != dim
self.project_in = ( self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.eps = eps self.eps = eps
self.commitment_weight = commitment_weight self.commitment_weight = commitment_weight
@ -664,14 +617,10 @@ class VectorQuantize(nn.Module):
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
assert not (ema_update and learnable_codebook), ( assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update"
"learnable codebook not compatible with EMA update"
)
assert 0 <= sync_update_v <= 1.0 assert 0 <= sync_update_v <= 1.0
assert not (sync_update_v > 0.0 and not learnable_codebook), ( assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on"
"learnable codebook must be turned on"
)
self.sync_update_v = sync_update_v self.sync_update_v = sync_update_v
@ -683,9 +632,7 @@ class VectorQuantize(nn.Module):
) )
if sync_codebook is None: if sync_codebook is None:
sync_codebook = ( sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
distributed.is_initialized() and distributed.get_world_size() > 1
)
codebook_kwargs = { codebook_kwargs = {
"dim": codebook_dim, "dim": codebook_dim,
@ -850,17 +797,11 @@ class VectorQuantize(nn.Module):
# quantize again # quantize again
quantize, embed_ind, distances = self._codebook( quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
x, **codebook_forward_kwargs
)
if self.training: if self.training:
# determine code to use for commitment loss # determine code to use for commitment loss
maybe_detach = ( maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
torch.detach
if not self.learnable_codebook or freeze_codebook
else identity
)
commit_quantize = maybe_detach(quantize) commit_quantize = maybe_detach(quantize)
@ -870,9 +811,7 @@ class VectorQuantize(nn.Module):
if self.sync_update_v > 0.0: if self.sync_update_v > 0.0:
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
quantize = quantize + self.sync_update_v * ( quantize = quantize + self.sync_update_v * (quantize - quantize.detach())
quantize - quantize.detach()
)
# function for calculating cross entropy loss to distance matrix # function for calculating cross entropy loss to distance matrix
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
@ -905,9 +844,7 @@ class VectorQuantize(nn.Module):
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads) embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
if self.accept_image_fmap: if self.accept_image_fmap:
embed_ind = rearrange( embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)
embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
)
if only_one: if only_one:
embed_ind = rearrange(embed_ind, "b 1 -> b") embed_ind = rearrange(embed_ind, "b 1 -> b")
@ -961,12 +898,8 @@ class VectorQuantize(nn.Module):
num_codes = codebook.shape[-2] num_codes = codebook.shape[-2]
if ( if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes:
self.orthogonal_reg_max_codes is not None rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes]
) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[
: self.orthogonal_reg_max_codes
]
codebook = codebook[:, rand_ids] codebook = codebook[:, rand_ids]
orthogonal_reg_loss = orthogonal_loss_fn(codebook) orthogonal_reg_loss = orthogonal_loss_fn(codebook)
@ -998,9 +931,7 @@ class VectorQuantize(nn.Module):
# if masking, only return quantized for where mask has True # if masking, only return quantized for where mask has True
if mask is not None: if mask is not None:
quantize = torch.where( quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input)
rearrange(mask, "... -> ... 1"), quantize, orig_input
)
return quantize, embed_ind, loss return quantize, embed_ind, loss
@ -1110,9 +1041,7 @@ def sample_vectors(samples, num):
def batched_sample_vectors(samples, num): def batched_sample_vectors(samples, num):
return torch.stack( return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0)
[sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
)
def pad_shape(shape, size, dim=0): def pad_shape(shape, size, dim=0):
@ -1163,9 +1092,7 @@ def sample_vectors_distributed(local_samples, num):
all_num_samples = all_gather_sizes(local_samples, dim=0) all_num_samples = all_gather_sizes(local_samples, dim=0)
if rank == 0: if rank == 0:
samples_per_rank = sample_multinomial( samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
num, all_num_samples / all_num_samples.sum()
)
else: else:
samples_per_rank = torch.empty_like(all_num_samples) samples_per_rank = torch.empty_like(all_num_samples)
@ -1278,9 +1205,7 @@ class EuclideanCodebook(nn.Module):
self.eps = eps self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code self.threshold_ema_dead_code = threshold_ema_dead_code
self.reset_cluster_size = ( self.reset_cluster_size = (
reset_cluster_size reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
if (reset_cluster_size is not None)
else threshold_ema_dead_code
) )
assert callable(gumbel_sample) assert callable(gumbel_sample)
@ -1291,14 +1216,8 @@ class EuclideanCodebook(nn.Module):
"kmeans init is not compatible with multiple codebooks in distributed environment for now" "kmeans init is not compatible with multiple codebooks in distributed environment for now"
) )
self.sample_fn = ( self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
sample_vectors_distributed self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
if use_ddp and sync_kmeans
else batched_sample_vectors
)
self.kmeans_all_reduce_fn = (
distributed.all_reduce if use_ddp and sync_kmeans else noop
)
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
self.register_buffer("initted", torch.Tensor([not kmeans_init])) self.register_buffer("initted", torch.Tensor([not kmeans_init]))
@ -1437,9 +1356,7 @@ class EuclideanCodebook(nn.Module):
distributed.all_reduce(variance_number) distributed.all_reduce(variance_number)
batch_variance = variance_number / num_vectors batch_variance = variance_number / num_vectors
self.update_with_decay( self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
"batch_variance", batch_variance, self.affine_param_batch_decay
)
def replace(self, batch_samples, batch_mask): def replace(self, batch_samples, batch_mask):
for ind, (samples, mask) in enumerate( for ind, (samples, mask) in enumerate(
@ -1448,9 +1365,7 @@ class EuclideanCodebook(nn.Module):
if not torch.any(mask): if not torch.any(mask):
continue continue
sampled = self.sample_fn( sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
rearrange(samples, "... -> 1 ..."), mask.sum().item()
)
sampled = rearrange(sampled, "1 ... -> ...") sampled = rearrange(sampled, "1 ... -> ...")
self.embed.data[ind][mask] = sampled self.embed.data[ind][mask] = sampled
@ -1474,9 +1389,7 @@ class EuclideanCodebook(nn.Module):
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False): def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
needs_codebook_dim = x.ndim < 4 needs_codebook_dim = x.ndim < 4
sample_codebook_temp = ( sample_codebook_temp = (
sample_codebook_temp sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
if (sample_codebook_temp is not None)
else self.sample_codebook_temp
) )
x = x.float() x = x.float()
@ -1504,9 +1417,7 @@ class EuclideanCodebook(nn.Module):
if self.affine_param: if self.affine_param:
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt() codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
batch_std = self.batch_variance.clamp(min=1e-5).sqrt() batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
embed = (embed - self.codebook_mean) * ( embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
batch_std / codebook_std
) + self.batch_mean
dist = -cdist(flatten, embed) dist = -cdist(flatten, embed)
@ -1524,9 +1435,7 @@ class EuclideanCodebook(nn.Module):
if self.training and self.ema_update and not freeze_codebook: if self.training and self.ema_update and not freeze_codebook:
if self.affine_param: if self.affine_param:
flatten = (flatten - self.batch_mean) * ( flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
codebook_std / batch_std
) + self.codebook_mean
if mask is not None: if mask is not None:
embed_onehot[~mask] = 0.0 embed_onehot[~mask] = 0.0
@ -1549,9 +1458,7 @@ class EuclideanCodebook(nn.Module):
self.expire_codes_(x) self.expire_codes_(x)
if needs_codebook_dim: if needs_codebook_dim:
quantize, embed_ind = tuple( quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))
rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)
)
dist = unpack_one(dist, ps, "h * d") dist = unpack_one(dist, ps, "h * d")

View File

@ -57,9 +57,7 @@ class OpenCVCameraConfig(CameraConfig):
self.channels = 3 self.channels = 3
if self.rotation not in [-90, None, 90, 180]: if self.rotation not in [-90, None, 90, 180]:
raise ValueError( raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
)
@CameraConfig.register_subclass("intelrealsense") @CameraConfig.register_subclass("intelrealsense")
@ -104,12 +102,8 @@ class IntelRealSenseCameraConfig(CameraConfig):
self.channels = 3 self.channels = 3
at_least_one_is_not_none = ( at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
self.fps is not None or self.width is not None or self.height is not None at_least_one_is_none = self.fps is None or self.width is None or self.height is None
)
at_least_one_is_none = (
self.fps is None or self.width is None or self.height is None
)
if at_least_one_is_not_none and at_least_one_is_none: if at_least_one_is_not_none and at_least_one_is_none:
raise ValueError( raise ValueError(
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, " "For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
@ -117,6 +111,4 @@ class IntelRealSenseCameraConfig(CameraConfig):
) )
if self.rotation not in [-90, None, 90, 180]: if self.rotation not in [-90, None, 90, 180]:
raise ValueError( raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
)

View File

@ -79,9 +79,7 @@ def save_image(img_array, serial_number, frame_index, images_dir):
img.save(str(path), quality=100) img.save(str(path), quality=100)
logging.info(f"Saved image: {path}") logging.info(f"Saved image: {path}")
except Exception as e: except Exception as e:
logging.error( logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
f"Failed to save image for camera {serial_number} frame {frame_index}: {e}"
)
def save_images_from_cameras( def save_images_from_cameras(
@ -159,9 +157,7 @@ def save_images_from_cameras(
if time.perf_counter() - start_time > record_time_s: if time.perf_counter() - start_time > record_time_s:
break break
print( print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
frame_index += 1 frame_index += 1
finally: finally:
@ -279,9 +275,7 @@ class IntelRealSenseCamera:
f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them." f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
) )
name_to_serial_dict = { name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
cam["name"]: cam["serial_number"] for cam in camera_infos
}
cam_sn = name_to_serial_dict[name] cam_sn = name_to_serial_dict[name]
return cam_sn return cam_sn
@ -353,9 +347,7 @@ class IntelRealSenseCamera:
actual_height = color_profile.height() actual_height = color_profile.height()
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose( if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
self.fps, actual_fps, rel_tol=1e-3
):
# Using `OSError` since it's a broad that encompasses issues related to device communication # Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError( raise OSError(
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}." f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
@ -375,9 +367,7 @@ class IntelRealSenseCamera:
self.is_connected = True self.is_connected = True
def read( def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
self, temporary_color: str | None = None
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3) """Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
of type `np.uint8`, contrarily to the pytorch format which is float channel first. of type `np.uint8`, contrarily to the pytorch format which is float channel first.
@ -404,15 +394,11 @@ class IntelRealSenseCamera:
color_frame = frame.get_color_frame() color_frame = frame.get_color_frame()
if not color_frame: if not color_frame:
raise OSError( raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
f"Can't capture color image from IntelRealSenseCamera({self.serial_number})."
)
color_image = np.asanyarray(color_frame.get_data()) color_image = np.asanyarray(color_frame.get_data())
requested_color_mode = ( requested_color_mode = self.color_mode if temporary_color is None else temporary_color
self.color_mode if temporary_color is None else temporary_color
)
if requested_color_mode not in ["rgb", "bgr"]: if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError( raise ValueError(
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided." f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
@ -440,9 +426,7 @@ class IntelRealSenseCamera:
if self.use_depth: if self.use_depth:
depth_frame = frame.get_depth_frame() depth_frame = frame.get_depth_frame()
if not depth_frame: if not depth_frame:
raise OSError( raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})."
)
depth_map = np.asanyarray(depth_frame.get_data()) depth_map = np.asanyarray(depth_frame.get_data())
@ -484,9 +468,7 @@ class IntelRealSenseCamera:
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here # TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
num_tries += 1 num_tries += 1
time.sleep(1 / self.fps) time.sleep(1 / self.fps)
if num_tries > self.fps and ( if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
self.thread.ident is None or not self.thread.is_alive()
):
raise Exception( raise Exception(
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called." "The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
) )

View File

@ -45,14 +45,10 @@ from lerobot.common.utils.utils import capture_timestamp_utc
MAX_OPENCV_INDEX = 60 MAX_OPENCV_INDEX = 60
def find_cameras( def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False
) -> list[dict]:
cameras = [] cameras = []
if platform.system() == "Linux": if platform.system() == "Linux":
print( print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
"Linux detected. Finding available camera indices through scanning '/dev/video*' ports"
)
possible_ports = [str(port) for port in Path("/dev").glob("video*")] possible_ports = [str(port) for port in Path("/dev").glob("video*")]
ports = _find_cameras(possible_ports, mock=mock) ports = _find_cameras(possible_ports, mock=mock)
for port in ports: for port in ports:
@ -144,9 +140,7 @@ def save_images_from_cameras(
print("Connecting cameras") print("Connecting cameras")
cameras = [] cameras = []
for cam_idx in camera_ids: for cam_idx in camera_ids:
config = OpenCVCameraConfig( config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock
)
camera = OpenCVCamera(config) camera = OpenCVCamera(config)
camera.connect() camera.connect()
print( print(
@ -186,9 +180,7 @@ def save_images_from_cameras(
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
print( print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
if time.perf_counter() - start_time > record_time_s: if time.perf_counter() - start_time > record_time_s:
break break
@ -245,16 +237,12 @@ class OpenCVCamera:
if platform.system() == "Linux": if platform.system() == "Linux":
if isinstance(self.camera_index, int): if isinstance(self.camera_index, int):
self.port = Path(f"/dev/video{self.camera_index}") self.port = Path(f"/dev/video{self.camera_index}")
elif isinstance(self.camera_index, str) and is_valid_unix_path( elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index):
self.camera_index
):
self.port = Path(self.camera_index) self.port = Path(self.camera_index)
# Retrieve the camera index from a potentially symlinked path # Retrieve the camera index from a potentially symlinked path
self.camera_index = get_camera_index_from_unix_port(self.port) self.camera_index = get_camera_index_from_unix_port(self.port)
else: else:
raise ValueError( raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
f"Please check the provided camera_index: {self.camera_index}"
)
# Store the raw (capture) resolution from the config. # Store the raw (capture) resolution from the config.
self.capture_width = config.width self.capture_width = config.width
@ -295,9 +283,7 @@ class OpenCVCamera:
def connect(self): def connect(self):
if self.is_connected: if self.is_connected:
raise RobotDeviceAlreadyConnectedError( raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
f"OpenCVCamera({self.camera_index}) is already connected."
)
if self.mock: if self.mock:
import tests.cameras.mock_cv2 as cv2 import tests.cameras.mock_cv2 as cv2
@ -318,11 +304,7 @@ class OpenCVCamera:
else cv2.CAP_ANY else cv2.CAP_ANY
) )
camera_idx = ( camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
f"/dev/video{self.camera_index}"
if platform.system() == "Linux"
else self.camera_index
)
# First create a temporary camera trying to access `camera_index`, # First create a temporary camera trying to access `camera_index`,
# and verify it is a valid camera by calling `isOpened`. # and verify it is a valid camera by calling `isOpened`.
tmp_camera = cv2.VideoCapture(camera_idx, backend) tmp_camera = cv2.VideoCapture(camera_idx, backend)
@ -362,9 +344,7 @@ class OpenCVCamera:
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT) actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose( if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
self.fps, actual_fps, rel_tol=1e-3
):
# Using `OSError` since it's a broad that encompasses issues related to device communication # Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError( raise OSError(
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}." f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
@ -406,9 +386,7 @@ class OpenCVCamera:
if not ret: if not ret:
raise OSError(f"Can't capture color image from camera {self.camera_index}.") raise OSError(f"Can't capture color image from camera {self.camera_index}.")
requested_color_mode = ( requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
self.color_mode if temporary_color_mode is None else temporary_color_mode
)
if requested_color_mode not in ["rgb", "bgr"]: if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError( raise ValueError(

View File

@ -93,9 +93,7 @@ class RecordControlConfig(ControlConfig):
policy_path = parser.get_path_arg("control.policy") policy_path = parser.get_path_arg("control.policy")
if policy_path: if policy_path:
cli_overrides = parser.get_cli_overrides("control.policy") cli_overrides = parser.get_cli_overrides("control.policy")
self.policy = PreTrainedConfig.from_pretrained( self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
policy_path, cli_overrides=cli_overrides
)
self.policy.pretrained_path = policy_path self.policy.pretrained_path = policy_path

View File

@ -39,9 +39,7 @@ from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, has_method from lerobot.common.utils.utils import get_safe_torch_device, has_method
def log_control_info( def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
):
log_items = [] log_items = []
if episode_index is not None: if episode_index is not None:
log_items.append(f"ep:{episode_index}") log_items.append(f"ep:{episode_index}")
@ -108,9 +106,7 @@ def predict_action(observation, policy, device, use_amp):
observation = copy(observation) observation = copy(observation)
with ( with (
torch.inference_mode(), torch.inference_mode(),
torch.autocast(device_type=device.type) torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
if device.type == "cuda" and use_amp
else nullcontext(),
): ):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation: for name in observation:
@ -166,9 +162,7 @@ def init_keyboard_listener(assign_rewards=False):
print("Right arrow key pressed. Exiting loop...") print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True events["exit_early"] = True
elif key == keyboard.Key.left: elif key == keyboard.Key.left:
print( print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
events["rerecord_episode"] = True events["rerecord_episode"] = True
events["exit_early"] = True events["exit_early"] = True
elif key == keyboard.Key.esc: elif key == keyboard.Key.esc:
@ -262,9 +256,7 @@ def control_loop(
raise ValueError("You need to provide a task as argument in `single_task`.") raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps: if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError( raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps})."
)
timestamp = 0 timestamp = 0
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
@ -297,9 +289,7 @@ def control_loop(
dataset.add_frame(frame) dataset.add_frame(frame)
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon) # TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
if (display_data and not is_headless()) or ( if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
display_data and robot.robot_type.startswith("lekiwi")
):
for k, v in action.items(): for k, v in action.items():
for i, vv in enumerate(v): for i, vv in enumerate(v):
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy())) rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
@ -389,14 +379,11 @@ def sanity_check_dataset_robot_compatibility(
mismatches = [] mismatches = []
for field, dataset_value, present_value in fields: for field, dataset_value, present_value in fields:
diff = DeepDiff( diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
)
if diff: if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches: if mismatches:
raise ValueError( raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" "Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
+ "\n".join(mismatches)
) )

View File

@ -161,9 +161,7 @@ NUM_READ_RETRY = 10
NUM_WRITE_RETRY = 10 NUM_WRITE_RETRY = 10
def convert_degrees_to_steps( def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation. """This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180. It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@ -389,9 +387,7 @@ class DynamixelMotorsBus:
indices = [] indices = []
for idx in tqdm.tqdm(possible_ids): for idx in tqdm.tqdm(possible_ids):
try: try:
present_idx = self.read_with_motor_ids( present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
except ConnectionError: except ConnectionError:
continue continue
@ -407,9 +403,7 @@ class DynamixelMotorsBus:
def set_bus_baudrate(self, baudrate): def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate() present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate: if present_bus_baudrate != baudrate:
print( print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
self.port_handler.setBaudRate(baudrate) self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate: if self.port_handler.getBaudRate() != baudrate:
@ -430,9 +424,7 @@ class DynamixelMotorsBus:
def set_calibration(self, calibration: dict[str, list]): def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration self.calibration = calibration
def apply_calibration_autocorrect( def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct. """This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@ -445,9 +437,7 @@ class DynamixelMotorsBus:
values = self.apply_calibration(values, motor_names) values = self.apply_calibration(values, motor_names)
return values return values
def apply_calibration( def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree. a "zero position" at 0 degree.
@ -522,9 +512,7 @@ class DynamixelMotorsBus:
return values return values
def autocorrect_calibration( def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues. """This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration. Some motors might have values outside of expected maximum bounds after calibration.
@ -566,23 +554,15 @@ class DynamixelMotorsBus:
values[i] *= -1 values[i] *= -1
# Convert from initial range to range [-180, 180] degrees # Convert from initial range to range [-180, 180] degrees
calib_val = ( calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees # Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution # (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
low_factor = ( low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
-(resolution // 2) - values[i] - homing_offset upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
) / resolution
upp_factor = (
(resolution // 2) - values[i] - homing_offset
) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
start_pos = self.calibration["start_pos"][calib_idx] start_pos = self.calibration["start_pos"][calib_idx]
@ -590,9 +570,7 @@ class DynamixelMotorsBus:
# Convert from initial range to range [0, 100] in % # Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and ( in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
calib_val < UPPER_BOUND_LINEAR
)
# Solve this inequality to find the factor to shift the range into [0, 100] % # Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@ -608,27 +586,19 @@ class DynamixelMotorsBus:
factor = math.ceil(low_factor) factor = math.ceil(low_factor)
if factor > upp_factor: if factor > upp_factor:
raise ValueError( raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
else: else:
factor = math.ceil(upp_factor) factor = math.ceil(upp_factor)
if factor > low_factor: if factor > low_factor:
raise ValueError( raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = ( out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
logging.warning( logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@ -638,9 +608,7 @@ class DynamixelMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration( def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Inverse of `apply_calibration`.""" """Inverse of `apply_calibration`."""
if motor_names is None: if motor_names is None:
motor_names = self.motor_names motor_names = self.motor_names
@ -679,9 +647,7 @@ class DynamixelMotorsBus:
values = np.round(values).astype(np.int32) values = np.round(values).astype(np.int32)
return values return values
def read_with_motor_ids( def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
if self.mock: if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl import tests.motors.mock_dynamixel_sdk as dxl
else: else:
@ -783,9 +749,7 @@ class DynamixelMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names) values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors # log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name( delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
"delta_timestamp_s", "read", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received # log the utc time at which the data was received
@ -794,9 +758,7 @@ class DynamixelMotorsBus:
return values return values
def write_with_motor_ids( def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
if self.mock: if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl import tests.motors.mock_dynamixel_sdk as dxl
else: else:
@ -891,9 +853,7 @@ class DynamixelMotorsBus:
) )
# log the number of seconds it took to write the data to the motors # log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name( delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
"delta_timestamp_s", "write", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command? # TODO(rcadene): should we log the time before sending the write command?

View File

@ -140,9 +140,7 @@ NUM_READ_RETRY = 20
NUM_WRITE_RETRY = 20 NUM_WRITE_RETRY = 20
def convert_degrees_to_steps( def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation. """This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180. It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@ -370,9 +368,7 @@ class FeetechMotorsBus:
indices = [] indices = []
for idx in tqdm.tqdm(possible_ids): for idx in tqdm.tqdm(possible_ids):
try: try:
present_idx = self.read_with_motor_ids( present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
except ConnectionError: except ConnectionError:
continue continue
@ -388,9 +384,7 @@ class FeetechMotorsBus:
def set_bus_baudrate(self, baudrate): def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate() present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate: if present_bus_baudrate != baudrate:
print( print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
self.port_handler.setBaudRate(baudrate) self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate: if self.port_handler.getBaudRate() != baudrate:
@ -411,9 +405,7 @@ class FeetechMotorsBus:
def set_calibration(self, calibration: dict[str, list]): def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration self.calibration = calibration
def apply_calibration_autocorrect( def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct. """This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@ -426,9 +418,7 @@ class FeetechMotorsBus:
values = self.apply_calibration(values, motor_names) values = self.apply_calibration(values, motor_names)
return values return values
def apply_calibration( def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree. a "zero position" at 0 degree.
@ -502,9 +492,7 @@ class FeetechMotorsBus:
return values return values
def autocorrect_calibration( def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues. """This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration. Some motors might have values outside of expected maximum bounds after calibration.
@ -543,26 +531,18 @@ class FeetechMotorsBus:
values[i] *= -1 values[i] *= -1
# Convert from initial range to range [-180, 180] degrees # Convert from initial range to range [-180, 180] degrees
calib_val = ( calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees # Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution # (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
low_factor = ( low_factor = (
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) -HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
- values[i]
- homing_offset
) / resolution ) / resolution
upp_factor = ( upp_factor = (
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
- values[i]
- homing_offset
) / resolution ) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
@ -571,9 +551,7 @@ class FeetechMotorsBus:
# Convert from initial range to range [0, 100] in % # Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and ( in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
calib_val < UPPER_BOUND_LINEAR
)
# Solve this inequality to find the factor to shift the range into [0, 100] % # Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@ -589,27 +567,19 @@ class FeetechMotorsBus:
factor = math.ceil(low_factor) factor = math.ceil(low_factor)
if factor > upp_factor: if factor > upp_factor:
raise ValueError( raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
else: else:
factor = math.ceil(upp_factor) factor = math.ceil(upp_factor)
if factor > low_factor: if factor > low_factor:
raise ValueError( raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = ( out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
logging.warning( logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@ -619,9 +589,7 @@ class FeetechMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration( def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Inverse of `apply_calibration`.""" """Inverse of `apply_calibration`."""
if motor_names is None: if motor_names is None:
motor_names = self.motor_names motor_names = self.motor_names
@ -697,9 +665,7 @@ class FeetechMotorsBus:
return values return values
def read_with_motor_ids( def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
if self.mock: if self.mock:
import tests.motors.mock_scservo_sdk as scs import tests.motors.mock_scservo_sdk as scs
else: else:
@ -808,9 +774,7 @@ class FeetechMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names) values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors # log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name( delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
"delta_timestamp_s", "read", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received # log the utc time at which the data was received
@ -819,9 +783,7 @@ class FeetechMotorsBus:
return values return values
def write_with_motor_ids( def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
if self.mock: if self.mock:
import tests.motors.mock_scservo_sdk as scs import tests.motors.mock_scservo_sdk as scs
else: else:
@ -916,9 +878,7 @@ class FeetechMotorsBus:
) )
# log the number of seconds it took to write the data to the motors # log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name( delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
"delta_timestamp_s", "write", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command? # TODO(rcadene): should we log the time before sending the write command?

View File

@ -69,13 +69,9 @@ class ManipulatorRobotConfig(RobotConfig):
if not cam.mock: if not cam.mock:
cam.mock = True cam.mock = True
if self.max_relative_target is not None and isinstance( if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
self.max_relative_target, Sequence
):
for name in self.follower_arms: for name in self.follower_arms:
if len(self.follower_arms[name].motors) != len( if len(self.follower_arms[name].motors) != len(self.max_relative_target):
self.max_relative_target
):
raise ValueError( raise ValueError(
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has " f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the " f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "

View File

@ -24,7 +24,9 @@ from lerobot.common.robot_devices.motors.dynamixel import (
) )
from lerobot.common.robot_devices.motors.utils import MotorsBus from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
# The following positions are provided in nominal degree range ]-180, +180[ # The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used. # For more info on these constants, see comments in the code where they get used.
@ -35,9 +37,7 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode): def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])): if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError( raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
def apply_drive_mode(position, drive_mode): def apply_drive_mode(position, drive_mode):
@ -78,16 +78,12 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
``` ```
""" """
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError( raise ValueError("To run calibration, the torque must be disabled on all motors.")
"To run calibration, the torque must be disabled on all motors."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position") print("\nMove arm to zero position")
print( print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
input("Press Enter to continue...") input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@ -108,15 +104,10 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# of the previous motor in the kinetic chain. # of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position") print("\nMove arm to rotated target position")
print( print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
input("Press Enter to continue...") input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps( rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
ROTATED_POSITION_DEGREE, arm.motor_models
)
# Find drive mode by rotating each motor by a quarter of a turn. # Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@ -125,15 +116,11 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# Re-compute homing offset to take into account drive mode # Re-compute homing offset to take into account drive mode
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode) rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
rotated_nearest_pos = compute_nearest_rounded_position( rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
rotated_drived_pos, arm.motor_models
)
homing_offset = rotated_target_pos - rotated_nearest_pos homing_offset = rotated_target_pos - rotated_nearest_pos
print("\nMove arm to rest position") print("\nMove arm to rest position")
print( print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
input("Press Enter to continue...") input("Press Enter to continue...")
print() print()

View File

@ -26,7 +26,9 @@ from lerobot.common.robot_devices.motors.feetech import (
) )
from lerobot.common.robot_devices.motors.utils import MotorsBus from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
# The following positions are provided in nominal degree range ]-180, +180[ # The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used. # For more info on these constants, see comments in the code where they get used.
@ -37,9 +39,7 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode): def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])): if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError( raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
def apply_drive_mode(position, drive_mode): def apply_drive_mode(position, drive_mode):
@ -140,9 +140,7 @@ def apply_offset(calib, offset):
return calib return calib
def run_arm_auto_calibration( def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
if robot_type == "so100": if robot_type == "so100":
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type) return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
elif robot_type == "moss": elif robot_type == "moss":
@ -151,27 +149,18 @@ def run_arm_auto_calibration(
raise ValueError(robot_type) raise ValueError(robot_type)
def run_arm_auto_calibration_so100( def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError( raise ValueError("To run calibration, the torque must be disabled on all motors.")
"To run calibration, the torque must be disabled on all motors."
)
if not (robot_type == "so100" and arm_type == "follower"): if not (robot_type == "so100" and arm_type == "follower"):
raise NotImplementedError( raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
"Auto calibration only supports the follower of so100 arms for now."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position") print("\nMove arm to initial position")
print( print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
input("Press Enter to continue...") input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254]) # Lower the acceleration of the motors (in [0,254])
@ -225,9 +214,7 @@ def run_arm_auto_calibration_so100(
) )
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024) calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
arm.write( arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
"Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex"
)
time.sleep(1) time.sleep(1)
def in_between_move_hook(): def in_between_move_hook():
@ -261,13 +248,9 @@ def run_arm_auto_calibration_so100(
"shoulder_lift", "shoulder_lift",
) )
time.sleep(2) time.sleep(2)
arm.write( arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
"Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex"
)
time.sleep(2) time.sleep(2)
arm.write( arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
"Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex"
)
time.sleep(2) time.sleep(2)
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper") arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
time.sleep(2) time.sleep(2)
@ -288,9 +271,7 @@ def run_arm_auto_calibration_so100(
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex") arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
time.sleep(1) time.sleep(1)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex") arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
arm.write( arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
"Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift"
)
time.sleep(1) time.sleep(1)
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan") arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
time.sleep(1) time.sleep(1)
@ -319,27 +300,18 @@ def run_arm_auto_calibration_so100(
return calib_dict return calib_dict
def run_arm_auto_calibration_moss( def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError( raise ValueError("To run calibration, the torque must be disabled on all motors.")
"To run calibration, the torque must be disabled on all motors."
)
if not (robot_type == "moss" and arm_type == "follower"): if not (robot_type == "moss" and arm_type == "follower"):
raise NotImplementedError( raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
"Auto calibration only supports the follower of moss arms for now."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position") print("\nMove arm to initial position")
print( print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
input("Press Enter to continue...") input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254]) # Lower the acceleration of the motors (in [0,254])
@ -423,12 +395,8 @@ def run_arm_auto_calibration_moss(
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex") arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
time.sleep(1) time.sleep(1)
arm.write( arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
"Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift" arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
)
arm.write(
"Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex"
)
time.sleep(2) time.sleep(2)
calib_modes = [] calib_modes = []
@ -455,9 +423,7 @@ def run_arm_auto_calibration_moss(
return calib_dict return calib_dict
def run_arm_manual_calibration( def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""This function ensures that a neural network trained on data collected on a given robot """This function ensures that a neural network trained on data collected on a given robot
can work on another robot. For instance before calibration, setting a same goal position can work on another robot. For instance before calibration, setting a same goal position
for each motor of two different robots will get two very different positions. But after calibration, for each motor of two different robots will get two very different positions. But after calibration,
@ -480,16 +446,12 @@ def run_arm_manual_calibration(
``` ```
""" """
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError( raise ValueError("To run calibration, the torque must be disabled on all motors.")
"To run calibration, the torque must be disabled on all motors."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position") print("\nMove arm to zero position")
print( print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
input("Press Enter to continue...") input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@ -509,15 +471,10 @@ def run_arm_manual_calibration(
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# of the previous motor in the kinetic chain. # of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position") print("\nMove arm to rotated target position")
print( print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
input("Press Enter to continue...") input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps( rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
ROTATED_POSITION_DEGREE, arm.motor_models
)
# Find drive mode by rotating each motor by a quarter of a turn. # Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@ -529,9 +486,7 @@ def run_arm_manual_calibration(
homing_offset = rotated_target_pos - rotated_drived_pos homing_offset = rotated_target_pos - rotated_drived_pos
print("\nMove arm to rest position") print("\nMove arm to rest position")
print( print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
input("Press Enter to continue...") input("Press Enter to continue...")
print() print()

View File

@ -42,9 +42,7 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
local_dict = {} local_dict = {}
for name, cam in cameras.items(): for name, cam in cameras.items():
frame = cam.async_read() frame = cam.async_read()
ret, buffer = cv2.imencode( ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]
)
if ret: if ret:
local_dict[name] = base64.b64encode(buffer).decode("utf-8") local_dict[name] = base64.b64encode(buffer).decode("utf-8")
else: else:
@ -76,9 +74,7 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
print(f"[INFO] Loaded calibration from {calib_file}") print(f"[INFO] Loaded calibration from {calib_file}")
else: else:
print("[INFO] Calibration file not found. Running manual calibration...") print("[INFO] Calibration file not found. Running manual calibration...")
calibration = run_arm_manual_calibration( calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
motors_bus, "lekiwi", "follower_arm", "follower"
)
print(f"[INFO] Calibration complete. Saving to {calib_file}") print(f"[INFO] Calibration complete. Saving to {calib_file}")
with open(calib_file, "w") as f: with open(calib_file, "w") as f:
json.dump(calibration, f) json.dump(calibration, f)
@ -174,9 +170,7 @@ def run_lekiwi(robot_config):
f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}" f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}"
) )
else: else:
for motor, pos in zip( for motor, pos in zip(arm_motor_ids, arm_positions, strict=False):
arm_motor_ids, arm_positions, strict=False
):
motors_bus.write("Goal_Position", pos, motor) motors_bus.write("Goal_Position", pos, motor)
# Process wheel (base) commands. # Process wheel (base) commands.
if "raw_velocity" in data: if "raw_velocity" in data:
@ -207,9 +201,7 @@ def run_lekiwi(robot_config):
try: try:
pos = motors_bus.read("Present_Position", motor) pos = motors_bus.read("Present_Position", motor)
# Convert the position to a float (or use as is if already numeric). # Convert the position to a float (or use as is if already numeric).
follower_arm_state.append( follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos)
float(pos) if not isinstance(pos, (int, float)) else pos
)
except Exception as e: except Exception as e:
print(f"[ERROR] Reading motor {motor} failed: {e}") print(f"[ERROR] Reading motor {motor} failed: {e}")

View File

@ -285,9 +285,7 @@ class ManipulatorRobot:
# to squeeze the gripper and have it spring back to an open position on its own. # to squeeze the gripper and have it spring back to an open position on its own.
for name in self.leader_arms: for name in self.leader_arms:
self.leader_arms[name].write("Torque_Enable", 1, "gripper") self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write( self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
# Check both arms can be read # Check both arms can be read
for name in self.follower_arms: for name in self.follower_arms:
@ -323,22 +321,16 @@ class ManipulatorRobot:
run_arm_calibration, run_arm_calibration,
) )
calibration = run_arm_calibration( calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
arm, self.robot_type, name, arm_type
)
elif self.robot_type in ["so100", "moss", "lekiwi"]: elif self.robot_type in ["so100", "moss", "lekiwi"]:
from lerobot.common.robot_devices.robots.feetech_calibration import ( from lerobot.common.robot_devices.robots.feetech_calibration import (
run_arm_manual_calibration, run_arm_manual_calibration,
) )
calibration = run_arm_manual_calibration( calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
arm, self.robot_type, name, arm_type
)
print( print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
f"Calibration is done! Saving calibration file '{arm_calib_path}'"
)
arm_calib_path.parent.mkdir(parents=True, exist_ok=True) arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f: with open(arm_calib_path, "w") as f:
json.dump(calibration, f) json.dump(calibration, f)
@ -357,17 +349,13 @@ class ManipulatorRobot:
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError( raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
"To run set robot preset, the torque must be disabled on all motors."
)
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm, # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
# you could end up with a servo with a position 0 or 4095 at a crucial point See [ # you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [ all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
name for name in arm.motor_names if name != "gripper"
]
if len(all_motors_except_gripper) > 0: if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Koch motors # 4 corresponds to Extended Position on Koch motors
arm.write("Operating_Mode", 4, all_motors_except_gripper) arm.write("Operating_Mode", 4, all_motors_except_gripper)
@ -396,9 +384,7 @@ class ManipulatorRobot:
# Enable torque on the gripper of the leader arms, and move it to 45 degrees, # Enable torque on the gripper of the leader arms, and move it to 45 degrees,
# so that we can use it as a trigger to close the gripper of the follower arms. # so that we can use it as a trigger to close the gripper of the follower arms.
self.leader_arms[name].write("Torque_Enable", 1, "gripper") self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write( self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
def set_aloha_robot_preset(self): def set_aloha_robot_preset(self):
def set_shadow_(arm): def set_shadow_(arm):
@ -428,15 +414,11 @@ class ManipulatorRobot:
# you could end up with a servo with a position 0 or 4095 at a crucial point See [ # you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [ all_motors_except_gripper = [
name name for name in self.follower_arms[name].motor_names if name != "gripper"
for name in self.follower_arms[name].motor_names
if name != "gripper"
] ]
if len(all_motors_except_gripper) > 0: if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Aloha motors # 4 corresponds to Extended Position on Aloha motors
self.follower_arms[name].write( self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper)
"Operating_Mode", 4, all_motors_except_gripper
)
# Use 'position control current based' for follower gripper to be limited by the limit of the current. # Use 'position control current based' for follower gripper to be limited by the limit of the current.
# It can grasp an object without forcing too much even tho, # It can grasp an object without forcing too much even tho,
@ -484,9 +466,7 @@ class ManipulatorRobot:
before_lread_t = time.perf_counter() before_lread_t = time.perf_counter()
leader_pos[name] = self.leader_arms[name].read("Present_Position") leader_pos[name] = self.leader_arms[name].read("Present_Position")
leader_pos[name] = torch.from_numpy(leader_pos[name]) leader_pos[name] = torch.from_numpy(leader_pos[name])
self.logs[f"read_leader_{name}_pos_dt_s"] = ( self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
time.perf_counter() - before_lread_t
)
# Send goal position to the follower # Send goal position to the follower
follower_goal_pos = {} follower_goal_pos = {}
@ -507,18 +487,14 @@ class ManipulatorRobot:
if self.config.max_relative_target is not None: if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position") present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos) present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position( goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
goal_pos, present_pos, self.config.max_relative_target
)
# Used when record_data=True # Used when record_data=True
follower_goal_pos[name] = goal_pos follower_goal_pos[name] = goal_pos
goal_pos = goal_pos.numpy().astype(np.float32) goal_pos = goal_pos.numpy().astype(np.float32)
self.follower_arms[name].write("Goal_Position", goal_pos) self.follower_arms[name].write("Goal_Position", goal_pos)
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = ( self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
time.perf_counter() - before_fwrite_t
)
# Early exit when recording data is not requested # Early exit when recording data is not requested
if not record_data: if not record_data:
@ -531,9 +507,7 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter() before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position") follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name]) follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = ( self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
time.perf_counter() - before_fread_t
)
# Create state by concatenating follower current position # Create state by concatenating follower current position
state = [] state = []
@ -555,12 +529,8 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter() before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read() images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name]) images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
"delta_timestamp_s" self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionaries # Populate output dictionaries
obs_dict, action_dict = {}, {} obs_dict, action_dict = {}, {}
@ -584,9 +554,7 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter() before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position") follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name]) follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = ( self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
time.perf_counter() - before_fread_t
)
# Create state by concatenating follower current position # Create state by concatenating follower current position
state = [] state = []
@ -601,12 +569,8 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter() before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read() images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name]) images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
"delta_timestamp_s" self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionaries and format to pytorch # Populate output dictionaries and format to pytorch
obs_dict = {} obs_dict = {}
@ -652,9 +616,7 @@ class ManipulatorRobot:
if self.config.max_relative_target is not None: if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position") present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos) present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position( goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
goal_pos, present_pos, self.config.max_relative_target
)
# Save tensor to concat and return # Save tensor to concat and return
action_sent.append(goal_pos) action_sent.append(goal_pos)

View File

@ -271,9 +271,7 @@ class MobileManipulator:
calibration = json.load(f) calibration = json.load(f)
else: else:
print(f"Missing calibration file '{arm_calib_path}'") print(f"Missing calibration file '{arm_calib_path}'")
calibration = run_arm_manual_calibration( calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
arm, self.robot_type, name, arm_type
)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
arm_calib_path.parent.mkdir(parents=True, exist_ok=True) arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f: with open(arm_calib_path, "w") as f:
@ -303,9 +301,7 @@ class MobileManipulator:
bus.write("Torque_Enable", 0, motor_id) bus.write("Torque_Enable", 0, motor_id)
# Then filter out wheels # Then filter out wheels
arm_only_dict = { arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
k: v for k, v in bus.motors.items() if not k.startswith("wheel_")
}
if not arm_only_dict: if not arm_only_dict:
continue continue
@ -377,9 +373,7 @@ class MobileManipulator:
if new_arm_state is not None and frames is not None: if new_arm_state is not None and frames is not None:
self.last_frames = frames self.last_frames = frames
remote_arm_state_tensor = torch.tensor( remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
new_arm_state, dtype=torch.float32
)
self.last_remote_arm_state = remote_arm_state_tensor self.last_remote_arm_state = remote_arm_state_tensor
present_speed = new_speed present_speed = new_speed
@ -405,10 +399,7 @@ class MobileManipulator:
def _process_present_speed(self, present_speed: dict) -> torch.Tensor: def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
state_tensor = torch.zeros(3, dtype=torch.int32) state_tensor = torch.zeros(3, dtype=torch.int32)
if present_speed: if present_speed:
decoded = { decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
key: MobileManipulator.raw_to_degps(value)
for key, value in present_speed.items()
}
if "1" in decoded: if "1" in decoded:
state_tensor[0] = decoded["1"] state_tensor[0] = decoded["1"]
if "2" in decoded: if "2" in decoded:
@ -421,9 +412,7 @@ class MobileManipulator:
self, record_data: bool = False self, record_data: bool = False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError( raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
"MobileManipulator is not connected. Run `connect()` first."
)
speed_setting = self.speed_levels[self.speed_index] speed_setting = self.speed_levels[self.speed_index]
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4 xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
@ -495,9 +484,7 @@ class MobileManipulator:
body_state[2], body_state[2],
) # Convert x,y to mm/s ) # Convert x,y to mm/s
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32) wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
combined_state_tensor = torch.cat( combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
(remote_arm_state_tensor, wheel_state_tensor), dim=0
)
obs_dict = {"observation.state": combined_state_tensor} obs_dict = {"observation.state": combined_state_tensor}

View File

@ -52,9 +52,7 @@ class StretchRobot(StretchAPI):
def connect(self) -> None: def connect(self) -> None:
self.is_connected = self.startup() self.is_connected = self.startup()
if not self.is_connected: if not self.is_connected:
print( print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
"Another process is already using Stretch. Try running 'stretch_free_robot_process.py'"
)
raise ConnectionError() raise ConnectionError()
for name in self.cameras: for name in self.cameras:
@ -62,9 +60,7 @@ class StretchRobot(StretchAPI):
self.is_connected = self.is_connected and self.cameras[name].is_connected self.is_connected = self.is_connected and self.cameras[name].is_connected
if not self.is_connected: if not self.is_connected:
print( print("Could not connect to the cameras, check that all cameras are plugged-in.")
"Could not connect to the cameras, check that all cameras are plugged-in."
)
raise ConnectionError() raise ConnectionError()
self.run_calibration() self.run_calibration()
@ -109,12 +105,8 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter() before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read() images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name]) images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
"delta_timestamp_s" self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionaries # Populate output dictionaries
obs_dict, action_dict = {}, {} obs_dict, action_dict = {}, {}
@ -158,12 +150,8 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter() before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read() images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name]) images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
"delta_timestamp_s" self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionaries # Populate output dictionaries
obs_dict = {} obs_dict = {}

View File

@ -69,9 +69,7 @@ class HubMixin:
if push_to_hub: if push_to_hub:
if repo_id is None: if repo_id is None:
repo_id = save_directory.name # Defaults to `save_directory` name repo_id = save_directory.name # Defaults to `save_directory` name
return self.push_to_hub( return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs
)
return None return None
def _save_pretrained(self, save_directory: Path) -> None: def _save_pretrained(self, save_directory: Path) -> None:
@ -177,9 +175,7 @@ class HubMixin:
The url of the commit of your object in the given repository. The url of the commit of your object in the given repository.
""" """
api = HfApi(token=token) api = HfApi(token=token)
repo_id = api.create_repo( repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
repo_id=repo_id, private=private, exist_ok=True
).repo_id
if commit_message is None: if commit_message is None:
if "Policy" in self.__class__.__name__: if "Policy" in self.__class__.__name__:

View File

@ -17,9 +17,7 @@ import importlib
import logging import logging
def is_package_available( def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
pkg_name: str, return_version: bool = False
) -> tuple[bool, str] | bool:
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
Check if the package spec exists and grab its version to avoid importing a local directory. Check if the package spec exists and grab its version to avoid importing a local directory.
**Note:** this doesn't work for all packages. **Note:** this doesn't work for all packages.

View File

@ -20,16 +20,7 @@ from typing import TypeVar
import imageio import imageio
JsonLike = ( JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
str
| int
| float
| bool
| None
| list["JsonLike"]
| dict[str, "JsonLike"]
| tuple["JsonLike", ...]
)
T = TypeVar("T", bound=JsonLike) T = TypeVar("T", bound=JsonLike)
@ -85,9 +76,7 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
# Check length # Check length
if len(target) != len(source): if len(target) != len(source):
raise ValueError( raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
f"List length mismatch: expected {len(target)}, got {len(source)}"
)
# Recursively update each element. # Recursively update each element.
for i in range(len(target)): for i in range(len(target)):
@ -99,14 +88,10 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
# which we'll convert back to a tuple. # which we'll convert back to a tuple.
elif isinstance(target, tuple): elif isinstance(target, tuple):
if not isinstance(source, list): if not isinstance(source, list):
raise TypeError( raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
f"Type mismatch: expected list (for tuple), got {type(source)}"
)
if len(target) != len(source): if len(target) != len(source):
raise ValueError( raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
f"Tuple length mismatch: expected {len(target)}, got {len(source)}"
)
# Convert each element, forming a new tuple. # Convert each element, forming a new tuple.
converted_items = [] converted_items = []
@ -120,9 +105,7 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
else: else:
# Check the exact type. If these must match 1:1, do: # Check the exact type. If these must match 1:1, do:
if type(target) is not type(source): if type(target) is not type(source):
raise TypeError( raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
f"Type mismatch: expected {type(target)}, got {type(source)}"
)
return source return source
# Perform the in-place/recursive deserialization # Perform the in-place/recursive deserialization

View File

@ -107,17 +107,13 @@ class MetricsTracker:
self.episodes = self.samples / self._avg_samples_per_ep self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames self.epochs = self.samples / self._num_frames
def __getattr__( def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
self, name: str
) -> int | dict[str, AverageMeter] | AverageMeter | Any:
if name in self.__dict__: if name in self.__dict__:
return self.__dict__[name] return self.__dict__[name]
elif name in self.metrics: elif name in self.metrics:
return self.metrics[name] return self.metrics[name]
else: else:
raise AttributeError( raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
def __setattr__(self, name: str, value: Any) -> None: def __setattr__(self, name: str, value: Any) -> None:
if name in self.__dict__: if name in self.__dict__:
@ -125,9 +121,7 @@ class MetricsTracker:
elif name in self.metrics: elif name in self.metrics:
self.metrics[name].update(value) self.metrics[name].update(value)
else: else:
raise AttributeError( raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
def step(self) -> None: def step(self) -> None:
""" """

View File

@ -123,9 +123,7 @@ def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
""" """
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")} py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")} np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
torch_rng_state_dict = { torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
k: v for k, v in rng_state_dict.items() if k.startswith("torch")
}
deserialize_python_rng_state(py_rng_state_dict) deserialize_python_rng_state(py_rng_state_dict)
deserialize_numpy_rng_state(np_rng_state_dict) deserialize_numpy_rng_state(np_rng_state_dict)

View File

@ -48,9 +48,7 @@ def auto_select_torch_device() -> torch.device:
logging.info("Metal backend detected, using cuda.") logging.info("Metal backend detected, using cuda.")
return torch.device("mps") return torch.device("mps")
else: else:
logging.warning( logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
"No accelerated backend detected. Using default cpu, this will be slow."
)
return torch.device("cpu") return torch.device("cpu")
@ -98,9 +96,7 @@ def is_torch_device_available(try_device: str) -> bool:
elif try_device == "cpu": elif try_device == "cpu":
return True return True
else: else:
raise ValueError( raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu."
)
def is_amp_available(device: str): def is_amp_available(device: str):
@ -158,10 +154,7 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
except ValueError: # most likely because path1 is not a subpath of path2 except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts common_parts = Path(osp.commonpath([path1, path2])).parts
return Path( return Path(
"/".join( "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
[".."] * (len(path2.parts) - len(common_parts))
+ list(path1.parts[len(common_parts) :])
)
) )
@ -172,26 +165,10 @@ def print_cuda_memory_usage():
gc.collect() gc.collect()
# Also clear the cache if you want to fully release the memory # Also clear the cache if you want to fully release the memory
torch.cuda.empty_cache() torch.cuda.empty_cache()
print( print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
"Current GPU Memory Allocated: {:.2f} MB".format( print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
torch.cuda.memory_allocated(0) / 1024**2 print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
) print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
)
print(
"Maximum GPU Memory Allocated: {:.2f} MB".format(
torch.cuda.max_memory_allocated(0) / 1024**2
)
)
print(
"Current GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.memory_reserved(0) / 1024**2
)
)
print(
"Maximum GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.max_memory_reserved(0) / 1024**2
)
)
def capture_timestamp_utc(): def capture_timestamp_utc():
@ -223,9 +200,7 @@ def say(text, blocking=False):
if blocking: if blocking:
subprocess.run(cmd, check=True) subprocess.run(cmd, check=True)
else: else:
subprocess.Popen( subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0
)
def log_say(text, play_sounds, blocking=False): def log_say(text, play_sounds, blocking=False):

View File

@ -26,9 +26,7 @@ from lerobot.common.constants import PRETRAINED_MODEL_DIR
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
def cfg_to_group( def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
cfg: TrainPipelineConfig, return_list: bool = False
) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list.""" """Return a group name for logging. Optionally returns group name as list."""
lst = [ lst = [
f"policy:{cfg.policy.type}", f"policy:{cfg.policy.type}",
@ -95,9 +93,7 @@ class WandBLogger:
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online", mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
) )
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info( logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
)
self._wandb = wandb self._wandb = wandb
def log_policy(self, checkpoint_dir: Path): def log_policy(self, checkpoint_dir: Path):
@ -109,9 +105,7 @@ class WandBLogger:
artifact_name = f"{self._group}-{step_id}" artifact_name = f"{self._group}-{step_id}"
artifact_name = get_safe_wandb_artifact_name(artifact_name) artifact_name = get_safe_wandb_artifact_name(artifact_name)
artifact = self._wandb.Artifact(artifact_name, type="model") artifact = self._wandb.Artifact(artifact_name, type="model")
artifact.add_file( artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE
)
self._wandb.log_artifact(artifact) self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train"): def log_dict(self, d: dict, step: int, mode: str = "train"):

View File

@ -33,9 +33,7 @@ class DatasetConfig:
# Root directory where the dataset will be stored (e.g. 'dataset/path'). # Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | None = None root: str | None = None
episodes: list[int] | None = None episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field( image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
default_factory=ImageTransformsConfig
)
revision: str | None = None revision: str | None = None
use_imagenet_stats: bool = True use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec) video_backend: str = field(default_factory=get_safe_default_codec)

View File

@ -40,9 +40,7 @@ class EvalPipelineConfig:
policy_path = parser.get_path_arg("policy") policy_path = parser.get_path_arg("policy")
if policy_path: if policy_path:
cli_overrides = parser.get_cli_overrides("policy") cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained( self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
policy_path, cli_overrides=cli_overrides
)
self.policy.pretrained_path = policy_path self.policy.pretrained_path = policy_path
else: else:

View File

@ -29,9 +29,7 @@ PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
draccus.set_config_type("json") draccus.set_config_type("json")
def get_cli_overrides( def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
field_name: str, args: Sequence[str] | None = None
) -> list[str] | None:
"""Parses arguments from cli at a given nested attribute level. """Parses arguments from cli at a given nested attribute level.
For example, supposing the main script was called with: For example, supposing the main script was called with:
@ -158,9 +156,7 @@ def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")] return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
def filter_path_args( def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
fields_to_filter: str | list[str], args: Sequence[str] | None = None
) -> list[str]:
""" """
Filters command-line arguments related to fields with specific path arguments. Filters command-line arguments related to fields with specific path arguments.
@ -188,9 +184,7 @@ def filter_path_args(
argument=None, argument=None,
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}", message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
) )
filtered_args = [ filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
arg for arg in filtered_args if not arg.startswith(f"--{field}.")
]
return filtered_args return filtered_args
@ -222,9 +216,7 @@ def wrap(config_path: Path | None = None):
load_plugin(plugin_path) load_plugin(plugin_path)
except PluginLoadError as e: except PluginLoadError as e:
# add the relevant CLI arg to the error message # add the relevant CLI arg to the error message
raise PluginLoadError( raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}"
) from e
cli_args = filter_arg(plugin_cli_arg, cli_args) cli_args = filter_arg(plugin_cli_arg, cli_args)
config_path_cli = parse_arg("config_path", cli_args) config_path_cli = parse_arg("config_path", cli_args)
if has_method(argtype, "__get_path_fields__"): if has_method(argtype, "__get_path_fields__"):
@ -234,9 +226,7 @@ def wrap(config_path: Path | None = None):
cli_args = filter_arg("config_path", cli_args) cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args) cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
else: else:
cfg = draccus.parse( cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
config_class=argtype, config_path=config_path, args=cli_args
)
response = fn(cfg, *args, **kwargs) response = fn(cfg, *args, **kwargs)
return response return response

View File

@ -68,9 +68,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
self.pretrained_path = None self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device): if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device() auto_device = auto_select_torch_device()
logging.warning( logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
f"Device '{self.device}' is not available. Switching to '{auto_device}'."
)
self.device = auto_device.type self.device = auto_device.type
# Automatically deactivate AMP if necessary # Automatically deactivate AMP if necessary
@ -124,11 +122,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
@property @property
def image_features(self) -> dict[str, PolicyFeature]: def image_features(self) -> dict[str, PolicyFeature]:
return { return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
key: ft
for key, ft in self.input_features.items()
if ft.type is FeatureType.VISUAL
}
@property @property
def action_feature(self) -> PolicyFeature | None: def action_feature(self) -> PolicyFeature | None:

View File

@ -73,9 +73,7 @@ class TrainPipelineConfig(HubMixin):
if policy_path: if policy_path:
# Only load the policy config # Only load the policy config
cli_overrides = parser.get_cli_overrides("policy") cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained( self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
policy_path, cli_overrides=cli_overrides
)
self.policy.pretrained_path = policy_path self.policy.pretrained_path = policy_path
elif self.resume: elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir # The entire train config is already loaded, we just need to get the checkpoint dir
@ -99,11 +97,7 @@ class TrainPipelineConfig(HubMixin):
else: else:
self.job_name = f"{self.env.type}_{self.policy.type}" self.job_name = f"{self.env.type}_{self.policy.type}"
if ( if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
not self.resume
and isinstance(self.output_dir, Path)
and self.output_dir.is_dir()
):
raise FileExistsError( raise FileExistsError(
f"Output directory {self.output_dir} already exists and resume is {self.resume}. " f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
f"Please change your output directory so that {self.output_dir} is not overwritten." f"Please change your output directory so that {self.output_dir} is not overwritten."
@ -114,16 +108,10 @@ class TrainPipelineConfig(HubMixin):
self.output_dir = Path("outputs/train") / train_dir self.output_dir = Path("outputs/train") / train_dir
if isinstance(self.dataset.repo_id, list): if isinstance(self.dataset.repo_id, list):
raise NotImplementedError( raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
"LeRobotMultiDataset is not currently implemented."
)
if not self.use_policy_training_preset and ( if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
self.optimizer is None or self.scheduler is None raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
):
raise ValueError(
"Optimizer and Scheduler must be set when the policy presets are not used."
)
elif self.use_policy_training_preset and not self.resume: elif self.use_policy_training_preset and not self.resume:
self.optimizer = self.policy.get_optimizer_preset() self.optimizer = self.policy.get_optimizer_preset()
self.scheduler = self.policy.get_scheduler_preset() self.scheduler = self.policy.get_scheduler_preset()

View File

@ -67,8 +67,8 @@ def get_motor_bus_cls(brand: str) -> tuple:
def configure_motor(port, brand, model, motor_idx_des, baudrate_des): def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = ( motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = get_motor_bus_cls(
get_motor_bus_cls(brand) brand
) )
# Check if the provided model exists in the model_baud_rate_table # Check if the provided model exists in the model_baud_rate_table
@ -82,9 +82,7 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
motor_index_arbitrary = motor_idx_des # Use the motor ID passed via argument motor_index_arbitrary = motor_idx_des # Use the motor ID passed via argument
motor_model = model # Use the motor model passed via argument motor_model = model # Use the motor model passed via argument
config = motor_bus_config_cls( config = motor_bus_config_cls(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)})
port=port, motors={motor_name: (motor_index_arbitrary, motor_model)}
)
# Initialize the MotorBus with the correct port and motor configurations # Initialize the MotorBus with the correct port and motor configurations
motor_bus = motor_bus_cls(config=config) motor_bus = motor_bus_cls(config=config)
@ -120,26 +118,20 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
break break
if motor_index == -1: if motor_index == -1:
raise ValueError( raise ValueError("No motors detected. Please ensure you have one motor connected.")
"No motors detected. Please ensure you have one motor connected."
)
print(f"Motor index found at: {motor_index}") print(f"Motor index found at: {motor_index}")
if brand == "feetech": if brand == "feetech":
# Allows ID and BAUDRATE to be written in memory # Allows ID and BAUDRATE to be written in memory
motor_bus.write_with_motor_ids( motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
motor_bus.motor_models, motor_index, "Lock", 0
)
if baudrate != baudrate_des: if baudrate != baudrate_des:
print(f"Setting its baudrate to {baudrate_des}") print(f"Setting its baudrate to {baudrate_des}")
baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des) baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des)
# The write can fail, so we allow retries # The write can fail, so we allow retries
motor_bus.write_with_motor_ids( motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx)
motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx
)
time.sleep(0.5) time.sleep(0.5)
motor_bus.set_bus_baudrate(baudrate_des) motor_bus.set_bus_baudrate(baudrate_des)
present_baudrate_idx = motor_bus.read_with_motor_ids( present_baudrate_idx = motor_bus.read_with_motor_ids(
@ -151,16 +143,10 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
print(f"Setting its index to desired index {motor_idx_des}") print(f"Setting its index to desired index {motor_idx_des}")
if brand == "feetech": if brand == "feetech":
motor_bus.write_with_motor_ids( motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
motor_bus.motor_models, motor_index, "Lock", 0 motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
)
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "ID", motor_idx_des
)
present_idx = motor_bus.read_with_motor_ids( present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2)
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
)
if present_idx != motor_idx_des: if present_idx != motor_idx_des:
raise OSError("Failed to write index.") raise OSError("Failed to write index.")
@ -194,12 +180,8 @@ if __name__ == "__main__":
required=True, required=True,
help="Motors bus port (e.g. dynamixel,feetech)", help="Motors bus port (e.g. dynamixel,feetech)",
) )
parser.add_argument( parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
"--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)" parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
)
parser.add_argument(
"--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)"
)
parser.add_argument( parser.add_argument(
"--ID", "--ID",
type=int, type=int,

View File

@ -255,8 +255,7 @@ def record(
if len(robot.cameras) > 0: if len(robot.cameras) > 0:
dataset.start_image_writer( dataset.start_image_writer(
num_processes=cfg.num_image_writer_processes, num_processes=cfg.num_image_writer_processes,
num_threads=cfg.num_image_writer_threads_per_camera num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
* len(robot.cameras),
) )
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video) sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
else: else:
@ -269,19 +268,14 @@ def record(
robot=robot, robot=robot,
use_videos=cfg.video, use_videos=cfg.video,
image_writer_processes=cfg.num_image_writer_processes, image_writer_processes=cfg.num_image_writer_processes,
image_writer_threads=cfg.num_image_writer_threads_per_camera image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
* len(robot.cameras),
) )
# Load pretrained policy # Load pretrained policy
policy = ( policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
)
# Load pretrained policy # Load pretrained policy
policy = ( policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
)
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()

View File

@ -174,10 +174,7 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
leader_pos = robot.leader_arms.main.read("Present_Position") leader_pos = robot.leader_arms.main.read("Present_Position")
action = process_action_fn(leader_pos) action = process_action_fn(leader_pos)
env.step(np.expand_dims(action, 0)) env.step(np.expand_dims(action, 0))
if ( if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
teleop_time_s is not None
and time.perf_counter() - start_teleop_t > teleop_time_s
):
print("Teleoperation processes finished.") print("Teleoperation processes finished.")
break break
@ -209,27 +206,19 @@ def record(
# Load pretrained policy # Load pretrained policy
extra_features = ( extra_features = (
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} {"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
if assign_rewards
else None
) )
policy = None policy = None
if pretrained_policy_name_or_path is not None: if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy( policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
pretrained_policy_name_or_path, policy_overrides
)
if fps is None: if fps is None:
fps = policy_fps fps = policy_fps
logging.warning( logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
f"No fps provided, so using the fps from policy config ({policy_fps})."
)
if policy is None and process_action_from_leader is None: if policy is None and process_action_from_leader is None:
raise ValueError( raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
"Either policy or process_action_fn has to be set to enable control in sim."
)
# initialize listener before sim env # initialize listener before sim env
listener, events = init_keyboard_listener(assign_rewards=assign_rewards) listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
@ -380,9 +369,7 @@ def record(
if events["stop_recording"] or recorded_episodes >= num_episodes: if events["stop_recording"] or recorded_episodes >= num_episodes:
break break
else: else:
logging.info( logging.info("Waiting for a few seconds before starting next episode recording...")
"Waiting for a few seconds before starting next episode recording..."
)
busy_wait(3) busy_wait(3)
log_say("Stop recording", play_sounds, blocking=True) log_say("Stop recording", play_sounds, blocking=True)
@ -481,9 +468,7 @@ if __name__ == "__main__":
required=True, required=True,
help="A description of the task preformed during recording that can be used as a language instruction.", help="A description of the task preformed during recording that can be used as a language instruction.",
) )
parser_record.add_argument( parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
"--num-episodes", type=int, default=50, help="Number of episodes to record."
)
parser_record.add_argument( parser_record.add_argument(
"--run-compute-stats", "--run-compute-stats",
type=int, type=int,
@ -561,9 +546,7 @@ if __name__ == "__main__":
default="lerobot/test", default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
) )
parser_replay.add_argument( parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
"--episode", type=int, default=0, help="Index of the episodes to replay."
)
args = parser.parse_args() args = parser.parse_args()

View File

@ -59,11 +59,7 @@ np_version = np.__version__ if HAS_NP else "N/A"
torch_version = torch.__version__ if HAS_TORCH else "N/A" torch_version = torch.__version__ if HAS_TORCH else "N/A"
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A" torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
cuda_version = ( cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
torch._C._cuda_getCompiledVersion()
if HAS_TORCH and torch.version.cuda is not None
else "N/A"
)
# TODO(aliberts): refactor into an actual command `lerobot env` # TODO(aliberts): refactor into an actual command `lerobot env`
@ -81,9 +77,7 @@ def display_sys_info() -> dict:
"Using GPU in script?": "<fill in>", "Using GPU in script?": "<fill in>",
# "Using distributed or parallel set-up in script?": "<fill in>", # "Using distributed or parallel set-up in script?": "<fill in>",
} }
print( print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
"\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n"
)
print(format_dict(info)) print(format_dict(info))
return info return info

View File

@ -152,8 +152,7 @@ def rollout(
all_observations.append(deepcopy(observation)) all_observations.append(deepcopy(observation))
observation = { observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
for key in observation
} }
# Infer "task" from attributes of environments. # Infer "task" from attributes of environments.
@ -175,10 +174,7 @@ def rollout(
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished. # available of none of the envs finished.
if "final_info" in info: if "final_info" in info:
successes = [ successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
info["is_success"] if info is not None else False
for info in info["final_info"]
]
else: else:
successes = [False] * env.num_envs successes = [False] * env.num_envs
@ -192,13 +188,9 @@ def rollout(
step += 1 step += 1
running_success_rate = ( running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any") einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
.numpy()
.mean()
)
progbar.set_postfix(
{"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
) )
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update() progbar.update()
# Track the final observation. # Track the final observation.
@ -216,9 +208,7 @@ def rollout(
if return_observations: if return_observations:
stacked_observations = {} stacked_observations = {}
for key in all_observations[0]: for key in all_observations[0]:
stacked_observations[key] = torch.stack( stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
[obs[key] for obs in all_observations], dim=1
)
ret["observation"] = stacked_observations ret["observation"] = stacked_observations
if hasattr(policy, "use_original_modules"): if hasattr(policy, "use_original_modules"):
@ -280,9 +270,7 @@ def eval_policy(
return return
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs) n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
if isinstance(env, gym.vector.SyncVectorEnv): if isinstance(env, gym.vector.SyncVectorEnv):
ep_frames.append( ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
np.stack([env.envs[i].render() for i in range(n_to_render_now)])
) # noqa: B023
elif isinstance(env, gym.vector.AsyncVectorEnv): elif isinstance(env, gym.vector.AsyncVectorEnv):
# Here we must render all frames and discard any we don't need. # Here we must render all frames and discard any we don't need.
ep_frames.append(np.stack(env.call("render")[:n_to_render_now])) ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
@ -294,9 +282,7 @@ def eval_policy(
episode_data: dict | None = None episode_data: dict | None = None
# we dont want progress bar when we use slurm, since it clutters the logs # we dont want progress bar when we use slurm, since it clutters the logs
progbar = trange( progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
n_batches, desc="Stepping through eval batches", disable=inside_slurm()
)
for batch_ix in progbar: for batch_ix in progbar:
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout # Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
# step. # step.
@ -326,22 +312,13 @@ def eval_policy(
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done # Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step. # (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
mask = ( mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
torch.arange(n_steps)
<= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)
).int()
# Extend metrics. # Extend metrics.
batch_sum_rewards = einops.reduce( batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
(rollout_data["reward"] * mask), "b n -> b", "sum"
)
sum_rewards.extend(batch_sum_rewards.tolist()) sum_rewards.extend(batch_sum_rewards.tolist())
batch_max_rewards = einops.reduce( batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
(rollout_data["reward"] * mask), "b n -> b", "max"
)
max_rewards.extend(batch_max_rewards.tolist()) max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce( batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
(rollout_data["success"] * mask), "b n -> b", "any"
)
all_successes.extend(batch_successes.tolist()) all_successes.extend(batch_successes.tolist())
if seeds: if seeds:
all_seeds.extend(seeds) all_seeds.extend(seeds)
@ -354,27 +331,17 @@ def eval_policy(
rollout_data, rollout_data,
done_indices, done_indices,
start_episode_index=batch_ix * env.num_envs, start_episode_index=batch_ix * env.num_envs,
start_data_index=( start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
0
if episode_data is None
else (episode_data["index"][-1].item() + 1)
),
fps=env.unwrapped.metadata["render_fps"], fps=env.unwrapped.metadata["render_fps"],
) )
if episode_data is None: if episode_data is None:
episode_data = this_episode_data episode_data = this_episode_data
else: else:
# Some sanity checks to make sure we are correctly compiling the data. # Some sanity checks to make sure we are correctly compiling the data.
assert ( assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
episode_data["episode_index"][-1] + 1
== this_episode_data["episode_index"][0]
)
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0] assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
# Concatenate the episode data. # Concatenate the episode data.
episode_data = { episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
k: torch.cat([episode_data[k], this_episode_data[k]])
for k in episode_data
}
# Maybe render video for visualization. # Maybe render video for visualization.
if max_episodes_rendered > 0 and len(ep_frames) > 0: if max_episodes_rendered > 0 and len(ep_frames) > 0:
@ -392,9 +359,7 @@ def eval_policy(
target=write_video, target=write_video,
args=( args=(
str(video_path), str(video_path),
stacked_frames[ stacked_frames[: done_index + 1], # + 1 to capture the last observation
: done_index + 1
], # + 1 to capture the last observation
env.unwrapped.metadata["render_fps"], env.unwrapped.metadata["render_fps"],
), ),
) )
@ -403,9 +368,7 @@ def eval_policy(
n_episodes_rendered += 1 n_episodes_rendered += 1
progbar.set_postfix( progbar.set_postfix(
{ {"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"
}
) )
# Wait till all video rendering threads are done. # Wait till all video rendering threads are done.
@ -473,16 +436,12 @@ def _compile_episode_data(
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet. # Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = { ep_dict = {
"action": rollout_data["action"][ep_ix, : num_frames - 1], "action": rollout_data["action"][ep_ix, : num_frames - 1],
"episode_index": torch.tensor( "episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
[start_episode_index + ep_ix] * (num_frames - 1)
),
"frame_index": torch.arange(0, num_frames - 1, 1), "frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps, "timestamp": torch.arange(0, num_frames - 1, 1) / fps,
"next.done": rollout_data["done"][ep_ix, : num_frames - 1], "next.done": rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1], "next.success": rollout_data["success"][ep_ix, : num_frames - 1],
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type( "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
torch.float32
),
} }
# For the last observation frame, all other keys will just be copy padded. # For the last observation frame, all other keys will just be copy padded.
@ -498,9 +457,7 @@ def _compile_episode_data(
for key in ep_dicts[0]: for key in ep_dicts[0]:
data_dict[key] = torch.cat([x[key] for x in ep_dicts]) data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange( data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
start_data_index, start_data_index + total_frames, 1
)
return data_dict return data_dict
@ -516,14 +473,10 @@ def eval_main(cfg: EvalPipelineConfig):
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed) set_seed(cfg.seed)
logging.info( logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}"
)
logging.info("Making environment.") logging.info("Making environment.")
env = make_env( env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs
)
logging.info("Making policy.") logging.info("Making policy.")
@ -535,9 +488,7 @@ def eval_main(cfg: EvalPipelineConfig):
with ( with (
torch.no_grad(), torch.no_grad(),
torch.autocast(device_type=device.type) torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
if cfg.policy.use_amp
else nullcontext(),
): ):
info = eval_policy( info = eval_policy(
env, env,

View File

@ -74,9 +74,7 @@ def get_classifier(pretrained_path, config_path):
cfg = init_hydra_config(config_path) cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len( classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
cfg.training.image_keys
) # TODO automate these paths
model = Classifier(classifier_config) model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict()) model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to("mps") model = model.to("mps")
@ -161,17 +159,11 @@ def rollout(
images = [] images = []
for key in image_keys: for key in image_keys:
if display_cameras: if display_cameras:
cv2.imshow( cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1) cv2.waitKey(1)
images.append(observation[key].to("mps")) images.append(observation[key].to("mps"))
reward = ( reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
reward_classifier.predict_reward(images)
if reward_classifier is not None
else 0.0
)
all_rewards.append(reward) all_rewards.append(reward)
# print("REWARD : ", reward) # print("REWARD : ", reward)
@ -235,9 +227,7 @@ def eval_policy(
start_eval = time.perf_counter() start_eval = time.perf_counter()
progbar = trange(n_episodes, desc="Evaluating policy on real robot") progbar = trange(n_episodes, desc="Evaluating policy on real robot")
reward_classifier = get_classifier( reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
reward_classifier_pretrained_path, reward_classifier_config_file
)
for _ in progbar: for _ in progbar:
rollout_data = rollout( rollout_data = rollout(
@ -313,9 +303,7 @@ def init_keyboard_listener():
print("Right arrow key pressed. Exiting loop...") print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True events["exit_early"] = True
elif key == keyboard.Key.left: elif key == keyboard.Key.left:
print( print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
events["rerecord_episode"] = True events["rerecord_episode"] = True
events["exit_early"] = True events["exit_early"] = True
elif key == keyboard.Key.space: elif key == keyboard.Key.space:
@ -380,9 +368,7 @@ if __name__ == "__main__":
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
), ),
) )
parser.add_argument( parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
)
parser.add_argument( parser.add_argument(
"--out-dir", "--out-dir",
help=( help=(

View File

@ -45,13 +45,9 @@ def find_port():
print(f"The port of this MotorsBus is '{port}'") print(f"The port of this MotorsBus is '{port}'")
print("Reconnect the USB cable.") print("Reconnect the USB cable.")
elif len(ports_diff) == 0: elif len(ports_diff) == 0:
raise OSError( raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
f"Could not detect the port. No difference was found ({ports_diff})."
)
else: else:
raise OSError( raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
f"Could not detect the port. More than one port was found ({ports_diff})."
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -14,18 +14,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from statistics import mean, quantiles import time
from functools import lru_cache from functools import lru_cache
from lerobot.scripts.server.utils import setup_process_handlers from queue import Empty
from statistics import mean, quantiles
# from lerobot.scripts.eval import eval_policy # from lerobot.scripts.eval import eval_policy
import grpc import grpc
import hydra import hydra
import torch import torch
from omegaconf import DictConfig from omegaconf import DictConfig
from torch import nn from torch import nn
import time from torch.multiprocessing import Event, Queue
# TODO: Remove the import of maniskill # TODO: Remove the import of maniskill
# from lerobot.common.envs.factory import make_maniskill_env # from lerobot.common.envs.factory import make_maniskill_env
@ -34,34 +34,28 @@ from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
TimerManager, TimerManager,
get_safe_torch_device, get_safe_torch_device,
init_logging,
set_global_seed, set_global_seed,
) )
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
from lerobot.scripts.server.buffer import ( from lerobot.scripts.server.buffer import (
Transition, Transition,
bytes_to_state_dict,
move_state_dict_to_device, move_state_dict_to_device,
move_transition_to_device, move_transition_to_device,
python_object_to_bytes, python_object_to_bytes,
transitions_to_bytes, transitions_to_bytes,
bytes_to_state_dict,
) )
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
from lerobot.scripts.server.network_utils import ( from lerobot.scripts.server.network_utils import (
receive_bytes_in_chunks, receive_bytes_in_chunks,
send_bytes_in_chunks, send_bytes_in_chunks,
) )
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env from lerobot.scripts.server.utils import get_last_item_from_queue, setup_process_handlers
from lerobot.scripts.server import learner_service
from lerobot.common.robot_devices.utils import busy_wait
from torch.multiprocessing import Queue, Event
from queue import Empty
from lerobot.common.utils.utils import init_logging
from lerobot.scripts.server.utils import get_last_item_from_queue
ACTOR_SHUTDOWN_TIMEOUT = 30 ACTOR_SHUTDOWN_TIMEOUT = 30
@ -102,9 +96,7 @@ def receive_policy(
logging.info("[ACTOR] Received policy loop stopped") logging.info("[ACTOR] Received policy loop stopped")
def transitions_stream( def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty:
shutdown_event: Event, transitions_queue: Queue
) -> hilserl_pb2.Empty:
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
message = transitions_queue.get(block=True, timeout=5) message = transitions_queue.get(block=True, timeout=5)
@ -169,9 +161,7 @@ def send_transitions(
) )
try: try:
learner_client.SendTransitions( learner_client.SendTransitions(transitions_stream(shutdown_event, transitions_queue))
transitions_stream(shutdown_event, transitions_queue)
)
except grpc.RpcError as e: except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}") logging.error(f"[ACTOR] gRPC error: {e}")
@ -211,9 +201,7 @@ def send_interactions(
) )
try: try:
learner_client.SendInteractions( learner_client.SendInteractions(interactions_stream(shutdown_event, interactions_queue))
interactions_stream(shutdown_event, interactions_queue)
)
except grpc.RpcError as e: except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}") logging.error(f"[ACTOR] gRPC error: {e}")
@ -301,9 +289,7 @@ def act_with_policy(
logging.info("make_env online") logging.info("make_env online")
online_env = make_robot_env( online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
robot=robot, reward_classifier=reward_classifier, cfg=cfg
)
set_global_seed(cfg.seed) set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
@ -355,13 +341,9 @@ def act_with_policy(
action = policy.select_action(batch=obs) action = policy.select_action(batch=obs)
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9) policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
log_policy_frequency_issue( log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step
)
next_obs, reward, done, truncated, info = online_env.step( next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
action.squeeze(dim=0).cpu().numpy()
)
else: else:
# TODO (azouitine): Make a custom space for torch tensor # TODO (azouitine): Make a custom space for torch tensor
action = online_env.action_space.sample() action = online_env.action_space.sample()
@ -369,9 +351,7 @@ def act_with_policy(
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box # HACK: We have only one env but we want to batch it, it will be resolved with the torch box
action = ( action = (
torch.from_numpy(action[0]) torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
.to(device, non_blocking=device.type == "cuda")
.unsqueeze(dim=0)
) )
sum_reward_episode += float(reward) sum_reward_episode += float(reward)
@ -391,9 +371,7 @@ def act_with_policy(
# Check for NaN values in observations # Check for NaN values in observations
for key, tensor in obs.items(): for key, tensor in obs.items():
if torch.isnan(tensor).any(): if torch.isnan(tensor).any():
logging.error( logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}"
)
list_transition_to_send_to_learner.append( list_transition_to_send_to_learner.append(
Transition( Transition(
@ -413,13 +391,9 @@ def act_with_policy(
# Because we are using a single environment we can index at zero # Because we are using a single environment we can index at zero
if done or truncated: if done or truncated:
# TODO: Handle logging for episode information # TODO: Handle logging for episode information
logging.info( logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}"
)
update_policy_parameters( update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
policy=policy.actor, parameters_queue=parameters_queue, device=device
)
if len(list_transition_to_send_to_learner) > 0: if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue( push_transitions_to_transport_queue(
@ -495,9 +469,7 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
return stats return stats
def log_policy_frequency_issue( def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
policy_fps: float, cfg: DictConfig, interaction_step: int
):
if policy_fps < cfg.fps: if policy_fps < cfg.fps:
logging.warning( logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}" f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"

View File

@ -14,16 +14,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools import functools
import io
import os
import pickle
from typing import Any, Callable, Optional, Sequence, TypedDict from typing import Any, Callable, Optional, Sequence, TypedDict
import io
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from tqdm import tqdm from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import os
import pickle
class Transition(TypedDict): class Transition(TypedDict):
@ -45,38 +45,27 @@ class BatchTransition(TypedDict):
truncated: torch.Tensor truncated: torch.Tensor
def move_transition_to_device( def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
transition: Transition, device: str = "cpu"
) -> Transition:
# Move state tensors to CPU # Move state tensors to CPU
device = torch.device(device) device = torch.device(device)
transition["state"] = { transition["state"] = {
key: val.to(device, non_blocking=device.type == "cuda") key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
for key, val in transition["state"].items()
} }
# Move action to CPU # Move action to CPU
transition["action"] = transition["action"].to( transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
device, non_blocking=device.type == "cuda"
)
# No need to move reward or done, as they are float and bool # No need to move reward or done, as they are float and bool
# No need to move reward or done, as they are float and bool # No need to move reward or done, as they are float and bool
if isinstance(transition["reward"], torch.Tensor): if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to( transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
device=device, non_blocking=device.type == "cuda"
)
if isinstance(transition["done"], torch.Tensor): if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to( transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
device, non_blocking=device.type == "cuda"
)
if isinstance(transition["truncated"], torch.Tensor): if isinstance(transition["truncated"], torch.Tensor):
transition["truncated"] = transition["truncated"].to( transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda")
device, non_blocking=device.type == "cuda"
)
# Move next_state tensors to CPU # Move next_state tensors to CPU
transition["next_state"] = { transition["next_state"] = {
@ -100,10 +89,7 @@ def move_state_dict_to_device(state_dict, device="cpu"):
if isinstance(state_dict, torch.Tensor): if isinstance(state_dict, torch.Tensor):
return state_dict.to(device) return state_dict.to(device)
elif isinstance(state_dict, dict): elif isinstance(state_dict, dict):
return { return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
k: move_state_dict_to_device(v, device=device)
for k, v in state_dict.items()
}
elif isinstance(state_dict, list): elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict] return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple): elif isinstance(state_dict, tuple):
@ -174,9 +160,7 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels # Gather pixels
cropped_hwcn = images_hwcn[ cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
]
# cropped_hwcn => (B, crop_h, crop_w, C) # cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
@ -223,9 +207,7 @@ class ReplayBuffer:
self.optimize_memory = optimize_memory self.optimize_memory = optimize_memory
# Track episode boundaries for memory optimization # Track episode boundaries for memory optimization
self.episode_ends = torch.zeros( self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
capacity, dtype=torch.bool, device=storage_device
)
# If no state_keys provided, default to an empty list # If no state_keys provided, default to an empty list
self.state_keys = state_keys if state_keys is not None else [] self.state_keys = state_keys if state_keys is not None else []
@ -246,9 +228,7 @@ class ReplayBuffer:
key: torch.empty((self.capacity, *shape), device=self.storage_device) key: torch.empty((self.capacity, *shape), device=self.storage_device)
for key, shape in state_shapes.items() for key, shape in state_shapes.items()
} }
self.actions = torch.empty( self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
(self.capacity, *action_shape), device=self.storage_device
)
self.rewards = torch.empty((self.capacity,), device=self.storage_device) self.rewards = torch.empty((self.capacity,), device=self.storage_device)
if not self.optimize_memory: if not self.optimize_memory:
@ -262,12 +242,8 @@ class ReplayBuffer:
# Just create a reference to states for consistent API # Just create a reference to states for consistent API
self.next_states = self.states # Just a reference for API consistency self.next_states = self.states # Just a reference for API consistency
self.dones = torch.empty( self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
(self.capacity,), dtype=torch.bool, device=self.storage_device self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
)
self.truncateds = torch.empty(
(self.capacity,), dtype=torch.bool, device=self.storage_device
)
self.initialized = True self.initialized = True
def __len__(self): def __len__(self):
@ -294,9 +270,7 @@ class ReplayBuffer:
if not self.optimize_memory: if not self.optimize_memory:
# Only store next_states if not optimizing memory # Only store next_states if not optimizing memory
self.next_states[key][self.position].copy_( self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
next_state[key].squeeze(dim=0)
)
self.actions[self.position].copy_(action.squeeze(dim=0)) self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward self.rewards[self.position] = reward
@ -309,23 +283,15 @@ class ReplayBuffer:
def sample(self, batch_size: int) -> BatchTransition: def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors.""" """Sample a random batch of transitions and collate them into batched tensors."""
if not self.initialized: if not self.initialized:
raise RuntimeError( raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
"Cannot sample from an empty buffer. Add transitions first."
)
batch_size = min(batch_size, self.size) batch_size = min(batch_size, self.size)
# Random indices for sampling - create on the same device as storage # Random indices for sampling - create on the same device as storage
idx = torch.randint( idx = torch.randint(low=0, high=self.size, size=(batch_size,), device=self.storage_device)
low=0, high=self.size, size=(batch_size,), device=self.storage_device
)
# Identify image keys that need augmentation # Identify image keys that need augmentation
image_keys = ( image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
[k for k in self.states if k.startswith("observation.image")]
if self.use_drq
else []
)
# Create batched state and next_state # Create batched state and next_state
batch_state = {} batch_state = {}
@ -358,13 +324,9 @@ class ReplayBuffer:
# Split the augmented images back to their sources # Split the augmented images back to their sources
for i, key in enumerate(image_keys): for i, key in enumerate(image_keys):
# State images are at even indices (0, 2, 4...) # State images are at even indices (0, 2, 4...)
batch_state[key] = augmented_images[ batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
i * 2 * batch_size : (i * 2 + 1) * batch_size
]
# Next state images are at odd indices (1, 3, 5...) # Next state images are at odd indices (1, 3, 5...)
batch_next_state[key] = augmented_images[ batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size
]
# Sample other tensors # Sample other tensors
batch_actions = self.actions[idx].to(self.device) batch_actions = self.actions[idx].to(self.device)
@ -434,16 +396,12 @@ class ReplayBuffer:
) )
# Convert dataset to transitions # Convert dataset to transitions
list_transition = cls._lerobotdataset_to_transitions( list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
dataset=lerobot_dataset, state_keys=state_keys
)
# Initialize the buffer with the first transition to set up storage tensors # Initialize the buffer with the first transition to set up storage tensors
if list_transition: if list_transition:
first_transition = list_transition[0] first_transition = list_transition[0]
first_state = { first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
k: v.to(device) for k, v in first_transition["state"].items()
}
first_action = first_transition["action"].to(device) first_action = first_transition["action"].to(device)
# Apply action mask/delta if needed # Apply action mask/delta if needed
@ -541,9 +499,7 @@ class ReplayBuffer:
# Convert transitions into episodes and frames # Convert transitions into episodes and frames
episode_index = 0 episode_index = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index)
episode_index=episode_index
)
frame_idx_in_episode = 0 frame_idx_in_episode = 0
for idx in range(self.size): for idx in range(self.size):
@ -557,12 +513,8 @@ class ReplayBuffer:
# Fill action, reward, done # Fill action, reward, done
frame_dict["action"] = self.actions[actual_idx].cpu() frame_dict["action"] = self.actions[actual_idx].cpu()
frame_dict["next.reward"] = torch.tensor( frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
[self.rewards[actual_idx]], dtype=torch.float32 frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
).cpu()
frame_dict["next.done"] = torch.tensor(
[self.dones[actual_idx]], dtype=torch.bool
).cpu()
# Add to the dataset's buffer # Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict) lerobot_dataset.add_frame(frame_dict)
@ -619,9 +571,7 @@ class ReplayBuffer:
A list of Transition dictionaries with the same length as `dataset`. A list of Transition dictionaries with the same length as `dataset`.
""" """
if state_keys is None: if state_keys is None:
raise ValueError( raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
"State keys must be provided when converting LeRobotDataset to Transitions."
)
transitions = [] transitions = []
num_frames = len(dataset) num_frames = len(dataset)
@ -632,9 +582,7 @@ class ReplayBuffer:
# If not, we need to infer it from episode boundaries # If not, we need to infer it from episode boundaries
if not has_done_key: if not has_done_key:
print( print("'next.done' key not found in dataset. Inferring from episode boundaries...")
"'next.done' key not found in dataset. Inferring from episode boundaries..."
)
for i in tqdm(range(num_frames)): for i in tqdm(range(num_frames)):
current_sample = dataset[i] current_sample = dataset[i]
@ -886,8 +834,7 @@ if __name__ == "__main__":
# We need to be careful because we don't know the original index # We need to be careful because we don't know the original index
# So we check if the increment is roughly 0.01 # So we check if the increment is roughly 0.01
next_state_check = ( next_state_check = (
abs(next_state_sig - state_sig - 0.01) < 1e-4 abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
or abs(next_state_sig - state_sig) < 1e-4
) )
# Count correct relationships # Count correct relationships
@ -901,17 +848,11 @@ if __name__ == "__main__":
total_checks += 3 total_checks += 3
alignment_accuracy = 100.0 * correct_relationships / total_checks alignment_accuracy = 100.0 * correct_relationships / total_checks
print( print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%")
f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%"
)
if alignment_accuracy > 99.0: if alignment_accuracy > 99.0:
print( print("✅ All relationships verified! Buffer maintains correct temporal relationships.")
"✅ All relationships verified! Buffer maintains correct temporal relationships."
)
else: else:
print( print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.")
"⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues."
)
# Print some debug information about failures # Print some debug information about failures
print("\nDebug information for failed checks:") print("\nDebug information for failed checks:")
@ -973,18 +914,14 @@ if __name__ == "__main__":
# Verify consistency before and after conversion # Verify consistency before and after conversion
original_states = batch["state"]["observation.image"].mean().item() original_states = batch["state"]["observation.image"].mean().item()
reconverted_states = ( reconverted_states = reconverted_batch["state"]["observation.image"].mean().item()
reconverted_batch["state"]["observation.image"].mean().item()
)
print(f"Original buffer state mean: {original_states:.4f}") print(f"Original buffer state mean: {original_states:.4f}")
print(f"Reconverted buffer state mean: {reconverted_states:.4f}") print(f"Reconverted buffer state mean: {reconverted_states:.4f}")
if abs(original_states - reconverted_states) < 1.0: if abs(original_states - reconverted_states) < 1.0:
print("Values are reasonably similar - conversion works as expected") print("Values are reasonably similar - conversion works as expected")
else: else:
print( print("WARNING: Significant difference between original and reconverted values")
"WARNING: Significant difference between original and reconverted values"
)
print("\nAll previous tests completed!") print("\nAll previous tests completed!")
@ -1093,15 +1030,11 @@ if __name__ == "__main__":
all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device) all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device)
# Get state tensors # Get state tensors
batch_state = { batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)}
"value": test_buffer.states["value"][all_indices].to(test_buffer.device)
}
# Get next_state using memory-optimized approach (simply index+1) # Get next_state using memory-optimized approach (simply index+1)
next_indices = (all_indices + 1) % test_buffer.capacity next_indices = (all_indices + 1) % test_buffer.capacity
batch_next_state = { batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)}
"value": test_buffer.states["value"][next_indices].to(test_buffer.device)
}
# Get other tensors # Get other tensors
batch_dones = test_buffer.dones[all_indices].to(test_buffer.device) batch_dones = test_buffer.dones[all_indices].to(test_buffer.device)
@ -1121,9 +1054,7 @@ if __name__ == "__main__":
print("- We always use the next state in the buffer (index+1) as next_state") print("- We always use the next state in the buffer (index+1) as next_state")
print("- For terminal states, this means using the first state of the next episode") print("- For terminal states, this means using the first state of the next episode")
print("- This is a common tradeoff in RL implementations for memory efficiency") print("- This is a common tradeoff in RL implementations for memory efficiency")
print( print("- Since we track done flags, the algorithm can handle these transitions correctly")
"- Since we track done flags, the algorithm can handle these transitions correctly"
)
# Test random sampling # Test random sampling
print("\nVerifying random sampling with simplified memory optimization...") print("\nVerifying random sampling with simplified memory optimization...")
@ -1137,23 +1068,19 @@ if __name__ == "__main__":
# Print a few samples # Print a few samples
print("Random samples - State, Next State, Done (First 10):") print("Random samples - State, Next State, Done (First 10):")
for i in range(10): for i in range(10):
print( print(f" {random_state_values[i]:.1f}{random_next_values[i]:.1f}, Done: {random_done_flags[i]}")
f" {random_state_values[i]:.1f}{random_next_values[i]:.1f}, Done: {random_done_flags[i]}"
)
# Calculate memory savings # Calculate memory savings
# Assume optimized_buffer and standard_buffer have already been initialized and filled # Assume optimized_buffer and standard_buffer have already been initialized and filled
std_mem = ( std_mem = (
sum( sum(
standard_buffer.states[key].nelement() standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size()
* standard_buffer.states[key].element_size()
for key in standard_buffer.states for key in standard_buffer.states
) )
* 2 * 2
) )
opt_mem = sum( opt_mem = sum(
optimized_buffer.states[key].nelement() optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size()
* optimized_buffer.states[key].element_size()
for key in optimized_buffer.states for key in optimized_buffer.states
) )

View File

@ -225,9 +225,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
description="Crop rectangular ROIs from a LeRobot dataset."
)
parser.add_argument( parser.add_argument(
"--repo-id", "--repo-id",
type=str, type=str,
@ -249,9 +247,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
local_files_only = args.root is not None local_files_only = args.root is not None
dataset = LeRobotDataset( dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
repo_id=args.repo_id, root=args.root, local_files_only=local_files_only
)
images = get_image_from_lerobot_dataset(dataset) images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()} images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}

View File

@ -1,13 +1,14 @@
from lerobot.common.robot_devices.robots.factory import make_robot import argparse
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.scripts.server.kinematics import RobotKinematics
import logging import logging
import time import time
import torch
import numpy as np
import argparse
import numpy as np
import torch
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -187,9 +188,7 @@ class KeyboardController(InputController):
class GamepadController(InputController): class GamepadController(InputController):
"""Generate motion deltas from gamepad input.""" """Generate motion deltas from gamepad input."""
def __init__( def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1):
self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1
):
super().__init__(x_step_size, y_step_size, z_step_size) super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone self.deadzone = deadzone
self.joystick = None self.joystick = None
@ -203,9 +202,7 @@ class GamepadController(InputController):
pygame.joystick.init() pygame.joystick.init()
if pygame.joystick.get_count() == 0: if pygame.joystick.get_count() == 0:
logging.error( logging.error("No gamepad detected. Please connect a gamepad and try again.")
"No gamepad detected. Please connect a gamepad and try again."
)
self.running = False self.running = False
return return
@ -338,18 +335,12 @@ class GamepadControllerHID(InputController):
devices = hid.enumerate() devices = hid.enumerate()
for device in devices: for device in devices:
if ( if device["vendor_id"] == self.vendor_id and device["product_id"] == self.product_id:
device["vendor_id"] == self.vendor_id logging.info(f"Found gamepad: {device.get('product_string', 'Unknown')}")
and device["product_id"] == self.product_id
):
logging.info(
f"Found gamepad: {device.get('product_string', 'Unknown')}"
)
return device return device
logging.error( logging.error(
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and " f"No gamepad with vendor ID 0x{self.vendor_id:04X} and product ID 0x{self.product_id:04X} found"
f"product ID 0x{self.product_id:04X} found"
) )
return None return None
@ -381,9 +372,7 @@ class GamepadControllerHID(InputController):
except OSError as e: except OSError as e:
logging.error(f"Error opening gamepad: {e}") logging.error(f"Error opening gamepad: {e}")
logging.error( logging.error("You might need to run this with sudo/admin privileges on some systems")
"You might need to run this with sudo/admin privileges on some systems"
)
self.running = False self.running = False
def stop(self): def stop(self):
@ -421,12 +410,8 @@ class GamepadControllerHID(InputController):
# Apply deadzone # Apply deadzone
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
self.right_x = ( self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
0 if abs(self.right_x) < self.deadzone else self.right_x self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
)
self.right_y = (
0 if abs(self.right_y) < self.deadzone else self.right_y
)
# Parse button states (byte 5 in the Logitech RumblePad 2) # Parse button states (byte 5 in the Logitech RumblePad 2)
buttons = data[5] buttons = data[5]
@ -493,9 +478,7 @@ def test_inverse_kinematics(robot, fps=10):
joint_positions = obs["observation.state"].cpu().numpy() joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions) ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
desired_ee_pos = ee_pos desired_ee_pos = ee_pos
target_joint_state = RobotKinematics.ik( target_joint_state = RobotKinematics.ik(joint_positions, desired_ee_pos, position_only=True)
joint_positions, desired_ee_pos, position_only=True
)
robot.send_action(torch.from_numpy(target_joint_state)) robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Target Joint State: {target_joint_state}") logging.info(f"Target Joint State: {target_joint_state}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@ -573,17 +556,13 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
robot.send_action(torch.from_numpy(target_joint_state)) robot.send_action(torch.from_numpy(target_joint_state))
# Logging # Logging
logging.info( logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}"
)
logging.info(f"Delta EE: {ee_delta[:3, 3]}") logging.info(f"Delta EE: {ee_delta[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_delta_inverse_kinematics( def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None, fk_func=None):
robot, controller, fps=10, bounds=None, fk_func=None
):
""" """
Control a robot using delta end-effector movements from any input controller. Control a robot using delta end-effector movements from any input controller.
@ -597,9 +576,7 @@ def teleoperate_delta_inverse_kinematics(
if fk_func is None: if fk_func is None:
fk_func = RobotKinematics.fk_gripper_tip fk_func = RobotKinematics.fk_gripper_tip
logging.info( logging.info(f"Testing Delta End-Effector Control with {controller.__class__.__name__}")
f"Testing Delta End-Effector Control with {controller.__class__.__name__}"
)
# Initial position capture # Initial position capture
obs = robot.capture_observation() obs = robot.capture_observation()
@ -631,9 +608,7 @@ def teleoperate_delta_inverse_kinematics(
# Apply bounds if provided # Apply bounds if provided
if bounds is not None: if bounds is not None:
desired_ee_pos[:3, 3] = np.clip( desired_ee_pos[:3, 3] = np.clip(desired_ee_pos[:3, 3], bounds["min"], bounds["max"])
desired_ee_pos[:3, 3], bounds["min"], bounds["max"]
)
# Only send commands if there's actual movement # Only send commands if there's actual movement
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]): if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
@ -684,14 +659,10 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]): if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
# Step the environment - pass action as a tensor with intervention flag # Step the environment - pass action as a tensor with intervention flag
action_tensor = torch.from_numpy(action.astype(np.float32)) action_tensor = torch.from_numpy(action.astype(np.float32))
obs, reward, terminated, truncated, info = env.step( obs, reward, terminated, truncated, info = env.step((action_tensor, False))
(action_tensor, False)
)
# Log information # Log information
logging.info( logging.info(f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]")
f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]"
)
logging.info(f"Reward: {reward}") logging.info(f"Reward: {reward}")
# Reset if episode ended # Reset if episode ended
@ -761,20 +732,14 @@ if __name__ == "__main__":
# Determine controller type based on mode prefix # Determine controller type based on mode prefix
controller = None controller = None
if args.mode.startswith("keyboard"): if args.mode.startswith("keyboard"):
controller = KeyboardController( controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
x_step_size=0.01, y_step_size=0.01, z_step_size=0.05
)
elif args.mode.startswith("gamepad"): elif args.mode.startswith("gamepad"):
controller = GamepadController( controller = GamepadController(x_step_size=0.02, y_step_size=0.02, z_step_size=0.05)
x_step_size=0.02, y_step_size=0.02, z_step_size=0.05
)
# Handle mode categories # Handle mode categories
if args.mode in ["keyboard", "gamepad"]: if args.mode in ["keyboard", "gamepad"]:
# Direct robot control modes # Direct robot control modes
teleoperate_delta_inverse_kinematics( teleoperate_delta_inverse_kinematics(robot, controller, bounds=bounds, fps=10)
robot, controller, bounds=bounds, fps=10
)
elif args.mode in ["keyboard_gym", "gamepad_gym"]: elif args.mode in ["keyboard_gym", "gamepad_gym"]:
# Gym environment control modes # Gym environment control modes

View File

@ -32,9 +32,7 @@ def find_joint_bounds(
if display_cameras and not is_headless(): if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key] image_keys = [key for key in observation if "image" in key]
for key in image_keys: for key in image_keys:
cv2.imshow( cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1) cv2.waitKey(1)
if time.perf_counter() - start_episode_t > control_time_s: if time.perf_counter() - start_episode_t > control_time_s:
@ -69,9 +67,7 @@ def find_ee_bounds(
if display_cameras and not is_headless(): if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key] image_keys = [key for key in observation if "image" in key]
for key in image_keys: for key in image_keys:
cv2.imshow( cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1) cv2.waitKey(1)
if time.perf_counter() - start_episode_t > control_time_s: if time.perf_counter() - start_episode_t > control_time_s:

View File

@ -1,10 +1,10 @@
import argparse import argparse
import sys
import logging import logging
import sys
import time import time
from threading import Lock from threading import Lock
from typing import Annotated, Any, Dict, Tuple from typing import Annotated, Any, Dict, Tuple
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
import torch import torch
@ -18,7 +18,6 @@ from lerobot.common.robot_devices.control_utils import (
) )
from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config, log_say from lerobot.common.utils.utils import init_hydra_config, log_say
from lerobot.scripts.server.kinematics import RobotKinematics from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -67,9 +66,7 @@ class HILSerlRobotEnv(gym.Env):
if not self.robot.is_connected: if not self.robot.is_connected:
self.robot.connect() self.robot.connect()
self.initial_follower_position = robot.follower_arms["main"].read( self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
"Present_Position"
)
# Episode tracking. # Episode tracking.
self.current_step = 0 self.current_step = 0
@ -77,9 +74,7 @@ class HILSerlRobotEnv(gym.Env):
self.delta = delta self.delta = delta
self.use_delta_action_space = use_delta_action_space self.use_delta_action_space = use_delta_action_space
self.current_joint_positions = self.robot.follower_arms["main"].read( self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
"Present_Position"
)
# Retrieve the size of the joint position interval bound. # Retrieve the size of the joint position interval bound.
self.relative_bounds_size = ( self.relative_bounds_size = (
@ -92,9 +87,7 @@ class HILSerlRobotEnv(gym.Env):
) )
self.robot.config.max_relative_target = ( self.robot.config.max_relative_target = (
self.relative_bounds_size.float() self.relative_bounds_size.float() if self.relative_bounds_size is not None else None
if self.relative_bounds_size is not None
else None
) )
# Dynamically configure the observation and action spaces. # Dynamically configure the observation and action spaces.
@ -119,9 +112,7 @@ class HILSerlRobotEnv(gym.Env):
# Define observation spaces for images and other states. # Define observation spaces for images and other states.
image_keys = [key for key in example_obs if "image" in key] image_keys = [key for key in example_obs if "image" in key]
observation_spaces = { observation_spaces = {
key: gym.spaces.Box( key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8
)
for key in image_keys for key in image_keys
} }
observation_spaces["observation.state"] = gym.spaces.Box( observation_spaces["observation.state"] = gym.spaces.Box(
@ -172,9 +163,7 @@ class HILSerlRobotEnv(gym.Env):
), ),
) )
def reset( def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
self, seed=None, options=None
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
""" """
Reset the environment to its initial state. Reset the environment to its initial state.
This method resets the step counter and clears any episodic data. This method resets the step counter and clears any episodic data.
@ -231,35 +220,25 @@ class HILSerlRobotEnv(gym.Env):
""" """
policy_action, intervention_bool = action policy_action, intervention_bool = action
teleop_action = None teleop_action = None
self.current_joint_positions = self.robot.follower_arms["main"].read( self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
"Present_Position"
)
if isinstance(policy_action, torch.Tensor): if isinstance(policy_action, torch.Tensor):
policy_action = policy_action.cpu().numpy() policy_action = policy_action.cpu().numpy()
policy_action = np.clip( policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
policy_action, self.action_space[0].low, self.action_space[0].high
)
if not intervention_bool: if not intervention_bool:
if self.use_delta_action_space: if self.use_delta_action_space:
target_joint_positions = ( target_joint_positions = self.current_joint_positions + self.delta * policy_action
self.current_joint_positions + self.delta * policy_action
)
else: else:
target_joint_positions = policy_action target_joint_positions = policy_action
self.robot.send_action(torch.from_numpy(target_joint_positions)) self.robot.send_action(torch.from_numpy(target_joint_positions))
observation = self.robot.capture_observation() observation = self.robot.capture_observation()
else: else:
observation, teleop_action = self.robot.teleop_step(record_data=True) observation, teleop_action = self.robot.teleop_step(record_data=True)
teleop_action = teleop_action[ teleop_action = teleop_action["action"] # Convert tensor to appropriate format
"action"
] # Convert tensor to appropriate format
# When applying the delta action space, convert teleop absolute values to relative differences. # When applying the delta action space, convert teleop absolute values to relative differences.
if self.use_delta_action_space: if self.use_delta_action_space:
teleop_action = ( teleop_action = (teleop_action - self.current_joint_positions) / self.delta
teleop_action - self.current_joint_positions
) / self.delta
if self.relative_bounds_size is not None and ( if self.relative_bounds_size is not None and (
torch.any(teleop_action < -self.relative_bounds_size) torch.any(teleop_action < -self.relative_bounds_size)
and torch.any(teleop_action > self.relative_bounds_size) and torch.any(teleop_action > self.relative_bounds_size)
@ -333,12 +312,8 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
self.last_joint_positions = np.zeros(old_shape) self.last_joint_positions = np.zeros(old_shape)
new_low = np.concatenate( new_low = np.concatenate([old_low, np.ones_like(old_low) * -joint_velocity_limits])
[old_low, np.ones_like(old_low) * -joint_velocity_limits] new_high = np.concatenate([old_high, np.ones_like(old_high) * joint_velocity_limits])
)
new_high = np.concatenate(
[old_high, np.ones_like(old_high) * joint_velocity_limits]
)
new_shape = (old_shape[0] * 2,) new_shape = (old_shape[0] * 2,)
@ -352,9 +327,7 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
self.dt = 1.0 / fps self.dt = 1.0 / fps
def observation(self, observation): def observation(self, observation):
joint_velocities = ( joint_velocities = (observation["observation.state"] - self.last_joint_positions) / self.dt
observation["observation.state"] - self.last_joint_positions
) / self.dt
self.last_joint_positions = observation["observation.state"].clone() self.last_joint_positions = observation["observation.state"].clone()
observation["observation.state"] = torch.cat( observation["observation.state"] = torch.cat(
[observation["observation.state"], joint_velocities], dim=-1 [observation["observation.state"], joint_velocities], dim=-1
@ -439,9 +412,7 @@ class JointMaskingActionSpace(gym.Wrapper):
raise ValueError("Mask length must match action space dimensions") raise ValueError("Mask length must match action space dimensions")
low = env.action_space.low[self.active_dims] low = env.action_space.low[self.active_dims]
high = env.action_space.high[self.active_dims] high = env.action_space.high[self.active_dims]
self.action_space = gym.spaces.Box( self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype)
low=low, high=high, dtype=env.action_space.dtype
)
if isinstance(env.action_space, gym.spaces.Tuple): if isinstance(env.action_space, gym.spaces.Tuple):
if len(mask) != env.action_space[0].shape[0]: if len(mask) != env.action_space[0].shape[0]:
@ -449,12 +420,8 @@ class JointMaskingActionSpace(gym.Wrapper):
low = env.action_space[0].low[self.active_dims] low = env.action_space[0].low[self.active_dims]
high = env.action_space[0].high[self.active_dims] high = env.action_space[0].high[self.active_dims]
action_space_masked = gym.spaces.Box( action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype)
low=low, high=high, dtype=env.action_space[0].dtype self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1]))
)
self.action_space = gym.spaces.Tuple(
(action_space_masked, env.action_space[1])
)
# Create new action space with masked dimensions # Create new action space with masked dimensions
def action(self, action): def action(self, action):
@ -473,18 +440,14 @@ class JointMaskingActionSpace(gym.Wrapper):
# Extract the masked component from the tuple. # Extract the masked component from the tuple.
masked_action = action[0] if isinstance(action, tuple) else action masked_action = action[0] if isinstance(action, tuple) else action
# Create a full action for the Box element. # Create a full action for the Box element.
full_box_action = np.zeros( full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype)
self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype
)
full_box_action[self.active_dims] = masked_action full_box_action[self.active_dims] = masked_action
# Return a tuple with the reconstructed Box action and the unchanged remainder. # Return a tuple with the reconstructed Box action and the unchanged remainder.
return (full_box_action, action[1]) return (full_box_action, action[1])
else: else:
# For Box action spaces. # For Box action spaces.
masked_action = action if not isinstance(action, tuple) else action[0] masked_action = action if not isinstance(action, tuple) else action[0]
full_action = np.zeros( full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype)
self.env.action_space.shape, dtype=self.env.action_space.dtype
)
full_action[self.active_dims] = masked_action full_action[self.active_dims] = masked_action
return full_action return full_action
@ -493,13 +456,9 @@ class JointMaskingActionSpace(gym.Wrapper):
obs, reward, terminated, truncated, info = self.env.step(action) obs, reward, terminated, truncated, info = self.env.step(action)
if "action_intervention" in info and info["action_intervention"] is not None: if "action_intervention" in info and info["action_intervention"] is not None:
if info["action_intervention"].dim() == 1: if info["action_intervention"].dim() == 1:
info["action_intervention"] = info["action_intervention"][ info["action_intervention"] = info["action_intervention"][self.active_dims]
self.active_dims
]
else: else:
info["action_intervention"] = info["action_intervention"][ info["action_intervention"] = info["action_intervention"][:, self.active_dims]
:, self.active_dims
]
return obs, reward, terminated, truncated, info return obs, reward, terminated, truncated, info
@ -555,9 +514,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
for key in crop_params_dict: for key in crop_params_dict:
top, left, height, width = crop_params_dict[key] top, left, height, width = crop_params_dict[key]
new_shape = (top + height, left + width) new_shape = (top + height, left + width)
self.observation_space[key] = gym.spaces.Box( self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
low=0, high=255, shape=new_shape
)
self.resize_size = resize_size self.resize_size = resize_size
if self.resize_size is None: if self.resize_size is None:
@ -583,9 +540,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
) )
# Check for NaNs before processing # Check for NaNs before processing
if torch.isnan(obs[k]).any(): if torch.isnan(obs[k]).any():
logging.error( logging.error(f"NaN values detected in observation {k} before crop and resize")
f"NaN values detected in observation {k} before crop and resize"
)
if device == torch.device("mps:0"): if device == torch.device("mps:0"):
obs[k] = obs[k].cpu() obs[k] = obs[k].cpu()
@ -595,9 +550,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
# Check for NaNs after processing # Check for NaNs after processing
if torch.isnan(obs[k]).any(): if torch.isnan(obs[k]).any():
logging.error( logging.error(f"NaN values detected in observation {k} after crop and resize")
f"NaN values detected in observation {k} after crop and resize"
)
obs[k] = obs[k].to(device) obs[k] = obs[k].to(device)
@ -627,14 +580,10 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper):
observation = preprocess_observation(observation) observation = preprocess_observation(observation)
observation = { observation = {
key: observation[key].to( key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
self.device, non_blocking=self.device.type == "cuda"
)
for key in observation for key in observation
} }
observation = { observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
k: torch.tensor(v, device=self.device) for k, v in observation.items()
}
return observation return observation
@ -686,26 +635,16 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
play_sounds=True, play_sounds=True,
) )
return return
if ( if self.events["pause_policy"] and not self.events["human_intervention_step"]:
self.events["pause_policy"]
and not self.events["human_intervention_step"]
):
self.events["human_intervention_step"] = True self.events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.") print("Space key pressed. Human intervention starting.")
log_say( log_say("Starting human intervention.", play_sounds=True)
"Starting human intervention.", play_sounds=True
)
return return
if ( if self.events["pause_policy"] and self.events["human_intervention_step"]:
self.events["pause_policy"]
and self.events["human_intervention_step"]
):
self.events["pause_policy"] = False self.events["pause_policy"] = False
self.events["human_intervention_step"] = False self.events["human_intervention_step"] = False
print("Space key pressed for a third time.") print("Space key pressed for a third time.")
log_say( log_say("Continuing with policy actions.", play_sounds=True)
"Continuing with policy actions.", play_sounds=True
)
return return
except Exception as e: except Exception as e:
print(f"Error handling key press: {e}") print(f"Error handling key press: {e}")
@ -713,9 +652,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
self.listener = keyboard.Listener(on_press=on_press) self.listener = keyboard.Listener(on_press=on_press)
self.listener.start() self.listener.start()
except ImportError: except ImportError:
logging.warning( logging.warning("Could not import pynput. Keyboard interface will not be available.")
"Could not import pynput. Keyboard interface will not be available."
)
self.listener = None self.listener = None
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]: def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
@ -742,9 +679,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
time.sleep(0.1) # Check more frequently if desired time.sleep(0.1) # Check more frequently if desired
# Execute the step in the underlying environment # Execute the step in the underlying environment
obs, reward, terminated, truncated, info = self.env.step( obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
(policy_action, is_intervention)
)
# Override reward and termination if episode success event triggered # Override reward and termination if episode success event triggered
with self.event_lock: with self.event_lock:
@ -807,9 +742,7 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
def observation( def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
self, observation: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
for key in observation: for key in observation:
if "image" in key and observation[key].dim() == 3: if "image" in key and observation[key].dim() == 3:
observation[key] = observation[key].unsqueeze(0) observation[key] = observation[key].unsqueeze(0)
@ -844,9 +777,7 @@ class EEActionWrapper(gym.ActionWrapper):
dtype=np.float32, dtype=np.float32,
) )
if isinstance(self.action_space, gym.spaces.Tuple): if isinstance(self.action_space, gym.spaces.Tuple):
self.action_space = gym.spaces.Tuple( self.action_space = gym.spaces.Tuple((ee_action_space, self.action_space[1]))
(ee_action_space, self.action_space[1])
)
else: else:
self.action_space = ee_action_space self.action_space = ee_action_space
@ -858,9 +789,7 @@ class EEActionWrapper(gym.ActionWrapper):
if isinstance(action, tuple): if isinstance(action, tuple):
action, _ = action action, _ = action
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read( current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
"Present_Position"
)
current_ee_pos = self.fk_function(current_joint_pos) current_ee_pos = self.fk_function(current_joint_pos)
if isinstance(action, torch.Tensor): if isinstance(action, torch.Tensor):
action = action.cpu().numpy() action = action.cpu().numpy()
@ -898,9 +827,7 @@ class EEObservationWrapper(gym.ObservationWrapper):
self.fk_function = self.kinematics.fk_gripper_tip self.fk_function = self.kinematics.fk_gripper_tip
def observation(self, observation): def observation(self, observation):
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read( current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
"Present_Position"
)
current_ee_pos = self.fk_function(current_joint_pos) current_ee_pos = self.fk_function(current_joint_pos)
observation["observation.state"] = torch.cat( observation["observation.state"] = torch.cat(
[ [
@ -944,8 +871,8 @@ class GamepadControlWrapper(gym.Wrapper):
""" """
super().__init__(env) super().__init__(env)
from lerobot.scripts.server.end_effector_control_utils import ( from lerobot.scripts.server.end_effector_control_utils import (
GamepadControllerHID,
GamepadController, GamepadController,
GamepadControllerHID,
) )
# use HidApi for macos # use HidApi for macos
@ -1027,9 +954,7 @@ class GamepadControlWrapper(gym.Wrapper):
# Update episode ending state if requested # Update episode ending state if requested
if terminate_episode: if terminate_episode:
logging.info( logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}")
f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}"
)
# Only override the action if gamepad is active # Only override the action if gamepad is active
if is_intervention: if is_intervention:
@ -1054,9 +979,7 @@ class GamepadControlWrapper(gym.Wrapper):
logging.info("Episode ended successfully with reward 1.0") logging.info("Episode ended successfully with reward 1.0")
info["is_intervention"] = is_intervention info["is_intervention"] = is_intervention
action_intervention = ( action_intervention = final_action[0] if isinstance(final_action, Tuple) else final_action
final_action[0] if isinstance(final_action, Tuple) else final_action
)
if isinstance(action_intervention, np.ndarray): if isinstance(action_intervention, np.ndarray):
action_intervention = torch.from_numpy(action_intervention) action_intervention = torch.from_numpy(action_intervention)
info["action_intervention"] = action_intervention info["action_intervention"] = action_intervention
@ -1087,9 +1010,7 @@ class GamepadControlWrapper(gym.Wrapper):
class ActionScaleWrapper(gym.ActionWrapper): class ActionScaleWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None): def __init__(self, env, ee_action_space_params=None):
super().__init__(env) super().__init__(env)
assert ee_action_space_params is not None, ( assert ee_action_space_params is not None, "TODO: method implemented for ee action space only so far"
"TODO: method implemented for ee action space only so far"
)
self.scale_vector = np.array( self.scale_vector = np.array(
[ [
[ [
@ -1148,9 +1069,7 @@ def make_robot_env(
if cfg.env.wrapper.add_joint_velocity_to_observation: if cfg.env.wrapper.add_joint_velocity_to_observation:
env = AddJointVelocityToObservation(env=env, fps=cfg.fps) env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
if cfg.env.wrapper.add_ee_pose_to_observation: if cfg.env.wrapper.add_ee_pose_to_observation:
env = EEObservationWrapper( env = EEObservationWrapper(env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds)
env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds
)
env = ConvertToLeRobotObservation(env=env, device=cfg.env.device) env = ConvertToLeRobotObservation(env=env, device=cfg.env.device)
@ -1163,13 +1082,9 @@ def make_robot_env(
# Add reward computation and control wrappers # Add reward computation and control wrappers
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper( env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps)
env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps
)
if cfg.env.wrapper.ee_action_space_params is not None: if cfg.env.wrapper.ee_action_space_params is not None:
env = EEActionWrapper( env = EEActionWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params)
env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params
)
if ( if (
cfg.env.wrapper.ee_action_space_params is not None cfg.env.wrapper.ee_action_space_params is not None
and cfg.env.wrapper.ee_action_space_params.use_gamepad and cfg.env.wrapper.ee_action_space_params.use_gamepad
@ -1193,9 +1108,7 @@ def make_robot_env(
cfg.env.wrapper.ee_action_space_params is None cfg.env.wrapper.ee_action_space_params is None
and cfg.env.wrapper.joint_masking_action_space is not None and cfg.env.wrapper.joint_masking_action_space is not None
): ):
env = JointMaskingActionSpace( env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space)
env=env, mask=cfg.env.wrapper.joint_masking_action_space
)
env = BatchCompitableWrapper(env=env) env = BatchCompitableWrapper(env=env)
return env return env
@ -1216,9 +1129,7 @@ def get_classifier(pretrained_path, config_path, device="mps"):
cfg = init_hydra_config(config_path) cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len( classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
cfg.training.image_keys
) # TODO automate these paths
model = Classifier(classifier_config) model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict()) model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to(device) model = model.to(device)
@ -1317,9 +1228,7 @@ def record_dataset(
# For teleop, get action from intervention # For teleop, get action from intervention
if policy is None: if policy is None:
action = { action = {"action": info["action_intervention"].cpu().squeeze(0).float()}
"action": info["action_intervention"].cpu().squeeze(0).float()
}
# Process observation for dataset # Process observation for dataset
obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()} obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()}
@ -1357,9 +1266,7 @@ def replay_episode(env, repo_id, root=None, episode=0):
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
local_files_only = root is not None local_files_only = root is not None
dataset = LeRobotDataset( dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
repo_id, root=root, episodes=[episode], local_files_only=local_files_only
)
env.reset() env.reset()
actions = dataset.hf_dataset.select_columns("action") actions = dataset.hf_dataset.select_columns("action")
@ -1414,9 +1321,7 @@ if __name__ == "__main__":
default=None, default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.", help="Path to a yaml config file that is necessary to build the reward classifier model.",
) )
parser.add_argument( parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
"--env-path", type=str, default=None, help="Path to the env yaml file"
)
parser.add_argument( parser.add_argument(
"--env-overrides", "--env-overrides",
type=str, type=str,
@ -1441,12 +1346,8 @@ if __name__ == "__main__":
default=None, default=None,
help="Repo ID of the episode to replay", help="Repo ID of the episode to replay",
) )
parser.add_argument( parser.add_argument("--dataset-root", type=str, default=None, help="Root of the dataset to replay")
"--dataset-root", type=str, default=None, help="Root of the dataset to replay" parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay")
)
parser.add_argument(
"--replay-episode", type=int, default=0, help="Episode to replay"
)
parser.add_argument( parser.add_argument(
"--record-repo-id", "--record-repo-id",
type=str, type=str,
@ -1534,9 +1435,7 @@ if __name__ == "__main__":
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# Execute the step: wrap the NumPy action in a torch tensor. # Execute the step: wrap the NumPy action in a torch tensor.
obs, reward, terminated, truncated, info = env.step( obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
(torch.from_numpy(smoothed_action), False)
)
if terminated or truncated: if terminated or truncated:
sucesses.append(reward) sucesses.append(reward)
env.reset() env.reset()

View File

@ -23,11 +23,7 @@ def screw_axis_to_transform(S, theta):
elif np.linalg.norm(S_w) == 1: # Rotation and translation elif np.linalg.norm(S_w) == 1: # Rotation and translation
w_hat = skew_symmetric(S_w) w_hat = skew_symmetric(S_w)
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
t = ( t = (np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat) @ S_v
np.eye(3) * theta
+ (1 - np.cos(theta)) * w_hat
+ (theta - np.sin(theta)) * w_hat @ w_hat
) @ S_v
T = np.eye(4) T = np.eye(4)
T[:3, :3] = R T[:3, :3] = R
T[:3, 3] = t T[:3, 3] = t
@ -189,9 +185,7 @@ class RobotKinematics:
# Wrist # Wrist
# Screw axis of wrist frame wrt base frame # Screw axis of wrist frame wrt base frame
self.S_BR = np.array( self.S_BR = np.array([0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]])
[0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]]
)
# 0-position origin to centroid transform # 0-position origin to centroid transform
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002) self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
@ -284,12 +278,7 @@ class RobotKinematics:
def fk_shoulder(self, robot_pos_deg): def fk_shoulder(self, robot_pos_deg):
"""Forward kinematics for the shoulder frame.""" """Forward kinematics for the shoulder frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi robot_pos_rad = robot_pos_deg / 180 * np.pi
return ( return self.X_WoBo @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) @ self.X_SoSc @ self.X_BS
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ self.X_SoSc
@ self.X_BS
)
def fk_humerus(self, robot_pos_deg): def fk_humerus(self, robot_pos_deg):
"""Forward kinematics for the humerus frame.""" """Forward kinematics for the humerus frame."""
@ -403,15 +392,12 @@ class RobotKinematics:
delta *= 0 delta *= 0
delta[el_ix] = eps / 2 delta[el_ix] = eps / 2
Sdot = ( Sdot = (
fk_func(robot_pos_deg[:-1] + delta)[:3, 3] fk_func(robot_pos_deg[:-1] + delta)[:3, 3] - fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
- fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
) / eps ) / eps
jac[:, el_ix] = Sdot jac[:, el_ix] = Sdot
return jac return jac
def ik( def ik(self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None):
self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None
):
"""Inverse kinematics using gradient descent. """Inverse kinematics using gradient descent.
Args: Args:
@ -457,9 +443,7 @@ if __name__ == "__main__":
# Test 1: Forward kinematics consistency # Test 1: Forward kinematics consistency
print("Test 1: Forward kinematics consistency") print("Test 1: Forward kinematics consistency")
test_angles = np.array( test_angles = np.array([30, 45, -30, 20, 10, 0]) # Example joint angles in degrees
[30, 45, -30, 20, 10, 0]
) # Example joint angles in degrees
# Calculate FK for different joints # Calculate FK for different joints
shoulder_pose = robot.fk_shoulder(test_angles) shoulder_pose = robot.fk_shoulder(test_angles)
@ -480,13 +464,9 @@ if __name__ == "__main__":
] ]
# Check if distances generally increase along the chain # Check if distances generally increase along the chain
is_consistent = all( is_consistent = all(distances[i] <= distances[i + 1] for i in range(len(distances) - 1))
distances[i] <= distances[i + 1] for i in range(len(distances) - 1)
)
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}") print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
print( print(f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}")
f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}"
)
# Test 2: Jacobian computation # Test 2: Jacobian computation
print("Test 2: Jacobian computation") print("Test 2: Jacobian computation")
@ -498,9 +478,7 @@ if __name__ == "__main__":
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5) pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}") print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
print( print(f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}")
f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}"
)
# Test 3: Inverse kinematics # Test 3: Inverse kinematics
print("Test 3: Inverse kinematics (position only)") print("Test 3: Inverse kinematics (position only)")

View File

@ -17,15 +17,8 @@
import logging import logging
import shutil import shutil
import time import time
from pprint import pformat
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pprint import pformat
# from torch.multiprocessing import Event, Queue, Process
# from threading import Event, Thread
# from torch.multiprocessing import Queue, Event
from torch.multiprocessing import Queue
from lerobot.scripts.server.utils import setup_process_handlers
import grpc import grpc
@ -37,6 +30,11 @@ from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
from torch import nn from torch import nn
# from torch.multiprocessing import Event, Queue, Process
# from threading import Event, Thread
# from torch.multiprocessing import Queue, Event
from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
@ -55,18 +53,17 @@ from lerobot.common.utils.utils import (
set_global_random_state, set_global_random_state,
set_global_seed, set_global_seed,
) )
from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import ( from lerobot.scripts.server.buffer import (
ReplayBuffer, ReplayBuffer,
concatenate_batch_transitions,
move_transition_to_device,
move_state_dict_to_device,
bytes_to_transitions,
state_to_bytes,
bytes_to_python_object, bytes_to_python_object,
bytes_to_transitions,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
state_to_bytes,
) )
from lerobot.scripts.server.utils import setup_process_handlers
from lerobot.scripts.server import learner_service
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig: def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
@ -81,13 +78,9 @@ def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
# if resume == True # if resume == True
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir) checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
if not checkpoint_dir.exists(): if not checkpoint_dir.exists():
raise RuntimeError( raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
f"No model checkpoint found in {checkpoint_dir} for resume=True"
)
checkpoint_cfg_path = str( checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
logging.info( logging.info(
colored( colored(
"Resume=True detected, resuming previous run", "Resume=True detected, resuming previous run",
@ -136,9 +129,7 @@ def load_training_state(
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None: def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
num_learnable_params = sum( num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir) log_output_dir(out_dir)
@ -210,22 +201,15 @@ def initialize_offline_replay_buffer(
def get_observation_features( def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]: ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if ( if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
policy.config.vision_encoder_name is None
or not policy.config.freeze_vision_encoder
):
return None, None return None, None
with torch.no_grad(): with torch.no_grad():
observation_features = ( observation_features = (
policy.actor.encoder(observations) policy.actor.encoder(observations) if policy.actor.encoder is not None else None
if policy.actor.encoder is not None
else None
) )
next_observation_features = ( next_observation_features = (
policy.actor.encoder(next_observations) policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
if policy.actor.encoder is not None
else None
) )
return observation_features, next_observation_features return observation_features, next_observation_features
@ -452,9 +436,7 @@ def add_actor_information_and_train(
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats # Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None, dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
if cfg.resume
else None,
) )
# Update the policy config with the grad_clip_norm value from training config if it exists # Update the policy config with the grad_clip_norm value from training config if it exists
@ -469,9 +451,7 @@ def add_actor_information_and_train(
last_time_policy_pushed = time.time() last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state( resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
cfg, logger, optimizers
)
log_training_info(cfg, out_dir, policy) log_training_info(cfg, out_dir, policy)
@ -483,9 +463,7 @@ def add_actor_information_and_train(
active_action_dims = None active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None: if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [ active_action_dims = [
i i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
] ]
offline_replay_buffer = initialize_offline_replay_buffer( offline_replay_buffer = initialize_offline_replay_buffer(
cfg=cfg, cfg=cfg,
@ -502,12 +480,8 @@ def add_actor_information_and_train(
time.time() time.time()
logging.info("Starting learner thread") logging.info("Starting learner thread")
interaction_message, transition = None, None interaction_message, transition = None, None
optimization_step = ( optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
resume_optimization_step if resume_optimization_step is not None else 0 interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
)
interaction_step_shift = (
resume_interaction_step if resume_interaction_step is not None else 0
)
# Extract variables from cfg # Extract variables from cfg
online_step_before_learning = cfg.training.online_step_before_learning online_step_before_learning = cfg.training.online_step_before_learning
@ -519,9 +493,7 @@ def add_actor_information_and_train(
device = cfg.device device = cfg.device
storage_device = cfg.training.storage_device storage_device = cfg.training.storage_device
policy_update_freq = cfg.training.policy_update_freq policy_update_freq = cfg.training.policy_update_freq
policy_parameters_push_frequency = ( policy_parameters_push_frequency = cfg.actor_learner_config.policy_parameters_push_frequency
cfg.actor_learner_config.policy_parameters_push_frequency
)
save_checkpoint = cfg.training.save_checkpoint save_checkpoint = cfg.training.save_checkpoint
online_steps = cfg.training.online_steps online_steps = cfg.training.online_steps
@ -544,9 +516,9 @@ def add_actor_information_and_train(
continue continue
replay_buffer.add(**transition) replay_buffer.add(**transition)
if cfg.dataset_repo_id is not None and transition.get( if cfg.dataset_repo_id is not None and transition.get("complementary_info", {}).get(
"complementary_info", {} "is_intervention"
).get("is_intervention"): ):
offline_replay_buffer.add(**transition) offline_replay_buffer.add(**transition)
logging.debug("[LEARNER] Received transitions") logging.debug("[LEARNER] Received transitions")
@ -556,9 +528,7 @@ def add_actor_information_and_train(
interaction_message = bytes_to_python_object(interaction_message) interaction_message = bytes_to_python_object(interaction_message)
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict( logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
interaction_message, mode="train", custom_step_key="Interaction step"
)
logging.debug("[LEARNER] Received interactions") logging.debug("[LEARNER] Received interactions")
@ -579,9 +549,7 @@ def add_actor_information_and_train(
observations = batch["state"] observations = batch["state"]
next_observations = batch["next_state"] next_observations = batch["next_state"]
done = batch["done"] done = batch["done"]
check_nan_in_transition( check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features( observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations policy, observations, next_observations
@ -619,9 +587,7 @@ def add_actor_information_and_train(
next_observations = batch["next_state"] next_observations = batch["next_state"]
done = batch["done"] done = batch["done"]
check_nan_in_transition( check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features( observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations policy, observations, next_observations
@ -697,23 +663,15 @@ def add_actor_information_and_train(
if optimization_step % log_freq == 0: if optimization_step % log_freq == 0:
training_infos["replay_buffer_size"] = len(replay_buffer) training_infos["replay_buffer_size"] = len(replay_buffer)
if offline_replay_buffer is not None: if offline_replay_buffer is not None:
training_infos["offline_replay_buffer_size"] = len( training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
offline_replay_buffer
)
training_infos["Optimization step"] = optimization_step training_infos["Optimization step"] = optimization_step
logger.log_dict( logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
d=training_infos, mode="train", custom_step_key="Optimization step"
)
# logging.info(f"Training infos: {training_infos}") # logging.info(f"Training infos: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / ( frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
time_for_one_optimization_step + 1e-9
)
logging.info( logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}"
)
logger.log_dict( logger.log_dict(
{ {
@ -728,16 +686,12 @@ def add_actor_information_and_train(
if optimization_step % log_freq == 0: if optimization_step % log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if save_checkpoint and ( if save_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
optimization_step % save_freq == 0 or optimization_step == online_steps
):
logging.info(f"Checkpoint policy after step {optimization_step}") logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps))) _num_digits = max(6, len(str(online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}" step_identifier = f"{optimization_step:0{_num_digits}d}"
interaction_step = ( interaction_step = (
interaction_message["Interaction step"] interaction_message["Interaction step"] if interaction_message is not None else 0
if interaction_message is not None
else 0
) )
logger.save_checkpoint( logger.save_checkpoint(
optimization_step, optimization_step,
@ -755,9 +709,7 @@ def add_actor_information_and_train(
shutil.rmtree( shutil.rmtree(
dataset_dir, dataset_dir,
) )
replay_buffer.to_lerobot_dataset( replay_buffer.to_lerobot_dataset(dataset_repo_id, fps=fps, root=logger.log_dir / "dataset")
dataset_repo_id, fps=fps, root=logger.log_dir / "dataset"
)
if offline_replay_buffer is not None: if offline_replay_buffer is not None:
dataset_dir = logger.log_dir / "dataset_offline" dataset_dir = logger.log_dir / "dataset_offline"
@ -809,9 +761,7 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
optimizer_critic = torch.optim.Adam( optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
) )
optimizer_temperature = torch.optim.Adam( optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
params=[policy.log_alpha], lr=policy.config.critic_lr
)
lr_scheduler = None lr_scheduler = None
optimizers = { optimizers = {
"actor": optimizer_actor, "actor": optimizer_actor,

View File

@ -1,10 +1,10 @@
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import logging import logging
from multiprocessing import Event, Queue from multiprocessing import Event, Queue
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks import hilserl_pb2 # type: ignore
from lerobot.scripts.server.network_utils import send_bytes_in_chunks import hilserl_pb2_grpc # type: ignore
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks, send_bytes_in_chunks
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
@ -64,9 +64,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
def SendInteractions(self, request_iterator, _context): def SendInteractions(self, request_iterator, _context):
# TODO: authorize the request # TODO: authorize the request
logging.info( logging.info("[LEARNER] Received request to receive interactions from the Actor")
"[LEARNER] Received request to receive interactions from the Actor"
)
receive_bytes_in_chunks( receive_bytes_in_chunks(
request_iterator, request_iterator,

View File

@ -1,12 +1,12 @@
import einops
import numpy as np
import gymnasium as gym
import torch
from omegaconf import DictConfig
from typing import Any from typing import Any
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
import einops
import gymnasium as gym
import numpy as np
import torch
from mani_skill.utils.wrappers.record import RecordEpisode from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from omegaconf import DictConfig
def preprocess_maniskill_observation( def preprocess_maniskill_observation(
@ -63,9 +63,7 @@ class ManiSkillCompat(gym.Wrapper):
new_action_space_shape = env.action_space.shape[-1] new_action_space_shape = env.action_space.shape[-1]
new_low = np.squeeze(env.action_space.low, axis=0) new_low = np.squeeze(env.action_space.low, axis=0)
new_high = np.squeeze(env.action_space.high, axis=0) new_high = np.squeeze(env.action_space.high, axis=0)
self.action_space = gym.spaces.Box( self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
low=new_low, high=new_high, shape=(new_action_space_shape,)
)
def reset( def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None self, *, seed: int | None = None, options: dict[str, Any] | None = None
@ -84,9 +82,7 @@ class ManiSkillCompat(gym.Wrapper):
class ManiSkillActionWrapper(gym.ActionWrapper): class ManiSkillActionWrapper(gym.ActionWrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
self.action_space = gym.spaces.Tuple( self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
spaces=(env.action_space, gym.spaces.Discrete(2))
)
def action(self, action): def action(self, action):
action, telop = action action, telop = action
@ -100,9 +96,7 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
action_space_agent: gym.spaces.Box = env.action_space[0] action_space_agent: gym.spaces.Box = env.action_space[0]
action_space_agent.low = action_space_agent.low * multiply_factor action_space_agent.low = action_space_agent.low * multiply_factor
action_space_agent.high = action_space_agent.high * multiply_factor action_space_agent.high = action_space_agent.high * multiply_factor
self.action_space = gym.spaces.Tuple( self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
spaces=(action_space_agent, gym.spaces.Discrete(2))
)
def step(self, action): def step(self, action):
if isinstance(action, tuple): if isinstance(action, tuple):
@ -153,9 +147,7 @@ def make_maniskill(
) )
env = ManiSkillObservationWrapper(env, device=cfg.env.device) env = ManiSkillObservationWrapper(env, device=cfg.env.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False) env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = ( env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
50 # gym_utils.find_max_episode_steps_value(env)
)
env.unwrapped.metadata["render_fps"] = 20 env.unwrapped.metadata["render_fps"] = 20
env = ManiSkillCompat(env) env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env) env = ManiSkillActionWrapper(env)
@ -166,12 +158,11 @@ def make_maniskill(
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
import hydra import hydra
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
"--config", type=str, default="lerobot/configs/env/maniskill_example.yaml"
)
args = parser.parse_args() args = parser.parse_args()
# Initialize config # Initialize config

View File

@ -15,12 +15,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from lerobot.scripts.server import hilserl_pb2
import logging
import io import io
from multiprocessing import Queue, Event import logging
from multiprocessing import Event, Queue
from typing import Any from typing import Any
from lerobot.scripts.server import hilserl_pb2
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
@ -31,9 +32,7 @@ def bytes_buffer_size(buffer: io.BytesIO) -> int:
return result return result
def send_bytes_in_chunks( def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True
):
buffer = io.BytesIO(buffer) buffer = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(buffer) size_in_bytes = bytes_buffer_size(buffer)
@ -56,16 +55,12 @@ def send_bytes_in_chunks(
yield message_class(transfer_state=transfer_state, data=chunk) yield message_class(transfer_state=transfer_state, data=chunk)
sent_bytes += size_to_read sent_bytes += size_to_read
logging_method( logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
)
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks( def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""):
iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""
):
bytes_buffer = io.BytesIO() bytes_buffer = io.BytesIO()
step = 0 step = 0
@ -89,9 +84,7 @@ def receive_bytes_in_chunks(
logging.debug(f"{log_prefix} Received data at step {step}") logging.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END: elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(item.data) bytes_buffer.write(item.data)
logging.debug( logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}"
)
queue.put(bytes_buffer.getvalue()) queue.put(bytes_buffer.getvalue())

View File

@ -18,9 +18,10 @@
import logging import logging
import signal import signal
import sys import sys
from torch.multiprocessing import Queue
from queue import Empty from queue import Empty
from torch.multiprocessing import Queue
shutdown_event_counter = 0 shutdown_event_counter = 0

View File

@ -223,18 +223,12 @@ def train(cfg: TrainPipelineConfig):
step = 0 # number of policy updates (forward + backward + optim) step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume: if cfg.resume:
step, optimizer, lr_scheduler = load_training_state( step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
cfg.checkpoint_path, optimizer, lr_scheduler
)
num_learnable_params = sum( num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
logging.info( logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}"
)
if cfg.env is not None: if cfg.env is not None:
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
@ -335,9 +329,7 @@ def train(cfg: TrainPipelineConfig):
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
with ( with (
torch.no_grad(), torch.no_grad(),
torch.autocast(device_type=device.type) torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
if cfg.policy.use_amp
else nullcontext(),
): ):
eval_info = eval_policy( eval_info = eval_policy(
eval_env, eval_env,

View File

@ -52,19 +52,13 @@ def get_model(cfg, logger): # noqa I001
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config) model = Classifier(classifier_config)
if cfg.resume: if cfg.resume:
model.load_state_dict( model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
Classifier.from_pretrained(
str(logger.last_pretrained_model_dir)
).state_dict()
)
return model return model
def create_balanced_sampler(dataset, cfg): def create_balanced_sampler(dataset, cfg):
# Get underlying dataset if using Subset # Get underlying dataset if using Subset
original_dataset = ( original_dataset = dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
)
# Get indices if using Subset (for slicing) # Get indices if using Subset (for slicing)
indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None
@ -83,9 +77,7 @@ def create_balanced_sampler(dataset, cfg):
class_weights = 1.0 / counts.float() class_weights = 1.0 / counts.float()
sample_weights = class_weights[labels] sample_weights = class_weights[labels]
return WeightedRandomSampler( return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
weights=sample_weights, num_samples=len(sample_weights), replacement=True
)
def support_amp(device: torch.device, cfg: DictConfig) -> bool: def support_amp(device: torch.device, cfg: DictConfig) -> bool:
@ -94,9 +86,7 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool:
return cfg.training.use_amp and device.type in ("cuda", "cpu") return cfg.training.use_amp and device.type in ("cuda", "cpu")
def train_epoch( def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg
):
# Single epoch training loop with AMP support and progress tracking # Single epoch training loop with AMP support and progress tracking
model.train() model.train()
correct = 0 correct = 0
@ -110,11 +100,7 @@ def train_epoch(
labels = batch[cfg.training.label_key].float().to(device) labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP # Forward pass with optional AMP
with ( with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext()
):
outputs = model(images) outputs = model(images)
loss = criterion(outputs.logits, labels) loss = criterion(outputs.logits, labels)
@ -159,9 +145,7 @@ def validate(model, val_loader, criterion, device, logger, cfg):
with ( with (
torch.no_grad(), torch.no_grad(),
torch.autocast(device_type=device.type) torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
if support_amp(device, cfg)
else nullcontext(),
): ):
for batch in tqdm(val_loader, desc="Validation"): for batch in tqdm(val_loader, desc="Validation"):
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys] images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
@ -174,9 +158,7 @@ def validate(model, val_loader, criterion, device, logger, cfg):
): ):
outputs = model(images) outputs = model(images)
inference_times.append( inference_times.append(
next( next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
) )
else: else:
outputs = model(images) outputs = model(images)
@ -194,24 +176,16 @@ def validate(model, val_loader, criterion, device, logger, cfg):
# Log sample predictions for visualization # Log sample predictions for visualization
if len(samples) < cfg.eval.num_samples_to_log: if len(samples) < cfg.eval.num_samples_to_log:
for i in range( for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
min(cfg.eval.num_samples_to_log - len(samples), len(images))
):
if model.config.num_classes == 2: if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3) confidence = round(outputs.probabilities[i].item(), 3)
else: else:
confidence = [ confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
round(prob, 3) for prob in outputs.probabilities[i].tolist()
]
samples.append( samples.append(
{ {
**{ **{
f"image_{img_key}": wandb.Image( f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
images[img_idx][i].cpu() for img_idx, img_key in enumerate(cfg.training.image_keys)
)
for img_idx, img_key in enumerate(
cfg.training.image_keys
)
}, },
"true_label": labels[i].item(), "true_label": labels[i].item(),
"predicted": predictions[i].item(), "predicted": predictions[i].item(),
@ -286,9 +260,7 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
_ = model(x) _ = model(x)
inference_times.append( inference_times.append(
next( next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
) )
inference_times = np.array(inference_times) inference_times = np.array(inference_times)
@ -314,9 +286,7 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
return avg, median, std return avg, median, std
def train( def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None) -> None:
cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None
) -> None:
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
if job_name is None: if job_name is None:
@ -372,9 +342,7 @@ def train(
"You have set resume=True, but there is no model checkpoint in " "You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}" f"{Logger.get_last_checkpoint_dir(out_dir)}"
) )
checkpoint_cfg_path = str( checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
logging.info( logging.info(
colored( colored(
"You have set resume=True, indicating that you wish to resume a run", "You have set resume=True, indicating that you wish to resume a run",
@ -387,9 +355,7 @@ def train(
# Check for differences between the checkpoint configuration and provided configuration. # Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff. # Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg) resolve_delta_timestamps(cfg)
diff = DeepDiff( diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)
)
# Ignore the `resume` and parameters. # Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]: if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"] del diff["values_changed"]["root['resume']"]
@ -408,11 +374,7 @@ def train(
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate) optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class # Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
criterion = ( criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
nn.BCEWithLogitsLoss()
if model.config.num_classes == 2
else nn.CrossEntropyLoss()
)
grad_scaler = GradScaler(enabled=cfg.training.use_amp) grad_scaler = GradScaler(enabled=cfg.training.use_amp)
# Log model parameters # Log model parameters

View File

@ -52,9 +52,7 @@ def make_optimizers_and_scheduler(cfg, policy):
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
) )
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module # We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
optimizer_temperature = torch.optim.Adam( optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
params=[policy.log_alpha], lr=policy.config.critic_lr
)
lr_scheduler = None lr_scheduler = None
optimizers = { optimizers = {
"actor": optimizer_actor, "actor": optimizer_actor,
@ -106,9 +104,7 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels # Gather pixels
cropped_hwcn = images_hwcn[ cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
]
# cropped_hwcn => (B, crop_h, crop_w, C) # cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
@ -198,12 +194,8 @@ class ReplayBuffer:
""" """
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from # We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
# a replay buffer than from a lerobot dataset. # a replay buffer than from a lerobot dataset.
replay_buffer = cls( replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
capacity=len(lerobot_dataset), device=device, state_keys=state_keys list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
)
list_transition = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
# Fill the replay buffer with the lerobot dataset transitions # Fill the replay buffer with the lerobot dataset transitions
for data in list_transition: for data in list_transition:
replay_buffer.add( replay_buffer.add(
@ -248,9 +240,7 @@ class ReplayBuffer:
# If not provided, you can either raise an error or define a default: # If not provided, you can either raise an error or define a default:
if state_keys is None: if state_keys is None:
raise ValueError( raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
"You must provide a list of keys in `state_keys` that define your 'state'."
)
transitions: list[Transition] = [] transitions: list[Transition] = []
num_frames = len(dataset) num_frames = len(dataset)
@ -304,40 +294,36 @@ class ReplayBuffer:
# -- Build batched states -- # -- Build batched states --
batch_state = {} batch_state = {}
for key in self.state_keys: for key in self.state_keys:
batch_state[key] = torch.cat( batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
[t["state"][key] for t in list_of_transitions], dim=0 self.device
).to(self.device) )
if key.startswith("observation.image") and self.use_drq: if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key]) batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions -- # -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to( batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
self.device
)
# -- Build batched rewards -- # -- Build batched rewards --
batch_rewards = torch.tensor( batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
[t["reward"] for t in list_of_transitions], dtype=torch.float32 self.device
).to(self.device) )
# -- Build batched next states -- # -- Build batched next states --
batch_next_state = {} batch_next_state = {}
for key in self.state_keys: for key in self.state_keys:
batch_next_state[key] = torch.cat( batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
[t["next_state"][key] for t in list_of_transitions], dim=0 self.device
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(
batch_next_state[key]
) )
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
# -- Build batched dones -- # -- Build batched dones --
batch_dones = torch.tensor( batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
[t["done"] for t in list_of_transitions], dtype=torch.float32 self.device
).to(self.device) )
batch_dones = torch.tensor( batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
[t["done"] for t in list_of_transitions], dtype=torch.float32 self.device
).to(self.device) )
# Return a BatchTransition typed dict # Return a BatchTransition typed dict
return BatchTransition( return BatchTransition(
@ -427,9 +413,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats # Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None, dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
if cfg.resume
else None,
device=device, device=device,
) )
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
@ -438,9 +422,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# TODO: Handle resume # TODO: Handle resume
num_learnable_params = sum( num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir) log_output_dir(out_dir)
@ -481,16 +463,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
if interaction_step >= cfg.training.online_step_before_learning: if interaction_step >= cfg.training.online_step_before_learning:
action = policy.select_action(batch=obs) action = policy.select_action(batch=obs)
next_obs, reward, done, truncated, info = online_env.step( next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
action.cpu().numpy()
)
else: else:
action = online_env.action_space.sample() action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action) next_obs, reward, done, truncated, info = online_env.step(action)
# HACK # HACK
action = torch.tensor(action, dtype=torch.float32).to( action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
device, non_blocking=True
)
# HACK: For maniskill # HACK: For maniskill
# next_obs = preprocess_observation(next_obs) # next_obs = preprocess_observation(next_obs)
@ -500,20 +478,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Because we are using a single environment # Because we are using a single environment
# we can safely assume that the episode is done # we can safely assume that the episode is done
if done[0] or truncated[0]: if done[0] or truncated[0]:
logging.info( logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
f"Global step {interaction_step}: Episode reward: {sum_reward_episode}" logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
)
logger.log_dict(
{"Sum episode reward": sum_reward_episode}, interaction_step
)
sum_reward_episode = 0 sum_reward_episode = 0
# HACK: This is for maniskill # HACK: This is for maniskill
logging.info( logging.info(
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n" f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
) )
logger.log_dict( logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
{"Episode success": info["success"].float().item()}, interaction_step
)
replay_buffer.add( replay_buffer.add(
state=obs, state=obs,
@ -587,9 +559,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
training_infos["loss_actor"] = loss_actor.item() training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature( loss_temperature = policy.compute_loss_temperature(observations=observations)
observations=observations
)
optimizers["temperature"].zero_grad() optimizers["temperature"].zero_grad()
loss_temperature.backward() loss_temperature.backward()
optimizers["temperature"].step() optimizers["temperature"].step()
@ -611,9 +581,7 @@ def train_cli(cfg: dict):
) )
def train_notebook( def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
out_dir=None, job_name=None, config_name="default", config_path="../configs"
):
from hydra import compose, initialize from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear() hydra.core.global_hydra.GlobalHydra.instance().clear()

View File

@ -94,12 +94,8 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3 assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape c, h, w = chw_float32_torch.shape
assert c < h and c < w, ( assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
f"expect channel first images, but instead {chw_float32_torch.shape}" hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
)
hwc_uint8_numpy = (
(chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
)
return hwc_uint8_numpy return hwc_uint8_numpy

View File

@ -142,12 +142,8 @@ def run_server(
) )
) )
@app.route( @app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
"/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>" def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
)
def show_episode(
dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes
):
repo_id = f"{dataset_namespace}/{dataset_name}" repo_id = f"{dataset_namespace}/{dataset_name}"
try: try:
if dataset is None: if dataset is None:
@ -158,9 +154,7 @@ def run_server(
400, 400,
) )
dataset_version = ( dataset_version = (
str(dataset.meta._version) str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
if isinstance(dataset, LeRobotDataset)
else dataset.codebase_version
) )
match = re.search(r"v(\d+)\.", dataset_version) match = re.search(r"v(\d+)\.", dataset_version)
if match: if match:
@ -168,9 +162,7 @@ def run_server(
if major_version < 2: if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above." return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns, ignored_columns = get_episode_data( episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
dataset, episode_id
)
dataset_info = { dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}", "repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames "num_samples": dataset.num_frames
@ -183,8 +175,7 @@ def run_server(
} }
if isinstance(dataset, LeRobotDataset): if isinstance(dataset, LeRobotDataset):
video_paths = [ video_paths = [
dataset.meta.get_video_file_path(episode_id, key) dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
for key in dataset.meta.video_keys
] ]
videos_info = [ videos_info = [
{ {
@ -197,9 +188,7 @@ def run_server(
] ]
tasks = dataset.meta.episodes[episode_id]["tasks"] tasks = dataset.meta.episodes[episode_id]["tasks"]
else: else:
video_keys = [ video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
key for key, ft in dataset.features.items() if ft["dtype"] == "video"
]
videos_info = [ videos_info = [
{ {
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/" "url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
@ -219,24 +208,16 @@ def run_server(
) )
response.raise_for_status() response.raise_for_status()
# Split into lines and parse each line as JSON # Split into lines and parse each line as JSON
tasks_jsonl = [ tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
json.loads(line) for line in response.text.splitlines() if line.strip()
]
filtered_tasks_jsonl = [ filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
row for row in tasks_jsonl if row["episode_index"] == episode_id
]
tasks = filtered_tasks_jsonl[0]["tasks"] tasks = filtered_tasks_jsonl[0]["tasks"]
videos_info[0]["language_instruction"] = tasks videos_info[0]["language_instruction"] = tasks
if episodes is None: if episodes is None:
episodes = list( episodes = list(
range( range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
dataset.num_episodes
if isinstance(dataset, LeRobotDataset)
else dataset.total_episodes
)
) )
return render_template( return render_template(
@ -263,11 +244,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
This file will be loaded by Dygraph javascript to plot data in real time.""" This file will be loaded by Dygraph javascript to plot data in real time."""
columns = [] columns = []
selected_columns = [ selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
col
for col, ft in dataset.features.items()
if ft["dtype"] in ["float32", "int32"]
]
selected_columns.remove("timestamp") selected_columns.remove("timestamp")
ignored_columns = [] ignored_columns = []
@ -288,10 +265,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
else dataset.features[column_name].shape[0] else dataset.features[column_name].shape[0]
) )
if ( if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
"names" in dataset.features[column_name]
and dataset.features[column_name]["names"]
):
column_names = dataset.features[column_name]["names"] column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list): while not isinstance(column_names, list):
column_names = list(column_names.values())[0] column_names = list(column_names.values())[0]
@ -314,13 +288,10 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
else: else:
repo_id = dataset.repo_id repo_id = dataset.repo_id
url = ( url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size, episode_chunk=int(episode_index) // dataset.chunks_size,
episode_index=episode_index, episode_index=episode_index,
) )
)
df = pd.read_parquet(url) df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns data = df[selected_columns] # Select specific columns
@ -352,9 +323,7 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
] ]
def get_episode_language_instruction( def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
dataset: LeRobotDataset, ep_index: int
) -> list[str]:
# check if the dataset has language instructions # check if the dataset has language instructions
if "language_instruction" not in dataset.features: if "language_instruction" not in dataset.features:
return None return None
@ -365,9 +334,7 @@ def get_episode_language_instruction(
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"] language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string # with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix( return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
"', shape=(), dtype=string)"
)
def get_dataset_info(repo_id: str) -> IterableNamespace: def get_dataset_info(repo_id: str) -> IterableNamespace:
@ -403,9 +370,7 @@ def visualize_dataset_html(
if force_override: if force_override:
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
else: else:
logging.info( logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
f"Output directory already exists. Loading from it: '{output_dir}'"
)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

View File

@ -47,9 +47,7 @@ OUTPUT_DIR = Path("outputs/image_transforms")
to_pil = ToPILImage() to_pil = ToPILImage()
def save_all_transforms( def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
cfg: ImageTransformsConfig, original_frame, output_dir, n_examples
):
output_dir_all = output_dir / "all" output_dir_all = output_dir / "all"
output_dir_all.mkdir(parents=True, exist_ok=True) output_dir_all.mkdir(parents=True, exist_ok=True)
@ -62,9 +60,7 @@ def save_all_transforms(
print(f" {output_dir_all}") print(f" {output_dir_all}")
def save_each_transform( def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
cfg: ImageTransformsConfig, original_frame, output_dir, n_examples
):
if not cfg.enable: if not cfg.enable:
logging.warning( logging.warning(
"No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`." "No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`."
@ -93,15 +89,9 @@ def save_each_transform(
tf_cfg_kwgs_max[key] = [max_, max_] tf_cfg_kwgs_max[key] = [max_, max_]
tf_cfg_kwgs_avg[key] = [avg, avg] tf_cfg_kwgs_avg[key] = [avg, avg]
tf_min = make_transform_from_config( tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}))
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}) tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}))
) tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}))
tf_max = make_transform_from_config(
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max})
)
tf_avg = make_transform_from_config(
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg})
)
tf_frame_min = tf_min(original_frame) tf_frame_min = tf_min(original_frame)
tf_frame_max = tf_max(original_frame) tf_frame_max = tf_max(original_frame)
@ -115,9 +105,7 @@ def save_each_transform(
@draccus.wrap() @draccus.wrap()
def visualize_image_transforms( def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5):
cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5
):
dataset = LeRobotDataset( dataset = LeRobotDataset(
repo_id=cfg.repo_id, repo_id=cfg.repo_id,
episodes=cfg.episodes, episodes=cfg.episodes,

View File

@ -52,13 +52,7 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 frames at the middle of first episode # save 2 frames at the middle of first episode
i = int( i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
(
dataset.episode_data_index["to"][0].item()
- dataset.episode_data_index["from"][0].item()
)
/ 2
)
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")

View File

@ -51,9 +51,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
batch = next(iter(dataloader)) batch = next(iter(dataloader))
loss, output_dict = policy.forward(batch) loss, output_dict = policy.forward(batch)
if output_dict is not None: if output_dict is not None:
output_dict = { output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)
}
output_dict["loss"] = loss output_dict["loss"] = loss
else: else:
output_dict = {"loss": loss} output_dict = {"loss": loss}
@ -71,9 +69,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
param_stats = {} param_stats = {}
for key, param in policy.named_parameters(): for key, param in policy.named_parameters():
param_stats[f"{key}_mean"] = param.mean() param_stats[f"{key}_mean"] = param.mean()
param_stats[f"{key}_std"] = ( param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
param.std() if param.numel() > 1 else torch.tensor(float(0.0))
)
optimizer.zero_grad() optimizer.zero_grad()
policy.reset() policy.reset()
@ -100,15 +96,11 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
else: else:
actions_queue = train_cfg.policy.n_action_repeats actions_queue = train_cfg.policy.n_action_repeats
actions = { actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)
}
return output_dict, grad_stats, param_stats, actions return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors( def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict
):
if output_dir.exists(): if output_dir.exists():
print(f"Overwrite existing safetensors in '{output_dir}':") print(f"Overwrite existing safetensors in '{output_dir}':")
print(f" - Validate with: `git add {output_dir}`") print(f" - Validate with: `git add {output_dir}`")
@ -116,9 +108,7 @@ def save_policy_to_safetensors(
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats( output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
ds_repo_id, policy_name, policy_kwargs
)
save_file(output_dict, output_dir / "output_dict.safetensors") save_file(output_dict, output_dir / "output_dict.safetensors")
save_file(grad_stats, output_dir / "grad_stats.safetensors") save_file(grad_stats, output_dir / "grad_stats.safetensors")
save_file(param_stats, output_dir / "param_stats.safetensors") save_file(param_stats, output_dir / "param_stats.safetensors")
@ -151,7 +141,5 @@ if __name__ == "__main__":
raise RuntimeError("No policies were provided!") raise RuntimeError("No policies were provided!")
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg: for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
ds_name = ds_repo_id.split("/")[-1] ds_name = ds_repo_id.split("/")[-1]
output_dir = ( output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
)
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs) save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)

View File

@ -30,9 +30,7 @@ class config: # noqa: N801
def enable_device(self, device_id: str): def enable_device(self, device_id: str):
self.device_enabled = device_id self.device_enabled = device_id
def enable_stream( def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None):
self, stream_type: stream, width=None, height=None, color_format=None, fps=None
):
self.stream_type = stream_type self.stream_type = stream_type
# Overwrite default values when possible # Overwrite default values when possible
self.width = 848 if width is None else width self.width = 848 if width is None else width

View File

@ -9,9 +9,7 @@ from lerobot.common.envs.configs import EnvConfig
from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap
def create_plugin_code( def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str:
*, base_class: str = "EnvConfig", plugin_name: str = "test_env"
) -> str:
"""Creates a dummy plugin module that implements its own EnvConfig subclass.""" """Creates a dummy plugin module that implements its own EnvConfig subclass."""
return f""" return f"""
from dataclasses import dataclass from dataclasses import dataclass

Some files were not shown because too many files have changed in this diff Show More