diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py
index d8da85d6..488b7696 100644
--- a/lerobot/common/datasets/lerobot_dataset.py
+++ b/lerobot/common/datasets/lerobot_dataset.py
@@ -36,8 +36,11 @@ from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
 from lerobot.common.datasets.utils import (
     DEFAULT_FEATURES,
     DEFAULT_IMAGE_PATH,
+    DEFAULT_RAW_AUDIO_PATH,
+    DEFAULT_COMPRESSED_AUDIO_PATH,
     INFO_PATH,
     TASKS_PATH,
+    DEFAULT_AUDIO_CHUNK_DURATION,
     append_jsonlines,
     backward_compatible_episodes_stats,
     check_delta_timestamps,
@@ -69,8 +72,11 @@ from lerobot.common.datasets.video_utils import (
     VideoFrame,
     decode_video_frames,
     encode_video_frames,
+    encode_audio,
+    decode_audio,
     get_safe_default_codec,
     get_video_info,
+    get_audio_info,
 )
 from lerobot.common.robot_devices.robots.utils import Robot
 
@@ -141,6 +147,11 @@ class LeRobotDatasetMetadata:
         ep_chunk = self.get_episode_chunk(ep_index)
         fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
         return Path(fpath)
+    
+    def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
+        episode_chunk = self.get_episode_chunk(episode_index)
+        fpath = self.audio_path.format(episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index)
+        return self.root / fpath
 
     def get_episode_chunk(self, ep_index: int) -> int:
         return ep_index // self.chunks_size
@@ -154,6 +165,11 @@ class LeRobotDatasetMetadata:
     def video_path(self) -> str | None:
         """Formattable string for the video files."""
         return self.info["video_path"]
+    
+    @property
+    def audio_path(self) -> str | None:
+        """Formattable string for the audio files."""
+        return self.info["audio_path"]
 
     @property
     def robot_type(self) -> str | None:
@@ -184,6 +200,11 @@ class LeRobotDatasetMetadata:
     def camera_keys(self) -> list[str]:
         """Keys to access visual modalities (regardless of their storage method)."""
         return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
+    
+    @property
+    def audio_keys(self) -> list[str]:
+        """Keys to access audio modalities."""
+        return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
 
     @property
     def names(self) -> dict[str, list | dict]:
@@ -264,6 +285,9 @@ class LeRobotDatasetMetadata:
         if len(self.video_keys) > 0:
             self.update_video_info()
 
+        if len(self.audio_keys) > 0:
+            self.update_audio_info()
+
         write_info(self.info, self.root)
 
         episode_dict = {
@@ -288,6 +312,17 @@ class LeRobotDatasetMetadata:
                 video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
                 self.info["features"][key]["info"] = get_video_info(video_path)
 
+    def update_audio_info(self) -> None:
+        """
+        Warning: this function writes info from first episode audio, implicitly assuming that all audio have
+        been encoded the same way. Also, this means it assumes the first episode exists.
+        """
+        bound_audio_keys = {self.features[video_key]["audio"] for video_key in self.video_keys if self.features[video_key]["audio"] is not None}
+        for key in set(self.audio_keys) - bound_audio_keys:
+            if not self.features[key].get("info", None):
+                audio_path = self.root / self.get_compressed_audio_file_path(0, key)
+                self.info["features"][key]["info"] = get_audio_info(audio_path)
+
     def __repr__(self):
         feature_keys = list(self.features)
         return (
@@ -364,6 +399,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
         force_cache_sync: bool = False,
         download_videos: bool = True,
         video_backend: str | None = None,
+        audio_backend: str | None = None,
     ):
         """
         2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -465,6 +501,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
                 True.
             video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
                 You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
+            audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'.
         """
         super().__init__()
         self.repo_id = repo_id
@@ -475,6 +512,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
         self.tolerance_s = tolerance_s
         self.revision = revision if revision else CODEBASE_VERSION
         self.video_backend = video_backend if video_backend else get_safe_default_codec()
+        self.audio_backend = audio_backend if audio_backend else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal)
         self.delta_indices = None
 
         # Unused attributes
@@ -499,7 +537,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
             self.hf_dataset = self.load_hf_dataset()
         except (AssertionError, FileNotFoundError, NotADirectoryError):
             self.revision = get_safe_version(self.repo_id, self.revision)
-            self.download_episodes(download_videos)
+            self.download_episodes(download_videos) #Sould load audio as well #TODO(CarolinePascal): separate audio from video
             self.hf_dataset = self.load_hf_dataset()
 
         self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
@@ -677,7 +715,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
         }
         return query_indices, padding
 
