pre-commit run -a
This commit is contained in:
parent
1ae6205269
commit
45b4ecb727
|
@ -137,13 +137,16 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
mean_std = self._compute_or_load_mean_std(storage)
|
mean_std = self._compute_or_load_mean_std(storage)
|
||||||
mean_std["next", "observation", "image"] = mean_std["observation", "image"]
|
mean_std["next", "observation", "image"] = mean_std["observation", "image"]
|
||||||
mean_std["next", "observation", "state"] = mean_std["observation", "state"]
|
mean_std["next", "observation", "state"] = mean_std["observation", "state"]
|
||||||
transform = NormalizeTransform(mean_std, in_keys=[
|
transform = NormalizeTransform(
|
||||||
("observation", "image"),
|
mean_std,
|
||||||
("observation", "state"),
|
in_keys=[
|
||||||
("next", "observation", "image"),
|
("observation", "image"),
|
||||||
("next", "observation", "state"),
|
("observation", "state"),
|
||||||
("action"),
|
("next", "observation", "image"),
|
||||||
])
|
("next", "observation", "state"),
|
||||||
|
("action"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
if writer is None:
|
if writer is None:
|
||||||
writer = ImmutableDatasetWriter()
|
writer = ImmutableDatasetWriter()
|
||||||
|
@ -185,7 +188,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
download_and_extract_zip(PUSHT_URL, raw_dir)
|
download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||||
|
|
||||||
# load
|
# load
|
||||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
|
||||||
|
zarr_path
|
||||||
|
) # , keys=['img', 'state', 'action'])
|
||||||
|
|
||||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||||
|
@ -291,8 +296,8 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
action_mean = torch.zeros(batch["action"].shape[1])
|
action_mean = torch.zeros(batch["action"].shape[1])
|
||||||
action_std = torch.zeros(batch["action"].shape[1])
|
action_std = torch.zeros(batch["action"].shape[1])
|
||||||
|
|
||||||
for i in tqdm.tqdm(range(num_batch)):
|
for _ in tqdm.tqdm(range(num_batch)):
|
||||||
image_mean += einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
|
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
|
||||||
state_mean += batch["observation", "state"].mean(dim=0)
|
state_mean += batch["observation", "state"].mean(dim=0)
|
||||||
action_mean += batch["action"].mean(dim=0)
|
action_mean += batch["action"].mean(dim=0)
|
||||||
batch = rb.sample()
|
batch = rb.sample()
|
||||||
|
@ -302,25 +307,25 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
action_mean /= num_batch
|
action_mean /= num_batch
|
||||||
|
|
||||||
for i in tqdm.tqdm(range(num_batch)):
|
for i in tqdm.tqdm(range(num_batch)):
|
||||||
image_mean_batch = einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
|
image_mean_batch = einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
|
||||||
image_std += (image_mean_batch - image_mean) ** 2
|
image_std += (image_mean_batch - image_mean) ** 2
|
||||||
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
|
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
|
||||||
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
|
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
|
||||||
if i < num_batch - 1:
|
if i < num_batch - 1:
|
||||||
batch = rb.sample()
|
batch = rb.sample()
|
||||||
|
|
||||||
image_std = torch.sqrt(image_std / num_batch)
|
image_std = torch.sqrt(image_std / num_batch)
|
||||||
state_std = torch.sqrt(state_std / num_batch)
|
state_std = torch.sqrt(state_std / num_batch)
|
||||||
action_std = torch.sqrt(action_std / num_batch)
|
action_std = torch.sqrt(action_std / num_batch)
|
||||||
|
|
||||||
mean_std = TensorDict(
|
mean_std = TensorDict(
|
||||||
{
|
{
|
||||||
("observation", "image", "mean"): image_mean[None,:,None,None],
|
("observation", "image", "mean"): image_mean[None, :, None, None],
|
||||||
("observation", "image", "std"): image_std[None,:,None,None],
|
("observation", "image", "std"): image_std[None, :, None, None],
|
||||||
("observation", "state", "mean"): state_mean[None,:],
|
("observation", "state", "mean"): state_mean[None, :],
|
||||||
("observation", "state", "std"): state_std[None,:],
|
("observation", "state", "std"): state_std[None, :],
|
||||||
("action", "mean"): action_mean[None,:],
|
("action", "mean"): action_mean[None, :],
|
||||||
("action", "std"): action_std[None,:],
|
("action", "std"): action_std[None, :],
|
||||||
},
|
},
|
||||||
batch_size=[],
|
batch_size=[],
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,11 +3,9 @@ import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from tensordict import TensorDictBase
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
print(f"downloading from {url}")
|
print(f"downloading from {url}")
|
||||||
response = requests.get(url, stream=True)
|
response = requests.get(url, stream=True)
|
||||||
|
@ -30,4 +28,3 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
from tensordict import TensorDictBase
|
|
||||||
|
|
||||||
from tensordict.utils import NestedKey
|
from tensordict import TensorDictBase
|
||||||
from torchrl.envs.transforms import ObservationTransform
|
|
||||||
from torchrl.envs.transforms import Transform
|
|
||||||
from tensordict.nn import dispatch
|
from tensordict.nn import dispatch
|
||||||
|
from tensordict.utils import NestedKey
|
||||||
|
from torchrl.envs.transforms import ObservationTransform, Transform
|
||||||
|
|
||||||
|
|
||||||
class Prod(ObservationTransform):
|
class Prod(ObservationTransform):
|
||||||
|
@ -27,28 +26,31 @@ class Prod(ObservationTransform):
|
||||||
class NormalizeTransform(Transform):
|
class NormalizeTransform(Transform):
|
||||||
invertible = True
|
invertible = True
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
mean_std: TensorDictBase,
|
self,
|
||||||
in_keys: Sequence[NestedKey] = None,
|
mean_std: TensorDictBase,
|
||||||
out_keys: Sequence[NestedKey] | None = None,
|
in_keys: Sequence[NestedKey] = None,
|
||||||
in_keys_inv: Sequence[NestedKey] | None = None,
|
out_keys: Sequence[NestedKey] | None = None,
|
||||||
out_keys_inv: Sequence[NestedKey] | None = None,
|
in_keys_inv: Sequence[NestedKey] | None = None,
|
||||||
):
|
out_keys_inv: Sequence[NestedKey] | None = None,
|
||||||
|
):
|
||||||
if out_keys is None:
|
if out_keys is None:
|
||||||
out_keys = in_keys
|
out_keys = in_keys
|
||||||
if in_keys_inv is None:
|
if in_keys_inv is None:
|
||||||
in_keys_inv = out_keys
|
in_keys_inv = out_keys
|
||||||
if out_keys_inv is None:
|
if out_keys_inv is None:
|
||||||
out_keys_inv = in_keys
|
out_keys_inv = in_keys
|
||||||
super().__init__(in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv)
|
super().__init__(
|
||||||
|
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
||||||
|
)
|
||||||
self.mean_std = mean_std
|
self.mean_std = mean_std
|
||||||
|
|
||||||
@dispatch(source="in_keys", dest="out_keys")
|
@dispatch(source="in_keys", dest="out_keys")
|
||||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||||
return self._call(tensordict)
|
return self._call(tensordict)
|
||||||
|
|
||||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||||
for inkey, outkey in zip(self.in_keys, self.out_keys):
|
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||||
if td.get(inkey, None) is None:
|
if td.get(inkey, None) is None:
|
||||||
continue
|
continue
|
||||||
|
@ -58,7 +60,7 @@ class NormalizeTransform(Transform):
|
||||||
return td
|
return td
|
||||||
|
|
||||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||||
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv):
|
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
|
||||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||||
if td.get(inkey, None) is None:
|
if td.get(inkey, None) is None:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -5,6 +5,7 @@ import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
|
||||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||||
|
|
|
@ -5,7 +5,6 @@ from copy import deepcopy
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensordict import TensorDict
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
@ -127,7 +126,7 @@ class TDMPC(nn.Module):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, observation, step_count):
|
def forward(self, observation, step_count):
|
||||||
t0 = step_count.item() == 0
|
t0 = step_count.item() == 0
|
||||||
|
|
||||||
# TODO(rcadene): remove unsqueeze hack...
|
# TODO(rcadene): remove unsqueeze hack...
|
||||||
if observation["image"].ndim == 3:
|
if observation["image"].ndim == 3:
|
||||||
observation["image"] = observation["image"].unsqueeze(0)
|
observation["image"] = observation["image"].unsqueeze(0)
|
||||||
|
@ -147,10 +146,7 @@ class TDMPC(nn.Module):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def act(self, obs, t0=False, step=None):
|
def act(self, obs, t0=False, step=None):
|
||||||
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
||||||
if isinstance(obs, dict):
|
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
|
||||||
obs = {k: o.detach() for k, o in obs.items()}
|
|
||||||
else:
|
|
||||||
obs = obs.detach()
|
|
||||||
z = self.model.encode(obs)
|
z = self.model.encode(obs)
|
||||||
if self.cfg.mpc:
|
if self.cfg.mpc:
|
||||||
a = self.plan(z, t0=t0, step=step)
|
a = self.plan(z, t0=t0, step=step)
|
||||||
|
|
|
@ -11,8 +11,8 @@ import tqdm
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torchrl.envs import EnvBase
|
from torchrl.envs import EnvBase
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
|
||||||
|
|
||||||
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import set_seed
|
from lerobot.common.utils import set_seed
|
||||||
|
|
Loading…
Reference in New Issue