fix pusht data_dir path
This commit is contained in:
parent
54b05bfb77
commit
f1e2837d63
|
@ -8,13 +8,13 @@ import pymunk
|
|||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||
|
||||
|
@ -111,7 +111,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
|||
)
|
||||
|
||||
def _download_and_preproc(self):
|
||||
raw_dir = self.data_dir / "raw"
|
||||
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
Loading…
Reference in New Issue