diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py
index 853acbc3..71d961a1 100644
--- a/lerobot/common/logger.py
+++ b/lerobot/common/logger.py
@@ -233,9 +233,6 @@ class Logger:
         if self._wandb is not None:
             for k, v in d.items():
                 if not isinstance(v, (int, float, str)):
-                    logging.warning(
-                        f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
-                    )
                     continue
                 self._wandb.log({f"{mode}/{k}": v}, step=step)
 
diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py
index bef59bec..a3dc4ccb 100644
--- a/lerobot/common/policies/act/modeling_act.py
+++ b/lerobot/common/policies/act/modeling_act.py
@@ -139,25 +139,26 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
         batch = self.normalize_targets(batch)
         actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
 
-        l1_loss = (
-            F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
-        ).mean()
+        bsize = actions_hat.shape[0]
+        l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
+        l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
+        l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
+
+        out_dict = {}
+        out_dict["l1_loss"] = l1_loss
 
-        loss_dict = {"l1_loss": l1_loss.item()}
         if self.config.use_vae:
             # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for
             # each dimension independently, we sum over the latent dimension to get the total
             # KL-divergence per batch element, then take the mean over the batch.
             # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
-            mean_kld = (
-                (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
-            )
-            loss_dict["kld_loss"] = mean_kld.item()
-            loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
+            kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
+            out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
         else:
-            loss_dict["loss"] = l1_loss
+            out_dict["loss"] = l1_loss
 
-        return loss_dict
+        out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
+        return out_dict
 
 
 class ACT(nn.Module):
diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py
index 860412bd..e63a5633 100644
--- a/lerobot/scripts/train.py
+++ b/lerobot/scripts/train.py
@@ -107,7 +107,7 @@ def update_policy(
     with torch.autocast(device_type=device.type) if use_amp else nullcontext():
         output_dict = policy.forward(batch)
         # TODO(rcadene): policy.unnormalize_outputs(out_dict)
-        loss = output_dict["loss"]
+        loss = output_dict["loss"].mean()
     grad_scaler.scale(loss).backward()
 
     # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py
index 58da6a47..6534343d 100644
--- a/lerobot/scripts/visualize_dataset.py
+++ b/lerobot/scripts/visualize_dataset.py
@@ -30,48 +30,46 @@ Examples:
 - Visualize data stored on a local machine:
 ```
 local$ python lerobot/scripts/visualize_dataset.py \
-    --repo-id lerobot/pusht \
-    --episode-index 0
+    --repo-id lerobot/pusht
+
+local$ open http://localhost:9090
 ```
 
 - Visualize data stored on a distant machine with a local viewer:
 ```
 distant$ python lerobot/scripts/visualize_dataset.py \
+    --repo-id lerobot/pusht
+
+local$ ssh -L 9090:localhost:9090 distant  # create a ssh tunnel
+local$ open http://localhost:9090
+```
+
+- Select episodes to visualize:
+```
+python lerobot/scripts/visualize_dataset.py \
     --repo-id lerobot/pusht \
-    --episode-index 0 \
-    --save 1 \
-    --output-dir path/to/directory
-
-local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
-local$ rerun lerobot_pusht_episode_0.rrd
+    --episode-indices 7 3 5 1 4
 ```
-
-- Visualize data stored on a distant machine through streaming:
-(You need to forward the websocket port to the distant machine, with
-`ssh -L 9087:localhost:9087 username@remote-host`)
-```
-distant$ python lerobot/scripts/visualize_dataset.py \
-    --repo-id lerobot/pusht \
-    --episode-index 0 \
-    --mode distant \
-    --ws-port 9087
-
-local$ rerun ws://localhost:9087
-```
-
 """
 
 import argparse
-import gc
+import http.server
 import logging
-import time
+import os
+import shutil
+import socketserver
 from pathlib import Path
 
-import rerun as rr
 import torch
 import tqdm
+import yaml
+from bs4 import BeautifulSoup
+from huggingface_hub import snapshot_download
+from safetensors.torch import load_file, save_file
 
 from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.common.policies.act.modeling_act import ACTPolicy
+from lerobot.common.utils.utils import init_logging
 
 
 class EpisodeSampler(torch.utils.data.Sampler):
@@ -87,33 +85,307 @@ class EpisodeSampler(torch.utils.data.Sampler):
         return len(self.frame_ids)
 
 
-def to_hwc_uint8_numpy(chw_float32_torch):
-    assert chw_float32_torch.dtype == torch.float32
-    assert chw_float32_torch.ndim == 3
-    c, h, w = chw_float32_torch.shape
-    assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
-    hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
-    return hwc_uint8_numpy
+class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
+    def end_headers(self):
+        self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
+        self.send_header("Pragma", "no-cache")
+        self.send_header("Expires", "0")
+        super().end_headers()
 
 
-def visualize_dataset(
-    repo_id: str,
-    episode_index: int,
-    batch_size: int = 32,
-    num_workers: int = 0,
-    mode: str = "local",
-    web_port: int = 9090,
-    ws_port: int = 9087,
-    save: bool = False,
-    output_dir: Path | None = None,
-) -> Path | None:
-    if save:
-        assert (
-            output_dir is not None
-        ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
+def run_server(path, port):
+    # Change directory to serve 'index.html` as front page
+    os.chdir(path)
 
-    logging.info("Loading dataset")
-    dataset = LeRobotDataset(repo_id)
+    with socketserver.TCPServer(("", port), NoCacheHTTPRequestHandler) as httpd:
+        logging.info(f"Serving HTTP on 0.0.0.0 port {port} (http://0.0.0.0:{port}/) ...")
+        httpd.serve_forever()
+
+
+def create_html_page(page_title: str):
+    """Create a html page with beautiful soop with default doctype, meta, header and title."""
+    soup = BeautifulSoup("", "html.parser")
+
+    doctype = soup.new_tag("!DOCTYPE html")
+    soup.append(doctype)
+
+    html = soup.new_tag("html", lang="en")
+    soup.append(html)
+
+    head = soup.new_tag("head")
+    html.append(head)
+
+    meta_charset = soup.new_tag("meta", charset="UTF-8")
+    head.append(meta_charset)
+
+    meta_viewport = soup.new_tag(
+        "meta", attrs={"name": "viewport", "content": "width=device-width, initial-scale=1.0"}
+    )
+    head.append(meta_viewport)
+
+    title = soup.new_tag("title")
+    title.string = page_title
+    head.append(title)
+
+    body = soup.new_tag("body")
+    html.append(body)
+
+    main_div = soup.new_tag("div")
+    body.append(main_div)
+    return soup, head, body
+
+
+def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None):
+    """Write a csv file containg timeseries data of an episode (e.g. state and action).
+    This file will be loaded by Dygraph javascript to plot data in real time."""
+    from_idx = dataset.episode_data_index["from"][episode_index]
+    to_idx = dataset.episode_data_index["to"][episode_index]
+
+    has_state = "observation.state" in dataset.hf_dataset.features
+    has_action = "action" in dataset.hf_dataset.features
+    has_inference = inference_results is not None
+
+    # init header of csv with state and action names
+    header = ["timestamp"]
+    if has_state:
+        dim_state = len(dataset.hf_dataset["observation.state"][0])
+        header += [f"state_{i}" for i in range(dim_state)]
+    if has_action:
+        dim_action = len(dataset.hf_dataset["action"][0])
+        header += [f"action_{i}" for i in range(dim_action)]
+    if has_inference:
+        assert "actions" in inference_results
+        assert "loss" in inference_results
+        dim_pred_action = inference_results["actions"].shape[2]
+        header += [f"pred_action_{i}" for i in range(dim_pred_action)]
+        header += ["loss"]
+
+    columns = ["timestamp"]
+    if has_state:
+        columns += ["observation.state"]
+    if has_action:
+        columns += ["action"]
+
+    rows = []
+    data = dataset.hf_dataset.select_columns(columns)
+    for i in range(from_idx, to_idx):
+        row = [data[i]["timestamp"].item()]
+        if has_state:
+            row += data[i]["observation.state"].tolist()
+        if has_action:
+            row += data[i]["action"].tolist()
+        rows.append(row)
+
+    if has_inference:
+        num_frames = len(rows)
+        assert num_frames == inference_results["actions"].shape[0]
+        assert num_frames == inference_results["loss"].shape[0]
+        for i in range(num_frames):
+            rows[i] += inference_results["actions"][i, 0].tolist()
+            rows[i] += [inference_results["loss"][i].item()]
+
+    output_dir.mkdir(parents=True, exist_ok=True)
+    with open(output_dir / file_name, "w") as f:
+        f.write(",".join(header) + "\n")
+        for row in rows:
+            row_str = [str(col) for col in row]
+            f.write(",".join(row_str) + "\n")
+
+
+def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset):
+    """Write a javascript file containing logic to synchronize camera feeds and timeseries."""
+    s = ""
+    s += "document.addEventListener('DOMContentLoaded', function () {\n"
+    for i, key in enumerate(dataset.video_frame_keys):
+        s += f"  const video{i} = document.getElementById('video_{key}');\n"
+    s += "  const slider = document.getElementById('videoControl');\n"
+    s += "  const playButton = document.getElementById('playButton');\n"
+    s += f"  const dygraph = new Dygraph(document.getElementById('graph'), '{ep_csv_fname}', " + "{\n"
+    s += "    pixelsPerPoint: 0.01,\n"
+    s += "    legend: 'always',\n"
+    s += "    labelsDiv: document.getElementById('labels'),\n"
+    s += "    labelsSeparateLines: true,\n"
+    s += "    labelsKMB: true,\n"
+    s += "    highlightCircleSize: 1.5,\n"
+    s += "    highlightSeriesOpts: {\n"
+    s += "        strokeWidth: 1.5,\n"
+    s += "        strokeBorderWidth: 1,\n"
+    s += "        highlightCircleSize: 3\n"
+    s += "    }\n"
+    s += "  });\n"
+    s += "\n"
+    s += "  // Function to play both videos\n"
+    s += "  playButton.addEventListener('click', function () {\n"
+    for i in range(len(dataset.video_frame_keys)):
+        s += f"    video{i}.play();\n"
+    s += "    // playButton.disabled = true; // Optional: disable button after playing\n"
+    s += "  });\n"
+    s += "\n"
+    s += "  // Update the video time when the slider value changes\n"
+    s += "  slider.addEventListener('input', function () {\n"
+    s += "    const sliderValue = slider.value;\n"
+    for i in range(len(dataset.video_frame_keys)):
+        s += f"    const time{i} = (video{i}.duration * sliderValue) / 100;\n"
+    for i in range(len(dataset.video_frame_keys)):
+        s += f"    video{i}.currentTime = time{i};\n"
+    s += "  });\n"
+    s += "\n"
+    s += "  // Synchronize slider with the video's current time\n"
+    s += "  const syncSlider = (video) => {\n"
+    s += "    video.addEventListener('timeupdate', function () {\n"
+    s += "      if (video.duration) {\n"
+    s += "        const pc = (100 / video.duration) * video.currentTime;\n"
+    s += "        slider.value = pc;\n"
+    s += "        const index = Math.floor(pc * dygraph.numRows() / 100);\n"
+    s += "        dygraph.setSelection(index, undefined, true, true);\n"
+    s += "      }\n"
+    s += "    });\n"
+    s += "  };\n"
+    s += "\n"
+    for i in range(len(dataset.video_frame_keys)):
+        s += f"  syncSlider(video{i});\n"
+    s += "\n"
+    s += "});\n"
+
+    output_dir.mkdir(parents=True, exist_ok=True)
+    with open(output_dir / file_name, "w", encoding="utf-8") as f:
+        f.write(s)
+
+
+def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset):
+    """Write an html file containg video feeds and timeseries associated to an episode."""
+    soup, head, body = create_html_page("")
+
+    css_style = soup.new_tag("style")
+    css_style.string = ""
+    css_style.string += "#labels > span.highlight {\n"
+    css_style.string += "  border: 1px solid grey;\n"
+    css_style.string += "}"
+    head.append(css_style)
+
+    # Add videos from camera feeds
+
+    videos_control_div = soup.new_tag("div")
+    body.append(videos_control_div)
+
+    videos_div = soup.new_tag("div")
+    videos_control_div.append(videos_div)
+
+    def create_video(id, src):
+        video = soup.new_tag("video", id=id, width="320", height="240", controls="")
+        source = soup.new_tag("source", src=src, type="video/mp4")
+        video.string = "Your browser does not support the video tag."
+        video.append(source)
+        return video
+
+    # get first frame of episode (hack to get video_path of the episode)
+    first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
+
+    for key in dataset.video_frame_keys:
+        # Example of video_path: 'videos/observation.image_episode_000004.mp4'
+        video_path = dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
+        videos_div.append(create_video(f"video_{key}", video_path))
+
+    # Add controls for videos and graph
+
+    control_div = soup.new_tag("div")
+    videos_control_div.append(control_div)
+
+    button_div = soup.new_tag("div")
+    control_div.append(button_div)
+
+    button = soup.new_tag("button", id="playButton")
+    button.string = "Play Videos"
+    button_div.append(button)
+
+    slider_div = soup.new_tag("div")
+    control_div.append(slider_div)
+
+    slider = soup.new_tag("input", type="range", id="videoControl", min="0", max="100", value="0", step="1")
+    control_div.append(slider)
+
+    # Add graph of states/actions, and its labels
+
+    graph_labels_div = soup.new_tag("div", style="display: flex;")
+    body.append(graph_labels_div)
+
+    graph_div = soup.new_tag("div", id="graph", style="flex: 1; width: 85%")
+    graph_labels_div.append(graph_div)
+
+    labels_div = soup.new_tag("div", id="labels", style="flex: 1; width: 15%")
+    graph_labels_div.append(labels_div)
+
+    # add dygraph library
+    script = soup.new_tag("script", type="text/javascript", src=js_fname)
+    body.append(script)
+
+    script_dygraph = soup.new_tag(
+        "script",
+        type="text/javascript",
+        src="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.js",
+    )
+    body.append(script_dygraph)
+
+    link_dygraph = soup.new_tag(
+        "link", rel="stylesheet", href="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.css"
+    )
+    body.append(link_dygraph)
+
+    # Write as a html file
+
+    output_dir.mkdir(parents=True, exist_ok=True)
+    with open(output_dir / file_name, "w", encoding="utf-8") as f:
+        f.write(soup.prettify())
+
+
+def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset):
+    """Write an html file containing information related to the dataset and a list of links to
+    html pages of episodes."""
+    soup, head, body = create_html_page("TODO")
+
+    h3 = soup.new_tag("h3")
+    h3.string = "TODO"
+    body.append(h3)
+
+    ul_info = soup.new_tag("ul")
+    body.append(ul_info)
+
+    li_info = soup.new_tag("li")
+    li_info.string = f"Number of samples/frames: {dataset.num_samples}"
+    ul_info.append(li_info)
+
+    li_info = soup.new_tag("li")
+    li_info.string = f"Number of episodes: {dataset.num_episodes}"
+    ul_info.append(li_info)
+
+    li_info = soup.new_tag("li")
+    li_info.string = f"Frames per second: {dataset.fps}"
+    ul_info.append(li_info)
+
+    # li_info = soup.new_tag("li")
+    # li_info.string = f"Size: {format_big_number(dataset.hf_dataset.info.size_in_bytes)}B"
+    # ul_info.append(li_info)
+
+    ul = soup.new_tag("ul")
+    body.append(ul)
+
+    for ep_idx, ep_html_fname in zip(ep_indices, ep_html_fnames, strict=False):
+        li = soup.new_tag("li")
+        ul.append(li)
+
+        a = soup.new_tag("a", href=ep_html_fname)
+        a.string = f"Episode number {ep_idx}"
+
+        li.append(a)
+
+    output_dir.mkdir(parents=True, exist_ok=True)
+    with open(output_dir / file_name, "w", encoding="utf-8") as f:
+        f.write(soup.prettify())
+
+
+def run_inference(dataset, episode_index, policy, num_workers=4, batch_size=32, device="cuda"):
+    policy.eval()
+    policy.to(device)
 
     logging.info("Loading dataloader")
     episode_sampler = EpisodeSampler(dataset, episode_index)
@@ -124,70 +396,104 @@ def visualize_dataset(
         sampler=episode_sampler,
     )
 
-    logging.info("Starting Rerun")
-
-    if mode not in ["local", "distant"]:
-        raise ValueError(mode)
-
-    spawn_local_viewer = mode == "local" and not save
-    rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
-
-    # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
-    # when iterating on a dataloader with `num_workers` > 0
-    # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
-    gc.collect()
-
-    if mode == "distant":
-        rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
-
-    logging.info("Logging to Rerun")
-
+    logging.info("Running inference")
+    inference_results = {}
     for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
-        # iterate over the batch
-        for i in range(len(batch["index"])):
-            rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
-            rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
+        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
+        with torch.inference_mode():
+            output_dict = policy.forward(batch)
 
-            # display each camera image
-            for key in dataset.camera_keys:
-                # TODO(rcadene): add `.compress()`? is it lossless?
-                rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
+        for key in output_dict:
+            if key not in inference_results:
+                inference_results[key] = []
+            inference_results[key].append(output_dict[key].to("cpu"))
 
-            # display each dimension of action space (e.g. actuators command)
-            if "action" in batch:
-                for dim_idx, val in enumerate(batch["action"][i]):
-                    rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
+    for key in inference_results:
+        inference_results[key] = torch.cat(inference_results[key])
 
-            # display each dimension of observed state space (e.g. agent position in joint space)
-            if "observation.state" in batch:
-                for dim_idx, val in enumerate(batch["observation.state"][i]):
-                    rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
+    return inference_results
 
-            if "next.done" in batch:
-                rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
 
-            if "next.reward" in batch:
-                rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
+def visualize_dataset(
+    repo_id: str,
+    episode_indices: list[int] = None,
+    output_dir: Path | None = None,
+    serve: bool = True,
+    port: int = 9090,
+    force_overwrite: bool = True,
+    policy_repo_id: str | None = None,
+    policy_ckpt_path: Path | None = None,
+    batch_size: int = 32,
+    num_workers: int = 4,
+) -> Path | None:
+    init_logging()
 
-            if "next.success" in batch:
-                rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
+    has_policy = policy_repo_id or policy_ckpt_path
 
-    if mode == "local" and save:
-        # save .rrd locally
-        output_dir = Path(output_dir)
-        output_dir.mkdir(parents=True, exist_ok=True)
-        repo_id_str = repo_id.replace("/", "_")
-        rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
-        rr.save(rrd_path)
-        return rrd_path
+    if has_policy:
+        logging.info("Loading policy")
+        if policy_repo_id:
+            pretrained_policy_path = Path(snapshot_download(policy_repo_id))
+        elif policy_ckpt_path:
+            pretrained_policy_path = Path(policy_ckpt_path)
+        policy = ACTPolicy.from_pretrained(pretrained_policy_path)
+        with open(pretrained_policy_path / "config.yaml") as f:
+            cfg = yaml.safe_load(f)
+        delta_timestamps = cfg["training"]["delta_timestamps"]
+    else:
+        delta_timestamps = None
 
-    elif mode == "distant":
-        # stop the process from exiting since it is serving the websocket connection
-        try:
-            while True:
-                time.sleep(1)
-        except KeyboardInterrupt:
-            print("Ctrl-C received. Exiting.")
+    logging.info("Loading dataset")
+    dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
+
+    if not dataset.video:
+        raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
+
+    if output_dir is None:
+        output_dir = f"outputs/visualize_dataset/{repo_id}"
+
+    output_dir = Path(output_dir)
+    if force_overwrite and output_dir.exists():
+        shutil.rmtree(output_dir)
+    output_dir.mkdir(parents=True, exist_ok=True)
+
+    # Create a simlink from the dataset video folder containg mp4 files to the output directory
+    # so that the http server can get access to the mp4 files.
+    ln_videos_dir = output_dir / "videos"
+    if not ln_videos_dir.exists():
+        ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
+
+    if episode_indices is None:
+        episode_indices = list(range(dataset.num_episodes))
+
+    logging.info("Writing html")
+    ep_html_fnames = []
+    for episode_index in tqdm.tqdm(episode_indices):
+        inference_results = None
+        if has_policy:
+            inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
+            if inference_results_path.exists():
+                inference_results = load_file(inference_results_path)
+            else:
+                inference_results = run_inference(dataset, episode_index, policy)
+                save_file(inference_results, inference_results_path)
+
+        # write states and actions in a csv
+        ep_csv_fname = f"episode_{episode_index}.csv"
+        write_episode_data_csv(output_dir, ep_csv_fname, episode_index, dataset, inference_results)
+
+        js_fname = f"episode_{episode_index}.js"
+        write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset)
+
+        # write a html page to view videos and timeseries
+        ep_html_fname = f"episode_{episode_index}.html"
+        write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_index, dataset)
+        ep_html_fnames.append(ep_html_fname)
+
+    write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset)
+
+    if serve:
+        run_server(output_dir, port)
 
 
 def main():
@@ -197,13 +503,51 @@ def main():
         "--repo-id",
         type=str,
         required=True,
-        help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
+        help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
     )
     parser.add_argument(
-        "--episode-index",
+        "--episode-indices",
         type=int,
-        required=True,
-        help="Episode to visualize.",
+        nargs="*",
+        default=None,
+        help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
+    )
+    parser.add_argument(
+        "--output-dir",
+        type=str,
+        default=None,
+        help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
+    )
+    parser.add_argument(
+        "--serve",
+        type=int,
+        default=1,
+        help="Launch web server.",
+    )
+    parser.add_argument(
+        "--port",
+        type=int,
+        default=9090,
+        help="Web port used by the http server.",
+    )
+    parser.add_argument(
+        "--force-overwrite",
+        type=int,
+        default=1,
+        help="Delete the output directory if it exists already.",
+    )
+
+    parser.add_argument(
+        "--policy-repo-id",
+        type=str,
+        default=None,
+        help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
+    )
+    parser.add_argument(
+        "--policy-ckpt-path",
+        type=str,
+        default=None,
+        help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
     )
     parser.add_argument(
         "--batch-size",
@@ -217,43 +561,6 @@ def main():
         default=4,
         help="Number of processes of Dataloader for loading the data.",
     )
-    parser.add_argument(
-        "--mode",
-        type=str,
-        default="local",
-        help=(
-            "Mode of viewing between 'local' or 'distant'. "
-            "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
-            "'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
-        ),
-    )
-    parser.add_argument(
-        "--web-port",
-        type=int,
-        default=9090,
-        help="Web port for rerun.io when `--mode distant` is set.",
-    )
-    parser.add_argument(
-        "--ws-port",
-        type=int,
-        default=9087,
-        help="Web socket port for rerun.io when `--mode distant` is set.",
-    )
-    parser.add_argument(
-        "--save",
-        type=int,
-        default=0,
-        help=(
-            "Save a .rrd file in the directory provided by `--output-dir`. "
-            "It also deactivates the spawning of a viewer. ",
-            "Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
-        ),
-    )
-    parser.add_argument(
-        "--output-dir",
-        type=str,
-        help="Directory path to write a .rrd file when `--save 1` is set.",
-    )
 
     args = parser.parse_args()
     visualize_dataset(**vars(args))
diff --git a/lerobot/scripts/visualize_dataset_rerun.py b/lerobot/scripts/visualize_dataset_rerun.py
new file mode 100644
index 00000000..58da6a47
--- /dev/null
+++ b/lerobot/scripts/visualize_dataset_rerun.py
@@ -0,0 +1,263 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
+
+Note: The last frame of the episode doesnt always correspond to a final state.
+That's because our datasets are composed of transition from state to state up to
+the antepenultimate state associated to the ultimate action to arrive in the final state.
+However, there might not be a transition from a final state to another state.
+
+Note: This script aims to visualize the data used to train the neural networks.
+~What you see is what you get~. When visualizing image modality, it is often expected to observe
+lossly compression artifacts since these images have been decoded from compressed mp4 videos to
+save disk space. The compression factor applied has been tuned to not affect success rate.
+
+Examples:
+
+- Visualize data stored on a local machine:
+```
+local$ python lerobot/scripts/visualize_dataset.py \
+    --repo-id lerobot/pusht \
+    --episode-index 0
+```
+
+- Visualize data stored on a distant machine with a local viewer:
+```
+distant$ python lerobot/scripts/visualize_dataset.py \
+    --repo-id lerobot/pusht \
+    --episode-index 0 \
+    --save 1 \
+    --output-dir path/to/directory
+
+local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
+local$ rerun lerobot_pusht_episode_0.rrd
+```
+
+- Visualize data stored on a distant machine through streaming:
+(You need to forward the websocket port to the distant machine, with
+`ssh -L 9087:localhost:9087 username@remote-host`)
+```
+distant$ python lerobot/scripts/visualize_dataset.py \
+    --repo-id lerobot/pusht \
+    --episode-index 0 \
+    --mode distant \
+    --ws-port 9087
+
+local$ rerun ws://localhost:9087
+```
+
+"""
+
+import argparse
+import gc
+import logging
+import time
+from pathlib import Path
+
+import rerun as rr
+import torch
+import tqdm
+
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+
+
+class EpisodeSampler(torch.utils.data.Sampler):
+    def __init__(self, dataset, episode_index):
+        from_idx = dataset.episode_data_index["from"][episode_index].item()
+        to_idx = dataset.episode_data_index["to"][episode_index].item()
+        self.frame_ids = range(from_idx, to_idx)
+
+    def __iter__(self):
+        return iter(self.frame_ids)
+
+    def __len__(self):
+        return len(self.frame_ids)
+
+
+def to_hwc_uint8_numpy(chw_float32_torch):
+    assert chw_float32_torch.dtype == torch.float32
+    assert chw_float32_torch.ndim == 3
+    c, h, w = chw_float32_torch.shape
+    assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
+    hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
+    return hwc_uint8_numpy
+
+
+def visualize_dataset(
+    repo_id: str,
+    episode_index: int,
+    batch_size: int = 32,
+    num_workers: int = 0,
+    mode: str = "local",
+    web_port: int = 9090,
+    ws_port: int = 9087,
+    save: bool = False,
+    output_dir: Path | None = None,
+) -> Path | None:
+    if save:
+        assert (
+            output_dir is not None
+        ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
+
+    logging.info("Loading dataset")
+    dataset = LeRobotDataset(repo_id)
+
+    logging.info("Loading dataloader")
+    episode_sampler = EpisodeSampler(dataset, episode_index)
+    dataloader = torch.utils.data.DataLoader(
+        dataset,
+        num_workers=num_workers,
+        batch_size=batch_size,
+        sampler=episode_sampler,
+    )
+
+    logging.info("Starting Rerun")
+
+    if mode not in ["local", "distant"]:
+        raise ValueError(mode)
+
+    spawn_local_viewer = mode == "local" and not save
+    rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
+
+    # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
+    # when iterating on a dataloader with `num_workers` > 0
+    # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
+    gc.collect()
+
+    if mode == "distant":
+        rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
+
+    logging.info("Logging to Rerun")
+
+    for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
+        # iterate over the batch
+        for i in range(len(batch["index"])):
+            rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
+            rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
+
+            # display each camera image
+            for key in dataset.camera_keys:
+                # TODO(rcadene): add `.compress()`? is it lossless?
+                rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
+
+            # display each dimension of action space (e.g. actuators command)
+            if "action" in batch:
+                for dim_idx, val in enumerate(batch["action"][i]):
+                    rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
+
+            # display each dimension of observed state space (e.g. agent position in joint space)
+            if "observation.state" in batch:
+                for dim_idx, val in enumerate(batch["observation.state"][i]):
+                    rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
+
+            if "next.done" in batch:
+                rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
+
+            if "next.reward" in batch:
+                rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
+
+            if "next.success" in batch:
+                rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
+
+    if mode == "local" and save:
+        # save .rrd locally
+        output_dir = Path(output_dir)
+        output_dir.mkdir(parents=True, exist_ok=True)
+        repo_id_str = repo_id.replace("/", "_")
+        rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
+        rr.save(rrd_path)
+        return rrd_path
+
+    elif mode == "distant":
+        # stop the process from exiting since it is serving the websocket connection
+        try:
+            while True:
+                time.sleep(1)
+        except KeyboardInterrupt:
+            print("Ctrl-C received. Exiting.")
+
+
+def main():
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument(
+        "--repo-id",
+        type=str,
+        required=True,
+        help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
+    )
+    parser.add_argument(
+        "--episode-index",
+        type=int,
+        required=True,
+        help="Episode to visualize.",
+    )
+    parser.add_argument(
+        "--batch-size",
+        type=int,
+        default=32,
+        help="Batch size loaded by DataLoader.",
+    )
+    parser.add_argument(
+        "--num-workers",
+        type=int,
+        default=4,
+        help="Number of processes of Dataloader for loading the data.",
+    )
+    parser.add_argument(
+        "--mode",
+        type=str,
+        default="local",
+        help=(
+            "Mode of viewing between 'local' or 'distant'. "
+            "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
+            "'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
+        ),
+    )
+    parser.add_argument(
+        "--web-port",
+        type=int,
+        default=9090,
+        help="Web port for rerun.io when `--mode distant` is set.",
+    )
+    parser.add_argument(
+        "--ws-port",
+        type=int,
+        default=9087,
+        help="Web socket port for rerun.io when `--mode distant` is set.",
+    )
+    parser.add_argument(
+        "--save",
+        type=int,
+        default=0,
+        help=(
+            "Save a .rrd file in the directory provided by `--output-dir`. "
+            "It also deactivates the spawning of a viewer. ",
+            "Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
+        ),
+    )
+    parser.add_argument(
+        "--output-dir",
+        type=str,
+        help="Directory path to write a .rrd file when `--save 1` is set.",
+    )
+
+    args = parser.parse_args()
+    visualize_dataset(**vars(args))
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tests/test_visualize_dataset.py b/tests/test_visualize_dataset.py
index 99954040..71819568 100644
--- a/tests/test_visualize_dataset.py
+++ b/tests/test_visualize_dataset.py
@@ -25,9 +25,8 @@ from lerobot.scripts.visualize_dataset import visualize_dataset
 def test_visualize_dataset(tmpdir, repo_id):
     rrd_path = visualize_dataset(
         repo_id,
-        episode_index=0,
-        batch_size=32,
-        save=True,
+        episode_indices=[0],
         output_dir=tmpdir,
+        serve=False,
     )
     assert rrd_path.exists()