Clean + alpha beta corresponds to config (before 0.7 and 0.9)
This commit is contained in:
parent
0cdd23dcac
commit
0b4084f0f8
|
@ -21,6 +21,15 @@ python setup.py develop
|
|||
- [x] self.step=100000 should be updated at every step to adjust to horizon of planner
|
||||
- [ ] prefetch replay buffer to speedup training
|
||||
- [ ] parallelize env to speedup eval
|
||||
- [ ] clean checkpointing / loading
|
||||
- [ ] clean logging
|
||||
- [ ] clean config
|
||||
- [ ] clean hyperparameter tuning
|
||||
- [ ] add pusht
|
||||
- [ ] add aloha
|
||||
- [ ] add act
|
||||
- [ ] add diffusion
|
||||
- [ ] add aloha 2
|
||||
|
||||
## Contribute
|
||||
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict.nn import TensorDictModule
|
||||
|
@ -19,7 +16,6 @@ from lerobot.common.logger import Logger
|
|||
from lerobot.common.tdmpc import TDMPC
|
||||
from lerobot.common.utils import set_seed
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
from rl.torchrl.collectors.collectors import SyncDataCollector
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||
|
@ -30,11 +26,11 @@ def train(cfg: dict):
|
|||
|
||||
env = make_env(cfg)
|
||||
policy = TDMPC(cfg)
|
||||
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||
policy.step = 25000
|
||||
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||
# policy.step = 100000
|
||||
policy.load(ckpt_path)
|
||||
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||
# policy.step = 25000
|
||||
# # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||
# # policy.step = 100000
|
||||
# policy.load(ckpt_path)
|
||||
|
||||
td_policy = TensorDictModule(
|
||||
policy,
|
||||
|
@ -51,8 +47,8 @@ def train(cfg: dict):
|
|||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=0.7,
|
||||
beta=0.9,
|
||||
alpha=cfg.per_alpha,
|
||||
beta=cfg.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
|
@ -74,8 +70,8 @@ def train(cfg: dict):
|
|||
if cfg.balanced_sampling:
|
||||
online_sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=0.7,
|
||||
beta=0.9,
|
||||
alpha=cfg.per_alpha,
|
||||
beta=cfg.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
|
@ -83,18 +79,8 @@ def train(cfg: dict):
|
|||
online_buffer = TensorDictReplayBuffer(
|
||||
storage=LazyMemmapStorage(100_000),
|
||||
sampler=online_sampler,
|
||||
# batch_size=3,
|
||||
# pin_memory=False,
|
||||
# prefetch=3,
|
||||
)
|
||||
|
||||
# Observation encoder
|
||||
# Dynamics predictor
|
||||
# Reward predictor
|
||||
# Policy
|
||||
# Qs state-action value predictor
|
||||
# V state value predictor
|
||||
|
||||
L = Logger(cfg.log_dir, cfg)
|
||||
|
||||
online_episode_idx = 0
|
||||
|
@ -103,9 +89,6 @@ def train(cfg: dict):
|
|||
last_log_step = 0
|
||||
last_save_step = 0
|
||||
|
||||
# TODO(rcadene): remove
|
||||
step = 25000
|
||||
|
||||
while step < cfg.train_steps:
|
||||
is_offline = True
|
||||
num_updates = cfg.episode_length
|
||||
|
@ -126,26 +109,11 @@ def train(cfg: dict):
|
|||
)
|
||||
online_buffer.extend(rollout)
|
||||
|
||||
# Collect trajectory
|
||||
# obs = env.reset()
|
||||
# episode = Episode(cfg, obs)
|
||||
# success = False
|
||||
# while not episode.done:
|
||||
# action = policy.act(obs, step=step, t0=episode.first)
|
||||
# obs, reward, done, info = env.step(action.cpu().numpy())
|
||||
# reward = reward_normalizer(reward)
|
||||
# mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0
|
||||
# success = info.get('success', False)
|
||||
# episode += (obs, action, reward, done, mask, success)
|
||||
|
||||
ep_reward = rollout["next", "reward"].sum()
|
||||
ep_success = rollout["next", "success"].any()
|
||||
|
||||
online_episode_idx += 1
|
||||
rollout_metrics = {
|
||||
# 'episode_reward': episode.cumulative_reward,
|
||||
# 'episode_success': float(success),
|
||||
# 'episode_length': len(episode)
|
||||
"avg_reward": np.nanmean(ep_reward),
|
||||
"pc_success": np.nanmean(ep_success) * 100,
|
||||
}
|
||||
|
@ -190,10 +158,6 @@ def train(cfg: dict):
|
|||
# TODO(rcadene): add step, env_step, L.video
|
||||
)
|
||||
|
||||
# TODO(rcadene):
|
||||
# if hasattr(env, "get_normalized_score"):
|
||||
# eval_metrics['normalized_score'] = env.get_normalized_score(eval_metrics["episode_reward"]) * 100.0
|
||||
|
||||
common_metrics.update(eval_metrics)
|
||||
|
||||
L.log(common_metrics, category="eval")
|
||||
|
|
Loading…
Reference in New Issue