fix simxarm factory
This commit is contained in:
parent
96c53ad06f
commit
63d18475cc
|
@ -11,7 +11,6 @@ def make_env(cfg):
|
||||||
"from_pixels": cfg.from_pixels,
|
"from_pixels": cfg.from_pixels,
|
||||||
"pixels_only": cfg.pixels_only,
|
"pixels_only": cfg.pixels_only,
|
||||||
"image_size": cfg.image_size,
|
"image_size": cfg.image_size,
|
||||||
"max_episode_length": cfg.episode_length,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.env == "simxarm":
|
if cfg.env == "simxarm":
|
||||||
|
|
|
@ -29,14 +29,12 @@ class PushtEnv(EnvBase):
|
||||||
image_size=None,
|
image_size=None,
|
||||||
seed=1337,
|
seed=1337,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
max_episode_length=300,
|
|
||||||
):
|
):
|
||||||
super().__init__(device=device, batch_size=[])
|
super().__init__(device=device, batch_size=[])
|
||||||
self.frame_skip = frame_skip
|
self.frame_skip = frame_skip
|
||||||
self.from_pixels = from_pixels
|
self.from_pixels = from_pixels
|
||||||
self.pixels_only = pixels_only
|
self.pixels_only = pixels_only
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.max_episode_length = max_episode_length
|
|
||||||
|
|
||||||
if pixels_only:
|
if pixels_only:
|
||||||
assert from_pixels
|
assert from_pixels
|
||||||
|
|
|
@ -80,7 +80,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
alpha=cfg.per_alpha,
|
alpha=cfg.per_alpha,
|
||||||
beta=cfg.per_beta,
|
beta=cfg.per_beta,
|
||||||
num_slices=num_traj_per_batch,
|
num_slices=num_traj_per_batch,
|
||||||
strict_length=False,
|
strict_length=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
online_buffer = TensorDictReplayBuffer(
|
online_buffer = TensorDictReplayBuffer(
|
||||||
|
|
Loading…
Reference in New Issue