import jsonlines import json import copy import re PROMPT_DICT = { "prompt_input": ( "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" ), "prompt_no_input": ( "### Instruction:\n{instruction}\n\n### Response:\n" ), } TASK_INST = {"wow": "Given a chat history separated by new lines, generates an informative, knowledgeable and engaging response. ", "fever": "Is the following statement correct or not? Say true if it's correct; otherwise say false.", "eli5": "Provide a paragraph-length response using simple words to answer the following question.", "obqa": "Given four answer candidates, A, B, C and D, choose the best answer choice.", "arc_easy": "Given four answer candidates, A, B, C and D, choose the best answer choice.", "arc_c": "Given four answer candidates, A, B, C and D, choose the best answer choice.", "trex": "Given the input format 'Subject Entity [SEP] Relationship Type,' predict the target entity.", "asqa": "Answer the following question. The question may be ambiguous and have multiple correct answers, and in that case, you have to provide a long-form answer including all correct answers."} rel_tokens_names = ["[Irrelevant]", "[Relevant]"] retrieval_tokens_names = ["[No Retrieval]", "[Retrieval]", "[Continue to Use Evidence]"] utility_tokens_names = ["[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"] ground_tokens_names = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]"] other_special_tokens = ["", "", "[PAD]", "", "", ""] control_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]", "[Irrelevant]", "[Relevant]", "", "", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"] def load_special_tokens(tokenizer, use_grounding=False, use_utility=False): ret_tokens = {token: tokenizer.convert_tokens_to_ids( token) for token in retrieval_tokens_names} rel_tokens = {} for token in ["[Irrelevant]", "[Relevant]"]: rel_tokens[token] = tokenizer.convert_tokens_to_ids(token) grd_tokens = None if use_grounding is True: grd_tokens = {} for token in ground_tokens_names: grd_tokens[token] = tokenizer.convert_tokens_to_ids(token) ut_tokens = None if use_utility is True: ut_tokens = {} for token in utility_tokens_names: ut_tokens[token] = tokenizer.convert_tokens_to_ids(token) return ret_tokens, rel_tokens, grd_tokens, ut_tokens def fix_spacing(input_text): # Add a space after periods that lack whitespace output_text = re.sub(r'(?<=\w)([.!?])(?=\w)', r'\1 ', input_text) return output_text def postprocess(pred): special_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]", "[Irrelevant]", "[Relevant]", "", "", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"] for item in special_tokens: pred = pred.replace(item, "") pred = pred.replace("", "") if len(pred) == 0: return "" if pred[0] == " ": pred = pred[1:] return pred def load_jsonlines(file): with jsonlines.open(file, 'r') as jsonl_f: lst = [obj for obj in jsonl_f] return lst def load_file(input_fp): if input_fp.endswith(".json"): input_data = json.load(open(input_fp)) else: input_data = load_jsonlines(input_fp) return input_data def save_file_jsonl(data, fp): with jsonlines.open(fp, mode='w') as writer: writer.write_all(data) def preprocess_input(input_data, task): if task == "factscore": for item in input_data: item["instruction"] = item["input"] item["output"] = [item["output"] ] if "output" in item else [item["topic"]] return input_data elif task == "qa": for item in input_data: if "instruction" not in item: item["instruction"] = item["question"] if "answers" not in item and "output" in item: item["answers"] = "output" return input_data elif task in ["asqa", "eli5"]: processed_input_data = [] for instance_idx, item in enumerate(input_data["data"]): prompt = item["question"] instructions = TASK_INST[task] prompt = instructions + "## Input:\n\n" + prompt entry = copy.deepcopy(item) entry["instruction"] = prompt processed_input_data.append(entry) return processed_input_data def postprocess_output(input_instance, prediction, task, intermediate_results=None): if task == "factscore": return {"input": input_instance["input"], "output": prediction, "topic": input_instance["topic"], "cat": input_instance["cat"]} elif task == "qa": input_instance["pred"] = prediction return input_instance elif task in ["asqa", "eli5"]: # ALCE datasets require additional postprocessing to compute citation accuracy. final_output = "" docs = [] if "splitted_sentences" not in intermediate_results: input_instance["output"] = postprocess(prediction) else: for idx, (sent, doc) in enumerate(zip(intermediate_results["splitted_sentences"][0], intermediate_results["ctxs"][0])): if len(sent) == 0: continue postprocessed_result = postprocess(sent) final_output += postprocessed_result[:- 1] + " [{}]".format(idx) + ". " docs.append(doc) if final_output[-1] == " ": final_output = final_output[:-1] input_instance["output"] = final_output input_instance["docs"] = docs return input_instance def process_arc_instruction(item, instruction): choices = item["choices"] answer_labels = {} for i in range(len(choices["label"])): answer_key = choices["label"][i] text = choices["text"][i] if answer_key == "1": answer_labels["A"] = text if answer_key == "2": answer_labels["B"] = text if answer_key == "3": answer_labels["C"] = text if answer_key == "4": answer_labels["D"] = text if answer_key in ["A", "B", "C", "D"]: answer_labels[answer_key] = text if "D" not in answer_labels: answer_labels["D"] = "" choices = "\nA: {0}\nB: {1}\nC: {2}\nD: {3}".format(answer_labels["A"], answer_labels["B"], answer_labels["C"], answer_labels["D"]) if "E" in answer_labels: choices += "\nE: {}".format(answer_labels["E"]) processed_instruction = instruction + "\n\n### Input:\n" + item["instruction"] + choices return processed_instruction def postprocess_answers_closed(output, task, choices=None): final_output = None if choices is not None: for c in choices.split(" "): if c in output: final_output = c if task == "fever" and output in ["REFUTES", "SUPPORTS"]: final_output = "true" if output == "SUPPORTS" else "REFUTES" if task == "fever" and output.lower() in ["true", "false"]: final_output = output.lower() if final_output is None: return output else: return final_output