import numpy as np import string import re from collections import Counter import re def exact_match_score(prediction, ground_truth): return (normalize_answer(prediction) == normalize_answer(ground_truth)) def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): scores_for_ground_truths = [] for ground_truth in ground_truths: score = metric_fn(prediction, ground_truth) scores_for_ground_truths.append(score) return max(scores_for_ground_truths) def accuracy(preds, labels): match_count = 0 for pred, label in zip(preds, labels): target = label[0] if pred == target: match_count += 1 return 100 * (match_count / len(preds)) def f1(decoded_preds, decoded_labels): f1_all = [] for prediction, answers in zip(decoded_preds, decoded_labels): if type(answers) == list: if len(answers) == 0: return 0 f1_all.append(np.max([qa_f1_score(prediction, gt) for gt in answers])) else: f1_all.append(qa_f1_score(prediction, answers)) return 100 * np.mean(f1_all) def qa_f1_score(prediction, ground_truth): prediction_tokens = normalize_answer(prediction).split() ground_truth_tokens = normalize_answer(ground_truth).split() common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if num_same == 0: return 0 precision = 1.0 * num_same / len(prediction_tokens) recall = 1.0 * num_same / len(ground_truth_tokens) f1 = (2 * precision * recall) / (precision + recall) return f1 def normalize_answer(s): def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text) def white_space_fix(text): return ' '.join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return ''.join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def find_entity_tags(sentence): entity_regex = r'(.+?)(?=\s<|$)' tag_regex = r'<(.+?)>' entity_names = re.findall(entity_regex, sentence) tags = re.findall(tag_regex, sentence) results = {} for entity, tag in zip(entity_names, tags): if "<" in entity: results[entity.split("> ")[1]] = tag else: results[entity] = tag return results def match(prediction, ground_truth): for gt in ground_truth: if gt in prediction: return 1 return 0