Solve conflicts + pre-commit run -a
This commit is contained in:
parent
0b9027f05e
commit
ae050d2e94
|
@ -108,7 +108,10 @@ eval_episodes=7
|
||||||
|
|
||||||
**Style**
|
**Style**
|
||||||
```
|
```
|
||||||
|
# install if needed
|
||||||
pre-commit install
|
pre-commit install
|
||||||
|
# apply style and linter checks before git commit
|
||||||
|
pre-commit run -a
|
||||||
```
|
```
|
||||||
|
|
||||||
**Tests**
|
**Tests**
|
||||||
|
|
|
@ -9,19 +9,14 @@ import pymunk
|
||||||
import torch
|
import torch
|
||||||
import torchrl
|
import torchrl
|
||||||
import tqdm
|
import tqdm
|
||||||
from tensordict import TensorDict
|
|
||||||
from torchrl.data.datasets.utils import _get_root_dir
|
|
||||||
from torchrl.data.replay_buffers.replay_buffers import (
|
|
||||||
TensorDictReplayBuffer,
|
|
||||||
)
|
|
||||||
from torchrl.data.replay_buffers.samplers import (
|
|
||||||
Sampler,
|
|
||||||
)
|
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
|
||||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
|
||||||
|
|
||||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.datasets.utils import _get_root_dir
|
||||||
|
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||||
|
from torchrl.data.replay_buffers.samplers import Sampler
|
||||||
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||||
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||||
|
|
||||||
# as define in env
|
# as define in env
|
||||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||||
|
|
|
@ -8,9 +8,7 @@ import torchrl
|
||||||
import tqdm
|
import tqdm
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.data.datasets.utils import _get_root_dir
|
from torchrl.data.datasets.utils import _get_root_dir
|
||||||
from torchrl.data.replay_buffers.replay_buffers import (
|
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||||
TensorDictReplayBuffer,
|
|
||||||
)
|
|
||||||
from torchrl.data.replay_buffers.samplers import (
|
from torchrl.data.replay_buffers.samplers import (
|
||||||
Sampler,
|
Sampler,
|
||||||
SliceSampler,
|
SliceSampler,
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import datetime
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
|
||||||
def make_dir(dir_path):
|
def make_dir(dir_path):
|
||||||
"""Create directory if it does not already exist."""
|
"""Create directory if it does not already exist."""
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
|
|
|
@ -5,7 +5,6 @@ import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
|
||||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||||
|
@ -128,12 +127,8 @@ class DiffusionPolicy(nn.Module):
|
||||||
|
|
||||||
out = {
|
out = {
|
||||||
"obs": {
|
"obs": {
|
||||||
"image": batch["observation", "image"].to(
|
"image": batch["observation", "image"].to(self.device, non_blocking=True),
|
||||||
self.device, non_blocking=True
|
"agent_pos": batch["observation", "state"].to(self.device, non_blocking=True),
|
||||||
),
|
|
||||||
"agent_pos": batch["observation", "state"].to(
|
|
||||||
self.device, non_blocking=True
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
"action": batch["action"].to(self.device, non_blocking=True),
|
"action": batch["action"].to(self.device, non_blocking=True),
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ def init_logging():
|
||||||
logging.getLogger().addHandler(console_handler)
|
logging.getLogger().addHandler(console_handler)
|
||||||
|
|
||||||
|
|
||||||
def format_number_KMB(num):
|
def format_big_number(num):
|
||||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||||
divisor = 1000.0
|
divisor = 1000.0
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -13,7 +12,7 @@ from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger
|
from lerobot.common.logger import Logger
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import format_number_KMB, init_logging, set_seed
|
from lerobot.common.utils import format_big_number, init_logging, set_seed
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,11 +48,11 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
|
||||||
num_episodes = num_samples / avg_samples_per_ep
|
num_episodes = num_samples / avg_samples_per_ep
|
||||||
num_epochs = num_samples / offline_buffer.num_samples
|
num_epochs = num_samples / offline_buffer.num_samples
|
||||||
log_items = [
|
log_items = [
|
||||||
f"step:{format_number_KMB(step)}",
|
f"step:{format_big_number(step)}",
|
||||||
# number of samples seen during training
|
# number of samples seen during training
|
||||||
f"smpl:{format_number_KMB(num_samples)}",
|
f"smpl:{format_big_number(num_samples)}",
|
||||||
# number of episodes seen during training
|
# number of episodes seen during training
|
||||||
f"ep:{format_number_KMB(num_episodes)}",
|
f"ep:{format_big_number(num_episodes)}",
|
||||||
# number of time all unique samples are seen
|
# number of time all unique samples are seen
|
||||||
f"epch:{num_epochs:.2f}",
|
f"epch:{num_epochs:.2f}",
|
||||||
f"loss:{loss:.3f}",
|
f"loss:{loss:.3f}",
|
||||||
|
@ -86,11 +85,11 @@ def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline):
|
||||||
num_episodes = num_samples / avg_samples_per_ep
|
num_episodes = num_samples / avg_samples_per_ep
|
||||||
num_epochs = num_samples / offline_buffer.num_samples
|
num_epochs = num_samples / offline_buffer.num_samples
|
||||||
log_items = [
|
log_items = [
|
||||||
f"step:{format_number_KMB(step)}",
|
f"step:{format_big_number(step)}",
|
||||||
# number of samples seen during training
|
# number of samples seen during training
|
||||||
f"smpl:{format_number_KMB(num_samples)}",
|
f"smpl:{format_big_number(num_samples)}",
|
||||||
# number of episodes seen during training
|
# number of episodes seen during training
|
||||||
f"ep:{format_number_KMB(num_episodes)}",
|
f"ep:{format_big_number(num_episodes)}",
|
||||||
# number of time all unique samples are seen
|
# number of time all unique samples are seen
|
||||||
f"epch:{num_epochs:.2f}",
|
f"epch:{num_epochs:.2f}",
|
||||||
f"∑rwrd:{avg_sum_reward:.3f}",
|
f"∑rwrd:{avg_sum_reward:.3f}",
|
||||||
|
@ -156,7 +155,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
logger = Logger(out_dir, job_name, cfg)
|
logger = Logger(out_dir, job_name, cfg)
|
||||||
|
|
||||||
online_ep_idx = 0
|
|
||||||
step = 0 # number of policy update
|
step = 0 # number of policy update
|
||||||
|
|
||||||
is_offline = True
|
is_offline = True
|
||||||
|
@ -200,7 +198,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
auto_cast_to_device=True,
|
auto_cast_to_device=True,
|
||||||
)
|
)
|
||||||
assert len(rollout) <= cfg.env.episode_length
|
assert len(rollout) <= cfg.env.episode_length
|
||||||
rollout["episode"] = torch.tensor([online_ep_idx] * len(rollout), dtype=torch.int)
|
# set same episode index for all time steps contained in this rollout
|
||||||
|
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
|
||||||
online_buffer.extend(rollout)
|
online_buffer.extend(rollout)
|
||||||
|
|
||||||
ep_sum_reward = rollout["next", "reward"].sum()
|
ep_sum_reward = rollout["next", "reward"].sum()
|
||||||
|
@ -210,12 +209,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
"avg_sum_reward": np.nanmean(ep_sum_reward),
|
"avg_sum_reward": np.nanmean(ep_sum_reward),
|
||||||
"avg_max_reward": np.nanmean(ep_max_reward),
|
"avg_max_reward": np.nanmean(ep_max_reward),
|
||||||
"pc_success": np.nanmean(ep_success) * 100,
|
"pc_success": np.nanmean(ep_success) * 100,
|
||||||
"online_ep_idx": online_ep_idx,
|
"env_step": env_step,
|
||||||
"ep_length": len(rollout),
|
"ep_length": len(rollout),
|
||||||
}
|
}
|
||||||
|
|
||||||
online_ep_idx += 1
|
|
||||||
|
|
||||||
for _ in range(cfg.policy.utd):
|
for _ in range(cfg.policy.utd):
|
||||||
train_info = policy.update(
|
train_info = policy.update(
|
||||||
online_buffer,
|
online_buffer,
|
||||||
|
@ -233,7 +230,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
num_episodes=cfg.eval_episodes,
|
num_episodes=cfg.eval_episodes,
|
||||||
return_first_video=True,
|
return_first_video=True,
|
||||||
)
|
)
|
||||||
log_eval_info(L, eval_info, step, cfg, offline_buffer, is_offline)
|
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(first_video, step, mode="eval")
|
logger.log_video(first_video, step, mode="eval")
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,7 @@ from pathlib import Path
|
||||||
import hydra
|
import hydra
|
||||||
import imageio
|
import imageio
|
||||||
import torch
|
import torch
|
||||||
from torchrl.data.replay_buffers import (
|
from torchrl.data.replay_buffers import SliceSamplerWithoutReplacement
|
||||||
SliceSamplerWithoutReplacement,
|
|
||||||
)
|
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue