Wandb works, One output dir

This commit is contained in:
Cadene 2024-02-22 12:14:12 +00:00
parent ece89730e6
commit e3643d6146
11 changed files with 200 additions and 100 deletions

View File

@ -15,6 +15,31 @@ conda activate lerobot
python setup.py develop python setup.py develop
``` ```
## Usage
### Train
```
python lerobot/scripts/train.py \
--config-name=pusht hydra.job.name=pusht
```
### Visualize offline buffer
```
python lerobot/scripts/visualize_dataset.py \
--config-name=pusht hydra.run.dir=tmp/$(date +"%Y_%m_%d")
```
### Visualize online buffer / Eval
```
python lerobot/scripts/eval.py \
--config-name=pusht hydra.run.dir=tmp/$(date +"%Y_%m_%d")
```
## TODO ## TODO
- [x] priority update doesnt match FOWM or original paper - [x] priority update doesnt match FOWM or original paper

View File

@ -11,6 +11,7 @@ def make_env(cfg):
"from_pixels": cfg.from_pixels, "from_pixels": cfg.from_pixels,
"pixels_only": cfg.pixels_only, "pixels_only": cfg.pixels_only,
"image_size": cfg.image_size, "image_size": cfg.image_size,
"max_episode_length": cfg.episode_length,
} }
if cfg.env == "simxarm": if cfg.env == "simxarm":

View File

