Merge pull request #6 from Cadene/user/rcadene/2024_03_04_diffusion
Make diffusion work
This commit is contained in:
commit
e990f3e148
|
@ -69,7 +69,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
prefetch=prefetch,
|
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||||
)
|
)
|
||||||
elif cfg.env.name == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
offline_buffer = PushtExperienceReplay(
|
offline_buffer = PushtExperienceReplay(
|
||||||
|
@ -79,7 +79,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
prefetch=prefetch,
|
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env.name)
|
raise ValueError(cfg.env.name)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
@ -134,20 +135,32 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
else:
|
else:
|
||||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||||
|
|
||||||
mean_std = self._compute_or_load_mean_std(storage)
|
stats = self._compute_or_load_stats(storage)
|
||||||
mean_std["next", "observation", "image"] = mean_std["observation", "image"]
|
|
||||||
mean_std["next", "observation", "state"] = mean_std["observation", "state"]
|
|
||||||
transform = NormalizeTransform(
|
transform = NormalizeTransform(
|
||||||
mean_std,
|
stats,
|
||||||
in_keys=[
|
in_keys=[
|
||||||
("observation", "image"),
|
# TODO(rcadene): imagenet normalization is applied inside diffusion policy
|
||||||
|
# We need to automate this for tdmpc and others
|
||||||
|
# ("observation", "image"),
|
||||||
("observation", "state"),
|
("observation", "state"),
|
||||||
("next", "observation", "image"),
|
# TODO(rcadene): for tdmpc, we might want next image and state
|
||||||
("next", "observation", "state"),
|
# ("next", "observation", "image"),
|
||||||
|
# ("next", "observation", "state"),
|
||||||
("action"),
|
("action"),
|
||||||
],
|
],
|
||||||
|
mode="min_max",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||||
|
transform.stats["observation", "state", "min"] = torch.tensor(
|
||||||
|
[13.456424, 32.938293], dtype=torch.float32
|
||||||
|
)
|
||||||
|
transform.stats["observation", "state", "max"] = torch.tensor(
|
||||||
|
[496.14618, 510.9579], dtype=torch.float32
|
||||||
|
)
|
||||||
|
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||||
|
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||||
|
|
||||||
if writer is None:
|
if writer is None:
|
||||||
writer = ImmutableDatasetWriter()
|
writer = ImmutableDatasetWriter()
|
||||||
if collate_fn is None:
|
if collate_fn is None:
|
||||||
|
@ -282,24 +295,50 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
|
|
||||||
return TensorStorage(td_data.lock_())
|
return TensorStorage(td_data.lock_())
|
||||||
|
|
||||||
def _compute_mean_std(self, storage, num_batch=10, batch_size=32):
|
def _compute_stats(self, storage, num_batch=100, batch_size=32):
|
||||||
rb = TensorDictReplayBuffer(
|
rb = TensorDictReplayBuffer(
|
||||||
storage=storage,
|
storage=storage,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
prefetch=True,
|
prefetch=True,
|
||||||
)
|
)
|
||||||
batch = rb.sample()
|
batch = rb.sample()
|
||||||
image_mean = torch.zeros(batch["observation", "image"].shape[1])
|
|
||||||
image_std = torch.zeros(batch["observation", "image"].shape[1])
|
image_channels = batch["observation", "image"].shape[1]
|
||||||
state_mean = torch.zeros(batch["observation", "state"].shape[1])
|
image_mean = torch.zeros(image_channels)
|
||||||
state_std = torch.zeros(batch["observation", "state"].shape[1])
|
image_std = torch.zeros(image_channels)
|
||||||
action_mean = torch.zeros(batch["action"].shape[1])
|
image_max = torch.tensor([-math.inf] * image_channels)
|
||||||
action_std = torch.zeros(batch["action"].shape[1])
|
image_min = torch.tensor([math.inf] * image_channels)
|
||||||
|
|
||||||
|
state_channels = batch["observation", "state"].shape[1]
|
||||||
|
state_mean = torch.zeros(state_channels)
|
||||||
|
state_std = torch.zeros(state_channels)
|
||||||
|
state_max = torch.tensor([-math.inf] * state_channels)
|
||||||
|
state_min = torch.tensor([math.inf] * state_channels)
|
||||||
|
|
||||||
|
action_channels = batch["action"].shape[1]
|
||||||
|
action_mean = torch.zeros(action_channels)
|
||||||
|
action_std = torch.zeros(action_channels)
|
||||||
|
action_max = torch.tensor([-math.inf] * action_channels)
|
||||||
|
action_min = torch.tensor([math.inf] * action_channels)
|
||||||
|
|
||||||
for _ in tqdm.tqdm(range(num_batch)):
|
for _ in tqdm.tqdm(range(num_batch)):
|
||||||
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
|
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||||
state_mean += batch["observation", "state"].mean(dim=0)
|
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||||
action_mean += batch["action"].mean(dim=0)
|
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
|
||||||
|
|
||||||
|
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||||
|
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||||
|
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||||
|
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||||
|
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||||
|
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||||
|
image_max = torch.maximum(image_max, b_image_max)
|
||||||
|
image_min = torch.maximum(image_min, b_image_min)
|
||||||
|
state_max = torch.maximum(state_max, b_state_max)
|
||||||
|
state_min = torch.maximum(state_min, b_state_min)
|
||||||
|
action_max = torch.maximum(action_max, b_action_max)
|
||||||
|
action_min = torch.maximum(action_min, b_action_min)
|
||||||
|
|
||||||
batch = rb.sample()
|
batch = rb.sample()
|
||||||
|
|
||||||
image_mean /= num_batch
|
image_mean /= num_batch
|
||||||
|
@ -307,10 +346,26 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
action_mean /= num_batch
|
action_mean /= num_batch
|
||||||
|
|
||||||
for i in tqdm.tqdm(range(num_batch)):
|
for i in tqdm.tqdm(range(num_batch)):
|
||||||
image_mean_batch = einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
|
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||||
image_std += (image_mean_batch - image_mean) ** 2
|
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||||
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
|
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
|
||||||
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
|
image_std += (b_image_mean - image_mean) ** 2
|
||||||
|
state_std += (b_state_mean - state_mean) ** 2
|
||||||
|
action_std += (b_action_mean - action_mean) ** 2
|
||||||
|
|
||||||
|
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||||
|
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||||
|
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||||
|
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||||
|
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||||
|
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||||
|
image_max = torch.maximum(image_max, b_image_max)
|
||||||
|
image_min = torch.maximum(image_min, b_image_min)
|
||||||
|
state_max = torch.maximum(state_max, b_state_max)
|
||||||
|
state_min = torch.maximum(state_min, b_state_min)
|
||||||
|
action_max = torch.maximum(action_max, b_action_max)
|
||||||
|
action_min = torch.maximum(action_min, b_action_min)
|
||||||
|
|
||||||
if i < num_batch - 1:
|
if i < num_batch - 1:
|
||||||
batch = rb.sample()
|
batch = rb.sample()
|
||||||
|
|
||||||
|
@ -318,25 +373,33 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
state_std = torch.sqrt(state_std / num_batch)
|
state_std = torch.sqrt(state_std / num_batch)
|
||||||
action_std = torch.sqrt(action_std / num_batch)
|
action_std = torch.sqrt(action_std / num_batch)
|
||||||
|
|
||||||
mean_std = TensorDict(
|
stats = TensorDict(
|
||||||
{
|
{
|
||||||
("observation", "image", "mean"): image_mean[None, :, None, None],
|
("observation", "image", "mean"): image_mean[None, :, None, None],
|
||||||
("observation", "image", "std"): image_std[None, :, None, None],
|
("observation", "image", "std"): image_std[None, :, None, None],
|
||||||
|
("observation", "image", "max"): image_max[None, :, None, None],
|
||||||
|
("observation", "image", "min"): image_min[None, :, None, None],
|
||||||
("observation", "state", "mean"): state_mean[None, :],
|
("observation", "state", "mean"): state_mean[None, :],
|
||||||
("observation", "state", "std"): state_std[None, :],
|
("observation", "state", "std"): state_std[None, :],
|
||||||
|
("observation", "state", "max"): state_max[None, :],
|
||||||
|
("observation", "state", "min"): state_min[None, :],
|
||||||
("action", "mean"): action_mean[None, :],
|
("action", "mean"): action_mean[None, :],
|
||||||
("action", "std"): action_std[None, :],
|
("action", "std"): action_std[None, :],
|
||||||
|
("action", "max"): action_max[None, :],
|
||||||
|
("action", "min"): action_min[None, :],
|
||||||
},
|
},
|
||||||
batch_size=[],
|
batch_size=[],
|
||||||
)
|
)
|
||||||
return mean_std
|
stats["next", "observation", "image"] = stats["observation", "image"]
|
||||||
|
stats["next", "observation", "state"] = stats["observation", "state"]
|
||||||
|
return stats
|
||||||
|
|
||||||
def _compute_or_load_mean_std(self, storage) -> TensorDict:
|
def _compute_or_load_stats(self, storage) -> TensorDict:
|
||||||
mean_std_path = self.root / self.dataset_id / "mean_std.pth"
|
stats_path = self.root / self.dataset_id / "stats.pth"
|
||||||
if mean_std_path.exists():
|
if stats_path.exists():
|
||||||
mean_std = torch.load(mean_std_path)
|
stats = torch.load(stats_path)
|
||||||
else:
|
else:
|
||||||
logging.info(f"compute_mean_std and save to {mean_std_path}")
|
logging.info(f"compute_stats and save to {stats_path}")
|
||||||
mean_std = self._compute_mean_std(storage)
|
stats = self._compute_stats(storage)
|
||||||
torch.save(mean_std, mean_std_path)
|
torch.save(stats, stats_path)
|
||||||
return mean_std
|
return stats
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||||
|
|
||||||
from lerobot.common.envs.transforms import Prod
|
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg, transform=None):
|
def make_env(cfg, transform=None):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
@ -9,6 +7,8 @@ def make_env(cfg, transform=None):
|
||||||
"from_pixels": cfg.env.from_pixels,
|
"from_pixels": cfg.env.from_pixels,
|
||||||
"pixels_only": cfg.env.pixels_only,
|
"pixels_only": cfg.env.pixels_only,
|
||||||
"image_size": cfg.env.image_size,
|
"image_size": cfg.env.image_size,
|
||||||
|
# TODO(rcadene): do we want a specific eval_env_seed?
|
||||||
|
"seed": cfg.seed,
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.env.name == "simxarm":
|
if cfg.env.name == "simxarm":
|
||||||
|
@ -19,6 +19,8 @@ def make_env(cfg, transform=None):
|
||||||
elif cfg.env.name == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
from lerobot.common.envs.pusht import PushtEnv
|
from lerobot.common.envs.pusht import PushtEnv
|
||||||
|
|
||||||
|
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||||
|
|
||||||
clsfunc = PushtEnv
|
clsfunc = PushtEnv
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env.name)
|
raise ValueError(cfg.env.name)
|
||||||
|
@ -28,12 +30,8 @@ def make_env(cfg, transform=None):
|
||||||
# limit rollout to max_steps
|
# limit rollout to max_steps
|
||||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
|
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
|
||||||
|
|
||||||
if cfg.env.name == "pusht":
|
|
||||||
# to ensure pusht is in [0,255] like simxarm
|
|
||||||
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
|
|
||||||
|
|
||||||
if transform is not None:
|
if transform is not None:
|
||||||
# useful to add mean and std normalization
|
# useful to add normalization
|
||||||
env.append_transform(transform)
|
env.append_transform(transform)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
from collections import deque
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -27,12 +28,16 @@ class PushtEnv(EnvBase):
|
||||||
image_size=None,
|
image_size=None,
|
||||||
seed=1337,
|
seed=1337,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
|
num_prev_obs=1,
|
||||||
|
num_prev_action=0,
|
||||||
):
|
):
|
||||||
super().__init__(device=device, batch_size=[])
|
super().__init__(device=device, batch_size=[])
|
||||||
self.frame_skip = frame_skip
|
self.frame_skip = frame_skip
|
||||||
self.from_pixels = from_pixels
|
self.from_pixels = from_pixels
|
||||||
self.pixels_only = pixels_only
|
self.pixels_only = pixels_only
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
|
self.num_prev_obs = num_prev_obs
|
||||||
|
self.num_prev_action = num_prev_action
|
||||||
|
|
||||||
if pixels_only:
|
if pixels_only:
|
||||||
assert from_pixels
|
assert from_pixels
|
||||||
|
@ -56,6 +61,12 @@ class PushtEnv(EnvBase):
|
||||||
self._make_spec()
|
self._make_spec()
|
||||||
self._current_seed = self.set_seed(seed)
|
self._current_seed = self.set_seed(seed)
|
||||||
|
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||||
|
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||||
|
if self.num_prev_action > 0:
|
||||||
|
self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||||
|
|
||||||
def render(self, mode="rgb_array", width=384, height=384):
|
def render(self, mode="rgb_array", width=384, height=384):
|
||||||
if width != height:
|
if width != height:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -67,7 +78,8 @@ class PushtEnv(EnvBase):
|
||||||
|
|
||||||
def _format_raw_obs(self, raw_obs):
|
def _format_raw_obs(self, raw_obs):
|
||||||
if self.from_pixels:
|
if self.from_pixels:
|
||||||
obs = {"image": torch.from_numpy(raw_obs["image"])}
|
image = torch.from_numpy(raw_obs["image"])
|
||||||
|
obs = {"image": image}
|
||||||
|
|
||||||
if not self.pixels_only:
|
if not self.pixels_only:
|
||||||
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
|
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
|
||||||
|
@ -75,7 +87,6 @@ class PushtEnv(EnvBase):
|
||||||
# TODO:
|
# TODO:
|
||||||
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||||
|
|
||||||
obs = TensorDict(obs, batch_size=[])
|
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
|
@ -87,9 +98,25 @@ class PushtEnv(EnvBase):
|
||||||
raw_obs = self._env.reset()
|
raw_obs = self._env.reset()
|
||||||
assert self._current_seed == self._env._seed
|
assert self._current_seed == self._env._seed
|
||||||
|
|
||||||
|
obs = self._format_raw_obs(raw_obs)
|
||||||
|
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
stacked_obs = {}
|
||||||
|
if "image" in obs:
|
||||||
|
self._prev_obs_image_queue = deque(
|
||||||
|
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||||
|
)
|
||||||
|
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||||
|
if "state" in obs:
|
||||||
|
self._prev_obs_state_queue = deque(
|
||||||
|
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||||
|
)
|
||||||
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
|
obs = stacked_obs
|
||||||
|
|
||||||
td = TensorDict(
|
td = TensorDict(
|
||||||
{
|
{
|
||||||
"observation": self._format_raw_obs(raw_obs),
|
"observation": TensorDict(obs, batch_size=[]),
|
||||||
"done": torch.tensor([False], dtype=torch.bool),
|
"done": torch.tensor([False], dtype=torch.bool),
|
||||||
},
|
},
|
||||||
batch_size=[],
|
batch_size=[],
|
||||||
|
@ -100,18 +127,37 @@ class PushtEnv(EnvBase):
|
||||||
|
|
||||||
def _step(self, tensordict: TensorDict):
|
def _step(self, tensordict: TensorDict):
|
||||||
td = tensordict
|
td = tensordict
|
||||||
# remove batch dim
|
action = td["action"].numpy()
|
||||||
action = td["action"].squeeze(0).numpy()
|
|
||||||
# step expects shape=(4,) so we pad if necessary
|
# step expects shape=(4,) so we pad if necessary
|
||||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||||
sum_reward = 0
|
sum_reward = 0
|
||||||
for _ in range(self.frame_skip):
|
|
||||||
raw_obs, reward, done, info = self._env.step(action)
|
if action.ndim == 1:
|
||||||
|
action = action.repeat(self.frame_skip, 1)
|
||||||
|
else:
|
||||||
|
if self.frame_skip > 1:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
num_action_steps = action.shape[0]
|
||||||
|
for i in range(num_action_steps):
|
||||||
|
raw_obs, reward, done, info = self._env.step(action[i])
|
||||||
sum_reward += reward
|
sum_reward += reward
|
||||||
|
|
||||||
|
obs = self._format_raw_obs(raw_obs)
|
||||||
|
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
stacked_obs = {}
|
||||||
|
if "image" in obs:
|
||||||
|
self._prev_obs_image_queue.append(obs["image"])
|
||||||
|
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||||
|
if "state" in obs:
|
||||||
|
self._prev_obs_state_queue.append(obs["state"])
|
||||||
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
|
obs = stacked_obs
|
||||||
|
|
||||||
td = TensorDict(
|
td = TensorDict(
|
||||||
{
|
{
|
||||||
"observation": self._format_raw_obs(raw_obs),
|
"observation": TensorDict(obs, batch_size=[]),
|
||||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||||
# succes and done are true when coverage > self.success_threshold in env
|
# succes and done are true when coverage > self.success_threshold in env
|
||||||
"done": torch.tensor([done], dtype=torch.bool),
|
"done": torch.tensor([done], dtype=torch.bool),
|
||||||
|
@ -124,14 +170,22 @@ class PushtEnv(EnvBase):
|
||||||
def _make_spec(self):
|
def _make_spec(self):
|
||||||
obs = {}
|
obs = {}
|
||||||
if self.from_pixels:
|
if self.from_pixels:
|
||||||
|
image_shape = (3, self.image_size, self.image_size)
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
image_shape = (self.num_prev_obs, *image_shape)
|
||||||
|
|
||||||
obs["image"] = BoundedTensorSpec(
|
obs["image"] = BoundedTensorSpec(
|
||||||
low=0,
|
low=0,
|
||||||
high=1,
|
high=1,
|
||||||
shape=(3, self.image_size, self.image_size),
|
shape=image_shape,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
if not self.pixels_only:
|
if not self.pixels_only:
|
||||||
|
state_shape = self._env.observation_space["agent_pos"].shape
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
state_shape = (self.num_prev_obs, *state_shape)
|
||||||
|
|
||||||
obs["state"] = BoundedTensorSpec(
|
obs["state"] = BoundedTensorSpec(
|
||||||
low=0,
|
low=0,
|
||||||
high=512,
|
high=512,
|
||||||
|
@ -141,6 +195,10 @@ class PushtEnv(EnvBase):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||||
|
state_shape = self._env.observation_space["observation"].shape
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
state_shape = (self.num_prev_obs, *state_shape)
|
||||||
|
|
||||||
obs["state"] = UnboundedContinuousTensorSpec(
|
obs["state"] = UnboundedContinuousTensorSpec(
|
||||||
# TODO:
|
# TODO:
|
||||||
shape=self._env.observation_space["observation"].shape,
|
shape=self._env.observation_space["observation"].shape,
|
||||||
|
|
|
@ -28,11 +28,12 @@ class NormalizeTransform(Transform):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
mean_std: TensorDictBase,
|
stats: TensorDictBase,
|
||||||
in_keys: Sequence[NestedKey] = None,
|
in_keys: Sequence[NestedKey] = None,
|
||||||
out_keys: Sequence[NestedKey] | None = None,
|
out_keys: Sequence[NestedKey] | None = None,
|
||||||
in_keys_inv: Sequence[NestedKey] | None = None,
|
in_keys_inv: Sequence[NestedKey] | None = None,
|
||||||
out_keys_inv: Sequence[NestedKey] | None = None,
|
out_keys_inv: Sequence[NestedKey] | None = None,
|
||||||
|
mode="mean_std",
|
||||||
):
|
):
|
||||||
if out_keys is None:
|
if out_keys is None:
|
||||||
out_keys = in_keys
|
out_keys = in_keys
|
||||||
|
@ -43,7 +44,14 @@ class NormalizeTransform(Transform):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
||||||
)
|
)
|
||||||
self.mean_std = mean_std
|
self.stats = stats
|
||||||
|
assert mode in ["mean_std", "min_max"]
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||||
|
# _reset is called once when the environment reset to normalize the first observation
|
||||||
|
tensordict_reset = self._call(tensordict_reset)
|
||||||
|
return tensordict_reset
|
||||||
|
|
||||||
@dispatch(source="in_keys", dest="out_keys")
|
@dispatch(source="in_keys", dest="out_keys")
|
||||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||||
|
@ -54,9 +62,17 @@ class NormalizeTransform(Transform):
|
||||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||||
if td.get(inkey, None) is None:
|
if td.get(inkey, None) is None:
|
||||||
continue
|
continue
|
||||||
mean = self.mean_std[inkey]["mean"]
|
if self.mode == "mean_std":
|
||||||
std = self.mean_std[inkey]["std"]
|
mean = self.stats[inkey]["mean"]
|
||||||
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
|
std = self.stats[inkey]["std"]
|
||||||
|
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
|
||||||
|
else:
|
||||||
|
min = self.stats[inkey]["min"]
|
||||||
|
max = self.stats[inkey]["max"]
|
||||||
|
# normalize to [0,1]
|
||||||
|
td[outkey] = (td[inkey] - min) / (max - min)
|
||||||
|
# normalize to [-1, 1]
|
||||||
|
td[outkey] = td[outkey] * 2 - 1
|
||||||
return td
|
return td
|
||||||
|
|
||||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||||
|
@ -64,7 +80,13 @@ class NormalizeTransform(Transform):
|
||||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||||
if td.get(inkey, None) is None:
|
if td.get(inkey, None) is None:
|
||||||
continue
|
continue
|
||||||
mean = self.mean_std[inkey]["mean"]
|
if self.mode == "mean_std":
|
||||||
std = self.mean_std[inkey]["std"]
|
mean = self.stats[inkey]["mean"]
|
||||||
td[outkey] = td[inkey] * std + mean
|
std = self.stats[inkey]["std"]
|
||||||
|
td[outkey] = td[inkey] * std + mean
|
||||||
|
else:
|
||||||
|
min = self.stats[inkey]["min"]
|
||||||
|
max = self.stats[inkey]["max"]
|
||||||
|
td[outkey] = (td[inkey] + 1) / 2
|
||||||
|
td[outkey] = td[outkey] * (max - min) + min
|
||||||
return td
|
return td
|
||||||
|
|
|
@ -1,51 +1,11 @@
|
||||||
import contextlib
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
|
||||||
def make_dir(dir_path):
|
|
||||||
"""Create directory if it does not already exist."""
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
dir_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
return dir_path
|
|
||||||
|
|
||||||
|
|
||||||
def print_run(cfg, reward=None):
|
|
||||||
"""Pretty-printing of run information. Call at start of training."""
|
|
||||||
prefix, color, attrs = " ", "green", ["bold"]
|
|
||||||
|
|
||||||
def limstr(s, maxlen=32):
|
|
||||||
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
|
|
||||||
|
|
||||||
def pprint(k, v):
|
|
||||||
print(
|
|
||||||
prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs),
|
|
||||||
limstr(v),
|
|
||||||
)
|
|
||||||
|
|
||||||
kvs = [
|
|
||||||
("task", cfg.env.task),
|
|
||||||
("offline_steps", f"{cfg.offline_steps}"),
|
|
||||||
("online_steps", f"{cfg.online_steps}"),
|
|
||||||
("action_repeat", f"{cfg.env.action_repeat}"),
|
|
||||||
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
|
|
||||||
# ('actions', cfg.action_dim),
|
|
||||||
# ('experiment', cfg.exp_name),
|
|
||||||
]
|
|
||||||
if reward is not None:
|
|
||||||
kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"])))
|
|
||||||
w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
|
|
||||||
div = "-" * w
|
|
||||||
print(div)
|
|
||||||
for k, v in kvs:
|
|
||||||
pprint(k, v)
|
|
||||||
print(div)
|
|
||||||
|
|
||||||
|
|
||||||
def cfg_to_group(cfg, return_list=False):
|
def cfg_to_group(cfg, return_list=False):
|
||||||
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
||||||
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
||||||
|
@ -71,13 +31,12 @@ class Logger:
|
||||||
self._seed = cfg.seed
|
self._seed = cfg.seed
|
||||||
self._cfg = cfg
|
self._cfg = cfg
|
||||||
self._eval = []
|
self._eval = []
|
||||||
print_run(cfg)
|
|
||||||
project = cfg.get("wandb", {}).get("project")
|
project = cfg.get("wandb", {}).get("project")
|
||||||
entity = cfg.get("wandb", {}).get("entity")
|
entity = cfg.get("wandb", {}).get("entity")
|
||||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||||
run_offline = not enable_wandb or not project or not entity
|
run_offline = not enable_wandb or not project or not entity
|
||||||
if run_offline:
|
if run_offline:
|
||||||
print(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||||
self._wandb = None
|
self._wandb = None
|
||||||
else:
|
else:
|
||||||
os.environ["WANDB_SILENT"] = "true"
|
os.environ["WANDB_SILENT"] = "true"
|
||||||
|
@ -134,7 +93,6 @@ class Logger:
|
||||||
self.save_buffer(buffer, identifier="buffer")
|
self.save_buffer(buffer, identifier="buffer")
|
||||||
if self._wandb:
|
if self._wandb:
|
||||||
self._wandb.finish()
|
self._wandb.finish()
|
||||||
print_run(self._cfg, self._eval[-1][-1])
|
|
||||||
|
|
||||||
def log_dict(self, d, step, mode="train"):
|
def log_dict(self, d, step, mode="train"):
|
||||||
assert mode in {"train", "eval"}
|
assert mode in {"train", "eval"}
|
||||||
|
@ -144,5 +102,5 @@ class Logger:
|
||||||
|
|
||||||
def log_video(self, video, step, mode="train"):
|
def log_video(self, video, step, mode="train"):
|
||||||
assert mode in {"train", "eval"}
|
assert mode in {"train", "eval"}
|
||||||
wandb_video = self._wandb.Video(video, fps=self.cfg.fps, format="mp4")
|
wandb_video = self._wandb.Video(video, fps=self._cfg.fps, format="mp4")
|
||||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||||
|
|
|
@ -1,20 +1,15 @@
|
||||||
import copy
|
import copy
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import einops
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
||||||
|
|
||||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
|
||||||
|
|
||||||
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
from .multi_image_obs_encoder import MultiImageObsEncoder
|
from .multi_image_obs_encoder import MultiImageObsEncoder
|
||||||
|
|
||||||
FIRST_ACTION = 0
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPolicy(nn.Module):
|
class DiffusionPolicy(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -42,8 +37,8 @@ class DiffusionPolicy(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
noise_scheduler = DDPMScheduler(**cfg_noise_scheduler)
|
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||||
rgb_model = get_resnet(**cfg_rgb_model)
|
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
|
||||||
obs_encoder = MultiImageObsEncoder(
|
obs_encoder = MultiImageObsEncoder(
|
||||||
rgb_model=rgb_model,
|
rgb_model=rgb_model,
|
||||||
**cfg_obs_encoder,
|
**cfg_obs_encoder,
|
||||||
|
@ -101,20 +96,17 @@ class DiffusionPolicy(nn.Module):
|
||||||
# TODO(rcadene): remove unused step_count
|
# TODO(rcadene): remove unused step_count
|
||||||
del step_count
|
del step_count
|
||||||
|
|
||||||
# TODO(rcadene): remove unsqueeze hack...
|
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
||||||
if observation["image"].ndim == 3:
|
observation["image"] = observation["image"].unsqueeze(0)
|
||||||
observation["image"] = observation["image"].unsqueeze(0)
|
observation["state"] = observation["state"].unsqueeze(0)
|
||||||
observation["state"] = observation["state"].unsqueeze(0)
|
|
||||||
|
|
||||||
obs_dict = {
|
obs_dict = {
|
||||||
# TODO(rcadene): hack to add temporal dim
|
"image": observation["image"],
|
||||||
"image": einops.rearrange(observation["image"], "b c h w -> b 1 c h w"),
|
"agent_pos": observation["state"],
|
||||||
"agent_pos": einops.rearrange(observation["state"], "b c -> b 1 c"),
|
|
||||||
}
|
}
|
||||||
out = self.diffusion.predict_action(obs_dict)
|
out = self.diffusion.predict_action(obs_dict)
|
||||||
|
|
||||||
# TODO(rcadene): add possibility to return >1 timestemps
|
action = out["action"].squeeze(0)
|
||||||
action = out["action"].squeeze(0)[FIRST_ACTION]
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def update(self, replay_buffer, step):
|
def update(self, replay_buffer, step):
|
||||||
|
@ -133,16 +125,36 @@ class DiffusionPolicy(nn.Module):
|
||||||
# (t h) ... -> t h ...
|
# (t h) ... -> t h ...
|
||||||
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
|
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
|
||||||
|
|
||||||
|
# |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
|
||||||
|
# |o|o| observations: 2
|
||||||
|
# | |a|a|a|a|a|a|a|a| actions executed: 8
|
||||||
|
# |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
|
||||||
|
# note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
|
||||||
|
|
||||||
|
image = batch["observation", "image"]
|
||||||
|
state = batch["observation", "state"]
|
||||||
|
action = batch["action"]
|
||||||
|
assert image.shape[1] == horizon
|
||||||
|
assert state.shape[1] == horizon
|
||||||
|
assert action.shape[1] == horizon
|
||||||
|
|
||||||
|
if not (horizon == 16 and self.cfg.n_obs_steps == 2):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
# keep first 2 observations of the slice corresponding to t=[-1,0]
|
||||||
|
image = image[:, : self.cfg.n_obs_steps]
|
||||||
|
state = state[:, : self.cfg.n_obs_steps]
|
||||||
|
|
||||||
out = {
|
out = {
|
||||||
"obs": {
|
"obs": {
|
||||||
"image": batch["observation", "image"].to(self.device, non_blocking=True),
|
"image": image.to(self.device, non_blocking=True),
|
||||||
"agent_pos": batch["observation", "state"].to(self.device, non_blocking=True),
|
"agent_pos": state.to(self.device, non_blocking=True),
|
||||||
},
|
},
|
||||||
"action": batch["action"].to(self.device, non_blocking=True),
|
"action": action.to(self.device, non_blocking=True),
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
batch = replay_buffer.sample(batch_size)
|
||||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||||
|
|
||||||
data_s = time.time() - start_time
|
data_s = time.time() - start_time
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
defaults:
|
defaults:
|
||||||
- _self_
|
- _self_
|
||||||
- env: simxarm
|
- env: pusht
|
||||||
- policy: tdmpc
|
- policy: diffusion
|
||||||
|
|
||||||
hydra:
|
hydra:
|
||||||
run:
|
run:
|
||||||
|
@ -21,6 +21,7 @@ save_buffer: false
|
||||||
train_steps: ???
|
train_steps: ???
|
||||||
fps: ???
|
fps: ???
|
||||||
|
|
||||||
|
n_action_steps: ???
|
||||||
env: ???
|
env: ???
|
||||||
|
|
||||||
policy: ???
|
policy: ???
|
||||||
|
|
|
@ -13,7 +13,7 @@ shape_meta:
|
||||||
shape: [2]
|
shape: [2]
|
||||||
|
|
||||||
horizon: 16
|
horizon: 16
|
||||||
n_obs_steps: 1 # TODO(rcadene): before 2
|
n_obs_steps: 2
|
||||||
n_action_steps: 8
|
n_action_steps: 8
|
||||||
n_latency_steps: 0
|
n_latency_steps: 0
|
||||||
dataset_obs_steps: ${n_obs_steps}
|
dataset_obs_steps: ${n_obs_steps}
|
||||||
|
@ -21,7 +21,7 @@ past_action_visible: False
|
||||||
keypoint_visible_rate: 1.0
|
keypoint_visible_rate: 1.0
|
||||||
obs_as_global_cond: True
|
obs_as_global_cond: True
|
||||||
|
|
||||||
eval_episodes: 50
|
eval_episodes: 1
|
||||||
eval_freq: 10000
|
eval_freq: 10000
|
||||||
save_freq: 100000
|
save_freq: 100000
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
|
@ -40,8 +40,8 @@ policy:
|
||||||
num_inference_steps: 100
|
num_inference_steps: 100
|
||||||
obs_as_global_cond: ${obs_as_global_cond}
|
obs_as_global_cond: ${obs_as_global_cond}
|
||||||
# crop_shape: null
|
# crop_shape: null
|
||||||
diffusion_step_embed_dim: 128
|
diffusion_step_embed_dim: 256 # before 128
|
||||||
down_dims: [512, 1024, 2048]
|
down_dims: [256, 512, 1024] # before [512, 1024, 2048]
|
||||||
kernel_size: 5
|
kernel_size: 5
|
||||||
n_groups: 8
|
n_groups: 8
|
||||||
cond_predict_scale: True
|
cond_predict_scale: True
|
||||||
|
@ -59,10 +59,10 @@ policy:
|
||||||
use_ema: true
|
use_ema: true
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
lr_warmup_steps: 500
|
lr_warmup_steps: 500
|
||||||
grad_clip_norm: 0
|
grad_clip_norm: 10
|
||||||
|
|
||||||
noise_scheduler:
|
noise_scheduler:
|
||||||
# _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||||
num_train_timesteps: 100
|
num_train_timesteps: 100
|
||||||
beta_start: 0.0001
|
beta_start: 0.0001
|
||||||
beta_end: 0.02
|
beta_end: 0.02
|
||||||
|
@ -74,16 +74,16 @@ noise_scheduler:
|
||||||
obs_encoder:
|
obs_encoder:
|
||||||
# _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
# _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
||||||
shape_meta: ${shape_meta}
|
shape_meta: ${shape_meta}
|
||||||
resize_shape: null
|
# resize_shape: null
|
||||||
crop_shape: [76, 76]
|
# crop_shape: [76, 76]
|
||||||
# constant center crop
|
# constant center crop
|
||||||
random_crop: True
|
# random_crop: True
|
||||||
use_group_norm: True
|
use_group_norm: True
|
||||||
share_rgb_model: False
|
share_rgb_model: False
|
||||||
imagenet_norm: False # TODO(rcadene): was set to True
|
imagenet_norm: True
|
||||||
|
|
||||||
rgb_model:
|
rgb_model:
|
||||||
#_target_: diffusion_policy.model.vision.model_getter.get_resnet
|
_target_: diffusion_policy.model.vision.model_getter.get_resnet
|
||||||
name: resnet18
|
name: resnet18
|
||||||
weights: null
|
weights: null
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
|
n_action_steps: 1
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
name: tdmpc
|
name: tdmpc
|
||||||
|
|
||||||
|
|
|
@ -118,7 +118,7 @@ def eval(cfg: dict, out_dir=None):
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
env = make_env(cfg, transform=offline_buffer.transform)
|
env = make_env(cfg, transform=offline_buffer._transform)
|
||||||
|
|
||||||
if cfg.policy.pretrained_model_path:
|
if cfg.policy.pretrained_model_path:
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
@ -137,7 +137,7 @@ def eval(cfg: dict, out_dir=None):
|
||||||
save_video=True,
|
save_video=True,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
fps=cfg.env.fps,
|
fps=cfg.env.fps,
|
||||||
max_steps=cfg.env.episode_length,
|
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||||
num_episodes=cfg.eval_episodes,
|
num_episodes=cfg.eval_episodes,
|
||||||
)
|
)
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
|
@ -123,7 +123,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
|
||||||
|
|
||||||
logging.info("make_offline_buffer")
|
logging.info("make_offline_buffer")
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
@ -153,6 +152,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
|
||||||
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
td_policy = TensorDictModule(
|
td_policy = TensorDictModule(
|
||||||
policy,
|
policy,
|
||||||
in_keys=["observation", "step_count"],
|
in_keys=["observation", "step_count"],
|
||||||
|
@ -162,6 +164,16 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
# log metrics to terminal and wandb
|
# log metrics to terminal and wandb
|
||||||
logger = Logger(out_dir, job_name, cfg)
|
logger = Logger(out_dir, job_name, cfg)
|
||||||
|
|
||||||
|
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||||
|
logging.info(f"{cfg.env.task=}")
|
||||||
|
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
||||||
|
logging.info(f"{cfg.online_steps=}")
|
||||||
|
logging.info(f"{cfg.env.action_repeat=}")
|
||||||
|
logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})")
|
||||||
|
logging.info(f"{offline_buffer.num_episodes=}")
|
||||||
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||||
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
step = 0 # number of policy update
|
step = 0 # number of policy update
|
||||||
|
|
||||||
is_offline = True
|
is_offline = True
|
||||||
|
@ -174,19 +186,23 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
||||||
|
|
||||||
if step > 0 and step % cfg.eval_freq == 0:
|
if step > 0 and step % cfg.eval_freq == 0:
|
||||||
|
logging.info(f"Eval policy at step {step}")
|
||||||
eval_info, first_video = eval_policy(
|
eval_info, first_video = eval_policy(
|
||||||
env,
|
env,
|
||||||
td_policy,
|
td_policy,
|
||||||
num_episodes=cfg.eval_episodes,
|
num_episodes=cfg.eval_episodes,
|
||||||
|
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||||
return_first_video=True,
|
return_first_video=True,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(first_video, step, mode="eval")
|
logger.log_video(first_video, step, mode="eval")
|
||||||
|
logging.info("Resume training")
|
||||||
|
|
||||||
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
||||||
logging.info(f"Checkpoint model at step {step}")
|
logging.info(f"Checkpoint policy at step {step}")
|
||||||
logger.save_model(policy, identifier=step)
|
logger.save_model(policy, identifier=step)
|
||||||
|
logging.info("Resume training")
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
|
@ -200,11 +216,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
# TODO: add configurable number of rollout? (default=1)
|
# TODO: add configurable number of rollout? (default=1)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
rollout = env.rollout(
|
rollout = env.rollout(
|
||||||
max_steps=cfg.env.episode_length,
|
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||||
policy=td_policy,
|
policy=td_policy,
|
||||||
auto_cast_to_device=True,
|
auto_cast_to_device=True,
|
||||||
)
|
)
|
||||||
assert len(rollout) <= cfg.env.episode_length
|
assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps
|
||||||
# set same episode index for all time steps contained in this rollout
|
# set same episode index for all time steps contained in this rollout
|
||||||
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
|
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
|
||||||
online_buffer.extend(rollout)
|
online_buffer.extend(rollout)
|
||||||
|
@ -231,19 +247,23 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
||||||
|
|
||||||
if step > 0 and step % cfg.eval_freq == 0:
|
if step > 0 and step % cfg.eval_freq == 0:
|
||||||
|
logging.info(f"Eval policy at step {step}")
|
||||||
eval_info, first_video = eval_policy(
|
eval_info, first_video = eval_policy(
|
||||||
env,
|
env,
|
||||||
td_policy,
|
td_policy,
|
||||||
num_episodes=cfg.eval_episodes,
|
num_episodes=cfg.eval_episodes,
|
||||||
|
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||||
return_first_video=True,
|
return_first_video=True,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(first_video, step, mode="eval")
|
logger.log_video(first_video, step, mode="eval")
|
||||||
|
logging.info("Resume training")
|
||||||
|
|
||||||
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
||||||
logging.info(f"Checkpoint model at step {step}")
|
logging.info(f"Checkpoint policy at step {step}")
|
||||||
logger.save_model(policy, identifier=step)
|
logger.save_model(policy, identifier=step)
|
||||||
|
logging.info("Resume training")
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
online_step += 1
|
online_step += 1
|
||||||
|
|
Loading…
Reference in New Issue