116 lines
4.4 KiB
Python
116 lines
4.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import os
|
|
|
|
import argparse
|
|
import pickle
|
|
|
|
import torch
|
|
|
|
|
|
def embed_passages(args, passages, model, tokenizer):
|
|
total = 0
|
|
allids, allembeddings = [], []
|
|
batch_ids, batch_text = [], []
|
|
with torch.no_grad():
|
|
for k, p in enumerate(passages):
|
|
batch_ids.append(p["id"])
|
|
|
|
"""if args.no_title or not "title" in p:
|
|
text = p["text"]
|
|
else:
|
|
text = p["title"] + " " + p["text"]"""
|
|
text = p["title"]
|
|
if args.lowercase:
|
|
text = text.lower()
|
|
if args.normalize_text:
|
|
text = robowaiter.llm_client.retrieval_lm.src.normalize_text.normalize(text)
|
|
batch_text.append(text)
|
|
|
|
if len(batch_text) == args.per_gpu_batch_size or k == len(passages) - 1:
|
|
|
|
encoded_batch = tokenizer.batch_encode_plus(
|
|
batch_text,
|
|
return_tensors="pt",
|
|
max_length=args.passage_maxlength,
|
|
padding=True,
|
|
truncation=True,
|
|
)
|
|
|
|
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
|
embeddings = model(**encoded_batch)
|
|
|
|
embeddings = embeddings.cpu()
|
|
total += len(batch_ids)
|
|
allids.extend(batch_ids)
|
|
allembeddings.append(embeddings)
|
|
|
|
batch_text = []
|
|
batch_ids = []
|
|
if k % 100000 == 0 and k > 0:
|
|
print(f"Encoded passages {total}")
|
|
|
|
allembeddings = torch.cat(allembeddings, dim=0).numpy()
|
|
return allids, allembeddings
|
|
|
|
|
|
def main(args):
|
|
model, tokenizer, _ = robowaiter.llm_client.retrieval_lm.src.contriever.load_retriever(args.model_name_or_path)
|
|
print(f"Model loaded from {args.model_name_or_path}.", flush=True)
|
|
model.eval()
|
|
model = model.cuda()
|
|
if not args.no_fp16:
|
|
model = model.half()
|
|
|
|
passages = robowaiter.llm_client.retrieval_lm.src.data.load_passages(args.passages)
|
|
|
|
shard_size = len(passages) // args.num_shards
|
|
start_idx = args.shard_id * shard_size
|
|
end_idx = start_idx + shard_size
|
|
if args.shard_id == args.num_shards - 1:
|
|
end_idx = len(passages)
|
|
|
|
passages = passages[start_idx:end_idx]
|
|
print(f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.")
|
|
|
|
allids, allembeddings = embed_passages(args, passages, model, tokenizer)
|
|
|
|
save_file = os.path.join(args.output_dir, args.prefix + f"_{args.shard_id:02d}")
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
print(f"Saving {len(allids)} passage embeddings to {save_file}.")
|
|
with open(save_file, mode="wb") as f:
|
|
pickle.dump((allids, allembeddings), f)
|
|
|
|
print(f"Total passages processed {len(allids)}. Written to {save_file}.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
|
|
parser.add_argument("--output_dir", type=str, default="wikipedia_embeddings", help="dir path to save embeddings")
|
|
parser.add_argument("--prefix", type=str, default="passages", help="prefix path to save embeddings")
|
|
parser.add_argument("--shard_id", type=int, default=0, help="Id of the current shard")
|
|
parser.add_argument("--num_shards", type=int, default=1, help="Total number of shards")
|
|
parser.add_argument(
|
|
"--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass"
|
|
)
|
|
parser.add_argument("--passage_maxlength", type=int, default=512, help="Maximum number of tokens in a passage")
|
|
parser.add_argument(
|
|
"--model_name_or_path", type=str, help="path to directory containing model weights and config file"
|
|
)
|
|
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
|
|
parser.add_argument("--no_title", action="store_true", help="title not added to the passage body")
|
|
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
|
|
parser.add_argument("--normalize_text", action="store_true", help="lowercase text before encoding")
|
|
|
|
args = parser.parse_args()
|
|
|
|
robowaiter.llm_client.retrieval_lm.src.slurm.init_distributed_mode(args)
|
|
|
|
main(args)
|