Merge pull request #46 from huggingface/user/rcadene/2024_03_23_update_stats_v1.2
Fix bug with stats.pth + Move from cadene to lerobot + Update datasets to v1.2
This commit is contained in:
commit
f3cfc8b3b4
|
@ -14,6 +14,8 @@ from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||||
from torchrl.envs.transforms.transforms import Compose
|
from torchrl.envs.transforms.transforms import Compose
|
||||||
|
|
||||||
|
HF_USER = "lerobot"
|
||||||
|
|
||||||
|
|
||||||
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -106,7 +108,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||||
if self.root is None:
|
if self.root is None:
|
||||||
self.data_dir = Path(
|
self.data_dir = Path(
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=f"cadene/{self.dataset_id}", repo_type="dataset", revision=self.version
|
repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -84,7 +84,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
version: str | None = "v1.1",
|
version: str | None = "v1.2",
|
||||||
batch_size: int = None,
|
batch_size: int = None,
|
||||||
*,
|
*,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
|
|
|
@ -87,7 +87,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
version: str | None = "v1.1",
|
version: str | None = "v1.2",
|
||||||
batch_size: int = None,
|
batch_size: int = None,
|
||||||
*,
|
*,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -22,6 +22,9 @@ from .utils import DEVICE, init_config
|
||||||
("simxarm", "diffusion", []),
|
("simxarm", "diffusion", []),
|
||||||
("pusht", "diffusion", []),
|
("pusht", "diffusion", []),
|
||||||
("aloha", "act", ["env.task=sim_insertion_scripted"]),
|
("aloha", "act", ["env.task=sim_insertion_scripted"]),
|
||||||
|
("aloha", "act", ["env.task=sim_insertion_human"]),
|
||||||
|
("aloha", "act", ["env.task=sim_transfer_cube_scripted"]),
|
||||||
|
("aloha", "act", ["env.task=sim_transfer_cube_human"]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_concrete_policy(env_name, policy_name, extra_overrides):
|
def test_concrete_policy(env_name, policy_name, extra_overrides):
|
||||||
|
|
Loading…
Reference in New Issue