Sanitize cfg.env

This commit is contained in:
Cadene 2024-02-25 12:02:29 +00:00
parent 9b469c4232
commit ed80db2846
6 changed files with 46 additions and 42 deletions

View File

@ -7,26 +7,26 @@ from lerobot.common.envs.transforms import Prod
def make_env(cfg): def make_env(cfg):
kwargs = { kwargs = {
"frame_skip": cfg.action_repeat, "frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.from_pixels, "from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.pixels_only, "pixels_only": cfg.env.pixels_only,
"image_size": cfg.image_size, "image_size": cfg.env.image_size,
} }
if cfg.env == "simxarm": if cfg.env.name == "simxarm":
kwargs["task"] = cfg.task kwargs["task"] = cfg.env.task
clsfunc = SimxarmEnv clsfunc = SimxarmEnv
elif cfg.env == "pusht": elif cfg.env.name == "pusht":
clsfunc = PushtEnv clsfunc = PushtEnv
else: else:
raise ValueError(cfg.env) raise ValueError(cfg.env.name)
env = clsfunc(**kwargs) env = clsfunc(**kwargs)
# limit rollout to max_steps # limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length)) env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
if cfg.env == "pusht": if cfg.env.name == "pusht":
# to ensure pusht is in [0,255] like simxarm # to ensure pusht is in [0,255] like simxarm
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0)) env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))

View File

@ -51,7 +51,7 @@ def print_run(cfg, reward=None):
kvs = [ kvs = [
("task", cfg.task), ("task", cfg.task),
("train steps", f"{int(cfg.train_steps * cfg.action_repeat):,}"), ("train steps", f"{int(cfg.train_steps * cfg.env.action_repeat):,}"),
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])), # ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
# ('actions', cfg.action_dim), # ('actions', cfg.action_dim),
# ('experiment', cfg.exp_name), # ('experiment', cfg.exp_name),
@ -117,7 +117,11 @@ class VideoRecorder:
if self.enabled: if self.enabled:
frames = np.stack(self.frames).transpose(0, 3, 1, 2) frames = np.stack(self.frames).transpose(0, 3, 1, 2)
self._wandb.log( self._wandb.log(
{"eval_video": self._wandb.Video(frames, fps=self.fps, format="mp4")}, {
"eval_video": self._wandb.Video(
frames, fps=self.env.fps, format="mp4"
)
},
step=step, step=step,
) )

View File

@ -13,21 +13,18 @@ eval_episodes: 20
save_video: false save_video: false
save_model: false save_model: false
save_buffer: false save_buffer: false
train_steps: 50000
# env
env: simxarm
task: lift
from_pixels: True
pixels_only: False
image_size: 84
fps: 15 fps: 15
env:
name: simxarm
# xarm_lift task: lift
action_repeat: 2 from_pixels: True
episode_length: 25 pixels_only: False
train_steps: 50000 image_size: 84
action_repeat: 2
episode_length: 25
fps: ${fps}
policy: policy:
@ -37,14 +34,14 @@ policy:
# xarm_lift # xarm_lift
train_steps: ${train_steps} train_steps: ${train_steps}
episode_length: ${episode_length} episode_length: ${env.episode_length}
discount: 0.9 discount: 0.9
modality: 'all' modality: 'all'
# pixels # pixels
frame_stack: 1 frame_stack: 1
num_channels: 32 num_channels: 32
img_size: ${image_size} img_size: ${env.image_size}
state_dim: 4 state_dim: 4
action_dim: 4 action_dim: 4

View File

@ -6,16 +6,19 @@ hydra:
job: job:
name: pusht name: pusht
# env
env: pusht
task: pusht
image_size: 96
action_repeat: 1
state_dim: 2
action_dim: 2
fps: 10
eval_episodes: 50 eval_episodes: 50
episode_length: 300
eval_freq: 7500 eval_freq: 7500
save_freq: 75000 save_freq: 75000
fps: 10
env:
name: pusht
task: pusht
image_size: 96
fps: ${fps}
action_repeat: 1
episode_length: 300
policy:
state_dim: 2
action_dim: 2

View File

@ -126,8 +126,8 @@ def eval(cfg: dict, out_dir=None):
policy=policy, policy=policy,
save_video=True, save_video=True,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
fps=cfg.fps, fps=cfg.env.fps,
max_steps=cfg.episode_length, max_steps=cfg.env.episode_length,
num_episodes=cfg.eval_episodes, num_episodes=cfg.eval_episodes,
) )
print(metrics) print(metrics)

View File

@ -88,7 +88,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
while step < cfg.train_steps: while step < cfg.train_steps:
is_offline = True is_offline = True
num_updates = cfg.episode_length num_updates = cfg.env.episode_length
_step = step + num_updates _step = step + num_updates
rollout_metrics = {} rollout_metrics = {}
@ -98,11 +98,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO: use SyncDataCollector for that? # TODO: use SyncDataCollector for that?
with torch.no_grad(): with torch.no_grad():
rollout = env.rollout( rollout = env.rollout(
max_steps=cfg.episode_length, max_steps=cfg.env.episode_length,
policy=td_policy, policy=td_policy,
auto_cast_to_device=True, auto_cast_to_device=True,
) )
assert len(rollout) <= cfg.episode_length assert len(rollout) <= cfg.env.episode_length
rollout["episode"] = torch.tensor( rollout["episode"] = torch.tensor(
[online_episode_idx] * len(rollout), dtype=torch.int [online_episode_idx] * len(rollout), dtype=torch.int
) )
@ -133,7 +133,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
) )
# Log training metrics # Log training metrics
env_step = int(_step * cfg.action_repeat) env_step = int(_step * cfg.env.action_repeat)
common_metrics = { common_metrics = {
"episode": online_episode_idx, "episode": online_episode_idx,
"step": _step, "step": _step,