2023-11-23 23:05:23 +08:00
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import argparse
import json
import pickle
import time
import glob
import numpy as np
import torch
from robowaiter . algos . retrieval . retrieval_lm . src . slurm import init_distributed_mode
from robowaiter . algos . retrieval . retrieval_lm . src . normalize_text import normalize
from robowaiter . algos . retrieval . retrieval_lm . src . contriever import load_retriever
from robowaiter . algos . retrieval . retrieval_lm . src . index import Indexer
from robowaiter . algos . retrieval . retrieval_lm . src . data import load_passages
from robowaiter . algos . retrieval . retrieval_lm . src . evaluation import calculate_matches
import warnings
2023-11-24 09:47:27 +08:00
from robowaiter . utils . basic import get_root_path
root_path = get_root_path ( )
2023-11-23 23:05:23 +08:00
warnings . filterwarnings ( ' ignore ' )
os . environ [ " TOKENIZERS_PARALLELISM " ] = " true "
def embed_queries ( args , queries , model , tokenizer ) :
model . eval ( )
embeddings , batch_question = [ ] , [ ]
with torch . no_grad ( ) :
for k , q in enumerate ( queries ) :
if args . lowercase :
q = q . lower ( )
if args . normalize_text :
q = normalize ( q )
batch_question . append ( q )
if len ( batch_question ) == args . per_gpu_batch_size or k == len ( queries ) - 1 :
encoded_batch = tokenizer . batch_encode_plus (
batch_question ,
return_tensors = " pt " ,
max_length = args . question_maxlength ,
padding = True ,
truncation = True ,
)
encoded_batch = { k : v . cuda ( ) for k , v in encoded_batch . items ( ) }
output = model ( * * encoded_batch )
embeddings . append ( output . cpu ( ) )
batch_question = [ ]
embeddings = torch . cat ( embeddings , dim = 0 )
#print(f"Questions embeddings shape: {embeddings.size()}")
return embeddings . numpy ( )
def index_encoded_data ( index , embedding_files , indexing_batch_size ) :
allids = [ ]
allembeddings = np . array ( [ ] )
for i , file_path in enumerate ( embedding_files ) :
#print(f"Loading file {file_path}")
with open ( file_path , " rb " ) as fin :
ids , embeddings = pickle . load ( fin )
allembeddings = np . vstack ( ( allembeddings , embeddings ) ) if allembeddings . size else embeddings
allids . extend ( ids )
while allembeddings . shape [ 0 ] > indexing_batch_size :
allembeddings , allids = add_embeddings ( index , allembeddings , allids , indexing_batch_size )
while allembeddings . shape [ 0 ] > 0 :
allembeddings , allids = add_embeddings ( index , allembeddings , allids , indexing_batch_size )
#print("Data indexing completed.")
def add_embeddings ( index , embeddings , ids , indexing_batch_size ) :
end_idx = min ( indexing_batch_size , embeddings . shape [ 0 ] )
ids_toadd = ids [ : end_idx ]
embeddings_toadd = embeddings [ : end_idx ]
ids = ids [ end_idx : ]
embeddings = embeddings [ end_idx : ]
index . index_data ( ids_toadd , embeddings_toadd )
return embeddings , ids
def validate ( data , workers_num ) :
match_stats = calculate_matches ( data , workers_num )
top_k_hits = match_stats . top_k_hits
# print("Validation results: top k documents hits %s", top_k_hits)
top_k_hits = [ v / len ( data ) for v in top_k_hits ]
message = " "
for k in [ 5 , 10 , 20 , 100 ] :
if k < = len ( top_k_hits ) :
message + = f " R@ { k } : { top_k_hits [ k - 1 ] } "
#print(message)
return match_stats . questions_doc_hits
def add_passages ( data , passages , top_passages_and_scores ) :
# add passages to original data
merged_data = [ ]
assert len ( data ) == len ( top_passages_and_scores )
for i , d in enumerate ( data ) :
results_and_scores = top_passages_and_scores [ i ]
#print(passages[2393])
docs = [ passages [ int ( doc_id ) ] for doc_id in results_and_scores [ 0 ] ]
scores = [ str ( score ) for score in results_and_scores [ 1 ] ]
ctxs_num = len ( docs )
d [ " ctxs " ] = [
{
" id " : results_and_scores [ 0 ] [ c ] ,
" title " : docs [ c ] [ " title " ] ,
" text " : docs [ c ] [ " text " ] ,
" score " : scores [ c ] ,
}
for c in range ( ctxs_num )
]
def add_hasanswer ( data , hasanswer ) :
# add hasanswer to data
for i , ex in enumerate ( data ) :
for k , d in enumerate ( ex [ " ctxs " ] ) :
d [ " hasanswer " ] = hasanswer [ i ] [ k ]
# def load_data(data_path):
# if data_path.endswith(".json"):
# with open(data_path, "r",encoding='utf-8') as fin:
# data = json.load(fin)
# elif data_path.endswith(".jsonl"):
# data = []
# with open(data_path, "r",encoding='utf-8') as fin:
# for k, example in enumerate(fin):
# example = json.loads(example)
# data.append(example)
# print("data:",data)
# return data
def load_data ( data_path ) :
if data_path . endswith ( " .json " ) :
with open ( data_path , " r " , encoding = ' utf-8 ' ) as fin :
data = json . load ( fin )
elif data_path . endswith ( " .jsonl " ) :
data = [ ]
with open ( data_path , " r " , encoding = ' utf-8 ' ) as fin :
for k , example in enumerate ( fin ) :
example = json . loads ( example )
#print("example:",example)
data . append ( example )
return data
def test ( args ) : #path为query
# args = {"model_name_or_path":"contriever-msmarco","passages":"train_robot.jsonl"\
# passages_embeddings = "robot_embeddings/*"
# data = "test_robot.jsonl"
# output_dir = "robot_result"
# n_docs = 1
#print(f"Loading model from: {args.model_name_or_path}")
model , tokenizer , _ = load_retriever ( args . model_name_or_path )
model . eval ( )
model = model . cuda ( )
if not args . no_fp16 :
model = model . half ( )
index = Indexer ( args . projection_size , args . n_subquantizers , args . n_bits )
# index all passages
input_paths = glob . glob ( args . passages_embeddings )
input_paths = sorted ( input_paths )
embeddings_dir = os . path . dirname ( input_paths [ 0 ] )
index_path = os . path . join ( embeddings_dir , " index.faiss " )
if args . save_or_load_index and os . path . exists ( index_path ) :
index . deserialize_from ( embeddings_dir )
else :
#print(f"Indexing passages from files {input_paths}")
start_time_indexing = time . time ( )
index_encoded_data ( index , input_paths , args . indexing_batch_size )
#print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
if args . save_or_load_index :
index . serialize ( embeddings_dir )
# load passages
passages = load_passages ( args . passages )
passage_id_map = { x [ " id " ] : x for x in passages }
data_paths = glob . glob ( args . data )
alldata = [ ]
for path in data_paths :
data = load_data ( path )
#print("data:",data)
output_path = os . path . join ( args . output_dir , os . path . basename ( path ) )
queries = [ ex [ " question " ] for ex in data ]
questions_embedding = embed_queries ( args , queries , model , tokenizer )
# get top k results
start_time_retrieval = time . time ( )
top_ids_and_scores = index . search_knn ( questions_embedding , args . n_docs )
#print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
add_passages ( data , passage_id_map , top_ids_and_scores )
#hasanswer = validate(data, args.validation_workers)
#add_hasanswer(data, hasanswer)
os . makedirs ( os . path . dirname ( output_path ) , exist_ok = True )
ret_list = [ ]
with open ( output_path , " w " , encoding = ' utf-8 ' ) as fout :
for ex in data :
json . dump ( ex , fout , ensure_ascii = False )
ret_list . append ( ex )
fout . write ( " \n " )
return ret_list
#print(f"Saved results to {output_path}")
#将query写到test_robot.jsonl
def get_json ( query ) :
dic = { " id " : 1 , " question " : query }
with open ( ' test_robot.jsonl ' , " w " , encoding = ' utf-8 ' ) as fout :
json . dump ( dic , fout , ensure_ascii = False )
def get_answer ( ) :
with open ( ' robot_result \\ test_robot.jsonl ' , " w " , encoding = ' utf-8 ' ) as fin :
for k , example in enumerate ( fin ) :
example = json . loads ( example )
answer = example [ " ctxs " ] [ 0 ] [ " text " ]
score = example [ " ctxs " ] [ 0 ] [ " score " ]
return score , answer
2023-11-28 17:21:54 +08:00
2023-11-23 23:05:23 +08:00
def retri ( query ) :
get_json ( query )
parser = argparse . ArgumentParser ( )
parser . add_argument (
" --data " ,
#required=True,
type = str ,
default = ' test_robot.jsonl ' ,
help = " .json file containing question and answers, similar format to reader data " ,
)
# parser.add_argument("--passages", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/train_robot.jsonl', help="Path to passages (.tsv file)")
# parser.add_argument("--passages_embeddings", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/robot_embeddings/*', help="Glob path to encoded passages")
2023-11-24 09:47:27 +08:00
parser . add_argument ( " --passages " , type = str , default = f ' { root_path } /robowaiter/llm_client/train_robot.jsonl ' , help = " Path to passages (.tsv file) " )
parser . add_argument ( " --passages_embeddings " , type = str , default = f ' { root_path } /robowaiter/algos/retrieval/robot_embeddings/* ' , help = " Glob path to encoded passages " )
2023-11-23 23:05:23 +08:00
parser . add_argument (
" --output_dir " , type = str , default = ' robot_result ' , help = " Results are written to outputdir with data suffix "
)
parser . add_argument ( " --n_docs " , type = int , default = 5 , help = " Number of documents to retrieve per questions " ) #可以改这个参数, 返回前n_docs个检索结果
parser . add_argument (
" --validation_workers " , type = int , default = 32 , help = " Number of parallel processes to validate results "
)
parser . add_argument ( " --per_gpu_batch_size " , type = int , default = 64 , help = " Batch size for question encoding " )
parser . add_argument (
" --save_or_load_index " , action = " store_true " , help = " If enabled, save index and load index if it exists "
)
# parser.add_argument(
# "--model_name_or_path", type=str, default='C:\\Users\\huangyu\\Desktop\\RoboWaiter-main\\RoboWaiter-main\\contriever-msmarco',help="path to directory containing model weights and config file"
# )
parser . add_argument (
2023-11-24 09:47:27 +08:00
" --model_name_or_path " , type = str , default = f ' { root_path } /robowaiter/algos/retrieval/contriever-msmarco ' , help = " path to directory containing model weights and config file "
2023-11-23 23:05:23 +08:00
)
parser . add_argument ( " --no_fp16 " , action = " store_true " , help = " inference in fp32 " )
parser . add_argument ( " --question_maxlength " , type = int , default = 512 , help = " Maximum number of tokens in a question " )
parser . add_argument (
" --indexing_batch_size " , type = int , default = 1000000 , help = " Batch size of the number of passages indexed "
)
parser . add_argument ( " --projection_size " , type = int , default = 768 )
parser . add_argument (
" --n_subquantizers " ,
type = int ,
default = 0 ,
help = " Number of subquantizer used for vector quantization, if 0 flat index is used " ,
)
parser . add_argument ( " --n_bits " , type = int , default = 8 , help = " Number of bits per subquantizer " )
parser . add_argument ( " --lang " , nargs = " + " )
parser . add_argument ( " --dataset " , type = str , default = " none " )
parser . add_argument ( " --lowercase " , action = " store_true " , help = " lowercase text before encoding " )
parser . add_argument ( " --normalize_text " , action = " store_true " , help = " normalize text " )
args = parser . parse_args ( )
init_distributed_mode ( args )
#print(args)
ret = test ( args )
#print(ret)
return ret [ 0 ]
# example = ret[0]
# answer = example["ctxs"][0]["text"]
# score = example["ctxs"][0]["score"]
# return score, answer
if __name__ == " __main__ " :
# query = "请你拿一下软饮料到第三张桌子位置。"
# score,answer = retri(query)
# print(score,answer)
query = " 你能把空调打开一下吗? "
all_ret = retri ( query )
for i , example in enumerate ( all_ret [ " ctxs " ] ) :
answer = example [ " text " ]
score = example [ " score " ]
id = example [ " id " ]
print ( i , answer , score , " id= " , id )