RoboWaiter/robowaiter/algos/retrieval/retrieval_lm/metrics.py

87 lines
2.5 KiB
Python

import numpy as np
import string
import re
from collections import Counter
import re
def exact_match_score(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def accuracy(preds, labels):
match_count = 0
for pred, label in zip(preds, labels):
target = label[0]
if pred == target:
match_count += 1
return 100 * (match_count / len(preds))
def f1(decoded_preds, decoded_labels):
f1_all = []
for prediction, answers in zip(decoded_preds, decoded_labels):
if type(answers) == list:
if len(answers) == 0:
return 0
f1_all.append(np.max([qa_f1_score(prediction, gt)
for gt in answers]))
else:
f1_all.append(qa_f1_score(prediction, answers))
return 100 * np.mean(f1_all)
def qa_f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def find_entity_tags(sentence):
entity_regex = r'(.+?)(?=\s<|$)'
tag_regex = r'<(.+?)>'
entity_names = re.findall(entity_regex, sentence)
tags = re.findall(tag_regex, sentence)
results = {}
for entity, tag in zip(entity_names, tags):
if "<" in entity:
results[entity.split("> ")[1]] = tag
else:
results[entity] = tag
return results
def match(prediction, ground_truth):
for gt in ground_truth:
if gt in prediction:
return 1
return 0