RoboWaiter/robowaiter/algos/retrieval/retrieval_lm/utils.py

195 lines
7.6 KiB
Python

import jsonlines
import json
import copy
import re
PROMPT_DICT = {
"prompt_input": (
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
),
"prompt_no_input": (
"### Instruction:\n{instruction}\n\n### Response:\n"
),
}
TASK_INST = {"wow": "Given a chat history separated by new lines, generates an informative, knowledgeable and engaging response. ",
"fever": "Is the following statement correct or not? Say true if it's correct; otherwise say false.",
"eli5": "Provide a paragraph-length response using simple words to answer the following question.",
"obqa": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
"arc_easy": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
"arc_c": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
"trex": "Given the input format 'Subject Entity [SEP] Relationship Type,' predict the target entity.",
"asqa": "Answer the following question. The question may be ambiguous and have multiple correct answers, and in that case, you have to provide a long-form answer including all correct answers."}
rel_tokens_names = ["[Irrelevant]", "[Relevant]"]
retrieval_tokens_names = ["[No Retrieval]",
"[Retrieval]", "[Continue to Use Evidence]"]
utility_tokens_names = ["[Utility:1]", "[Utility:2]",
"[Utility:3]", "[Utility:4]", "[Utility:5]"]
ground_tokens_names = ["[Fully supported]",
"[Partially supported]", "[No support / Contradictory]"]
other_special_tokens = ["<s>", "</s>", "[PAD]",
"<unk>", "<paragraph>", "</paragraph>"]
control_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]",
"[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"]
def load_special_tokens(tokenizer, use_grounding=False, use_utility=False):
ret_tokens = {token: tokenizer.convert_tokens_to_ids(
token) for token in retrieval_tokens_names}
rel_tokens = {}
for token in ["[Irrelevant]", "[Relevant]"]:
rel_tokens[token] = tokenizer.convert_tokens_to_ids(token)
grd_tokens = None
if use_grounding is True:
grd_tokens = {}
for token in ground_tokens_names:
grd_tokens[token] = tokenizer.convert_tokens_to_ids(token)
ut_tokens = None
if use_utility is True:
ut_tokens = {}
for token in utility_tokens_names:
ut_tokens[token] = tokenizer.convert_tokens_to_ids(token)
return ret_tokens, rel_tokens, grd_tokens, ut_tokens
def fix_spacing(input_text):
# Add a space after periods that lack whitespace
output_text = re.sub(r'(?<=\w)([.!?])(?=\w)', r'\1 ', input_text)
return output_text
def postprocess(pred):
special_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]",
"[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"]
for item in special_tokens:
pred = pred.replace(item, "")
pred = pred.replace("</s>", "")
if len(pred) == 0:
return ""
if pred[0] == " ":
pred = pred[1:]
return pred
def load_jsonlines(file):
with jsonlines.open(file, 'r') as jsonl_f:
lst = [obj for obj in jsonl_f]
return lst
def load_file(input_fp):
if input_fp.endswith(".json"):
input_data = json.load(open(input_fp))
else:
input_data = load_jsonlines(input_fp)
return input_data
def save_file_jsonl(data, fp):
with jsonlines.open(fp, mode='w') as writer:
writer.write_all(data)
def preprocess_input(input_data, task):
if task == "factscore":
for item in input_data:
item["instruction"] = item["input"]
item["output"] = [item["output"]
] if "output" in item else [item["topic"]]
return input_data
elif task == "qa":
for item in input_data:
if "instruction" not in item:
item["instruction"] = item["question"]
if "answers" not in item and "output" in item:
item["answers"] = "output"
return input_data
elif task in ["asqa", "eli5"]:
processed_input_data = []
for instance_idx, item in enumerate(input_data["data"]):
prompt = item["question"]
instructions = TASK_INST[task]
prompt = instructions + "## Input:\n\n" + prompt
entry = copy.deepcopy(item)
entry["instruction"] = prompt
processed_input_data.append(entry)
return processed_input_data
def postprocess_output(input_instance, prediction, task, intermediate_results=None):
if task == "factscore":
return {"input": input_instance["input"], "output": prediction, "topic": input_instance["topic"], "cat": input_instance["cat"]}
elif task == "qa":
input_instance["pred"] = prediction
return input_instance
elif task in ["asqa", "eli5"]:
# ALCE datasets require additional postprocessing to compute citation accuracy.
final_output = ""
docs = []
if "splitted_sentences" not in intermediate_results:
input_instance["output"] = postprocess(prediction)
else:
for idx, (sent, doc) in enumerate(zip(intermediate_results["splitted_sentences"][0], intermediate_results["ctxs"][0])):
if len(sent) == 0:
continue
postprocessed_result = postprocess(sent)
final_output += postprocessed_result[:-
1] + " [{}]".format(idx) + ". "
docs.append(doc)
if final_output[-1] == " ":
final_output = final_output[:-1]
input_instance["output"] = final_output
input_instance["docs"] = docs
return input_instance
def process_arc_instruction(item, instruction):
choices = item["choices"]
answer_labels = {}
for i in range(len(choices["label"])):
answer_key = choices["label"][i]
text = choices["text"][i]
if answer_key == "1":
answer_labels["A"] = text
if answer_key == "2":
answer_labels["B"] = text
if answer_key == "3":
answer_labels["C"] = text
if answer_key == "4":
answer_labels["D"] = text
if answer_key in ["A", "B", "C", "D"]:
answer_labels[answer_key] = text
if "D" not in answer_labels:
answer_labels["D"] = ""
choices = "\nA: {0}\nB: {1}\nC: {2}\nD: {3}".format(answer_labels["A"], answer_labels["B"], answer_labels["C"], answer_labels["D"])
if "E" in answer_labels:
choices += "\nE: {}".format(answer_labels["E"])
processed_instruction = instruction + "\n\n### Input:\n" + item["instruction"] + choices
return processed_instruction
def postprocess_answers_closed(output, task, choices=None):
final_output = None
if choices is not None:
for c in choices.split(" "):
if c in output:
final_output = c
if task == "fever" and output in ["REFUTES", "SUPPORTS"]:
final_output = "true" if output == "SUPPORTS" else "REFUTES"
if task == "fever" and output.lower() in ["true", "false"]:
final_output = output.lower()
if final_output is None:
return output
else:
return final_output