Solve conflicts + pre-commit run -a

This commit is contained in:
Cadene 2024-02-29 23:31:32 +00:00
parent 0b9027f05e
commit ae050d2e94
8 changed files with 26 additions and 41 deletions

View File

@ -108,7 +108,10 @@ eval_episodes=7
**Style**
```
# install if needed
pre-commit install
# apply style and linter checks before git commit
pre-commit run -a
```
**Tests**

View File

@ -9,19 +9,14 @@ import pymunk
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import (
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import (
Sampler,
)
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,

View File

@ -8,9 +8,7 @@ import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import (
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import (
Sampler,
SliceSampler,

View File

@ -1,13 +1,12 @@
import contextlib
import datetime
import os
from pathlib import Path
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from termcolor import colored
def make_dir(dir_path):
"""Create directory if it does not already exist."""
with contextlib.suppress(OSError):

View File

@ -5,7 +5,6 @@ import hydra
import torch
import torch.nn as nn
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
@ -128,12 +127,8 @@ class DiffusionPolicy(nn.Module):
out = {
"obs": {
"image": batch["observation", "image"].to(
self.device, non_blocking=True
),
"agent_pos": batch["observation", "state"].to(
self.device, non_blocking=True
),
"image": batch["observation", "image"].to(self.device, non_blocking=True),
"agent_pos": batch["observation", "state"].to(self.device, non_blocking=True),
},
"action": batch["action"].to(self.device, non_blocking=True),
}

View File

@ -33,7 +33,7 @@ def init_logging():
logging.getLogger().addHandler(console_handler)
def format_number_KMB(num):
def format_big_number(num):
suffixes = ["", "K", "M", "B", "T", "Q"]
divisor = 1000.0

View File

@ -1,5 +1,4 @@
import logging
import time
import hydra
import numpy as np
@ -13,7 +12,7 @@ from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import format_number_KMB, init_logging, set_seed
from lerobot.common.utils import format_big_number, init_logging, set_seed
from lerobot.scripts.eval import eval_policy
@ -49,11 +48,11 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / offline_buffer.num_samples
log_items = [
f"step:{format_number_KMB(step)}",
f"step:{format_big_number(step)}",
# number of samples seen during training
f"smpl:{format_number_KMB(num_samples)}",
f"smpl:{format_big_number(num_samples)}",
# number of episodes seen during training
f"ep:{format_number_KMB(num_episodes)}",
f"ep:{format_big_number(num_episodes)}",
# number of time all unique samples are seen
f"epch:{num_epochs:.2f}",
f"loss:{loss:.3f}",
@ -86,11 +85,11 @@ def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline):
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / offline_buffer.num_samples
log_items = [
f"step:{format_number_KMB(step)}",
f"step:{format_big_number(step)}",
# number of samples seen during training
f"smpl:{format_number_KMB(num_samples)}",
f"smpl:{format_big_number(num_samples)}",
# number of episodes seen during training
f"ep:{format_number_KMB(num_episodes)}",
f"ep:{format_big_number(num_episodes)}",
# number of time all unique samples are seen
f"epch:{num_epochs:.2f}",
f"∑rwrd:{avg_sum_reward:.3f}",
@ -156,7 +155,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
logger = Logger(out_dir, job_name, cfg)
online_ep_idx = 0
step = 0 # number of policy update
is_offline = True
@ -200,7 +198,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
auto_cast_to_device=True,
)
assert len(rollout) <= cfg.env.episode_length
rollout["episode"] = torch.tensor([online_ep_idx] * len(rollout), dtype=torch.int)
# set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout)
ep_sum_reward = rollout["next", "reward"].sum()
@ -210,12 +209,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
"avg_sum_reward": np.nanmean(ep_sum_reward),
"avg_max_reward": np.nanmean(ep_max_reward),
"pc_success": np.nanmean(ep_success) * 100,
"online_ep_idx": online_ep_idx,
"env_step": env_step,
"ep_length": len(rollout),
}
online_ep_idx += 1
for _ in range(cfg.policy.utd):
train_info = policy.update(
online_buffer,
@ -233,7 +230,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_episodes=cfg.eval_episodes,
return_first_video=True,
)
log_eval_info(L, eval_info, step, cfg, offline_buffer, is_offline)
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval")

View File

@ -3,9 +3,7 @@ from pathlib import Path
import hydra
import imageio
import torch
from torchrl.data.replay_buffers import (
SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers import SliceSamplerWithoutReplacement
from lerobot.common.datasets.factory import make_offline_buffer