Add '--independent' flag
This commit is contained in:
parent
fe31b7f4b7
commit
eb530fa595
|
@ -11,8 +11,16 @@ python lerobot/scripts/compare_policies.py \
|
|||
output/eval/new_policy/eval_info.json
|
||||
```
|
||||
|
||||
This script can accept `eval_info.json` dicts with identical seeds between each eval episode of ref_policy and new_policy
|
||||
(paired-samples) or from evals performed with different seeds (independent samples).
|
||||
This script can accept `eval_info.json` dicts with identical seeds between each eval episode of ref_policy and
|
||||
new_policy (paired-samples) or from evals performed with different seeds (independent samples).
|
||||
|
||||
The script will first perform normality tests to determine if parametric tests can be used or not, then
|
||||
evaluate if policies metrics are significantly different using the appropriate tests.
|
||||
|
||||
CAVEATS: by default, this script will compare seeds numbers to determine if samples can be considered paired.
|
||||
If changes have been made to this environment in-between the ref_policy eval and the new_policy eval, you
|
||||
should use the `--independent` flag to override this and not pair the samples even if they have identical
|
||||
seeds.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
@ -23,12 +31,12 @@ from pathlib import Path
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import scipy.stats as stats
|
||||
from scipy.stats import anderson, kstest, mannwhitneyu, shapiro, ttest_ind, ttest_rel, wilcoxon
|
||||
from scipy.stats import anderson, kstest, mannwhitneyu, normaltest, shapiro, ttest_ind, ttest_rel, wilcoxon
|
||||
from statsmodels.stats.contingency_tables import mcnemar
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
def init_logging(output_dir: Path) -> None:
|
||||
def init_logging() -> None:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(message)s",
|
||||
|
@ -42,6 +50,18 @@ def log_section(title: str) -> None:
|
|||
logging.info(section_title)
|
||||
|
||||
|
||||
def log_test(msg: str, p_value: float):
|
||||
if p_value < 0.01:
|
||||
color, interpretation = "red", "H_0 Rejected"
|
||||
elif 0.01 <= p_value < 0.05:
|
||||
color, interpretation = "yellow", "Inconclusive"
|
||||
else:
|
||||
color, interpretation = "green", "H_0 Not Rejected"
|
||||
logging.info(
|
||||
f"{msg}, p-value = {colored(f'{p_value:.3f}', color)} -> {colored(f'{interpretation}', color, attrs=['bold'])}"
|
||||
)
|
||||
|
||||
|
||||
def get_eval_info_episodes(eval_info_path: Path) -> dict:
|
||||
with open(eval_info_path) as f:
|
||||
eval_info = json.load(f)
|
||||
|
@ -55,7 +75,7 @@ def get_eval_info_episodes(eval_info_path: Path) -> dict:
|
|||
}
|
||||
|
||||
|
||||
def descriptive_stats(ref_sample: dict, new_sample: dict, metric_name: str):
|
||||
def describe_samples(ref_sample: dict, new_sample: dict, metric_name: str):
|
||||
ref_mean, ref_std = np.mean(ref_sample[metric_name]), np.std(ref_sample[metric_name])
|
||||
new_mean, new_std = np.mean(new_sample[metric_name]), np.std(new_sample[metric_name])
|
||||
logging.info(f"{metric_name} - Ref sample: mean = {ref_mean:.3f}, std = {ref_std:.3f}")
|
||||
|
@ -67,18 +87,20 @@ def cohens_d(x, y):
|
|||
|
||||
|
||||
def normality_tests(array: np.ndarray, name: str):
|
||||
shapiro_stat, shapiro_p = shapiro(array)
|
||||
ap_stat, ap_p = normaltest(array)
|
||||
sw_stat, sw_p = shapiro(array)
|
||||
ks_stat, ks_p = kstest(array, "norm", args=(np.mean(array), np.std(array)))
|
||||
ad_stat = anderson(array)
|
||||
|
||||
log_test(f"{name} - Shapiro-Wilk Test: statistic = {shapiro_stat:.3f}", shapiro_p)
|
||||
log_test(f"{name} - Kolmogorov-Smirnov Test: statistic = {ks_stat:.3f}", ks_p)
|
||||
logging.info(f"{name} - Anderson-Darling Test: statistic = {ad_stat.statistic:.3f}")
|
||||
log_test(f"{name} - D'Agostino and Pearson test: statistic = {ap_stat:.3f}", ap_p)
|
||||
log_test(f"{name} - Shapiro-Wilk test: statistic = {sw_stat:.3f}", sw_p)
|
||||
log_test(f"{name} - Kolmogorov-Smirnov test: statistic = {ks_stat:.3f}", ks_p)
|
||||
logging.info(f"{name} - Anderson-Darling test: statistic = {ad_stat.statistic:.3f}")
|
||||
for i in range(len(ad_stat.critical_values)):
|
||||
cv, sl = ad_stat.critical_values[i], ad_stat.significance_level[i]
|
||||
logging.info(f" Critical value at {sl}%: {cv:.3f}")
|
||||
|
||||
return shapiro_p > 0.05 and ks_p > 0.05
|
||||
return sw_p > 0.05 and ks_p > 0.05
|
||||
|
||||
|
||||
def plot_boxplot(data_a: np.ndarray, data_b: np.ndarray, labels: list[str], title: str, filename: str):
|
||||
|
@ -104,18 +126,6 @@ def plot_qqplot(data: np.ndarray, title: str, filename: str):
|
|||
plt.close()
|
||||
|
||||
|
||||
def log_test(msg, p_value):
|
||||
if p_value < 0.01:
|
||||
color, interpretation = "red", "H_0 Rejected"
|
||||
elif 0.01 <= p_value < 0.05:
|
||||
color, interpretation = "orange", "Inconclusive"
|
||||
else:
|
||||
color, interpretation = "green", "H_0 Not Rejected"
|
||||
logging.info(
|
||||
f"{msg}, p-value = {colored(f'{p_value:.3f}', color)} -> {colored(f'{interpretation}', color, attrs=['bold'])}"
|
||||
)
|
||||
|
||||
|
||||
def paired_sample_tests(ref_sample: dict, new_sample: dict):
|
||||
log_section("Normality tests")
|
||||
max_reward_diff = ref_sample["max_rewards"] - new_sample["max_rewards"]
|
||||
|
@ -185,22 +195,22 @@ def independent_sample_tests(ref_sample: dict, new_sample: dict):
|
|||
log_test(f"Mann-Whitney U test for Sum Reward: U-statistic = {u_stat_sum_reward:.3f}", p_u_sum_reward)
|
||||
|
||||
|
||||
def perform_tests(ref_sample: dict, new_sample: dict, output_dir: Path):
|
||||
def perform_tests(ref_sample: dict, new_sample: dict, output_dir: Path, independent: bool = False):
|
||||
log_section("Descriptive Stats")
|
||||
logging.info(f"Number of episode - Ref Sample: {ref_sample['num_episodes']}")
|
||||
logging.info(f"Number of episode - New Sample: {new_sample['num_episodes']}")
|
||||
|
||||
seeds_a, seeds_b = ref_sample["seeds"], new_sample["seeds"]
|
||||
if seeds_a == seeds_b:
|
||||
if (seeds_a == seeds_b) and not independent:
|
||||
logging.info("Samples are paired (identical seeds).")
|
||||
paired = True
|
||||
else:
|
||||
logging.info("Samples are considered independent (seeds are different).")
|
||||
paired = False
|
||||
|
||||
descriptive_stats(ref_sample, new_sample, "successes")
|
||||
descriptive_stats(ref_sample, new_sample, "max_rewards")
|
||||
descriptive_stats(ref_sample, new_sample, "sum_rewards")
|
||||
describe_samples(ref_sample, new_sample, "successes")
|
||||
describe_samples(ref_sample, new_sample, "max_rewards")
|
||||
describe_samples(ref_sample, new_sample, "sum_rewards")
|
||||
|
||||
log_section("Effect Size")
|
||||
d_max_reward = cohens_d(ref_sample["max_rewards"], new_sample["max_rewards"])
|
||||
|
@ -273,6 +283,11 @@ if __name__ == "__main__":
|
|||
)
|
||||
parser.add_argument("ref_sample_path", type=Path, help="Path to the reference sample JSON file.")
|
||||
parser.add_argument("new_sample_path", type=Path, help="Path to the new sample JSON file.")
|
||||
parser.add_argument(
|
||||
"--independent",
|
||||
action="store_true",
|
||||
help="Ignore seeds and consider samples to be independent (unpaired).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=Path,
|
||||
|
@ -280,8 +295,8 @@ if __name__ == "__main__":
|
|||
help="Directory to save the output results. Defaults to outputs/compare/",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
init_logging(args.output_dir)
|
||||
init_logging()
|
||||
|
||||
ref_sample = get_eval_info_episodes(args.ref_sample_path)
|
||||
new_sample = get_eval_info_episodes(args.new_sample_path)
|
||||
perform_tests(ref_sample, new_sample, args.output_dir)
|
||||
perform_tests(ref_sample, new_sample, args.output_dir, args.independent)
|
||||
|
|
Loading…
Reference in New Issue