backup wip

This commit is contained in:
Alexander Soare 2024-04-23 16:22:49 +01:00
parent c355737f3d
commit e69cd99f33
5 changed files with 40 additions and 26 deletions

View File

@ -4,7 +4,7 @@ from pathlib import Path
import torch
from torchvision.transforms import v2
from lerobot.common.transforms import NormalizeTransform
from lerobot.common.transforms import IdentityTransform
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@ -38,7 +38,7 @@ def make_dataset(
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
# normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {}
@ -62,14 +62,15 @@ def make_dataset(
transforms = v2.Compose(
[
NormalizeTransform(
stats,
in_keys=[
"observation.state",
"action",
],
mode=normalization_mode,
),
# TODO(now): Use the transform
# NormalizeTransform(
# in_keys=[
# "observation.state",
# "action",
# ],
# mode=normalization_mode,
# ),
IdentityTransform()
]
)

View File

@ -63,3 +63,13 @@ class NormalizeTransform(Transform):
item[outkey] = (item[inkey] + 1) / 2
item[outkey] = item[outkey] * (max - min) + min
return item
class IdentityTransform(Transform):
invertible = True
def forward(self, item):
return item
def inverse_transform(self, item):
return item

View File

@ -46,7 +46,6 @@ policy:
kappa: 0.1
lr: 3e-4
std_schedule: ${policy.min_std}
horizon_schedule: ${policy.horizon}
per: true
per_alpha: 0.6
per_beta: 0.4
@ -79,7 +78,7 @@ policy:
latent_dim: 50
delta_timestamps:
observation.image: "[i / ${fps} for i in range(6)]"
observation.state: "[i / ${fps} for i in range(6)]"
action: "[i / ${fps} for i in range(5)]"
next.reward: "[i / ${fps} for i in range(5)]"
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
action: "[i / ${fps} for i in range(${policy.horizon})]"
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"

View File

@ -157,7 +157,11 @@ def eval_policy(
# get the next action for the environment
with torch.inference_mode():
action = policy.select_action(observation, step=step)
# TODO(now): restore
observation["observation.image"] *= 255
# TODO(now): train_step
action = policy.select_action(observation)
observation["observation.image"] /= 255
# Send action:
while True:

View File

@ -270,7 +270,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# create dataloader for offline training
dataloader = torch.utils.data.DataLoader(
offline_dataset,
num_workers=4,
num_workers=32,
batch_size=cfg.policy.batch_size,
shuffle=True,
pin_memory=cfg.device != "cpu",
@ -316,7 +316,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
num_workers=4,
num_workers=32,
batch_size=cfg.policy.batch_size,
sampler=sampler,
pin_memory=cfg.device != "cpu",
@ -339,14 +339,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
seed=cfg.seed,
)
add_episodes_inplace(
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.get("demo_schedule", 0.5),
)
add_episodes_inplace(
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.get("demo_schedule", 0.5),
)
for _ in range(cfg.policy.utd):
policy.train()