312 lines
12 KiB
Python
312 lines
12 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
# All rights reserved.
|
||
#
|
||
# This source code is licensed under the license found in the
|
||
# LICENSE file in the root directory of this source tree.
|
||
|
||
import os
|
||
import argparse
|
||
import json
|
||
import pickle
|
||
import time
|
||
import glob
|
||
|
||
import numpy as np
|
||
import torch
|
||
from robowaiter.algos.retrieval.retrieval_lm.src.slurm import init_distributed_mode
|
||
from robowaiter.algos.retrieval.retrieval_lm.src.normalize_text import normalize
|
||
from robowaiter.algos.retrieval.retrieval_lm.src.contriever import load_retriever
|
||
from robowaiter.algos.retrieval.retrieval_lm.src.index import Indexer
|
||
from robowaiter.algos.retrieval.retrieval_lm.src.data import load_passages
|
||
|
||
from robowaiter.algos.retrieval.retrieval_lm.src.evaluation import calculate_matches
|
||
import warnings
|
||
|
||
warnings.filterwarnings('ignore')
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||
|
||
|
||
def embed_queries(args, queries, model, tokenizer):
|
||
model.eval()
|
||
embeddings, batch_question = [], []
|
||
with torch.no_grad():
|
||
|
||
for k, q in enumerate(queries):
|
||
if args.lowercase:
|
||
q = q.lower()
|
||
if args.normalize_text:
|
||
q = normalize(q)
|
||
batch_question.append(q)
|
||
|
||
if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
|
||
|
||
encoded_batch = tokenizer.batch_encode_plus(
|
||
batch_question,
|
||
return_tensors="pt",
|
||
max_length=args.question_maxlength,
|
||
padding=True,
|
||
truncation=True,
|
||
)
|
||
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
||
output = model(**encoded_batch)
|
||
embeddings.append(output.cpu())
|
||
|
||
batch_question = []
|
||
|
||
embeddings = torch.cat(embeddings, dim=0)
|
||
#print(f"Questions embeddings shape: {embeddings.size()}")
|
||
|
||
return embeddings.numpy()
|
||
|
||
|
||
def index_encoded_data(index, embedding_files, indexing_batch_size):
|
||
allids = []
|
||
allembeddings = np.array([])
|
||
for i, file_path in enumerate(embedding_files):
|
||
#print(f"Loading file {file_path}")
|
||
with open(file_path, "rb") as fin:
|
||
ids, embeddings = pickle.load(fin)
|
||
|
||
allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings
|
||
allids.extend(ids)
|
||
while allembeddings.shape[0] > indexing_batch_size:
|
||
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
||
|
||
while allembeddings.shape[0] > 0:
|
||
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
||
|
||
#print("Data indexing completed.")
|
||
|
||
|
||
def add_embeddings(index, embeddings, ids, indexing_batch_size):
|
||
end_idx = min(indexing_batch_size, embeddings.shape[0])
|
||
ids_toadd = ids[:end_idx]
|
||
embeddings_toadd = embeddings[:end_idx]
|
||
ids = ids[end_idx:]
|
||
embeddings = embeddings[end_idx:]
|
||
index.index_data(ids_toadd, embeddings_toadd)
|
||
return embeddings, ids
|
||
|
||
|
||
def validate(data, workers_num):
|
||
match_stats = calculate_matches(data, workers_num)
|
||
top_k_hits = match_stats.top_k_hits
|
||
|
||
# print("Validation results: top k documents hits %s", top_k_hits)
|
||
top_k_hits = [v / len(data) for v in top_k_hits]
|
||
message = ""
|
||
for k in [5, 10, 20, 100]:
|
||
if k <= len(top_k_hits):
|
||
message += f"R@{k}: {top_k_hits[k-1]} "
|
||
#print(message)
|
||
return match_stats.questions_doc_hits
|
||
|
||
|
||
def add_passages(data, passages, top_passages_and_scores):
|
||
# add passages to original data
|
||
merged_data = []
|
||
assert len(data) == len(top_passages_and_scores)
|
||
for i, d in enumerate(data):
|
||
results_and_scores = top_passages_and_scores[i]
|
||
#print(passages[2393])
|
||
docs = [passages[int(doc_id)] for doc_id in results_and_scores[0]]
|
||
scores = [str(score) for score in results_and_scores[1]]
|
||
ctxs_num = len(docs)
|
||
d["ctxs"] = [
|
||
{
|
||
"id": results_and_scores[0][c],
|
||
"title": docs[c]["title"],
|
||
"text": docs[c]["text"],
|
||
"score": scores[c],
|
||
}
|
||
for c in range(ctxs_num)
|
||
]
|
||
|
||
|
||
def add_hasanswer(data, hasanswer):
|
||
# add hasanswer to data
|
||
for i, ex in enumerate(data):
|
||
for k, d in enumerate(ex["ctxs"]):
|
||
d["hasanswer"] = hasanswer[i][k]
|
||
|
||
|
||
# def load_data(data_path):
|
||
# if data_path.endswith(".json"):
|
||
# with open(data_path, "r",encoding='utf-8') as fin:
|
||
# data = json.load(fin)
|
||
# elif data_path.endswith(".jsonl"):
|
||
# data = []
|
||
# with open(data_path, "r",encoding='utf-8') as fin:
|
||
# for k, example in enumerate(fin):
|
||
# example = json.loads(example)
|
||
# data.append(example)
|
||
# print("data:",data)
|
||
# return data
|
||
def load_data(data_path):
|
||
if data_path.endswith(".json"):
|
||
with open(data_path, "r",encoding='utf-8') as fin:
|
||
data = json.load(fin)
|
||
elif data_path.endswith(".jsonl"):
|
||
data = []
|
||
with open(data_path, "r",encoding='utf-8') as fin:
|
||
for k, example in enumerate(fin):
|
||
example = json.loads(example)
|
||
#print("example:",example)
|
||
data.append(example)
|
||
return data
|
||
|
||
|
||
def test(args):#path为query
|
||
# args = {"model_name_or_path":"contriever-msmarco","passages":"train_robot.jsonl"\
|
||
# passages_embeddings = "robot_embeddings/*"
|
||
# data = "test_robot.jsonl"
|
||
# output_dir = "robot_result"
|
||
# n_docs = 1
|
||
|
||
#print(f"Loading model from: {args.model_name_or_path}")
|
||
model, tokenizer, _ = load_retriever(args.model_name_or_path)
|
||
model.eval()
|
||
model = model.cuda()
|
||
if not args.no_fp16:
|
||
model = model.half()
|
||
|
||
index = Indexer(args.projection_size, args.n_subquantizers, args.n_bits)
|
||
|
||
# index all passages
|
||
input_paths = glob.glob(args.passages_embeddings)
|
||
input_paths = sorted(input_paths)
|
||
embeddings_dir = os.path.dirname(input_paths[0])
|
||
index_path = os.path.join(embeddings_dir, "index.faiss")
|
||
if args.save_or_load_index and os.path.exists(index_path):
|
||
index.deserialize_from(embeddings_dir)
|
||
else:
|
||
#print(f"Indexing passages from files {input_paths}")
|
||
start_time_indexing = time.time()
|
||
index_encoded_data(index, input_paths, args.indexing_batch_size)
|
||
#print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
|
||
if args.save_or_load_index:
|
||
index.serialize(embeddings_dir)
|
||
|
||
# load passages
|
||
passages = load_passages(args.passages)
|
||
passage_id_map = {x["id"]: x for x in passages}
|
||
|
||
data_paths = glob.glob(args.data)
|
||
alldata = []
|
||
for path in data_paths:
|
||
data = load_data(path)
|
||
#print("data:",data)
|
||
output_path = os.path.join(args.output_dir, os.path.basename(path))
|
||
|
||
queries = [ex["question"] for ex in data]
|
||
questions_embedding = embed_queries(args, queries, model, tokenizer)
|
||
|
||
# get top k results
|
||
start_time_retrieval = time.time()
|
||
top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
|
||
#print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
|
||
|
||
add_passages(data, passage_id_map, top_ids_and_scores)
|
||
#hasanswer = validate(data, args.validation_workers)
|
||
#add_hasanswer(data, hasanswer)
|
||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||
ret_list = []
|
||
with open(output_path, "w",encoding='utf-8') as fout:
|
||
for ex in data:
|
||
json.dump(ex, fout, ensure_ascii=False)
|
||
ret_list.append(ex)
|
||
fout.write("\n")
|
||
return ret_list
|
||
#print(f"Saved results to {output_path}")
|
||
|
||
#将query写到test_robot.jsonl
|
||
def get_json(query):
|
||
dic = {"id": 1, "question": query}
|
||
with open('test_robot.jsonl', "w", encoding='utf-8') as fout:
|
||
json.dump(dic, fout, ensure_ascii=False)
|
||
|
||
def get_answer():
|
||
with open('robot_result\\test_robot.jsonl', "w", encoding='utf-8') as fin:
|
||
for k, example in enumerate(fin):
|
||
example = json.loads(example)
|
||
answer = example["ctxs"][0]["text"]
|
||
score = example["ctxs"][0]["score"]
|
||
return score, answer
|
||
def retri(query):
|
||
get_json(query)
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument(
|
||
"--data",
|
||
#required=True,
|
||
type=str,
|
||
default='test_robot.jsonl',
|
||
help=".json file containing question and answers, similar format to reader data",
|
||
)
|
||
# parser.add_argument("--passages", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/train_robot.jsonl', help="Path to passages (.tsv file)")
|
||
# parser.add_argument("--passages_embeddings", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/robot_embeddings/*', help="Glob path to encoded passages")
|
||
parser.add_argument("--passages", type=str, default='D:/AAAAA_EI_LLM/UnrealProject/RobotProject/Plugins/RoboWaiter/robowaiter/llm_client/train_robot.jsonl', help="Path to passages (.tsv file)")
|
||
parser.add_argument("--passages_embeddings", type=str, default='D:/AAAAA_EI_LLM/UnrealProject/RobotProject/Plugins/RoboWaiter/robowaiter/algos/retrieval/robot_embeddings/*', help="Glob path to encoded passages")
|
||
|
||
parser.add_argument(
|
||
"--output_dir", type=str, default='robot_result', help="Results are written to outputdir with data suffix"
|
||
)
|
||
parser.add_argument("--n_docs", type=int, default=5, help="Number of documents to retrieve per questions") #可以改这个参数,返回前n_docs个检索结果
|
||
parser.add_argument(
|
||
"--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
|
||
)
|
||
parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding")
|
||
parser.add_argument(
|
||
"--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
|
||
)
|
||
# parser.add_argument(
|
||
# "--model_name_or_path", type=str, default='C:\\Users\\huangyu\\Desktop\\RoboWaiter-main\\RoboWaiter-main\\contriever-msmarco',help="path to directory containing model weights and config file"
|
||
# )
|
||
parser.add_argument(
|
||
"--model_name_or_path", type=str, default='D:\\AAAAA_EI_LLM\\UnrealProject\\RobotProject\\Plugins\\RoboWaiter\\robowaiter\\algos\\retrieval\\contriever-msmarco',help="path to directory containing model weights and config file"
|
||
)
|
||
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
|
||
parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
|
||
parser.add_argument(
|
||
"--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
|
||
)
|
||
parser.add_argument("--projection_size", type=int, default=768)
|
||
parser.add_argument(
|
||
"--n_subquantizers",
|
||
type=int,
|
||
default=0,
|
||
help="Number of subquantizer used for vector quantization, if 0 flat index is used",
|
||
)
|
||
parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer")
|
||
parser.add_argument("--lang", nargs="+")
|
||
parser.add_argument("--dataset", type=str, default="none")
|
||
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
|
||
parser.add_argument("--normalize_text", action="store_true", help="normalize text")
|
||
|
||
args = parser.parse_args()
|
||
init_distributed_mode(args)
|
||
#print(args)
|
||
ret = test(args)
|
||
#print(ret)
|
||
|
||
return ret[0]
|
||
|
||
# example = ret[0]
|
||
# answer = example["ctxs"][0]["text"]
|
||
# score = example["ctxs"][0]["score"]
|
||
# return score, answer
|
||
|
||
if __name__ == "__main__":
|
||
# query = "请你拿一下软饮料到第三张桌子位置。"
|
||
# score,answer = retri(query)
|
||
# print(score,answer)
|
||
|
||
query = "你能把空调打开一下吗?"
|
||
all_ret = retri(query)
|
||
for i,example in enumerate(all_ret["ctxs"]):
|
||
answer = example["text"]
|
||
score = example["score"]
|
||
id = example["id"]
|
||
print(i,answer,score," id=",id)
|
||
|
||
|