pre-commit run -a

This commit is contained in:
Remi Cadene 2024-03-02 15:58:21 +00:00
parent 1ae6205269
commit 45b4ecb727
6 changed files with 44 additions and 43 deletions

View File

@ -137,13 +137,16 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
mean_std = self._compute_or_load_mean_std(storage) mean_std = self._compute_or_load_mean_std(storage)
mean_std["next", "observation", "image"] = mean_std["observation", "image"] mean_std["next", "observation", "image"] = mean_std["observation", "image"]
mean_std["next", "observation", "state"] = mean_std["observation", "state"] mean_std["next", "observation", "state"] = mean_std["observation", "state"]
transform = NormalizeTransform(mean_std, in_keys=[ transform = NormalizeTransform(
("observation", "image"), mean_std,
("observation", "state"), in_keys=[
("next", "observation", "image"), ("observation", "image"),
("next", "observation", "state"), ("observation", "state"),
("action"), ("next", "observation", "image"),
]) ("next", "observation", "state"),
("action"),
],
)
if writer is None: if writer is None:
writer = ImmutableDatasetWriter() writer = ImmutableDatasetWriter()
@ -185,7 +188,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
download_and_extract_zip(PUSHT_URL, raw_dir) download_and_extract_zip(PUSHT_URL, raw_dir)
# load # load
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
zarr_path
) # , keys=['img', 'state', 'action'])
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs()) episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0] num_episodes = dataset_dict.meta["episode_ends"].shape[0]
@ -291,8 +296,8 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
action_mean = torch.zeros(batch["action"].shape[1]) action_mean = torch.zeros(batch["action"].shape[1])
action_std = torch.zeros(batch["action"].shape[1]) action_std = torch.zeros(batch["action"].shape[1])
for i in tqdm.tqdm(range(num_batch)): for _ in tqdm.tqdm(range(num_batch)):
image_mean += einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean') image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
state_mean += batch["observation", "state"].mean(dim=0) state_mean += batch["observation", "state"].mean(dim=0)
action_mean += batch["action"].mean(dim=0) action_mean += batch["action"].mean(dim=0)
batch = rb.sample() batch = rb.sample()
@ -302,25 +307,25 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
action_mean /= num_batch action_mean /= num_batch
for i in tqdm.tqdm(range(num_batch)): for i in tqdm.tqdm(range(num_batch)):
image_mean_batch = einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean') image_mean_batch = einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
image_std += (image_mean_batch - image_mean) ** 2 image_std += (image_mean_batch - image_mean) ** 2
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2 state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2 action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
if i < num_batch - 1: if i < num_batch - 1:
batch = rb.sample() batch = rb.sample()
image_std = torch.sqrt(image_std / num_batch) image_std = torch.sqrt(image_std / num_batch)
state_std = torch.sqrt(state_std / num_batch) state_std = torch.sqrt(state_std / num_batch)
action_std = torch.sqrt(action_std / num_batch) action_std = torch.sqrt(action_std / num_batch)
mean_std = TensorDict( mean_std = TensorDict(
{ {
("observation", "image", "mean"): image_mean[None,:,None,None], ("observation", "image", "mean"): image_mean[None, :, None, None],
("observation", "image", "std"): image_std[None,:,None,None], ("observation", "image", "std"): image_std[None, :, None, None],
("observation", "state", "mean"): state_mean[None,:], ("observation", "state", "mean"): state_mean[None, :],
("observation", "state", "std"): state_std[None,:], ("observation", "state", "std"): state_std[None, :],
("action", "mean"): action_mean[None,:], ("action", "mean"): action_mean[None, :],
("action", "std"): action_std[None,:], ("action", "std"): action_std[None, :],
}, },
batch_size=[], batch_size=[],
) )

View File

