312 lines
12 KiB
Python
312 lines
12 KiB
Python
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|||
|
# All rights reserved.
|
|||
|
#
|
|||
|
# This source code is licensed under the license found in the
|
|||
|
# LICENSE file in the root directory of this source tree.
|
|||
|
|
|||
|
import os
|
|||
|
import argparse
|
|||
|
import json
|
|||
|
import pickle
|
|||
|
import time
|
|||
|
import glob
|
|||
|
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
from robowaiter.algos.retrieval.retrieval_lm.src.slurm import init_distributed_mode
|
|||
|
from robowaiter.algos.retrieval.retrieval_lm.src.normalize_text import normalize
|
|||
|
from robowaiter.algos.retrieval.retrieval_lm.src.contriever import load_retriever
|
|||
|
from robowaiter.algos.retrieval.retrieval_lm.src.index import Indexer
|
|||
|
from robowaiter.algos.retrieval.retrieval_lm.src.data import load_passages
|
|||
|
|
|||
|
from robowaiter.algos.retrieval.retrieval_lm.src.evaluation import calculate_matches
|
|||
|
import warnings
|
|||
|
|
|||
|
warnings.filterwarnings('ignore')
|
|||
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|||
|
|
|||
|
|
|||
|
def embed_queries(args, queries, model, tokenizer):
|
|||
|
model.eval()
|
|||
|
embeddings, batch_question = [], []
|
|||
|
with torch.no_grad():
|
|||
|
|
|||
|
for k, q in enumerate(queries):
|
|||
|
if args.lowercase:
|
|||
|
q = q.lower()
|
|||
|
if args.normalize_text:
|
|||
|
q = normalize(q)
|
|||
|
batch_question.append(q)
|
|||
|
|
|||
|
if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
|
|||
|
|
|||
|
encoded_batch = tokenizer.batch_encode_plus(
|
|||
|
batch_question,
|
|||
|
return_tensors="pt",
|
|||
|
max_length=args.question_maxlength,
|
|||
|
padding=True,
|
|||
|
truncation=True,
|
|||
|
)
|
|||
|
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
|||
|
output = model(**encoded_batch)
|
|||
|
embeddings.append(output.cpu())
|
|||
|
|
|||
|
batch_question = []
|
|||
|
|
|||
|
embeddings = torch.cat(embeddings, dim=0)
|
|||
|
#print(f"Questions embeddings shape: {embeddings.size()}")
|
|||
|
|
|||
|
return embeddings.numpy()
|
|||
|
|
|||
|
|
|||
|
def index_encoded_data(index, embedding_files, indexing_batch_size):
|
|||
|
allids = []
|
|||
|
allembeddings = np.array([])
|
|||
|
for i, file_path in enumerate(embedding_files):
|
|||
|
#print(f"Loading file {file_path}")
|
|||
|
with open(file_path, "rb") as fin:
|
|||
|
ids, embeddings = pickle.load(fin)
|
|||
|
|
|||
|
allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings
|
|||
|
allids.extend(ids)
|
|||
|
while allembeddings.shape[0] > indexing_batch_size:
|
|||
|
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
|||
|
|
|||
|
while allembeddings.shape[0] > 0:
|
|||
|
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
|||
|
|
|||
|
#print("Data indexing completed.")
|
|||
|
|
|||
|
|
|||
|
def add_embeddings(index, embeddings, ids, indexing_batch_size):
|
|||
|
end_idx = min(indexing_batch_size, embeddings.shape[0])
|
|||
|
ids_toadd = ids[:end_idx]
|
|||
|
embeddings_toadd = embeddings[:end_idx]
|
|||
|
ids = ids[end_idx:]
|
|||
|
embeddings = embeddings[end_idx:]
|
|||
|
index.index_data(ids_toadd, embeddings_toadd)
|
|||
|
return embeddings, ids
|
|||
|
|
|||
|
|
|||
|
def validate(data, workers_num):
|
|||
|
match_stats = calculate_matches(data, workers_num)
|
|||
|
top_k_hits = match_stats.top_k_hits
|
|||
|
|
|||
|
# print("Validation results: top k documents hits %s", top_k_hits)
|
|||
|
top_k_hits = [v / len(data) for v in top_k_hits]
|
|||
|
message = ""
|
|||
|
for k in [5, 10, 20, 100]:
|
|||
|
if k <= len(top_k_hits):
|
|||
|
message += f"R@{k}: {top_k_hits[k-1]} "
|
|||
|
#print(message)
|
|||
|
return match_stats.questions_doc_hits
|
|||
|
|
|||
|
|
|||
|
def add_passages(data, passages, top_passages_and_scores):
|
|||
|
# add passages to original data
|
|||
|
merged_data = []
|
|||
|
assert len(data) == len(top_passages_and_scores)
|
|||
|
for i, d in enumerate(data):
|
|||
|
results_and_scores = top_passages_and_scores[i]
|
|||
|
#print(passages[2393])
|
|||
|
docs = [passages[int(doc_id)] for doc_id in results_and_scores[0]]
|
|||
|
scores = [str(score) for score in results_and_scores[1]]
|
|||
|
ctxs_num = len(docs)
|
|||
|
d["ctxs"] = [
|
|||
|
{
|
|||
|
"id": results_and_scores[0][c],
|
|||
|
"title": docs[c]["title"],
|
|||
|
"text": docs[c]["text"],
|
|||
|
"score": scores[c],
|
|||
|
}
|
|||
|
for c in range(ctxs_num)
|
|||
|
]
|
|||
|
|
|||
|
|
|||
|
def add_hasanswer(data, hasanswer):
|
|||
|
# add hasanswer to data
|
|||
|
for i, ex in enumerate(data):
|
|||
|
for k, d in enumerate(ex["ctxs"]):
|
|||
|
d["hasanswer"] = hasanswer[i][k]
|
|||
|
|
|||
|
|
|||
|
# def load_data(data_path):
|
|||
|
# if data_path.endswith(".json"):
|
|||
|
# with open(data_path, "r",encoding='utf-8') as fin:
|
|||
|
# data = json.load(fin)
|
|||
|
# elif data_path.endswith(".jsonl"):
|
|||
|
# data = []
|
|||
|
# with open(data_path, "r",encoding='utf-8') as fin:
|
|||
|
# for k, example in enumerate(fin):
|
|||
|
# example = json.loads(example)
|
|||
|
# data.append(example)
|
|||
|
# print("data:",data)
|
|||
|
# return data
|
|||
|
def load_data(data_path):
|
|||
|
if data_path.endswith(".json"):
|
|||
|
with open(data_path, "r",encoding='utf-8') as fin:
|
|||
|
data = json.load(fin)
|
|||
|
elif data_path.endswith(".jsonl"):
|
|||
|
data = []
|
|||
|
with open(data_path, "r",encoding='utf-8') as fin:
|
|||
|
for k, example in enumerate(fin):
|
|||
|
example = json.loads(example)
|
|||
|
#print("example:",example)
|
|||
|
data.append(example)
|
|||
|
return data
|
|||
|
|
|||
|
|
|||
|
def test(args):#path为query
|
|||
|
# args = {"model_name_or_path":"contriever-msmarco","passages":"train_robot.jsonl"\
|
|||
|
# passages_embeddings = "robot_embeddings/*"
|
|||
|
# data = "test_robot.jsonl"
|
|||
|
# output_dir = "robot_result"
|
|||
|
# n_docs = 1
|
|||
|
|
|||
|
#print(f"Loading model from: {args.model_name_or_path}")
|
|||
|
model, tokenizer, _ = load_retriever(args.model_name_or_path)
|
|||
|
model.eval()
|
|||
|
model = model.cuda()
|
|||
|
if not args.no_fp16:
|
|||
|
model = model.half()
|
|||
|
|
|||
|
index = Indexer(args.projection_size, args.n_subquantizers, args.n_bits)
|
|||
|
|
|||
|
# index all passages
|
|||
|
input_paths = glob.glob(args.passages_embeddings)
|
|||
|
input_paths = sorted(input_paths)
|
|||
|
embeddings_dir = os.path.dirname(input_paths[0])
|
|||
|
index_path = os.path.join(embeddings_dir, "index.faiss")
|
|||
|
if args.save_or_load_index and os.path.exists(index_path):
|
|||
|
index.deserialize_from(embeddings_dir)
|
|||
|
else:
|
|||
|
#print(f"Indexing passages from files {input_paths}")
|
|||
|
start_time_indexing = time.time()
|
|||
|
index_encoded_data(index, input_paths, args.indexing_batch_size)
|
|||
|
#print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
|
|||
|
if args.save_or_load_index:
|
|||
|
index.serialize(embeddings_dir)
|
|||
|
|
|||
|
# load passages
|
|||
|
passages = load_passages(args.passages)
|
|||
|
passage_id_map = {x["id"]: x for x in passages}
|
|||
|
|
|||
|
data_paths = glob.glob(args.data)
|
|||
|
alldata = []
|
|||
|
for path in data_paths:
|
|||
|
data = load_data(path)
|
|||
|
#print("data:",data)
|
|||
|
output_path = os.path.join(args.output_dir, os.path.basename(path))
|
|||
|
|
|||
|
queries = [ex["question"] for ex in data]
|
|||
|
questions_embedding = embed_queries(args, queries, model, tokenizer)
|
|||
|
|
|||
|
# get top k results
|
|||
|
start_time_retrieval = time.time()
|
|||
|
top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
|
|||
|
#print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
|
|||
|
|
|||
|
add_passages(data, passage_id_map, top_ids_and_scores)
|
|||
|
#hasanswer = validate(data, args.validation_workers)
|
|||
|
#add_hasanswer(data, hasanswer)
|
|||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|||
|
ret_list = []
|
|||
|
with open(output_path, "w",encoding='utf-8') as fout:
|
|||
|
for ex in data:
|
|||
|
json.dump(ex, fout, ensure_ascii=False)
|
|||
|
ret_list.append(ex)
|
|||
|
fout.write("\n")
|
|||
|
return ret_list
|
|||
|
#print(f"Saved results to {output_path}")
|
|||
|
|
|||
|
#将query写到test_robot.jsonl
|
|||
|
def get_json(query):
|
|||
|
dic = {"id": 1, "question": query}
|
|||
|
with open('test_robot.jsonl', "w", encoding='utf-8') as fout:
|
|||
|
json.dump(dic, fout, ensure_ascii=False)
|
|||
|
|
|||
|
def get_answer():
|
|||
|
with open('robot_result\\test_robot.jsonl', "w", encoding='utf-8') as fin:
|
|||
|
for k, example in enumerate(fin):
|
|||
|
example = json.loads(example)
|
|||
|
answer = example["ctxs"][0]["text"]
|
|||
|
score = example["ctxs"][0]["score"]
|
|||
|
return score, answer
|
|||
|
def retri(query):
|
|||
|
get_json(query)
|
|||
|
parser = argparse.ArgumentParser()
|
|||
|
parser.add_argument(
|
|||
|
"--data",
|
|||
|
#required=True,
|
|||
|
type=str,
|
|||
|
default='test_robot.jsonl',
|
|||
|
help=".json file containing question and answers, similar format to reader data",
|
|||
|
)
|
|||
|
# parser.add_argument("--passages", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/train_robot.jsonl', help="Path to passages (.tsv file)")
|
|||
|
# parser.add_argument("--passages_embeddings", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/robot_embeddings/*', help="Glob path to encoded passages")
|
|||
|
parser.add_argument("--passages", type=str, default='D:/AAAAA_EI_LLM/UnrealProject/RobotProject/Plugins/RoboWaiter/robowaiter/llm_client/train_robot.jsonl', help="Path to passages (.tsv file)")
|
|||
|
parser.add_argument("--passages_embeddings", type=str, default='D:/AAAAA_EI_LLM/UnrealProject/RobotProject/Plugins/RoboWaiter/robowaiter/algos/retrieval/robot_embeddings/*', help="Glob path to encoded passages")
|
|||
|
|
|||
|
parser.add_argument(
|
|||
|
"--output_dir", type=str, default='robot_result', help="Results are written to outputdir with data suffix"
|
|||
|
)
|
|||
|
parser.add_argument("--n_docs", type=int, default=5, help="Number of documents to retrieve per questions") #可以改这个参数,返回前n_docs个检索结果
|
|||
|
parser.add_argument(
|
|||
|
"--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
|
|||
|
)
|
|||
|
parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding")
|
|||
|
parser.add_argument(
|
|||
|
"--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
|
|||
|
)
|
|||
|
# parser.add_argument(
|
|||
|
# "--model_name_or_path", type=str, default='C:\\Users\\huangyu\\Desktop\\RoboWaiter-main\\RoboWaiter-main\\contriever-msmarco',help="path to directory containing model weights and config file"
|
|||
|
# )
|
|||
|
parser.add_argument(
|
|||
|
"--model_name_or_path", type=str, default='D:\\AAAAA_EI_LLM\\UnrealProject\\RobotProject\\Plugins\\RoboWaiter\\robowaiter\\algos\\retrieval\\contriever-msmarco',help="path to directory containing model weights and config file"
|
|||
|
)
|
|||
|
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
|
|||
|
parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
|
|||
|
parser.add_argument(
|
|||
|
"--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
|
|||
|
)
|
|||
|
parser.add_argument("--projection_size", type=int, default=768)
|
|||
|
parser.add_argument(
|
|||
|
"--n_subquantizers",
|
|||
|
type=int,
|
|||
|
default=0,
|
|||
|
help="Number of subquantizer used for vector quantization, if 0 flat index is used",
|
|||
|
)
|
|||
|
parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer")
|
|||
|
parser.add_argument("--lang", nargs="+")
|
|||
|
parser.add_argument("--dataset", type=str, default="none")
|
|||
|
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
|
|||
|
parser.add_argument("--normalize_text", action="store_true", help="normalize text")
|
|||
|
|
|||
|
args = parser.parse_args()
|
|||
|
init_distributed_mode(args)
|
|||
|
#print(args)
|
|||
|
ret = test(args)
|
|||
|
#print(ret)
|
|||
|
|
|||
|
return ret[0]
|
|||
|
|
|||
|
# example = ret[0]
|
|||
|
# answer = example["ctxs"][0]["text"]
|
|||
|
# score = example["ctxs"][0]["score"]
|
|||
|
# return score, answer
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
# query = "请你拿一下软饮料到第三张桌子位置。"
|
|||
|
# score,answer = retri(query)
|
|||
|
# print(score,answer)
|
|||
|
|
|||
|
query = "你能把空调打开一下吗?"
|
|||
|
all_ret = retri(query)
|
|||
|
for i,example in enumerate(all_ret["ctxs"]):
|
|||
|
answer = example["text"]
|
|||
|
score = example["score"]
|
|||
|
id = example["id"]
|
|||
|
print(i,answer,score," id=",id)
|
|||
|
|
|||
|
|