From c202c2b3c211f96d19af960f2d1d8e8319e28a48 Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 16 Feb 2024 15:13:24 +0000 Subject: [PATCH] Online finetuning runs (sometimes crash because of nans) --- README.md | 12 ++- lerobot/common/envs/simxarm.py | 16 ++-- lerobot/common/tdmpc.py | 134 ++++++++++++++++++--------------- lerobot/configs/default.yaml | 2 +- lerobot/scripts/train.py | 111 ++++++++++++++++++--------- 5 files changed, 165 insertions(+), 110 deletions(-) diff --git a/README.md b/README.md index 4bba66df..07da0d9a 100644 --- a/README.md +++ b/README.md @@ -15,12 +15,20 @@ conda activate lerobot python setup.py develop ``` +## TODO + +- [ ] priority update doesnt match FOWM or original paper +- [ ] self.step=100000 should be updated at every step to adjust to horizon of planner +- [ ] prefetch replay buffer to speedup training +- [ ] parallelize env to speedup eval ## Contribute **style** ``` -isort . -black . +isort lerobot +black lerobot +isort test +black test pylint lerobot ``` diff --git a/lerobot/common/envs/simxarm.py b/lerobot/common/envs/simxarm.py index 1470fceb..8d955072 100644 --- a/lerobot/common/envs/simxarm.py +++ b/lerobot/common/envs/simxarm.py @@ -77,18 +77,16 @@ class SimxarmEnv(EnvBase): def _format_raw_obs(self, raw_obs): if self.from_pixels: - camera = self.render( + image = self.render( mode="rgb_array", width=self.image_size, height=self.image_size ) - camera = camera.transpose(2, 0, 1) # (H, W, C) -> (C, H, W) - camera = torch.tensor(camera.copy(), dtype=torch.uint8) + image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W) + image = torch.tensor(image.copy(), dtype=torch.uint8) - obs = {"camera": camera} + obs = {"image": image} if not self.pixels_only: - obs["robot_state"] = torch.tensor( - self._env.robot_state, dtype=torch.float32 - ) + obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32) else: obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)} @@ -136,7 +134,7 @@ class SimxarmEnv(EnvBase): def _make_spec(self): obs = {} if self.from_pixels: - obs["camera"] = BoundedTensorSpec( + obs["image"] = BoundedTensorSpec( low=0, high=255, shape=(3, self.image_size, self.image_size), @@ -144,7 +142,7 @@ class SimxarmEnv(EnvBase): device=self.device, ) if not self.pixels_only: - obs["robot_state"] = UnboundedContinuousTensorSpec( + obs["state"] = UnboundedContinuousTensorSpec( shape=(len(self._env.robot_state),), dtype=torch.float32, device=self.device, diff --git a/lerobot/common/tdmpc.py b/lerobot/common/tdmpc.py index da8638dd..d694a06d 100644 --- a/lerobot/common/tdmpc.py +++ b/lerobot/common/tdmpc.py @@ -96,8 +96,7 @@ class TDMPC(nn.Module): self.model_target.eval() self.batch_size = cfg.batch_size - # TODO(rcadene): clean - self.step = 100000 + self.step = 0 def state_dict(self): """Retrieve state dict of TOLD model, including slow-moving target network.""" @@ -120,8 +119,8 @@ class TDMPC(nn.Module): def forward(self, observation, step_count): t0 = step_count.item() == 0 obs = { - "rgb": observation["camera"], - "state": observation["robot_state"], + "rgb": observation["image"], + "state": observation["state"], } return self.act(obs, t0=t0, step=self.step) @@ -298,65 +297,81 @@ class TDMPC(nn.Module): def update(self, replay_buffer, step, demo_buffer=None): """Main update function. Corresponds to one iteration of the model learning.""" - if demo_buffer is not None: - # Update oversampling ratio - self.demo_batch_size = int( - h.linear_schedule(self.cfg.demo_schedule, step) * self.batch_size - ) - replay_buffer.cfg.batch_size = self.batch_size - self.demo_batch_size - demo_buffer.cfg.batch_size = self.demo_batch_size + num_slices = self.cfg.batch_size + batch_size = self.cfg.horizon * num_slices + + if demo_buffer is None: + demo_batch_size = 0 else: - self.demo_batch_size = 0 + # Update oversampling ratio + demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step) + demo_num_slices = int(demo_pc_batch * self.batch_size) + demo_batch_size = self.cfg.horizon * demo_num_slices + batch_size -= demo_batch_size + num_slices -= demo_num_slices + replay_buffer._sampler.num_slices = num_slices + demo_buffer._sampler.num_slices = demo_num_slices + + assert demo_batch_size % self.cfg.horizon == 0 + assert demo_batch_size % demo_num_slices == 0 + + assert batch_size % self.cfg.horizon == 0 + assert batch_size % num_slices == 0 # Sample from interaction dataset - # to not have to mask - # batch_size = (self.cfg.batch_size // self.cfg.horizon) * self.cfg.horizon - batch_size = self.cfg.horizon * self.cfg.batch_size + def process_batch(batch, horizon, num_slices): + # trajectory t = 256, horizon h = 5 + # (t h) ... -> h t ... + batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() + batch = batch.to("cuda") + + FIRST_FRAME = 0 + obs = { + "rgb": batch["observation", "image"][FIRST_FRAME].float(), + "state": batch["observation", "state"][FIRST_FRAME], + } + action = batch["action"] + next_obses = { + "rgb": batch["next", "observation", "image"].float(), + "state": batch["next", "observation", "state"], + } + reward = batch["next", "reward"] + + # TODO(rcadene): rearrange directly in offline dataset + if reward.ndim == 2: + reward = einops.rearrange(reward, "h t -> h t 1") + + assert reward.ndim == 3 + assert reward.shape == (horizon, num_slices, 1) + # We dont use `batch["next", "done"]` since it only indicates the end of an + # episode, but not the end of the trajectory of an episode. + # Neither does `batch["next", "terminated"]` + done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) + mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) + + idxs = batch["index"][FIRST_FRAME] + weights = batch["_weight"][FIRST_FRAME, :, None] + return obs, action, next_obses, reward, mask, done, idxs, weights + batch = replay_buffer.sample(batch_size) - - # trajectory t = 256, horizon h = 5 - # (t h) ... -> h t ... - batch = ( - batch.reshape(self.cfg.batch_size, self.cfg.horizon) - .transpose(1, 0) - .contiguous() + obs, action, next_obses, reward, mask, done, idxs, weights = process_batch( + batch, self.cfg.horizon, num_slices ) - batch = batch.to("cuda") - - FIRST_FRAME = 0 - obs = { - "rgb": batch["observation", "image"][FIRST_FRAME].float(), - "state": batch["observation", "state"][FIRST_FRAME], - } - action = batch["action"] - next_obses = { - "rgb": batch["next", "observation", "image"].float(), - "state": batch["next", "observation", "state"], - } - reward = batch["next", "reward"] - reward = einops.rearrange(reward, "h t -> h t 1") - # We dont use `batch["next", "done"]` since it only indicates the end of an - # episode, but not the end of the trajectory of an episode. - # Neither does `batch["next", "terminated"]` - done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) - mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) - - idxs = batch["frame_id"][FIRST_FRAME] - weights = batch["_weight"][FIRST_FRAME, :, None] # Sample from demonstration dataset - if self.demo_batch_size > 0: + if demo_batch_size > 0: + demo_batch = demo_buffer.sample(demo_batch_size) ( demo_obs, - demo_next_obses, demo_action, + demo_next_obses, demo_reward, demo_mask, demo_done, demo_idxs, demo_weights, - ) = demo_buffer.sample() + ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices) if isinstance(obs, dict): obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs} @@ -440,9 +455,9 @@ class TDMPC(nn.Module): q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0) priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0) - self.expectile = h.linear_schedule(self.cfg.expectile, step) + expectile = h.linear_schedule(self.cfg.expectile, step) v_value_loss = ( - rho * h.l2_expectile(v_target - v, expectile=self.expectile) * loss_mask + rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask ).sum(dim=0) total_loss = ( @@ -464,17 +479,12 @@ class TDMPC(nn.Module): if self.cfg.per: # Update priorities priorities = priority_loss.clamp(max=1e4).detach() - # normalize between [0,1] to fit torchrl specification - priorities /= 1e4 - priorities = priorities.clamp(max=1.0) replay_buffer.update_priority( - idxs[: self.cfg.batch_size], - priorities[: self.cfg.batch_size], + idxs[:num_slices], + priorities[:num_slices], ) - if self.demo_batch_size > 0: - demo_buffer.update_priority( - demo_idxs, priorities[self.cfg.batch_size :] - ) + if demo_batch_size > 0: + demo_buffer.update_priority(demo_idxs, priorities[num_slices:]) # Update policy + target network _, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action) @@ -493,10 +503,12 @@ class TDMPC(nn.Module): "weighted_loss": float(weighted_loss.mean().item()), "grad_norm": float(grad_norm), } - for key in ["demo_batch_size", "expectile"]: - if hasattr(self, key): - metrics[key] = getattr(self, key) + # for key in ["demo_batch_size", "expectile"]: + # if hasattr(self, key): + metrics["demo_batch_size"] = demo_batch_size + metrics["expectile"] = expectile metrics.update(value_info) metrics.update(pi_update_info) + self.step = step return metrics diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index ce43b293..f1a014aa 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -80,7 +80,7 @@ expectile: 0.9 A_scaling: 3.0 # offline->online -offline_steps: ${train_steps}/2 +offline_steps: 25000 # ${train_steps}/2 pretrained_model_path: "" balanced_sampling: true demo_schedule: 0.5 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 55c9c0f8..8ae05cda 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -19,6 +19,7 @@ from lerobot.common.logger import Logger from lerobot.common.tdmpc import TDMPC from lerobot.common.utils import set_seed from lerobot.scripts.eval import eval_policy +from rl.torchrl.collectors.collectors import SyncDataCollector @hydra.main(version_base=None, config_name="default", config_path="../configs") @@ -29,8 +30,10 @@ def train(cfg: dict): env = make_env(cfg) policy = TDMPC(cfg) - # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" - ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" + ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" + policy.step = 25000 + # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" + # policy.step = 100000 policy.load(ckpt_path) td_policy = TensorDictModule( @@ -54,7 +57,7 @@ def train(cfg: dict): strict_length=False, ) - # TODO(rcadene): use PrioritizedReplayBuffer + # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here offline_buffer = SimxarmExperienceReplay( dataset_id, # download="force", @@ -68,9 +71,22 @@ def train(cfg: dict): index = torch.arange(0, num_steps, 1) sampler.extend(index) - # offline_buffer._storage.device = torch.device("cuda") - # offline_buffer._storage._storage.to(torch.device("cuda")) - # TODO(rcadene): add online_buffer + if cfg.balanced_sampling: + online_sampler = PrioritizedSliceSampler( + max_capacity=100_000, + alpha=0.7, + beta=0.9, + num_slices=num_traj_per_batch, + strict_length=False, + ) + + online_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(100_000), + sampler=online_sampler, + # batch_size=3, + # pin_memory=False, + # prefetch=3, + ) # Observation encoder # Dynamics predictor @@ -81,59 +97,80 @@ def train(cfg: dict): L = Logger(cfg.log_dir, cfg) - episode_idx = 0 + online_episode_idx = 0 start_time = time.time() step = 0 last_log_step = 0 last_save_step = 0 + # TODO(rcadene): remove + step = 25000 + while step < cfg.train_steps: is_offline = True num_updates = cfg.episode_length _step = step + num_updates rollout_metrics = {} - # if step >= cfg.offline_steps: - # is_offline = False + if step >= cfg.offline_steps: + is_offline = False - # # Collect trajectory - # obs = env.reset() - # episode = Episode(cfg, obs) - # success = False - # while not episode.done: - # action = policy.act(obs, step=step, t0=episode.first) - # obs, reward, done, info = env.step(action.cpu().numpy()) - # reward = reward_normalizer(reward) - # mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0 - # success = info.get('success', False) - # episode += (obs, action, reward, done, mask, success) - # assert len(episode) <= cfg.episode_length - # buffer += episode - # episode_idx += 1 - # rollout_metrics = { - # 'episode_reward': episode.cumulative_reward, - # 'episode_success': float(success), - # 'episode_length': len(episode) - # } - # num_updates = len(episode) * cfg.utd - # _step = min(step + len(episode), cfg.train_steps) + # TODO: use SyncDataCollector for that? + rollout = env.rollout( + max_steps=cfg.episode_length, + policy=td_policy, + ) + assert len(rollout) <= cfg.episode_length + rollout["episode"] = torch.tensor( + [online_episode_idx] * len(rollout), dtype=torch.int + ) + online_buffer.extend(rollout) + + # Collect trajectory + # obs = env.reset() + # episode = Episode(cfg, obs) + # success = False + # while not episode.done: + # action = policy.act(obs, step=step, t0=episode.first) + # obs, reward, done, info = env.step(action.cpu().numpy()) + # reward = reward_normalizer(reward) + # mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0 + # success = info.get('success', False) + # episode += (obs, action, reward, done, mask, success) + + ep_reward = rollout["next", "reward"].sum() + ep_success = rollout["next", "success"].any() + + online_episode_idx += 1 + rollout_metrics = { + # 'episode_reward': episode.cumulative_reward, + # 'episode_success': float(success), + # 'episode_length': len(episode) + "avg_reward": np.nanmean(ep_reward), + "pc_success": np.nanmean(ep_success) * 100, + } + num_updates = len(rollout) * cfg.utd + _step = min(step + len(rollout), cfg.train_steps) # Update model train_metrics = {} if is_offline: for i in range(num_updates): train_metrics.update(policy.update(offline_buffer, step + i)) - # else: - # for i in range(num_updates): - # train_metrics.update( - # policy.update(buffer, step + i // cfg.utd, - # demo_buffer=offline_buffer if cfg.balanced_sampling else None) - # ) + else: + for i in range(num_updates): + train_metrics.update( + policy.update( + online_buffer, + step + i // cfg.utd, + demo_buffer=offline_buffer if cfg.balanced_sampling else None, + ) + ) # Log training metrics env_step = int(_step * cfg.action_repeat) common_metrics = { - "episode": episode_idx, + "episode": online_episode_idx, "step": _step, "env_step": env_step, "total_time": time.time() - start_time,