fix simxarm factory

This commit is contained in:
Cadene 2024-02-22 13:04:24 +00:00
parent 96c53ad06f
commit 63d18475cc
3 changed files with 1 additions and 4 deletions

View File

@ -11,7 +11,6 @@ def make_env(cfg):
"from_pixels": cfg.from_pixels,
"pixels_only": cfg.pixels_only,
"image_size": cfg.image_size,
"max_episode_length": cfg.episode_length,
}
if cfg.env == "simxarm":

View File

@ -29,14 +29,12 @@ class PushtEnv(EnvBase):
image_size=None,
seed=1337,
device="cpu",
max_episode_length=300,
):
super().__init__(device=device, batch_size=[])
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
self.max_episode_length = max_episode_length
if pixels_only:
assert from_pixels

View File

@ -80,7 +80,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
alpha=cfg.per_alpha,
beta=cfg.per_beta,
num_slices=num_traj_per_batch,
strict_length=False,
strict_length=True,
)
online_buffer = TensorDictReplayBuffer(