backup wip
This commit is contained in:
parent
c355737f3d
commit
e69cd99f33
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||
import torch
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from lerobot.common.transforms import NormalizeTransform
|
||||
from lerobot.common.transforms import IdentityTransform
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
@ -38,7 +38,7 @@ def make_dataset(
|
|||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
# normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
stats = {}
|
||||
|
@ -62,14 +62,15 @@ def make_dataset(
|
|||
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
NormalizeTransform(
|
||||
stats,
|
||||
in_keys=[
|
||||
"observation.state",
|
||||
"action",
|
||||
],
|
||||
mode=normalization_mode,
|
||||
),
|
||||
# TODO(now): Use the transform
|
||||
# NormalizeTransform(
|
||||
# in_keys=[
|
||||
# "observation.state",
|
||||
# "action",
|
||||
# ],
|
||||
# mode=normalization_mode,
|
||||
# ),
|
||||
IdentityTransform()
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -63,3 +63,13 @@ class NormalizeTransform(Transform):
|
|||
item[outkey] = (item[inkey] + 1) / 2
|
||||
item[outkey] = item[outkey] * (max - min) + min
|
||||
return item
|
||||
|
||||
|
||||
class IdentityTransform(Transform):
|
||||
invertible = True
|
||||
|
||||
def forward(self, item):
|
||||
return item
|
||||
|
||||
def inverse_transform(self, item):
|
||||
return item
|
||||
|
|
|
@ -46,7 +46,6 @@ policy:
|
|||
kappa: 0.1
|
||||
lr: 3e-4
|
||||
std_schedule: ${policy.min_std}
|
||||
horizon_schedule: ${policy.horizon}
|
||||
per: true
|
||||
per_alpha: 0.6
|
||||
per_beta: 0.4
|
||||
|
@ -79,7 +78,7 @@ policy:
|
|||
latent_dim: 50
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(6)]"
|
||||
observation.state: "[i / ${fps} for i in range(6)]"
|
||||
action: "[i / ${fps} for i in range(5)]"
|
||||
next.reward: "[i / ${fps} for i in range(5)]"
|
||||
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
|
|
@ -157,7 +157,11 @@ def eval_policy(
|
|||
|
||||
# get the next action for the environment
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=step)
|
||||
# TODO(now): restore
|
||||
observation["observation.image"] *= 255
|
||||
# TODO(now): train_step
|
||||
action = policy.select_action(observation)
|
||||
observation["observation.image"] /= 255
|
||||
|
||||
# Send action:
|
||||
while True:
|
||||
|
|
|
@ -270,7 +270,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
# create dataloader for offline training
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
offline_dataset,
|
||||
num_workers=4,
|
||||
num_workers=32,
|
||||
batch_size=cfg.policy.batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
|
@ -316,7 +316,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
num_workers=4,
|
||||
num_workers=32,
|
||||
batch_size=cfg.policy.batch_size,
|
||||
sampler=sampler,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
|
@ -339,14 +339,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
seed=cfg.seed,
|
||||
)
|
||||
|
||||
add_episodes_inplace(
|
||||
online_dataset,
|
||||
concat_dataset,
|
||||
sampler,
|
||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||
pc_online_samples=cfg.get("demo_schedule", 0.5),
|
||||
)
|
||||
add_episodes_inplace(
|
||||
online_dataset,
|
||||
concat_dataset,
|
||||
sampler,
|
||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||
pc_online_samples=cfg.get("demo_schedule", 0.5),
|
||||
)
|
||||
|
||||
for _ in range(cfg.policy.utd):
|
||||
policy.train()
|
||||
|
|
Loading…
Reference in New Issue