fix Unet global_cond_dim to use state dim, not action dim (#278)
This commit is contained in:
parent
15dd682714
commit
b72d574891
|
@ -165,7 +165,9 @@ class DiffusionModel(nn.Module):
|
|||
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
self.unet = DiffusionConditionalUnet1d(
|
||||
config,
|
||||
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim * num_images)
|
||||
global_cond_dim=(
|
||||
config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images
|
||||
)
|
||||
* config.n_obs_steps,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue