# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import argparse import pickle import torch def embed_passages(args, passages, model, tokenizer): total = 0 allids, allembeddings = [], [] batch_ids, batch_text = [], [] with torch.no_grad(): for k, p in enumerate(passages): batch_ids.append(p["id"]) """if args.no_title or not "title" in p: text = p["text"] else: text = p["title"] + " " + p["text"]""" text = p["title"] if args.lowercase: text = text.lower() if args.normalize_text: text = robowaiter.llm_client.retrieval_lm.src.normalize_text.normalize(text) batch_text.append(text) if len(batch_text) == args.per_gpu_batch_size or k == len(passages) - 1: encoded_batch = tokenizer.batch_encode_plus( batch_text, return_tensors="pt", max_length=args.passage_maxlength, padding=True, truncation=True, ) encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} embeddings = model(**encoded_batch) embeddings = embeddings.cpu() total += len(batch_ids) allids.extend(batch_ids) allembeddings.append(embeddings) batch_text = [] batch_ids = [] if k % 100000 == 0 and k > 0: print(f"Encoded passages {total}") allembeddings = torch.cat(allembeddings, dim=0).numpy() return allids, allembeddings def main(args): model, tokenizer, _ = robowaiter.llm_client.retrieval_lm.src.contriever.load_retriever(args.model_name_or_path) print(f"Model loaded from {args.model_name_or_path}.", flush=True) model.eval() model = model.cuda() if not args.no_fp16: model = model.half() passages = robowaiter.llm_client.retrieval_lm.src.data.load_passages(args.passages) shard_size = len(passages) // args.num_shards start_idx = args.shard_id * shard_size end_idx = start_idx + shard_size if args.shard_id == args.num_shards - 1: end_idx = len(passages) passages = passages[start_idx:end_idx] print(f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.") allids, allembeddings = embed_passages(args, passages, model, tokenizer) save_file = os.path.join(args.output_dir, args.prefix + f"_{args.shard_id:02d}") os.makedirs(args.output_dir, exist_ok=True) print(f"Saving {len(allids)} passage embeddings to {save_file}.") with open(save_file, mode="wb") as f: pickle.dump((allids, allembeddings), f) print(f"Total passages processed {len(allids)}. Written to {save_file}.") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)") parser.add_argument("--output_dir", type=str, default="wikipedia_embeddings", help="dir path to save embeddings") parser.add_argument("--prefix", type=str, default="passages", help="prefix path to save embeddings") parser.add_argument("--shard_id", type=int, default=0, help="Id of the current shard") parser.add_argument("--num_shards", type=int, default=1, help="Total number of shards") parser.add_argument( "--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass" ) parser.add_argument("--passage_maxlength", type=int, default=512, help="Maximum number of tokens in a passage") parser.add_argument( "--model_name_or_path", type=str, help="path to directory containing model weights and config file" ) parser.add_argument("--no_fp16", action="store_true", help="inference in fp32") parser.add_argument("--no_title", action="store_true", help="title not added to the passage body") parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding") parser.add_argument("--normalize_text", action="store_true", help="lowercase text before encoding") args = parser.parse_args() robowaiter.llm_client.retrieval_lm.src.slurm.init_distributed_mode(args) main(args)