@ -29,7 +29,7 @@ class PushtEnv(EnvBase):
image_size=None, image_size=None,
seed=1337, seed=1337,
device="cpu", device="cpu",
max_episode_length=25, # TODO: verify max_episode_length=300,
): ):
super().__init__(device=device, batch_size=[]) super().__init__(device=device, batch_size=[])
self.frame_skip = frame_skip self.frame_skip = frame_skip
@ -53,13 +53,11 @@ class PushtEnv(EnvBase):
if not from_pixels: if not from_pixels:
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv") raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
from gym.wrappers import TimeLimit
self._env = PushTImageEnv(render_size=self.image_size) self._env = PushTImageEnv(render_size=self.image_size)
self._env = TimeLimit(self._env, self.max_episode_length)
self._make_spec() self._make_spec()
self.set_seed(seed) self._current_seed = self.set_seed(seed)
def render(self, mode="rgb_array", width=384, height=384): def render(self, mode="rgb_array", width=384, height=384):
if width != height: if width != height:
@ -90,7 +88,11 @@ class PushtEnv(EnvBase):
def _reset(self, tensordict: Optional[TensorDict] = None): def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict td = tensordict
if td is None or td.is_empty(): if td is None or td.is_empty():
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
raw_obs = self._env.reset() raw_obs = self._env.reset()
assert self._current_seed == self._env._seed
td = TensorDict( td = TensorDict(
{ {

View File

@ -49,7 +49,6 @@ class SimxarmEnv(EnvBase):
raise ImportError("Cannot import gym.") raise ImportError("Cannot import gym.")
import gym import gym
from gym.wrappers import TimeLimit
from simxarm import TASKS from simxarm import TASKS
if self.task not in TASKS: if self.task not in TASKS:
@ -58,7 +57,6 @@ class SimxarmEnv(EnvBase):
) )
self._env = TASKS[self.task]["env"]() self._env = TASKS[self.task]["env"]()
self._env = TimeLimit(self._env, TASKS[self.task]["episode_length"])
MAX_NUM_ACTIONS = 4 MAX_NUM_ACTIONS = 4
num_actions = len(TASKS[self.task]["action_space"]) num_actions = len(TASKS[self.task]["action_space"])

View File

@ -11,8 +11,9 @@ from termcolor import colored
CONSOLE_FORMAT = [ CONSOLE_FORMAT = [
("episode", "E", "int"), ("episode", "E", "int"),
("env_step", "S", "int"), ("env_step", "S", "int"),
("avg_reward", "R", "float"), ("avg_sum_reward", "RS", "float"),
("pc_success", "R", "float"), ("avg_max_reward", "RM", "float"),
("pc_success", "S", "float"),
("total_time", "T", "time"), ("total_time", "T", "time"),
] ]
AGENT_METRICS = [ AGENT_METRICS = [
@ -69,7 +70,11 @@ def print_run(cfg, reward=None):
def cfg_to_group(cfg, return_list=False): def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list.""" """Return a wandb-safe group name for logging. Optionally returns group name as list."""
lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] # lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
lst = [
f"env:{cfg.env}",
f"seed:{cfg.seed}",
]
return lst if return_list else "-".join(lst) return lst if return_list else "-".join(lst)
@ -120,8 +125,9 @@ class VideoRecorder:
class Logger(object): class Logger(object):
"""Primary logger object. Logs either locally or using wandb.""" """Primary logger object. Logs either locally or using wandb."""
def __init__(self, log_dir, cfg): def __init__(self, log_dir, job_name, cfg):
self._log_dir = make_dir(Path(log_dir)) self._log_dir = make_dir(Path(log_dir))
self._job_name = job_name
self._model_dir = make_dir(self._log_dir / "models") self._model_dir = make_dir(self._log_dir / "models")
self._buffer_dir = make_dir(self._log_dir / "buffers") self._buffer_dir = make_dir(self._log_dir / "buffers")
self._save_model = cfg.save_model self._save_model = cfg.save_model
@ -131,9 +137,8 @@ class Logger(object):
self._cfg = cfg self._cfg = cfg
self._eval = [] self._eval = []
print_run(cfg) print_run(cfg)
project, entity = cfg.get("wandb_project", "none"), cfg.get( project = cfg.get("wandb_project", "none")
"wandb_entity", "none" entity = cfg.get("wandb_entity", "none")
)
run_offline = ( run_offline = (
not cfg.get("use_wandb", False) or project == "none" or entity == "none" not cfg.get("use_wandb", False) or project == "none" or entity == "none"
) )
@ -141,35 +146,39 @@ class Logger(object):
print(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) print(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None self._wandb = None
else: else:
try: # try:
os.environ["WANDB_SILENT"] = "true" os.environ["WANDB_SILENT"] = "true"
import wandb import wandb
wandb.init( wandb.init(
project=project, project=project,
entity=entity, entity=entity,
name=str(cfg.seed), name=job_name,
notes=cfg.notes, notes=cfg.notes,
group=self._group, # group=self._group,
tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"], tags=cfg_to_group(cfg, return_list=True),
dir=self._log_dir, dir=self._log_dir,
config=OmegaConf.to_container(cfg, resolve=True), config=OmegaConf.to_container(cfg, resolve=True),
) # TODO(rcadene): try set to True
print( save_code=False,
colored("Logs will be synced with wandb.", "blue", attrs=["bold"]) # TODO(rcadene): split train and eval, and run async eval with job_type="eval"
) job_type="train_eval",
self._wandb = wandb # TODO(rcadene): add resume option
except: resume=None,
print( )
colored( print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
"Warning: failed to init wandb. Make sure `wandb_entity` is set to your username in `config.yaml`. Logs will be saved locally.", self._wandb = wandb
"yellow", # except:
attrs=["bold"], # print(
) # colored(
) # "Warning: failed to init wandb. Make sure `wandb_entity` is set to your username in `config.yaml`. Logs will be saved locally.",
self._wandb = None # "yellow",
# attrs=["bold"],
# )
# )
# self._wandb = None
self._video = ( self._video = (
VideoRecorder(log_dir, self._wandb) VideoRecorder(self._log_dir, self._wandb)
if self._wandb and cfg.save_video if self._wandb and cfg.save_video
else None else None
) )
@ -235,7 +244,7 @@ class Logger(object):
self._wandb.log({category + "/" + k: v}, step=d["env_step"]) self._wandb.log({category + "/" + k: v}, step=d["env_step"])
if category == "eval": if category == "eval":
# keys = ['env_step', 'avg_reward'] # keys = ['env_step', 'avg_reward']
keys = ["env_step", "avg_reward", "pc_success"] keys = ["env_step", "avg_sum_reward", "avg_max_reward", "pc_success"]
self._eval.append(np.array([d[key] for key in keys])) self._eval.append(np.array([d[key] for key in keys]))
pd.DataFrame(np.array(self._eval)).to_csv( pd.DataFrame(np.array(self._eval)).to_csv(
self._log_dir / "eval.log", header=keys, index=None self._log_dir / "eval.log", header=keys, index=None

View File

@ -96,7 +96,7 @@ class TDMPC(nn.Module):
self.model_target.eval() self.model_target.eval()
self.batch_size = cfg.batch_size self.batch_size = cfg.batch_size
self.step = 0 self.register_buffer("step", torch.zeros(1))
def state_dict(self): def state_dict(self):
"""Retrieve state dict of TOLD model, including slow-moving target network.""" """Retrieve state dict of TOLD model, including slow-moving target network."""
@ -122,7 +122,7 @@ class TDMPC(nn.Module):
"rgb": observation["image"], "rgb": observation["image"],
"state": observation["state"], "state": observation["state"],
} }
return self.act(obs, t0=t0, step=self.step) return self.act(obs, t0=t0, step=self.step.item())
@torch.no_grad() @torch.no_grad()
def act(self, obs, t0=False, step=None): def act(self, obs, t0=False, step=None):
@ -513,5 +513,5 @@ class TDMPC(nn.Module):
metrics.update(value_info) metrics.update(value_info)
metrics.update(pi_update_info) metrics.update(pi_update_info)
self.step = step self.step[0] = step
return metrics return metrics

View File

@ -1,7 +1,10 @@
hydra:
run:
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}_${hydra.job.name}
job:
name: default
seed: 1337 seed: 1337
log_dir: logs/2024_01_26_train
video_dir: tmp/2024_01_26_xarm_lift_medium
exp_name: default
device: cuda device: cuda
buffer_device: cuda buffer_device: cuda
eval_freq: 1000 eval_freq: 1000
@ -21,6 +24,7 @@ fps: 15
reward_scale: 1.0 reward_scale: 1.0
# xarm_lift # xarm_lift
episode_length: 25 episode_length: 25
modality: 'all' modality: 'all'
@ -97,6 +101,7 @@ mlp_dim: 512
latent_dim: 50 latent_dim: 50
# wandb # wandb
use_wandb: false use_wandb: true
wandb_project: FOWM wandb_project: lerobot
wandb_entity: rcadene # insert your own wandb_entity: rcadene # insert your own
notes: ""

View File

@ -5,12 +5,16 @@ hydra:
job: job:
name: pusht name: pusht
video_dir: tmp/2024_02_21_pusht
# env # env
env: pusht env: pusht
task: pusht
image_size: 96 image_size: 96
frame_skip: 1 frame_skip: 1
state_dim: 2 state_dim: 2
action_dim: 2 action_dim: 2
fps: 10 fps: 10
eval_episodes: 50
episode_length: 300
eval_freq: 7500
save_freq: 75000

View File

@ -21,19 +21,22 @@ def eval_policy(
save_video: bool = False, save_video: bool = False,
video_dir: Path = None, video_dir: Path = None,
fps: int = 15, fps: int = 15,
env_step: int = None,
wandb=None,
): ):
rewards = [] if wandb is not None:
assert env_step is not None
sum_rewards = []
max_rewards = []
successes = [] successes = []
for i in range(num_episodes): for i in range(num_episodes):
ep_frames = [] ep_frames = []
def rendering_callback(env, td=None): def rendering_callback(env, td=None):
nonlocal ep_frames ep_frames.append(env.render())
frame = env.render()
ep_frames.append(frame)
tensordict = env.reset() tensordict = env.reset()
if save_video: if save_video or wandb:
# render first frame before rollout # render first frame before rollout
rendering_callback(env) rendering_callback(env)
@ -41,35 +44,54 @@ def eval_policy(
rollout = env.rollout( rollout = env.rollout(
max_steps=max_steps, max_steps=max_steps,
policy=policy, policy=policy,
callback=rendering_callback if save_video else None, callback=rendering_callback if save_video or wandb else None,
auto_reset=False, auto_reset=False,
tensordict=tensordict, tensordict=tensordict,
auto_cast_to_device=True, auto_cast_to_device=True,
) )
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) # print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
ep_reward = rollout["next", "reward"].sum() ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
ep_success = rollout["next", "success"].any() ep_success = rollout["next", "success"].any()
rewards.append(ep_reward.item()) sum_rewards.append(ep_sum_reward.item())
max_rewards.append(ep_max_reward.item())
successes.append(ep_success.item()) successes.append(ep_success.item())
if save_video: if save_video or wandb:
video_dir.mkdir(parents=True, exist_ok=True) stacked_frames = np.stack(ep_frames)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{i}.mp4" if save_video:
imageio.mimsave(video_path, np.stack(ep_frames), fps=fps) video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, stacked_frames, fps=fps)
first_episode = i == 0
if wandb and first_episode:
eval_video = wandb.Video(
stacked_frames.transpose(0, 3, 1, 2), fps=fps, format="mp4"
)
wandb.log({"eval_video": eval_video}, step=env_step)
metrics = { metrics = {
"avg_reward": np.nanmean(rewards), "avg_sum_reward": np.nanmean(sum_rewards),
"avg_max_reward": np.nanmean(max_rewards),
"pc_success": np.nanmean(successes) * 100, "pc_success": np.nanmean(successes) * 100,
} }
return metrics return metrics
@hydra.main(version_base=None, config_name="default", config_path="../configs") @hydra.main(version_base=None, config_name="default", config_path="../configs")
def eval(cfg: dict): def eval_cli(cfg: dict):
eval(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
def eval(cfg: dict, out_dir=None):
if out_dir is None:
raise NotImplementedError()
assert torch.cuda.is_available() assert torch.cuda.is_available()
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir) print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
env = make_env(cfg) env = make_env(cfg)
@ -95,13 +117,14 @@ def eval(cfg: dict):
metrics = eval_policy( metrics = eval_policy(
env, env,
policy=policy, policy=policy,
num_episodes=20,
save_video=True, save_video=True,
video_dir=Path(cfg.video_dir), video_dir=Path(out_dir) / "eval",
fps=cfg.fps, fps=cfg.fps,
max_steps=cfg.episode_length,
num_episodes=cfg.eval_episodes,
) )
print(metrics) print(metrics)
if __name__ == "__main__": if __name__ == "__main__":
eval() eval_cli()

View File

@ -20,24 +20,47 @@ from lerobot.scripts.eval import eval_policy
@hydra.main(version_base=None, config_name="default", config_path="../configs") @hydra.main(version_base=None, config_name="default", config_path="../configs")
def train(cfg: dict): def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(
out_dir=None, job_name=None, config_name="default", config_path="../configs"
):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
def train(cfg: dict, out_dir=None, job_name=None):
if out_dir is None:
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
assert torch.cuda.is_available() assert torch.cuda.is_available()
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir) print(colored("Work dir:", "yellow", attrs=["bold"]), out_dir)
env = make_env(cfg) env = make_env(cfg)
policy = TDMPC(cfg) policy = TDMPC(cfg)
if cfg.pretrained_model_path: if cfg.pretrained_model_path:
ckpt_path = ( # TODO(rcadene): hack for old pretrained models from fowm
"/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" if "fowm" in cfg.pretrained_model_path:
) if "offline" in cfg.pretrained_model_path:
if "offline" in cfg.pretrained_model_path: policy.step = 25000
policy.step = 25000 elif "final" in cfg.pretrained_model_path:
elif "final" in cfg.pretrained_model_path: policy.step = 100000
policy.step = 100000 else:
else: raise NotImplementedError()
raise NotImplementedError() policy.load(cfg.pretrained_model_path)
policy.load(ckpt_path)
td_policy = TensorDictModule( td_policy = TensorDictModule(
policy, policy,
@ -65,7 +88,7 @@ def train(cfg: dict):
sampler=online_sampler, sampler=online_sampler,
) )
L = Logger(cfg.log_dir, cfg) L = Logger(out_dir, job_name, cfg)
online_episode_idx = 0 online_episode_idx = 0
start_time = time.time() start_time = time.time()
@ -95,12 +118,14 @@ def train(cfg: dict):
) )
online_buffer.extend(rollout) online_buffer.extend(rollout)
ep_reward = rollout["next", "reward"].sum() ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
ep_success = rollout["next", "success"].any() ep_success = rollout["next", "success"].any()
online_episode_idx += 1 online_episode_idx += 1
rollout_metrics = { rollout_metrics = {
"avg_reward": np.nanmean(ep_reward), "avg_sum_reward": np.nanmean(ep_sum_reward),
"avg_max_reward": np.nanmean(ep_max_reward),
"pc_success": np.nanmean(ep_success) * 100, "pc_success": np.nanmean(ep_success) * 100,
} }
num_updates = len(rollout) * cfg.utd num_updates = len(rollout) * cfg.utd
@ -137,23 +162,23 @@ def train(cfg: dict):
env, env,
td_policy, td_policy,
num_episodes=cfg.eval_episodes, num_episodes=cfg.eval_episodes,
# TODO(rcadene): add step, env_step, L.video env_step=env_step,
wandb=L._wandb,
) )
common_metrics.update(eval_metrics) common_metrics.update(eval_metrics)
L.log(common_metrics, category="eval") L.log(common_metrics, category="eval")
last_log_step = env_step - env_step % cfg.eval_freq last_log_step = env_step - env_step % cfg.eval_freq
# Save model periodically # Save model periodically
# if cfg.save_model and env_step - last_save_step >= cfg.save_freq: if cfg.save_model and env_step - last_save_step >= cfg.save_freq:
# L.save_model(policy, identifier=env_step) L.save_model(policy, identifier=env_step)
# print(f"Model has been checkpointed at step {env_step}") print(f"Model has been checkpointed at step {env_step}")
# last_save_step = env_step - env_step % cfg.save_freq last_save_step = env_step - env_step % cfg.save_freq
# if cfg.save_model and is_offline and _step >= cfg.offline_steps: if cfg.save_model and is_offline and _step >= cfg.offline_steps:
# # save the model after offline training # save the model after offline training
# L.save_model(policy, identifier="offline") L.save_model(policy, identifier="offline")
step = _step step = _step
@ -177,4 +202,4 @@ def train(cfg: dict):
if __name__ == "__main__": if __name__ == "__main__":
train() train_cli()

View File

@ -15,7 +15,15 @@ from lerobot.common.datasets.factory import make_offline_buffer
@hydra.main(version_base=None, config_name="default", config_path="../configs") @hydra.main(version_base=None, config_name="default", config_path="../configs")
def visualize_dataset(cfg: dict): def visualize_dataset_cli(cfg: dict):
visualize_dataset(
cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
)
def visualize_dataset(cfg: dict, out_dir=None):
if out_dir is None:
raise NotImplementedError()
sampler = SliceSamplerWithoutReplacement( sampler = SliceSamplerWithoutReplacement(
num_slices=1, num_slices=1,
@ -40,10 +48,10 @@ def visualize_dataset(cfg: dict):
dim=0, dim=0,
) )
video_dir = Path(cfg.video_dir) video_dir = Path(out_dir) / "visualize_dataset"
video_dir.mkdir(parents=True, exist_ok=True) video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable # TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{ep_idx}.mp4" video_path = video_dir / f"episode_{ep_idx}.mp4"
assert ep_frames.min().item() >= 0 assert ep_frames.min().item() >= 0
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check" assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
@ -59,4 +67,4 @@ def visualize_dataset(cfg: dict):
if __name__ == "__main__": if __name__ == "__main__":
visualize_dataset() visualize_dataset_cli()