-    def _get_query_timestamps(
+    def _get_query_timestamps_video(
         self,
         current_ts: float,
         query_indices: dict[str, list[int]] | None = None,
@@ -691,6 +729,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
                 query_timestamps[key] = [current_ts]
 
         return query_timestamps
+    
+    #TODO(CarolinePascal): add variable query durations
+    def _get_query_timestamps_audio(
+        self,
+        current_ts: float,
+        query_indices: dict[str, list[int]] | None = None,
+    ) -> dict[str, list[float]]:   
+        query_timestamps = {}
+        for key in self.meta.audio_keys:
+            if query_indices is not None and key in query_indices:
+                timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
+                query_timestamps[key] = torch.stack(timestamps).tolist()
+            else:
+                query_timestamps[key] = [current_ts]
+
+        return query_timestamps
 
     def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
         return {
@@ -713,6 +767,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
 
         return item
 
+    #TODO(CarolinePascal): add variable query durations 
+    def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]:
+        item = {}
+        bound_audio_keys_mapping = {self.meta.features[video_key]["audio"]:video_key for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None}
+        for audio_key, query_ts in query_timestamps.items():
+            #Audio stored with video in a single .mp4 file
+            if audio_key in bound_audio_keys_mapping.keys():
+                audio_path = self.root / self.meta.get_video_file_path(ep_idx, bound_audio_keys_mapping[audio_key])
+            #Audio stored alone in a separate .m4a file
+            else:
+                audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
+            audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
+            item[audio_key] = audio_chunk.squeeze(0)
+        return item
+
     def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
         for key, val in padding.items():
             item[key] = torch.BoolTensor(val)
@@ -733,11 +802,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
             for key, val in query_result.items():
                 item[key] = val
 
-        if len(self.meta.video_keys) > 0:
+        if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0:
             current_ts = item["timestamp"].item()
-            query_timestamps = self._get_query_timestamps(current_ts, query_indices)
+
+            query_timestamps = self._get_query_timestamps_video(current_ts, query_indices)
             video_frames = self._query_videos(query_timestamps, ep_idx)
-            item = {**video_frames, **item}
+            item = {**item, **video_frames}
+
+            query_timestamps = self._get_query_timestamps_audio(current_ts, query_indices)
+            audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
+            item = {**item, **audio_chunks}
 
         if self.image_transforms is not None:
             image_keys = self.meta.camera_keys
@@ -776,6 +850,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
             image_key=image_key, episode_index=episode_index, frame_index=frame_index
         )
         return self.root / fpath
+    
+    def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
+        fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
+        return self.root / fpath
 
     def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
         if self.image_writer is None:
@@ -867,7 +945,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
         for key, ft in self.features.items():
             # index, episode_index, task_index are already processed above, and image and video
             # are processed separately by storing image path and frame info as meta data
-            if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
+            if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video", "audio"]:
                 continue
             episode_buffer[key] = np.stack(episode_buffer[key])
 
@@ -880,6 +958,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
             for key in self.meta.video_keys:
                 episode_buffer[key] = video_paths[key]
 
+        if len(self.meta.audio_keys) > 0:
+            _ = self.encode_episode_audio(episode_index)
+            
         # `meta.save_episode` be executed after encoding the videos
         self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
 
@@ -904,6 +985,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
         if img_dir.is_dir():
             shutil.rmtree(self.root / "images")
 
+        # delete raw audio
+        raw_audio_files = list(self.root.rglob("*.wav"))
+        for raw_audio_file in raw_audio_files:
+            raw_audio_file.unlink()
+            if len(list(raw_audio_file.parent.iterdir())) == 0:
+                raw_audio_file.parent.rmdir()
+
         if not episode_data:  # Reset the buffer
             self.episode_buffer = self.create_episode_buffer()
 
@@ -971,18 +1059,45 @@ class LeRobotDataset(torch.utils.data.Dataset):
         since video encoding with ffmpeg is already using multithreading.
         """
         video_paths = {}
-        for key in self.meta.video_keys:
-            video_path = self.root / self.meta.get_video_file_path(episode_index, key)
-            video_paths[key] = str(video_path)
+        for video_key in self.meta.video_keys:
+            video_path = self.root / self.meta.get_video_file_path(episode_index, video_key)
+            video_paths[video_key] = str(video_path)
             if video_path.is_file():
                 # Skip if video is already encoded. Could be the case when resuming data recording.
                 continue
             img_dir = self._get_image_file_path(
-                episode_index=episode_index, image_key=key, frame_index=0
+                episode_index=episode_index, image_key=video_key, frame_index=0
             ).parent
-            encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
+
+            audio_path = None
+            if self.meta.features[video_key]["audio"] is not None:
+                audio_key = self.meta.features[video_key]["audio"]
+                audio_path = self._get_raw_audio_file_path(episode_index, audio_key)
+
+            encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True)
 
         return video_paths
+    
+    def encode_episode_audio(self, episode_index: int) -> dict:
+        """
+        Use ffmpeg to convert .wav raw audio files into .m4a audio files.
+        Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
+        since video encoding with ffmpeg is already using multithreading.
+        """
+        audio_paths = {}
+        bound_audio_keys = {self.meta.features[video_key]["audio"] for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None}
+        for audio_key in set(self.meta.audio_keys) - bound_audio_keys:
+            input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
+            output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)
+            
+            audio_paths[audio_key] = str(output_audio_path)
+            if output_audio_path.is_file():
+                # Skip if video is already encoded. Could be the case when resuming data recording.
+                continue
+
+            encode_audio(input_audio_path, output_audio_path, overwrite=True)
+
+        return audio_paths
 
     @classmethod
     def create(
@@ -998,6 +1113,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
         image_writer_processes: int = 0,
         image_writer_threads: int = 0,
         video_backend: str | None = None,
+        audio_backend: str | None = None,
     ) -> "LeRobotDataset":
         """Create a LeRobot Dataset from scratch in order to record data."""
         obj = cls.__new__(cls)
@@ -1029,6 +1145,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
         obj.delta_indices = None
         obj.episode_data_index = None
         obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
+        obj.audio_backend = audio_backend if audio_backend is not None else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal)
         return obj
 
 
@@ -1049,6 +1166,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
         tolerances_s: dict | None = None,
         download_videos: bool = True,
         video_backend: str | None = None,
+        audio_backend: str | None = None,
     ):
         super().__init__()
         self.repo_ids = repo_ids
@@ -1066,6 +1184,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
                 tolerance_s=self.tolerances_s[repo_id],
                 download_videos=download_videos,
                 video_backend=video_backend,
+                audio_backend=audio_backend,
             )
             for repo_id in repo_ids
         ]
diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py
index 9d8a54db..827e711b 100644
--- a/lerobot/common/datasets/utils.py
+++ b/lerobot/common/datasets/utils.py
@@ -55,6 +55,10 @@ TASKS_PATH = "meta/tasks.jsonl"
 DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
 DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
 DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
+DEFAULT_RAW_AUDIO_PATH = "audio/{audio_key}/episode_{episode_index:06d}.wav"
+DEFAULT_COMPRESSED_AUDIO_PATH = "audio/chunk-{episode_chunk:03d}/{audio_key}/episode_{episode_index:06d}.m4a"
+
+DEFAULT_AUDIO_CHUNK_DURATION = 0.5  # seconds
 
 DATASET_CARD_TEMPLATE = """
 ---
@@ -363,7 +367,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
 def get_hf_features_from_features(features: dict) -> datasets.Features:
     hf_features = {}
     for key, ft in features.items():
-        if ft["dtype"] == "video":
+        if ft["dtype"] == "video" or ft["dtype"] == "audio":
             continue
         elif ft["dtype"] == "image":
             hf_features[key] = datasets.Image()
@@ -394,7 +398,13 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
             key: {"dtype": "video" if use_videos else "image", **ft}
             for key, ft in robot.camera_features.items()
         }
