Merge branch 'huggingface:main' into 2024_05_30_add_data_augmentation

This commit is contained in:
Marina Barannikov 2024-06-05 13:56:47 +02:00 committed by GitHub
commit 8b134725d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 17 deletions

View File

@ -209,7 +209,7 @@ def eval_policy(
policy: torch.nn.Module, policy: torch.nn.Module,
n_episodes: int, n_episodes: int,
max_episodes_rendered: int = 0, max_episodes_rendered: int = 0,
video_dir: Path | None = None, videos_dir: Path | None = None,
return_episode_data: bool = False, return_episode_data: bool = False,
start_seed: int | None = None, start_seed: int | None = None,
enable_progbar: bool = False, enable_progbar: bool = False,
@ -221,7 +221,7 @@ def eval_policy(
policy: The policy. policy: The policy.
n_episodes: The number of episodes to evaluate. n_episodes: The number of episodes to evaluate.
max_episodes_rendered: Maximum number of episodes to render into videos. max_episodes_rendered: Maximum number of episodes to render into videos.
video_dir: Where to save rendered videos. videos_dir: Where to save rendered videos.
return_episode_data: Whether to return episode data for online training. Incorporates the data into return_episode_data: Whether to return episode data for online training. Incorporates the data into
the "episodes" key of the returned dictionary. the "episodes" key of the returned dictionary.
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
@ -347,8 +347,8 @@ def eval_policy(
): ):
if n_episodes_rendered >= max_episodes_rendered: if n_episodes_rendered >= max_episodes_rendered:
break break
video_dir.mkdir(parents=True, exist_ok=True) videos_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4" video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
video_paths.append(str(video_path)) video_paths.append(str(video_path))
thread = threading.Thread( thread = threading.Thread(
target=write_video, target=write_video,
@ -503,9 +503,10 @@ def _compile_episode_data(
} }
def eval( def main(
pretrained_policy_path: str | None = None, pretrained_policy_path: str | None = None,
hydra_cfg_path: str | None = None, hydra_cfg_path: str | None = None,
out_dir: str | None = None,
config_overrides: list[str] | None = None, config_overrides: list[str] | None = None,
): ):
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None) assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
@ -513,12 +514,8 @@ def eval(
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides) hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
else: else:
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides) hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
out_dir = (
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
)
if out_dir is None: if out_dir is None:
raise NotImplementedError() out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
# Check device is available # Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True) device = get_safe_torch_device(hydra_cfg.device, log=True)
@ -546,7 +543,7 @@ def eval(
policy, policy,
hydra_cfg.eval.n_episodes, hydra_cfg.eval.n_episodes,
max_episodes_rendered=10, max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval", videos_dir=Path(out_dir) / "videos",
start_seed=hydra_cfg.seed, start_seed=hydra_cfg.seed,
enable_progbar=True, enable_progbar=True,
enable_inner_progbar=True, enable_inner_progbar=True,
@ -586,6 +583,13 @@ if __name__ == "__main__":
), ),
) )
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.") parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"--out-dir",
help=(
"Where to save the evaluation outputs. If not provided, outputs are saved in "
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
),
)
parser.add_argument( parser.add_argument(
"overrides", "overrides",
nargs="*", nargs="*",
@ -594,7 +598,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if args.pretrained_policy_name_or_path is None: if args.pretrained_policy_name_or_path is None:
eval(hydra_cfg_path=args.config, config_overrides=args.overrides) main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
else: else:
try: try:
pretrained_policy_path = Path( pretrained_policy_path = Path(
@ -618,4 +622,8 @@ if __name__ == "__main__":
"repo ID, nor is it an existing local directory." "repo ID, nor is it an existing local directory."
) )
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides) main(
pretrained_policy_path=pretrained_policy_path,
out_dir=args.out_dir,
config_overrides=args.overrides,
)

View File

@ -150,6 +150,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
grad_norm = info["grad_norm"] grad_norm = info["grad_norm"]
lr = info["lr"] lr = info["lr"]
update_s = info["update_s"] update_s = info["update_s"]
dataloading_s = info["dataloading_s"]
# A sample is an (observation,action) pair, where observation and action # A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
@ -170,6 +171,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
f"lr:{lr:0.1e}", f"lr:{lr:0.1e}",
# in seconds # in seconds
f"updt_s:{update_s:.3f}", f"updt_s:{update_s:.3f}",
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
] ]
logging.info(" ".join(log_items)) logging.info(" ".join(log_items))
@ -325,6 +327,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Note: this helper will be used in offline and online training loops. # Note: this helper will be used in offline and online training loops.
def evaluate_and_checkpoint_if_needed(step): def evaluate_and_checkpoint_if_needed(step):
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
step_identifier = f"{step:0{_num_digits}d}"
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0: if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
@ -332,7 +337,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
eval_env, eval_env,
policy, policy,
cfg.eval.n_episodes, cfg.eval.n_episodes,
video_dir=Path(out_dir) / "eval", videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
max_episodes_rendered=4, max_episodes_rendered=4,
start_seed=cfg.seed, start_seed=cfg.seed,
) )
@ -350,9 +355,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
policy, policy,
optimizer, optimizer,
lr_scheduler, lr_scheduler,
identifier=str(step).zfill( identifier=step_identifier,
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
),
) )
logging.info("Resume training") logging.info("Resume training")
@ -382,7 +385,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
for _ in range(step, cfg.training.offline_steps): for _ in range(step, cfg.training.offline_steps):
if step == 0: if step == 0:
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
start_time = time.perf_counter()
batch = next(dl_iter) batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
for key in batch: for key in batch:
batch[key] = batch[key].to(device, non_blocking=True) batch[key] = batch[key].to(device, non_blocking=True)
@ -397,6 +403,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
use_amp=cfg.use_amp, use_amp=cfg.use_amp,
) )
train_info["dataloading_s"] = dataloading_s
if step % cfg.training.log_freq == 0: if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True) log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)