2023-11-23 23:05:23 +08:00
import jsonlines
import json
import copy
import re
PROMPT_DICT = {
" prompt_input " : (
" ### Instruction: \n {instruction} \n \n ### Input: \n {input} \n \n ### Response: \n "
) ,
" prompt_no_input " : (
" ### Instruction: \n {instruction} \n \n ### Response: \n "
) ,
2023-11-24 22:10:58 +08:00
" prompt_no_input_retrieval " : (
" ### Instruction: \n {instruction} \n \n ### Response: \n "
) ,
2023-11-23 23:05:23 +08:00
}
TASK_INST = { " wow " : " Given a chat history separated by new lines, generates an informative, knowledgeable and engaging response. " ,
" fever " : " Is the following statement correct or not? Say true if it ' s correct; otherwise say false. " ,
" eli5 " : " Provide a paragraph-length response using simple words to answer the following question. " ,
" obqa " : " Given four answer candidates, A, B, C and D, choose the best answer choice. " ,
" arc_easy " : " Given four answer candidates, A, B, C and D, choose the best answer choice. " ,
" arc_c " : " Given four answer candidates, A, B, C and D, choose the best answer choice. " ,
" trex " : " Given the input format ' Subject Entity [SEP] Relationship Type, ' predict the target entity. " ,
" asqa " : " Answer the following question. The question may be ambiguous and have multiple correct answers, and in that case, you have to provide a long-form answer including all correct answers. " }
rel_tokens_names = [ " [Irrelevant] " , " [Relevant] " ]
retrieval_tokens_names = [ " [No Retrieval] " ,
" [Retrieval] " , " [Continue to Use Evidence] " ]
utility_tokens_names = [ " [Utility:1] " , " [Utility:2] " ,
" [Utility:3] " , " [Utility:4] " , " [Utility:5] " ]
ground_tokens_names = [ " [Fully supported] " ,
" [Partially supported] " , " [No support / Contradictory] " ]
other_special_tokens = [ " <s> " , " </s> " , " [PAD] " ,
" <unk> " , " <paragraph> " , " </paragraph> " ]
control_tokens = [ " [Fully supported] " , " [Partially supported] " , " [No support / Contradictory] " , " [No Retrieval] " , " [Retrieval] " ,
" [Irrelevant] " , " [Relevant] " , " <paragraph> " , " </paragraph> " , " [Utility:1] " , " [Utility:2] " , " [Utility:3] " , " [Utility:4] " , " [Utility:5] " ]
def load_special_tokens ( tokenizer , use_grounding = False , use_utility = False ) :
ret_tokens = { token : tokenizer . convert_tokens_to_ids (
token ) for token in retrieval_tokens_names }
rel_tokens = { }
for token in [ " [Irrelevant] " , " [Relevant] " ] :
rel_tokens [ token ] = tokenizer . convert_tokens_to_ids ( token )
grd_tokens = None
if use_grounding is True :
grd_tokens = { }
for token in ground_tokens_names :
grd_tokens [ token ] = tokenizer . convert_tokens_to_ids ( token )
ut_tokens = None
if use_utility is True :
ut_tokens = { }
for token in utility_tokens_names :
ut_tokens [ token ] = tokenizer . convert_tokens_to_ids ( token )
return ret_tokens , rel_tokens , grd_tokens , ut_tokens
def fix_spacing ( input_text ) :
# Add a space after periods that lack whitespace
output_text = re . sub ( r ' (?<= \ w)([.!?])(?= \ w) ' , r ' \ 1 ' , input_text )
return output_text
def postprocess ( pred ) :
special_tokens = [ " [Fully supported] " , " [Partially supported] " , " [No support / Contradictory] " , " [No Retrieval] " , " [Retrieval] " ,
" [Irrelevant] " , " [Relevant] " , " <paragraph> " , " </paragraph> " , " [Utility:1] " , " [Utility:2] " , " [Utility:3] " , " [Utility:4] " , " [Utility:5] " ]
for item in special_tokens :
pred = pred . replace ( item , " " )
pred = pred . replace ( " </s> " , " " )
if len ( pred ) == 0 :
return " "
if pred [ 0 ] == " " :
pred = pred [ 1 : ]
return pred
def load_jsonlines ( file ) :
with jsonlines . open ( file , ' r ' ) as jsonl_f :
lst = [ obj for obj in jsonl_f ]
return lst
def load_file ( input_fp ) :
if input_fp . endswith ( " .json " ) :
input_data = json . load ( open ( input_fp ) )
else :
input_data = load_jsonlines ( input_fp )
return input_data
def save_file_jsonl ( data , fp ) :
with jsonlines . open ( fp , mode = ' w ' ) as writer :
writer . write_all ( data )
def preprocess_input ( input_data , task ) :
if task == " factscore " :
for item in input_data :
item [ " instruction " ] = item [ " input " ]
item [ " output " ] = [ item [ " output " ]
] if " output " in item else [ item [ " topic " ] ]
return input_data
elif task == " qa " :
for item in input_data :
if " instruction " not in item :
item [ " instruction " ] = item [ " question " ]
if " answers " not in item and " output " in item :
item [ " answers " ] = " output "
return input_data
elif task in [ " asqa " , " eli5 " ] :
processed_input_data = [ ]
for instance_idx , item in enumerate ( input_data [ " data " ] ) :
prompt = item [ " question " ]
instructions = TASK_INST [ task ]
prompt = instructions + " ## Input: \n \n " + prompt
entry = copy . deepcopy ( item )
entry [ " instruction " ] = prompt
processed_input_data . append ( entry )
return processed_input_data
def postprocess_output ( input_instance , prediction , task , intermediate_results = None ) :
if task == " factscore " :
return { " input " : input_instance [ " input " ] , " output " : prediction , " topic " : input_instance [ " topic " ] , " cat " : input_instance [ " cat " ] }
elif task == " qa " :
input_instance [ " pred " ] = prediction
return input_instance
elif task in [ " asqa " , " eli5 " ] :
# ALCE datasets require additional postprocessing to compute citation accuracy.
final_output = " "
docs = [ ]
if " splitted_sentences " not in intermediate_results :
input_instance [ " output " ] = postprocess ( prediction )
else :
for idx , ( sent , doc ) in enumerate ( zip ( intermediate_results [ " splitted_sentences " ] [ 0 ] , intermediate_results [ " ctxs " ] [ 0 ] ) ) :
if len ( sent ) == 0 :
continue
postprocessed_result = postprocess ( sent )
final_output + = postprocessed_result [ : -
1 ] + " [ {} ] " . format ( idx ) + " . "
docs . append ( doc )
if final_output [ - 1 ] == " " :
final_output = final_output [ : - 1 ]
input_instance [ " output " ] = final_output
input_instance [ " docs " ] = docs
return input_instance
2023-11-24 22:10:58 +08:00
def process_instruction ( item , task ) :
2023-11-23 23:05:23 +08:00
choices = item [ " choices " ]
2023-11-24 22:10:58 +08:00
instruction = TASK_INST [ task ]
2023-11-23 23:05:23 +08:00
answer_labels = { }
for i in range ( len ( choices [ " label " ] ) ) :
answer_key = choices [ " label " ] [ i ]
text = choices [ " text " ] [ i ]
if answer_key == " 1 " :
answer_labels [ " A " ] = text
if answer_key == " 2 " :
answer_labels [ " B " ] = text
if answer_key == " 3 " :
answer_labels [ " C " ] = text
if answer_key == " 4 " :
answer_labels [ " D " ] = text
if answer_key in [ " A " , " B " , " C " , " D " ] :
answer_labels [ answer_key ] = text
if " D " not in answer_labels :
answer_labels [ " D " ] = " "
choices = " \n A: {0} \n B: {1} \n C: {2} \n D: {3} " . format ( answer_labels [ " A " ] , answer_labels [ " B " ] , answer_labels [ " C " ] , answer_labels [ " D " ] )
if " E " in answer_labels :
choices + = " \n E: {} " . format ( answer_labels [ " E " ] )
2023-11-24 22:10:58 +08:00
#print("instruction:",instruction)
#print("choices:",choices)
#print("question:",item['question'])
processed_instruction = instruction + " \n \n ### Input: \n " + item [ " question " ] + choices
return processed_instruction
def preprocess_input_data ( item , task = None ) :
if task in TASK_INST :
instruction = TASK_INST [ task ]
else :
instruction = None
if task == " arc_c " :
choices = item [ " choices " ]
answer_labels = { }
for i in range ( len ( choices [ " label " ] ) ) :
answer_key = choices [ " label " ] [ i ]
text = choices [ " text " ] [ i ]
if answer_key == " 1 " :
answer_labels [ " A " ] = text
if answer_key == " 2 " :
answer_labels [ " B " ] = text
if answer_key == " 3 " :
answer_labels [ " C " ] = text
if answer_key == " 4 " :
answer_labels [ " D " ] = text
if answer_key in [ " A " , " B " , " C " , " D " ] :
answer_labels [ answer_key ] = text
if " D " not in answer_labels :
answer_labels [ " D " ] = " "
choices = " \n A: {0} \n B: {1} \n C: {2} \n D: {3} " . format (
answer_labels [ " A " ] , answer_labels [ " B " ] , answer_labels [ " C " ] , answer_labels [ " D " ] )
if " E " in answer_labels :
choices + = " \n E: {} " . format ( answer_labels [ " E " ] )
processed_instruction = instruction + \
" \n \n ### Input: \n " + item [ " question " ] + choices
item [ " answers " ] = [ item [ " answerKey " ] ]
else :
processed_instruction = instruction + " \n \n ## Input: \n \n " + \
item [ " question " ] if instruction is not None else item [ " question " ]
2023-11-23 23:05:23 +08:00
return processed_instruction
def postprocess_answers_closed ( output , task , choices = None ) :
final_output = None
if choices is not None :
for c in choices . split ( " " ) :
if c in output :
final_output = c
if task == " fever " and output in [ " REFUTES " , " SUPPORTS " ] :
final_output = " true " if output == " SUPPORTS " else " REFUTES "
if task == " fever " and output . lower ( ) in [ " true " , " false " ] :
final_output = output . lower ( )
if final_output is None :
return output
else :
return final_output