fix simxarm factory

This commit is contained in:
Cadene 2024-02-22 13:04:24 +00:00
parent 96c53ad06f
commit 63d18475cc
3 changed files with 1 additions and 4 deletions

View File

@ -11,7 +11,6 @@ def make_env(cfg):
"from_pixels": cfg.from_pixels, "from_pixels": cfg.from_pixels,
"pixels_only": cfg.pixels_only, "pixels_only": cfg.pixels_only,
"image_size": cfg.image_size, "image_size": cfg.image_size,
"max_episode_length": cfg.episode_length,
} }
if cfg.env == "simxarm": if cfg.env == "simxarm":

View File

@ -29,14 +29,12 @@ class PushtEnv(EnvBase):
image_size=None, image_size=None,
seed=1337, seed=1337,
device="cpu", device="cpu",
max_episode_length=300,
): ):
super().__init__(device=device, batch_size=[]) super().__init__(device=device, batch_size=[])
self.frame_skip = frame_skip self.frame_skip = frame_skip
self.from_pixels = from_pixels self.from_pixels = from_pixels
self.pixels_only = pixels_only self.pixels_only = pixels_only
self.image_size = image_size self.image_size = image_size
self.max_episode_length = max_episode_length
if pixels_only: if pixels_only:
assert from_pixels assert from_pixels

View File

@ -80,7 +80,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
alpha=cfg.per_alpha, alpha=cfg.per_alpha,
beta=cfg.per_beta, beta=cfg.per_beta,
num_slices=num_traj_per_batch, num_slices=num_traj_per_batch,
strict_length=False, strict_length=True,
) )
online_buffer = TensorDictReplayBuffer( online_buffer = TensorDictReplayBuffer(