Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)
This commit is contained in:
parent
b5a2f460ea
commit
1ae6205269
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -6,7 +7,7 @@ from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||||
|
|
||||||
DATA_PATH = Path("data/")
|
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
||||||
|
|
||||||
# TODO(rcadene): implement
|
# TODO(rcadene): implement
|
||||||
|
|
||||||
|
@ -64,7 +65,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||||
# download="force",
|
# download="force",
|
||||||
download=True,
|
download=True,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
root=str(DATA_PATH),
|
root=str(DATA_DIR),
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
|
@ -74,7 +75,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||||
offline_buffer = PushtExperienceReplay(
|
offline_buffer = PushtExperienceReplay(
|
||||||
"pusht",
|
"pusht",
|
||||||
streaming=False,
|
streaming=False,
|
||||||
root=DATA_PATH,
|
root=DATA_DIR,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
@ -16,9 +17,10 @@ from torchrl.data.replay_buffers.samplers import Sampler
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||||
|
|
||||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||||
from lerobot.common.datasets import utils
|
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||||
|
from lerobot.common.envs.transforms import NormalizeTransform
|
||||||
|
|
||||||
# as define in env
|
# as define in env
|
||||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||||
|
@ -132,29 +134,16 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
else:
|
else:
|
||||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||||
|
|
||||||
# if num_slices is not None or slice_len is not None:
|
mean_std = self._compute_or_load_mean_std(storage)
|
||||||
# if sampler is not None:
|
mean_std["next", "observation", "image"] = mean_std["observation", "image"]
|
||||||
# raise ValueError(
|
mean_std["next", "observation", "state"] = mean_std["observation", "state"]
|
||||||
# "`num_slices` and `slice_len` are exclusive with the `sampler` argument."
|
transform = NormalizeTransform(mean_std, in_keys=[
|
||||||
# )
|
("observation", "image"),
|
||||||
|
("observation", "state"),
|
||||||
# if replacement:
|
("next", "observation", "image"),
|
||||||
# if not self.shuffle:
|
("next", "observation", "state"),
|
||||||
# raise RuntimeError(
|
("action"),
|
||||||
# "shuffle=False can only be used when replacement=False."
|
])
|
||||||
# )
|
|
||||||
# sampler = SliceSampler(
|
|
||||||
# num_slices=num_slices,
|
|
||||||
# slice_len=slice_len,
|
|
||||||
# strict_length=strict_length,
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# sampler = SliceSamplerWithoutReplacement(
|
|
||||||
# num_slices=num_slices,
|
|
||||||
# slice_len=slice_len,
|
|
||||||
# strict_length=strict_length,
|
|
||||||
# shuffle=self.shuffle,
|
|
||||||
# )
|
|
||||||
|
|
||||||
if writer is None:
|
if writer is None:
|
||||||
writer = ImmutableDatasetWriter()
|
writer = ImmutableDatasetWriter()
|
||||||
|
@ -193,10 +182,10 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||||
if not zarr_path.is_dir():
|
if not zarr_path.is_dir():
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
utils.download_and_extract_zip(PUSHT_URL, raw_dir)
|
download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||||
|
|
||||||
# load
|
# load
|
||||||
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
||||||
|
|
||||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||||
|
@ -287,3 +276,62 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
idxtd = idxtd + len(episode)
|
idxtd = idxtd + len(episode)
|
||||||
|
|
||||||
return TensorStorage(td_data.lock_())
|
return TensorStorage(td_data.lock_())
|
||||||
|
|
||||||
|
def _compute_mean_std(self, storage, num_batch=10, batch_size=32):
|
||||||
|
rb = TensorDictReplayBuffer(
|
||||||
|
storage=storage,
|
||||||
|
batch_size=batch_size,
|
||||||
|
prefetch=True,
|
||||||
|
)
|
||||||
|
batch = rb.sample()
|
||||||
|
image_mean = torch.zeros(batch["observation", "image"].shape[1])
|
||||||
|
image_std = torch.zeros(batch["observation", "image"].shape[1])
|
||||||
|
state_mean = torch.zeros(batch["observation", "state"].shape[1])
|
||||||
|
state_std = torch.zeros(batch["observation", "state"].shape[1])
|
||||||
|
action_mean = torch.zeros(batch["action"].shape[1])
|
||||||
|
action_std = torch.zeros(batch["action"].shape[1])
|
||||||
|
|
||||||
|
for i in tqdm.tqdm(range(num_batch)):
|
||||||
|
image_mean += einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
|
||||||
|
state_mean += batch["observation", "state"].mean(dim=0)
|
||||||
|
action_mean += batch["action"].mean(dim=0)
|
||||||
|
batch = rb.sample()
|
||||||
|
|
||||||
|
image_mean /= num_batch
|
||||||
|
state_mean /= num_batch
|
||||||
|
action_mean /= num_batch
|
||||||
|
|
||||||
|
for i in tqdm.tqdm(range(num_batch)):
|
||||||
|
image_mean_batch = einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
|
||||||
|
image_std += (image_mean_batch - image_mean) ** 2
|
||||||
|
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
|
||||||
|
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
|
||||||
|
if i < num_batch - 1:
|
||||||
|
batch = rb.sample()
|
||||||
|
|
||||||
|
image_std = torch.sqrt(image_std / num_batch)
|
||||||
|
state_std = torch.sqrt(state_std / num_batch)
|
||||||
|
action_std = torch.sqrt(action_std / num_batch)
|
||||||
|
|
||||||
|
mean_std = TensorDict(
|
||||||
|
{
|
||||||
|
("observation", "image", "mean"): image_mean[None,:,None,None],
|
||||||
|
("observation", "image", "std"): image_std[None,:,None,None],
|
||||||
|
("observation", "state", "mean"): state_mean[None,:],
|
||||||
|
("observation", "state", "std"): state_std[None,:],
|
||||||
|
("action", "mean"): action_mean[None,:],
|
||||||
|
("action", "std"): action_std[None,:],
|
||||||
|
},
|
||||||
|
batch_size=[],
|
||||||
|
)
|
||||||
|
return mean_std
|
||||||
|
|
||||||
|
def _compute_or_load_mean_std(self, storage) -> TensorDict:
|
||||||
|
mean_std_path = self.root / self.dataset_id / "mean_std.pth"
|
||||||
|
if mean_std_path.exists():
|
||||||
|
mean_std = torch.load(mean_std_path)
|
||||||
|
else:
|
||||||
|
logging.info(f"compute_mean_std and save to {mean_std_path}")
|
||||||
|
mean_std = self._compute_mean_std(storage)
|
||||||
|
torch.save(mean_std, mean_std_path)
|
||||||
|
return mean_std
|
||||||
|
|
|
@ -3,9 +3,11 @@ import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from tensordict import TensorDictBase
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
print(f"downloading from {url}")
|
print(f"downloading from {url}")
|
||||||
response = requests.get(url, stream=True)
|
response = requests.get(url, stream=True)
|
||||||
|
@ -28,3 +30,4 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||||
from lerobot.common.envs.transforms import Prod
|
from lerobot.common.envs.transforms import Prod
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg):
|
def make_env(cfg, transform=None):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"frame_skip": cfg.env.action_repeat,
|
"frame_skip": cfg.env.action_repeat,
|
||||||
"from_pixels": cfg.env.from_pixels,
|
"from_pixels": cfg.env.from_pixels,
|
||||||
|
@ -32,6 +32,10 @@ def make_env(cfg):
|
||||||
# to ensure pusht is in [0,255] like simxarm
|
# to ensure pusht is in [0,255] like simxarm
|
||||||
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
|
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
|
||||||
|
|
||||||
|
if transform is not None:
|
||||||
|
# useful to add mean and std normalization
|
||||||
|
env.append_transform(transform)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -100,7 +100,8 @@ class PushtEnv(EnvBase):
|
||||||
|
|
||||||
def _step(self, tensordict: TensorDict):
|
def _step(self, tensordict: TensorDict):
|
||||||
td = tensordict
|
td = tensordict
|
||||||
action = td["action"].numpy()
|
# remove batch dim
|
||||||
|
action = td["action"].squeeze(0).numpy()
|
||||||
# step expects shape=(4,) so we pad if necessary
|
# step expects shape=(4,) so we pad if necessary
|
||||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||||
sum_reward = 0
|
sum_reward = 0
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
from tensordict import TensorDictBase
|
||||||
|
|
||||||
from tensordict.utils import NestedKey
|
from tensordict.utils import NestedKey
|
||||||
from torchrl.envs.transforms import ObservationTransform
|
from torchrl.envs.transforms import ObservationTransform
|
||||||
|
from torchrl.envs.transforms import Transform
|
||||||
|
from tensordict.nn import dispatch
|
||||||
|
|
||||||
|
|
||||||
class Prod(ObservationTransform):
|
class Prod(ObservationTransform):
|
||||||
|
@ -19,3 +22,47 @@ class Prod(ObservationTransform):
|
||||||
for key in self.in_keys:
|
for key in self.in_keys:
|
||||||
obs_spec[key].space.high *= self.prod
|
obs_spec[key].space.high *= self.prod
|
||||||
return obs_spec
|
return obs_spec
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeTransform(Transform):
|
||||||
|
invertible = True
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
mean_std: TensorDictBase,
|
||||||
|
in_keys: Sequence[NestedKey] = None,
|
||||||
|
out_keys: Sequence[NestedKey] | None = None,
|
||||||
|
in_keys_inv: Sequence[NestedKey] | None = None,
|
||||||
|
out_keys_inv: Sequence[NestedKey] | None = None,
|
||||||
|
):
|
||||||
|
if out_keys is None:
|
||||||
|
out_keys = in_keys
|
||||||
|
if in_keys_inv is None:
|
||||||
|
in_keys_inv = out_keys
|
||||||
|
if out_keys_inv is None:
|
||||||
|
out_keys_inv = in_keys
|
||||||
|
super().__init__(in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv)
|
||||||
|
self.mean_std = mean_std
|
||||||
|
|
||||||
|
@dispatch(source="in_keys", dest="out_keys")
|
||||||
|
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||||
|
return self._call(tensordict)
|
||||||
|
|
||||||
|
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||||
|
for inkey, outkey in zip(self.in_keys, self.out_keys):
|
||||||
|
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||||
|
if td.get(inkey, None) is None:
|
||||||
|
continue
|
||||||
|
mean = self.mean_std[inkey]["mean"]
|
||||||
|
std = self.mean_std[inkey]["std"]
|
||||||
|
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
|
||||||
|
return td
|
||||||
|
|
||||||
|
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||||
|
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv):
|
||||||
|
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||||
|
if td.get(inkey, None) is None:
|
||||||
|
continue
|
||||||
|
mean = self.mean_std[inkey]["mean"]
|
||||||
|
std = self.mean_std[inkey]["std"]
|
||||||
|
td[outkey] = td[inkey] * std + mean
|
||||||
|
return td
|
||||||
|
|
|
@ -5,6 +5,7 @@ from copy import deepcopy
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tensordict import TensorDict
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
@ -126,19 +127,30 @@ class TDMPC(nn.Module):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, observation, step_count):
|
def forward(self, observation, step_count):
|
||||||
t0 = step_count.item() == 0
|
t0 = step_count.item() == 0
|
||||||
|
|
||||||
|
# TODO(rcadene): remove unsqueeze hack...
|
||||||
|
if observation["image"].ndim == 3:
|
||||||
|
observation["image"] = observation["image"].unsqueeze(0)
|
||||||
|
observation["state"] = observation["state"].unsqueeze(0)
|
||||||
|
|
||||||
obs = {
|
obs = {
|
||||||
"rgb": observation["image"],
|
# TODO(rcadene): remove contiguous hack...
|
||||||
"state": observation["state"],
|
"rgb": observation["image"].contiguous(),
|
||||||
|
"state": observation["state"].contiguous(),
|
||||||
}
|
}
|
||||||
return self.act(obs, t0=t0, step=self.step.item())
|
action = self.act(obs, t0=t0, step=self.step.item())
|
||||||
|
|
||||||
|
# TODO(rcadene): hack to postprocess action (e.g. unnormalize)
|
||||||
|
# action = action * self.action_std + self.action_mean
|
||||||
|
return action
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def act(self, obs, t0=False, step=None):
|
def act(self, obs, t0=False, step=None):
|
||||||
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
||||||
if isinstance(obs, dict):
|
if isinstance(obs, dict):
|
||||||
obs = {k: o.detach().unsqueeze(0) for k, o in obs.items()}
|
obs = {k: o.detach() for k, o in obs.items()}
|
||||||
else:
|
else:
|
||||||
obs = obs.detach().unsqueeze(0)
|
obs = obs.detach()
|
||||||
z = self.model.encode(obs)
|
z = self.model.encode(obs)
|
||||||
if self.cfg.mpc:
|
if self.cfg.mpc:
|
||||||
a = self.plan(z, t0=t0, step=step)
|
a = self.plan(z, t0=t0, step=step)
|
||||||
|
@ -315,26 +327,20 @@ class TDMPC(nn.Module):
|
||||||
# trajectory t = 256, horizon h = 5
|
# trajectory t = 256, horizon h = 5
|
||||||
# (t h) ... -> h t ...
|
# (t h) ... -> h t ...
|
||||||
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
||||||
batch = batch.to(self.device)
|
|
||||||
|
|
||||||
obs = {
|
obs = {
|
||||||
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
|
"rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
|
||||||
"state": batch["observation", "state"][FIRST_FRAME],
|
"state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
|
||||||
}
|
}
|
||||||
action = batch["action"]
|
action = batch["action"].to(self.device, non_blocking=True)
|
||||||
next_obses = {
|
next_obses = {
|
||||||
"rgb": batch["next", "observation", "image"].float(),
|
"rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
|
||||||
"state": batch["next", "observation", "state"],
|
"state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
|
||||||
}
|
}
|
||||||
reward = batch["next", "reward"]
|
reward = batch["next", "reward"].to(self.device, non_blocking=True)
|
||||||
|
|
||||||
# TODO(rcadene): add non_blocking=True
|
idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
|
||||||
# for key in obs:
|
weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
|
||||||
# obs[key] = obs[key].to(self.device, non_blocking=True)
|
|
||||||
# next_obses[key] = next_obses[key].to(self.device, non_blocking=True)
|
|
||||||
|
|
||||||
# action = action.to(self.device, non_blocking=True)
|
|
||||||
# reward = reward.to(self.device, non_blocking=True)
|
|
||||||
|
|
||||||
# TODO(rcadene): rearrange directly in offline dataset
|
# TODO(rcadene): rearrange directly in offline dataset
|
||||||
if reward.ndim == 2:
|
if reward.ndim == 2:
|
||||||
|
@ -347,9 +353,6 @@ class TDMPC(nn.Module):
|
||||||
# Neither does `batch["next", "terminated"]`
|
# Neither does `batch["next", "terminated"]`
|
||||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
|
|
||||||
idxs = batch["index"][FIRST_FRAME]
|
|
||||||
weights = batch["_weight"][FIRST_FRAME, :, None]
|
|
||||||
return obs, action, next_obses, reward, mask, done, idxs, weights
|
return obs, action, next_obses, reward, mask, done, idxs, weights
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -10,6 +11,7 @@ import tqdm
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torchrl.envs import EnvBase
|
from torchrl.envs import EnvBase
|
||||||
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
|
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
@ -112,7 +114,11 @@ def eval(cfg: dict, out_dir=None):
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
|
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
|
||||||
|
|
||||||
env = make_env(cfg)
|
logging.info("make_offline_buffer")
|
||||||
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
|
||||||
|
logging.info("make_env")
|
||||||
|
env = make_env(cfg, transform=offline_buffer.transform)
|
||||||
|
|
||||||
if cfg.policy.pretrained_model_path:
|
if cfg.policy.pretrained_model_path:
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
|
|
@ -117,21 +117,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||||
|
|
||||||
logging.info("make_env")
|
|
||||||
env = make_env(cfg)
|
|
||||||
|
|
||||||
logging.info("make_policy")
|
|
||||||
policy = make_policy(cfg)
|
|
||||||
|
|
||||||
td_policy = TensorDictModule(
|
|
||||||
policy,
|
|
||||||
in_keys=["observation", "step_count"],
|
|
||||||
out_keys=["action"],
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("make_offline_buffer")
|
logging.info("make_offline_buffer")
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
|
||||||
|
@ -151,8 +140,22 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
online_buffer = TensorDictReplayBuffer(
|
online_buffer = TensorDictReplayBuffer(
|
||||||
storage=LazyMemmapStorage(100_000),
|
storage=LazyMemmapStorage(100_000),
|
||||||
sampler=online_sampler,
|
sampler=online_sampler,
|
||||||
|
transform=offline_buffer._transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.info("make_env")
|
||||||
|
env = make_env(cfg, transform=offline_buffer._transform)
|
||||||
|
|
||||||
|
logging.info("make_policy")
|
||||||
|
policy = make_policy(cfg, transform=offline_buffer._transform)
|
||||||
|
|
||||||
|
td_policy = TensorDictModule(
|
||||||
|
policy,
|
||||||
|
in_keys=["observation", "step_count"],
|
||||||
|
out_keys=["action"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# log metrics to terminal and wandb
|
||||||
logger = Logger(out_dir, job_name, cfg)
|
logger = Logger(out_dir, job_name, cfg)
|
||||||
|
|
||||||
step = 0 # number of policy update
|
step = 0 # number of policy update
|
||||||
|
|
Loading…
Reference in New Issue