Fix grad_clip_norm 0 -> 10, Fix normalization min_max to be per channel
This commit is contained in:
parent
cfc304e870
commit
e29fbb50e8
|
@ -136,14 +136,14 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||||
|
|
||||||
stats = self._compute_or_load_stats(storage)
|
stats = self._compute_or_load_stats(storage)
|
||||||
stats["next", "observation", "image"] = stats["observation", "image"]
|
|
||||||
stats["next", "observation", "state"] = stats["observation", "state"]
|
|
||||||
transform = NormalizeTransform(
|
transform = NormalizeTransform(
|
||||||
stats,
|
stats,
|
||||||
in_keys=[
|
in_keys=[
|
||||||
|
# TODO(rcadene): imagenet normalization is applied inside diffusion policy
|
||||||
|
# We need to automate this for tdmpc and others
|
||||||
# ("observation", "image"),
|
# ("observation", "image"),
|
||||||
("observation", "state"),
|
("observation", "state"),
|
||||||
# TODO(rcadene): for tdmpc, we might want image and state
|
# TODO(rcadene): for tdmpc, we might want next image and state
|
||||||
# ("next", "observation", "image"),
|
# ("next", "observation", "image"),
|
||||||
# ("next", "observation", "state"),
|
# ("next", "observation", "state"),
|
||||||
("action"),
|
("action"),
|
||||||
|
@ -151,7 +151,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
mode="min_max",
|
mode="min_max",
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, min_max_spec
|
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||||
transform.stats["observation", "state", "min"] = torch.tensor(
|
transform.stats["observation", "state", "min"] = torch.tensor(
|
||||||
[13.456424, 32.938293], dtype=torch.float32
|
[13.456424, 32.938293], dtype=torch.float32
|
||||||
)
|
)
|
||||||
|
@ -302,29 +302,43 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
prefetch=True,
|
prefetch=True,
|
||||||
)
|
)
|
||||||
batch = rb.sample()
|
batch = rb.sample()
|
||||||
image_mean = torch.zeros(batch["observation", "image"].shape[1])
|
|
||||||
image_std = torch.zeros(batch["observation", "image"].shape[1])
|
image_channels = batch["observation", "image"].shape[1]
|
||||||
image_max = -math.inf
|
image_mean = torch.zeros(image_channels)
|
||||||
image_min = math.inf
|
image_std = torch.zeros(image_channels)
|
||||||
state_mean = torch.zeros(batch["observation", "state"].shape[1])
|
image_max = torch.tensor([-math.inf] * image_channels)
|
||||||
state_std = torch.zeros(batch["observation", "state"].shape[1])
|
image_min = torch.tensor([math.inf] * image_channels)
|
||||||
state_max = -math.inf
|
|
||||||
state_min = math.inf
|
state_channels = batch["observation", "state"].shape[1]
|
||||||
action_mean = torch.zeros(batch["action"].shape[1])
|
state_mean = torch.zeros(state_channels)
|
||||||
action_std = torch.zeros(batch["action"].shape[1])
|
state_std = torch.zeros(state_channels)
|
||||||
action_max = -math.inf
|
state_max = torch.tensor([-math.inf] * state_channels)
|
||||||
action_min = math.inf
|
state_min = torch.tensor([math.inf] * state_channels)
|
||||||
|
|
||||||
|
action_channels = batch["action"].shape[1]
|
||||||
|
action_mean = torch.zeros(action_channels)
|
||||||
|
action_std = torch.zeros(action_channels)
|
||||||
|
action_max = torch.tensor([-math.inf] * action_channels)
|
||||||
|
action_min = torch.tensor([math.inf] * action_channels)
|
||||||
|
|
||||||
for _ in tqdm.tqdm(range(num_batch)):
|
for _ in tqdm.tqdm(range(num_batch)):
|
||||||
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
|
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||||
state_mean += batch["observation", "state"].mean(dim=0)
|
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||||
action_mean += batch["action"].mean(dim=0)
|
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
|
||||||
image_max = max(image_max, batch["observation", "image"].max().item())
|
|
||||||
image_min = min(image_min, batch["observation", "image"].min().item())
|
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||||
state_max = max(state_max, batch["observation", "state"].max().item())
|
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||||
state_min = min(state_min, batch["observation", "state"].min().item())
|
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||||
action_max = max(action_max, batch["action"].max().item())
|
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||||
action_min = min(action_min, batch["action"].min().item())
|
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||||
|
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||||
|
image_max = torch.maximum(image_max, b_image_max)
|
||||||
|
image_min = torch.maximum(image_min, b_image_min)
|
||||||
|
state_max = torch.maximum(state_max, b_state_max)
|
||||||
|
state_min = torch.maximum(state_min, b_state_min)
|
||||||
|
action_max = torch.maximum(action_max, b_action_max)
|
||||||
|
action_min = torch.maximum(action_min, b_action_min)
|
||||||
|
|
||||||
batch = rb.sample()
|
batch = rb.sample()
|
||||||
|
|
||||||
image_mean /= num_batch
|
image_mean /= num_batch
|
||||||
|
@ -332,16 +346,26 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
action_mean /= num_batch
|
action_mean /= num_batch
|
||||||
|
|
||||||
for i in tqdm.tqdm(range(num_batch)):
|
for i in tqdm.tqdm(range(num_batch)):
|
||||||
image_mean_batch = einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
|
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||||
image_std += (image_mean_batch - image_mean) ** 2
|
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||||
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
|
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
|
||||||
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
|
image_std += (b_image_mean - image_mean) ** 2
|
||||||
image_max = max(image_max, batch["observation", "image"].max().item())
|
state_std += (b_state_mean - state_mean) ** 2
|
||||||
image_min = min(image_min, batch["observation", "image"].min().item())
|
action_std += (b_action_mean - action_mean) ** 2
|
||||||
state_max = max(state_max, batch["observation", "state"].max().item())
|
|
||||||
state_min = min(state_min, batch["observation", "state"].min().item())
|
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||||
action_max = max(action_max, batch["action"].max().item())
|
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||||
action_min = min(action_min, batch["action"].min().item())
|
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||||
|
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||||
|
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||||
|
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||||
|
image_max = torch.maximum(image_max, b_image_max)
|
||||||
|
image_min = torch.maximum(image_min, b_image_min)
|
||||||
|
state_max = torch.maximum(state_max, b_state_max)
|
||||||
|
state_min = torch.maximum(state_min, b_state_min)
|
||||||
|
action_max = torch.maximum(action_max, b_action_max)
|
||||||
|
action_min = torch.maximum(action_min, b_action_min)
|
||||||
|
|
||||||
if i < num_batch - 1:
|
if i < num_batch - 1:
|
||||||
batch = rb.sample()
|
batch = rb.sample()
|
||||||
|
|
||||||
|
@ -353,19 +377,21 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
{
|
{
|
||||||
("observation", "image", "mean"): image_mean[None, :, None, None],
|
("observation", "image", "mean"): image_mean[None, :, None, None],
|
||||||
("observation", "image", "std"): image_std[None, :, None, None],
|
("observation", "image", "std"): image_std[None, :, None, None],
|
||||||
("observation", "image", "max"): torch.tensor(image_max),
|
("observation", "image", "max"): image_max[None, :, None, None],
|
||||||
("observation", "image", "min"): torch.tensor(image_min),
|
("observation", "image", "min"): image_min[None, :, None, None],
|
||||||
("observation", "state", "mean"): state_mean[None, :],
|
("observation", "state", "mean"): state_mean[None, :],
|
||||||
("observation", "state", "std"): state_std[None, :],
|
("observation", "state", "std"): state_std[None, :],
|
||||||
("observation", "state", "max"): torch.tensor(state_max),
|
("observation", "state", "max"): state_max[None, :],
|
||||||
("observation", "state", "min"): torch.tensor(state_min),
|
("observation", "state", "min"): state_min[None, :],
|
||||||
("action", "mean"): action_mean[None, :],
|
("action", "mean"): action_mean[None, :],
|
||||||
("action", "std"): action_std[None, :],
|
("action", "std"): action_std[None, :],
|
||||||
("action", "max"): torch.tensor(action_max),
|
("action", "max"): action_max[None, :],
|
||||||
("action", "min"): torch.tensor(action_min),
|
("action", "min"): action_min[None, :],
|
||||||
},
|
},
|
||||||
batch_size=[],
|
batch_size=[],
|
||||||
)
|
)
|
||||||
|
stats["next", "observation", "image"] = stats["observation", "image"]
|
||||||
|
stats["next", "observation", "state"] = stats["observation", "state"]
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def _compute_or_load_stats(self, storage) -> TensorDict:
|
def _compute_or_load_stats(self, storage) -> TensorDict:
|
||||||
|
|
|
@ -59,7 +59,7 @@ policy:
|
||||||
use_ema: true
|
use_ema: true
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
lr_warmup_steps: 500
|
lr_warmup_steps: 500
|
||||||
grad_clip_norm: 0
|
grad_clip_norm: 10
|
||||||
|
|
||||||
noise_scheduler:
|
noise_scheduler:
|
||||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||||
|
|
Loading…
Reference in New Issue