Merge branch 'main' of https://github.com/HPCL-EI/RoboWaiter
This commit is contained in:
commit
089d987c28
|
@ -19,6 +19,7 @@ share/python-wheels/
|
|||
MANIFEST
|
||||
MO-VLN/
|
||||
GLIP/
|
||||
pytorch_model.bin
|
||||
|
||||
sub_task.ptml
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@ pip install -e .
|
|||
### 安装UI
|
||||
1. 安装 [graphviz-9.0.0](https://gitlab.com/api/v4/projects/4207231/packages/generic/graphviz-releases/9.0.0/windows_10_cmake_Release_graphviz-install-9.0.0-win64.exe) (详见[官网](https://www.graphviz.org/download/#windows))
|
||||
2. 将软件安装目录的bin文件添加到系统环境中。如电脑是Windows系统,Graphviz安装在D:\Program Files (x86)\Graphviz2.38,该目录下有bin文件,将该路径添加到电脑系统环境变量path中,即D:\Program Files (x86)\Graphviz2.38\bin。
|
||||
3. 安装向量数据库
|
||||
conda install -c conda-forge faiss
|
||||
|
||||
### 快速入门
|
||||
1. 安装UE及Harix插件,打开默认项目并运行
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
{
|
||||
"architectures": [
|
||||
"Contriever"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"classifier_dropout": null,
|
||||
"gradient_checkpointing": false,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 768,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 0,
|
||||
"position_embedding_type": "absolute",
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.15.0",
|
||||
"type_vocab_size": 2,
|
||||
"use_cache": true,
|
||||
"vocab_size": 30522
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1 @@
|
|||
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "bert-base-uncased", "tokenizer_class": "BertTokenizer"}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,4 @@
|
|||
pip install gdown
|
||||
gdown 1IYNAkwawfCDiBL27BlBqGssxFQH9vOux
|
||||
unzip enwiki_2020_intro_only.zip
|
||||
rm enwiki_2020_intro_only.zip
|
|
@ -0,0 +1,731 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import datasets
|
||||
import torch
|
||||
import copy
|
||||
from functools import partial
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from typing import Optional, Dict, Sequence
|
||||
import json
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
LlamaTokenizer,
|
||||
LlamaTokenizerFast,
|
||||
SchedulerType,
|
||||
DataCollatorForSeq2Seq,
|
||||
get_scheduler,
|
||||
GPTNeoXTokenizerFast,
|
||||
GPT2Tokenizer,
|
||||
OPTForCausalLM
|
||||
)
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
PROMPT_DICT = {
|
||||
"prompt_input": (
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
),
|
||||
"prompt_no_input": (
|
||||
"### Instruction:\n{instruction}\n\n### Response:\n"
|
||||
),
|
||||
}
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The configuration name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained config name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_lora",
|
||||
action="store_true",
|
||||
help="If passed, will use LORA (low-rank parameter-efficient training) to train the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_rank",
|
||||
type=int,
|
||||
default=64,
|
||||
help="The rank of lora.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=float,
|
||||
default=16,
|
||||
help="The alpha parameter of lora.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_dropout",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="The dropout rate of lora modules.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_merged_lora_model",
|
||||
action="store_true",
|
||||
help="If passed, will merge the lora modules and save the entire model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attn",
|
||||
action="store_true",
|
||||
help="If passed, will use flash attention to train the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
type=int,
|
||||
default=512,
|
||||
help="The maximum total sequence length (prompt+completion) of each training example.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_type",
|
||||
type=SchedulerType,
|
||||
default="linear",
|
||||
help="The scheduler type to use.",
|
||||
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_ratio", type=float, default=0, help="Ratio of total training steps used for warmup."
|
||||
)
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--preprocessing_num_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The number of processes to use for the preprocessing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Log the training loss and learning rate every logging_steps steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="If the training should continue from a checkpoint folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_tracking",
|
||||
action="store_true",
|
||||
help="Whether to enable experiment trackers for logging.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="all",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||
' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
|
||||
"Only applicable when `--with_tracking` is passed."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--low_cpu_mem_usage",
|
||||
action="store_true",
|
||||
help=(
|
||||
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
|
||||
"If passed, LLM loading time and RAM consumption will be benefited."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_special_tokens",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Use special tokens."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
if args.dataset_name is None and args.train_file is None:
|
||||
raise ValueError("Need either a dataset name or a training file.")
|
||||
else:
|
||||
if args.train_file is not None:
|
||||
extension = args.train_file.split(".")[-1]
|
||||
assert extension in ["json", "jsonl"], "`train_file` should be a json/jsonl file."
|
||||
return args
|
||||
|
||||
def _tokenize_fn(text: str, tokenizer: transformers.PreTrainedTokenizer, max_seq_length: int) -> Dict:
|
||||
"""Tokenize a list of strings."""
|
||||
input_ids = labels = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=max_seq_length,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
input_ids_lens = labels_lens = input_ids.ne(tokenizer.pad_token_id).sum().item()
|
||||
print(input_ids_lens)
|
||||
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
def encode_with_prompt_completion_format(example, tokenizer, max_seq_length, context_markups=None):
|
||||
'''
|
||||
Here we assume each example has 'prompt' and 'completion' fields.
|
||||
We concatenate prompt and completion and tokenize them together because otherwise prompt will be padded/trancated
|
||||
and it doesn't make sense to follow directly with the completion.
|
||||
'''
|
||||
# if prompt doesn't end with space and completion doesn't start with space, add space
|
||||
|
||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||
source_text = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
||||
target_text = example['output'] + tokenizer.eos_token
|
||||
examples_tokenized = _tokenize_fn(source_text + target_text, tokenizer, max_seq_length)
|
||||
sources_tokenized = _tokenize_fn(source_text, tokenizer, max_seq_length)
|
||||
|
||||
input_ids = examples_tokenized["input_ids"].flatten()
|
||||
source_len = sources_tokenized["input_ids_lens"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
labels[ :source_len-1] = -100
|
||||
|
||||
if context_markups is not None:
|
||||
context_start = False
|
||||
for j, orig_token in enumerate(labels[source_len:]):
|
||||
if context_start is False and orig_token == context_markups[0]:
|
||||
context_start = True
|
||||
assert labels[source_len+j] == context_markups[0]
|
||||
start_idx = j+source_len
|
||||
end_idx = None
|
||||
for k, orig_token_2 in enumerate(labels[start_idx:]):
|
||||
if orig_token_2 == context_markups[1]:
|
||||
end_idx = start_idx + k
|
||||
if end_idx is None:
|
||||
end_idx = start_idx + k
|
||||
else:
|
||||
assert labels[end_idx] == context_markups[1]
|
||||
labels[start_idx+1:end_idx] = -100
|
||||
context_start = False
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
return {
|
||||
'input_ids': input_ids.flatten(),
|
||||
'labels': labels.flatten(),
|
||||
'attention_mask': attention_mask.flatten()
|
||||
}
|
||||
|
||||
def encode_with_messages_format(example, tokenizer, max_seq_length):
|
||||
'''
|
||||
Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields.
|
||||
We concatenate all messages with the roles as delimiters and tokenize them together.
|
||||
'''
|
||||
messages = example['messages']
|
||||
if len(messages) == 0:
|
||||
raise ValueError('messages field is empty.')
|
||||
|
||||
def _concat_messages(messages):
|
||||
message_text = ""
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
message_text += "<|system|>\n" + message["content"].strip() + "\n"
|
||||
elif message["role"] == "user":
|
||||
message_text += "<|user|>\n" + message["content"].strip() + "\n"
|
||||
elif message["role"] == "assistant":
|
||||
message_text += "<|assistant|>\n" + message["content"].strip() + tokenizer.eos_token + "\n"
|
||||
else:
|
||||
raise ValueError("Invalid role: {}".format(message["role"]))
|
||||
return message_text
|
||||
|
||||
example_text = _concat_messages(messages).strip()
|
||||
tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
|
||||
input_ids = tokenized_example.input_ids
|
||||
labels = input_ids.clone()
|
||||
|
||||
# mask the non-assistant part for avoiding loss
|
||||
for message_idx, message in enumerate(messages):
|
||||
if message["role"] != "assistant":
|
||||
if message_idx == 0:
|
||||
message_start_idx = 0
|
||||
else:
|
||||
message_start_idx = tokenizer(
|
||||
_concat_messages(messages[:message_idx]), return_tensors='pt', max_length=max_seq_length, truncation=True
|
||||
).input_ids.shape[1]
|
||||
if message_idx < len(messages) - 1 and messages[message_idx+1]["role"] == "assistant":
|
||||
# here we also ignore the role of the assistant
|
||||
messages_so_far = _concat_messages(messages[:message_idx+1]) + "<|assistant|>\n"
|
||||
else:
|
||||
messages_so_far = _concat_messages(messages[:message_idx+1])
|
||||
message_end_idx = tokenizer(
|
||||
messages_so_far,
|
||||
return_tensors='pt',
|
||||
max_length=max_seq_length,
|
||||
truncation=True
|
||||
).input_ids.shape[1]
|
||||
labels[:, message_start_idx:message_end_idx] = -100
|
||||
|
||||
if message_end_idx >= max_seq_length:
|
||||
break
|
||||
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return {
|
||||
'input_ids': input_ids.flatten(),
|
||||
'labels': labels.flatten(),
|
||||
'attention_mask': attention_mask.flatten(),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# A hacky way to make llama work with flash attention
|
||||
if args.use_flash_attn:
|
||||
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
||||
replace_llama_attn_with_flash_attn()
|
||||
|
||||
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
||||
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
|
||||
# in the environment
|
||||
accelerator_log_kwargs = {}
|
||||
|
||||
if args.with_tracking:
|
||||
accelerator_log_kwargs["log_with"] = args.report_to
|
||||
accelerator_log_kwargs["project_dir"] = args.output_dir
|
||||
|
||||
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
raw_datasets = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
dataset_args = {}
|
||||
if args.train_file is not None:
|
||||
data_files["train"] = args.train_file
|
||||
raw_datasets = load_dataset(
|
||||
"json",
|
||||
data_files=data_files,
|
||||
**dataset_args,
|
||||
)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.config_name:
|
||||
config = AutoConfig.from_pretrained(args.config_name)
|
||||
elif args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new config instance from scratch. This is not supported by this script."
|
||||
)
|
||||
|
||||
if args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
|
||||
elif args.model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
||||
)
|
||||
|
||||
if args.model_name_or_path:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
logger.info("Training new model from scratch")
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
|
||||
|
||||
# no default pad token for llama!
|
||||
# here we add all special tokens again, because the default ones are not in the special_tokens_map
|
||||
if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast):
|
||||
if args.use_special_tokens is True:
|
||||
special_token_dict = {"additional_special_tokens": ["[No Retrieval]", "[Retrieval]", "[Continue to Use Evidence]", "[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]", "[Fully supported]", "[Partially supported]", "[No support / Contradictory]"]}
|
||||
special_token_dict["bos_token"] = "<s>"
|
||||
special_token_dict["eos_token"] = "</s>"
|
||||
special_token_dict["unk_token"] = "<unk>"
|
||||
special_token_dict["pad_token"] = "<pad>"
|
||||
num_added_tokens = tokenizer.add_special_tokens(special_token_dict)
|
||||
|
||||
context_markups = []
|
||||
for token in ["<paragraph>", "</paragraph>"]:
|
||||
context_markups.append(tokenizer.convert_tokens_to_ids(token))
|
||||
if args.use_special_tokens is False:
|
||||
assert num_added_tokens in [0, 1], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present."
|
||||
else:
|
||||
assert num_added_tokens > 10, "special tokens must be added to the original tokenizers."
|
||||
elif isinstance(tokenizer, GPTNeoXTokenizerFast):
|
||||
num_added_tokens = tokenizer.add_special_tokens({
|
||||
"pad_token": "<pad>",
|
||||
})
|
||||
assert num_added_tokens == 1, "GPTNeoXTokenizer should only add one special token - the pad_token."
|
||||
elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM):
|
||||
num_added_tokens = tokenizer.add_special_tokens({'unk_token': '<unk>'})
|
||||
|
||||
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||
# on a small vocab and want a smaller embedding size, remove this test.
|
||||
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||
if len(tokenizer) > embedding_size:
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if args.use_lora:
|
||||
logger.info("Initializing LORA model...")
|
||||
modules_to_save = ["embed_tokens"]
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=args.lora_rank,
|
||||
#modules_to_save=modules_to_save,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
|
||||
encode_function = partial(
|
||||
encode_with_prompt_completion_format,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
context_markups=context_markups if args.use_special_tokens is True else None
|
||||
)
|
||||
# elif "messages" in raw_datasets["train"].column_names:
|
||||
# encode_function = partial(
|
||||
# encode_with_messages_format,
|
||||
# tokenizer=tokenizer,
|
||||
# max_seq_length=args.max_seq_length,
|
||||
# )
|
||||
with accelerator.main_process_first():
|
||||
lm_datasets = raw_datasets.map(
|
||||
encode_function,
|
||||
batched=False,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
remove_columns=[name for name in raw_datasets["train"].column_names if name not in ["input_ids", "labels", "attention_mask"]],
|
||||
desc="Tokenizing and reformatting instruction data",
|
||||
)
|
||||
lm_datasets.set_format(type="pt")
|
||||
lm_datasets = lm_datasets.filter(lambda example: (example['labels'] != -100).any())
|
||||
|
||||
train_dataset = lm_datasets["train"]
|
||||
#print(train_dataset[0])
|
||||
#print(train_dataset[1000])
|
||||
#print(train_dataset[500])
|
||||
#print(train_dataset[2000])
|
||||
#print(train_dataset[10000])
|
||||
with open("processed.json", "w") as outfile:
|
||||
new_data = []
|
||||
for item in train_dataset:
|
||||
print(item)
|
||||
labels = [int(i) for i in item["labels"]]
|
||||
input_ids = [int(i) for i in item["input_ids"]]
|
||||
new_data.append({"labels": labels, "input_ids": input_ids})
|
||||
json.dump(new_data, outfile)
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"),
|
||||
batch_size=args.per_device_train_batch_size
|
||||
)
|
||||
|
||||
# Optimizer
|
||||
# Split weights in two groups, one with weight decay and the other not.
|
||||
no_decay = ["bias", "layer_norm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
# Create the learning rate scheduler.
|
||||
# Note: the current accelerator.step() calls the .step() of the real scheduler for the `num_processes` times. This is because they assume
|
||||
# the user initialize the scheduler with the entire training set. In the case of data parallel training, each process only
|
||||
# sees a subset (1/num_processes) of the training set. So each time the process needs to update the lr multiple times so that the total
|
||||
# number of updates in the end matches the num_training_steps here.
|
||||
# Here we need to set the num_training_steps to either using the entire training set (when epochs is specified) or we need to multiply the
|
||||
# num_training_steps by num_processes so that the total number of updates matches the num_training_steps.
|
||||
num_training_steps_for_scheduler = args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes
|
||||
lr_scheduler = get_scheduler(
|
||||
name=args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio),
|
||||
)
|
||||
|
||||
# Prepare everything with `accelerator`.
|
||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Figure out how many steps we should save the Accelerator states
|
||||
checkpointing_steps = args.checkpointing_steps
|
||||
if checkpointing_steps is not None and checkpointing_steps.isdigit():
|
||||
checkpointing_steps = int(checkpointing_steps)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if args.with_tracking:
|
||||
experiment_config = vars(args)
|
||||
# TensorBoard cannot log Enums, need the raw value
|
||||
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
|
||||
accelerator.init_trackers("open_instruct", experiment_config)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
completed_steps = 0
|
||||
starting_epoch = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
|
||||
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
|
||||
accelerator.load_state(args.resume_from_checkpoint)
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
|
||||
dirs.sort(key=os.path.getctime)
|
||||
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
|
||||
# Extract `epoch_{i}` or `step_{i}`
|
||||
training_difference = os.path.splitext(path)[0]
|
||||
|
||||
if "epoch" in training_difference:
|
||||
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
|
||||
resume_step = None
|
||||
else:
|
||||
# need to multiply `gradient_accumulation_steps` to reflect real steps
|
||||
resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
|
||||
starting_epoch = resume_step // len(train_dataloader)
|
||||
resume_step -= starting_epoch * len(train_dataloader)
|
||||
|
||||
# update the progress_bar if load from checkpoint
|
||||
progress_bar.update(starting_epoch * num_update_steps_per_epoch)
|
||||
completed_steps = starting_epoch * num_update_steps_per_epoch
|
||||
|
||||
for epoch in range(starting_epoch, args.num_train_epochs):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# We need to skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == starting_epoch:
|
||||
if resume_step is not None and completed_steps < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
completed_steps += 1
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(model):
|
||||
outputs = model(**batch, use_cache=False)
|
||||
loss = outputs.loss
|
||||
# We keep track of the loss at each logged step
|
||||
total_loss += loss.detach().float()
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
# # Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
completed_steps += 1
|
||||
|
||||
if args.logging_steps and completed_steps % args.logging_steps == 0:
|
||||
avg_loss = accelerator.gather(total_loss).mean().item() / args.gradient_accumulation_steps / args.logging_steps
|
||||
logger.info(f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}")
|
||||
if args.with_tracking:
|
||||
accelerator.log(
|
||||
{
|
||||
"learning_rate": lr_scheduler.get_last_lr()[0],
|
||||
"train_loss": avg_loss,
|
||||
},
|
||||
step=completed_steps,
|
||||
)
|
||||
total_loss = 0
|
||||
|
||||
if isinstance(checkpointing_steps, int):
|
||||
if completed_steps % checkpointing_steps == 0:
|
||||
output_dir = f"step_{completed_steps}"
|
||||
if args.output_dir is not None:
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.checkpointing_steps == "epoch":
|
||||
output_dir = f"epoch_{epoch}"
|
||||
if args.output_dir is not None:
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
|
||||
if args.with_tracking:
|
||||
accelerator.end_training()
|
||||
|
||||
if args.output_dir is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
# When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict.
|
||||
# Otherwise, sometimes the model will be saved with only part of the parameters.
|
||||
# Also, accelerator needs to use the wrapped model to get the state_dict.
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
if args.use_lora:
|
||||
# When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process
|
||||
# and has its own save_pretrained function for only saving lora modules.
|
||||
# We have to mannually specify the is_main_process outside the save_pretrained function.
|
||||
if accelerator.is_main_process:
|
||||
unwrapped_model.save_pretrained(args.output_dir, state_dict=state_dict)
|
||||
else:
|
||||
unwrapped_model.save_pretrained(
|
||||
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=state_dict
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import argparse
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def embed_passages(args, passages, model, tokenizer):
|
||||
total = 0
|
||||
allids, allembeddings = [], []
|
||||
batch_ids, batch_text = [], []
|
||||
with torch.no_grad():
|
||||
for k, p in enumerate(passages):
|
||||
batch_ids.append(p["id"])
|
||||
|
||||
"""if args.no_title or not "title" in p:
|
||||
text = p["text"]
|
||||
else:
|
||||
text = p["title"] + " " + p["text"]"""
|
||||
text = p["title"]
|
||||
if args.lowercase:
|
||||
text = text.lower()
|
||||
if args.normalize_text:
|
||||
text = robowaiter.llm_client.retrieval_lm.src.normalize_text.normalize(text)
|
||||
batch_text.append(text)
|
||||
|
||||
if len(batch_text) == args.per_gpu_batch_size or k == len(passages) - 1:
|
||||
|
||||
encoded_batch = tokenizer.batch_encode_plus(
|
||||
batch_text,
|
||||
return_tensors="pt",
|
||||
max_length=args.passage_maxlength,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
||||
embeddings = model(**encoded_batch)
|
||||
|
||||
embeddings = embeddings.cpu()
|
||||
total += len(batch_ids)
|
||||
allids.extend(batch_ids)
|
||||
allembeddings.append(embeddings)
|
||||
|
||||
batch_text = []
|
||||
batch_ids = []
|
||||
if k % 100000 == 0 and k > 0:
|
||||
print(f"Encoded passages {total}")
|
||||
|
||||
allembeddings = torch.cat(allembeddings, dim=0).numpy()
|
||||
return allids, allembeddings
|
||||
|
||||
|
||||
def main(args):
|
||||
model, tokenizer, _ = robowaiter.llm_client.retrieval_lm.src.contriever.load_retriever(args.model_name_or_path)
|
||||
print(f"Model loaded from {args.model_name_or_path}.", flush=True)
|
||||
model.eval()
|
||||
model = model.cuda()
|
||||
if not args.no_fp16:
|
||||
model = model.half()
|
||||
|
||||
passages = robowaiter.llm_client.retrieval_lm.src.data.load_passages(args.passages)
|
||||
|
||||
shard_size = len(passages) // args.num_shards
|
||||
start_idx = args.shard_id * shard_size
|
||||
end_idx = start_idx + shard_size
|
||||
if args.shard_id == args.num_shards - 1:
|
||||
end_idx = len(passages)
|
||||
|
||||
passages = passages[start_idx:end_idx]
|
||||
print(f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.")
|
||||
|
||||
allids, allembeddings = embed_passages(args, passages, model, tokenizer)
|
||||
|
||||
save_file = os.path.join(args.output_dir, args.prefix + f"_{args.shard_id:02d}")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
print(f"Saving {len(allids)} passage embeddings to {save_file}.")
|
||||
with open(save_file, mode="wb") as f:
|
||||
pickle.dump((allids, allembeddings), f)
|
||||
|
||||
print(f"Total passages processed {len(allids)}. Written to {save_file}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
|
||||
parser.add_argument("--output_dir", type=str, default="wikipedia_embeddings", help="dir path to save embeddings")
|
||||
parser.add_argument("--prefix", type=str, default="passages", help="prefix path to save embeddings")
|
||||
parser.add_argument("--shard_id", type=int, default=0, help="Id of the current shard")
|
||||
parser.add_argument("--num_shards", type=int, default=1, help="Total number of shards")
|
||||
parser.add_argument(
|
||||
"--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass"
|
||||
)
|
||||
parser.add_argument("--passage_maxlength", type=int, default=512, help="Maximum number of tokens in a passage")
|
||||
parser.add_argument(
|
||||
"--model_name_or_path", type=str, help="path to directory containing model weights and config file"
|
||||
)
|
||||
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
|
||||
parser.add_argument("--no_title", action="store_true", help="title not added to the passage body")
|
||||
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
|
||||
parser.add_argument("--normalize_text", action="store_true", help="lowercase text before encoding")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
robowaiter.llm_client.retrieval_lm.src.slurm.init_distributed_mode(args)
|
||||
|
||||
main(args)
|
|
@ -0,0 +1,119 @@
|
|||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
||||
except ImportError:
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel
|
||||
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
assert past_key_value is None, "past_key_value is not supported"
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
assert not output_attentions, "output_attentions is not supported"
|
||||
assert not use_cache, "use_cache is not supported"
|
||||
|
||||
# Flash attention codes from
|
||||
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
||||
|
||||
# transform the data into the format required by flash attention
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask = attention_mask
|
||||
|
||||
if key_padding_mask is None:
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
max_s = q_len
|
||||
cu_q_lens = torch.arange(
|
||||
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
|
||||
)
|
||||
output = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
else:
|
||||
nheads = qkv.shape[-2]
|
||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||
x_unpad = rearrange(
|
||||
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
||||
)
|
||||
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
||||
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(
|
||||
pad_input(
|
||||
rearrange(output_unpad,
|
||||
"nnz h d -> nnz (h d)"), indices, bsz, q_len
|
||||
),
|
||||
"b s (h d) -> b s h d",
|
||||
h=nheads,
|
||||
)
|
||||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# [bsz, seq_len]
|
||||
return attention_mask
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn():
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
|
@ -0,0 +1,87 @@
|
|||
import numpy as np
|
||||
import string
|
||||
import re
|
||||
from collections import Counter
|
||||
import re
|
||||
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return (normalize_answer(prediction) == normalize_answer(ground_truth))
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
scores_for_ground_truths = []
|
||||
for ground_truth in ground_truths:
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
def accuracy(preds, labels):
|
||||
match_count = 0
|
||||
for pred, label in zip(preds, labels):
|
||||
target = label[0]
|
||||
if pred == target:
|
||||
match_count += 1
|
||||
|
||||
return 100 * (match_count / len(preds))
|
||||
|
||||
|
||||
def f1(decoded_preds, decoded_labels):
|
||||
f1_all = []
|
||||
for prediction, answers in zip(decoded_preds, decoded_labels):
|
||||
if type(answers) == list:
|
||||
if len(answers) == 0:
|
||||
return 0
|
||||
f1_all.append(np.max([qa_f1_score(prediction, gt)
|
||||
for gt in answers]))
|
||||
else:
|
||||
f1_all.append(qa_f1_score(prediction, answers))
|
||||
return 100 * np.mean(f1_all)
|
||||
|
||||
|
||||
def qa_f1_score(prediction, ground_truth):
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
def remove_articles(text):
|
||||
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
def find_entity_tags(sentence):
|
||||
entity_regex = r'(.+?)(?=\s<|$)'
|
||||
tag_regex = r'<(.+?)>'
|
||||
entity_names = re.findall(entity_regex, sentence)
|
||||
tags = re.findall(tag_regex, sentence)
|
||||
|
||||
results = {}
|
||||
for entity, tag in zip(entity_names, tags):
|
||||
if "<" in entity:
|
||||
results[entity.split("> ")[1]] = tag
|
||||
else:
|
||||
results[entity] = tag
|
||||
return results
|
||||
|
||||
def match(prediction, ground_truth):
|
||||
for gt in ground_truth:
|
||||
if gt in prediction:
|
||||
return 1
|
||||
return 0
|
|
@ -0,0 +1,8 @@
|
|||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3 ../generate_passage_embeddings.py \
|
||||
--model_name_or_path ../../model/contriever-msmarco \
|
||||
--passages train_robot.jsonl \
|
||||
--output_dir robot_embeddings \
|
||||
--shard_id 0 \
|
||||
--num_shards 1 \
|
||||
--per_gpu_batch_size 500
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,250 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import time
|
||||
import glob
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import src.index
|
||||
import src.contriever
|
||||
import src.utils
|
||||
import src.slurm
|
||||
import src.data
|
||||
from src.evaluation import calculate_matches
|
||||
import src.normalize_text
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
|
||||
def embed_queries(args, queries, model, tokenizer):
|
||||
model.eval()
|
||||
embeddings, batch_question = [], []
|
||||
with torch.no_grad():
|
||||
|
||||
for k, q in enumerate(queries):
|
||||
if args.lowercase:
|
||||
q = q.lower()
|
||||
if args.normalize_text:
|
||||
q = src.normalize_text.normalize(q)
|
||||
batch_question.append(q)
|
||||
|
||||
if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
|
||||
|
||||
encoded_batch = tokenizer.batch_encode_plus(
|
||||
batch_question,
|
||||
return_tensors="pt",
|
||||
max_length=args.question_maxlength,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
||||
output = model(**encoded_batch)
|
||||
embeddings.append(output.cpu())
|
||||
|
||||
batch_question = []
|
||||
|
||||
embeddings = torch.cat(embeddings, dim=0)
|
||||
print(f"Questions embeddings shape: {embeddings.size()}")
|
||||
|
||||
return embeddings.numpy()
|
||||
|
||||
|
||||
def index_encoded_data(index, embedding_files, indexing_batch_size):
|
||||
allids = []
|
||||
allembeddings = np.array([])
|
||||
for i, file_path in enumerate(embedding_files):
|
||||
print(f"Loading file {file_path}")
|
||||
with open(file_path, "rb") as fin:
|
||||
ids, embeddings = pickle.load(fin)
|
||||
|
||||
allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings
|
||||
allids.extend(ids)
|
||||
while allembeddings.shape[0] > indexing_batch_size:
|
||||
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
||||
|
||||
while allembeddings.shape[0] > 0:
|
||||
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
||||
|
||||
print("Data indexing completed.")
|
||||
|
||||
|
||||
def add_embeddings(index, embeddings, ids, indexing_batch_size):
|
||||
end_idx = min(indexing_batch_size, embeddings.shape[0])
|
||||
ids_toadd = ids[:end_idx]
|
||||
embeddings_toadd = embeddings[:end_idx]
|
||||
ids = ids[end_idx:]
|
||||
embeddings = embeddings[end_idx:]
|
||||
index.index_data(ids_toadd, embeddings_toadd)
|
||||
return embeddings, ids
|
||||
|
||||
|
||||
def validate(data, workers_num):
|
||||
match_stats = calculate_matches(data, workers_num)
|
||||
top_k_hits = match_stats.top_k_hits
|
||||
|
||||
print("Validation results: top k documents hits %s", top_k_hits)
|
||||
top_k_hits = [v / len(data) for v in top_k_hits]
|
||||
message = ""
|
||||
for k in [5, 10, 20, 100]:
|
||||
if k <= len(top_k_hits):
|
||||
message += f"R@{k}: {top_k_hits[k-1]} "
|
||||
print(message)
|
||||
return match_stats.questions_doc_hits
|
||||
|
||||
|
||||
def add_passages(data, passages, top_passages_and_scores):
|
||||
# add passages to original data
|
||||
merged_data = []
|
||||
assert len(data) == len(top_passages_and_scores)
|
||||
for i, d in enumerate(data):
|
||||
results_and_scores = top_passages_and_scores[i]
|
||||
#print(passages[2393])
|
||||
docs = [passages[int(doc_id)] for doc_id in results_and_scores[0]]
|
||||
scores = [str(score) for score in results_and_scores[1]]
|
||||
ctxs_num = len(docs)
|
||||
d["ctxs"] = [
|
||||
{
|
||||
"id": results_and_scores[0][c],
|
||||
"title": docs[c]["title"],
|
||||
"text": docs[c]["text"],
|
||||
"score": scores[c],
|
||||
}
|
||||
for c in range(ctxs_num)
|
||||
]
|
||||
|
||||
|
||||
def add_hasanswer(data, hasanswer):
|
||||
# add hasanswer to data
|
||||
for i, ex in enumerate(data):
|
||||
for k, d in enumerate(ex["ctxs"]):
|
||||
d["hasanswer"] = hasanswer[i][k]
|
||||
|
||||
|
||||
def load_data(data_path):
|
||||
if data_path.endswith(".json"):
|
||||
with open(data_path, "r") as fin:
|
||||
data = json.load(fin)
|
||||
elif data_path.endswith(".jsonl"):
|
||||
data = []
|
||||
with open(data_path, "r") as fin:
|
||||
for k, example in enumerate(fin):
|
||||
example = json.loads(example)
|
||||
data.append(example)
|
||||
return data
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
print(f"Loading model from: {args.model_name_or_path}")
|
||||
model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path)
|
||||
model.eval()
|
||||
model = model.cuda()
|
||||
if not args.no_fp16:
|
||||
model = model.half()
|
||||
|
||||
index = src.index.Indexer(args.projection_size, args.n_subquantizers, args.n_bits)
|
||||
|
||||
# index all passages
|
||||
input_paths = glob.glob(args.passages_embeddings)
|
||||
input_paths = sorted(input_paths)
|
||||
embeddings_dir = os.path.dirname(input_paths[0])
|
||||
index_path = os.path.join(embeddings_dir, "index.faiss")
|
||||
if args.save_or_load_index and os.path.exists(index_path):
|
||||
index.deserialize_from(embeddings_dir)
|
||||
else:
|
||||
print(f"Indexing passages from files {input_paths}")
|
||||
start_time_indexing = time.time()
|
||||
index_encoded_data(index, input_paths, args.indexing_batch_size)
|
||||
print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
|
||||
if args.save_or_load_index:
|
||||
index.serialize(embeddings_dir)
|
||||
|
||||
# load passages
|
||||
passages = src.data.load_passages(args.passages)
|
||||
passage_id_map = {x["id"]: x for x in passages}
|
||||
|
||||
data_paths = glob.glob(args.data)
|
||||
alldata = []
|
||||
for path in data_paths:
|
||||
data = load_data(path)
|
||||
output_path = os.path.join(args.output_dir, os.path.basename(path))
|
||||
|
||||
queries = [ex["question"] for ex in data]
|
||||
questions_embedding = embed_queries(args, queries, model, tokenizer)
|
||||
|
||||
# get top k results
|
||||
start_time_retrieval = time.time()
|
||||
top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
|
||||
print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
|
||||
|
||||
add_passages(data, passage_id_map, top_ids_and_scores)
|
||||
#hasanswer = validate(data, args.validation_workers)
|
||||
#add_hasanswer(data, hasanswer)
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
with open(output_path, "w") as fout:
|
||||
for ex in data:
|
||||
json.dump(ex, fout, ensure_ascii=False)
|
||||
fout.write("\n")
|
||||
print(f"Saved results to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
required=True,
|
||||
type=str,
|
||||
default=None,
|
||||
help=".json file containing question and answers, similar format to reader data",
|
||||
)
|
||||
parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
|
||||
parser.add_argument("--passages_embeddings", type=str, default=None, help="Glob path to encoded passages")
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default=None, help="Results are written to outputdir with data suffix"
|
||||
)
|
||||
parser.add_argument("--n_docs", type=int, default=100, help="Number of documents to retrieve per questions")
|
||||
parser.add_argument(
|
||||
"--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
|
||||
)
|
||||
parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding")
|
||||
parser.add_argument(
|
||||
"--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path", type=str, help="path to directory containing model weights and config file"
|
||||
)
|
||||
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
|
||||
parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
|
||||
parser.add_argument(
|
||||
"--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
|
||||
)
|
||||
parser.add_argument("--projection_size", type=int, default=768)
|
||||
parser.add_argument(
|
||||
"--n_subquantizers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of subquantizer used for vector quantization, if 0 flat index is used",
|
||||
)
|
||||
parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer")
|
||||
parser.add_argument("--lang", nargs="+")
|
||||
parser.add_argument("--dataset", type=str, default="none")
|
||||
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
|
||||
parser.add_argument("--normalize_text", action="store_true", help="normalize text")
|
||||
|
||||
args = parser.parse_args()
|
||||
src.slurm.init_distributed_mode(args)
|
||||
main(args)
|
|
@ -0,0 +1,41 @@
|
|||
import json
|
||||
import jsonlines
|
||||
import argparse
|
||||
|
||||
def train(args):
|
||||
filename=args.passages
|
||||
with open(filename, 'r', encoding="utf-8") as f:
|
||||
k=0
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
dict={"id":k,'title':data['title'],'text':data['text']}
|
||||
k+=1
|
||||
with jsonlines.open("train_robot.jsonl", "a") as file_jsonl:
|
||||
file_jsonl.write(dict)
|
||||
|
||||
def test(args):
|
||||
filename = args.passages
|
||||
with open(filename, 'r', encoding="utf-8") as f:
|
||||
k=0
|
||||
for line in f:
|
||||
if k<1000:
|
||||
data = json.loads(line)
|
||||
dict={"id":data['id'],'question':data['title'],'answers':data['text']}
|
||||
k+=1
|
||||
with jsonlines.open("test_robot.jsonl", "a") as file_jsonl:
|
||||
file_jsonl.write(dict)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--passages", type=str, default=None, help="Path to passages")
|
||||
parser.add_argument("--mode", type=str, default=None, help="train or test")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode=='train':
|
||||
train(args)
|
||||
elif args.mode=='test':
|
||||
test(args)
|
||||
else:
|
||||
print("error mode!")
|
Binary file not shown.
|
@ -0,0 +1,2 @@
|
|||
{"id": 0, "question": "请把酸奶放在咖啡台上,并打开窗帘。", "ctxs": [{"id": "0", "title": "请把酸奶放在咖啡台上,并打开窗帘。", "text": "On(Yogurt,CoffeeTable),Is(Curtain,Open)", "score": "1.9694625"}, {"id": "1", "title": "可以把牛奶饮料放在2号桌子上吗?还有关掉灯光。", "text": "On(MilkDrink,Table2),Is(TubeLight,Off)", "score": "1.8284101"}, {"id": "2", "title": "你好,可以给我上一份甜点吗?", "text": "On(Dessert,Table1)", "score": "1.4835652"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "1.4412252"}, {"id": "4", "title": "可以送一瓶牛奶饮料到1号桌吗?", "text": "On(MilkDrink,Table1)", "score": "1.2867957"}, {"id": "3", "title": "你能到另一个吧台这边来吗?空调可以关掉吗?", "text": "At(Robot,Bar2),Is(AC,On)", "score": "1.2599907"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}]}
|
||||
{"id": 1, "question": "可以把牛奶饮料放在2号桌子上吗?还有关掉灯光。", "ctxs": [{"id": "1", "title": "可以把牛奶饮料放在2号桌子上吗?还有关掉灯光。", "text": "On(MilkDrink,Table2),Is(TubeLight,Off)", "score": "2.138029"}, {"id": "0", "title": "请把酸奶放在咖啡台上,并打开窗帘。", "text": "On(Yogurt,CoffeeTable),Is(Curtain,Open)", "score": "1.8282425"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "1.6972268"}, {"id": "2", "title": "你好,可以给我上一份甜点吗?", "text": "On(Dessert,Table1)", "score": "1.4741647"}, {"id": "4", "title": "可以送一瓶牛奶饮料到1号桌吗?", "text": "On(MilkDrink,Table1)", "score": "1.4532053"}, {"id": "3", "title": "你能到另一个吧台这边来吗?空调可以关掉吗?", "text": "At(Robot,Bar2),Is(AC,On)", "score": "1.3438905"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}]}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,208 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import List, Dict
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import beir.util
|
||||
from beir.datasets.data_loader import GenericDataLoader
|
||||
from beir.retrieval.evaluation import EvaluateRetrieval
|
||||
from beir.retrieval.search.dense import DenseRetrievalExactSearch
|
||||
|
||||
from beir.reranking.models import CrossEncoder
|
||||
from beir.reranking import Rerank
|
||||
|
||||
import src.dist_utils as dist_utils
|
||||
from src import normalize_text
|
||||
|
||||
|
||||
class DenseEncoderModel:
|
||||
def __init__(
|
||||
self,
|
||||
query_encoder,
|
||||
doc_encoder=None,
|
||||
tokenizer=None,
|
||||
max_length=512,
|
||||
add_special_tokens=True,
|
||||
norm_query=False,
|
||||
norm_doc=False,
|
||||
lower_case=False,
|
||||
normalize_text=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.query_encoder = query_encoder
|
||||
self.doc_encoder = doc_encoder
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.add_special_tokens = add_special_tokens
|
||||
self.norm_query = norm_query
|
||||
self.norm_doc = norm_doc
|
||||
self.lower_case = lower_case
|
||||
self.normalize_text = normalize_text
|
||||
|
||||
def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray:
|
||||
|
||||
if dist.is_initialized():
|
||||
idx = np.array_split(range(len(queries)), dist.get_world_size())[dist.get_rank()]
|
||||
else:
|
||||
idx = range(len(queries))
|
||||
|
||||
queries = [queries[i] for i in idx]
|
||||
if self.normalize_text:
|
||||
queries = [normalize_text.normalize(q) for q in queries]
|
||||
if self.lower_case:
|
||||
queries = [q.lower() for q in queries]
|
||||
|
||||
allemb = []
|
||||
nbatch = (len(queries) - 1) // batch_size + 1
|
||||
with torch.no_grad():
|
||||
for k in range(nbatch):
|
||||
start_idx = k * batch_size
|
||||
end_idx = min((k + 1) * batch_size, len(queries))
|
||||
|
||||
qencode = self.tokenizer.batch_encode_plus(
|
||||
queries[start_idx:end_idx],
|
||||
max_length=self.max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
return_tensors="pt",
|
||||
)
|
||||
qencode = {key: value.cuda() for key, value in qencode.items()}
|
||||
emb = self.query_encoder(**qencode, normalize=self.norm_query)
|
||||
allemb.append(emb.cpu())
|
||||
|
||||
allemb = torch.cat(allemb, dim=0)
|
||||
allemb = allemb.cuda()
|
||||
if dist.is_initialized():
|
||||
allemb = dist_utils.varsize_gather_nograd(allemb)
|
||||
allemb = allemb.cpu().numpy()
|
||||
return allemb
|
||||
|
||||
def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs):
|
||||
|
||||
if dist.is_initialized():
|
||||
idx = np.array_split(range(len(corpus)), dist.get_world_size())[dist.get_rank()]
|
||||
else:
|
||||
idx = range(len(corpus))
|
||||
corpus = [corpus[i] for i in idx]
|
||||
corpus = [c["title"] + " " + c["text"] if len(c["title"]) > 0 else c["text"] for c in corpus]
|
||||
if self.normalize_text:
|
||||
corpus = [normalize_text.normalize(c) for c in corpus]
|
||||
if self.lower_case:
|
||||
corpus = [c.lower() for c in corpus]
|
||||
|
||||
allemb = []
|
||||
nbatch = (len(corpus) - 1) // batch_size + 1
|
||||
with torch.no_grad():
|
||||
for k in range(nbatch):
|
||||
start_idx = k * batch_size
|
||||
end_idx = min((k + 1) * batch_size, len(corpus))
|
||||
|
||||
cencode = self.tokenizer.batch_encode_plus(
|
||||
corpus[start_idx:end_idx],
|
||||
max_length=self.max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
return_tensors="pt",
|
||||
)
|
||||
cencode = {key: value.cuda() for key, value in cencode.items()}
|
||||
emb = self.doc_encoder(**cencode, normalize=self.norm_doc)
|
||||
allemb.append(emb.cpu())
|
||||
|
||||
allemb = torch.cat(allemb, dim=0)
|
||||
allemb = allemb.cuda()
|
||||
if dist.is_initialized():
|
||||
allemb = dist_utils.varsize_gather_nograd(allemb)
|
||||
allemb = allemb.cpu().numpy()
|
||||
return allemb
|
||||
|
||||
|
||||
def evaluate_model(
|
||||
query_encoder,
|
||||
doc_encoder,
|
||||
tokenizer,
|
||||
dataset,
|
||||
batch_size=128,
|
||||
add_special_tokens=True,
|
||||
norm_query=False,
|
||||
norm_doc=False,
|
||||
is_main=True,
|
||||
split="test",
|
||||
score_function="dot",
|
||||
beir_dir="BEIR/datasets",
|
||||
save_results_path=None,
|
||||
lower_case=False,
|
||||
normalize_text=False,
|
||||
):
|
||||
|
||||
metrics = defaultdict(list) # store final results
|
||||
|
||||
if hasattr(query_encoder, "module"):
|
||||
query_encoder = query_encoder.module
|
||||
query_encoder.eval()
|
||||
|
||||
if doc_encoder is not None:
|
||||
if hasattr(doc_encoder, "module"):
|
||||
doc_encoder = doc_encoder.module
|
||||
doc_encoder.eval()
|
||||
else:
|
||||
doc_encoder = query_encoder
|
||||
|
||||
dmodel = DenseRetrievalExactSearch(
|
||||
DenseEncoderModel(
|
||||
query_encoder=query_encoder,
|
||||
doc_encoder=doc_encoder,
|
||||
tokenizer=tokenizer,
|
||||
add_special_tokens=add_special_tokens,
|
||||
norm_query=norm_query,
|
||||
norm_doc=norm_doc,
|
||||
lower_case=lower_case,
|
||||
normalize_text=normalize_text,
|
||||
),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
retriever = EvaluateRetrieval(dmodel, score_function=score_function)
|
||||
data_path = os.path.join(beir_dir, dataset)
|
||||
|
||||
if not os.path.isdir(data_path) and is_main:
|
||||
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
|
||||
data_path = beir.util.download_and_unzip(url, beir_dir)
|
||||
dist_utils.barrier()
|
||||
|
||||
if not dataset == "cqadupstack":
|
||||
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
|
||||
results = retriever.retrieve(corpus, queries)
|
||||
if is_main:
|
||||
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
|
||||
for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"):
|
||||
if isinstance(metric, str):
|
||||
metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric)
|
||||
for key, value in metric.items():
|
||||
metrics[key].append(value)
|
||||
if save_results_path is not None:
|
||||
torch.save(results, f"{save_results_path}")
|
||||
elif dataset == "cqadupstack": # compute macroaverage over datasets
|
||||
paths = glob.glob(data_path)
|
||||
for path in paths:
|
||||
corpus, queries, qrels = GenericDataLoader(data_folder=data_folder).load(split=split)
|
||||
results = retriever.retrieve(corpus, queries)
|
||||
if is_main:
|
||||
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
|
||||
for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"):
|
||||
if isinstance(metric, str):
|
||||
metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric)
|
||||
for key, value in metric.items():
|
||||
metrics[key].append(value)
|
||||
for key, value in metrics.items():
|
||||
assert (
|
||||
len(value) == 12
|
||||
), f"cqadupstack includes 12 datasets, only {len(value)} values were compute for the {key} metric"
|
||||
|
||||
metrics = {key: 100 * np.mean(value) for key, value in metrics.items()}
|
||||
|
||||
return metrics
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import os
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import BertModel, XLMRobertaModel
|
||||
|
||||
from robowaiter.algos.retrieval.retrieval_lm.src import utils
|
||||
|
||||
|
||||
|
||||
class Contriever(BertModel):
|
||||
def __init__(self, config, pooling="average", **kwargs):
|
||||
super().__init__(config, add_pooling_layer=False)
|
||||
if not hasattr(config, "pooling"):
|
||||
self.config.pooling = pooling
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
normalize=False,
|
||||
):
|
||||
|
||||
model_output = super().forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
last_hidden = model_output["last_hidden_state"]
|
||||
last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||
|
||||
if self.config.pooling == "average":
|
||||
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||
elif self.config.pooling == "cls":
|
||||
emb = last_hidden[:, 0]
|
||||
|
||||
if normalize:
|
||||
emb = torch.nn.functional.normalize(emb, dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class XLMRetriever(XLMRobertaModel):
|
||||
def __init__(self, config, pooling="average", **kwargs):
|
||||
super().__init__(config, add_pooling_layer=False)
|
||||
if not hasattr(config, "pooling"):
|
||||
self.config.pooling = pooling
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
normalize=False,
|
||||
):
|
||||
|
||||
model_output = super().forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
last_hidden = model_output["last_hidden_state"]
|
||||
last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||
if self.config.pooling == "average":
|
||||
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||
elif self.config.pooling == "cls":
|
||||
emb = last_hidden[:, 0]
|
||||
if normalize:
|
||||
emb = torch.nn.functional.normalize(emb, dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
def load_retriever(model_path, pooling="average", random_init=False):
|
||||
# try: check if model exists locally
|
||||
path = os.path.join(model_path, "checkpoint.pth")
|
||||
if os.path.exists(path):
|
||||
pretrained_dict = torch.load(path, map_location="cpu")
|
||||
opt = pretrained_dict["opt"]
|
||||
if hasattr(opt, "retriever_model_id"):
|
||||
retriever_model_id = opt.retriever_model_id
|
||||
else:
|
||||
# retriever_model_id = "bert-base-uncased"
|
||||
retriever_model_id = "bert-base-multilingual-cased"
|
||||
tokenizer = utils.load_hf(transformers.AutoTokenizer, retriever_model_id)
|
||||
cfg = utils.load_hf(transformers.AutoConfig, retriever_model_id)
|
||||
if "xlm" in retriever_model_id:
|
||||
model_class = XLMRetriever
|
||||
else:
|
||||
model_class = Contriever
|
||||
retriever = model_class(cfg)
|
||||
pretrained_dict = pretrained_dict["model"]
|
||||
|
||||
if any("encoder_q." in key for key in pretrained_dict.keys()): # test if model is defined with moco class
|
||||
pretrained_dict = {k.replace("encoder_q.", ""): v for k, v in pretrained_dict.items() if "encoder_q." in k}
|
||||
elif any("encoder." in key for key in pretrained_dict.keys()): # test if model is defined with inbatch class
|
||||
pretrained_dict = {k.replace("encoder.", ""): v for k, v in pretrained_dict.items() if "encoder." in k}
|
||||
retriever.load_state_dict(pretrained_dict, strict=False)
|
||||
else:
|
||||
retriever_model_id = model_path
|
||||
if "xlm" in retriever_model_id:
|
||||
model_class = XLMRetriever
|
||||
else:
|
||||
model_class = Contriever
|
||||
cfg = utils.load_hf(transformers.AutoConfig, model_path)
|
||||
tokenizer = utils.load_hf(transformers.AutoTokenizer, model_path)
|
||||
retriever = utils.load_hf(model_class, model_path)
|
||||
|
||||
return retriever, tokenizer, retriever_model_id
|
|
@ -0,0 +1,243 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import random
|
||||
import json
|
||||
import csv
|
||||
import numpy as np
|
||||
import numpy.random
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from robowaiter.algos.retrieval.retrieval_lm.src import dist_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_data(opt, tokenizer):
|
||||
datasets = {}
|
||||
for path in opt.train_data:
|
||||
data = load_dataset(path, opt.loading_mode)
|
||||
if data is not None:
|
||||
datasets[path] = Dataset(data, opt.chunk_length, tokenizer, opt)
|
||||
dataset = MultiDataset(datasets)
|
||||
dataset.set_prob(coeff=opt.sampling_coefficient)
|
||||
return dataset
|
||||
|
||||
|
||||
def load_dataset(data_path, loading_mode):
|
||||
files = glob.glob(os.path.join(data_path, "*.p*"))
|
||||
files.sort()
|
||||
tensors = []
|
||||
if loading_mode == "split":
|
||||
files_split = list(np.array_split(files, dist_utils.get_world_size()))[dist_utils.get_rank()]
|
||||
for filepath in files_split:
|
||||
try:
|
||||
tensors.append(torch.load(filepath, map_location="cpu"))
|
||||
except:
|
||||
logger.warning(f"Unable to load file {filepath}")
|
||||
elif loading_mode == "full":
|
||||
for fin in files:
|
||||
tensors.append(torch.load(fin, map_location="cpu"))
|
||||
elif loading_mode == "single":
|
||||
tensors.append(torch.load(files[0], map_location="cpu"))
|
||||
if len(tensors) == 0:
|
||||
return None
|
||||
tensor = torch.cat(tensors)
|
||||
return tensor
|
||||
|
||||
|
||||
class MultiDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, datasets):
|
||||
|
||||
self.datasets = datasets
|
||||
self.prob = [1 / len(self.datasets) for _ in self.datasets]
|
||||
self.dataset_ids = list(self.datasets.keys())
|
||||
|
||||
def __len__(self):
|
||||
return sum([len(dataset) for dataset in self.datasets.values()])
|
||||
|
||||
def __getitem__(self, index):
|
||||
dataset_idx = numpy.random.choice(range(len(self.prob)), 1, p=self.prob)[0]
|
||||
did = self.dataset_ids[dataset_idx]
|
||||
index = random.randint(0, len(self.datasets[did]) - 1)
|
||||
sample = self.datasets[did][index]
|
||||
sample["dataset_id"] = did
|
||||
return sample
|
||||
|
||||
def generate_offset(self):
|
||||
for dataset in self.datasets.values():
|
||||
dataset.generate_offset()
|
||||
|
||||
def set_prob(self, coeff=0.0):
|
||||
|
||||
prob = np.array([float(len(dataset)) for _, dataset in self.datasets.items()])
|
||||
prob /= prob.sum()
|
||||
prob = np.array([p**coeff for p in prob])
|
||||
prob /= prob.sum()
|
||||
self.prob = prob
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
"""Monolingual dataset based on a list of paths"""
|
||||
|
||||
def __init__(self, data, chunk_length, tokenizer, opt):
|
||||
|
||||
self.data = data
|
||||
self.chunk_length = chunk_length
|
||||
self.tokenizer = tokenizer
|
||||
self.opt = opt
|
||||
self.generate_offset()
|
||||
|
||||
def __len__(self):
|
||||
return (self.data.size(0) - self.offset) // self.chunk_length
|
||||
|
||||
def __getitem__(self, index):
|
||||
start_idx = self.offset + index * self.chunk_length
|
||||
end_idx = start_idx + self.chunk_length
|
||||
tokens = self.data[start_idx:end_idx]
|
||||
q_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max)
|
||||
k_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max)
|
||||
q_tokens = apply_augmentation(q_tokens, self.opt)
|
||||
q_tokens = add_bos_eos(q_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id)
|
||||
k_tokens = apply_augmentation(k_tokens, self.opt)
|
||||
k_tokens = add_bos_eos(k_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id)
|
||||
|
||||
return {"q_tokens": q_tokens, "k_tokens": k_tokens}
|
||||
|
||||
def generate_offset(self):
|
||||
self.offset = random.randint(0, self.chunk_length - 1)
|
||||
|
||||
|
||||
class Collator(object):
|
||||
def __init__(self, opt):
|
||||
self.opt = opt
|
||||
|
||||
def __call__(self, batch_examples):
|
||||
|
||||
batch = defaultdict(list)
|
||||
for example in batch_examples:
|
||||
for k, v in example.items():
|
||||
batch[k].append(v)
|
||||
|
||||
q_tokens, q_mask = build_mask(batch["q_tokens"])
|
||||
k_tokens, k_mask = build_mask(batch["k_tokens"])
|
||||
|
||||
batch["q_tokens"] = q_tokens
|
||||
batch["q_mask"] = q_mask
|
||||
batch["k_tokens"] = k_tokens
|
||||
batch["k_mask"] = k_mask
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def randomcrop(x, ratio_min, ratio_max):
|
||||
|
||||
ratio = random.uniform(ratio_min, ratio_max)
|
||||
length = int(len(x) * ratio)
|
||||
start = random.randint(0, len(x) - length)
|
||||
end = start + length
|
||||
crop = x[start:end].clone()
|
||||
return crop
|
||||
|
||||
|
||||
def build_mask(tensors):
|
||||
shapes = [x.shape for x in tensors]
|
||||
maxlength = max([len(x) for x in tensors])
|
||||
returnmasks = []
|
||||
ids = []
|
||||
for k, x in enumerate(tensors):
|
||||
returnmasks.append(torch.tensor([1] * len(x) + [0] * (maxlength - len(x))))
|
||||
ids.append(torch.cat((x, torch.tensor([0] * (maxlength - len(x))))))
|
||||
ids = torch.stack(ids, dim=0).long()
|
||||
returnmasks = torch.stack(returnmasks, dim=0).bool()
|
||||
return ids, returnmasks
|
||||
|
||||
|
||||
def add_token(x, token):
|
||||
x = torch.cat((torch.tensor([token]), x))
|
||||
return x
|
||||
|
||||
|
||||
def deleteword(x, p=0.1):
|
||||
mask = np.random.rand(len(x))
|
||||
x = [e for e, m in zip(x, mask) if m > p]
|
||||
return x
|
||||
|
||||
|
||||
def replaceword(x, min_random, max_random, p=0.1):
|
||||
mask = np.random.rand(len(x))
|
||||
x = [e if m > p else random.randint(min_random, max_random) for e, m in zip(x, mask)]
|
||||
return x
|
||||
|
||||
|
||||
def maskword(x, mask_id, p=0.1):
|
||||
mask = np.random.rand(len(x))
|
||||
x = [e if m > p else mask_id for e, m in zip(x, mask)]
|
||||
return x
|
||||
|
||||
|
||||
def shuffleword(x, p=0.1):
|
||||
count = (np.random.rand(len(x)) < p).sum()
|
||||
"""Shuffles any n number of values in a list"""
|
||||
indices_to_shuffle = random.sample(range(len(x)), k=count)
|
||||
to_shuffle = [x[i] for i in indices_to_shuffle]
|
||||
random.shuffle(to_shuffle)
|
||||
for index, value in enumerate(to_shuffle):
|
||||
old_index = indices_to_shuffle[index]
|
||||
x[old_index] = value
|
||||
return x
|
||||
|
||||
|
||||
def apply_augmentation(x, opt):
|
||||
if opt.augmentation == "mask":
|
||||
return torch.tensor(maskword(x, mask_id=opt.mask_id, p=opt.prob_augmentation))
|
||||
elif opt.augmentation == "replace":
|
||||
return torch.tensor(
|
||||
replaceword(x, min_random=opt.start_id, max_random=opt.vocab_size - 1, p=opt.prob_augmentation)
|
||||
)
|
||||
elif opt.augmentation == "delete":
|
||||
return torch.tensor(deleteword(x, p=opt.prob_augmentation))
|
||||
elif opt.augmentation == "shuffle":
|
||||
return torch.tensor(shuffleword(x, p=opt.prob_augmentation))
|
||||
else:
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.Tensor(x)
|
||||
return x
|
||||
|
||||
|
||||
def add_bos_eos(x, bos_token_id, eos_token_id):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.Tensor(x)
|
||||
if bos_token_id is None and eos_token_id is not None:
|
||||
x = torch.cat([x.clone().detach(), torch.tensor([eos_token_id])])
|
||||
elif bos_token_id is not None and eos_token_id is None:
|
||||
x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach()])
|
||||
elif bos_token_id is None and eos_token_id is None:
|
||||
pass
|
||||
else:
|
||||
x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach(), torch.tensor([eos_token_id])])
|
||||
return x
|
||||
|
||||
|
||||
# Used for passage retrieval
|
||||
def load_passages(path):
|
||||
if not os.path.exists(path):
|
||||
logger.info(f"{path} does not exist")
|
||||
return
|
||||
logger.info(f"Loading passages from: {path}")
|
||||
passages = []
|
||||
with open(path,encoding='UTF-8') as fin:
|
||||
if path.endswith(".jsonl"):
|
||||
for k, line in enumerate(fin):
|
||||
ex = json.loads(line)
|
||||
passages.append(ex)
|
||||
else:
|
||||
reader = csv.reader(fin, delimiter="\t")
|
||||
for k, row in enumerate(reader):
|
||||
if not row[0] == "id":
|
||||
ex = {"id": row[0], "title": row[2], "text": row[1]}
|
||||
passages.append(ex)
|
||||
return passages
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class Gather(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: torch.tensor):
|
||||
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output, x)
|
||||
return tuple(output)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
all_gradients = torch.stack(grads)
|
||||
dist.all_reduce(all_gradients)
|
||||
return all_gradients[dist.get_rank()]
|
||||
|
||||
|
||||
def gather(x: torch.tensor):
|
||||
if not dist.is_initialized():
|
||||
return x
|
||||
x_gather = Gather.apply(x)
|
||||
x_gather = torch.cat(x_gather, dim=0)
|
||||
return x_gather
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gather_nograd(x: torch.tensor):
|
||||
if not dist.is_initialized():
|
||||
return x
|
||||
x_gather = [torch.ones_like(x) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(x_gather, x, async_op=False)
|
||||
|
||||
x_gather = torch.cat(x_gather, dim=0)
|
||||
return x_gather
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def varsize_gather_nograd(x: torch.Tensor):
|
||||
"""gather tensors of different sizes along the first dimension"""
|
||||
if not dist.is_initialized():
|
||||
return x
|
||||
|
||||
# determine max size
|
||||
size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int)
|
||||
allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(allsizes, size)
|
||||
max_size = max([size.cpu().max() for size in allsizes])
|
||||
|
||||
padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device)
|
||||
padded[: x.shape[0]] = x
|
||||
output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output, padded)
|
||||
|
||||
output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)]
|
||||
output = torch.cat(output, dim=0)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_varsize(x: torch.Tensor):
|
||||
"""gather tensors of different sizes along the first dimension"""
|
||||
if not dist.is_initialized():
|
||||
return [x.shape[0]]
|
||||
|
||||
# determine max size
|
||||
size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int)
|
||||
allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(allsizes, size)
|
||||
allsizes = torch.cat(allsizes)
|
||||
return allsizes
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
else:
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def barrier():
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def average_main(x):
|
||||
if not dist.is_initialized():
|
||||
return x
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
dist.reduce(x, 0, op=dist.ReduceOp.SUM)
|
||||
if is_main():
|
||||
x = x / dist.get_world_size()
|
||||
return x
|
||||
|
||||
|
||||
def sum_main(x):
|
||||
if not dist.is_initialized():
|
||||
return x
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
dist.reduce(x, 0, op=dist.ReduceOp.SUM)
|
||||
return x
|
||||
|
||||
|
||||
def weighted_average(x, count):
|
||||
if not dist.is_initialized():
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.item()
|
||||
return x, count
|
||||
t_loss = torch.tensor([x * count]).cuda()
|
||||
t_total = torch.tensor([count]).cuda()
|
||||
t_loss = sum_main(t_loss)
|
||||
t_total = sum_main(t_total)
|
||||
return (t_loss / t_total).item(), t_total.item()
|
|
@ -0,0 +1,190 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import regex
|
||||
import string
|
||||
import unicodedata
|
||||
from functools import partial
|
||||
from multiprocessing import Pool as ProcessPool
|
||||
from typing import Tuple, List, Dict
|
||||
import numpy as np
|
||||
|
||||
"""
|
||||
Evaluation code from DPR: https://github.com/facebookresearch/DPR
|
||||
"""
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
|
||||
NON_WS = r'[^\p{Z}\p{C}]'
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Args:
|
||||
annotators: None or empty set (only tokenizes).
|
||||
"""
|
||||
self._regexp = regex.compile(
|
||||
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
|
||||
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
|
||||
)
|
||||
|
||||
def tokenize(self, text, uncased=False):
|
||||
matches = [m for m in self._regexp.finditer(text)]
|
||||
if uncased:
|
||||
tokens = [m.group().lower() for m in matches]
|
||||
else:
|
||||
tokens = [m.group() for m in matches]
|
||||
return tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits'])
|
||||
|
||||
def calculate_matches(data: List, workers_num: int):
|
||||
"""
|
||||
Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of
|
||||
documents and results. It internally forks multiple sub-processes for evaluation and then merges results
|
||||
:param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title)
|
||||
:param answers: list of answers's list. One list per question
|
||||
:param closest_docs: document ids of the top results along with their scores
|
||||
:param workers_num: amount of parallel threads to process data
|
||||
:param match_type: type of answer matching. Refer to has_answer code for available options
|
||||
:return: matching information tuple.
|
||||
top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of
|
||||
valid matches across an entire dataset.
|
||||
questions_doc_hits - more detailed info with answer matches for every question and every retrieved document
|
||||
"""
|
||||
|
||||
logger.info('Matching answers in top docs...')
|
||||
|
||||
tokenizer = SimpleTokenizer()
|
||||
get_score_partial = partial(check_answer, tokenizer=tokenizer)
|
||||
|
||||
processes = ProcessPool(processes=workers_num)
|
||||
scores = processes.map(get_score_partial, data)
|
||||
|
||||
logger.info('Per question validation results len=%d', len(scores))
|
||||
|
||||
n_docs = len(data[0]['ctxs'])
|
||||
top_k_hits = [0] * n_docs
|
||||
for question_hits in scores:
|
||||
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
|
||||
if best_hit is not None:
|
||||
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
|
||||
|
||||
return QAMatchStats(top_k_hits, scores)
|
||||
|
||||
def check_answer(example, tokenizer) -> List[bool]:
|
||||
"""Search through all the top docs to see if they have any of the answers."""
|
||||
answers = example['answers']
|
||||
ctxs = example['ctxs']
|
||||
|
||||
hits = []
|
||||
|
||||
for i, doc in enumerate(ctxs):
|
||||
text = doc['text']
|
||||
|
||||
if text is None: # cannot find the document for some reason
|
||||
logger.warning("no doc in db")
|
||||
hits.append(False)
|
||||
continue
|
||||
|
||||
hits.append(has_answer(answers, text, tokenizer))
|
||||
|
||||
return hits
|
||||
|
||||
def has_answer(answers, text, tokenizer) -> bool:
|
||||
"""Check if a document contains an answer string."""
|
||||
text = _normalize(text)
|
||||
text = tokenizer.tokenize(text, uncased=True)
|
||||
|
||||
for answer in answers:
|
||||
answer = _normalize(answer)
|
||||
answer = tokenizer.tokenize(answer, uncased=True)
|
||||
for i in range(0, len(text) - len(answer) + 1):
|
||||
if answer == text[i: i + len(answer)]:
|
||||
return True
|
||||
return False
|
||||
|
||||
#################################################
|
||||
######## READER EVALUATION ########
|
||||
#################################################
|
||||
|
||||
def _normalize(text):
|
||||
return unicodedata.normalize('NFD', text)
|
||||
|
||||
#Normalization and score functions from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
|
||||
def normalize_answer(s):
|
||||
def remove_articles(text):
|
||||
return regex.sub(r'\b(a|an|the)\b', ' ', text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
def em(prediction, ground_truth):
|
||||
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||
|
||||
def f1(prediction, ground_truth):
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
def f1_score(prediction, ground_truths):
|
||||
return max([f1(prediction, gt) for gt in ground_truths])
|
||||
|
||||
def exact_match_score(prediction, ground_truths):
|
||||
return max([em(prediction, gt) for gt in ground_truths])
|
||||
|
||||
####################################################
|
||||
######## RETRIEVER EVALUATION ########
|
||||
####################################################
|
||||
|
||||
def eval_batch(scores, inversions, avg_topk, idx_topk):
|
||||
for k, s in enumerate(scores):
|
||||
s = s.cpu().numpy()
|
||||
sorted_idx = np.argsort(-s)
|
||||
score(sorted_idx, inversions, avg_topk, idx_topk)
|
||||
|
||||
def count_inversions(arr):
|
||||
inv_count = 0
|
||||
lenarr = len(arr)
|
||||
for i in range(lenarr):
|
||||
for j in range(i + 1, lenarr):
|
||||
if (arr[i] > arr[j]):
|
||||
inv_count += 1
|
||||
return inv_count
|
||||
|
||||
def score(x, inversions, avg_topk, idx_topk):
|
||||
x = np.array(x)
|
||||
inversions.append(count_inversions(x))
|
||||
for k in avg_topk:
|
||||
# ratio of passages in the predicted top-k that are
|
||||
# also in the topk given by gold score
|
||||
avg_pred_topk = (x[:k]<k).mean()
|
||||
avg_topk[k].append(avg_pred_topk)
|
||||
for k in idx_topk:
|
||||
below_k = (x<k)
|
||||
# number of passages required to obtain all passages from gold top-k
|
||||
idx_gold_topk = len(x) - np.argmax(below_k[::-1])
|
||||
idx_topk[k].append(idx_gold_topk)
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import torch
|
||||
import random
|
||||
import json
|
||||
import sys
|
||||
import numpy as np
|
||||
from src import normalize_text
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
datapaths,
|
||||
negative_ctxs=1,
|
||||
negative_hard_ratio=0.0,
|
||||
negative_hard_min_idx=0,
|
||||
training=False,
|
||||
global_rank=-1,
|
||||
world_size=-1,
|
||||
maxload=None,
|
||||
normalize=False,
|
||||
):
|
||||
self.negative_ctxs = negative_ctxs
|
||||
self.negative_hard_ratio = negative_hard_ratio
|
||||
self.negative_hard_min_idx = negative_hard_min_idx
|
||||
self.training = training
|
||||
self.normalize_fn = normalize_text.normalize if normalize_text else lambda x: x
|
||||
self._load_data(datapaths, global_rank, world_size, maxload)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = self.data[index]
|
||||
question = example["question"]
|
||||
if self.training:
|
||||
gold = random.choice(example["positive_ctxs"])
|
||||
|
||||
n_hard_negatives, n_random_negatives = self.sample_n_hard_negatives(example)
|
||||
negatives = []
|
||||
if n_random_negatives > 0:
|
||||
random_negatives = random.sample(example["negative_ctxs"], n_random_negatives)
|
||||
negatives += random_negatives
|
||||
if n_hard_negatives > 0:
|
||||
hard_negatives = random.sample(
|
||||
example["hard_negative_ctxs"][self.negative_hard_min_idx :], n_hard_negatives
|
||||
)
|
||||
negatives += hard_negatives
|
||||
else:
|
||||
gold = example["positive_ctxs"][0]
|
||||
nidx = 0
|
||||
if "negative_ctxs" in example:
|
||||
negatives = [example["negative_ctxs"][nidx]]
|
||||
else:
|
||||
negatives = []
|
||||
|
||||
gold = gold["title"] + " " + gold["text"] if "title" in gold and len(gold["title"]) > 0 else gold["text"]
|
||||
|
||||
negatives = [
|
||||
n["title"] + " " + n["text"] if ("title" in n and len(n["title"]) > 0) else n["text"] for n in negatives
|
||||
]
|
||||
|
||||
example = {
|
||||
"query": self.normalize_fn(question),
|
||||
"gold": self.normalize_fn(gold),
|
||||
"negatives": [self.normalize_fn(n) for n in negatives],
|
||||
}
|
||||
return example
|
||||
|
||||
def _load_data(self, datapaths, global_rank, world_size, maxload):
|
||||
counter = 0
|
||||
self.data = []
|
||||
for path in datapaths:
|
||||
path = str(path)
|
||||
if path.endswith(".jsonl"):
|
||||
file_data, counter = self._load_data_jsonl(path, global_rank, world_size, counter, maxload)
|
||||
elif path.endswith(".json"):
|
||||
file_data, counter = self._load_data_json(path, global_rank, world_size, counter, maxload)
|
||||
self.data.extend(file_data)
|
||||
if maxload is not None and maxload > 0 and counter >= maxload:
|
||||
break
|
||||
|
||||
def _load_data_json(self, path, global_rank, world_size, counter, maxload=None):
|
||||
examples = []
|
||||
with open(path, "r") as fin:
|
||||
data = json.load(fin)
|
||||
for example in data:
|
||||
counter += 1
|
||||
if global_rank > -1 and not counter % world_size == global_rank:
|
||||
continue
|
||||
examples.append(example)
|
||||
if maxload is not None and maxload > 0 and counter == maxload:
|
||||
break
|
||||
|
||||
return examples, counter
|
||||
|
||||
def _load_data_jsonl(self, path, global_rank, world_size, counter, maxload=None):
|
||||
examples = []
|
||||
with open(path, "r") as fin:
|
||||
for line in fin:
|
||||
counter += 1
|
||||
if global_rank > -1 and not counter % world_size == global_rank:
|
||||
continue
|
||||
example = json.loads(line)
|
||||
examples.append(example)
|
||||
if maxload is not None and maxload > 0 and counter == maxload:
|
||||
break
|
||||
|
||||
return examples, counter
|
||||
|
||||
def sample_n_hard_negatives(self, ex):
|
||||
|
||||
if "hard_negative_ctxs" in ex:
|
||||
n_hard_negatives = sum([random.random() < self.negative_hard_ratio for _ in range(self.negative_ctxs)])
|
||||
n_hard_negatives = min(n_hard_negatives, len(ex["hard_negative_ctxs"][self.negative_hard_min_idx :]))
|
||||
else:
|
||||
n_hard_negatives = 0
|
||||
n_random_negatives = self.negative_ctxs - n_hard_negatives
|
||||
if "negative_ctxs" in ex:
|
||||
n_random_negatives = min(n_random_negatives, len(ex["negative_ctxs"]))
|
||||
else:
|
||||
n_random_negatives = 0
|
||||
return n_hard_negatives, n_random_negatives
|
||||
|
||||
|
||||
class Collator(object):
|
||||
def __init__(self, tokenizer, passage_maxlength=200):
|
||||
self.tokenizer = tokenizer
|
||||
self.passage_maxlength = passage_maxlength
|
||||
|
||||
def __call__(self, batch):
|
||||
queries = [ex["query"] for ex in batch]
|
||||
golds = [ex["gold"] for ex in batch]
|
||||
negs = [item for ex in batch for item in ex["negatives"]]
|
||||
allpassages = golds + negs
|
||||
|
||||
qout = self.tokenizer.batch_encode_plus(
|
||||
queries,
|
||||
max_length=self.passage_maxlength,
|
||||
truncation=True,
|
||||
padding=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
kout = self.tokenizer.batch_encode_plus(
|
||||
allpassages,
|
||||
max_length=self.passage_maxlength,
|
||||
truncation=True,
|
||||
padding=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
q_tokens, q_mask = qout["input_ids"], qout["attention_mask"].bool()
|
||||
k_tokens, k_mask = kout["input_ids"], kout["attention_mask"].bool()
|
||||
|
||||
g_tokens, g_mask = k_tokens[: len(golds)], k_mask[: len(golds)]
|
||||
n_tokens, n_mask = k_tokens[len(golds) :], k_mask[len(golds) :]
|
||||
|
||||
batch = {
|
||||
"q_tokens": q_tokens,
|
||||
"q_mask": q_mask,
|
||||
"k_tokens": k_tokens,
|
||||
"k_mask": k_mask,
|
||||
"g_tokens": g_tokens,
|
||||
"g_mask": g_mask,
|
||||
"n_tokens": n_tokens,
|
||||
"n_mask": n_mask,
|
||||
}
|
||||
|
||||
return batch
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import math
|
||||
import random
|
||||
import transformers
|
||||
import logging
|
||||
import torch.distributed as dist
|
||||
|
||||
from src import contriever, dist_utils, utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InBatch(nn.Module):
|
||||
def __init__(self, opt, retriever=None, tokenizer=None):
|
||||
super(InBatch, self).__init__()
|
||||
|
||||
self.opt = opt
|
||||
self.norm_doc = opt.norm_doc
|
||||
self.norm_query = opt.norm_query
|
||||
self.label_smoothing = opt.label_smoothing
|
||||
if retriever is None or tokenizer is None:
|
||||
retriever, tokenizer = self._load_retriever(
|
||||
opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
self.encoder = retriever
|
||||
|
||||
def _load_retriever(self, model_id, pooling, random_init):
|
||||
cfg = utils.load_hf(transformers.AutoConfig, model_id)
|
||||
tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id)
|
||||
|
||||
if "xlm" in model_id:
|
||||
model_class = contriever.XLMRetriever
|
||||
else:
|
||||
model_class = contriever.Contriever
|
||||
|
||||
if random_init:
|
||||
retriever = model_class(cfg)
|
||||
else:
|
||||
retriever = utils.load_hf(model_class, model_id)
|
||||
|
||||
if "bert-" in model_id:
|
||||
if tokenizer.bos_token_id is None:
|
||||
tokenizer.bos_token = "[CLS]"
|
||||
if tokenizer.eos_token_id is None:
|
||||
tokenizer.eos_token = "[SEP]"
|
||||
|
||||
retriever.config.pooling = pooling
|
||||
|
||||
return retriever, tokenizer
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs):
|
||||
|
||||
bsz = len(q_tokens)
|
||||
labels = torch.arange(0, bsz, dtype=torch.long, device=q_tokens.device)
|
||||
|
||||
qemb = self.encoder(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query)
|
||||
kemb = self.encoder(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc)
|
||||
|
||||
gather_fn = dist_utils.gather
|
||||
|
||||
gather_kemb = gather_fn(kemb)
|
||||
|
||||
labels = labels + dist_utils.get_rank() * len(kemb)
|
||||
|
||||
scores = torch.einsum("id, jd->ij", qemb / self.opt.temperature, gather_kemb)
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=self.label_smoothing)
|
||||
|
||||
# log stats
|
||||
if len(stats_prefix) > 0:
|
||||
stats_prefix = stats_prefix + "/"
|
||||
iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz)
|
||||
|
||||
predicted_idx = torch.argmax(scores, dim=-1)
|
||||
accuracy = 100 * (predicted_idx == labels).float().mean()
|
||||
stdq = torch.std(qemb, dim=0).mean().item()
|
||||
stdk = torch.std(kemb, dim=0).mean().item()
|
||||
iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz)
|
||||
iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz)
|
||||
iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz)
|
||||
|
||||
return loss, iter_stats
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from typing import List, Tuple
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
class Indexer(object):
|
||||
|
||||
def __init__(self, vector_sz, n_subquantizers=0, n_bits=8):
|
||||
if n_subquantizers > 0:
|
||||
self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
|
||||
else:
|
||||
self.index = faiss.IndexFlatIP(vector_sz)
|
||||
#self.index_id_to_db_id = np.empty((0), dtype=np.int64)
|
||||
self.index_id_to_db_id = []
|
||||
|
||||
def index_data(self, ids, embeddings):
|
||||
self._update_id_mapping(ids)
|
||||
embeddings = embeddings.astype('float32')
|
||||
if not self.index.is_trained:
|
||||
self.index.train(embeddings)
|
||||
self.index.add(embeddings)
|
||||
|
||||
print(f'Total data indexed {len(self.index_id_to_db_id)}')
|
||||
|
||||
def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 2048) -> List[Tuple[List[object], List[float]]]:
|
||||
query_vectors = query_vectors.astype('float32')
|
||||
result = []
|
||||
nbatch = (len(query_vectors)-1) // index_batch_size + 1
|
||||
for k in tqdm(range(nbatch)):
|
||||
start_idx = k*index_batch_size
|
||||
end_idx = min((k+1)*index_batch_size, len(query_vectors))
|
||||
q = query_vectors[start_idx: end_idx]
|
||||
scores, indexes = self.index.search(q, top_docs)
|
||||
# convert to external ids
|
||||
db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
|
||||
result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))])
|
||||
return result
|
||||
|
||||
def serialize(self, dir_path):
|
||||
index_file = os.path.join(dir_path, 'index.faiss')
|
||||
meta_file = os.path.join(dir_path, 'index_meta.faiss')
|
||||
print(f'Serializing index to {index_file}, meta data to {meta_file}')
|
||||
|
||||
faiss.write_index(self.index, index_file)
|
||||
with open(meta_file, mode='wb') as f:
|
||||
pickle.dump(self.index_id_to_db_id, f)
|
||||
|
||||
def deserialize_from(self, dir_path):
|
||||
index_file = os.path.join(dir_path, 'index.faiss')
|
||||
meta_file = os.path.join(dir_path, 'index_meta.faiss')
|
||||
print(f'Loading index from {index_file}, meta data from {meta_file}')
|
||||
|
||||
self.index = faiss.read_index(index_file)
|
||||
print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)
|
||||
|
||||
with open(meta_file, "rb") as reader:
|
||||
self.index_id_to_db_id = pickle.load(reader)
|
||||
assert len(
|
||||
self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
|
||||
|
||||
def _update_id_mapping(self, db_ids: List):
|
||||
#new_ids = np.array(db_ids, dtype=np.int64)
|
||||
#self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0)
|
||||
self.index_id_to_db_id.extend(db_ids)
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import logging
|
||||
import copy
|
||||
import transformers
|
||||
|
||||
from src import contriever, dist_utils, utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MoCo(nn.Module):
|
||||
def __init__(self, opt):
|
||||
super(MoCo, self).__init__()
|
||||
|
||||
self.queue_size = opt.queue_size
|
||||
self.momentum = opt.momentum
|
||||
self.temperature = opt.temperature
|
||||
self.label_smoothing = opt.label_smoothing
|
||||
self.norm_doc = opt.norm_doc
|
||||
self.norm_query = opt.norm_query
|
||||
self.moco_train_mode_encoder_k = opt.moco_train_mode_encoder_k # apply the encoder on keys in train mode
|
||||
|
||||
retriever, tokenizer = self._load_retriever(
|
||||
opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init
|
||||
)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.encoder_q = retriever
|
||||
self.encoder_k = copy.deepcopy(retriever)
|
||||
|
||||
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
|
||||
param_k.data.copy_(param_q.data)
|
||||
param_k.requires_grad = False
|
||||
|
||||
# create the queue
|
||||
self.register_buffer("queue", torch.randn(opt.projection_size, self.queue_size))
|
||||
self.queue = nn.functional.normalize(self.queue, dim=0)
|
||||
|
||||
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
||||
|
||||
def _load_retriever(self, model_id, pooling, random_init):
|
||||
cfg = utils.load_hf(transformers.AutoConfig, model_id)
|
||||
tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id)
|
||||
|
||||
if "xlm" in model_id:
|
||||
model_class = contriever.XLMRetriever
|
||||
else:
|
||||
model_class = contriever.Contriever
|
||||
|
||||
if random_init:
|
||||
retriever = model_class(cfg)
|
||||
else:
|
||||
retriever = utils.load_hf(model_class, model_id)
|
||||
|
||||
if "bert-" in model_id:
|
||||
if tokenizer.bos_token_id is None:
|
||||
tokenizer.bos_token = "[CLS]"
|
||||
if tokenizer.eos_token_id is None:
|
||||
tokenizer.eos_token = "[SEP]"
|
||||
|
||||
retriever.config.pooling = pooling
|
||||
|
||||
return retriever, tokenizer
|
||||
|
||||
def get_encoder(self, return_encoder_k=False):
|
||||
if return_encoder_k:
|
||||
return self.encoder_k
|
||||
else:
|
||||
return self.encoder_q
|
||||
|
||||
def _momentum_update_key_encoder(self):
|
||||
"""
|
||||
Update of the key encoder
|
||||
"""
|
||||
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
|
||||
param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum)
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self, keys):
|
||||
# gather keys before updating queue
|
||||
keys = dist_utils.gather_nograd(keys.contiguous())
|
||||
|
||||
batch_size = keys.shape[0]
|
||||
|
||||
ptr = int(self.queue_ptr)
|
||||
assert self.queue_size % batch_size == 0, f"{batch_size}, {self.queue_size}" # for simplicity
|
||||
|
||||
# replace the keys at ptr (dequeue and enqueue)
|
||||
self.queue[:, ptr : ptr + batch_size] = keys.T
|
||||
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
||||
|
||||
self.queue_ptr[0] = ptr
|
||||
|
||||
def _compute_logits(self, q, k):
|
||||
l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
|
||||
l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])
|
||||
|
||||
logits = torch.cat([l_pos, l_neg], dim=1)
|
||||
return logits
|
||||
|
||||
def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs):
|
||||
bsz = q_tokens.size(0)
|
||||
|
||||
q = self.encoder_q(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query)
|
||||
|
||||
# compute key features
|
||||
with torch.no_grad(): # no gradient to keys
|
||||
self._momentum_update_key_encoder() # update the key encoder
|
||||
|
||||
if not self.encoder_k.training and not self.moco_train_mode_encoder_k:
|
||||
self.encoder_k.eval()
|
||||
|
||||
k = self.encoder_k(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc)
|
||||
|
||||
logits = self._compute_logits(q, k) / self.temperature
|
||||
|
||||
# labels: positive key indicators
|
||||
labels = torch.zeros(bsz, dtype=torch.long).cuda()
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
|
||||
|
||||
self._dequeue_and_enqueue(k)
|
||||
|
||||
# log stats
|
||||
if len(stats_prefix) > 0:
|
||||
stats_prefix = stats_prefix + "/"
|
||||
iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz)
|
||||
|
||||
predicted_idx = torch.argmax(logits, dim=-1)
|
||||
accuracy = 100 * (predicted_idx == labels).float().mean()
|
||||
stdq = torch.std(q, dim=0).mean().item()
|
||||
stdk = torch.std(k, dim=0).mean().item()
|
||||
iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz)
|
||||
iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz)
|
||||
iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz)
|
||||
|
||||
return loss, iter_stats
|
|
@ -0,0 +1,162 @@
|
|||
"""
|
||||
adapted from chemdataextractor.text.normalize
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Tools for normalizing text.
|
||||
https://github.com/mcs07/ChemDataExtractor
|
||||
:copyright: Copyright 2016 by Matt Swain.
|
||||
:license: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files (the
|
||||
'Software'), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish,
|
||||
distribute, sublicense, and/or sell copies of the Software, and to
|
||||
permit persons to whom the Software is furnished to do so, subject to
|
||||
the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
#: Control characters.
|
||||
CONTROLS = {
|
||||
'\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011',
|
||||
'\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b',
|
||||
}
|
||||
# There are further control characters, but they are instead replaced with a space by unicode normalization
|
||||
# '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f'
|
||||
|
||||
|
||||
#: Hyphen and dash characters.
|
||||
HYPHENS = {
|
||||
'-', # \u002d Hyphen-minus
|
||||
'‐', # \u2010 Hyphen
|
||||
'‑', # \u2011 Non-breaking hyphen
|
||||
'⁃', # \u2043 Hyphen bullet
|
||||
'‒', # \u2012 figure dash
|
||||
'–', # \u2013 en dash
|
||||
'—', # \u2014 em dash
|
||||
'―', # \u2015 horizontal bar
|
||||
}
|
||||
|
||||
#: Minus characters.
|
||||
MINUSES = {
|
||||
'-', # \u002d Hyphen-minus
|
||||
'−', # \u2212 Minus
|
||||
'-', # \uff0d Full-width Hyphen-minus
|
||||
'⁻', # \u207b Superscript minus
|
||||
}
|
||||
|
||||
#: Plus characters.
|
||||
PLUSES = {
|
||||
'+', # \u002b Plus
|
||||
'+', # \uff0b Full-width Plus
|
||||
'⁺', # \u207a Superscript plus
|
||||
}
|
||||
|
||||
#: Slash characters.
|
||||
SLASHES = {
|
||||
'/', # \u002f Solidus
|
||||
'⁄', # \u2044 Fraction slash
|
||||
'∕', # \u2215 Division slash
|
||||
}
|
||||
|
||||
#: Tilde characters.
|
||||
TILDES = {
|
||||
'~', # \u007e Tilde
|
||||
'˜', # \u02dc Small tilde
|
||||
'⁓', # \u2053 Swung dash
|
||||
'∼', # \u223c Tilde operator #in mbert vocab
|
||||
'∽', # \u223d Reversed tilde
|
||||
'∿', # \u223f Sine wave
|
||||
'〜', # \u301c Wave dash #in mbert vocab
|
||||
'~', # \uff5e Full-width tilde #in mbert vocab
|
||||
}
|
||||
|
||||
#: Apostrophe characters.
|
||||
APOSTROPHES = {
|
||||
"'", # \u0027
|
||||
'’', # \u2019
|
||||
'՚', # \u055a
|
||||
'Ꞌ', # \ua78b
|
||||
'ꞌ', # \ua78c
|
||||
''', # \uff07
|
||||
}
|
||||
|
||||
#: Single quote characters.
|
||||
SINGLE_QUOTES = {
|
||||
"'", # \u0027
|
||||
'‘', # \u2018
|
||||
'’', # \u2019
|
||||
'‚', # \u201a
|
||||
'‛', # \u201b
|
||||
|
||||
}
|
||||
|
||||
#: Double quote characters.
|
||||
DOUBLE_QUOTES = {
|
||||
'"', # \u0022
|
||||
'“', # \u201c
|
||||
'”', # \u201d
|
||||
'„', # \u201e
|
||||
'‟', # \u201f
|
||||
}
|
||||
|
||||
#: Accent characters.
|
||||
ACCENTS = {
|
||||
'`', # \u0060
|
||||
'´', # \u00b4
|
||||
}
|
||||
|
||||
#: Prime characters.
|
||||
PRIMES = {
|
||||
'′', # \u2032
|
||||
'″', # \u2033
|
||||
'‴', # \u2034
|
||||
'‵', # \u2035
|
||||
'‶', # \u2036
|
||||
'‷', # \u2037
|
||||
'⁗', # \u2057
|
||||
}
|
||||
|
||||
#: Quote characters, including apostrophes, single quotes, double quotes, accents and primes.
|
||||
QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES
|
||||
|
||||
def normalize(text):
|
||||
for control in CONTROLS:
|
||||
text = text.replace(control, '')
|
||||
text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ')
|
||||
|
||||
for hyphen in HYPHENS | MINUSES:
|
||||
text = text.replace(hyphen, '-')
|
||||
text = text.replace('\u00ad', '')
|
||||
|
||||
for double_quote in DOUBLE_QUOTES:
|
||||
text = text.replace(double_quote, '"') # \u0022
|
||||
for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS):
|
||||
text = text.replace(single_quote, "'") # \u0027
|
||||
text = text.replace('′', "'") # \u2032 prime
|
||||
text = text.replace('‵', "'") # \u2035 reversed prime
|
||||
text = text.replace('″', "''") # \u2033 double prime
|
||||
text = text.replace('‶', "''") # \u2036 reversed double prime
|
||||
text = text.replace('‴', "'''") # \u2034 triple prime
|
||||
text = text.replace('‷', "'''") # \u2037 reversed triple prime
|
||||
text = text.replace('⁗', "''''") # \u2057 quadruple prime
|
||||
|
||||
text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026
|
||||
|
||||
for slash in SLASHES:
|
||||
text = text.replace(slash, '/')
|
||||
|
||||
#for tilde in TILDES:
|
||||
# text = text.replace(tilde, '~')
|
||||
|
||||
return text
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
class Options:
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
self.initialize()
|
||||
|
||||
def initialize(self):
|
||||
# basic parameters
|
||||
self.parser.add_argument(
|
||||
"--output_dir", type=str, default="./checkpoint/my_experiments", help="models are saved here"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--train_data",
|
||||
nargs="+",
|
||||
default=[],
|
||||
help="Data used for training, passed as a list of directories splitted into tensor files.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--eval_data",
|
||||
nargs="+",
|
||||
default=[],
|
||||
help="Data used for evaluation during finetuning, this option is not used during contrastive pre-training.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--eval_datasets", nargs="+", default=[], help="List of datasets used for evaluation, in BEIR format"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--eval_datasets_dir", type=str, default="./", help="Directory where eval datasets are stored"
|
||||
)
|
||||
self.parser.add_argument("--model_path", type=str, default="none", help="path for retraining")
|
||||
self.parser.add_argument("--continue_training", action="store_true")
|
||||
self.parser.add_argument("--num_workers", type=int, default=5)
|
||||
|
||||
self.parser.add_argument("--chunk_length", type=int, default=256)
|
||||
self.parser.add_argument("--loading_mode", type=str, default="split")
|
||||
self.parser.add_argument("--lower_case", action="store_true", help="perform evaluation after lowercasing")
|
||||
self.parser.add_argument(
|
||||
"--sampling_coefficient",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="coefficient used for sampling between different datasets during training, \
|
||||
by default sampling is uniform over datasets",
|
||||
)
|
||||
self.parser.add_argument("--augmentation", type=str, default="none")
|
||||
self.parser.add_argument("--prob_augmentation", type=float, default=0.0)
|
||||
|
||||
self.parser.add_argument("--dropout", type=float, default=0.1)
|
||||
self.parser.add_argument("--rho", type=float, default=0.05)
|
||||
|
||||
self.parser.add_argument("--contrastive_mode", type=str, default="moco")
|
||||
self.parser.add_argument("--queue_size", type=int, default=65536)
|
||||
self.parser.add_argument("--temperature", type=float, default=1.0)
|
||||
self.parser.add_argument("--momentum", type=float, default=0.999)
|
||||
self.parser.add_argument("--moco_train_mode_encoder_k", action="store_true")
|
||||
self.parser.add_argument("--eval_normalize_text", action="store_true")
|
||||
self.parser.add_argument("--norm_query", action="store_true")
|
||||
self.parser.add_argument("--norm_doc", action="store_true")
|
||||
self.parser.add_argument("--projection_size", type=int, default=768)
|
||||
|
||||
self.parser.add_argument("--ratio_min", type=float, default=0.1)
|
||||
self.parser.add_argument("--ratio_max", type=float, default=0.5)
|
||||
self.parser.add_argument("--score_function", type=str, default="dot")
|
||||
self.parser.add_argument("--retriever_model_id", type=str, default="bert-base-uncased")
|
||||
self.parser.add_argument("--pooling", type=str, default="average")
|
||||
self.parser.add_argument("--random_init", action="store_true", help="init model with random weights")
|
||||
|
||||
# dataset parameters
|
||||
self.parser.add_argument("--per_gpu_batch_size", default=64, type=int, help="Batch size per GPU for training.")
|
||||
self.parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=256, type=int, help="Batch size per GPU for evaluation."
|
||||
)
|
||||
self.parser.add_argument("--total_steps", type=int, default=1000)
|
||||
self.parser.add_argument("--warmup_steps", type=int, default=-1)
|
||||
|
||||
self.parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
self.parser.add_argument("--main_port", type=int, default=10001, help="Master port (for multi-node SLURM jobs)")
|
||||
self.parser.add_argument("--seed", type=int, default=0, help="random seed for initialization")
|
||||
# training parameters
|
||||
self.parser.add_argument("--optim", type=str, default="adamw")
|
||||
self.parser.add_argument("--scheduler", type=str, default="linear")
|
||||
self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
|
||||
self.parser.add_argument(
|
||||
"--lr_min_ratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="minimum learning rate at the end of the optimization schedule as a ratio of the learning rate",
|
||||
)
|
||||
self.parser.add_argument("--weight_decay", type=float, default=0.01, help="learning rate")
|
||||
self.parser.add_argument("--beta1", type=float, default=0.9, help="beta1")
|
||||
self.parser.add_argument("--beta2", type=float, default=0.98, help="beta2")
|
||||
self.parser.add_argument("--eps", type=float, default=1e-6, help="eps")
|
||||
self.parser.add_argument(
|
||||
"--log_freq", type=int, default=100, help="log train stats every <log_freq> steps during training"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--eval_freq", type=int, default=500, help="evaluate model every <eval_freq> steps during training"
|
||||
)
|
||||
self.parser.add_argument("--save_freq", type=int, default=50000)
|
||||
self.parser.add_argument("--maxload", type=int, default=None)
|
||||
self.parser.add_argument("--label_smoothing", type=float, default=0.0)
|
||||
|
||||
# finetuning options
|
||||
self.parser.add_argument("--negative_ctxs", type=int, default=1)
|
||||
self.parser.add_argument("--negative_hard_min_idx", type=int, default=0)
|
||||
self.parser.add_argument("--negative_hard_ratio", type=float, default=0.0)
|
||||
|
||||
def print_options(self, opt):
|
||||
message = ""
|
||||
for k, v in sorted(vars(opt).items()):
|
||||
comment = ""
|
||||
default = self.parser.get_default(k)
|
||||
if v != default:
|
||||
comment = f"\t[default: %s]" % str(default)
|
||||
message += f"{str(k):>40}: {str(v):<40}{comment}\n"
|
||||
print(message, flush=True)
|
||||
model_dir = os.path.join(opt.output_dir, "models")
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(os.path.join(opt.output_dir, "models"))
|
||||
file_name = os.path.join(opt.output_dir, "opt.txt")
|
||||
with open(file_name, "wt") as opt_file:
|
||||
opt_file.write(message)
|
||||
opt_file.write("\n")
|
||||
|
||||
def parse(self):
|
||||
opt, _ = self.parser.parse_known_args()
|
||||
# opt = self.parser.parse_args()
|
||||
return opt
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from logging import getLogger
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import socket
|
||||
import signal
|
||||
import subprocess
|
||||
|
||||
|
||||
logger = getLogger()
|
||||
|
||||
def sig_handler(signum, frame):
|
||||
logger.warning("Signal handler called with signal " + str(signum))
|
||||
prod_id = int(os.environ['SLURM_PROCID'])
|
||||
logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id))
|
||||
if prod_id == 0:
|
||||
logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID'])
|
||||
os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID'])
|
||||
else:
|
||||
logger.warning("Not the main process, no need to requeue.")
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def term_handler(signum, frame):
|
||||
logger.warning("Signal handler called with signal " + str(signum))
|
||||
logger.warning("Bypassing SIGTERM.")
|
||||
|
||||
|
||||
def init_signal_handler():
|
||||
"""
|
||||
Handle signals sent by SLURM for time limit / pre-emption.
|
||||
"""
|
||||
signal.signal(signal.SIGUSR1, sig_handler)
|
||||
signal.signal(signal.SIGTERM, term_handler)
|
||||
|
||||
|
||||
def init_distributed_mode(params):
|
||||
"""
|
||||
Handle single and multi-GPU / multi-node / SLURM jobs.
|
||||
Initialize the following variables:
|
||||
- local_rank
|
||||
- global_rank
|
||||
- world_size
|
||||
"""
|
||||
is_slurm_job = 'SLURM_JOB_ID' in os.environ and not 'WORLD_SIZE' in os.environ
|
||||
has_local_rank = hasattr(params, 'local_rank')
|
||||
|
||||
# SLURM job without torch.distributed.launch
|
||||
if is_slurm_job and has_local_rank:
|
||||
|
||||
assert params.local_rank == -1 # on the cluster, this is handled by SLURM
|
||||
|
||||
# local rank on the current node / global rank
|
||||
params.local_rank = int(os.environ['SLURM_LOCALID'])
|
||||
params.global_rank = int(os.environ['SLURM_PROCID'])
|
||||
params.world_size = int(os.environ['SLURM_NTASKS'])
|
||||
|
||||
# define master address and master port
|
||||
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']])
|
||||
params.main_addr = hostnames.split()[0].decode('utf-8')
|
||||
assert 10001 <= params.main_port <= 20000 or params.world_size == 1
|
||||
|
||||
# set environment variables for 'env://'
|
||||
os.environ['MASTER_ADDR'] = params.main_addr
|
||||
os.environ['MASTER_PORT'] = str(params.main_port)
|
||||
os.environ['WORLD_SIZE'] = str(params.world_size)
|
||||
os.environ['RANK'] = str(params.global_rank)
|
||||
is_distributed = True
|
||||
|
||||
|
||||
# multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch
|
||||
elif has_local_rank and params.local_rank != -1:
|
||||
|
||||
assert params.main_port == -1
|
||||
|
||||
# read environment variables
|
||||
params.global_rank = int(os.environ['RANK'])
|
||||
params.world_size = int(os.environ['WORLD_SIZE'])
|
||||
|
||||
is_distributed = True
|
||||
|
||||
# local job (single GPU)
|
||||
else:
|
||||
params.local_rank = 0
|
||||
params.global_rank = 0
|
||||
params.world_size = 1
|
||||
is_distributed = False
|
||||
|
||||
# set GPU device
|
||||
torch.cuda.set_device(params.local_rank)
|
||||
|
||||
# initialize multi-GPU
|
||||
if is_distributed:
|
||||
|
||||
# http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization
|
||||
# 'env://' will read these environment variables:
|
||||
# MASTER_PORT - required; has to be a free port on machine with rank 0
|
||||
# MASTER_ADDR - required (except for rank 0); address of rank 0 node
|
||||
# WORLD_SIZE - required; can be set either here, or in a call to init function
|
||||
# RANK - required; can be set either here, or in a call to init function
|
||||
|
||||
#print("Initializing PyTorch distributed ...")
|
||||
torch.distributed.init_process_group(
|
||||
init_method='env://',
|
||||
backend='nccl',
|
||||
#world_size=params.world_size,
|
||||
#rank=params.global_rank,
|
||||
)
|
|
@ -0,0 +1,213 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import torch
|
||||
import errno
|
||||
from typing import Union, Tuple, Dict
|
||||
from collections import defaultdict
|
||||
|
||||
from robowaiter.algos.retrieval.retrieval_lm.src import dist_utils
|
||||
|
||||
Number = Union[float, int]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_logger(args, stdout_only=False):
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
handlers = [stdout_handler]
|
||||
if not stdout_only:
|
||||
file_handler = logging.FileHandler(filename=os.path.join(args.output_dir, "run.log"))
|
||||
handlers.append(file_handler)
|
||||
logging.basicConfig(
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if dist_utils.is_main() else logging.WARN,
|
||||
format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s",
|
||||
handlers=handlers,
|
||||
)
|
||||
return logger
|
||||
|
||||
|
||||
def symlink_force(target, link_name):
|
||||
try:
|
||||
os.symlink(target, link_name)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST:
|
||||
os.remove(link_name)
|
||||
os.symlink(target, link_name)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
def save(model, optimizer, scheduler, step, opt, dir_path, name):
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
path = os.path.join(dir_path, "checkpoint")
|
||||
epoch_path = os.path.join(path, name) # "step-%s" % step)
|
||||
os.makedirs(epoch_path, exist_ok=True)
|
||||
cp = os.path.join(path, "latest")
|
||||
fp = os.path.join(epoch_path, "checkpoint.pth")
|
||||
checkpoint = {
|
||||
"step": step,
|
||||
"model": model_to_save.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
"opt": opt,
|
||||
}
|
||||
torch.save(checkpoint, fp)
|
||||
symlink_force(epoch_path, cp)
|
||||
if not name == "lastlog":
|
||||
logger.info(f"Saving model to {epoch_path}")
|
||||
|
||||
|
||||
def load(model_class, dir_path, opt, reset_params=False):
|
||||
epoch_path = os.path.realpath(dir_path)
|
||||
checkpoint_path = os.path.join(epoch_path, "checkpoint.pth")
|
||||
logger.info(f"loading checkpoint {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
opt_checkpoint = checkpoint["opt"]
|
||||
state_dict = checkpoint["model"]
|
||||
|
||||
model = model_class(opt_checkpoint)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model = model.cuda()
|
||||
step = checkpoint["step"]
|
||||
if not reset_params:
|
||||
optimizer, scheduler = set_optim(opt_checkpoint, model)
|
||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
else:
|
||||
optimizer, scheduler = set_optim(opt, model)
|
||||
|
||||
return model, optimizer, scheduler, opt_checkpoint, step
|
||||
|
||||
|
||||
############ OPTIM
|
||||
|
||||
|
||||
class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR):
|
||||
def __init__(self, optimizer, warmup, total, ratio, last_epoch=-1):
|
||||
self.warmup = warmup
|
||||
self.total = total
|
||||
self.ratio = ratio
|
||||
super(WarmupLinearScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
def lr_lambda(self, step):
|
||||
if step < self.warmup:
|
||||
return (1 - self.ratio) * step / float(max(1, self.warmup))
|
||||
|
||||
return max(
|
||||
0.0,
|
||||
1.0 + (self.ratio - 1) * (step - self.warmup) / float(max(1.0, self.total - self.warmup)),
|
||||
)
|
||||
|
||||
|
||||
class CosineScheduler(torch.optim.lr_scheduler.LambdaLR):
|
||||
def __init__(self, optimizer, warmup, total, ratio=0.1, last_epoch=-1):
|
||||
self.warmup = warmup
|
||||
self.total = total
|
||||
self.ratio = ratio
|
||||
super(CosineScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
def lr_lambda(self, step):
|
||||
if step < self.warmup:
|
||||
return float(step) / self.warmup
|
||||
s = float(step - self.warmup) / (self.total - self.warmup)
|
||||
return self.ratio + (1.0 - self.ratio) * math.cos(0.5 * math.pi * s)
|
||||
|
||||
|
||||
def set_optim(opt, model):
|
||||
if opt.optim == "adamw":
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), eps=opt.eps, weight_decay=opt.weight_decay
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("optimizer class not implemented")
|
||||
|
||||
scheduler_args = {
|
||||
"warmup": opt.warmup_steps,
|
||||
"total": opt.total_steps,
|
||||
"ratio": opt.lr_min_ratio,
|
||||
}
|
||||
if opt.scheduler == "linear":
|
||||
scheduler_class = WarmupLinearScheduler
|
||||
elif opt.scheduler == "cosine":
|
||||
scheduler_class = CosineScheduler
|
||||
else:
|
||||
raise ValueError
|
||||
scheduler = scheduler_class(optimizer, **scheduler_args)
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
def get_parameters(net, verbose=False):
|
||||
num_params = 0
|
||||
for param in net.parameters():
|
||||
num_params += param.numel()
|
||||
message = "[Network] Total number of parameters : %.6f M" % (num_params / 1e6)
|
||||
return message
|
||||
|
||||
|
||||
class WeightedAvgStats:
|
||||
"""provides an average over a bunch of stats"""
|
||||
|
||||
def __init__(self):
|
||||
self.raw_stats: Dict[str, float] = defaultdict(float)
|
||||
self.total_weights: Dict[str, float] = defaultdict(float)
|
||||
|
||||
def update(self, vals: Dict[str, Tuple[Number, Number]]) -> None:
|
||||
for key, (value, weight) in vals.items():
|
||||
self.raw_stats[key] += value * weight
|
||||
self.total_weights[key] += weight
|
||||
|
||||
@property
|
||||
def stats(self) -> Dict[str, float]:
|
||||
return {x: self.raw_stats[x] / self.total_weights[x] for x in self.raw_stats.keys()}
|
||||
|
||||
@property
|
||||
def tuple_stats(self) -> Dict[str, Tuple[float, float]]:
|
||||
return {x: (self.raw_stats[x] / self.total_weights[x], self.total_weights[x]) for x in self.raw_stats.keys()}
|
||||
|
||||
def reset(self) -> None:
|
||||
self.raw_stats = defaultdict(float)
|
||||
self.total_weights = defaultdict(float)
|
||||
|
||||
@property
|
||||
def average_stats(self) -> Dict[str, float]:
|
||||
keys = sorted(self.raw_stats.keys())
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.broadcast_object_list(keys, src=0)
|
||||
global_dict = {}
|
||||
for k in keys:
|
||||
if not k in self.total_weights:
|
||||
v = 0.0
|
||||
else:
|
||||
v = self.raw_stats[k] / self.total_weights[k]
|
||||
v, _ = dist_utils.weighted_average(v, self.total_weights[k])
|
||||
global_dict[k] = v
|
||||
return global_dict
|
||||
|
||||
|
||||
def load_hf(object_class, model_name):
|
||||
try:
|
||||
obj = object_class.from_pretrained(model_name, local_files_only=True)
|
||||
except:
|
||||
obj = object_class.from_pretrained(model_name, local_files_only=False)
|
||||
return obj
|
||||
|
||||
|
||||
def init_tb_logger(output_dir):
|
||||
try:
|
||||
from torch.utils import tensorboard
|
||||
|
||||
if dist_utils.is_main():
|
||||
tb_logger = tensorboard.SummaryWriter(output_dir)
|
||||
else:
|
||||
tb_logger = None
|
||||
except:
|
||||
logger.warning("Tensorboard is not available.")
|
||||
tb_logger = None
|
||||
|
||||
return tb_logger
|
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 1e5,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
|
@ -0,0 +1,194 @@
|
|||
import jsonlines
|
||||
import json
|
||||
import copy
|
||||
import re
|
||||
|
||||
PROMPT_DICT = {
|
||||
"prompt_input": (
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
),
|
||||
"prompt_no_input": (
|
||||
"### Instruction:\n{instruction}\n\n### Response:\n"
|
||||
),
|
||||
}
|
||||
|
||||
TASK_INST = {"wow": "Given a chat history separated by new lines, generates an informative, knowledgeable and engaging response. ",
|
||||
"fever": "Is the following statement correct or not? Say true if it's correct; otherwise say false.",
|
||||
"eli5": "Provide a paragraph-length response using simple words to answer the following question.",
|
||||
"obqa": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
|
||||
"arc_easy": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
|
||||
"arc_c": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
|
||||
"trex": "Given the input format 'Subject Entity [SEP] Relationship Type,' predict the target entity.",
|
||||
"asqa": "Answer the following question. The question may be ambiguous and have multiple correct answers, and in that case, you have to provide a long-form answer including all correct answers."}
|
||||
|
||||
rel_tokens_names = ["[Irrelevant]", "[Relevant]"]
|
||||
retrieval_tokens_names = ["[No Retrieval]",
|
||||
"[Retrieval]", "[Continue to Use Evidence]"]
|
||||
utility_tokens_names = ["[Utility:1]", "[Utility:2]",
|
||||
"[Utility:3]", "[Utility:4]", "[Utility:5]"]
|
||||
ground_tokens_names = ["[Fully supported]",
|
||||
"[Partially supported]", "[No support / Contradictory]"]
|
||||
other_special_tokens = ["<s>", "</s>", "[PAD]",
|
||||
"<unk>", "<paragraph>", "</paragraph>"]
|
||||
control_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]",
|
||||
"[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"]
|
||||
|
||||
|
||||
def load_special_tokens(tokenizer, use_grounding=False, use_utility=False):
|
||||
ret_tokens = {token: tokenizer.convert_tokens_to_ids(
|
||||
token) for token in retrieval_tokens_names}
|
||||
rel_tokens = {}
|
||||
for token in ["[Irrelevant]", "[Relevant]"]:
|
||||
rel_tokens[token] = tokenizer.convert_tokens_to_ids(token)
|
||||
|
||||
grd_tokens = None
|
||||
if use_grounding is True:
|
||||
grd_tokens = {}
|
||||
for token in ground_tokens_names:
|
||||
grd_tokens[token] = tokenizer.convert_tokens_to_ids(token)
|
||||
|
||||
ut_tokens = None
|
||||
if use_utility is True:
|
||||
ut_tokens = {}
|
||||
for token in utility_tokens_names:
|
||||
ut_tokens[token] = tokenizer.convert_tokens_to_ids(token)
|
||||
|
||||
return ret_tokens, rel_tokens, grd_tokens, ut_tokens
|
||||
|
||||
|
||||
def fix_spacing(input_text):
|
||||
# Add a space after periods that lack whitespace
|
||||
output_text = re.sub(r'(?<=\w)([.!?])(?=\w)', r'\1 ', input_text)
|
||||
return output_text
|
||||
|
||||
|
||||
def postprocess(pred):
|
||||
special_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]",
|
||||
"[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"]
|
||||
for item in special_tokens:
|
||||
pred = pred.replace(item, "")
|
||||
pred = pred.replace("</s>", "")
|
||||
|
||||
if len(pred) == 0:
|
||||
return ""
|
||||
if pred[0] == " ":
|
||||
pred = pred[1:]
|
||||
return pred
|
||||
|
||||
|
||||
def load_jsonlines(file):
|
||||
with jsonlines.open(file, 'r') as jsonl_f:
|
||||
lst = [obj for obj in jsonl_f]
|
||||
return lst
|
||||
|
||||
|
||||
def load_file(input_fp):
|
||||
if input_fp.endswith(".json"):
|
||||
input_data = json.load(open(input_fp))
|
||||
else:
|
||||
input_data = load_jsonlines(input_fp)
|
||||
return input_data
|
||||
|
||||
|
||||
def save_file_jsonl(data, fp):
|
||||
with jsonlines.open(fp, mode='w') as writer:
|
||||
writer.write_all(data)
|
||||
|
||||
|
||||
def preprocess_input(input_data, task):
|
||||
if task == "factscore":
|
||||
for item in input_data:
|
||||
item["instruction"] = item["input"]
|
||||
item["output"] = [item["output"]
|
||||
] if "output" in item else [item["topic"]]
|
||||
return input_data
|
||||
|
||||
elif task == "qa":
|
||||
for item in input_data:
|
||||
if "instruction" not in item:
|
||||
item["instruction"] = item["question"]
|
||||
if "answers" not in item and "output" in item:
|
||||
item["answers"] = "output"
|
||||
return input_data
|
||||
|
||||
elif task in ["asqa", "eli5"]:
|
||||
processed_input_data = []
|
||||
for instance_idx, item in enumerate(input_data["data"]):
|
||||
prompt = item["question"]
|
||||
instructions = TASK_INST[task]
|
||||
prompt = instructions + "## Input:\n\n" + prompt
|
||||
entry = copy.deepcopy(item)
|
||||
entry["instruction"] = prompt
|
||||
processed_input_data.append(entry)
|
||||
return processed_input_data
|
||||
|
||||
|
||||
def postprocess_output(input_instance, prediction, task, intermediate_results=None):
|
||||
if task == "factscore":
|
||||
return {"input": input_instance["input"], "output": prediction, "topic": input_instance["topic"], "cat": input_instance["cat"]}
|
||||
|
||||
elif task == "qa":
|
||||
input_instance["pred"] = prediction
|
||||
return input_instance
|
||||
|
||||
elif task in ["asqa", "eli5"]:
|
||||
# ALCE datasets require additional postprocessing to compute citation accuracy.
|
||||
final_output = ""
|
||||
docs = []
|
||||
if "splitted_sentences" not in intermediate_results:
|
||||
input_instance["output"] = postprocess(prediction)
|
||||
|
||||
else:
|
||||
for idx, (sent, doc) in enumerate(zip(intermediate_results["splitted_sentences"][0], intermediate_results["ctxs"][0])):
|
||||
if len(sent) == 0:
|
||||
continue
|
||||
postprocessed_result = postprocess(sent)
|
||||
final_output += postprocessed_result[:-
|
||||
1] + " [{}]".format(idx) + ". "
|
||||
docs.append(doc)
|
||||
if final_output[-1] == " ":
|
||||
final_output = final_output[:-1]
|
||||
input_instance["output"] = final_output
|
||||
input_instance["docs"] = docs
|
||||
return input_instance
|
||||
|
||||
def process_arc_instruction(item, instruction):
|
||||
choices = item["choices"]
|
||||
answer_labels = {}
|
||||
for i in range(len(choices["label"])):
|
||||
answer_key = choices["label"][i]
|
||||
text = choices["text"][i]
|
||||
if answer_key == "1":
|
||||
answer_labels["A"] = text
|
||||
if answer_key == "2":
|
||||
answer_labels["B"] = text
|
||||
if answer_key == "3":
|
||||
answer_labels["C"] = text
|
||||
if answer_key == "4":
|
||||
answer_labels["D"] = text
|
||||
if answer_key in ["A", "B", "C", "D"]:
|
||||
answer_labels[answer_key] = text
|
||||
|
||||
if "D" not in answer_labels:
|
||||
answer_labels["D"] = ""
|
||||
choices = "\nA: {0}\nB: {1}\nC: {2}\nD: {3}".format(answer_labels["A"], answer_labels["B"], answer_labels["C"], answer_labels["D"])
|
||||
if "E" in answer_labels:
|
||||
choices += "\nE: {}".format(answer_labels["E"])
|
||||
processed_instruction = instruction + "\n\n### Input:\n" + item["instruction"] + choices
|
||||
return processed_instruction
|
||||
|
||||
|
||||
def postprocess_answers_closed(output, task, choices=None):
|
||||
final_output = None
|
||||
if choices is not None:
|
||||
for c in choices.split(" "):
|
||||
if c in output:
|
||||
final_output = c
|
||||
if task == "fever" and output in ["REFUTES", "SUPPORTS"]:
|
||||
final_output = "true" if output == "SUPPORTS" else "REFUTES"
|
||||
if task == "fever" and output.lower() in ["true", "false"]:
|
||||
final_output = output.lower()
|
||||
if final_output is None:
|
||||
return output
|
||||
else:
|
||||
return final_output
|
Binary file not shown.
|
@ -0,0 +1,10 @@
|
|||
python passage_retrieval3.py \
|
||||
--model_name_or_path ../model/contriever-msmarco \
|
||||
--passages train_robot.jsonl \
|
||||
--passages_embeddings "robot_embeddings/*" \
|
||||
--data test_robot.jsonl \
|
||||
--output_dir robot_result \
|
||||
--n_docs 2
|
||||
|
||||
|
||||
#python passage_retrieval3.py --model_name_or_path contriever-msmarco --passages train_robot.jsonl --passages_embeddings "robot_embeddings/*" --data test_robot.jsonl --output_dir robot_result --n_docs 2
|
|
@ -195,6 +195,7 @@ get_object_info
|
|||
我带着孩子呢,想要宽敞亮堂的地方。
|
||||
好的,我明白了,那么我们推荐您到大厅的桌子,那里的空间比较宽敞,环境也比较明亮,适合带着孩子一起用餐。
|
||||
|
||||
|
||||
冰红茶
|
||||
好的
|
||||
create_sub_task
|
||||
|
|
|
@ -0,0 +1,312 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import pickle
|
||||
import time
|
||||
import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from robowaiter.algos.retrieval.retrieval_lm.src.slurm import init_distributed_mode
|
||||
from robowaiter.algos.retrieval.retrieval_lm.src.normalize_text import normalize
|
||||
from robowaiter.algos.retrieval.retrieval_lm.src.contriever import load_retriever
|
||||
from robowaiter.algos.retrieval.retrieval_lm.src.index import Indexer
|
||||
from robowaiter.algos.retrieval.retrieval_lm.src.data import load_passages
|
||||
|
||||
from robowaiter.algos.retrieval.retrieval_lm.src.evaluation import calculate_matches
|
||||
import warnings
|
||||
from robowaiter.utils.basic import get_root_path
|
||||
root_path = get_root_path()
|
||||
warnings.filterwarnings('ignore')
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
|
||||
def embed_queries(args, queries, model, tokenizer):
|
||||
model.eval()
|
||||
embeddings, batch_question = [], []
|
||||
with torch.no_grad():
|
||||
|
||||
for k, q in enumerate(queries):
|
||||
if args.lowercase:
|
||||
q = q.lower()
|
||||
if args.normalize_text:
|
||||
q = normalize(q)
|
||||
batch_question.append(q)
|
||||
|
||||
if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
|
||||
|
||||
encoded_batch = tokenizer.batch_encode_plus(
|
||||
batch_question,
|
||||
return_tensors="pt",
|
||||
max_length=args.question_maxlength,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
||||
output = model(**encoded_batch)
|
||||
embeddings.append(output.cpu())
|
||||
|
||||
batch_question = []
|
||||
|
||||
embeddings = torch.cat(embeddings, dim=0)
|
||||
#print(f"Questions embeddings shape: {embeddings.size()}")
|
||||
|
||||
return embeddings.numpy()
|
||||
|
||||
|
||||
def index_encoded_data(index, embedding_files, indexing_batch_size):
|
||||
allids = []
|
||||
allembeddings = np.array([])
|
||||
for i, file_path in enumerate(embedding_files):
|
||||
#print(f"Loading file {file_path}")
|
||||
with open(file_path, "rb") as fin:
|
||||
ids, embeddings = pickle.load(fin)
|
||||
|
||||
allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings
|
||||
allids.extend(ids)
|
||||
while allembeddings.shape[0] > indexing_batch_size:
|
||||
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
||||
|
||||
while allembeddings.shape[0] > 0:
|
||||
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
||||
|
||||
#print("Data indexing completed.")
|
||||
|
||||
|
||||
def add_embeddings(index, embeddings, ids, indexing_batch_size):
|
||||
end_idx = min(indexing_batch_size, embeddings.shape[0])
|
||||
ids_toadd = ids[:end_idx]
|
||||
embeddings_toadd = embeddings[:end_idx]
|
||||
ids = ids[end_idx:]
|
||||
embeddings = embeddings[end_idx:]
|
||||
index.index_data(ids_toadd, embeddings_toadd)
|
||||
return embeddings, ids
|
||||
|
||||
|
||||
def validate(data, workers_num):
|
||||
match_stats = calculate_matches(data, workers_num)
|
||||
top_k_hits = match_stats.top_k_hits
|
||||
|
||||
# print("Validation results: top k documents hits %s", top_k_hits)
|
||||
top_k_hits = [v / len(data) for v in top_k_hits]
|
||||
message = ""
|
||||
for k in [5, 10, 20, 100]:
|
||||
if k <= len(top_k_hits):
|
||||
message += f"R@{k}: {top_k_hits[k-1]} "
|
||||
#print(message)
|
||||
return match_stats.questions_doc_hits
|
||||
|
||||
|
||||
def add_passages(data, passages, top_passages_and_scores):
|
||||
# add passages to original data
|
||||
merged_data = []
|
||||
assert len(data) == len(top_passages_and_scores)
|
||||
for i, d in enumerate(data):
|
||||
results_and_scores = top_passages_and_scores[i]
|
||||
#print(passages[2393])
|
||||
docs = [passages[int(doc_id)] for doc_id in results_and_scores[0]]
|
||||
scores = [str(score) for score in results_and_scores[1]]
|
||||
ctxs_num = len(docs)
|
||||
d["ctxs"] = [
|
||||
{
|
||||
"id": results_and_scores[0][c],
|
||||
"title": docs[c]["title"],
|
||||
"text": docs[c]["text"],
|
||||
"score": scores[c],
|
||||
}
|
||||
for c in range(ctxs_num)
|
||||
]
|
||||
|
||||
|
||||
def add_hasanswer(data, hasanswer):
|
||||
# add hasanswer to data
|
||||
for i, ex in enumerate(data):
|
||||
for k, d in enumerate(ex["ctxs"]):
|
||||
d["hasanswer"] = hasanswer[i][k]
|
||||
|
||||
|
||||
# def load_data(data_path):
|
||||
# if data_path.endswith(".json"):
|
||||
# with open(data_path, "r",encoding='utf-8') as fin:
|
||||
# data = json.load(fin)
|
||||
# elif data_path.endswith(".jsonl"):
|
||||
# data = []
|
||||
# with open(data_path, "r",encoding='utf-8') as fin:
|
||||
# for k, example in enumerate(fin):
|
||||
# example = json.loads(example)
|
||||
# data.append(example)
|
||||
# print("data:",data)
|
||||
# return data
|
||||
def load_data(data_path):
|
||||
if data_path.endswith(".json"):
|
||||
with open(data_path, "r",encoding='utf-8') as fin:
|
||||
data = json.load(fin)
|
||||
elif data_path.endswith(".jsonl"):
|
||||
data = []
|
||||
with open(data_path, "r",encoding='utf-8') as fin:
|
||||
for k, example in enumerate(fin):
|
||||
example = json.loads(example)
|
||||
#print("example:",example)
|
||||
data.append(example)
|
||||
return data
|
||||
|
||||
|
||||
def test(args):#path为query
|
||||
# args = {"model_name_or_path":"contriever-msmarco","passages":"train_robot.jsonl"\
|
||||
# passages_embeddings = "robot_embeddings/*"
|
||||
# data = "test_robot.jsonl"
|
||||
# output_dir = "robot_result"
|
||||
# n_docs = 1
|
||||
|
||||
#print(f"Loading model from: {args.model_name_or_path}")
|
||||
model, tokenizer, _ = load_retriever(args.model_name_or_path)
|
||||
model.eval()
|
||||
model = model.cuda()
|
||||
if not args.no_fp16:
|
||||
model = model.half()
|
||||
|
||||
index = Indexer(args.projection_size, args.n_subquantizers, args.n_bits)
|
||||
|
||||
# index all passages
|
||||
input_paths = glob.glob(args.passages_embeddings)
|
||||
input_paths = sorted(input_paths)
|
||||
embeddings_dir = os.path.dirname(input_paths[0])
|
||||
index_path = os.path.join(embeddings_dir, "index.faiss")
|
||||
if args.save_or_load_index and os.path.exists(index_path):
|
||||
index.deserialize_from(embeddings_dir)
|
||||
else:
|
||||
#print(f"Indexing passages from files {input_paths}")
|
||||
start_time_indexing = time.time()
|
||||
index_encoded_data(index, input_paths, args.indexing_batch_size)
|
||||
#print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
|
||||
if args.save_or_load_index:
|
||||
index.serialize(embeddings_dir)
|
||||
|
||||
# load passages
|
||||
passages = load_passages(args.passages)
|
||||
passage_id_map = {x["id"]: x for x in passages}
|
||||
|
||||
data_paths = glob.glob(args.data)
|
||||
alldata = []
|
||||
for path in data_paths:
|
||||
data = load_data(path)
|
||||
#print("data:",data)
|
||||
output_path = os.path.join(args.output_dir, os.path.basename(path))
|
||||
|
||||
queries = [ex["question"] for ex in data]
|
||||
questions_embedding = embed_queries(args, queries, model, tokenizer)
|
||||
|
||||
# get top k results
|
||||
start_time_retrieval = time.time()
|
||||
top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
|
||||
#print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
|
||||
|
||||
add_passages(data, passage_id_map, top_ids_and_scores)
|
||||
#hasanswer = validate(data, args.validation_workers)
|
||||
#add_hasanswer(data, hasanswer)
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
ret_list = []
|
||||
with open(output_path, "w",encoding='utf-8') as fout:
|
||||
for ex in data:
|
||||
json.dump(ex, fout, ensure_ascii=False)
|
||||
ret_list.append(ex)
|
||||
fout.write("\n")
|
||||
return ret_list
|
||||
#print(f"Saved results to {output_path}")
|
||||
|
||||
#将query写到test_robot.jsonl
|
||||
def get_json(query):
|
||||
dic = {"id": 1, "question": query}
|
||||
with open('test_robot.jsonl', "w", encoding='utf-8') as fout:
|
||||
json.dump(dic, fout, ensure_ascii=False)
|
||||
|
||||
def get_answer():
|
||||
with open('robot_result\\test_robot.jsonl', "w", encoding='utf-8') as fin:
|
||||
for k, example in enumerate(fin):
|
||||
example = json.loads(example)
|
||||
answer = example["ctxs"][0]["text"]
|
||||
score = example["ctxs"][0]["score"]
|
||||
return score, answer
|
||||
def retri(query):
|
||||
get_json(query)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
#required=True,
|
||||
type=str,
|
||||
default='test_robot.jsonl',
|
||||
help=".json file containing question and answers, similar format to reader data",
|
||||
)
|
||||
# parser.add_argument("--passages", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/train_robot.jsonl', help="Path to passages (.tsv file)")
|
||||
# parser.add_argument("--passages_embeddings", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/robot_embeddings/*', help="Glob path to encoded passages")
|
||||
parser.add_argument("--passages", type=str, default=f'{root_path}/robowaiter/llm_client/train_robot.jsonl', help="Path to passages (.tsv file)")
|
||||
parser.add_argument("--passages_embeddings", type=str, default=f'{root_path}/robowaiter/algos/retrieval/robot_embeddings/*', help="Glob path to encoded passages")
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default='robot_result', help="Results are written to outputdir with data suffix"
|
||||
)
|
||||
parser.add_argument("--n_docs", type=int, default=5, help="Number of documents to retrieve per questions") #可以改这个参数,返回前n_docs个检索结果
|
||||
parser.add_argument(
|
||||
"--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
|
||||
)
|
||||
parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding")
|
||||
parser.add_argument(
|
||||
"--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--model_name_or_path", type=str, default='C:\\Users\\huangyu\\Desktop\\RoboWaiter-main\\RoboWaiter-main\\contriever-msmarco',help="path to directory containing model weights and config file"
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--model_name_or_path", type=str, default=f'{root_path}/robowaiter/algos/retrieval/contriever-msmarco',help="path to directory containing model weights and config file"
|
||||
)
|
||||
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
|
||||
parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
|
||||
parser.add_argument(
|
||||
"--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
|
||||
)
|
||||
parser.add_argument("--projection_size", type=int, default=768)
|
||||
parser.add_argument(
|
||||
"--n_subquantizers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of subquantizer used for vector quantization, if 0 flat index is used",
|
||||
)
|
||||
parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer")
|
||||
parser.add_argument("--lang", nargs="+")
|
||||
parser.add_argument("--dataset", type=str, default="none")
|
||||
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
|
||||
parser.add_argument("--normalize_text", action="store_true", help="normalize text")
|
||||
|
||||
args = parser.parse_args()
|
||||
init_distributed_mode(args)
|
||||
#print(args)
|
||||
ret = test(args)
|
||||
#print(ret)
|
||||
|
||||
return ret[0]
|
||||
|
||||
# example = ret[0]
|
||||
# answer = example["ctxs"][0]["text"]
|
||||
# score = example["ctxs"][0]["score"]
|
||||
# return score, answer
|
||||
|
||||
if __name__ == "__main__":
|
||||
# query = "请你拿一下软饮料到第三张桌子位置。"
|
||||
# score,answer = retri(query)
|
||||
# print(score,answer)
|
||||
|
||||
query = "你能把空调打开一下吗?"
|
||||
all_ret = retri(query)
|
||||
for i,example in enumerate(all_ret["ctxs"]):
|
||||
answer = example["text"]
|
||||
score = example["score"]
|
||||
id = example["id"]
|
||||
print(i,answer,score," id=",id)
|
||||
|
||||
|
|
@ -0,0 +1 @@
|
|||
{"id": 1, "question": "你能把空调打开一下吗?", "ctxs": [{"id": "505", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}, {"id": "313", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8567487"}, {"id": "312", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}, {"id": "120", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8567487"}, {"id": "119", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}]}
|
|
@ -0,0 +1,43 @@
|
|||
|
||||
import requests
|
||||
import urllib3
|
||||
########################################
|
||||
# 该文件实现了与大模型的简单通信
|
||||
########################################
|
||||
|
||||
# 忽略https的安全性警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
|
||||
def single_round(question,prefix=""):
|
||||
url = "https://45.125.46.134:25344/v1/chat/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
"model": "RoboWaiter",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个机器人服务员:RoboWaiter. 你的职责是为顾客提供对话及具身服务。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prefix + question
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, verify=False)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result['choices'][0]['message']['content'].strip()
|
||||
else:
|
||||
return "大模型请求失败:", response.status_code
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
question = '''
|
||||
给我一杯拿铁
|
||||
'''
|
||||
|
||||
print(single_round(question))
|
|
@ -0,0 +1 @@
|
|||
{"id": 1, "question": "你能把空调打开一下吗?"}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue