diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 8ea64f86..11569ee2 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -136,14 +136,14 @@ class PushtExperienceReplay(TensorDictReplayBuffer): storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) stats = self._compute_or_load_stats(storage) - stats["next", "observation", "image"] = stats["observation", "image"] - stats["next", "observation", "state"] = stats["observation", "state"] transform = NormalizeTransform( stats, in_keys=[ + # TODO(rcadene): imagenet normalization is applied inside diffusion policy + # We need to automate this for tdmpc and others # ("observation", "image"), ("observation", "state"), - # TODO(rcadene): for tdmpc, we might want image and state + # TODO(rcadene): for tdmpc, we might want next image and state # ("next", "observation", "image"), # ("next", "observation", "state"), ("action"), @@ -151,7 +151,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer): mode="min_max", ) - # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, min_max_spec + # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec transform.stats["observation", "state", "min"] = torch.tensor( [13.456424, 32.938293], dtype=torch.float32 ) @@ -302,29 +302,43 @@ class PushtExperienceReplay(TensorDictReplayBuffer): prefetch=True, ) batch = rb.sample() - image_mean = torch.zeros(batch["observation", "image"].shape[1]) - image_std = torch.zeros(batch["observation", "image"].shape[1]) - image_max = -math.inf - image_min = math.inf - state_mean = torch.zeros(batch["observation", "state"].shape[1]) - state_std = torch.zeros(batch["observation", "state"].shape[1]) - state_max = -math.inf - state_min = math.inf - action_mean = torch.zeros(batch["action"].shape[1]) - action_std = torch.zeros(batch["action"].shape[1]) - action_max = -math.inf - action_min = math.inf + + image_channels = batch["observation", "image"].shape[1] + image_mean = torch.zeros(image_channels) + image_std = torch.zeros(image_channels) + image_max = torch.tensor([-math.inf] * image_channels) + image_min = torch.tensor([math.inf] * image_channels) + + state_channels = batch["observation", "state"].shape[1] + state_mean = torch.zeros(state_channels) + state_std = torch.zeros(state_channels) + state_max = torch.tensor([-math.inf] * state_channels) + state_min = torch.tensor([math.inf] * state_channels) + + action_channels = batch["action"].shape[1] + action_mean = torch.zeros(action_channels) + action_std = torch.zeros(action_channels) + action_max = torch.tensor([-math.inf] * action_channels) + action_min = torch.tensor([math.inf] * action_channels) for _ in tqdm.tqdm(range(num_batch)): - image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean") - state_mean += batch["observation", "state"].mean(dim=0) - action_mean += batch["action"].mean(dim=0) - image_max = max(image_max, batch["observation", "image"].max().item()) - image_min = min(image_min, batch["observation", "image"].min().item()) - state_max = max(state_max, batch["observation", "state"].max().item()) - state_min = min(state_min, batch["observation", "state"].min().item()) - action_max = max(action_max, batch["action"].max().item()) - action_min = min(action_min, batch["action"].min().item()) + image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean") + state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean") + action_mean += einops.reduce(batch["action"], "b c -> c", "mean") + + b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max") + b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min") + b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max") + b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min") + b_action_max = einops.reduce(batch["action"], "b c -> c", "max") + b_action_min = einops.reduce(batch["action"], "b c -> c", "min") + image_max = torch.maximum(image_max, b_image_max) + image_min = torch.maximum(image_min, b_image_min) + state_max = torch.maximum(state_max, b_state_max) + state_min = torch.maximum(state_min, b_state_min) + action_max = torch.maximum(action_max, b_action_max) + action_min = torch.maximum(action_min, b_action_min) + batch = rb.sample() image_mean /= num_batch @@ -332,16 +346,26 @@ class PushtExperienceReplay(TensorDictReplayBuffer): action_mean /= num_batch for i in tqdm.tqdm(range(num_batch)): - image_mean_batch = einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean") - image_std += (image_mean_batch - image_mean) ** 2 - state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2 - action_std += (batch["action"].mean(dim=0) - action_mean) ** 2 - image_max = max(image_max, batch["observation", "image"].max().item()) - image_min = min(image_min, batch["observation", "image"].min().item()) - state_max = max(state_max, batch["observation", "state"].max().item()) - state_min = min(state_min, batch["observation", "state"].min().item()) - action_max = max(action_max, batch["action"].max().item()) - action_min = min(action_min, batch["action"].min().item()) + b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean") + b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean") + b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean") + image_std += (b_image_mean - image_mean) ** 2 + state_std += (b_state_mean - state_mean) ** 2 + action_std += (b_action_mean - action_mean) ** 2 + + b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max") + b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min") + b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max") + b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min") + b_action_max = einops.reduce(batch["action"], "b c -> c", "max") + b_action_min = einops.reduce(batch["action"], "b c -> c", "min") + image_max = torch.maximum(image_max, b_image_max) + image_min = torch.maximum(image_min, b_image_min) + state_max = torch.maximum(state_max, b_state_max) + state_min = torch.maximum(state_min, b_state_min) + action_max = torch.maximum(action_max, b_action_max) + action_min = torch.maximum(action_min, b_action_min) + if i < num_batch - 1: batch = rb.sample() @@ -353,19 +377,21 @@ class PushtExperienceReplay(TensorDictReplayBuffer): { ("observation", "image", "mean"): image_mean[None, :, None, None], ("observation", "image", "std"): image_std[None, :, None, None], - ("observation", "image", "max"): torch.tensor(image_max), - ("observation", "image", "min"): torch.tensor(image_min), + ("observation", "image", "max"): image_max[None, :, None, None], + ("observation", "image", "min"): image_min[None, :, None, None], ("observation", "state", "mean"): state_mean[None, :], ("observation", "state", "std"): state_std[None, :], - ("observation", "state", "max"): torch.tensor(state_max), - ("observation", "state", "min"): torch.tensor(state_min), + ("observation", "state", "max"): state_max[None, :], + ("observation", "state", "min"): state_min[None, :], ("action", "mean"): action_mean[None, :], ("action", "std"): action_std[None, :], - ("action", "max"): torch.tensor(action_max), - ("action", "min"): torch.tensor(action_min), + ("action", "max"): action_max[None, :], + ("action", "min"): action_min[None, :], }, batch_size=[], ) + stats["next", "observation", "image"] = stats["observation", "image"] + stats["next", "observation", "state"] = stats["observation", "state"] return stats def _compute_or_load_stats(self, storage) -> TensorDict: diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index da1b6545..0ea8f638 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -59,7 +59,7 @@ policy: use_ema: true lr_scheduler: cosine lr_warmup_steps: 500 - grad_clip_norm: 0 + grad_clip_norm: 10 noise_scheduler: _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler