Add '--independent' flag

This commit is contained in:
Simon Alibert 2024-05-16 19:31:57 +02:00
parent fe31b7f4b7
commit eb530fa595
1 changed files with 44 additions and 29 deletions

View File

@ -11,8 +11,16 @@ python lerobot/scripts/compare_policies.py \
output/eval/new_policy/eval_info.json
```
This script can accept `eval_info.json` dicts with identical seeds between each eval episode of ref_policy and new_policy
(paired-samples) or from evals performed with different seeds (independent samples).
This script can accept `eval_info.json` dicts with identical seeds between each eval episode of ref_policy and
new_policy (paired-samples) or from evals performed with different seeds (independent samples).
The script will first perform normality tests to determine if parametric tests can be used or not, then
evaluate if policies metrics are significantly different using the appropriate tests.
CAVEATS: by default, this script will compare seeds numbers to determine if samples can be considered paired.
If changes have been made to this environment in-between the ref_policy eval and the new_policy eval, you
should use the `--independent` flag to override this and not pair the samples even if they have identical
seeds.
"""
import argparse
@ -23,12 +31,12 @@ from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
from scipy.stats import anderson, kstest, mannwhitneyu, shapiro, ttest_ind, ttest_rel, wilcoxon
from scipy.stats import anderson, kstest, mannwhitneyu, normaltest, shapiro, ttest_ind, ttest_rel, wilcoxon
from statsmodels.stats.contingency_tables import mcnemar
from termcolor import colored
def init_logging(output_dir: Path) -> None:
def init_logging() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
@ -42,6 +50,18 @@ def log_section(title: str) -> None:
logging.info(section_title)
def log_test(msg: str, p_value: float):
if p_value < 0.01:
color, interpretation = "red", "H_0 Rejected"
elif 0.01 <= p_value < 0.05:
color, interpretation = "yellow", "Inconclusive"
else:
color, interpretation = "green", "H_0 Not Rejected"
logging.info(
f"{msg}, p-value = {colored(f'{p_value:.3f}', color)} -> {colored(f'{interpretation}', color, attrs=['bold'])}"
)
def get_eval_info_episodes(eval_info_path: Path) -> dict:
with open(eval_info_path) as f:
eval_info = json.load(f)
@ -55,7 +75,7 @@ def get_eval_info_episodes(eval_info_path: Path) -> dict:
}
def descriptive_stats(ref_sample: dict, new_sample: dict, metric_name: str):
def describe_samples(ref_sample: dict, new_sample: dict, metric_name: str):
ref_mean, ref_std = np.mean(ref_sample[metric_name]), np.std(ref_sample[metric_name])
new_mean, new_std = np.mean(new_sample[metric_name]), np.std(new_sample[metric_name])
logging.info(f"{metric_name} - Ref sample: mean = {ref_mean:.3f}, std = {ref_std:.3f}")
@ -67,18 +87,20 @@ def cohens_d(x, y):
def normality_tests(array: np.ndarray, name: str):
shapiro_stat, shapiro_p = shapiro(array)
ap_stat, ap_p = normaltest(array)
sw_stat, sw_p = shapiro(array)
ks_stat, ks_p = kstest(array, "norm", args=(np.mean(array), np.std(array)))
ad_stat = anderson(array)
log_test(f"{name} - Shapiro-Wilk Test: statistic = {shapiro_stat:.3f}", shapiro_p)
log_test(f"{name} - Kolmogorov-Smirnov Test: statistic = {ks_stat:.3f}", ks_p)
logging.info(f"{name} - Anderson-Darling Test: statistic = {ad_stat.statistic:.3f}")
log_test(f"{name} - D'Agostino and Pearson test: statistic = {ap_stat:.3f}", ap_p)
log_test(f"{name} - Shapiro-Wilk test: statistic = {sw_stat:.3f}", sw_p)
log_test(f"{name} - Kolmogorov-Smirnov test: statistic = {ks_stat:.3f}", ks_p)
logging.info(f"{name} - Anderson-Darling test: statistic = {ad_stat.statistic:.3f}")
for i in range(len(ad_stat.critical_values)):
cv, sl = ad_stat.critical_values[i], ad_stat.significance_level[i]
logging.info(f" Critical value at {sl}%: {cv:.3f}")
return shapiro_p > 0.05 and ks_p > 0.05
return sw_p > 0.05 and ks_p > 0.05
def plot_boxplot(data_a: np.ndarray, data_b: np.ndarray, labels: list[str], title: str, filename: str):
@ -104,18 +126,6 @@ def plot_qqplot(data: np.ndarray, title: str, filename: str):
plt.close()
def log_test(msg, p_value):
if p_value < 0.01:
color, interpretation = "red", "H_0 Rejected"
elif 0.01 <= p_value < 0.05:
color, interpretation = "orange", "Inconclusive"
else:
color, interpretation = "green", "H_0 Not Rejected"
logging.info(
f"{msg}, p-value = {colored(f'{p_value:.3f}', color)} -> {colored(f'{interpretation}', color, attrs=['bold'])}"
)
def paired_sample_tests(ref_sample: dict, new_sample: dict):
log_section("Normality tests")
max_reward_diff = ref_sample["max_rewards"] - new_sample["max_rewards"]
@ -185,22 +195,22 @@ def independent_sample_tests(ref_sample: dict, new_sample: dict):
log_test(f"Mann-Whitney U test for Sum Reward: U-statistic = {u_stat_sum_reward:.3f}", p_u_sum_reward)
def perform_tests(ref_sample: dict, new_sample: dict, output_dir: Path):
def perform_tests(ref_sample: dict, new_sample: dict, output_dir: Path, independent: bool = False):
log_section("Descriptive Stats")
logging.info(f"Number of episode - Ref Sample: {ref_sample['num_episodes']}")
logging.info(f"Number of episode - New Sample: {new_sample['num_episodes']}")
seeds_a, seeds_b = ref_sample["seeds"], new_sample["seeds"]
if seeds_a == seeds_b:
if (seeds_a == seeds_b) and not independent:
logging.info("Samples are paired (identical seeds).")
paired = True
else:
logging.info("Samples are considered independent (seeds are different).")
paired = False
descriptive_stats(ref_sample, new_sample, "successes")
descriptive_stats(ref_sample, new_sample, "max_rewards")
descriptive_stats(ref_sample, new_sample, "sum_rewards")
describe_samples(ref_sample, new_sample, "successes")
describe_samples(ref_sample, new_sample, "max_rewards")
describe_samples(ref_sample, new_sample, "sum_rewards")
log_section("Effect Size")
d_max_reward = cohens_d(ref_sample["max_rewards"], new_sample["max_rewards"])
@ -273,6 +283,11 @@ if __name__ == "__main__":
)
parser.add_argument("ref_sample_path", type=Path, help="Path to the reference sample JSON file.")
parser.add_argument("new_sample_path", type=Path, help="Path to the new sample JSON file.")
parser.add_argument(
"--independent",
action="store_true",
help="Ignore seeds and consider samples to be independent (unpaired).",
)
parser.add_argument(
"--output_dir",
type=Path,
@ -280,8 +295,8 @@ if __name__ == "__main__":
help="Directory to save the output results. Defaults to outputs/compare/",
)
args = parser.parse_args()
init_logging(args.output_dir)
init_logging()
ref_sample = get_eval_info_episodes(args.ref_sample_path)
new_sample = get_eval_info_episodes(args.new_sample_path)
perform_tests(ref_sample, new_sample, args.output_dir)
perform_tests(ref_sample, new_sample, args.output_dir, args.independent)