diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index b93b519b..5d2f5e38 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -8,13 +8,13 @@ import pymunk import torch import torchrl import tqdm +from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer +from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely from tensordict import TensorDict from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.storages import TensorStorage from torchrl.data.replay_buffers.writers import Writer -from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer -from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely from lerobot.common.datasets.abstract import AbstractExperienceReplay from lerobot.common.datasets.utils import download_and_extract_zip @@ -111,7 +111,7 @@ class PushtExperienceReplay(AbstractExperienceReplay): ) def _download_and_preproc(self): - raw_dir = self.data_dir / "raw" + raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" zarr_path = (raw_dir / PUSHT_ZARR).resolve() if not zarr_path.is_dir(): raw_dir.mkdir(parents=True, exist_ok=True)