backup wip

This commit is contained in:
Alexander Soare 2024-04-23 16:22:49 +01:00
parent c355737f3d
commit e69cd99f33
5 changed files with 40 additions and 26 deletions

View File

@ -4,7 +4,7 @@ from pathlib import Path
import torch import torch
from torchvision.transforms import v2 from torchvision.transforms import v2
from lerobot.common.transforms import NormalizeTransform from lerobot.common.transforms import IdentityTransform
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@ -38,7 +38,7 @@ def make_dataset(
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec # min_max_from_spec
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" # normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {} stats = {}
@ -62,14 +62,15 @@ def make_dataset(
transforms = v2.Compose( transforms = v2.Compose(
[ [
NormalizeTransform( # TODO(now): Use the transform
stats, # NormalizeTransform(
in_keys=[ # in_keys=[
"observation.state", # "observation.state",
"action", # "action",
], # ],
mode=normalization_mode, # mode=normalization_mode,
), # ),
IdentityTransform()
] ]
) )

View File

@ -63,3 +63,13 @@ class NormalizeTransform(Transform):
item[outkey] = (item[inkey] + 1) / 2 item[outkey] = (item[inkey] + 1) / 2
item[outkey] = item[outkey] * (max - min) + min item[outkey] = item[outkey] * (max - min) + min
return item return item
class IdentityTransform(Transform):
invertible = True
def forward(self, item):
return item
def inverse_transform(self, item):
return item

View File

@ -46,7 +46,6 @@ policy:
kappa: 0.1 kappa: 0.1
lr: 3e-4 lr: 3e-4
std_schedule: ${policy.min_std} std_schedule: ${policy.min_std}
horizon_schedule: ${policy.horizon}
per: true per: true
per_alpha: 0.6 per_alpha: 0.6
per_beta: 0.4 per_beta: 0.4
@ -79,7 +78,7 @@ policy:
latent_dim: 50 latent_dim: 50
delta_timestamps: delta_timestamps:
observation.image: "[i / ${fps} for i in range(6)]" observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
observation.state: "[i / ${fps} for i in range(6)]" observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
action: "[i / ${fps} for i in range(5)]" action: "[i / ${fps} for i in range(${policy.horizon})]"
next.reward: "[i / ${fps} for i in range(5)]" next.reward: "[i / ${fps} for i in range(${policy.horizon})]"

View File

@ -157,7 +157,11 @@ def eval_policy(
# get the next action for the environment # get the next action for the environment
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation, step=step) # TODO(now): restore
observation["observation.image"] *= 255
# TODO(now): train_step
action = policy.select_action(observation)
observation["observation.image"] /= 255
# Send action: # Send action:
while True: while True:

View File

@ -270,7 +270,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# create dataloader for offline training # create dataloader for offline training
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
offline_dataset, offline_dataset,
num_workers=4, num_workers=32,
batch_size=cfg.policy.batch_size, batch_size=cfg.policy.batch_size,
shuffle=True, shuffle=True,
pin_memory=cfg.device != "cpu", pin_memory=cfg.device != "cpu",
@ -316,7 +316,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
) )
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
concat_dataset, concat_dataset,
num_workers=4, num_workers=32,
batch_size=cfg.policy.batch_size, batch_size=cfg.policy.batch_size,
sampler=sampler, sampler=sampler,
pin_memory=cfg.device != "cpu", pin_memory=cfg.device != "cpu",