Online finetuning runs (sometimes crash because of nans)

This commit is contained in:
Cadene 2024-02-16 15:13:24 +00:00
parent 228c045674
commit c202c2b3c2
5 changed files with 165 additions and 110 deletions

View File

@ -15,12 +15,20 @@ conda activate lerobot
python setup.py develop python setup.py develop
``` ```
## TODO
- [ ] priority update doesnt match FOWM or original paper
- [ ] self.step=100000 should be updated at every step to adjust to horizon of planner
- [ ] prefetch replay buffer to speedup training
- [ ] parallelize env to speedup eval
## Contribute ## Contribute
**style** **style**
``` ```
isort . isort lerobot
black . black lerobot
isort test
black test
pylint lerobot pylint lerobot
``` ```

View File

@ -77,18 +77,16 @@ class SimxarmEnv(EnvBase):
def _format_raw_obs(self, raw_obs): def _format_raw_obs(self, raw_obs):
if self.from_pixels: if self.from_pixels:
camera = self.render( image = self.render(
mode="rgb_array", width=self.image_size, height=self.image_size mode="rgb_array", width=self.image_size, height=self.image_size
) )
camera = camera.transpose(2, 0, 1) # (H, W, C) -> (C, H, W) image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
camera = torch.tensor(camera.copy(), dtype=torch.uint8) image = torch.tensor(image.copy(), dtype=torch.uint8)
obs = {"camera": camera} obs = {"image": image}
if not self.pixels_only: if not self.pixels_only:
obs["robot_state"] = torch.tensor( obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32)
self._env.robot_state, dtype=torch.float32
)
else: else:
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)} obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
@ -136,7 +134,7 @@ class SimxarmEnv(EnvBase):
def _make_spec(self): def _make_spec(self):
obs = {} obs = {}
if self.from_pixels: if self.from_pixels:
obs["camera"] = BoundedTensorSpec( obs["image"] = BoundedTensorSpec(
low=0, low=0,
high=255, high=255,
shape=(3, self.image_size, self.image_size), shape=(3, self.image_size, self.image_size),
@ -144,7 +142,7 @@ class SimxarmEnv(EnvBase):
device=self.device, device=self.device,
) )
if not self.pixels_only: if not self.pixels_only:
obs["robot_state"] = UnboundedContinuousTensorSpec( obs["state"] = UnboundedContinuousTensorSpec(
shape=(len(self._env.robot_state),), shape=(len(self._env.robot_state),),
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,

View File

@ -96,8 +96,7 @@ class TDMPC(nn.Module):
self.model_target.eval() self.model_target.eval()
self.batch_size = cfg.batch_size self.batch_size = cfg.batch_size
# TODO(rcadene): clean self.step = 0
self.step = 100000
def state_dict(self): def state_dict(self):
"""Retrieve state dict of TOLD model, including slow-moving target network.""" """Retrieve state dict of TOLD model, including slow-moving target network."""
@ -120,8 +119,8 @@ class TDMPC(nn.Module):
def forward(self, observation, step_count): def forward(self, observation, step_count):
t0 = step_count.item() == 0 t0 = step_count.item() == 0
obs = { obs = {
"rgb": observation["camera"], "rgb": observation["image"],
"state": observation["robot_state"], "state": observation["state"],
} }
return self.act(obs, t0=t0, step=self.step) return self.act(obs, t0=t0, step=self.step)
@ -298,65 +297,81 @@ class TDMPC(nn.Module):
def update(self, replay_buffer, step, demo_buffer=None): def update(self, replay_buffer, step, demo_buffer=None):
"""Main update function. Corresponds to one iteration of the model learning.""" """Main update function. Corresponds to one iteration of the model learning."""
if demo_buffer is not None: num_slices = self.cfg.batch_size
# Update oversampling ratio batch_size = self.cfg.horizon * num_slices
self.demo_batch_size = int(
h.linear_schedule(self.cfg.demo_schedule, step) * self.batch_size if demo_buffer is None:
) demo_batch_size = 0
replay_buffer.cfg.batch_size = self.batch_size - self.demo_batch_size
demo_buffer.cfg.batch_size = self.demo_batch_size
else: else:
self.demo_batch_size = 0 # Update oversampling ratio
demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
demo_num_slices = int(demo_pc_batch * self.batch_size)
demo_batch_size = self.cfg.horizon * demo_num_slices
batch_size -= demo_batch_size
num_slices -= demo_num_slices
replay_buffer._sampler.num_slices = num_slices
demo_buffer._sampler.num_slices = demo_num_slices
assert demo_batch_size % self.cfg.horizon == 0
assert demo_batch_size % demo_num_slices == 0
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
# Sample from interaction dataset # Sample from interaction dataset
# to not have to mask def process_batch(batch, horizon, num_slices):
# batch_size = (self.cfg.batch_size // self.cfg.horizon) * self.cfg.horizon # trajectory t = 256, horizon h = 5
batch_size = self.cfg.horizon * self.cfg.batch_size # (t h) ... -> h t ...
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
batch = batch.to("cuda")
FIRST_FRAME = 0
obs = {
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
"state": batch["observation", "state"][FIRST_FRAME],
}
action = batch["action"]
next_obses = {
"rgb": batch["next", "observation", "image"].float(),
"state": batch["next", "observation", "state"],
}
reward = batch["next", "reward"]
# TODO(rcadene): rearrange directly in offline dataset
if reward.ndim == 2:
reward = einops.rearrange(reward, "h t -> h t 1")
assert reward.ndim == 3
assert reward.shape == (horizon, num_slices, 1)
# We dont use `batch["next", "done"]` since it only indicates the end of an
# episode, but not the end of the trajectory of an episode.
# Neither does `batch["next", "terminated"]`
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
idxs = batch["index"][FIRST_FRAME]
weights = batch["_weight"][FIRST_FRAME, :, None]
return obs, action, next_obses, reward, mask, done, idxs, weights
batch = replay_buffer.sample(batch_size) batch = replay_buffer.sample(batch_size)
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
# trajectory t = 256, horizon h = 5 batch, self.cfg.horizon, num_slices
# (t h) ... -> h t ...
batch = (
batch.reshape(self.cfg.batch_size, self.cfg.horizon)
.transpose(1, 0)
.contiguous()
) )
batch = batch.to("cuda")
FIRST_FRAME = 0
obs = {
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
"state": batch["observation", "state"][FIRST_FRAME],
}
action = batch["action"]
next_obses = {
"rgb": batch["next", "observation", "image"].float(),
"state": batch["next", "observation", "state"],
}
reward = batch["next", "reward"]
reward = einops.rearrange(reward, "h t -> h t 1")
# We dont use `batch["next", "done"]` since it only indicates the end of an
# episode, but not the end of the trajectory of an episode.
# Neither does `batch["next", "terminated"]`
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
idxs = batch["frame_id"][FIRST_FRAME]
weights = batch["_weight"][FIRST_FRAME, :, None]
# Sample from demonstration dataset # Sample from demonstration dataset
if self.demo_batch_size > 0: if demo_batch_size > 0:
demo_batch = demo_buffer.sample(demo_batch_size)
( (
demo_obs, demo_obs,
demo_next_obses,
demo_action, demo_action,
demo_next_obses,
demo_reward, demo_reward,
demo_mask, demo_mask,
demo_done, demo_done,
demo_idxs, demo_idxs,
demo_weights, demo_weights,
) = demo_buffer.sample() ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
if isinstance(obs, dict): if isinstance(obs, dict):
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs} obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
@ -440,9 +455,9 @@ class TDMPC(nn.Module):
q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0) q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0)
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0) priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
self.expectile = h.linear_schedule(self.cfg.expectile, step) expectile = h.linear_schedule(self.cfg.expectile, step)
v_value_loss = ( v_value_loss = (
rho * h.l2_expectile(v_target - v, expectile=self.expectile) * loss_mask rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask
).sum(dim=0) ).sum(dim=0)
total_loss = ( total_loss = (
@ -464,17 +479,12 @@ class TDMPC(nn.Module):
if self.cfg.per: if self.cfg.per:
# Update priorities # Update priorities
priorities = priority_loss.clamp(max=1e4).detach() priorities = priority_loss.clamp(max=1e4).detach()
# normalize between [0,1] to fit torchrl specification
priorities /= 1e4
priorities = priorities.clamp(max=1.0)
replay_buffer.update_priority( replay_buffer.update_priority(
idxs[: self.cfg.batch_size], idxs[:num_slices],
priorities[: self.cfg.batch_size], priorities[:num_slices],
) )
if self.demo_batch_size > 0: if demo_batch_size > 0:
demo_buffer.update_priority( demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
demo_idxs, priorities[self.cfg.batch_size :]
)
# Update policy + target network # Update policy + target network
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action) _, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
@ -493,10 +503,12 @@ class TDMPC(nn.Module):
"weighted_loss": float(weighted_loss.mean().item()), "weighted_loss": float(weighted_loss.mean().item()),
"grad_norm": float(grad_norm), "grad_norm": float(grad_norm),
} }
for key in ["demo_batch_size", "expectile"]: # for key in ["demo_batch_size", "expectile"]:
if hasattr(self, key): # if hasattr(self, key):
metrics[key] = getattr(self, key) metrics["demo_batch_size"] = demo_batch_size
metrics["expectile"] = expectile
metrics.update(value_info) metrics.update(value_info)
metrics.update(pi_update_info) metrics.update(pi_update_info)
self.step = step
return metrics return metrics

