fix Unet global_cond_dim to use state dim, not action dim (#278)

This commit is contained in:
Jihoon Oh 2024-06-17 23:17:28 +09:00 committed by GitHub
parent 15dd682714
commit b72d574891
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 1 deletions

View File

@ -165,7 +165,9 @@ class DiffusionModel(nn.Module):
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")]) num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self.unet = DiffusionConditionalUnet1d( self.unet = DiffusionConditionalUnet1d(
config, config,
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim * num_images) global_cond_dim=(
config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images
)
* config.n_obs_steps, * config.n_obs_steps,
) )