fix simxarm factory
This commit is contained in:
parent
96c53ad06f
commit
63d18475cc
|
@ -11,7 +11,6 @@ def make_env(cfg):
|
|||
"from_pixels": cfg.from_pixels,
|
||||
"pixels_only": cfg.pixels_only,
|
||||
"image_size": cfg.image_size,
|
||||
"max_episode_length": cfg.episode_length,
|
||||
}
|
||||
|
||||
if cfg.env == "simxarm":
|
||||
|
|
|
@ -29,14 +29,12 @@ class PushtEnv(EnvBase):
|
|||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
max_episode_length=300,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
self.max_episode_length = max_episode_length
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
|
|
|
@ -80,7 +80,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
alpha=cfg.per_alpha,
|
||||
beta=cfg.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
strict_length=True,
|
||||
)
|
||||
|
||||
online_buffer = TensorDictReplayBuffer(
|
||||
|
|
Loading…
Reference in New Issue