195 lines
7.6 KiB
Python
195 lines
7.6 KiB
Python
import jsonlines
|
|
import json
|
|
import copy
|
|
import re
|
|
|
|
PROMPT_DICT = {
|
|
"prompt_input": (
|
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
|
),
|
|
"prompt_no_input": (
|
|
"### Instruction:\n{instruction}\n\n### Response:\n"
|
|
),
|
|
}
|
|
|
|
TASK_INST = {"wow": "Given a chat history separated by new lines, generates an informative, knowledgeable and engaging response. ",
|
|
"fever": "Is the following statement correct or not? Say true if it's correct; otherwise say false.",
|
|
"eli5": "Provide a paragraph-length response using simple words to answer the following question.",
|
|
"obqa": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
|
|
"arc_easy": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
|
|
"arc_c": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
|
|
"trex": "Given the input format 'Subject Entity [SEP] Relationship Type,' predict the target entity.",
|
|
"asqa": "Answer the following question. The question may be ambiguous and have multiple correct answers, and in that case, you have to provide a long-form answer including all correct answers."}
|
|
|
|
rel_tokens_names = ["[Irrelevant]", "[Relevant]"]
|
|
retrieval_tokens_names = ["[No Retrieval]",
|
|
"[Retrieval]", "[Continue to Use Evidence]"]
|
|
utility_tokens_names = ["[Utility:1]", "[Utility:2]",
|
|
"[Utility:3]", "[Utility:4]", "[Utility:5]"]
|
|
ground_tokens_names = ["[Fully supported]",
|
|
"[Partially supported]", "[No support / Contradictory]"]
|
|
other_special_tokens = ["<s>", "</s>", "[PAD]",
|
|
"<unk>", "<paragraph>", "</paragraph>"]
|
|
control_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]",
|
|
"[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"]
|
|
|
|
|
|
def load_special_tokens(tokenizer, use_grounding=False, use_utility=False):
|
|
ret_tokens = {token: tokenizer.convert_tokens_to_ids(
|
|
token) for token in retrieval_tokens_names}
|
|
rel_tokens = {}
|
|
for token in ["[Irrelevant]", "[Relevant]"]:
|
|
rel_tokens[token] = tokenizer.convert_tokens_to_ids(token)
|
|
|
|
grd_tokens = None
|
|
if use_grounding is True:
|
|
grd_tokens = {}
|
|
for token in ground_tokens_names:
|
|
grd_tokens[token] = tokenizer.convert_tokens_to_ids(token)
|
|
|
|
ut_tokens = None
|
|
if use_utility is True:
|
|
ut_tokens = {}
|
|
for token in utility_tokens_names:
|
|
ut_tokens[token] = tokenizer.convert_tokens_to_ids(token)
|
|
|
|
return ret_tokens, rel_tokens, grd_tokens, ut_tokens
|
|
|
|
|
|
def fix_spacing(input_text):
|
|
# Add a space after periods that lack whitespace
|
|
output_text = re.sub(r'(?<=\w)([.!?])(?=\w)', r'\1 ', input_text)
|
|
return output_text
|
|
|
|
|
|
def postprocess(pred):
|
|
special_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]",
|
|
"[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"]
|
|
for item in special_tokens:
|
|
pred = pred.replace(item, "")
|
|
pred = pred.replace("</s>", "")
|
|
|
|
if len(pred) == 0:
|
|
return ""
|
|
if pred[0] == " ":
|
|
pred = pred[1:]
|
|
return pred
|
|
|
|
|
|
def load_jsonlines(file):
|
|
with jsonlines.open(file, 'r') as jsonl_f:
|
|
lst = [obj for obj in jsonl_f]
|
|
return lst
|
|
|
|
|
|
def load_file(input_fp):
|
|
if input_fp.endswith(".json"):
|
|
input_data = json.load(open(input_fp))
|
|
else:
|
|
input_data = load_jsonlines(input_fp)
|
|
return input_data
|
|
|
|
|
|
def save_file_jsonl(data, fp):
|
|
with jsonlines.open(fp, mode='w') as writer:
|
|
writer.write_all(data)
|
|
|
|
|
|
def preprocess_input(input_data, task):
|
|
if task == "factscore":
|
|
for item in input_data:
|
|
item["instruction"] = item["input"]
|
|
item["output"] = [item["output"]
|
|
] if "output" in item else [item["topic"]]
|
|
return input_data
|
|
|
|
elif task == "qa":
|
|
for item in input_data:
|
|
if "instruction" not in item:
|
|
item["instruction"] = item["question"]
|
|
if "answers" not in item and "output" in item:
|
|
item["answers"] = "output"
|
|
return input_data
|
|
|
|
elif task in ["asqa", "eli5"]:
|
|
processed_input_data = []
|
|
for instance_idx, item in enumerate(input_data["data"]):
|
|
prompt = item["question"]
|
|
instructions = TASK_INST[task]
|
|
prompt = instructions + "## Input:\n\n" + prompt
|
|
entry = copy.deepcopy(item)
|
|
entry["instruction"] = prompt
|
|
processed_input_data.append(entry)
|
|
return processed_input_data
|
|
|
|
|
|
def postprocess_output(input_instance, prediction, task, intermediate_results=None):
|
|
if task == "factscore":
|
|
return {"input": input_instance["input"], "output": prediction, "topic": input_instance["topic"], "cat": input_instance["cat"]}
|
|
|
|
elif task == "qa":
|
|
input_instance["pred"] = prediction
|
|
return input_instance
|
|
|
|
elif task in ["asqa", "eli5"]:
|
|
# ALCE datasets require additional postprocessing to compute citation accuracy.
|
|
final_output = ""
|
|
docs = []
|
|
if "splitted_sentences" not in intermediate_results:
|
|
input_instance["output"] = postprocess(prediction)
|
|
|
|
else:
|
|
for idx, (sent, doc) in enumerate(zip(intermediate_results["splitted_sentences"][0], intermediate_results["ctxs"][0])):
|
|
if len(sent) == 0:
|
|
continue
|
|
postprocessed_result = postprocess(sent)
|
|
final_output += postprocessed_result[:-
|
|
1] + " [{}]".format(idx) + ". "
|
|
docs.append(doc)
|
|
if final_output[-1] == " ":
|
|
final_output = final_output[:-1]
|
|
input_instance["output"] = final_output
|
|
input_instance["docs"] = docs
|
|
return input_instance
|
|
|
|
def process_arc_instruction(item, instruction):
|
|
choices = item["choices"]
|
|
answer_labels = {}
|
|
for i in range(len(choices["label"])):
|
|
answer_key = choices["label"][i]
|
|
text = choices["text"][i]
|
|
if answer_key == "1":
|
|
answer_labels["A"] = text
|
|
if answer_key == "2":
|
|
answer_labels["B"] = text
|
|
if answer_key == "3":
|
|
answer_labels["C"] = text
|
|
if answer_key == "4":
|
|
answer_labels["D"] = text
|
|
if answer_key in ["A", "B", "C", "D"]:
|
|
answer_labels[answer_key] = text
|
|
|
|
if "D" not in answer_labels:
|
|
answer_labels["D"] = ""
|
|
choices = "\nA: {0}\nB: {1}\nC: {2}\nD: {3}".format(answer_labels["A"], answer_labels["B"], answer_labels["C"], answer_labels["D"])
|
|
if "E" in answer_labels:
|
|
choices += "\nE: {}".format(answer_labels["E"])
|
|
processed_instruction = instruction + "\n\n### Input:\n" + item["instruction"] + choices
|
|
return processed_instruction
|
|
|
|
|
|
def postprocess_answers_closed(output, task, choices=None):
|
|
final_output = None
|
|
if choices is not None:
|
|
for c in choices.split(" "):
|
|
if c in output:
|
|
final_output = c
|
|
if task == "fever" and output in ["REFUTES", "SUPPORTS"]:
|
|
final_output = "true" if output == "SUPPORTS" else "REFUTES"
|
|
if task == "fever" and output.lower() in ["true", "false"]:
|
|
final_output = output.lower()
|
|
if final_output is None:
|
|
return output
|
|
else:
|
|
return final_output
|