pre-commit run -a

This commit is contained in:
Remi Cadene 2024-03-02 15:58:21 +00:00
parent 1ae6205269
commit 45b4ecb727
6 changed files with 44 additions and 43 deletions

View File

@ -137,13 +137,16 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
mean_std = self._compute_or_load_mean_std(storage)
mean_std["next", "observation", "image"] = mean_std["observation", "image"]
mean_std["next", "observation", "state"] = mean_std["observation", "state"]
transform = NormalizeTransform(mean_std, in_keys=[
transform = NormalizeTransform(
mean_std,
in_keys=[
("observation", "image"),
("observation", "state"),
("next", "observation", "image"),
("next", "observation", "state"),
("action"),
])
],
)
if writer is None:
writer = ImmutableDatasetWriter()
@ -185,7 +188,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
download_and_extract_zip(PUSHT_URL, raw_dir)
# load
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
zarr_path
) # , keys=['img', 'state', 'action'])
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
@ -291,8 +296,8 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
action_mean = torch.zeros(batch["action"].shape[1])
action_std = torch.zeros(batch["action"].shape[1])
for i in tqdm.tqdm(range(num_batch)):
image_mean += einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
for _ in tqdm.tqdm(range(num_batch)):
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
state_mean += batch["observation", "state"].mean(dim=0)
action_mean += batch["action"].mean(dim=0)
batch = rb.sample()
@ -302,7 +307,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
action_mean /= num_batch
for i in tqdm.tqdm(range(num_batch)):
image_mean_batch = einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
image_mean_batch = einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
image_std += (image_mean_batch - image_mean) ** 2
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
@ -315,12 +320,12 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
mean_std = TensorDict(
{
("observation", "image", "mean"): image_mean[None,:,None,None],
("observation", "image", "std"): image_std[None,:,None,None],
("observation", "state", "mean"): state_mean[None,:],
("observation", "state", "std"): state_std[None,:],
("action", "mean"): action_mean[None,:],
("action", "std"): action_std[None,:],
("observation", "image", "mean"): image_mean[None, :, None, None],
("observation", "image", "std"): image_std[None, :, None, None],
("observation", "state", "mean"): state_mean[None, :],
("observation", "state", "std"): state_std[None, :],
("action", "mean"): action_mean[None, :],
("action", "std"): action_std[None, :],
},
batch_size=[],
)

View File

@ -3,11 +3,9 @@ import zipfile
from pathlib import Path
import requests
from tensordict import TensorDictBase
import tqdm
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}")
response = requests.get(url, stream=True)
@ -30,4 +28,3 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return True
else:
return False

View File

@ -1,10 +1,9 @@
from typing import Sequence
from tensordict import TensorDictBase
from tensordict.utils import NestedKey
from torchrl.envs.transforms import ObservationTransform
from torchrl.envs.transforms import Transform
from tensordict import TensorDictBase
from tensordict.nn import dispatch
from tensordict.utils import NestedKey
from torchrl.envs.transforms import ObservationTransform, Transform
class Prod(ObservationTransform):
@ -27,7 +26,8 @@ class Prod(ObservationTransform):
class NormalizeTransform(Transform):
invertible = True
def __init__(self,
def __init__(
self,
mean_std: TensorDictBase,
in_keys: Sequence[NestedKey] = None,
out_keys: Sequence[NestedKey] | None = None,
@ -40,7 +40,9 @@ class NormalizeTransform(Transform):
in_keys_inv = out_keys
if out_keys_inv is None:
out_keys_inv = in_keys
super().__init__(in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv)
super().__init__(
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
)
self.mean_std = mean_std
@dispatch(source="in_keys", dest="out_keys")
@ -48,7 +50,7 @@ class NormalizeTransform(Transform):
return self._call(tensordict)
def _call(self, td: TensorDictBase) -> TensorDictBase:
for inkey, outkey in zip(self.in_keys, self.out_keys):
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
continue
@ -58,7 +60,7 @@ class NormalizeTransform(Transform):
return td
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv):
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
continue

View File

@ -5,6 +5,7 @@ import hydra
import torch
import torch.nn as nn
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder

View File

@ -5,7 +5,6 @@ from copy import deepcopy
import einops
import numpy as np
from tensordict import TensorDict
import torch
import torch.nn as nn
@ -147,10 +146,7 @@ class TDMPC(nn.Module):
@torch.no_grad()
def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
if isinstance(obs, dict):
obs = {k: o.detach() for k, o in obs.items()}
else:
obs = obs.detach()
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
z = self.model.encode(obs)
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)

View File

@ -11,8 +11,8 @@ import tqdm
from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.envs import EnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import set_seed