rm EpisodeSampler from viz (#389)
This commit is contained in:
parent
04a995e7d1
commit
114e09f570
|
@ -57,7 +57,6 @@ import logging
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from flask import Flask, redirect, render_template, url_for
|
from flask import Flask, redirect, render_template, url_for
|
||||||
|
|
||||||
|
@ -65,19 +64,6 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.utils.utils import init_logging
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSampler(torch.utils.data.Sampler):
|
|
||||||
def __init__(self, dataset, episode_index):
|
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
|
||||||
self.frame_ids = range(from_idx, to_idx)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self.frame_ids)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.frame_ids)
|
|
||||||
|
|
||||||
|
|
||||||
def run_server(
|
def run_server(
|
||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset,
|
||||||
episodes: list[int],
|
episodes: list[int],
|
||||||
|
|
Loading…
Reference in New Issue