RoboWaiter/robowaiter/llm_client/passage_retrieval3.py

312 lines
12 KiB
Python
Raw Normal View History

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import argparse
import json
import pickle
import time
import glob
import numpy as np
import torch
from robowaiter.algos.retrieval.retrieval_lm.src.slurm import init_distributed_mode
from robowaiter.algos.retrieval.retrieval_lm.src.normalize_text import normalize
from robowaiter.algos.retrieval.retrieval_lm.src.contriever import load_retriever
from robowaiter.algos.retrieval.retrieval_lm.src.index import Indexer
from robowaiter.algos.retrieval.retrieval_lm.src.data import load_passages
from robowaiter.algos.retrieval.retrieval_lm.src.evaluation import calculate_matches
import warnings
warnings.filterwarnings('ignore')
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def embed_queries(args, queries, model, tokenizer):
model.eval()
embeddings, batch_question = [], []
with torch.no_grad():
for k, q in enumerate(queries):
if args.lowercase:
q = q.lower()
if args.normalize_text:
q = normalize(q)
batch_question.append(q)
if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
encoded_batch = tokenizer.batch_encode_plus(
batch_question,
return_tensors="pt",
max_length=args.question_maxlength,
padding=True,
truncation=True,
)
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
output = model(**encoded_batch)
embeddings.append(output.cpu())
batch_question = []
embeddings = torch.cat(embeddings, dim=0)
#print(f"Questions embeddings shape: {embeddings.size()}")
return embeddings.numpy()
def index_encoded_data(index, embedding_files, indexing_batch_size):
allids = []
allembeddings = np.array([])
for i, file_path in enumerate(embedding_files):
#print(f"Loading file {file_path}")
with open(file_path, "rb") as fin:
ids, embeddings = pickle.load(fin)
allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings
allids.extend(ids)
while allembeddings.shape[0] > indexing_batch_size:
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
while allembeddings.shape[0] > 0:
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
#print("Data indexing completed.")
def add_embeddings(index, embeddings, ids, indexing_batch_size):
end_idx = min(indexing_batch_size, embeddings.shape[0])
ids_toadd = ids[:end_idx]
embeddings_toadd = embeddings[:end_idx]
ids = ids[end_idx:]
embeddings = embeddings[end_idx:]
index.index_data(ids_toadd, embeddings_toadd)
return embeddings, ids
def validate(data, workers_num):
match_stats = calculate_matches(data, workers_num)
top_k_hits = match_stats.top_k_hits
# print("Validation results: top k documents hits %s", top_k_hits)
top_k_hits = [v / len(data) for v in top_k_hits]
message = ""
for k in [5, 10, 20, 100]:
if k <= len(top_k_hits):
message += f"R@{k}: {top_k_hits[k-1]} "
#print(message)
return match_stats.questions_doc_hits
def add_passages(data, passages, top_passages_and_scores):
# add passages to original data
merged_data = []
assert len(data) == len(top_passages_and_scores)
for i, d in enumerate(data):
results_and_scores = top_passages_and_scores[i]
#print(passages[2393])
docs = [passages[int(doc_id)] for doc_id in results_and_scores[0]]
scores = [str(score) for score in results_and_scores[1]]
ctxs_num = len(docs)
d["ctxs"] = [
{
"id": results_and_scores[0][c],
"title": docs[c]["title"],
"text": docs[c]["text"],
"score": scores[c],
}
for c in range(ctxs_num)
]
def add_hasanswer(data, hasanswer):
# add hasanswer to data
for i, ex in enumerate(data):
for k, d in enumerate(ex["ctxs"]):
d["hasanswer"] = hasanswer[i][k]
# def load_data(data_path):
# if data_path.endswith(".json"):
# with open(data_path, "r",encoding='utf-8') as fin:
# data = json.load(fin)
# elif data_path.endswith(".jsonl"):
# data = []
# with open(data_path, "r",encoding='utf-8') as fin:
# for k, example in enumerate(fin):
# example = json.loads(example)
# data.append(example)
# print("data:",data)
# return data
def load_data(data_path):
if data_path.endswith(".json"):
with open(data_path, "r",encoding='utf-8') as fin:
data = json.load(fin)
elif data_path.endswith(".jsonl"):
data = []
with open(data_path, "r",encoding='utf-8') as fin:
for k, example in enumerate(fin):
example = json.loads(example)
#print("example:",example)
data.append(example)
return data
def test(args):#path为query
# args = {"model_name_or_path":"contriever-msmarco","passages":"train_robot.jsonl"\
# passages_embeddings = "robot_embeddings/*"
# data = "test_robot.jsonl"
# output_dir = "robot_result"
# n_docs = 1
#print(f"Loading model from: {args.model_name_or_path}")
model, tokenizer, _ = load_retriever(args.model_name_or_path)
model.eval()
model = model.cuda()
if not args.no_fp16:
model = model.half()
index = Indexer(args.projection_size, args.n_subquantizers, args.n_bits)
# index all passages
input_paths = glob.glob(args.passages_embeddings)
input_paths = sorted(input_paths)
embeddings_dir = os.path.dirname(input_paths[0])
index_path = os.path.join(embeddings_dir, "index.faiss")
if args.save_or_load_index and os.path.exists(index_path):
index.deserialize_from(embeddings_dir)
else:
#print(f"Indexing passages from files {input_paths}")
start_time_indexing = time.time()
index_encoded_data(index, input_paths, args.indexing_batch_size)
#print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
if args.save_or_load_index:
index.serialize(embeddings_dir)
# load passages
passages = load_passages(args.passages)
passage_id_map = {x["id"]: x for x in passages}
data_paths = glob.glob(args.data)
alldata = []
for path in data_paths:
data = load_data(path)
#print("data:",data)
output_path = os.path.join(args.output_dir, os.path.basename(path))
queries = [ex["question"] for ex in data]
questions_embedding = embed_queries(args, queries, model, tokenizer)
# get top k results
start_time_retrieval = time.time()
top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
#print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
add_passages(data, passage_id_map, top_ids_and_scores)
#hasanswer = validate(data, args.validation_workers)
#add_hasanswer(data, hasanswer)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
ret_list = []
with open(output_path, "w",encoding='utf-8') as fout:
for ex in data:
json.dump(ex, fout, ensure_ascii=False)
ret_list.append(ex)
fout.write("\n")
return ret_list
#print(f"Saved results to {output_path}")
#将query写到test_robot.jsonl
def get_json(query):
dic = {"id": 1, "question": query}
with open('test_robot.jsonl', "w", encoding='utf-8') as fout:
json.dump(dic, fout, ensure_ascii=False)
def get_answer():
with open('robot_result\\test_robot.jsonl', "w", encoding='utf-8') as fin:
for k, example in enumerate(fin):
example = json.loads(example)
answer = example["ctxs"][0]["text"]
score = example["ctxs"][0]["score"]
return score, answer
def retri(query):
get_json(query)
parser = argparse.ArgumentParser()
parser.add_argument(
"--data",
#required=True,
type=str,
default='test_robot.jsonl',
help=".json file containing question and answers, similar format to reader data",
)
# parser.add_argument("--passages", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/train_robot.jsonl', help="Path to passages (.tsv file)")
# parser.add_argument("--passages_embeddings", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/robot_embeddings/*', help="Glob path to encoded passages")
parser.add_argument("--passages", type=str, default='D:/AAAAA_EI_LLM/UnrealProject/RobotProject/Plugins/RoboWaiter/robowaiter/llm_client/train_robot.jsonl', help="Path to passages (.tsv file)")
parser.add_argument("--passages_embeddings", type=str, default='D:/AAAAA_EI_LLM/UnrealProject/RobotProject/Plugins/RoboWaiter/robowaiter/algos/retrieval/robot_embeddings/*', help="Glob path to encoded passages")
parser.add_argument(
"--output_dir", type=str, default='robot_result', help="Results are written to outputdir with data suffix"
)
parser.add_argument("--n_docs", type=int, default=5, help="Number of documents to retrieve per questions") #可以改这个参数返回前n_docs个检索结果
parser.add_argument(
"--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
)
parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding")
parser.add_argument(
"--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
)
# parser.add_argument(
# "--model_name_or_path", type=str, default='C:\\Users\\huangyu\\Desktop\\RoboWaiter-main\\RoboWaiter-main\\contriever-msmarco',help="path to directory containing model weights and config file"
# )
parser.add_argument(
"--model_name_or_path", type=str, default='D:\\AAAAA_EI_LLM\\UnrealProject\\RobotProject\\Plugins\\RoboWaiter\\robowaiter\\algos\\retrieval\\contriever-msmarco',help="path to directory containing model weights and config file"
)
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
parser.add_argument(
"--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
)
parser.add_argument("--projection_size", type=int, default=768)
parser.add_argument(
"--n_subquantizers",
type=int,
default=0,
help="Number of subquantizer used for vector quantization, if 0 flat index is used",
)
parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer")
parser.add_argument("--lang", nargs="+")
parser.add_argument("--dataset", type=str, default="none")
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
parser.add_argument("--normalize_text", action="store_true", help="normalize text")
args = parser.parse_args()
init_distributed_mode(args)
#print(args)
ret = test(args)
#print(ret)
return ret[0]
# example = ret[0]
# answer = example["ctxs"][0]["text"]
# score = example["ctxs"][0]["score"]
# return score, answer
if __name__ == "__main__":
# query = "请你拿一下软饮料到第三张桌子位置。"
# score,answer = retri(query)
# print(score,answer)
query = "你能把空调打开一下吗?"
all_ret = retri(query)
for i,example in enumerate(all_ret["ctxs"]):
answer = example["text"]
score = example["score"]
id = example["id"]
print(i,answer,score," id=",id)