wip - still need to verify full training run

This commit is contained in:
Alexander Soare 2024-03-11 18:45:21 +00:00
parent 304355c917
commit 87fcc536f9
3 changed files with 9 additions and 7 deletions

View File

@ -25,7 +25,7 @@ class PushTImageEnv(PushTEnv):
img = super()._render_frame(mode="rgb_array") img = super()._render_frame(mode="rgb_array")
agent_pos = np.array(self.agent.position) agent_pos = np.array(self.agent.position)
img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0) img_obs = np.moveaxis(img.astype(np.float32), -1, 0)
obs = {"image": img_obs, "agent_pos": agent_pos} obs = {"image": img_obs, "agent_pos": agent_pos}
# draw action # draw action

View File

@ -123,6 +123,8 @@ class MultiImageObsEncoder(ModuleAttrMixin):
if imagenet_norm: if imagenet_norm:
# TODO(rcadene): move normalizer to dataset and env # TODO(rcadene): move normalizer to dataset and env
this_normalizer = torchvision.transforms.Normalize( this_normalizer = torchvision.transforms.Normalize(
# Note: This matches the normalization in the original impl. for PushT Image. This may not be
# the case for other tasks.
mean=[127.5, 127.5, 127.5], mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5],
) )

View File

@ -42,8 +42,8 @@ policy:
num_inference_steps: 100 num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond} obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null # crop_shape: null
diffusion_step_embed_dim: 256 # before 128 diffusion_step_embed_dim: 128
down_dims: [256, 512, 1024] # before [512, 1024, 2048] down_dims: [512, 1024, 2048]
kernel_size: 5 kernel_size: 5
n_groups: 8 n_groups: 8
cond_predict_scale: True cond_predict_scale: True
@ -109,13 +109,13 @@ training:
debug: False debug: False
resume: True resume: True
# optimization # optimization
# lr_scheduler: cosine lr_scheduler: cosine
# lr_warmup_steps: 500 lr_warmup_steps: 500
num_epochs: 8000 num_epochs: 500
# gradient_accumulate_every: 1 # gradient_accumulate_every: 1
# EMA destroys performance when used with BatchNorm # EMA destroys performance when used with BatchNorm
# replace BatchNorm with GroupNorm. # replace BatchNorm with GroupNorm.
# use_ema: True use_ema: True
freeze_encoder: False freeze_encoder: False
# training loop control # training loop control
# in epochs # in epochs