-    return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
+    microphones_ft = {}
+    if robot.microphones:
+        microphones_ft = {
+            key: {"dtype": "audio", **ft}
+            for key, ft in robot.microphones_features.items()
+        }
+    return {**robot.motor_features, **camera_ft, **microphones_ft, **DEFAULT_FEATURES}
 
 
 def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
@@ -448,6 +458,7 @@ def create_empty_dataset_info(
         "splits": {},
         "data_path": DEFAULT_PARQUET_PATH,
         "video_path": DEFAULT_VIDEO_PATH if use_videos else None,
+        "audio_path": DEFAULT_COMPRESSED_AUDIO_PATH,
         "features": features,
     }
 
@@ -721,6 +732,7 @@ def validate_features_presence(
 ):
     error_message = ""
     missing_features = expected_features - actual_features
+    missing_features =  {feature for feature in missing_features if "observation.audio" not in feature}
     extra_features = actual_features - (expected_features | optional_features)
 
     if missing_features or extra_features:
diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py
index c38d570d..44d5a1a5 100644
--- a/lerobot/common/datasets/video_utils.py
+++ b/lerobot/common/datasets/video_utils.py
@@ -26,9 +26,12 @@ from typing import Any, ClassVar
 import pyarrow as pa
 import torch
 import torchvision
