Sanitize cfg.env
This commit is contained in:
parent
9b469c4232
commit
ed80db2846
|
@ -7,26 +7,26 @@ from lerobot.common.envs.transforms import Prod
|
||||||
|
|
||||||
def make_env(cfg):
|
def make_env(cfg):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"frame_skip": cfg.action_repeat,
|
"frame_skip": cfg.env.action_repeat,
|
||||||
"from_pixels": cfg.from_pixels,
|
"from_pixels": cfg.env.from_pixels,
|
||||||
"pixels_only": cfg.pixels_only,
|
"pixels_only": cfg.env.pixels_only,
|
||||||
"image_size": cfg.image_size,
|
"image_size": cfg.env.image_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.env == "simxarm":
|
if cfg.env.name == "simxarm":
|
||||||
kwargs["task"] = cfg.task
|
kwargs["task"] = cfg.env.task
|
||||||
clsfunc = SimxarmEnv
|
clsfunc = SimxarmEnv
|
||||||
elif cfg.env == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
clsfunc = PushtEnv
|
clsfunc = PushtEnv
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env)
|
raise ValueError(cfg.env.name)
|
||||||
|
|
||||||
env = clsfunc(**kwargs)
|
env = clsfunc(**kwargs)
|
||||||
|
|
||||||
# limit rollout to max_steps
|
# limit rollout to max_steps
|
||||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length))
|
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
|
||||||
|
|
||||||
if cfg.env == "pusht":
|
if cfg.env.name == "pusht":
|
||||||
# to ensure pusht is in [0,255] like simxarm
|
# to ensure pusht is in [0,255] like simxarm
|
||||||
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
|
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ def print_run(cfg, reward=None):
|
||||||
|
|
||||||
kvs = [
|
kvs = [
|
||||||
("task", cfg.task),
|
("task", cfg.task),
|
||||||
("train steps", f"{int(cfg.train_steps * cfg.action_repeat):,}"),
|
("train steps", f"{int(cfg.train_steps * cfg.env.action_repeat):,}"),
|
||||||
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
|
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
|
||||||
# ('actions', cfg.action_dim),
|
# ('actions', cfg.action_dim),
|
||||||
# ('experiment', cfg.exp_name),
|
# ('experiment', cfg.exp_name),
|
||||||
|
@ -117,7 +117,11 @@ class VideoRecorder:
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
frames = np.stack(self.frames).transpose(0, 3, 1, 2)
|
frames = np.stack(self.frames).transpose(0, 3, 1, 2)
|
||||||
self._wandb.log(
|
self._wandb.log(
|
||||||
{"eval_video": self._wandb.Video(frames, fps=self.fps, format="mp4")},
|
{
|
||||||
|
"eval_video": self._wandb.Video(
|
||||||
|
frames, fps=self.env.fps, format="mp4"
|
||||||
|
)
|
||||||
|
},
|
||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,21 +13,18 @@ eval_episodes: 20
|
||||||
save_video: false
|
save_video: false
|
||||||
save_model: false
|
save_model: false
|
||||||
save_buffer: false
|
save_buffer: false
|
||||||
|
train_steps: 50000
|
||||||
|
fps: 15
|
||||||
|
|
||||||
# env
|
env:
|
||||||
env: simxarm
|
name: simxarm
|
||||||
task: lift
|
task: lift
|
||||||
from_pixels: True
|
from_pixels: True
|
||||||
pixels_only: False
|
pixels_only: False
|
||||||
image_size: 84
|
image_size: 84
|
||||||
fps: 15
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# xarm_lift
|
|
||||||
action_repeat: 2
|
action_repeat: 2
|
||||||
episode_length: 25
|
episode_length: 25
|
||||||
train_steps: 50000
|
fps: ${fps}
|
||||||
|
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
|
@ -37,14 +34,14 @@ policy:
|
||||||
|
|
||||||
# xarm_lift
|
# xarm_lift
|
||||||
train_steps: ${train_steps}
|
train_steps: ${train_steps}
|
||||||
episode_length: ${episode_length}
|
episode_length: ${env.episode_length}
|
||||||
discount: 0.9
|
discount: 0.9
|
||||||
modality: 'all'
|
modality: 'all'
|
||||||
|
|
||||||
# pixels
|
# pixels
|
||||||
frame_stack: 1
|
frame_stack: 1
|
||||||
num_channels: 32
|
num_channels: 32
|
||||||
img_size: ${image_size}
|
img_size: ${env.image_size}
|
||||||
state_dim: 4
|
state_dim: 4
|
||||||
action_dim: 4
|
action_dim: 4
|
||||||
|
|
||||||
|
|
|
@ -6,16 +6,19 @@ hydra:
|
||||||
job:
|
job:
|
||||||
name: pusht
|
name: pusht
|
||||||
|
|
||||||
# env
|
|
||||||
env: pusht
|
|
||||||
task: pusht
|
|
||||||
image_size: 96
|
|
||||||
action_repeat: 1
|
|
||||||
state_dim: 2
|
|
||||||
action_dim: 2
|
|
||||||
fps: 10
|
|
||||||
eval_episodes: 50
|
eval_episodes: 50
|
||||||
episode_length: 300
|
|
||||||
eval_freq: 7500
|
eval_freq: 7500
|
||||||
save_freq: 75000
|
save_freq: 75000
|
||||||
|
fps: 10
|
||||||
|
|
||||||
|
env:
|
||||||
|
name: pusht
|
||||||
|
task: pusht
|
||||||
|
image_size: 96
|
||||||
|
fps: ${fps}
|
||||||
|
action_repeat: 1
|
||||||
|
episode_length: 300
|
||||||
|
|
||||||
|
policy:
|
||||||
|
state_dim: 2
|
||||||
|
action_dim: 2
|
|
@ -126,8 +126,8 @@ def eval(cfg: dict, out_dir=None):
|
||||||
policy=policy,
|
policy=policy,
|
||||||
save_video=True,
|
save_video=True,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
fps=cfg.fps,
|
fps=cfg.env.fps,
|
||||||
max_steps=cfg.episode_length,
|
max_steps=cfg.env.episode_length,
|
||||||
num_episodes=cfg.eval_episodes,
|
num_episodes=cfg.eval_episodes,
|
||||||
)
|
)
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
|
@ -88,7 +88,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
while step < cfg.train_steps:
|
while step < cfg.train_steps:
|
||||||
is_offline = True
|
is_offline = True
|
||||||
num_updates = cfg.episode_length
|
num_updates = cfg.env.episode_length
|
||||||
_step = step + num_updates
|
_step = step + num_updates
|
||||||
rollout_metrics = {}
|
rollout_metrics = {}
|
||||||
|
|
||||||
|
@ -98,11 +98,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
# TODO: use SyncDataCollector for that?
|
# TODO: use SyncDataCollector for that?
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
rollout = env.rollout(
|
rollout = env.rollout(
|
||||||
max_steps=cfg.episode_length,
|
max_steps=cfg.env.episode_length,
|
||||||
policy=td_policy,
|
policy=td_policy,
|
||||||
auto_cast_to_device=True,
|
auto_cast_to_device=True,
|
||||||
)
|
)
|
||||||
assert len(rollout) <= cfg.episode_length
|
assert len(rollout) <= cfg.env.episode_length
|
||||||
rollout["episode"] = torch.tensor(
|
rollout["episode"] = torch.tensor(
|
||||||
[online_episode_idx] * len(rollout), dtype=torch.int
|
[online_episode_idx] * len(rollout), dtype=torch.int
|
||||||
)
|
)
|
||||||
|
@ -133,7 +133,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log training metrics
|
# Log training metrics
|
||||||
env_step = int(_step * cfg.action_repeat)
|
env_step = int(_step * cfg.env.action_repeat)
|
||||||
common_metrics = {
|
common_metrics = {
|
||||||
"episode": online_episode_idx,
|
"episode": online_episode_idx,
|
||||||
"step": _step,
|
"step": _step,
|
||||||
|
|
Loading…
Reference in New Issue