@ -3,11 +3,9 @@ import zipfile
from pathlib import Path from pathlib import Path
import requests import requests
from tensordict import TensorDictBase
import tqdm import tqdm
def download_and_extract_zip(url: str, destination_folder: Path) -> bool: def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}") print(f"downloading from {url}")
response = requests.get(url, stream=True) response = requests.get(url, stream=True)
@ -30,4 +28,3 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return True return True
else: else:
return False return False

View File

@ -1,10 +1,9 @@
from typing import Sequence from typing import Sequence
from tensordict import TensorDictBase
from tensordict.utils import NestedKey from tensordict import TensorDictBase
from torchrl.envs.transforms import ObservationTransform
from torchrl.envs.transforms import Transform
from tensordict.nn import dispatch from tensordict.nn import dispatch
from tensordict.utils import NestedKey
from torchrl.envs.transforms import ObservationTransform, Transform
class Prod(ObservationTransform): class Prod(ObservationTransform):
@ -27,28 +26,31 @@ class Prod(ObservationTransform):
class NormalizeTransform(Transform): class NormalizeTransform(Transform):
invertible = True invertible = True
def __init__(self, def __init__(
mean_std: TensorDictBase, self,
in_keys: Sequence[NestedKey] = None, mean_std: TensorDictBase,
out_keys: Sequence[NestedKey] | None = None, in_keys: Sequence[NestedKey] = None,
in_keys_inv: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None,
): out_keys_inv: Sequence[NestedKey] | None = None,
):
if out_keys is None: if out_keys is None:
out_keys = in_keys out_keys = in_keys
if in_keys_inv is None: if in_keys_inv is None:
in_keys_inv = out_keys in_keys_inv = out_keys
if out_keys_inv is None: if out_keys_inv is None:
out_keys_inv = in_keys out_keys_inv = in_keys
super().__init__(in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv) super().__init__(
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
)
self.mean_std = mean_std self.mean_std = mean_std
@dispatch(source="in_keys", dest="out_keys") @dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict) return self._call(tensordict)
def _call(self, td: TensorDictBase) -> TensorDictBase: def _call(self, td: TensorDictBase) -> TensorDictBase:
for inkey, outkey in zip(self.in_keys, self.out_keys): for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
# TODO(rcadene): don't know how to do `inkey not in td` # TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None: if td.get(inkey, None) is None:
continue continue
@ -58,7 +60,7 @@ class NormalizeTransform(Transform):
return td return td
def _inv_call(self, td: TensorDictBase) -> TensorDictBase: def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv): for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
# TODO(rcadene): don't know how to do `inkey not in td` # TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None: if td.get(inkey, None) is None:
continue continue

View File

@ -5,6 +5,7 @@ import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder

View File

@ -5,7 +5,6 @@ from copy import deepcopy
import einops import einops
import numpy as np import numpy as np
from tensordict import TensorDict
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -127,7 +126,7 @@ class TDMPC(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, observation, step_count): def forward(self, observation, step_count):
t0 = step_count.item() == 0 t0 = step_count.item() == 0
# TODO(rcadene): remove unsqueeze hack... # TODO(rcadene): remove unsqueeze hack...
if observation["image"].ndim == 3: if observation["image"].ndim == 3:
observation["image"] = observation["image"].unsqueeze(0) observation["image"] = observation["image"].unsqueeze(0)
@ -147,10 +146,7 @@ class TDMPC(nn.Module):
@torch.no_grad() @torch.no_grad()
def act(self, obs, t0=False, step=None): def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
if isinstance(obs, dict): obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
obs = {k: o.detach() for k, o in obs.items()}
else:
obs = obs.detach()
z = self.model.encode(obs) z = self.model.encode(obs)
if self.cfg.mpc: if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step) a = self.plan(z, t0=t0, step=step)

View File

@ -11,8 +11,8 @@ import tqdm
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from termcolor import colored from termcolor import colored
from torchrl.envs import EnvBase from torchrl.envs import EnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import set_seed from lerobot.common.utils import set_seed