Merge remote-tracking branch 'upstream/main' into finish_examples
This commit is contained in:
commit
120f0aef5c
|
@ -86,13 +86,8 @@ def make_offline_buffer(
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env.name)
|
raise ValueError(cfg.env.name)
|
||||||
|
|
||||||
# TODO(rcadene): backward compatiblity to load pretrained pusht policy
|
|
||||||
dataset_id = cfg.get("dataset_id")
|
|
||||||
if dataset_id is None and cfg.env.name == "pusht":
|
|
||||||
dataset_id = "pusht"
|
|
||||||
|
|
||||||
offline_buffer = clsfunc(
|
offline_buffer = clsfunc(
|
||||||
dataset_id=dataset_id,
|
dataset_id=cfg.dataset_id,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
root=DATA_DIR,
|
root=DATA_DIR,
|
||||||
|
|
|
@ -103,29 +103,3 @@ optimizer:
|
||||||
betas: [0.95, 0.999]
|
betas: [0.95, 0.999]
|
||||||
eps: 1.0e-8
|
eps: 1.0e-8
|
||||||
weight_decay: 1.0e-6
|
weight_decay: 1.0e-6
|
||||||
|
|
||||||
training:
|
|
||||||
device: "cuda:0"
|
|
||||||
seed: 42
|
|
||||||
debug: False
|
|
||||||
resume: True
|
|
||||||
# optimization
|
|
||||||
# lr_scheduler: cosine
|
|
||||||
# lr_warmup_steps: 500
|
|
||||||
num_epochs: 8000
|
|
||||||
# gradient_accumulate_every: 1
|
|
||||||
# EMA destroys performance when used with BatchNorm
|
|
||||||
# replace BatchNorm with GroupNorm.
|
|
||||||
# use_ema: True
|
|
||||||
freeze_encoder: False
|
|
||||||
# training loop control
|
|
||||||
# in epochs
|
|
||||||
rollout_every: 50
|
|
||||||
checkpoint_every: 50
|
|
||||||
val_every: 1
|
|
||||||
sample_every: 5
|
|
||||||
# steps per epoch
|
|
||||||
max_train_steps: null
|
|
||||||
max_val_steps: null
|
|
||||||
# misc
|
|
||||||
tqdm_interval_sec: 1.0
|
|
||||||
|
|
Loading…
Reference in New Issue