Wandb works, One output dir

This commit is contained in:
Cadene 2024-02-22 12:14:12 +00:00
parent ece89730e6
commit e3643d6146
11 changed files with 200 additions and 100 deletions

View File

@ -15,6 +15,31 @@ conda activate lerobot
python setup.py develop
```
## Usage
### Train
```
python lerobot/scripts/train.py \
--config-name=pusht hydra.job.name=pusht
```
### Visualize offline buffer
```
python lerobot/scripts/visualize_dataset.py \
--config-name=pusht hydra.run.dir=tmp/$(date +"%Y_%m_%d")
```
### Visualize online buffer / Eval
```
python lerobot/scripts/eval.py \
--config-name=pusht hydra.run.dir=tmp/$(date +"%Y_%m_%d")
```
## TODO
- [x] priority update doesnt match FOWM or original paper

View File

@ -11,6 +11,7 @@ def make_env(cfg):
"from_pixels": cfg.from_pixels,
"pixels_only": cfg.pixels_only,
"image_size": cfg.image_size,
"max_episode_length": cfg.episode_length,
}
if cfg.env == "simxarm":

View File

@ -29,7 +29,7 @@ class PushtEnv(EnvBase):
image_size=None,
seed=1337,
device="cpu",
max_episode_length=25, # TODO: verify
max_episode_length=300,
):
super().__init__(device=device, batch_size=[])
self.frame_skip = frame_skip
@ -53,13 +53,11 @@ class PushtEnv(EnvBase):
if not from_pixels:
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
from gym.wrappers import TimeLimit
self._env = PushTImageEnv(render_size=self.image_size)
self._env = TimeLimit(self._env, self.max_episode_length)
self._make_spec()
self.set_seed(seed)
self._current_seed = self.set_seed(seed)
def render(self, mode="rgb_array", width=384, height=384):
if width != height:
@ -90,7 +88,11 @@ class PushtEnv(EnvBase):
def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
raw_obs = self._env.reset()
assert self._current_seed == self._env._seed
td = TensorDict(
{

View File

@ -49,7 +49,6 @@ class SimxarmEnv(EnvBase):
raise ImportError("Cannot import gym.")
import gym
from gym.wrappers import TimeLimit
from simxarm import TASKS
if self.task not in TASKS:
@ -58,7 +57,6 @@ class SimxarmEnv(EnvBase):
)
self._env = TASKS[self.task]["env"]()
self._env = TimeLimit(self._env, TASKS[self.task]["episode_length"])
MAX_NUM_ACTIONS = 4
num_actions = len(TASKS[self.task]["action_space"])

View File

@ -11,8 +11,9 @@ from termcolor import colored
CONSOLE_FORMAT = [
("episode", "E", "int"),
("env_step", "S", "int"),
("avg_reward", "R", "float"),
("pc_success", "R", "float"),
("avg_sum_reward", "RS", "float"),
("avg_max_reward", "RM", "float"),
("pc_success", "S", "float"),
("total_time", "T", "time"),
]
AGENT_METRICS = [
@ -69,7 +70,11 @@ def print_run(cfg, reward=None):
def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
lst = [
f"env:{cfg.env}",
f"seed:{cfg.seed}",
]
return lst if return_list else "-".join(lst)
@ -120,8 +125,9 @@ class VideoRecorder:
class Logger(object):
"""Primary logger object. Logs either locally or using wandb."""
def __init__(self, log_dir, cfg):
def __init__(self, log_dir, job_name, cfg):
self._log_dir = make_dir(Path(log_dir))
self._job_name = job_name
self._model_dir = make_dir(self._log_dir / "models")
self._buffer_dir = make_dir(self._log_dir / "buffers")
self._save_model = cfg.save_model
@ -131,9 +137,8 @@ class Logger(object):
self._cfg = cfg
self._eval = []
print_run(cfg)
project, entity = cfg.get("wandb_project", "none"), cfg.get(
"wandb_entity", "none"
)
project = cfg.get("wandb_project", "none")
entity = cfg.get("wandb_entity", "none")
run_offline = (
not cfg.get("use_wandb", False) or project == "none" or entity == "none"
)
@ -141,35 +146,39 @@ class Logger(object):
print(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None
else:
try:
os.environ["WANDB_SILENT"] = "true"
import wandb
# try:
os.environ["WANDB_SILENT"] = "true"
import wandb
wandb.init(
project=project,
entity=entity,
name=str(cfg.seed),
notes=cfg.notes,
group=self._group,
tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
dir=self._log_dir,
config=OmegaConf.to_container(cfg, resolve=True),
)
print(
colored("Logs will be synced with wandb.", "blue", attrs=["bold"])
)
self._wandb = wandb
except:
print(
colored(
"Warning: failed to init wandb. Make sure `wandb_entity` is set to your username in `config.yaml`. Logs will be saved locally.",
"yellow",
attrs=["bold"],
)
)
self._wandb = None
wandb.init(
project=project,
entity=entity,
name=job_name,
notes=cfg.notes,
# group=self._group,
tags=cfg_to_group(cfg, return_list=True),
dir=self._log_dir,
config=OmegaConf.to_container(cfg, resolve=True),
# TODO(rcadene): try set to True
save_code=False,
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
job_type="train_eval",
# TODO(rcadene): add resume option
resume=None,
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
self._wandb = wandb
# except:
# print(
# colored(
# "Warning: failed to init wandb. Make sure `wandb_entity` is set to your username in `config.yaml`. Logs will be saved locally.",
# "yellow",
# attrs=["bold"],
# )
# )
# self._wandb = None
self._video = (
VideoRecorder(log_dir, self._wandb)
VideoRecorder(self._log_dir, self._wandb)
if self._wandb and cfg.save_video
else None
)
@ -235,7 +244,7 @@ class Logger(object):
self._wandb.log({category + "/" + k: v}, step=d["env_step"])
if category == "eval":
# keys = ['env_step', 'avg_reward']
keys = ["env_step", "avg_reward", "pc_success"]
keys = ["env_step", "avg_sum_reward", "avg_max_reward", "pc_success"]
self._eval.append(np.array([d[key] for key in keys]))
pd.DataFrame(np.array(self._eval)).to_csv(
self._log_dir / "eval.log", header=keys, index=None

View File

@ -96,7 +96,7 @@ class TDMPC(nn.Module):
self.model_target.eval()
self.batch_size = cfg.batch_size
self.step = 0
self.register_buffer("step", torch.zeros(1))
def state_dict(self):
"""Retrieve state dict of TOLD model, including slow-moving target network."""
@ -122,7 +122,7 @@ class TDMPC(nn.Module):
"rgb": observation["image"],
"state": observation["state"],
}
return self.act(obs, t0=t0, step=self.step)
return self.act(obs, t0=t0, step=self.step.item())
@torch.no_grad()
def act(self, obs, t0=False, step=None):
@ -513,5 +513,5 @@ class TDMPC(nn.Module):
metrics.update(value_info)
metrics.update(pi_update_info)
self.step = step
self.step[0] = step
return metrics

View File

@ -1,7 +1,10 @@
hydra:
run:
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}_${hydra.job.name}
job:
name: default
seed: 1337
log_dir: logs/2024_01_26_train
video_dir: tmp/2024_01_26_xarm_lift_medium
exp_name: default
device: cuda
buffer_device: cuda
eval_freq: 1000
@ -21,6 +24,7 @@ fps: 15
reward_scale: 1.0
# xarm_lift
episode_length: 25
modality: 'all'
@ -97,6 +101,7 @@ mlp_dim: 512
latent_dim: 50
# wandb
use_wandb: false
wandb_project: FOWM
use_wandb: true
wandb_project: lerobot
wandb_entity: rcadene # insert your own
notes: ""

View File

@ -5,12 +5,16 @@ hydra:
job:
name: pusht
video_dir: tmp/2024_02_21_pusht
# env
env: pusht
task: pusht
image_size: 96
frame_skip: 1
state_dim: 2
action_dim: 2
fps: 10
fps: 10
eval_episodes: 50
episode_length: 300
eval_freq: 7500
save_freq: 75000

View File

@ -21,19 +21,22 @@ def eval_policy(
save_video: bool = False,
video_dir: Path = None,
fps: int = 15,
env_step: int = None,
wandb=None,
):
rewards = []
if wandb is not None:
assert env_step is not None
sum_rewards = []
max_rewards = []
successes = []
for i in range(num_episodes):
ep_frames = []
def rendering_callback(env, td=None):
nonlocal ep_frames
frame = env.render()
ep_frames.append(frame)
ep_frames.append(env.render())
tensordict = env.reset()
if save_video:
if save_video or wandb:
# render first frame before rollout
rendering_callback(env)
@ -41,35 +44,54 @@ def eval_policy(
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
callback=rendering_callback if save_video else None,
callback=rendering_callback if save_video or wandb else None,
auto_reset=False,
tensordict=tensordict,
auto_cast_to_device=True,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
ep_reward = rollout["next", "reward"].sum()
ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
ep_success = rollout["next", "success"].any()
rewards.append(ep_reward.item())
sum_rewards.append(ep_sum_reward.item())
max_rewards.append(ep_max_reward.item())
successes.append(ep_success.item())
if save_video:
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, np.stack(ep_frames), fps=fps)
if save_video or wandb:
stacked_frames = np.stack(ep_frames)
if save_video:
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, stacked_frames, fps=fps)
first_episode = i == 0
if wandb and first_episode:
eval_video = wandb.Video(
stacked_frames.transpose(0, 3, 1, 2), fps=fps, format="mp4"
)
wandb.log({"eval_video": eval_video}, step=env_step)
metrics = {
"avg_reward": np.nanmean(rewards),
"avg_sum_reward": np.nanmean(sum_rewards),
"avg_max_reward": np.nanmean(max_rewards),
"pc_success": np.nanmean(successes) * 100,
}
return metrics
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def eval(cfg: dict):
def eval_cli(cfg: dict):
eval(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
def eval(cfg: dict, out_dir=None):
if out_dir is None:
raise NotImplementedError()
assert torch.cuda.is_available()
set_seed(cfg.seed)
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir)
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
env = make_env(cfg)
@ -95,13 +117,14 @@ def eval(cfg: dict):
metrics = eval_policy(
env,
policy=policy,
num_episodes=20,
save_video=True,
video_dir=Path(cfg.video_dir),
video_dir=Path(out_dir) / "eval",
fps=cfg.fps,
max_steps=cfg.episode_length,
num_episodes=cfg.eval_episodes,
)
print(metrics)
if __name__ == "__main__":
eval()
eval_cli()

View File

@ -20,24 +20,47 @@ from lerobot.scripts.eval import eval_policy
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def train(cfg: dict):
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(
out_dir=None, job_name=None, config_name="default", config_path="../configs"
):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
def train(cfg: dict, out_dir=None, job_name=None):
if out_dir is None:
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
assert torch.cuda.is_available()
set_seed(cfg.seed)
print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir)
print(colored("Work dir:", "yellow", attrs=["bold"]), out_dir)
env = make_env(cfg)
policy = TDMPC(cfg)
if cfg.pretrained_model_path:
ckpt_path = (
"/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
)
if "offline" in cfg.pretrained_model_path:
policy.step = 25000
elif "final" in cfg.pretrained_model_path:
policy.step = 100000
else:
raise NotImplementedError()
policy.load(ckpt_path)
# TODO(rcadene): hack for old pretrained models from fowm
if "fowm" in cfg.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
policy.step = 25000
elif "final" in cfg.pretrained_model_path:
policy.step = 100000
else:
raise NotImplementedError()
policy.load(cfg.pretrained_model_path)
td_policy = TensorDictModule(
policy,
@ -65,7 +88,7 @@ def train(cfg: dict):
sampler=online_sampler,
)
L = Logger(cfg.log_dir, cfg)
L = Logger(out_dir, job_name, cfg)
online_episode_idx = 0
start_time = time.time()
@ -95,12 +118,14 @@ def train(cfg: dict):
)
online_buffer.extend(rollout)
ep_reward = rollout["next", "reward"].sum()
ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
ep_success = rollout["next", "success"].any()
online_episode_idx += 1
rollout_metrics = {
"avg_reward": np.nanmean(ep_reward),
"avg_sum_reward": np.nanmean(ep_sum_reward),
"avg_max_reward": np.nanmean(ep_max_reward),
"pc_success": np.nanmean(ep_success) * 100,
}
num_updates = len(rollout) * cfg.utd
@ -137,23 +162,23 @@ def train(cfg: dict):
env,
td_policy,
num_episodes=cfg.eval_episodes,
# TODO(rcadene): add step, env_step, L.video
env_step=env_step,
wandb=L._wandb,
)
common_metrics.update(eval_metrics)
L.log(common_metrics, category="eval")
last_log_step = env_step - env_step % cfg.eval_freq
# Save model periodically
# if cfg.save_model and env_step - last_save_step >= cfg.save_freq:
# L.save_model(policy, identifier=env_step)
# print(f"Model has been checkpointed at step {env_step}")
# last_save_step = env_step - env_step % cfg.save_freq
if cfg.save_model and env_step - last_save_step >= cfg.save_freq:
L.save_model(policy, identifier=env_step)
print(f"Model has been checkpointed at step {env_step}")
last_save_step = env_step - env_step % cfg.save_freq
# if cfg.save_model and is_offline and _step >= cfg.offline_steps:
# # save the model after offline training
# L.save_model(policy, identifier="offline")
if cfg.save_model and is_offline and _step >= cfg.offline_steps:
# save the model after offline training
L.save_model(policy, identifier="offline")
step = _step
@ -177,4 +202,4 @@ def train(cfg: dict):
if __name__ == "__main__":
train()
train_cli()

View File

@ -15,7 +15,15 @@ from lerobot.common.datasets.factory import make_offline_buffer
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def visualize_dataset(cfg: dict):
def visualize_dataset_cli(cfg: dict):
visualize_dataset(
cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
)
def visualize_dataset(cfg: dict, out_dir=None):
if out_dir is None:
raise NotImplementedError()
sampler = SliceSamplerWithoutReplacement(
num_slices=1,
@ -40,10 +48,10 @@ def visualize_dataset(cfg: dict):
dim=0,
)
video_dir = Path(cfg.video_dir)
video_dir = Path(out_dir) / "visualize_dataset"
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{ep_idx}.mp4"
video_path = video_dir / f"episode_{ep_idx}.mp4"
assert ep_frames.min().item() >= 0
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
@ -59,4 +67,4 @@ def visualize_dataset(cfg: dict):
if __name__ == "__main__":
visualize_dataset()
visualize_dataset_cli()