From 36d9e885ef1f198ee42a56e01d06e57e07356e8e Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 16 Apr 2024 17:14:40 +0000 Subject: [PATCH] Address comments --- download_and_upload_dataset.py | 88 +++++++++--------- lerobot/common/datasets/utils.py | 14 +-- lerobot/scripts/eval.py | 4 +- lerobot/scripts/train.py | 8 +- lerobot/scripts/visualize_dataset.py | 2 +- .../train/data-00000-of-00001.arrow | Bin 14792336 -> 14792344 bytes .../train/dataset_info.json | 10 +- .../train/state.json | 2 +- .../train/data-00000-of-00001.arrow | Bin 10420448 -> 10420456 bytes .../train/dataset_info.json | 10 +- .../train/state.json | 2 +- .../train/data-00000-of-00001.arrow | Bin 10468384 -> 10468392 bytes .../train/dataset_info.json | 10 +- .../train/state.json | 2 +- .../train/data-00000-of-00001.arrow | Bin 11702168 -> 11702176 bytes .../train/dataset_info.json | 10 +- .../train/state.json | 2 +- .../pusht/train/data-00000-of-00001.arrow | Bin 200704 -> 200712 bytes tests/data/pusht/train/dataset_info.json | 4 +- tests/data/pusht/train/state.json | 2 +- .../train/data-00000-of-00001.arrow | Bin 104360 -> 104368 bytes .../xarm_lift_medium/train/dataset_info.json | 4 +- tests/data/xarm_lift_medium/train/state.json | 2 +- tests/test_datasets.py | 18 ++-- 24 files changed, 100 insertions(+), 94 deletions(-) diff --git a/download_and_upload_dataset.py b/download_and_upload_dataset.py index 0ff86697..2e5c806c 100644 --- a/download_and_upload_dataset.py +++ b/download_and_upload_dataset.py @@ -17,6 +17,17 @@ from datasets import Dataset, Features, Image, Sequence, Value from PIL import Image as PILImage +def download_and_upload(root, root_tests, dataset_id): + if "pusht" in dataset_id: + download_and_upload_pusht(root, root_tests, dataset_id) + elif "xarm" in dataset_id: + download_and_upload_xarm(root, root_tests, dataset_id) + elif "aloha" in dataset_id: + download_and_upload_aloha(root, root_tests, dataset_id) + else: + raise ValueError(dataset_id) + + def download_and_extract_zip(url: str, destination_folder: Path) -> bool: import zipfile @@ -87,7 +98,6 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10): states = torch.from_numpy(dataset_dict["state"]) actions = torch.from_numpy(dataset_dict["action"]) - data_ids_per_episode = {} ep_dicts = [] id_from = 0 @@ -150,15 +160,11 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10): "next.reward": torch.cat([reward[1:], reward[[-1]]]), "next.done": torch.cat([done[1:], done[[-1]]]), "next.success": torch.cat([success[1:], success[[-1]]]), - "episode_data_id_from": torch.tensor([id_from] * num_frames), - "episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames), + "episode_data_index_from": torch.tensor([id_from] * num_frames), + "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), } ep_dicts.append(ep_dict) - assert isinstance(episode_id, int) - data_ids_per_episode[episode_id] = torch.arange(id_from, id_to, 1) - assert len(data_ids_per_episode[episode_id]) == num_frames - id_from += num_frames data_dict = {} @@ -190,8 +196,8 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10): "next.done": Value(dtype="bool", id=None), "next.success": Value(dtype="bool", id=None), "index": Value(dtype="int64", id=None), - "episode_data_id_from": Value(dtype="int64", id=None), - "episode_data_id_to": Value(dtype="int64", id=None), + "episode_data_index_from": Value(dtype="int64", id=None), + "episode_data_index_to": Value(dtype="int64", id=None), } features = Features(features) dataset = Dataset.from_dict(data_dict, features=features) @@ -265,8 +271,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15): # "next.observation.state": next_state, "next.reward": next_reward, "next.done": next_done, - "episode_data_id_from": torch.tensor([id_from] * num_frames), - "episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames), + "episode_data_index_from": torch.tensor([id_from] * num_frames), + "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), } ep_dicts.append(ep_dict) @@ -301,8 +307,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15): "next.done": Value(dtype="bool", id=None), #'next.success': Value(dtype='bool', id=None), "index": Value(dtype="int64", id=None), - "episode_data_id_from": Value(dtype="int64", id=None), - "episode_data_id_to": Value(dtype="int64", id=None), + "episode_data_index_from": Value(dtype="int64", id=None), + "episode_data_index_to": Value(dtype="int64", id=None), } features = Features(features) dataset = Dataset.from_dict(data_dict, features=features) @@ -390,20 +396,7 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): state = torch.from_numpy(ep["/observations/qpos"][:]) action = torch.from_numpy(ep["/action"][:]) - ep_dict = { - "observation.state": state, - "action": action, - "episode_id": torch.tensor([ep_id] * num_frames), - "frame_id": torch.arange(0, num_frames, 1), - "timestamp": torch.arange(0, num_frames, 1) / fps, - # "next.observation.state": state, - # TODO(rcadene): compute reward and success - # "next.reward": reward, - "next.done": done, - # "next.success": success, - "episode_data_id_from": torch.tensor([id_from] * num_frames), - "episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames), - } + ep_dict = {} for cam in cameras[dataset_id]: image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c @@ -411,6 +404,23 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image] # ep_dict[f"next.observation.images.{cam}"] = image + ep_dict.update( + { + "observation.state": state, + "action": action, + "episode_id": torch.tensor([ep_id] * num_frames), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / fps, + # "next.observation.state": state, + # TODO(rcadene): compute reward and success + # "next.reward": reward, + "next.done": done, + # "next.success": success, + "episode_data_index_from": torch.tensor([id_from] * num_frames), + "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), + } + ) + assert isinstance(ep_id, int) ep_dicts.append(ep_dict) @@ -446,8 +456,8 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): "next.done": Value(dtype="bool", id=None), #'next.success': Value(dtype='bool', id=None), "index": Value(dtype="int64", id=None), - "episode_data_id_from": Value(dtype="int64", id=None), - "episode_data_id_to": Value(dtype="int64", id=None), + "episode_data_index_from": Value(dtype="int64", id=None), + "episode_data_index_to": Value(dtype="int64", id=None), } features = Features(features) dataset = Dataset.from_dict(data_dict, features=features) @@ -461,23 +471,17 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): if __name__ == "__main__": root = "data" - root_tests = "{root_tests}" - - download_and_upload_pusht(root, root_tests, dataset_id="pusht") - download_and_upload_xarm(root, root_tests, dataset_id="xarm_lift_medium") - download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_human") - download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_scripted") - download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_human") - download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_scripted") + root_tests = "tests/data" dataset_ids = [ - "pusht", - "xarm_lift_medium", - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", + # "pusht", + # "xarm_lift_medium", + # "aloha_sim_insertion_human", + # "aloha_sim_insertion_scripted", + # "aloha_sim_transfer_cube_human", "aloha_sim_transfer_cube_scripted", ] for dataset_id in dataset_ids: + download_and_upload(root, root_tests, dataset_id) # assume stats have been precomputed shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth") diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 154dcb68..1b353e69 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -7,9 +7,9 @@ import tqdm def load_previous_and_future_frames( - item: dict[torch.Tensor], - data_dict: dict[torch.Tensor], - delta_timestamps: dict[list[float]], + item: dict[str, torch.Tensor], + data_dict: dict[str, torch.Tensor], + delta_timestamps: dict[str, list[float]], tol: float = 0.04, ) -> dict[torch.Tensor]: """ @@ -35,12 +35,12 @@ def load_previous_and_future_frames( - AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection. """ # get indices of the frames associated to the episode, and their timestamps - ep_data_id_from = item["episode_data_id_from"].item() - ep_data_id_to = item["episode_data_id_to"].item() - ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to + 1, 1) + ep_data_id_from = item["episode_data_index_from"].item() + ep_data_id_to = item["episode_data_index_to"].item() + ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1) # load timestamps - ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from : ep_data_id_to + 1]["timestamp"] + ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"] # we make the assumption that the timestamps are sorted ep_first_ts = ep_timestamps[0] diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index d8c697c2..6c5b757d 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -215,8 +215,8 @@ def eval_policy( "timestamp": torch.arange(0, num_frames, 1) / fps, "next.done": dones[ep_id, :num_frames], "next.reward": rewards[ep_id, :num_frames].type(torch.float32), - "episode_data_id_from": torch.tensor([idx_from] * num_frames), - "episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames), + "episode_data_index_from": torch.tensor([idx_from] * num_frames), + "episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames), } for key in observations: ep_dict[key] = observations[key][ep_id][:num_frames] diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 71bc1e72..19218355 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -141,15 +141,15 @@ def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_ online_dataset.data_dict = data_dict else: # find episode index and data frame indices according to previous episode in online_dataset - start_episode = online_dataset.data_dict["episode_id"][-1].item() + 1 - start_index = online_dataset.data_dict["index"][-1].item() + 1 + start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1 + start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1 def shift_indices(example): # note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to example["episode_id"] += start_episode example["index"] += start_index - example["episode_data_id_from"] += start_index - example["episode_data_id_to"] += start_index + example["episode_data_index_from"] += start_index + example["episode_data_index_to"] += start_index return example disable_progress_bar() # map has a tqdm progress bar diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 10ed98d5..e7bd0693 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -77,7 +77,7 @@ def render_dataset(dataset, out_dir, max_num_episodes): # add current frame to list of frames to render frames[im_key].append(item[im_key]) - end_of_episode = item["index"].item() == item["episode_data_id_to"].item() + end_of_episode = item["index"].item() == item["episode_data_index_to"].item() out_dir.mkdir(parents=True, exist_ok=True) for im_key in dataset.image_keys: diff --git a/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow index 4d357e343b933615f748c31e5f38c0dfaf7731b5..165298cfb9b96413d383ba05c7b9d5a88bdf2542 100644 GIT binary patch delta 3779 zcmeI#S#VQT7{Kv!Z|_YLpf(8DvbfNKDYei_Wt9pFBJM0@5vpBmYAb|B8W6Ol)glN8 zG#sr_L96WhrW6pREGl9}6wtDlH+*o086Ag*((u1GxsGv}@zt4}%rEzR=R4nBPEHqWjAzBn6Rv-sGaFy`}lw%+?jFvrPk}LPO27?63Futm;8+}D}(ocag%V?e8_Lh{D z7P@l_^L_cb=4EqzrHr@yslMS~>J$FCzGz0Nr;8ZE^O84jK$s^}Y{m_AlOC0Z%B386xij$Ag6wSObde+M5 z^b6+r+-2sNCNvpmtFAm1OGs-9A6HLnHsdN^=s47j_(j68H9t@u%7@e9zs7l5I7x&DdplNZm8oM6S)3atN2 z-ioaYj9MC5Z^c&r%F4*R8Ou0(vh>;$y|sT1=*jKo72g@^G79>rsL98cg$jK&y@ z#W-YRJSJcwCgCYe#uVh>X*`2mUi5-L>zvNvochD#Ik|mjW+&oQfArV<(GKUccklh(x%aos z+8jaTLF*t1DdqOx3_!{=r-6mKYy`_*u0gb<=d6(Wp5 zMTnhfKr#H-M?3-zr84BonVNqnnWUo3<73V5#&=~@7T?1#dNGhFmh1E!zn3ibT0L~x zuLWg7_lCkW4qA;`fUd@^#%;!>W6{dG%38`&`en&#ZTz&+)GC)()z(BS%c2!^6=hY? zvbq`}4*VsNl0$N#Yg|$AAG$7HQBzH`h%g|9nRlystKv2im@I#CP4H zmlYSk$yr*VnwONzwApGjDYt3`YHL#VX;ZSB%;6N#ZIuCI>UQ!@f(INyFNG76i|o9?Ycd|_RQgNyKcvevyFJk&;(Y)-|M}pgz~0a z19mR$H#;RZ44CDg32n6My&iecGf8^AVY$THAtT;!+u&}O*4)|H-__YTuq`w8t2K~t zsEj!^acW+kxkJC)n%8*eoJ_!&-^m2b`5jEan%}tuV$Lo+#zbj}45e_xqatJ_RaV z(4fN&54;^go|+rN^mJI!{sQ&6}S>t;c8rig(yQgDzFIGB8o~Z#&uYNDqN2n zuoTr;h8ir#jaY$N)Zr%Fj9YLkR^m3?j#an=cOr(>xC?iq9t~(j6PmFG_uyWv#eG`G^h~oiBv|knKK9}Re29;*4sG zD&78>If~z3?vxMIn!Ktrr>3Sx2+<#2Rfrx%hz=AX z6+zlxz=t`|utBC-Qfi;|_Vp1c`TW&=57iabK|LLzUZFA^Zf|*IMVY&}tfZ=>*jTo> zs)F{OU&g=q$M}T5kKcE!!ZS+DXTRi)7ZAo|iZ%2CNz$W@*1s@ihV*Mr`xb zj@r^IziPQ^cdyLWr>Q-?vQ$sDw;59@B3(A?2}-)WuRFB^w8?;;F|s|Jj8A0lEQB`4 zNeP$qSW>L_$iDBBZIZPZu+?7U~CNvpu zZKz#pC8QKYo{8T&jo9@){LD5YA=)Fc**{?Eie$o%|BdUmhx@yB8RKhzP7-POBQ^&F z9{;vrB(4kXn&NsMoi}4MA9x#`567&ZAKCf%chMcp*gS#4-|}W`F7W62#uZ#eGp^w}THr@3ByQj)+Hec)xQ#owiw<-?^#!{EmM5D1%Lf_Y mG6rq7nAjP_?m+Ag#O^@s4*bU*a6D;xRUK#JlQi^^EPenjTbrl= delta 4671 zcmeIzTW}0n7{Kw8QY#u^ObE`qpZV#tPJSwY-yK@jT_VPa*8goF*d%tYcEw{haS zYw+qtd5WS+dGL@2r97%U<$cF_9op3OODAU|uU-GPw@I9L0k~u{^G4dkv~+vqtE$L-Wgrx)BJ{ z`O50ld<->jHE%Ny9UFGnRn*d!GA>hAX}yPcC#tBatgQ}Ll!eRd%F8OlWp&j;y#1<> zl8lJ3}|^sIH=0gm@rLvhJ4Tt%}<$V6c4V>gUlY!Yt#S^k_MrTUK7YmJ_vn zbwomL(V zpAd46o~sDi)7=<}t|aELTPyni`182y_Q+J5ilgQ%9MRiWYoG1d&WpBWCSvDBszVgI zt**icUsC3k9gj1Lv6&ViBt2k-VXie z*1YET&B+3i-**QKu=bmC1z7Vtw*(Pqx1UJwf9KZTT>A9>ckt!I(x>;ovp+!T_ka5R z@8ll8jh2F?S zAM`~(^v3`U#2^et7KR`jLop1)k%JK!iBTAhTm+DZF$iKT#vva+q5$JD0U=DpB>aTQ zn1Y`%6@{3F>6n2c%)~6r#vIJWJj}-eEW{!#MlqIPDVCuG%drA0QHoXg1*@?JYq1Vx zSdVgSz(#~oflc@on^B41umxLDg>9%t4Yp$kYEg%s*oEIwkKNdVKkz5^Vjm)iq5=ES zh$bAsK^($i9KlgE;~0+P1Ww`|g3E|hY^ zfrBbkIVj~1P)?4Hb4V$NnDOnNX=N-`{(yN?pXv8|zxR9H{puY$_${nGf*zqc1A3nF zAuy#-2?VO$@}Zjfr0uby)l_*JWS&|a?ReFRdbK&^Q{9D4O-(|G-tehH^e95yK`BNf z#Q1CYF&!G#$T3M-%`d}|2!T?v@?zM{;$ZCQOS z<2}FBzy3#k`rqq!AFA~Z5k&Fy#}&52U^Mk`dW zgymOlrurc)3-nyICoC)VOh>1=lp?Zahn}uv%Llq!+sl~ztmh2wiY8MsSUVra%du0U zTzZqTY(cUgBeF9ov-lldrx+yc!X|o&WTrUz8BNj52cu{0j84DuSiPst9Mgm*Q*5ox z3v7h6C9&(Kr}xY__#%2OFe4#8#$s!}>s2%#ON;-FTQ)~4f@$}SPP0`nVgrA~R)MR= zEcltX;uhnrX~D+$ycJt{S9^Ru8gu`W*qfJk#80qdYX`c1%UiK^gN$48^;T@tD**%*iMn1CEiL@x4>k4Y#%AtqxAreYeVV+Lko z7K%`e*_eY8%*8y+#{w+GA}q!dEX6V`#|o^(Dy+sDti?L4#|CV~CTzwQY{fQgM=5q- zC(2Nc3RGejs^GzH?7?1CV;}b80KBL{EqpkLL#RVN4&w-pq5;Ryh~x012`A8u7M#Q> zw4x2CaR%);i*q=S3%H05T*75sK>$}FaShkei5uv`P29q5+`--F{?I+wt7n>{`;+|4 mVGP;r35h*OoPop{NSuMh8TgMg;Cy!WZ9}r1Ptu6-LG}aB>Xo_x delta 4672 zcmeIzUsM!T9Ki8A!^|>ki*2$hMIeF-*)FUSwNl7dSZJxFgqVL8Y%!@211jj>u0m=e z!g?tx73i%~4_PmGs#84hz=Ma=Gf%#GeDE9}eBgLaY~R`0$t-)Gm%V#FbLV&Oow>iA zyLac?g+*-=dXi!!^a^DrQB|o7CfZAX{B2xOqOwJ;;rNXjNDP$AIW-tqB7_)Gg$Ux7 zBE%I8paC9Sr9T7>%`%Wxvev(tOcE$z4yXS#&y}1!&Vw*VaU;vg{(vSA9XQ9M@v1+RU?@4asJH{Km&HT*ZOOD8ma)9P#b zX*_k3H#w22?SvZAah*%v$u5+J5tK)aNf|PNW8v>>wJHDX`0e@Fu(>i6hw7M96Q}28 zy&d}I*1YEL&B+8(zjp@{NX_ry6R_rYE`gY{+fQWnzjJGEE`4VIJNWWp=`;J^**~Dn z`#PT#upW=Yj{wS1fgmPpYOoo#2;*5ihv!j;EqDPhq8?kZ4KLwkY{w4lL<3&IF6>4lUPTl3pc#Ad8unp7 z4&WdnXhACu;dMmOhQoLRN6?Np@fO}j2aci>$M6oi(2X9vi}&z8deMgu@F70J$M^&> z#Bm%a(2oIpijz2n)A$T$Fo?4l!Z{3M1PPpn#3(+;7Z}5r_zGX+0xsgxl?+39Ki8A)0s$UiM2s(EmD+LYb~Yro!Yn3MS~(5S`?{<+SjI+inX=xOL_3( zfj1m~faA%d<0%i0r)qsO$vARwPXB=1oKNQd?(hEY%$>P`h7Wop^j4bPucs>?{iCxK zzrWNe@2bg7haW21_7Yd6OjmQl9k)7MFExfdsxzyqs!9ma9v)SQW<`jr$j3kg8Gi#W z#z8}#92A#Q{ko;Kl|aem&-XLbS5!OwRD@=QT6efSrR8NMu7Z-HilPGJwgnYsj5q&L z|K>mSiT|$OxULZn_FbDo@}c(mQqBjJfkNnsq&8Q)HX9FzvyW_1L0(XgS9gd+8l366iRno zilve4r&d`Xmz?vCu3dB!R$&qCL_AaM{6bC9j0dAhj4Ez-=a&YQ8B53GsKhhy%a8`=4%!srQRoEctU*o&g`W^CS|)*YQUV>55x7o9g^ znP5$k9;-usMTWJp&ZgCPgU7r<1uEj8L5BtHV1*6s(E;&DKu2_f9i7nyUC|BQ(F2Ka zpeK5vHZy4 zB62VZlaY%ln2Kqbjv1JVS(uGEn2ULsj|EtWMaaWqEWuJN!*Z;^O02?atif93V;$C` z0EH+*F*cwCE^NdmY(^=zU@Nx4jqNCd2RpD6<*2|e?8Y8cVlVb#KfI{I0aU|>gE)j5 z9L5nG#W5Vm37o_!oW>cP#W~c%k2*-4#|6~mA_BOC%eaE8xQ6R!KqGG8CT`(2?%*!& z;XWSVAs*o|p5Q5-;rT0XuqjaYQnP*kEWdEDgH~%y>>b3;KTuAk7fuw3Dq#_^OZ5Npntw_U+Nh7h8~B}5SS zH6i{)GfI$&Tl9yZqg)2kvPb7ViN^`FA`ZJBt8>kp#d#2F6gLtja<7q{X_7^^HNuep z>VEl;VFrRU-nSZcA4Aop>M~UftZb~Ws-rGvT$ZfS2aM@URJEzPt~Oj%5w5JStf&rG z)Yl5}@PCPvK9S4Q`WO0NGIaIk+8UZgh#OKR@@`FBs<~AH!{yzyK^~nXR2ny>N6+zG zwfe;aIY}>Y<;LYkeX=VYmzVT>S8H7M=p#q$Ox#WrH>}#Hrp|-939p+7$stCL-zWUS z6mF3&GHLViCxo13ciO^FE0Eg%&hGz=-p6lSY*nYf$?MtW^=NP*4LS_Ckq#4TOa@* z=!-YuM?dt(0Ayhx-ohXZ#t;m}Fbu~Cj6^m@VKl}d2f26~@8DhJA%J|0MG)gK9tD_y zLQKSa2w@T?V+y8X8s5ir6k!HFz=tTtOniioF$=RX2XiqGpI|;dMF|$*Gb}_Y7U6T0 zp&VaeF_vH{mSH(6umY7>iB$-r3ahaOYf+7LSdR^;!A8_#6TZY|)S(_<;cIL`0~)av z-{4z(hwl+V6x*;JO=!jr?8GkY#t--rd+-zXVjuRS1u-0eL@N&B5ZZ7UM{pF!@H39% z1WuwIr|=6-;|$K?9M0ncF5(h?#bx}4-*E+hba_1;Z3kr3t#{>?bv7h^^NFi(IsXAf CrbYMw diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json b/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json index 473812f3..542c7bf1 100644 --- a/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json +++ b/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json @@ -2,6 +2,9 @@ "citation": "", "description": "", "features": { + "observation.images.top": { + "_type": "Image" + }, "observation.state": { "feature": { "dtype": "float32", @@ -34,17 +37,14 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_id_from": { + "episode_data_index_from": { "dtype": "int64", "_type": "Value" }, - "episode_data_id_to": { + "episode_data_index_to": { "dtype": "int64", "_type": "Value" }, - "observation.images.top": { - "_type": "Image" - }, "index": { "dtype": "int64", "_type": "Value" diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/state.json b/tests/data/aloha_sim_transfer_cube_scripted/train/state.json index ee3cc1fe..56005bc9 100644 --- a/tests/data/aloha_sim_transfer_cube_scripted/train/state.json +++ b/tests/data/aloha_sim_transfer_cube_scripted/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "43f176a3740fe622", + "_fingerprint": "93e03c6320c7d56e", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/data/pusht/train/data-00000-of-00001.arrow b/tests/data/pusht/train/data-00000-of-00001.arrow index 71f657ab836a7d936f9e8bd1d31f959f17c4ef52..9a36a8db0fef662afd92ec87297e5d552fead22e 100644 GIT binary patch delta 529 zcmZozz|*mShxI=Y+?dF!&uB2w)R^(x#>@yNmb{eIiphtVY{864%yFhX3=9l6m>3vR zfS3h{k1#SY^Z;=L5UT+37ohwBAg%#o3m|3z;xCh@GAr}O{QLhO$YFE<5-O7!S;VD8 zfvnVm%;Nl%)cBOdlEnB-kah7T`3#dkF{??7W0y`V%Fk^!W@$HOVFY5P?Zzz3TNx*N ca7s)L;1n3eWQBpm@&(_RSUEr`02t^D0Gy7C1poj5 delta 1511 zcmeBJz|*jRhxI=YT$sqJ&!{ue)R^(Z#>@yN#+1oNnQU27;!E-;pJLLQY|LzE&cMLK zz`$^YiGd*jh#7$R03!oK2M~tuf(01EE`;tC+P0Ady({xNwbv$9pfzyJS%97YEq zq45{QW@HiovQ>Z>qzVLtfH<`vvp7E`H9jS=Br!e{Z2d20HE9tnQfWo`xj^A&YnFCv z7Dgau+HTFlyp^$j0ZG_`h--k@0*G0FSY>h~v$9>!zyJS%97YEq5%L$rW@Hio zvNeDhqzVK?fjG4wvp7E`H9jS=Br!e{WPf}~KEvc~%xco&*rn5o@^d$TXD(%GE@0bU Yz{a?sfAR(nj>#8jiUp=Wn8tVl05hjx&Hw-a delta 414 zcmdn6o^8c?HrD?@kTH=}pYh8?Q)9;68#9%e7*i%sV6p|07nx!WxfvK37BDd|=m7B< zMh1o+APxay1t5L{l-mQu6+mnO#4JFpF*%W0*>l3b|Nns;Mh74f@fXBqWD)?fRe%_z z3Iv3JIJF?NI6oydJ|(dvF+MXTz9b)L3zFU`lP5BZON(GpnO2ma3l!eW%u>q4$k1HM cw!M^%adCgW1W*tJ28ec|+~Vmsr!ig!01Ada00000 diff --git a/tests/data/xarm_lift_medium/train/dataset_info.json b/tests/data/xarm_lift_medium/train/dataset_info.json index 81ba7c8c..bb647c41 100644 --- a/tests/data/xarm_lift_medium/train/dataset_info.json +++ b/tests/data/xarm_lift_medium/train/dataset_info.json @@ -41,11 +41,11 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_id_from": { + "episode_data_index_from": { "dtype": "int64", "_type": "Value" }, - "episode_data_id_to": { + "episode_data_index_to": { "dtype": "int64", "_type": "Value" }, diff --git a/tests/data/xarm_lift_medium/train/state.json b/tests/data/xarm_lift_medium/train/state.json index 500bdb85..c930c52c 100644 --- a/tests/data/xarm_lift_medium/train/state.json +++ b/tests/data/xarm_lift_medium/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "7dcd82fc3815bba6", + "_fingerprint": "a95cbec45e3bb9d6", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/test_datasets.py b/tests/test_datasets.py index c40d478a..85ddb00f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -95,12 +95,14 @@ def test_compute_stats(): """ from lerobot.common.datasets.xarm import XarmDataset + DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None # get transform to convert images from uint8 [0,255] to float32 [0,1] transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0) dataset = XarmDataset( dataset_id="xarm_lift_medium", + root=DATA_DIR, transform=transform, ) @@ -115,11 +117,11 @@ def test_compute_stats(): # get all frames from the dataset in the same dtype and range as during compute_stats dataloader = torch.utils.data.DataLoader( dataset, - num_workers=16, + num_workers=8, batch_size=len(dataset), shuffle=False, ) - data_dict = next(iter(dataloader)) # takes 23 seconds + data_dict = next(iter(dataloader)) # compute stats based on all frames from the dataset without any batching expected_stats = {} @@ -154,8 +156,8 @@ def test_load_previous_and_future_frames_within_tolerance(): data_dict = Dataset.from_dict({ "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_data_id_from": [0, 0, 0, 0, 0], - "episode_data_id_to": [4, 4, 4, 4, 4], + "episode_data_index_from": [0, 0, 0, 0, 0], + "episode_data_index_to": [4, 4, 4, 4, 4], }) data_dict = data_dict.with_format("torch") item = data_dict[2] @@ -170,8 +172,8 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range( data_dict = Dataset.from_dict({ "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_data_id_from": [0, 0, 0, 0, 0], - "episode_data_id_to": [4, 4, 4, 4, 4], + "episode_data_index_from": [0, 0, 0, 0, 0], + "episode_data_index_to": [4, 4, 4, 4, 4], }) data_dict = data_dict.with_format("torch") item = data_dict[2] @@ -184,8 +186,8 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range data_dict = Dataset.from_dict({ "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_data_id_from": [0, 0, 0, 0, 0], - "episode_data_id_to": [4, 4, 4, 4, 4], + "episode_data_index_from": [0, 0, 0, 0, 0], + "episode_data_index_to": [4, 4, 4, 4, 4], }) data_dict = data_dict.with_format("torch") item = data_dict[2]