pre-commit run -a
This commit is contained in:
parent
1ae6205269
commit
45b4ecb727
|
@ -137,13 +137,16 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
mean_std = self._compute_or_load_mean_std(storage)
|
||||
mean_std["next", "observation", "image"] = mean_std["observation", "image"]
|
||||
mean_std["next", "observation", "state"] = mean_std["observation", "state"]
|
||||
transform = NormalizeTransform(mean_std, in_keys=[
|
||||
transform = NormalizeTransform(
|
||||
mean_std,
|
||||
in_keys=[
|
||||
("observation", "image"),
|
||||
("observation", "state"),
|
||||
("next", "observation", "image"),
|
||||
("next", "observation", "state"),
|
||||
("action"),
|
||||
])
|
||||
],
|
||||
)
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
|
@ -185,7 +188,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||
|
||||
# load
|
||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
|
||||
zarr_path
|
||||
) # , keys=['img', 'state', 'action'])
|
||||
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
|
@ -291,8 +296,8 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
action_mean = torch.zeros(batch["action"].shape[1])
|
||||
action_std = torch.zeros(batch["action"].shape[1])
|
||||
|
||||
for i in tqdm.tqdm(range(num_batch)):
|
||||
image_mean += einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
|
||||
state_mean += batch["observation", "state"].mean(dim=0)
|
||||
action_mean += batch["action"].mean(dim=0)
|
||||
batch = rb.sample()
|
||||
|
@ -302,7 +307,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
action_mean /= num_batch
|
||||
|
||||
for i in tqdm.tqdm(range(num_batch)):
|
||||
image_mean_batch = einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
|
||||
image_mean_batch = einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
|
||||
image_std += (image_mean_batch - image_mean) ** 2
|
||||
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
|
||||
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
|
||||
|
@ -315,12 +320,12 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
|
||||
mean_std = TensorDict(
|
||||
{
|
||||
("observation", "image", "mean"): image_mean[None,:,None,None],
|
||||
("observation", "image", "std"): image_std[None,:,None,None],
|
||||
("observation", "state", "mean"): state_mean[None,:],
|
||||
("observation", "state", "std"): state_std[None,:],
|
||||
("action", "mean"): action_mean[None,:],
|
||||
("action", "std"): action_std[None,:],
|
||||
("observation", "image", "mean"): image_mean[None, :, None, None],
|
||||
("observation", "image", "std"): image_std[None, :, None, None],
|
||||
("observation", "state", "mean"): state_mean[None, :],
|
||||
("observation", "state", "std"): state_std[None, :],
|
||||
("action", "mean"): action_mean[None, :],
|
||||
("action", "std"): action_std[None, :],
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
|
|
|
@ -3,11 +3,9 @@ import zipfile
|
|||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from tensordict import TensorDictBase
|
||||
import tqdm
|
||||
|
||||
|
||||
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
print(f"downloading from {url}")
|
||||
response = requests.get(url, stream=True)
|
||||
|
@ -30,4 +28,3 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
|||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from typing import Sequence
|
||||
from tensordict import TensorDictBase
|
||||
|
||||
from tensordict.utils import NestedKey
|
||||
from torchrl.envs.transforms import ObservationTransform
|
||||
from torchrl.envs.transforms import Transform
|
||||
from tensordict import TensorDictBase
|
||||
from tensordict.nn import dispatch
|
||||
from tensordict.utils import NestedKey
|
||||
from torchrl.envs.transforms import ObservationTransform, Transform
|
||||
|
||||
|
||||
class Prod(ObservationTransform):
|
||||
|
@ -27,7 +26,8 @@ class Prod(ObservationTransform):
|
|||
class NormalizeTransform(Transform):
|
||||
invertible = True
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
mean_std: TensorDictBase,
|
||||
in_keys: Sequence[NestedKey] = None,
|
||||
out_keys: Sequence[NestedKey] | None = None,
|
||||
|
@ -40,7 +40,9 @@ class NormalizeTransform(Transform):
|
|||
in_keys_inv = out_keys
|
||||
if out_keys_inv is None:
|
||||
out_keys_inv = in_keys
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv)
|
||||
super().__init__(
|
||||
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
||||
)
|
||||
self.mean_std = mean_std
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
|
@ -48,7 +50,7 @@ class NormalizeTransform(Transform):
|
|||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys):
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
|
@ -58,7 +60,7 @@ class NormalizeTransform(Transform):
|
|||
return td
|
||||
|
||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv):
|
||||
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
|
|
|
@ -5,6 +5,7 @@ import hydra
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
|
|
|
@ -5,7 +5,6 @@ from copy import deepcopy
|
|||
|
||||
import einops
|
||||
import numpy as np
|
||||
from tensordict import TensorDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -147,10 +146,7 @@ class TDMPC(nn.Module):
|
|||
@torch.no_grad()
|
||||
def act(self, obs, t0=False, step=None):
|
||||
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
||||
if isinstance(obs, dict):
|
||||
obs = {k: o.detach() for k, o in obs.items()}
|
||||
else:
|
||||
obs = obs.detach()
|
||||
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
|
||||
z = self.model.encode(obs)
|
||||
if self.cfg.mpc:
|
||||
a = self.plan(z, t0=t0, step=step)
|
||||
|
|
|
@ -11,8 +11,8 @@ import tqdm
|
|||
from tensordict.nn import TensorDictModule
|
||||
from termcolor import colored
|
||||
from torchrl.envs import EnvBase
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import set_seed
|
||||
|
|
Loading…
Reference in New Issue