diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 27899870..c21a120f 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -967,17 +967,8 @@ class LeRobotDataset(torch.utils.data.Dataset): Starts recording audio data provided by the microphone and directly writes it in a .wav file. """ - audio_dir = self._get_raw_audio_file_path( - self.num_episodes, "observation.audio." + microphone_key - ).parent - if not audio_dir.is_dir(): - audio_dir.mkdir(parents=True, exist_ok=True) - - microphone.start_recording( - output_file=self._get_raw_audio_file_path( - self.num_episodes, "observation.audio." + microphone_key - ) - ) + audio_file = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key) + microphone.start_recording(output_file=audio_file) def save_episode(self, episode_data: dict | None = None) -> None: """ diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index a92be011..96d62d9c 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -302,7 +302,12 @@ class Microphone: return audio_readings - def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None: + def start_recording( + self, + output_file: str | None = None, + multiprocessing: bool | None = False, + overwrite: bool | None = True, + ) -> None: """ Starts the recording of the microphone. If output_file is provided, the audio will be written to this file. """ @@ -323,8 +328,15 @@ class Microphone: # Write recordings into a file if output_file is provided if output_file is not None: output_file = Path(output_file) + output_file.parent.mkdir(parents=True, exist_ok=True) + if output_file.exists(): - output_file.unlink() + if overwrite: + output_file.unlink() + else: + raise FileExistsError( + f"Output file {output_file} already exists. Set overwrite to True to overwrite it." + ) if multiprocessing: self.record_stop_event = process_Event() diff --git a/lerobot/common/robot_devices/microphones/utils.py b/lerobot/common/robot_devices/microphones/utils.py index fb1bac85..4a3df76e 100644 --- a/lerobot/common/robot_devices/microphones/utils.py +++ b/lerobot/common/robot_devices/microphones/utils.py @@ -21,7 +21,12 @@ from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, M class Microphone(Protocol): def connect(self): ... def disconnect(self): ... - def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False): ... + def start_recording( + self, + output_file: str | None = None, + multiprocessing: bool | None = False, + overwrite: bool | None = True, + ): ... def stop_recording(self): ...