backup wip
This commit is contained in:
parent
c355737f3d
commit
e69cd99f33
|
@ -4,7 +4,7 @@ from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
from lerobot.common.transforms import NormalizeTransform
|
from lerobot.common.transforms import IdentityTransform
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ def make_dataset(
|
||||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||||
# min_max_from_spec
|
# min_max_from_spec
|
||||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
# normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||||
|
|
||||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||||
stats = {}
|
stats = {}
|
||||||
|
@ -62,14 +62,15 @@ def make_dataset(
|
||||||
|
|
||||||
transforms = v2.Compose(
|
transforms = v2.Compose(
|
||||||
[
|
[
|
||||||
NormalizeTransform(
|
# TODO(now): Use the transform
|
||||||
stats,
|
# NormalizeTransform(
|
||||||
in_keys=[
|
# in_keys=[
|
||||||
"observation.state",
|
# "observation.state",
|
||||||
"action",
|
# "action",
|
||||||
],
|
# ],
|
||||||
mode=normalization_mode,
|
# mode=normalization_mode,
|
||||||
),
|
# ),
|
||||||
|
IdentityTransform()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -63,3 +63,13 @@ class NormalizeTransform(Transform):
|
||||||
item[outkey] = (item[inkey] + 1) / 2
|
item[outkey] = (item[inkey] + 1) / 2
|
||||||
item[outkey] = item[outkey] * (max - min) + min
|
item[outkey] = item[outkey] * (max - min) + min
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityTransform(Transform):
|
||||||
|
invertible = True
|
||||||
|
|
||||||
|
def forward(self, item):
|
||||||
|
return item
|
||||||
|
|
||||||
|
def inverse_transform(self, item):
|
||||||
|
return item
|
||||||
|
|
|
@ -46,7 +46,6 @@ policy:
|
||||||
kappa: 0.1
|
kappa: 0.1
|
||||||
lr: 3e-4
|
lr: 3e-4
|
||||||
std_schedule: ${policy.min_std}
|
std_schedule: ${policy.min_std}
|
||||||
horizon_schedule: ${policy.horizon}
|
|
||||||
per: true
|
per: true
|
||||||
per_alpha: 0.6
|
per_alpha: 0.6
|
||||||
per_beta: 0.4
|
per_beta: 0.4
|
||||||
|
@ -79,7 +78,7 @@ policy:
|
||||||
latent_dim: 50
|
latent_dim: 50
|
||||||
|
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
observation.image: "[i / ${fps} for i in range(6)]"
|
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||||
observation.state: "[i / ${fps} for i in range(6)]"
|
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||||
action: "[i / ${fps} for i in range(5)]"
|
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||||
next.reward: "[i / ${fps} for i in range(5)]"
|
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||||
|
|
|
@ -157,7 +157,11 @@ def eval_policy(
|
||||||
|
|
||||||
# get the next action for the environment
|
# get the next action for the environment
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = policy.select_action(observation, step=step)
|
# TODO(now): restore
|
||||||
|
observation["observation.image"] *= 255
|
||||||
|
# TODO(now): train_step
|
||||||
|
action = policy.select_action(observation)
|
||||||
|
observation["observation.image"] /= 255
|
||||||
|
|
||||||
# Send action:
|
# Send action:
|
||||||
while True:
|
while True:
|
||||||
|
|
|
@ -270,7 +270,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
offline_dataset,
|
offline_dataset,
|
||||||
num_workers=4,
|
num_workers=32,
|
||||||
batch_size=cfg.policy.batch_size,
|
batch_size=cfg.policy.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=cfg.device != "cpu",
|
||||||
|
@ -316,7 +316,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
)
|
)
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
concat_dataset,
|
concat_dataset,
|
||||||
num_workers=4,
|
num_workers=32,
|
||||||
batch_size=cfg.policy.batch_size,
|
batch_size=cfg.policy.batch_size,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=cfg.device != "cpu",
|
||||||
|
|
Loading…
Reference in New Issue