+import torchaudio
 from datasets.features.features import register_feature
 from PIL import Image
 
+from numpy import ceil
+
 
 def get_safe_default_codec():
     if importlib.util.find_spec("torchcodec"):
@@ -39,7 +42,72 @@ def get_safe_default_codec():
         )
         return "pyav"
 
+def decode_audio(
+    audio_path: Path | str,
+    timestamps: list[float],
+    duration: float,
+    backend: str | None = "ffmpeg",
+) -> torch.Tensor:
+    """
+    Decodes audio using the specified backend.
+    Args:
+        audio_path (Path): Path to the audio file.
+        timestamps (list[float]): List of timestamps to extract frames.
+        tolerance_s (float): Allowed deviation in seconds for frame retrieval.
+        backend (str, optional): Backend to use for decoding. Defaults to "pyav".
 
+    Returns:
+        torch.Tensor: Decoded frames.
+
+    Currently supports pyav.
+    """
+    if backend == "torchcodec":
+        raise NotImplementedError("torchcodec is not yet supported for audio decoding")
+    elif backend == "ffmpeg":
+        return decode_audio_torchvision(audio_path, timestamps, duration)
+    else:
+        raise ValueError(f"Unsupported video backend: {backend}")
+
+def decode_audio_torchvision(
+    audio_path: Path | str,
+    timestamps: list[float],   
+    duration: float,   
+    log_loaded_timestamps: bool = False,
+) -> torch.Tensor:
+    
+    #TODO(CarolinePascal) : add channels selection
+    audio_path = str(audio_path)
+
+    reader = torchaudio.io.StreamReader(src=audio_path)
+    audio_sampling_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
+
+    #TODO(CarolinePascal) : sort timestamps ?
+
+    reader.add_basic_audio_stream(
+        frames_per_chunk = int(ceil(duration * audio_sampling_rate)),    #Too much is better than not enough
+        buffer_chunk_size = -1, #No dropping frames
+    )
+
+    audio_chunks = []
+    for ts in timestamps:
+        reader.seek(ts) #Default to closest audio sample
+        status = reader.fill_buffer()
+        if status != 0:
+            logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
+
+        current_audio_chunk = reader.pop_chunks()[0]
+
+        if log_loaded_timestamps:
+            logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sampling_rate:.4f}")
+        
+        audio_chunks.append(current_audio_chunk)
+
+    audio_chunks = torch.stack(audio_chunks)
+    #TODO(CarolinePascal) : pytorch format conversion ?
+
+    assert len(timestamps) == len(audio_chunks)
+    return audio_chunks
+    
 def decode_video_frames(
     video_path: Path | str,
     timestamps: list[float],
@@ -69,7 +137,6 @@ def decode_video_frames(
     else:
         raise ValueError(f"Unsupported video backend: {backend}")
 
-
 def decode_video_frames_torchvision(
     video_path: Path | str,
     timestamps: list[float],
@@ -167,7 +234,6 @@ def decode_video_frames_torchvision(
     assert len(timestamps) == len(closest_frames)
     return closest_frames
 
-
 def decode_video_frames_torchcodec(
     video_path: Path | str,
     timestamps: list[float],
@@ -242,15 +308,52 @@ def decode_video_frames_torchcodec(
     assert len(timestamps) == len(closest_frames)
     return closest_frames
 
+def encode_audio(
+    input_path: Path | str,
+    output_path: Path | str,
+    codec: str = "aac",
+    log_level: str | None = "error",
+    overwrite: bool = False,
+) -> None:
+    """Encodes an audio file using ffmpeg."""
+    output_path = Path(output_path)
+    output_path.parent.mkdir(parents=True, exist_ok=True)
+
+    ffmpeg_args = OrderedDict(
+        [
+            ("-i", str(input_path)),
+            ("-acodec", codec),
+        ]
+    )
+
+    if log_level is not None:
+        ffmpeg_args["-loglevel"] = str(log_level)
+
+    ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
+    if overwrite:
+        ffmpeg_args.append("-y")
+
+    ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)]
+
+    # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
+    subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
+
+    if not output_path.exists():
+        raise OSError(
+            f"Video encoding did not work. File not found: {output_path}. "
+            f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
+        )
 
 def encode_video_frames(
     imgs_dir: Path | str,
     video_path: Path | str,
     fps: int,
+    audio_path: Path | str | None = None,
     vcodec: str = "libsvtav1",
     pix_fmt: str = "yuv420p",
     g: int | None = 2,
     crf: int | None = 30,
+    acodec: str = "aac",    #TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
     fast_decode: int = 0,
     log_level: str | None = "error",
     overwrite: bool = False,
@@ -260,35 +363,53 @@ def encode_video_frames(
     imgs_dir = Path(imgs_dir)
     video_path.parent.mkdir(parents=True, exist_ok=True)
 
-    ffmpeg_args = OrderedDict(
+    ffmpeg_video_args = OrderedDict(
         [
             ("-f", "image2"),
             ("-r", str(fps)),
-            ("-i", str(imgs_dir / "frame_%06d.png")),
-            ("-vcodec", vcodec),
-            ("-pix_fmt", pix_fmt),
+            ("-i", str(Path(imgs_dir) / "frame_%06d.png")),
         ]
     )
 
+    ffmpeg_audio_args = OrderedDict()
+    if audio_path is not None:
+        audio_path = Path(audio_path)
+        audio_path.parent.mkdir(parents=True, exist_ok=True)
+        ffmpeg_audio_args.update(OrderedDict(
+            [
+                ("-i", str(audio_path)),
+            ]
+        ))
+
+    ffmpeg_encoding_args = OrderedDict(
+        [
+            ("-pix_fmt", pix_fmt),
+            ("-vcodec", vcodec),    
+        ]
+    )
     if g is not None:
-        ffmpeg_args["-g"] = str(g)
-
+        ffmpeg_encoding_args["-g"] = str(g)
     if crf is not None:
-        ffmpeg_args["-crf"] = str(crf)
-
+        ffmpeg_encoding_args["-crf"] = str(crf)
     if fast_decode:
         key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
         value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
-        ffmpeg_args[key] = value
+        ffmpeg_encoding_args[key] = value
 
+    if audio_path is not None:
+        ffmpeg_encoding_args["-acodec"] = acodec
+    
     if log_level is not None:
-        ffmpeg_args["-loglevel"] = str(log_level)
+        ffmpeg_encoding_args["-loglevel"] = str(log_level)
 
-    ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
+    ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair]
+    ffmpeg_args += [item for pair in ffmpeg_audio_args.items() for item in pair]
+    ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair]
     if overwrite:
         ffmpeg_args.append("-y")
 
     ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
+
     # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
     subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
 
@@ -366,7 +487,6 @@ def get_audio_info(video_path: Path | str) -> dict:
         "audio.channel_layout": audio_stream_info.get("channel_layout", None),
     }
 
-
 def get_video_info(video_path: Path | str) -> dict:
     ffprobe_video_cmd = [
         "ffprobe",
diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py
index 443466ed..afa4006a 100644
--- a/lerobot/common/robot_devices/robots/manipulator.py
+++ b/lerobot/common/robot_devices/robots/manipulator.py
@@ -216,7 +216,7 @@ class ManipulatorRobot:
 
     @property
     def features(self):
-        return {**self.motor_features, **self.camera_features}
+        return {**self.motor_features, **self.camera_features, **self.microphones_features}
 
     @property
     def has_camera(self):
diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py
index 531977da..fbd7480f 100644
--- a/tests/fixtures/dataset_factories.py
+++ b/tests/fixtures/dataset_factories.py
@@ -29,6 +29,7 @@ from lerobot.common.datasets.utils import (
     DEFAULT_FEATURES,
     DEFAULT_PARQUET_PATH,
     DEFAULT_VIDEO_PATH,
+    DEFAULT_COMPRESSED_AUDIO_PATH,
     get_hf_features_from_features,
     hf_transform_to_torch,
 )
@@ -121,6 +122,7 @@ def info_factory(features_factory):
         chunks_size: int = DEFAULT_CHUNK_SIZE,
         data_path: str = DEFAULT_PARQUET_PATH,
         video_path: str = DEFAULT_VIDEO_PATH,
+        audio_path: str = DEFAULT_COMPRESSED_AUDIO_PATH,
         motor_features: dict = DUMMY_MOTOR_FEATURES,
         camera_features: dict = DUMMY_CAMERA_FEATURES,
         use_videos: bool = True,
@@ -139,6 +141,7 @@ def info_factory(features_factory):
             "splits": {},
             "data_path": data_path,
             "video_path": video_path if use_videos else None,
+            "audio_path": audio_path,
             "features": features,
         }