From a8e245fb31c0cba9d83cb687625251e31071dccf Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 6 May 2024 07:27:01 +0100 Subject: [PATCH] Remove loss masking from diffusion policy (#135) --- .../diffusion/configuration_diffusion.py | 6 ++++++ .../policies/diffusion/modeling_diffusion.py | 2 +- lerobot/configs/policy/diffusion.yaml | 3 +++ .../pusht_diffusion/grad_stats.safetensors | Bin 47424 -> 47424 bytes .../pusht_diffusion/output_dict.safetensors | Bin 68 -> 68 bytes 5 files changed, 10 insertions(+), 1 deletion(-) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 73fabefa..d7341c33 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -64,6 +64,9 @@ class DiffusionConfig: clip_sample_range: The magnitude of the clipping range as described above. num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly spaced). If not provided, this defaults to be the same as `num_train_timesteps`. + do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See + `LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults + to False as the original Diffusion Policy implementation does the same. """ # Inputs / output structure. @@ -118,6 +121,9 @@ class DiffusionConfig: # Inference num_inference_steps: int | None = None + # Loss computation + do_mask_loss_for_padding: bool = False + def __post_init__(self): """Input validation (not exhaustive).""" if not self.vision_backbone.startswith("resnet"): diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index f5f64d80..91cf6dd0 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -268,7 +268,7 @@ class DiffusionModel(nn.Module): loss = F.mse_loss(pred, target, reduction="none") # Mask loss wherever the action is padded with copies (edges of the dataset trajectory). - if "action_is_pad" in batch: + if self.config.do_mask_loss_for_padding and "action_is_pad" in batch: in_episode_bound = ~batch["action_is_pad"] loss = loss * in_episode_bound.unsqueeze(-1) diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 60061c38..2d611c88 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -95,3 +95,6 @@ policy: # Inference num_inference_steps: 100 + + # Loss computation + do_mask_loss_for_padding: false diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors index 1b1142b2df89ad1d618c895c3bf0313cb8173125..f27cd6780a7c2b399ec47e7a27d0022ba30c5045 100644 GIT binary patch delta 1719 zcmX|=FZ%Ts}s(r%@ctUTNxBoQ++_csHpQSn$ri!Pw3z@mVGB2_pdT1l!1 zU6hxAgg|)&L6k=zl4Js+0TB@jg-RbrWZ8<}bPE3k%456~IhJuLO}3ZrYK@q>Y-c2To)`eb)(AExJpp-iABym}58q@p zq6POSU>Ypd7V#Tkg~e`7*ScDWzIU4#3)2HONwpk16kIq4Qr@nBpa-`! z6L$OIqwb&ZCySzBWNIHnM~K;B2Hpn-$qA^mM`SK15l-)QC%eSi5V!t}W@r3ia0tJS zPp+rHbjOmYERp`r#V;nAVE1^K^gKcazrTeoN>Fw}1lW9+Q9CcRTuqJEAw0-u|PizIPi$CDAS&q<-S&>IS5i>T1lB2<{~L`4Pv1mWN9 zp^u$^PHY}ZWcxqxJ=Hx8(oTu9-}I>jxUn^ckvw#UxgYYF<(C{;K+5NEEd1|oYi=Kl znSz2OXKnpt10BBfuf^EH)t{pBKCYmv-w#$ z#1sswGmU`|IA_D~?#!UD%LxqMbA}xd&BIIQ=?Tdcca^bl9Nzi8pV(w-fT44qWUuyH z_@c?1oZom3nI;N}3nDpM{oi|Jp{qGH7Hy5&uC-wcMplG0Hle_8y8Lc_XNOvjtMN^e z4)S~MA)&EEN3t*E`?*&YrMhTT#h*+=sdXFiNQVIQWhkck#ZiQ2FN+C#$VHOgI(v1} zNozEDB2gQnz6;$}mVR-wKeGiG7SI)h=KtgBH9$4W0nZA}EsIlZUeVIz zSwwO4;?6{nX^ae`u=h}VwgjBes&znt|n4BL<4Ny zb&gU557X1H#wB#@rY-bttjhoN@Mi7T9~Z-pTTFZTfl5~3*0}gcKRN$m!=-^;3Nm5Q zM0?HN&E&DX?uD#D_M9N2A!|7}_D@Rpp>z|^SojpW7-HQTac*N4;0n}SkLWE01Z-F`~f*5aI3D<3( zGkMXO)K`6rG^;QEpDmDjE1ZhZk4O6j?!0|yRm10M(GUMbGV3s5>4@K-0-s-LU;1!h znBT#zHgI-lCy_NijLwcZllage$~p9i6flp$9F&CpS9$Q8y{xl4L*RcpLc!?51r&y5 zDJ5p1$yAi+gy!am8#R=oVp5VP!P59xdu`be`-?vnG^A?%Q=W;{${7V!-g-gzGCqTf zm{?$P#xkhLH{a^u;4M^j(h}Xhb}99*g{IDybSwWg8GTH?(2f;Y@j|i%)bY2S2vemA z<~9?C>( Pt?Bc-gS6{CoasLR6e#%s delta 1719 zcmX|=eLPh89>*uV(3I7NY>H(m&tp8yIp_C1XRMW)t0cXyo^0#d$Y`X+-d&f}kfCg( zh7GfxAG(sKFmuj$ShgX%O08Y1x+IS%tFeySJBX^C$PGz z26@V6Ah}3EPug6BBO|+26ItcpgZJSY{U2CfcsY8k7V4fGqg&hZq5COcy`0B@r%3@m z-rf$P2S>4`T~c_q z#y=8LyeT}xEcY)?&K+YC>2KlIf}46>YXxvP8vgj_=RwjNz@+x(fob0eF`rQmer*G) zu#hhxFliL4aSMUX5(#s$UBdEXF2%>CGAJ7+iK@9s_+HzNh#WWx%dTdt`p(5eF1H)A znkV7H4t;#(n&iK@oNQp`j3VG0!vvyeXa`#qQ7^k41Wz*3@!=N-z#}yWvki;|ubVy1 zw{AMavrc^_-<}71+w7SNM@Qg9pXIxY5lEW+h|AKz)u7iOOFRm2gN$%1+*nA$32z(K zfv$NZGV2tCMr`~)$a3HP$B;Pxo+Uh`CzR`tn6u)|^}|mrfaaUy-jW_PZJSI_Z!(7N zKdh9|I&=JX|% z5$ud9#!Aopf;iM0w`+w@5Qmdg7qI-_RC}Z@)}&toCC+l(VVM{{2u{JSI2?lb@3jcK zfChLVI4D&?(iz|;y==kvZ=m9LE@wXC@!kn)!^EF}#bl}SvrCrHBl?(Dbao=Sbc5$wLkjW5O?*JDT4al~0a&)E3nDjY225J8d zrC{U@9x)op_FFpdJ@FiL^`)!bcg9GdqAQqLrC9^z)|rfB-uoShOqxr>!2e=Mg$(udzP-k1L zfGq3l%4f4%phmfp3BEan>aIjF6NyuQF4U0tmik4if?u5@AgDFvL)qkus71ZEyklmGLcm5XQ0$n zE=Py$MRbtweGu+XaZlgz2V2nTXejbSZQE_#OYc1pyo%q&_t)#xw<1RD$dTXmq1Uuc#^* zxLiT5sOhI74?mYuap?h6)IgDNAblULot6W#^!Da&;)+-S$3gtOhZt`OZMtVwMl9;h zH`iXyamNUUZO($kxrc<4w?82;$%mS} zJ%S?Ra+YtXhJp)3wnWE+t5-@`ns6}Pb<)2RnNwAQxqZ7>AcR-d;7fXacC)el`g(M@ zY7+DO$0!lg^cI`;A0;xw50%#2zGh!|w&re85}KTCRK9LIhSonU7K|_3 zfKCkkhUs24VslPCHDf_3=LQ>f)w77dwFmEV=|Vw{IDTKzi<$=%xUY1RDR@~Pz^mKb z;EsikIx|m3*2fyFRJYRKP3W|*M~;5yWV((Ii_zfe3rf$vwWuxbOTh=LJ=pv;@7_Tk zrDYoKKKB(D-ng{k$xAm->d$Kk^{Z=0wP%V5t(k_{nFd_(@C+23Jm0#k$yK;QtCXqy z+Lh#Rs_UgMGQJ=qa%C#hK}#~gtBL9DegF#pX}V=^-@9MnY7<5ax48Q$MT=#mOUp&| zm%QU7H*AUdq5C+gNHJ}hT31IN{K-hYqc({=YF^XorqCC5v~)8QUMpF_&pi1fSCXIhKBN!^*>hBQS`Mi*)~^nuqDXmqyLx0p2x+r+ SFSD?MOWDM4p>I4kq5cO59`5P@ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors index 77472bb50b2896c48a0a739eee9848e821769021..5f33535d5ad316f4cef55925a215428643ed2f52 100644 GIT binary patch delta 9 QcmZ>9nc%>3p0ULq01lS|@Bjb+ delta 9 QcmZ>9nc%?UzP#BU01jsZ4FCWD