View File

@ -80,7 +80,7 @@ expectile: 0.9
A_scaling: 3.0 A_scaling: 3.0
# offline->online # offline->online
offline_steps: ${train_steps}/2 offline_steps: 25000 # ${train_steps}/2
pretrained_model_path: "" pretrained_model_path: ""
balanced_sampling: true balanced_sampling: true
demo_schedule: 0.5 demo_schedule: 0.5

View File

@ -19,6 +19,7 @@ from lerobot.common.logger import Logger
from lerobot.common.tdmpc import TDMPC from lerobot.common.tdmpc import TDMPC
from lerobot.common.utils import set_seed from lerobot.common.utils import set_seed
from lerobot.scripts.eval import eval_policy from lerobot.scripts.eval import eval_policy
from rl.torchrl.collectors.collectors import SyncDataCollector
@hydra.main(version_base=None, config_name="default", config_path="../configs") @hydra.main(version_base=None, config_name="default", config_path="../configs")
@ -29,8 +30,10 @@ def train(cfg: dict):
env = make_env(cfg) env = make_env(cfg)
policy = TDMPC(cfg) policy = TDMPC(cfg)
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" policy.step = 25000
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
# policy.step = 100000
policy.load(ckpt_path) policy.load(ckpt_path)
td_policy = TensorDictModule( td_policy = TensorDictModule(
@ -54,7 +57,7 @@ def train(cfg: dict):
strict_length=False, strict_length=False,
) )
# TODO(rcadene): use PrioritizedReplayBuffer # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
offline_buffer = SimxarmExperienceReplay( offline_buffer = SimxarmExperienceReplay(
dataset_id, dataset_id,
# download="force", # download="force",
@ -68,9 +71,22 @@ def train(cfg: dict):
index = torch.arange(0, num_steps, 1) index = torch.arange(0, num_steps, 1)
sampler.extend(index) sampler.extend(index)
# offline_buffer._storage.device = torch.device("cuda") if cfg.balanced_sampling:
# offline_buffer._storage._storage.to(torch.device("cuda")) online_sampler = PrioritizedSliceSampler(
# TODO(rcadene): add online_buffer max_capacity=100_000,
alpha=0.7,
beta=0.9,
num_slices=num_traj_per_batch,
strict_length=False,
)
online_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(100_000),
sampler=online_sampler,
# batch_size=3,
# pin_memory=False,
# prefetch=3,
)
# Observation encoder # Observation encoder
# Dynamics predictor # Dynamics predictor
@ -81,59 +97,80 @@ def train(cfg: dict):
L = Logger(cfg.log_dir, cfg) L = Logger(cfg.log_dir, cfg)
episode_idx = 0 online_episode_idx = 0
start_time = time.time() start_time = time.time()
step = 0 step = 0
last_log_step = 0 last_log_step = 0
last_save_step = 0 last_save_step = 0
# TODO(rcadene): remove
step = 25000
while step < cfg.train_steps: while step < cfg.train_steps:
is_offline = True is_offline = True
num_updates = cfg.episode_length num_updates = cfg.episode_length
_step = step + num_updates _step = step + num_updates
rollout_metrics = {} rollout_metrics = {}
# if step >= cfg.offline_steps: if step >= cfg.offline_steps:
# is_offline = False is_offline = False
# # Collect trajectory # TODO: use SyncDataCollector for that?
# obs = env.reset() rollout = env.rollout(
# episode = Episode(cfg, obs) max_steps=cfg.episode_length,
# success = False policy=td_policy,
# while not episode.done: )
# action = policy.act(obs, step=step, t0=episode.first) assert len(rollout) <= cfg.episode_length
# obs, reward, done, info = env.step(action.cpu().numpy()) rollout["episode"] = torch.tensor(
# reward = reward_normalizer(reward) [online_episode_idx] * len(rollout), dtype=torch.int
# mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0 )
# success = info.get('success', False) online_buffer.extend(rollout)
# episode += (obs, action, reward, done, mask, success)
# assert len(episode) <= cfg.episode_length # Collect trajectory
# buffer += episode # obs = env.reset()
# episode_idx += 1 # episode = Episode(cfg, obs)
# rollout_metrics = { # success = False
# 'episode_reward': episode.cumulative_reward, # while not episode.done:
# 'episode_success': float(success), # action = policy.act(obs, step=step, t0=episode.first)
# 'episode_length': len(episode) # obs, reward, done, info = env.step(action.cpu().numpy())
# } # reward = reward_normalizer(reward)
# num_updates = len(episode) * cfg.utd # mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0
# _step = min(step + len(episode), cfg.train_steps) # success = info.get('success', False)
# episode += (obs, action, reward, done, mask, success)
ep_reward = rollout["next", "reward"].sum()
ep_success = rollout["next", "success"].any()
online_episode_idx += 1
rollout_metrics = {
# 'episode_reward': episode.cumulative_reward,
# 'episode_success': float(success),
# 'episode_length': len(episode)
"avg_reward": np.nanmean(ep_reward),
"pc_success": np.nanmean(ep_success) * 100,
}
num_updates = len(rollout) * cfg.utd
_step = min(step + len(rollout), cfg.train_steps)
# Update model # Update model
train_metrics = {} train_metrics = {}
if is_offline: if is_offline:
for i in range(num_updates): for i in range(num_updates):
train_metrics.update(policy.update(offline_buffer, step + i)) train_metrics.update(policy.update(offline_buffer, step + i))
# else: else:
# for i in range(num_updates): for i in range(num_updates):
# train_metrics.update( train_metrics.update(
# policy.update(buffer, step + i // cfg.utd, policy.update(
# demo_buffer=offline_buffer if cfg.balanced_sampling else None) online_buffer,
# ) step + i // cfg.utd,
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
)
)
# Log training metrics # Log training metrics
env_step = int(_step * cfg.action_repeat) env_step = int(_step * cfg.action_repeat)
common_metrics = { common_metrics = {
"episode": episode_idx, "episode": online_episode_idx,
"step": _step, "step": _step,
"env_step": env_step, "env_step": env_step,
"total_time": time.time() - start_time, "total_time": time.time() - start_time,