wip - still need to verify full training run
This commit is contained in:
parent
304355c917
commit
87fcc536f9
|
@ -25,7 +25,7 @@ class PushTImageEnv(PushTEnv):
|
||||||
img = super()._render_frame(mode="rgb_array")
|
img = super()._render_frame(mode="rgb_array")
|
||||||
|
|
||||||
agent_pos = np.array(self.agent.position)
|
agent_pos = np.array(self.agent.position)
|
||||||
img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0)
|
img_obs = np.moveaxis(img.astype(np.float32), -1, 0)
|
||||||
obs = {"image": img_obs, "agent_pos": agent_pos}
|
obs = {"image": img_obs, "agent_pos": agent_pos}
|
||||||
|
|
||||||
# draw action
|
# draw action
|
||||||
|
|
|
@ -123,6 +123,8 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
if imagenet_norm:
|
if imagenet_norm:
|
||||||
# TODO(rcadene): move normalizer to dataset and env
|
# TODO(rcadene): move normalizer to dataset and env
|
||||||
this_normalizer = torchvision.transforms.Normalize(
|
this_normalizer = torchvision.transforms.Normalize(
|
||||||
|
# Note: This matches the normalization in the original impl. for PushT Image. This may not be
|
||||||
|
# the case for other tasks.
|
||||||
mean=[127.5, 127.5, 127.5],
|
mean=[127.5, 127.5, 127.5],
|
||||||
std=[127.5, 127.5, 127.5],
|
std=[127.5, 127.5, 127.5],
|
||||||
)
|
)
|
||||||
|
|
|
@ -42,8 +42,8 @@ policy:
|
||||||
num_inference_steps: 100
|
num_inference_steps: 100
|
||||||
obs_as_global_cond: ${obs_as_global_cond}
|
obs_as_global_cond: ${obs_as_global_cond}
|
||||||
# crop_shape: null
|
# crop_shape: null
|
||||||
diffusion_step_embed_dim: 256 # before 128
|
diffusion_step_embed_dim: 128
|
||||||
down_dims: [256, 512, 1024] # before [512, 1024, 2048]
|
down_dims: [512, 1024, 2048]
|
||||||
kernel_size: 5
|
kernel_size: 5
|
||||||
n_groups: 8
|
n_groups: 8
|
||||||
cond_predict_scale: True
|
cond_predict_scale: True
|
||||||
|
@ -109,13 +109,13 @@ training:
|
||||||
debug: False
|
debug: False
|
||||||
resume: True
|
resume: True
|
||||||
# optimization
|
# optimization
|
||||||
# lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
# lr_warmup_steps: 500
|
lr_warmup_steps: 500
|
||||||
num_epochs: 8000
|
num_epochs: 500
|
||||||
# gradient_accumulate_every: 1
|
# gradient_accumulate_every: 1
|
||||||
# EMA destroys performance when used with BatchNorm
|
# EMA destroys performance when used with BatchNorm
|
||||||
# replace BatchNorm with GroupNorm.
|
# replace BatchNorm with GroupNorm.
|
||||||
# use_ema: True
|
use_ema: True
|
||||||
freeze_encoder: False
|
freeze_encoder: False
|
||||||
# training loop control
|
# training loop control
|
||||||
# in epochs
|
# in epochs
|
||||||
|
|
Loading…
Reference in New Issue