Merge branch 'main' of https://github.com/HPCL-EI/RoboWaiter
This commit is contained in:
commit
089d987c28
|
@ -19,6 +19,7 @@ share/python-wheels/
|
||||||
MANIFEST
|
MANIFEST
|
||||||
MO-VLN/
|
MO-VLN/
|
||||||
GLIP/
|
GLIP/
|
||||||
|
pytorch_model.bin
|
||||||
|
|
||||||
sub_task.ptml
|
sub_task.ptml
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,8 @@ pip install -e .
|
||||||
### 安装UI
|
### 安装UI
|
||||||
1. 安装 [graphviz-9.0.0](https://gitlab.com/api/v4/projects/4207231/packages/generic/graphviz-releases/9.0.0/windows_10_cmake_Release_graphviz-install-9.0.0-win64.exe) (详见[官网](https://www.graphviz.org/download/#windows))
|
1. 安装 [graphviz-9.0.0](https://gitlab.com/api/v4/projects/4207231/packages/generic/graphviz-releases/9.0.0/windows_10_cmake_Release_graphviz-install-9.0.0-win64.exe) (详见[官网](https://www.graphviz.org/download/#windows))
|
||||||
2. 将软件安装目录的bin文件添加到系统环境中。如电脑是Windows系统,Graphviz安装在D:\Program Files (x86)\Graphviz2.38,该目录下有bin文件,将该路径添加到电脑系统环境变量path中,即D:\Program Files (x86)\Graphviz2.38\bin。
|
2. 将软件安装目录的bin文件添加到系统环境中。如电脑是Windows系统,Graphviz安装在D:\Program Files (x86)\Graphviz2.38,该目录下有bin文件,将该路径添加到电脑系统环境变量path中,即D:\Program Files (x86)\Graphviz2.38\bin。
|
||||||
|
3. 安装向量数据库
|
||||||
|
conda install -c conda-forge faiss
|
||||||
|
|
||||||
### 快速入门
|
### 快速入门
|
||||||
1. 安装UE及Harix插件,打开默认项目并运行
|
1. 安装UE及Harix插件,打开默认项目并运行
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"Contriever"
|
||||||
|
],
|
||||||
|
"attention_probs_dropout_prob": 0.1,
|
||||||
|
"classifier_dropout": null,
|
||||||
|
"gradient_checkpointing": false,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.1,
|
||||||
|
"hidden_size": 768,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"layer_norm_eps": 1e-12,
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"model_type": "bert",
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"position_embedding_type": "absolute",
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
"transformers_version": "4.15.0",
|
||||||
|
"type_vocab_size": 2,
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 30522
|
||||||
|
}
|
|
@ -0,0 +1 @@
|
||||||
|
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1 @@
|
||||||
|
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "bert-base-uncased", "tokenizer_class": "BertTokenizer"}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,4 @@
|
||||||
|
pip install gdown
|
||||||
|
gdown 1IYNAkwawfCDiBL27BlBqGssxFQH9vOux
|
||||||
|
unzip enwiki_2020_intro_only.zip
|
||||||
|
rm enwiki_2020_intro_only.zip
|
|
@ -0,0 +1,731 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import datasets
|
||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
from functools import partial
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from typing import Optional, Dict, Sequence
|
||||||
|
import json
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
LlamaTokenizer,
|
||||||
|
LlamaTokenizerFast,
|
||||||
|
SchedulerType,
|
||||||
|
DataCollatorForSeq2Seq,
|
||||||
|
get_scheduler,
|
||||||
|
GPTNeoXTokenizerFast,
|
||||||
|
GPT2Tokenizer,
|
||||||
|
OPTForCausalLM
|
||||||
|
)
|
||||||
|
from peft import LoraConfig, TaskType, get_peft_model
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_DICT = {
|
||||||
|
"prompt_input": (
|
||||||
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||||
|
),
|
||||||
|
"prompt_no_input": (
|
||||||
|
"### Instruction:\n{instruction}\n\n### Response:\n"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The name of the dataset to use (via the datasets library).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_config_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The configuration name of the dataset to use (via the datasets library).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
type=str,
|
||||||
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Pretrained config name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_lora",
|
||||||
|
action="store_true",
|
||||||
|
help="If passed, will use LORA (low-rank parameter-efficient training) to train the model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_rank",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="The rank of lora.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_alpha",
|
||||||
|
type=float,
|
||||||
|
default=16,
|
||||||
|
help="The alpha parameter of lora.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_dropout",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="The dropout rate of lora modules.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_merged_lora_model",
|
||||||
|
action="store_true",
|
||||||
|
help="If passed, will merge the lora modules and save the entire model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_flash_attn",
|
||||||
|
action="store_true",
|
||||||
|
help="If passed, will use flash attention to train the model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_slow_tokenizer",
|
||||||
|
action="store_true",
|
||||||
|
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_seq_length",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="The maximum total sequence length (prompt+completion) of each training example.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per_device_train_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Batch size (per device) for the training dataloader.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=5e-5,
|
||||||
|
help="Initial learning rate (after the potential warmup period) to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||||
|
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_train_steps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_accumulation_steps",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_scheduler_type",
|
||||||
|
type=SchedulerType,
|
||||||
|
default="linear",
|
||||||
|
help="The scheduler type to use.",
|
||||||
|
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--warmup_ratio", type=float, default=0, help="Ratio of total training steps used for warmup."
|
||||||
|
)
|
||||||
|
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||||
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--preprocessing_num_workers",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="The number of processes to use for the preprocessing.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpointing_steps",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logging_steps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Log the training loss and learning rate every logging_steps steps.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume_from_checkpoint",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="If the training should continue from a checkpoint folder.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--with_tracking",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to enable experiment trackers for logging.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--report_to",
|
||||||
|
type=str,
|
||||||
|
default="all",
|
||||||
|
help=(
|
||||||
|
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||||
|
' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
|
||||||
|
"Only applicable when `--with_tracking` is passed."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--low_cpu_mem_usage",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
|
||||||
|
"If passed, LLM loading time and RAM consumption will be benefited."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_special_tokens",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Use special tokens."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Sanity checks
|
||||||
|
if args.dataset_name is None and args.train_file is None:
|
||||||
|
raise ValueError("Need either a dataset name or a training file.")
|
||||||
|
else:
|
||||||
|
if args.train_file is not None:
|
||||||
|
extension = args.train_file.split(".")[-1]
|
||||||
|
assert extension in ["json", "jsonl"], "`train_file` should be a json/jsonl file."
|
||||||
|
return args
|
||||||
|
|
||||||
|
def _tokenize_fn(text: str, tokenizer: transformers.PreTrainedTokenizer, max_seq_length: int) -> Dict:
|
||||||
|
"""Tokenize a list of strings."""
|
||||||
|
input_ids = labels = tokenizer(
|
||||||
|
text,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
max_length=max_seq_length,
|
||||||
|
truncation=True,
|
||||||
|
).input_ids
|
||||||
|
input_ids_lens = labels_lens = input_ids.ne(tokenizer.pad_token_id).sum().item()
|
||||||
|
print(input_ids_lens)
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
input_ids=input_ids,
|
||||||
|
labels=labels,
|
||||||
|
input_ids_lens=input_ids_lens,
|
||||||
|
labels_lens=labels_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_with_prompt_completion_format(example, tokenizer, max_seq_length, context_markups=None):
|
||||||
|
'''
|
||||||
|
Here we assume each example has 'prompt' and 'completion' fields.
|
||||||
|
We concatenate prompt and completion and tokenize them together because otherwise prompt will be padded/trancated
|
||||||
|
and it doesn't make sense to follow directly with the completion.
|
||||||
|
'''
|
||||||
|
# if prompt doesn't end with space and completion doesn't start with space, add space
|
||||||
|
|
||||||
|
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||||
|
source_text = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
||||||
|
target_text = example['output'] + tokenizer.eos_token
|
||||||
|
examples_tokenized = _tokenize_fn(source_text + target_text, tokenizer, max_seq_length)
|
||||||
|
sources_tokenized = _tokenize_fn(source_text, tokenizer, max_seq_length)
|
||||||
|
|
||||||
|
input_ids = examples_tokenized["input_ids"].flatten()
|
||||||
|
source_len = sources_tokenized["input_ids_lens"]
|
||||||
|
labels = copy.deepcopy(input_ids)
|
||||||
|
labels[ :source_len-1] = -100
|
||||||
|
|
||||||
|
if context_markups is not None:
|
||||||
|
context_start = False
|
||||||
|
for j, orig_token in enumerate(labels[source_len:]):
|
||||||
|
if context_start is False and orig_token == context_markups[0]:
|
||||||
|
context_start = True
|
||||||
|
assert labels[source_len+j] == context_markups[0]
|
||||||
|
start_idx = j+source_len
|
||||||
|
end_idx = None
|
||||||
|
for k, orig_token_2 in enumerate(labels[start_idx:]):
|
||||||
|
if orig_token_2 == context_markups[1]:
|
||||||
|
end_idx = start_idx + k
|
||||||
|
if end_idx is None:
|
||||||
|
end_idx = start_idx + k
|
||||||
|
else:
|
||||||
|
assert labels[end_idx] == context_markups[1]
|
||||||
|
labels[start_idx+1:end_idx] = -100
|
||||||
|
context_start = False
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'input_ids': input_ids.flatten(),
|
||||||
|
'labels': labels.flatten(),
|
||||||
|
'attention_mask': attention_mask.flatten()
|
||||||
|
}
|
||||||
|
|
||||||
|
def encode_with_messages_format(example, tokenizer, max_seq_length):
|
||||||
|
'''
|
||||||
|
Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields.
|
||||||
|
We concatenate all messages with the roles as delimiters and tokenize them together.
|
||||||
|
'''
|
||||||
|
messages = example['messages']
|
||||||
|
if len(messages) == 0:
|
||||||
|
raise ValueError('messages field is empty.')
|
||||||
|
|
||||||
|
def _concat_messages(messages):
|
||||||
|
message_text = ""
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
message_text += "<|system|>\n" + message["content"].strip() + "\n"
|
||||||
|
elif message["role"] == "user":
|
||||||
|
message_text += "<|user|>\n" + message["content"].strip() + "\n"
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
message_text += "<|assistant|>\n" + message["content"].strip() + tokenizer.eos_token + "\n"
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid role: {}".format(message["role"]))
|
||||||
|
return message_text
|
||||||
|
|
||||||
|
example_text = _concat_messages(messages).strip()
|
||||||
|
tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
|
||||||
|
input_ids = tokenized_example.input_ids
|
||||||
|
labels = input_ids.clone()
|
||||||
|
|
||||||
|
# mask the non-assistant part for avoiding loss
|
||||||
|
for message_idx, message in enumerate(messages):
|
||||||
|
if message["role"] != "assistant":
|
||||||
|
if message_idx == 0:
|
||||||
|
message_start_idx = 0
|
||||||
|
else:
|
||||||
|
message_start_idx = tokenizer(
|
||||||
|
_concat_messages(messages[:message_idx]), return_tensors='pt', max_length=max_seq_length, truncation=True
|
||||||
|
).input_ids.shape[1]
|
||||||
|
if message_idx < len(messages) - 1 and messages[message_idx+1]["role"] == "assistant":
|
||||||
|
# here we also ignore the role of the assistant
|
||||||
|
messages_so_far = _concat_messages(messages[:message_idx+1]) + "<|assistant|>\n"
|
||||||
|
else:
|
||||||
|
messages_so_far = _concat_messages(messages[:message_idx+1])
|
||||||
|
message_end_idx = tokenizer(
|
||||||
|
messages_so_far,
|
||||||
|
return_tensors='pt',
|
||||||
|
max_length=max_seq_length,
|
||||||
|
truncation=True
|
||||||
|
).input_ids.shape[1]
|
||||||
|
labels[:, message_start_idx:message_end_idx] = -100
|
||||||
|
|
||||||
|
if message_end_idx >= max_seq_length:
|
||||||
|
break
|
||||||
|
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
return {
|
||||||
|
'input_ids': input_ids.flatten(),
|
||||||
|
'labels': labels.flatten(),
|
||||||
|
'attention_mask': attention_mask.flatten(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# A hacky way to make llama work with flash attention
|
||||||
|
if args.use_flash_attn:
|
||||||
|
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
||||||
|
replace_llama_attn_with_flash_attn()
|
||||||
|
|
||||||
|
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
||||||
|
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
|
||||||
|
# in the environment
|
||||||
|
accelerator_log_kwargs = {}
|
||||||
|
|
||||||
|
if args.with_tracking:
|
||||||
|
accelerator_log_kwargs["log_with"] = args.report_to
|
||||||
|
accelerator_log_kwargs["project_dir"] = args.output_dir
|
||||||
|
|
||||||
|
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
|
||||||
|
|
||||||
|
# Make one log on every process with the configuration for debugging.
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
logger.info(accelerator.state, main_process_only=False)
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
datasets.utils.logging.set_verbosity_warning()
|
||||||
|
transformers.utils.logging.set_verbosity_info()
|
||||||
|
else:
|
||||||
|
datasets.utils.logging.set_verbosity_error()
|
||||||
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# If passed along, set the training seed now.
|
||||||
|
if args.seed is not None:
|
||||||
|
set_seed(args.seed)
|
||||||
|
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
if args.output_dir is not None:
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
if args.dataset_name is not None:
|
||||||
|
# Downloading and loading a dataset from the hub.
|
||||||
|
raw_datasets = load_dataset(
|
||||||
|
args.dataset_name,
|
||||||
|
args.dataset_config_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data_files = {}
|
||||||
|
dataset_args = {}
|
||||||
|
if args.train_file is not None:
|
||||||
|
data_files["train"] = args.train_file
|
||||||
|
raw_datasets = load_dataset(
|
||||||
|
"json",
|
||||||
|
data_files=data_files,
|
||||||
|
**dataset_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load pretrained model and tokenizer
|
||||||
|
if args.config_name:
|
||||||
|
config = AutoConfig.from_pretrained(args.config_name)
|
||||||
|
elif args.model_name_or_path:
|
||||||
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You are instantiating a new config instance from scratch. This is not supported by this script."
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.tokenizer_name:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
|
||||||
|
elif args.model_name_or_path:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||||
|
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.model_name_or_path:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
|
config=config,
|
||||||
|
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Training new model from scratch")
|
||||||
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
# no default pad token for llama!
|
||||||
|
# here we add all special tokens again, because the default ones are not in the special_tokens_map
|
||||||
|
if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast):
|
||||||
|
if args.use_special_tokens is True:
|
||||||
|
special_token_dict = {"additional_special_tokens": ["[No Retrieval]", "[Retrieval]", "[Continue to Use Evidence]", "[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]", "[Fully supported]", "[Partially supported]", "[No support / Contradictory]"]}
|
||||||
|
special_token_dict["bos_token"] = "<s>"
|
||||||
|
special_token_dict["eos_token"] = "</s>"
|
||||||
|
special_token_dict["unk_token"] = "<unk>"
|
||||||
|
special_token_dict["pad_token"] = "<pad>"
|
||||||
|
num_added_tokens = tokenizer.add_special_tokens(special_token_dict)
|
||||||
|
|
||||||
|
context_markups = []
|
||||||
|
for token in ["<paragraph>", "</paragraph>"]:
|
||||||
|
context_markups.append(tokenizer.convert_tokens_to_ids(token))
|
||||||
|
if args.use_special_tokens is False:
|
||||||
|
assert num_added_tokens in [0, 1], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present."
|
||||||
|
else:
|
||||||
|
assert num_added_tokens > 10, "special tokens must be added to the original tokenizers."
|
||||||
|
elif isinstance(tokenizer, GPTNeoXTokenizerFast):
|
||||||
|
num_added_tokens = tokenizer.add_special_tokens({
|
||||||
|
"pad_token": "<pad>",
|
||||||
|
})
|
||||||
|
assert num_added_tokens == 1, "GPTNeoXTokenizer should only add one special token - the pad_token."
|
||||||
|
elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM):
|
||||||
|
num_added_tokens = tokenizer.add_special_tokens({'unk_token': '<unk>'})
|
||||||
|
|
||||||
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
if args.use_lora:
|
||||||
|
logger.info("Initializing LORA model...")
|
||||||
|
modules_to_save = ["embed_tokens"]
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
task_type=TaskType.CAUSAL_LM,
|
||||||
|
inference_mode=False,
|
||||||
|
r=args.lora_rank,
|
||||||
|
#modules_to_save=modules_to_save,
|
||||||
|
lora_alpha=args.lora_alpha,
|
||||||
|
lora_dropout=args.lora_dropout
|
||||||
|
)
|
||||||
|
model = get_peft_model(model, peft_config)
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
|
|
||||||
|
encode_function = partial(
|
||||||
|
encode_with_prompt_completion_format,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
context_markups=context_markups if args.use_special_tokens is True else None
|
||||||
|
)
|
||||||
|
# elif "messages" in raw_datasets["train"].column_names:
|
||||||
|
# encode_function = partial(
|
||||||
|
# encode_with_messages_format,
|
||||||
|
# tokenizer=tokenizer,
|
||||||
|
# max_seq_length=args.max_seq_length,
|
||||||
|
# )
|
||||||
|
with accelerator.main_process_first():
|
||||||
|
lm_datasets = raw_datasets.map(
|
||||||
|
encode_function,
|
||||||
|
batched=False,
|
||||||
|
num_proc=args.preprocessing_num_workers,
|
||||||
|
load_from_cache_file=not args.overwrite_cache,
|
||||||
|
remove_columns=[name for name in raw_datasets["train"].column_names if name not in ["input_ids", "labels", "attention_mask"]],
|
||||||
|
desc="Tokenizing and reformatting instruction data",
|
||||||
|
)
|
||||||
|
lm_datasets.set_format(type="pt")
|
||||||
|
lm_datasets = lm_datasets.filter(lambda example: (example['labels'] != -100).any())
|
||||||
|
|
||||||
|
train_dataset = lm_datasets["train"]
|
||||||
|
#print(train_dataset[0])
|
||||||
|
#print(train_dataset[1000])
|
||||||
|
#print(train_dataset[500])
|
||||||
|
#print(train_dataset[2000])
|
||||||
|
#print(train_dataset[10000])
|
||||||
|
with open("processed.json", "w") as outfile:
|
||||||
|
new_data = []
|
||||||
|
for item in train_dataset:
|
||||||
|
print(item)
|
||||||
|
labels = [int(i) for i in item["labels"]]
|
||||||
|
input_ids = [int(i) for i in item["input_ids"]]
|
||||||
|
new_data.append({"labels": labels, "input_ids": input_ids})
|
||||||
|
json.dump(new_data, outfile)
|
||||||
|
# Log a few random samples from the training set:
|
||||||
|
for index in random.sample(range(len(train_dataset)), 3):
|
||||||
|
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||||
|
|
||||||
|
# DataLoaders creation:
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"),
|
||||||
|
batch_size=args.per_device_train_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
# Split weights in two groups, one with weight decay and the other not.
|
||||||
|
no_decay = ["bias", "layer_norm.weight"]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||||
|
|
||||||
|
# Scheduler and math around the number of training steps.
|
||||||
|
overrode_max_train_steps = False
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
if args.max_train_steps is None:
|
||||||
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||||
|
overrode_max_train_steps = True
|
||||||
|
|
||||||
|
# Create the learning rate scheduler.
|
||||||
|
# Note: the current accelerator.step() calls the .step() of the real scheduler for the `num_processes` times. This is because they assume
|
||||||
|
# the user initialize the scheduler with the entire training set. In the case of data parallel training, each process only
|
||||||
|
# sees a subset (1/num_processes) of the training set. So each time the process needs to update the lr multiple times so that the total
|
||||||
|
# number of updates in the end matches the num_training_steps here.
|
||||||
|
# Here we need to set the num_training_steps to either using the entire training set (when epochs is specified) or we need to multiply the
|
||||||
|
# num_training_steps by num_processes so that the total number of updates matches the num_training_steps.
|
||||||
|
num_training_steps_for_scheduler = args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
name=args.lr_scheduler_type,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_training_steps=num_training_steps_for_scheduler,
|
||||||
|
num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare everything with `accelerator`.
|
||||||
|
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
model, optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
if overrode_max_train_steps:
|
||||||
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||||
|
# Afterwards we recalculate our number of training epochs
|
||||||
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|
||||||
|
# Figure out how many steps we should save the Accelerator states
|
||||||
|
checkpointing_steps = args.checkpointing_steps
|
||||||
|
if checkpointing_steps is not None and checkpointing_steps.isdigit():
|
||||||
|
checkpointing_steps = int(checkpointing_steps)
|
||||||
|
|
||||||
|
# We need to initialize the trackers we use, and also store our configuration.
|
||||||
|
# The trackers initializes automatically on the main process.
|
||||||
|
if args.with_tracking:
|
||||||
|
experiment_config = vars(args)
|
||||||
|
# TensorBoard cannot log Enums, need the raw value
|
||||||
|
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
|
||||||
|
accelerator.init_trackers("open_instruct", experiment_config)
|
||||||
|
|
||||||
|
# Train!
|
||||||
|
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
logger.info("***** Running training *****")
|
||||||
|
logger.info(f" Num examples = {len(train_dataset)}")
|
||||||
|
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||||
|
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
||||||
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||||
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||||
|
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||||
|
# Only show the progress bar once on each machine.
|
||||||
|
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||||
|
completed_steps = 0
|
||||||
|
starting_epoch = 0
|
||||||
|
|
||||||
|
# Potentially load in the weights and states from a previous save
|
||||||
|
if args.resume_from_checkpoint:
|
||||||
|
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
|
||||||
|
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
|
||||||
|
accelerator.load_state(args.resume_from_checkpoint)
|
||||||
|
path = os.path.basename(args.resume_from_checkpoint)
|
||||||
|
else:
|
||||||
|
# Get the most recent checkpoint
|
||||||
|
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
|
||||||
|
dirs.sort(key=os.path.getctime)
|
||||||
|
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
|
||||||
|
# Extract `epoch_{i}` or `step_{i}`
|
||||||
|
training_difference = os.path.splitext(path)[0]
|
||||||
|
|
||||||
|
if "epoch" in training_difference:
|
||||||
|
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
|
||||||
|
resume_step = None
|
||||||
|
else:
|
||||||
|
# need to multiply `gradient_accumulation_steps` to reflect real steps
|
||||||
|
resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
|
||||||
|
starting_epoch = resume_step // len(train_dataloader)
|
||||||
|
resume_step -= starting_epoch * len(train_dataloader)
|
||||||
|
|
||||||
|
# update the progress_bar if load from checkpoint
|
||||||
|
progress_bar.update(starting_epoch * num_update_steps_per_epoch)
|
||||||
|
completed_steps = starting_epoch * num_update_steps_per_epoch
|
||||||
|
|
||||||
|
for epoch in range(starting_epoch, args.num_train_epochs):
|
||||||
|
model.train()
|
||||||
|
total_loss = 0
|
||||||
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
# We need to skip steps until we reach the resumed step
|
||||||
|
if args.resume_from_checkpoint and epoch == starting_epoch:
|
||||||
|
if resume_step is not None and completed_steps < resume_step:
|
||||||
|
if step % args.gradient_accumulation_steps == 0:
|
||||||
|
progress_bar.update(1)
|
||||||
|
completed_steps += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
with accelerator.accumulate(model):
|
||||||
|
outputs = model(**batch, use_cache=False)
|
||||||
|
loss = outputs.loss
|
||||||
|
# We keep track of the loss at each logged step
|
||||||
|
total_loss += loss.detach().float()
|
||||||
|
accelerator.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
# # Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
|
if accelerator.sync_gradients:
|
||||||
|
progress_bar.update(1)
|
||||||
|
completed_steps += 1
|
||||||
|
|
||||||
|
if args.logging_steps and completed_steps % args.logging_steps == 0:
|
||||||
|
avg_loss = accelerator.gather(total_loss).mean().item() / args.gradient_accumulation_steps / args.logging_steps
|
||||||
|
logger.info(f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}")
|
||||||
|
if args.with_tracking:
|
||||||
|
accelerator.log(
|
||||||
|
{
|
||||||
|
"learning_rate": lr_scheduler.get_last_lr()[0],
|
||||||
|
"train_loss": avg_loss,
|
||||||
|
},
|
||||||
|
step=completed_steps,
|
||||||
|
)
|
||||||
|
total_loss = 0
|
||||||
|
|
||||||
|
if isinstance(checkpointing_steps, int):
|
||||||
|
if completed_steps % checkpointing_steps == 0:
|
||||||
|
output_dir = f"step_{completed_steps}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
if completed_steps >= args.max_train_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
if args.checkpointing_steps == "epoch":
|
||||||
|
output_dir = f"epoch_{epoch}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
|
if args.with_tracking:
|
||||||
|
accelerator.end_training()
|
||||||
|
|
||||||
|
if args.output_dir is not None:
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
# When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict.
|
||||||
|
# Otherwise, sometimes the model will be saved with only part of the parameters.
|
||||||
|
# Also, accelerator needs to use the wrapped model to get the state_dict.
|
||||||
|
state_dict = accelerator.get_state_dict(model)
|
||||||
|
if args.use_lora:
|
||||||
|
# When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process
|
||||||
|
# and has its own save_pretrained function for only saving lora modules.
|
||||||
|
# We have to mannually specify the is_main_process outside the save_pretrained function.
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
unwrapped_model.save_pretrained(args.output_dir, state_dict=state_dict)
|
||||||
|
else:
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=state_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,115 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def embed_passages(args, passages, model, tokenizer):
|
||||||
|
total = 0
|
||||||
|
allids, allembeddings = [], []
|
||||||
|
batch_ids, batch_text = [], []
|
||||||
|
with torch.no_grad():
|
||||||
|
for k, p in enumerate(passages):
|
||||||
|
batch_ids.append(p["id"])
|
||||||
|
|
||||||
|
"""if args.no_title or not "title" in p:
|
||||||
|
text = p["text"]
|
||||||
|
else:
|
||||||
|
text = p["title"] + " " + p["text"]"""
|
||||||
|
text = p["title"]
|
||||||
|
if args.lowercase:
|
||||||
|
text = text.lower()
|
||||||
|
if args.normalize_text:
|
||||||
|
text = robowaiter.llm_client.retrieval_lm.src.normalize_text.normalize(text)
|
||||||
|
batch_text.append(text)
|
||||||
|
|
||||||
|
if len(batch_text) == args.per_gpu_batch_size or k == len(passages) - 1:
|
||||||
|
|
||||||
|
encoded_batch = tokenizer.batch_encode_plus(
|
||||||
|
batch_text,
|
||||||
|
return_tensors="pt",
|
||||||
|
max_length=args.passage_maxlength,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
||||||
|
embeddings = model(**encoded_batch)
|
||||||
|
|
||||||
|
embeddings = embeddings.cpu()
|
||||||
|
total += len(batch_ids)
|
||||||
|
allids.extend(batch_ids)
|
||||||
|
allembeddings.append(embeddings)
|
||||||
|
|
||||||
|
batch_text = []
|
||||||
|
batch_ids = []
|
||||||
|
if k % 100000 == 0 and k > 0:
|
||||||
|
print(f"Encoded passages {total}")
|
||||||
|
|
||||||
|
allembeddings = torch.cat(allembeddings, dim=0).numpy()
|
||||||
|
return allids, allembeddings
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
model, tokenizer, _ = robowaiter.llm_client.retrieval_lm.src.contriever.load_retriever(args.model_name_or_path)
|
||||||
|
print(f"Model loaded from {args.model_name_or_path}.", flush=True)
|
||||||
|
model.eval()
|
||||||
|
model = model.cuda()
|
||||||
|
if not args.no_fp16:
|
||||||
|
model = model.half()
|
||||||
|
|
||||||
|
passages = robowaiter.llm_client.retrieval_lm.src.data.load_passages(args.passages)
|
||||||
|
|
||||||
|
shard_size = len(passages) // args.num_shards
|
||||||
|
start_idx = args.shard_id * shard_size
|
||||||
|
end_idx = start_idx + shard_size
|
||||||
|
if args.shard_id == args.num_shards - 1:
|
||||||
|
end_idx = len(passages)
|
||||||
|
|
||||||
|
passages = passages[start_idx:end_idx]
|
||||||
|
print(f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.")
|
||||||
|
|
||||||
|
allids, allembeddings = embed_passages(args, passages, model, tokenizer)
|
||||||
|
|
||||||
|
save_file = os.path.join(args.output_dir, args.prefix + f"_{args.shard_id:02d}")
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
print(f"Saving {len(allids)} passage embeddings to {save_file}.")
|
||||||
|
with open(save_file, mode="wb") as f:
|
||||||
|
pickle.dump((allids, allembeddings), f)
|
||||||
|
|
||||||
|
print(f"Total passages processed {len(allids)}. Written to {save_file}.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
|
||||||
|
parser.add_argument("--output_dir", type=str, default="wikipedia_embeddings", help="dir path to save embeddings")
|
||||||
|
parser.add_argument("--prefix", type=str, default="passages", help="prefix path to save embeddings")
|
||||||
|
parser.add_argument("--shard_id", type=int, default=0, help="Id of the current shard")
|
||||||
|
parser.add_argument("--num_shards", type=int, default=1, help="Total number of shards")
|
||||||
|
parser.add_argument(
|
||||||
|
"--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass"
|
||||||
|
)
|
||||||
|
parser.add_argument("--passage_maxlength", type=int, default=512, help="Maximum number of tokens in a passage")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path", type=str, help="path to directory containing model weights and config file"
|
||||||
|
)
|
||||||
|
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
|
||||||
|
parser.add_argument("--no_title", action="store_true", help="title not added to the passage body")
|
||||||
|
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
|
||||||
|
parser.add_argument("--normalize_text", action="store_true", help="lowercase text before encoding")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
robowaiter.llm_client.retrieval_lm.src.slurm.init_distributed_mode(args)
|
||||||
|
|
||||||
|
main(args)
|
|
@ -0,0 +1,119 @@
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
||||||
|
except ImportError:
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||||
|
|
||||||
|
from flash_attn.bert_padding import unpad_input, pad_input
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel
|
||||||
|
|
||||||
|
attention_mask: [bsz, q_len]
|
||||||
|
"""
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = (
|
||||||
|
self.q_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
key_states = (
|
||||||
|
self.k_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
value_states = (
|
||||||
|
self.v_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
# [bsz, q_len, nh, hd]
|
||||||
|
# [bsz, nh, q_len, hd]
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
assert past_key_value is None, "past_key_value is not supported"
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
assert not output_attentions, "output_attentions is not supported"
|
||||||
|
assert not use_cache, "use_cache is not supported"
|
||||||
|
|
||||||
|
# Flash attention codes from
|
||||||
|
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
||||||
|
|
||||||
|
# transform the data into the format required by flash attention
|
||||||
|
qkv = torch.stack(
|
||||||
|
[query_states, key_states, value_states], dim=2
|
||||||
|
) # [bsz, nh, 3, q_len, hd]
|
||||||
|
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||||
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
|
key_padding_mask = attention_mask
|
||||||
|
|
||||||
|
if key_padding_mask is None:
|
||||||
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
max_s = q_len
|
||||||
|
cu_q_lens = torch.arange(
|
||||||
|
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
|
||||||
|
)
|
||||||
|
output = flash_attn_unpadded_qkvpacked_func(
|
||||||
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
|
)
|
||||||
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
else:
|
||||||
|
nheads = qkv.shape[-2]
|
||||||
|
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||||
|
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||||
|
x_unpad = rearrange(
|
||||||
|
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
||||||
|
)
|
||||||
|
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
||||||
|
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
|
)
|
||||||
|
output = rearrange(
|
||||||
|
pad_input(
|
||||||
|
rearrange(output_unpad,
|
||||||
|
"nnz h d -> nnz (h d)"), indices, bsz, q_len
|
||||||
|
),
|
||||||
|
"b s (h d) -> b s h d",
|
||||||
|
h=nheads,
|
||||||
|
)
|
||||||
|
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
|
||||||
|
|
||||||
|
|
||||||
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
|
def _prepare_decoder_attention_mask(
|
||||||
|
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
|
):
|
||||||
|
# [bsz, seq_len]
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def replace_llama_attn_with_flash_attn():
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
||||||
|
_prepare_decoder_attention_mask
|
||||||
|
)
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
|
@ -0,0 +1,87 @@
|
||||||
|
import numpy as np
|
||||||
|
import string
|
||||||
|
import re
|
||||||
|
from collections import Counter
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def exact_match_score(prediction, ground_truth):
|
||||||
|
return (normalize_answer(prediction) == normalize_answer(ground_truth))
|
||||||
|
|
||||||
|
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||||
|
scores_for_ground_truths = []
|
||||||
|
for ground_truth in ground_truths:
|
||||||
|
score = metric_fn(prediction, ground_truth)
|
||||||
|
scores_for_ground_truths.append(score)
|
||||||
|
return max(scores_for_ground_truths)
|
||||||
|
|
||||||
|
def accuracy(preds, labels):
|
||||||
|
match_count = 0
|
||||||
|
for pred, label in zip(preds, labels):
|
||||||
|
target = label[0]
|
||||||
|
if pred == target:
|
||||||
|
match_count += 1
|
||||||
|
|
||||||
|
return 100 * (match_count / len(preds))
|
||||||
|
|
||||||
|
|
||||||
|
def f1(decoded_preds, decoded_labels):
|
||||||
|
f1_all = []
|
||||||
|
for prediction, answers in zip(decoded_preds, decoded_labels):
|
||||||
|
if type(answers) == list:
|
||||||
|
if len(answers) == 0:
|
||||||
|
return 0
|
||||||
|
f1_all.append(np.max([qa_f1_score(prediction, gt)
|
||||||
|
for gt in answers]))
|
||||||
|
else:
|
||||||
|
f1_all.append(qa_f1_score(prediction, answers))
|
||||||
|
return 100 * np.mean(f1_all)
|
||||||
|
|
||||||
|
|
||||||
|
def qa_f1_score(prediction, ground_truth):
|
||||||
|
prediction_tokens = normalize_answer(prediction).split()
|
||||||
|
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||||
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||||
|
num_same = sum(common.values())
|
||||||
|
if num_same == 0:
|
||||||
|
return 0
|
||||||
|
precision = 1.0 * num_same / len(prediction_tokens)
|
||||||
|
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||||
|
f1 = (2 * precision * recall) / (precision + recall)
|
||||||
|
return f1
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_answer(s):
|
||||||
|
def remove_articles(text):
|
||||||
|
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
||||||
|
|
||||||
|
def white_space_fix(text):
|
||||||
|
return ' '.join(text.split())
|
||||||
|
|
||||||
|
def remove_punc(text):
|
||||||
|
exclude = set(string.punctuation)
|
||||||
|
return ''.join(ch for ch in text if ch not in exclude)
|
||||||
|
|
||||||
|
def lower(text):
|
||||||
|
return text.lower()
|
||||||
|
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||||
|
|
||||||
|
def find_entity_tags(sentence):
|
||||||
|
entity_regex = r'(.+?)(?=\s<|$)'
|
||||||
|
tag_regex = r'<(.+?)>'
|
||||||
|
entity_names = re.findall(entity_regex, sentence)
|
||||||
|
tags = re.findall(tag_regex, sentence)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for entity, tag in zip(entity_names, tags):
|
||||||
|
if "<" in entity:
|
||||||
|
results[entity.split("> ")[1]] = tag
|
||||||
|
else:
|
||||||
|
results[entity] = tag
|
||||||
|
return results
|
||||||
|
|
||||||
|
def match(prediction, ground_truth):
|
||||||
|
for gt in ground_truth:
|
||||||
|
if gt in prediction:
|
||||||
|
return 1
|
||||||
|
return 0
|
|
@ -0,0 +1,8 @@
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
python3 ../generate_passage_embeddings.py \
|
||||||
|
--model_name_or_path ../../model/contriever-msmarco \
|
||||||
|
--passages train_robot.jsonl \
|
||||||
|
--output_dir robot_embeddings \
|
||||||
|
--shard_id 0 \
|
||||||
|
--num_shards 1 \
|
||||||
|
--per_gpu_batch_size 500
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,250 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
import glob
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
import src.index
|
||||||
|
import src.contriever
|
||||||
|
import src.utils
|
||||||
|
import src.slurm
|
||||||
|
import src.data
|
||||||
|
from src.evaluation import calculate_matches
|
||||||
|
import src.normalize_text
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
def embed_queries(args, queries, model, tokenizer):
|
||||||
|
model.eval()
|
||||||
|
embeddings, batch_question = [], []
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
for k, q in enumerate(queries):
|
||||||
|
if args.lowercase:
|
||||||
|
q = q.lower()
|
||||||
|
if args.normalize_text:
|
||||||
|
q = src.normalize_text.normalize(q)
|
||||||
|
batch_question.append(q)
|
||||||
|
|
||||||
|
if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
|
||||||
|
|
||||||
|
encoded_batch = tokenizer.batch_encode_plus(
|
||||||
|
batch_question,
|
||||||
|
return_tensors="pt",
|
||||||
|
max_length=args.question_maxlength,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
||||||
|
output = model(**encoded_batch)
|
||||||
|
embeddings.append(output.cpu())
|
||||||
|
|
||||||
|
batch_question = []
|
||||||
|
|
||||||
|
embeddings = torch.cat(embeddings, dim=0)
|
||||||
|
print(f"Questions embeddings shape: {embeddings.size()}")
|
||||||
|
|
||||||
|
return embeddings.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
def index_encoded_data(index, embedding_files, indexing_batch_size):
|
||||||
|
allids = []
|
||||||
|
allembeddings = np.array([])
|
||||||
|
for i, file_path in enumerate(embedding_files):
|
||||||
|
print(f"Loading file {file_path}")
|
||||||
|
with open(file_path, "rb") as fin:
|
||||||
|
ids, embeddings = pickle.load(fin)
|
||||||
|
|
||||||
|
allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings
|
||||||
|
allids.extend(ids)
|
||||||
|
while allembeddings.shape[0] > indexing_batch_size:
|
||||||
|
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
||||||
|
|
||||||
|
while allembeddings.shape[0] > 0:
|
||||||
|
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
||||||
|
|
||||||
|
print("Data indexing completed.")
|
||||||
|
|
||||||
|
|
||||||
|
def add_embeddings(index, embeddings, ids, indexing_batch_size):
|
||||||
|
end_idx = min(indexing_batch_size, embeddings.shape[0])
|
||||||
|
ids_toadd = ids[:end_idx]
|
||||||
|
embeddings_toadd = embeddings[:end_idx]
|
||||||
|
ids = ids[end_idx:]
|
||||||
|
embeddings = embeddings[end_idx:]
|
||||||
|
index.index_data(ids_toadd, embeddings_toadd)
|
||||||
|
return embeddings, ids
|
||||||
|
|
||||||
|
|
||||||
|
def validate(data, workers_num):
|
||||||
|
match_stats = calculate_matches(data, workers_num)
|
||||||
|
top_k_hits = match_stats.top_k_hits
|
||||||
|
|
||||||
|
print("Validation results: top k documents hits %s", top_k_hits)
|
||||||
|
top_k_hits = [v / len(data) for v in top_k_hits]
|
||||||
|
message = ""
|
||||||
|
for k in [5, 10, 20, 100]:
|
||||||
|
if k <= len(top_k_hits):
|
||||||
|
message += f"R@{k}: {top_k_hits[k-1]} "
|
||||||
|
print(message)
|
||||||
|
return match_stats.questions_doc_hits
|
||||||
|
|
||||||
|
|
||||||
|
def add_passages(data, passages, top_passages_and_scores):
|
||||||
|
# add passages to original data
|
||||||
|
merged_data = []
|
||||||
|
assert len(data) == len(top_passages_and_scores)
|
||||||
|
for i, d in enumerate(data):
|
||||||
|
results_and_scores = top_passages_and_scores[i]
|
||||||
|
#print(passages[2393])
|
||||||
|
docs = [passages[int(doc_id)] for doc_id in results_and_scores[0]]
|
||||||
|
scores = [str(score) for score in results_and_scores[1]]
|
||||||
|
ctxs_num = len(docs)
|
||||||
|
d["ctxs"] = [
|
||||||
|
{
|
||||||
|
"id": results_and_scores[0][c],
|
||||||
|
"title": docs[c]["title"],
|
||||||
|
"text": docs[c]["text"],
|
||||||
|
"score": scores[c],
|
||||||
|
}
|
||||||
|
for c in range(ctxs_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def add_hasanswer(data, hasanswer):
|
||||||
|
# add hasanswer to data
|
||||||
|
for i, ex in enumerate(data):
|
||||||
|
for k, d in enumerate(ex["ctxs"]):
|
||||||
|
d["hasanswer"] = hasanswer[i][k]
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(data_path):
|
||||||
|
if data_path.endswith(".json"):
|
||||||
|
with open(data_path, "r") as fin:
|
||||||
|
data = json.load(fin)
|
||||||
|
elif data_path.endswith(".jsonl"):
|
||||||
|
data = []
|
||||||
|
with open(data_path, "r") as fin:
|
||||||
|
for k, example in enumerate(fin):
|
||||||
|
example = json.loads(example)
|
||||||
|
data.append(example)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
|
||||||
|
print(f"Loading model from: {args.model_name_or_path}")
|
||||||
|
model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path)
|
||||||
|
model.eval()
|
||||||
|
model = model.cuda()
|
||||||
|
if not args.no_fp16:
|
||||||
|
model = model.half()
|
||||||
|
|
||||||
|
index = src.index.Indexer(args.projection_size, args.n_subquantizers, args.n_bits)
|
||||||
|
|
||||||
|
# index all passages
|
||||||
|
input_paths = glob.glob(args.passages_embeddings)
|
||||||
|
input_paths = sorted(input_paths)
|
||||||
|
embeddings_dir = os.path.dirname(input_paths[0])
|
||||||
|
index_path = os.path.join(embeddings_dir, "index.faiss")
|
||||||
|
if args.save_or_load_index and os.path.exists(index_path):
|
||||||
|
index.deserialize_from(embeddings_dir)
|
||||||
|
else:
|
||||||
|
print(f"Indexing passages from files {input_paths}")
|
||||||
|
start_time_indexing = time.time()
|
||||||
|
index_encoded_data(index, input_paths, args.indexing_batch_size)
|
||||||
|
print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
|
||||||
|
if args.save_or_load_index:
|
||||||
|
index.serialize(embeddings_dir)
|
||||||
|
|
||||||
|
# load passages
|
||||||
|
passages = src.data.load_passages(args.passages)
|
||||||
|
passage_id_map = {x["id"]: x for x in passages}
|
||||||
|
|
||||||
|
data_paths = glob.glob(args.data)
|
||||||
|
alldata = []
|
||||||
|
for path in data_paths:
|
||||||
|
data = load_data(path)
|
||||||
|
output_path = os.path.join(args.output_dir, os.path.basename(path))
|
||||||
|
|
||||||
|
queries = [ex["question"] for ex in data]
|
||||||
|
questions_embedding = embed_queries(args, queries, model, tokenizer)
|
||||||
|
|
||||||
|
# get top k results
|
||||||
|
start_time_retrieval = time.time()
|
||||||
|
top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
|
||||||
|
print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
|
||||||
|
|
||||||
|
add_passages(data, passage_id_map, top_ids_and_scores)
|
||||||
|
#hasanswer = validate(data, args.validation_workers)
|
||||||
|
#add_hasanswer(data, hasanswer)
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
with open(output_path, "w") as fout:
|
||||||
|
for ex in data:
|
||||||
|
json.dump(ex, fout, ensure_ascii=False)
|
||||||
|
fout.write("\n")
|
||||||
|
print(f"Saved results to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--data",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=".json file containing question and answers, similar format to reader data",
|
||||||
|
)
|
||||||
|
parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
|
||||||
|
parser.add_argument("--passages_embeddings", type=str, default=None, help="Glob path to encoded passages")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir", type=str, default=None, help="Results are written to outputdir with data suffix"
|
||||||
|
)
|
||||||
|
parser.add_argument("--n_docs", type=int, default=100, help="Number of documents to retrieve per questions")
|
||||||
|
parser.add_argument(
|
||||||
|
"--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
|
||||||
|
)
|
||||||
|
parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path", type=str, help="path to directory containing model weights and config file"
|
||||||
|
)
|
||||||
|
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
|
||||||
|
parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
|
||||||
|
parser.add_argument(
|
||||||
|
"--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
|
||||||
|
)
|
||||||
|
parser.add_argument("--projection_size", type=int, default=768)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_subquantizers",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Number of subquantizer used for vector quantization, if 0 flat index is used",
|
||||||
|
)
|
||||||
|
parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer")
|
||||||
|
parser.add_argument("--lang", nargs="+")
|
||||||
|
parser.add_argument("--dataset", type=str, default="none")
|
||||||
|
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
|
||||||
|
parser.add_argument("--normalize_text", action="store_true", help="normalize text")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
src.slurm.init_distributed_mode(args)
|
||||||
|
main(args)
|
|
@ -0,0 +1,41 @@
|
||||||
|
import json
|
||||||
|
import jsonlines
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
filename=args.passages
|
||||||
|
with open(filename, 'r', encoding="utf-8") as f:
|
||||||
|
k=0
|
||||||
|
for line in f:
|
||||||
|
data = json.loads(line)
|
||||||
|
dict={"id":k,'title':data['title'],'text':data['text']}
|
||||||
|
k+=1
|
||||||
|
with jsonlines.open("train_robot.jsonl", "a") as file_jsonl:
|
||||||
|
file_jsonl.write(dict)
|
||||||
|
|
||||||
|
def test(args):
|
||||||
|
filename = args.passages
|
||||||
|
with open(filename, 'r', encoding="utf-8") as f:
|
||||||
|
k=0
|
||||||
|
for line in f:
|
||||||
|
if k<1000:
|
||||||
|
data = json.loads(line)
|
||||||
|
dict={"id":data['id'],'question':data['title'],'answers':data['text']}
|
||||||
|
k+=1
|
||||||
|
with jsonlines.open("test_robot.jsonl", "a") as file_jsonl:
|
||||||
|
file_jsonl.write(dict)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--passages", type=str, default=None, help="Path to passages")
|
||||||
|
parser.add_argument("--mode", type=str, default=None, help="train or test")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.mode=='train':
|
||||||
|
train(args)
|
||||||
|
elif args.mode=='test':
|
||||||
|
test(args)
|
||||||
|
else:
|
||||||
|
print("error mode!")
|
Binary file not shown.
|
@ -0,0 +1,2 @@
|
||||||
|
{"id": 0, "question": "请把酸奶放在咖啡台上,并打开窗帘。", "ctxs": [{"id": "0", "title": "请把酸奶放在咖啡台上,并打开窗帘。", "text": "On(Yogurt,CoffeeTable),Is(Curtain,Open)", "score": "1.9694625"}, {"id": "1", "title": "可以把牛奶饮料放在2号桌子上吗?还有关掉灯光。", "text": "On(MilkDrink,Table2),Is(TubeLight,Off)", "score": "1.8284101"}, {"id": "2", "title": "你好,可以给我上一份甜点吗?", "text": "On(Dessert,Table1)", "score": "1.4835652"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "1.4412252"}, {"id": "4", "title": "可以送一瓶牛奶饮料到1号桌吗?", "text": "On(MilkDrink,Table1)", "score": "1.2867957"}, {"id": "3", "title": "你能到另一个吧台这边来吗?空调可以关掉吗?", "text": "At(Robot,Bar2),Is(AC,On)", "score": "1.2599907"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}]}
|
||||||
|
{"id": 1, "question": "可以把牛奶饮料放在2号桌子上吗?还有关掉灯光。", "ctxs": [{"id": "1", "title": "可以把牛奶饮料放在2号桌子上吗?还有关掉灯光。", "text": "On(MilkDrink,Table2),Is(TubeLight,Off)", "score": "2.138029"}, {"id": "0", "title": "请把酸奶放在咖啡台上,并打开窗帘。", "text": "On(Yogurt,CoffeeTable),Is(Curtain,Open)", "score": "1.8282425"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "1.6972268"}, {"id": "2", "title": "你好,可以给我上一份甜点吗?", "text": "On(Dessert,Table1)", "score": "1.4741647"}, {"id": "4", "title": "可以送一瓶牛奶饮料到1号桌吗?", "text": "On(MilkDrink,Table1)", "score": "1.4532053"}, {"id": "3", "title": "你能到另一个吧台这边来吗?空调可以关掉吗?", "text": "At(Robot,Bar2),Is(AC,On)", "score": "1.3438905"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗?还有,能关掉筒灯吗?", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}]}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,208 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import List, Dict
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
import beir.util
|
||||||
|
from beir.datasets.data_loader import GenericDataLoader
|
||||||
|
from beir.retrieval.evaluation import EvaluateRetrieval
|
||||||
|
from beir.retrieval.search.dense import DenseRetrievalExactSearch
|
||||||
|
|
||||||
|
from beir.reranking.models import CrossEncoder
|
||||||
|
from beir.reranking import Rerank
|
||||||
|
|
||||||
|
import src.dist_utils as dist_utils
|
||||||
|
from src import normalize_text
|
||||||
|
|
||||||
|
|
||||||
|
class DenseEncoderModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_encoder,
|
||||||
|
doc_encoder=None,
|
||||||
|
tokenizer=None,
|
||||||
|
max_length=512,
|
||||||
|
add_special_tokens=True,
|
||||||
|
norm_query=False,
|
||||||
|
norm_doc=False,
|
||||||
|
lower_case=False,
|
||||||
|
normalize_text=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.query_encoder = query_encoder
|
||||||
|
self.doc_encoder = doc_encoder
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_length = max_length
|
||||||
|
self.add_special_tokens = add_special_tokens
|
||||||
|
self.norm_query = norm_query
|
||||||
|
self.norm_doc = norm_doc
|
||||||
|
self.lower_case = lower_case
|
||||||
|
self.normalize_text = normalize_text
|
||||||
|
|
||||||
|
def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray:
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
idx = np.array_split(range(len(queries)), dist.get_world_size())[dist.get_rank()]
|
||||||
|
else:
|
||||||
|
idx = range(len(queries))
|
||||||
|
|
||||||
|
queries = [queries[i] for i in idx]
|
||||||
|
if self.normalize_text:
|
||||||
|
queries = [normalize_text.normalize(q) for q in queries]
|
||||||
|
if self.lower_case:
|
||||||
|
queries = [q.lower() for q in queries]
|
||||||
|
|
||||||
|
allemb = []
|
||||||
|
nbatch = (len(queries) - 1) // batch_size + 1
|
||||||
|
with torch.no_grad():
|
||||||
|
for k in range(nbatch):
|
||||||
|
start_idx = k * batch_size
|
||||||
|
end_idx = min((k + 1) * batch_size, len(queries))
|
||||||
|
|
||||||
|
qencode = self.tokenizer.batch_encode_plus(
|
||||||
|
queries[start_idx:end_idx],
|
||||||
|
max_length=self.max_length,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
add_special_tokens=self.add_special_tokens,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
qencode = {key: value.cuda() for key, value in qencode.items()}
|
||||||
|
emb = self.query_encoder(**qencode, normalize=self.norm_query)
|
||||||
|
allemb.append(emb.cpu())
|
||||||
|
|
||||||
|
allemb = torch.cat(allemb, dim=0)
|
||||||
|
allemb = allemb.cuda()
|
||||||
|
if dist.is_initialized():
|
||||||
|
allemb = dist_utils.varsize_gather_nograd(allemb)
|
||||||
|
allemb = allemb.cpu().numpy()
|
||||||
|
return allemb
|
||||||
|
|
||||||
|
def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs):
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
idx = np.array_split(range(len(corpus)), dist.get_world_size())[dist.get_rank()]
|
||||||
|
else:
|
||||||
|
idx = range(len(corpus))
|
||||||
|
corpus = [corpus[i] for i in idx]
|
||||||
|
corpus = [c["title"] + " " + c["text"] if len(c["title"]) > 0 else c["text"] for c in corpus]
|
||||||
|
if self.normalize_text:
|
||||||
|
corpus = [normalize_text.normalize(c) for c in corpus]
|
||||||
|
if self.lower_case:
|
||||||
|
corpus = [c.lower() for c in corpus]
|
||||||
|
|
||||||
|
allemb = []
|
||||||
|
nbatch = (len(corpus) - 1) // batch_size + 1
|
||||||
|
with torch.no_grad():
|
||||||
|
for k in range(nbatch):
|
||||||
|
start_idx = k * batch_size
|
||||||
|
end_idx = min((k + 1) * batch_size, len(corpus))
|
||||||
|
|
||||||
|
cencode = self.tokenizer.batch_encode_plus(
|
||||||
|
corpus[start_idx:end_idx],
|
||||||
|
max_length=self.max_length,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
add_special_tokens=self.add_special_tokens,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
cencode = {key: value.cuda() for key, value in cencode.items()}
|
||||||
|
emb = self.doc_encoder(**cencode, normalize=self.norm_doc)
|
||||||
|
allemb.append(emb.cpu())
|
||||||
|
|
||||||
|
allemb = torch.cat(allemb, dim=0)
|
||||||
|
allemb = allemb.cuda()
|
||||||
|
if dist.is_initialized():
|
||||||
|
allemb = dist_utils.varsize_gather_nograd(allemb)
|
||||||
|
allemb = allemb.cpu().numpy()
|
||||||
|
return allemb
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model(
|
||||||
|
query_encoder,
|
||||||
|
doc_encoder,
|
||||||
|
tokenizer,
|
||||||
|
dataset,
|
||||||
|
batch_size=128,
|
||||||
|
add_special_tokens=True,
|
||||||
|
norm_query=False,
|
||||||
|
norm_doc=False,
|
||||||
|
is_main=True,
|
||||||
|
split="test",
|
||||||
|
score_function="dot",
|
||||||
|
beir_dir="BEIR/datasets",
|
||||||
|
save_results_path=None,
|
||||||
|
lower_case=False,
|
||||||
|
normalize_text=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
metrics = defaultdict(list) # store final results
|
||||||
|
|
||||||
|
if hasattr(query_encoder, "module"):
|
||||||
|
query_encoder = query_encoder.module
|
||||||
|
query_encoder.eval()
|
||||||
|
|
||||||
|
if doc_encoder is not None:
|
||||||
|
if hasattr(doc_encoder, "module"):
|
||||||
|
doc_encoder = doc_encoder.module
|
||||||
|
doc_encoder.eval()
|
||||||
|
else:
|
||||||
|
doc_encoder = query_encoder
|
||||||
|
|
||||||
|
dmodel = DenseRetrievalExactSearch(
|
||||||
|
DenseEncoderModel(
|
||||||
|
query_encoder=query_encoder,
|
||||||
|
doc_encoder=doc_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
norm_query=norm_query,
|
||||||
|
norm_doc=norm_doc,
|
||||||
|
lower_case=lower_case,
|
||||||
|
normalize_text=normalize_text,
|
||||||
|
),
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
retriever = EvaluateRetrieval(dmodel, score_function=score_function)
|
||||||
|
data_path = os.path.join(beir_dir, dataset)
|
||||||
|
|
||||||
|
if not os.path.isdir(data_path) and is_main:
|
||||||
|
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
|
||||||
|
data_path = beir.util.download_and_unzip(url, beir_dir)
|
||||||
|
dist_utils.barrier()
|
||||||
|
|
||||||
|
if not dataset == "cqadupstack":
|
||||||
|
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
|
||||||
|
results = retriever.retrieve(corpus, queries)
|
||||||
|
if is_main:
|
||||||
|
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
|
||||||
|
for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"):
|
||||||
|
if isinstance(metric, str):
|
||||||
|
metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric)
|
||||||
|
for key, value in metric.items():
|
||||||
|
metrics[key].append(value)
|
||||||
|
if save_results_path is not None:
|
||||||
|
torch.save(results, f"{save_results_path}")
|
||||||
|
elif dataset == "cqadupstack": # compute macroaverage over datasets
|
||||||
|
paths = glob.glob(data_path)
|
||||||
|
for path in paths:
|
||||||
|
corpus, queries, qrels = GenericDataLoader(data_folder=data_folder).load(split=split)
|
||||||
|
results = retriever.retrieve(corpus, queries)
|
||||||
|
if is_main:
|
||||||
|
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
|
||||||
|
for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"):
|
||||||
|
if isinstance(metric, str):
|
||||||
|
metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric)
|
||||||
|
for key, value in metric.items():
|
||||||
|
metrics[key].append(value)
|
||||||
|
for key, value in metrics.items():
|
||||||
|
assert (
|
||||||
|
len(value) == 12
|
||||||
|
), f"cqadupstack includes 12 datasets, only {len(value)} values were compute for the {key} metric"
|
||||||
|
|
||||||
|
metrics = {key: 100 * np.mean(value) for key, value in metrics.items()}
|
||||||
|
|
||||||
|
return metrics
|
|
@ -0,0 +1,139 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from transformers import BertModel, XLMRobertaModel
|
||||||
|
|
||||||
|
from robowaiter.algos.retrieval.retrieval_lm.src import utils
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Contriever(BertModel):
|
||||||
|
def __init__(self, config, pooling="average", **kwargs):
|
||||||
|
super().__init__(config, add_pooling_layer=False)
|
||||||
|
if not hasattr(config, "pooling"):
|
||||||
|
self.config.pooling = pooling
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
normalize=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
model_output = super().forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden = model_output["last_hidden_state"]
|
||||||
|
last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||||
|
|
||||||
|
if self.config.pooling == "average":
|
||||||
|
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||||
|
elif self.config.pooling == "cls":
|
||||||
|
emb = last_hidden[:, 0]
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
emb = torch.nn.functional.normalize(emb, dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class XLMRetriever(XLMRobertaModel):
|
||||||
|
def __init__(self, config, pooling="average", **kwargs):
|
||||||
|
super().__init__(config, add_pooling_layer=False)
|
||||||
|
if not hasattr(config, "pooling"):
|
||||||
|
self.config.pooling = pooling
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
normalize=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
model_output = super().forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden = model_output["last_hidden_state"]
|
||||||
|
last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||||
|
if self.config.pooling == "average":
|
||||||
|
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||||
|
elif self.config.pooling == "cls":
|
||||||
|
emb = last_hidden[:, 0]
|
||||||
|
if normalize:
|
||||||
|
emb = torch.nn.functional.normalize(emb, dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def load_retriever(model_path, pooling="average", random_init=False):
|
||||||
|
# try: check if model exists locally
|
||||||
|
path = os.path.join(model_path, "checkpoint.pth")
|
||||||
|
if os.path.exists(path):
|
||||||
|
pretrained_dict = torch.load(path, map_location="cpu")
|
||||||
|
opt = pretrained_dict["opt"]
|
||||||
|
if hasattr(opt, "retriever_model_id"):
|
||||||
|
retriever_model_id = opt.retriever_model_id
|
||||||
|
else:
|
||||||
|
# retriever_model_id = "bert-base-uncased"
|
||||||
|
retriever_model_id = "bert-base-multilingual-cased"
|
||||||
|
tokenizer = utils.load_hf(transformers.AutoTokenizer, retriever_model_id)
|
||||||
|
cfg = utils.load_hf(transformers.AutoConfig, retriever_model_id)
|
||||||
|
if "xlm" in retriever_model_id:
|
||||||
|
model_class = XLMRetriever
|
||||||
|
else:
|
||||||
|
model_class = Contriever
|
||||||
|
retriever = model_class(cfg)
|
||||||
|
pretrained_dict = pretrained_dict["model"]
|
||||||
|
|
||||||
|
if any("encoder_q." in key for key in pretrained_dict.keys()): # test if model is defined with moco class
|
||||||
|
pretrained_dict = {k.replace("encoder_q.", ""): v for k, v in pretrained_dict.items() if "encoder_q." in k}
|
||||||
|
elif any("encoder." in key for key in pretrained_dict.keys()): # test if model is defined with inbatch class
|
||||||
|
pretrained_dict = {k.replace("encoder.", ""): v for k, v in pretrained_dict.items() if "encoder." in k}
|
||||||
|
retriever.load_state_dict(pretrained_dict, strict=False)
|
||||||
|
else:
|
||||||
|
retriever_model_id = model_path
|
||||||
|
if "xlm" in retriever_model_id:
|
||||||
|
model_class = XLMRetriever
|
||||||
|
else:
|
||||||
|
model_class = Contriever
|
||||||
|
cfg = utils.load_hf(transformers.AutoConfig, model_path)
|
||||||
|
tokenizer = utils.load_hf(transformers.AutoTokenizer, model_path)
|
||||||
|
retriever = utils.load_hf(model_class, model_path)
|
||||||
|
|
||||||
|
return retriever, tokenizer, retriever_model_id
|
|
@ -0,0 +1,243 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
import csv
|
||||||
|
import numpy as np
|
||||||
|
import numpy.random
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from robowaiter.algos.retrieval.retrieval_lm.src import dist_utils
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(opt, tokenizer):
|
||||||
|
datasets = {}
|
||||||
|
for path in opt.train_data:
|
||||||
|
data = load_dataset(path, opt.loading_mode)
|
||||||
|
if data is not None:
|
||||||
|
datasets[path] = Dataset(data, opt.chunk_length, tokenizer, opt)
|
||||||
|
dataset = MultiDataset(datasets)
|
||||||
|
dataset.set_prob(coeff=opt.sampling_coefficient)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(data_path, loading_mode):
|
||||||
|
files = glob.glob(os.path.join(data_path, "*.p*"))
|
||||||
|
files.sort()
|
||||||
|
tensors = []
|
||||||
|
if loading_mode == "split":
|
||||||
|
files_split = list(np.array_split(files, dist_utils.get_world_size()))[dist_utils.get_rank()]
|
||||||
|
for filepath in files_split:
|
||||||
|
try:
|
||||||
|
tensors.append(torch.load(filepath, map_location="cpu"))
|
||||||
|
except:
|
||||||
|
logger.warning(f"Unable to load file {filepath}")
|
||||||
|
elif loading_mode == "full":
|
||||||
|
for fin in files:
|
||||||
|
tensors.append(torch.load(fin, map_location="cpu"))
|
||||||
|
elif loading_mode == "single":
|
||||||
|
tensors.append(torch.load(files[0], map_location="cpu"))
|
||||||
|
if len(tensors) == 0:
|
||||||
|
return None
|
||||||
|
tensor = torch.cat(tensors)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class MultiDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, datasets):
|
||||||
|
|
||||||
|
self.datasets = datasets
|
||||||
|
self.prob = [1 / len(self.datasets) for _ in self.datasets]
|
||||||
|
self.dataset_ids = list(self.datasets.keys())
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return sum([len(dataset) for dataset in self.datasets.values()])
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
dataset_idx = numpy.random.choice(range(len(self.prob)), 1, p=self.prob)[0]
|
||||||
|
did = self.dataset_ids[dataset_idx]
|
||||||
|
index = random.randint(0, len(self.datasets[did]) - 1)
|
||||||
|
sample = self.datasets[did][index]
|
||||||
|
sample["dataset_id"] = did
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def generate_offset(self):
|
||||||
|
for dataset in self.datasets.values():
|
||||||
|
dataset.generate_offset()
|
||||||
|
|
||||||
|
def set_prob(self, coeff=0.0):
|
||||||
|
|
||||||
|
prob = np.array([float(len(dataset)) for _, dataset in self.datasets.items()])
|
||||||
|
prob /= prob.sum()
|
||||||
|
prob = np.array([p**coeff for p in prob])
|
||||||
|
prob /= prob.sum()
|
||||||
|
self.prob = prob
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(torch.utils.data.Dataset):
|
||||||
|
"""Monolingual dataset based on a list of paths"""
|
||||||
|
|
||||||
|
def __init__(self, data, chunk_length, tokenizer, opt):
|
||||||
|
|
||||||
|
self.data = data
|
||||||
|
self.chunk_length = chunk_length
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.opt = opt
|
||||||
|
self.generate_offset()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return (self.data.size(0) - self.offset) // self.chunk_length
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
start_idx = self.offset + index * self.chunk_length
|
||||||
|
end_idx = start_idx + self.chunk_length
|
||||||
|
tokens = self.data[start_idx:end_idx]
|
||||||
|
q_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max)
|
||||||
|
k_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max)
|
||||||
|
q_tokens = apply_augmentation(q_tokens, self.opt)
|
||||||
|
q_tokens = add_bos_eos(q_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id)
|
||||||
|
k_tokens = apply_augmentation(k_tokens, self.opt)
|
||||||
|
k_tokens = add_bos_eos(k_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id)
|
||||||
|
|
||||||
|
return {"q_tokens": q_tokens, "k_tokens": k_tokens}
|
||||||
|
|
||||||
|
def generate_offset(self):
|
||||||
|
self.offset = random.randint(0, self.chunk_length - 1)
|
||||||
|
|
||||||
|
|
||||||
|
class Collator(object):
|
||||||
|
def __init__(self, opt):
|
||||||
|
self.opt = opt
|
||||||
|
|
||||||
|
def __call__(self, batch_examples):
|
||||||
|
|
||||||
|
batch = defaultdict(list)
|
||||||
|
for example in batch_examples:
|
||||||
|
for k, v in example.items():
|
||||||
|
batch[k].append(v)
|
||||||
|
|
||||||
|
q_tokens, q_mask = build_mask(batch["q_tokens"])
|
||||||
|
k_tokens, k_mask = build_mask(batch["k_tokens"])
|
||||||
|
|
||||||
|
batch["q_tokens"] = q_tokens
|
||||||
|
batch["q_mask"] = q_mask
|
||||||
|
batch["k_tokens"] = k_tokens
|
||||||
|
batch["k_mask"] = k_mask
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def randomcrop(x, ratio_min, ratio_max):
|
||||||
|
|
||||||
|
ratio = random.uniform(ratio_min, ratio_max)
|
||||||
|
length = int(len(x) * ratio)
|
||||||
|
start = random.randint(0, len(x) - length)
|
||||||
|
end = start + length
|
||||||
|
crop = x[start:end].clone()
|
||||||
|
return crop
|
||||||
|
|
||||||
|
|
||||||
|
def build_mask(tensors):
|
||||||
|
shapes = [x.shape for x in tensors]
|
||||||
|
maxlength = max([len(x) for x in tensors])
|
||||||
|
returnmasks = []
|
||||||
|
ids = []
|
||||||
|
for k, x in enumerate(tensors):
|
||||||
|
returnmasks.append(torch.tensor([1] * len(x) + [0] * (maxlength - len(x))))
|
||||||
|
ids.append(torch.cat((x, torch.tensor([0] * (maxlength - len(x))))))
|
||||||
|
ids = torch.stack(ids, dim=0).long()
|
||||||
|
returnmasks = torch.stack(returnmasks, dim=0).bool()
|
||||||
|
return ids, returnmasks
|
||||||
|
|
||||||
|
|
||||||
|
def add_token(x, token):
|
||||||
|
x = torch.cat((torch.tensor([token]), x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def deleteword(x, p=0.1):
|
||||||
|
mask = np.random.rand(len(x))
|
||||||
|
x = [e for e, m in zip(x, mask) if m > p]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def replaceword(x, min_random, max_random, p=0.1):
|
||||||
|
mask = np.random.rand(len(x))
|
||||||
|
x = [e if m > p else random.randint(min_random, max_random) for e, m in zip(x, mask)]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def maskword(x, mask_id, p=0.1):
|
||||||
|
mask = np.random.rand(len(x))
|
||||||
|
x = [e if m > p else mask_id for e, m in zip(x, mask)]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def shuffleword(x, p=0.1):
|
||||||
|
count = (np.random.rand(len(x)) < p).sum()
|
||||||
|
"""Shuffles any n number of values in a list"""
|
||||||
|
indices_to_shuffle = random.sample(range(len(x)), k=count)
|
||||||
|
to_shuffle = [x[i] for i in indices_to_shuffle]
|
||||||
|
random.shuffle(to_shuffle)
|
||||||
|
for index, value in enumerate(to_shuffle):
|
||||||
|
old_index = indices_to_shuffle[index]
|
||||||
|
x[old_index] = value
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def apply_augmentation(x, opt):
|
||||||
|
if opt.augmentation == "mask":
|
||||||
|
return torch.tensor(maskword(x, mask_id=opt.mask_id, p=opt.prob_augmentation))
|
||||||
|
elif opt.augmentation == "replace":
|
||||||
|
return torch.tensor(
|
||||||
|
replaceword(x, min_random=opt.start_id, max_random=opt.vocab_size - 1, p=opt.prob_augmentation)
|
||||||
|
)
|
||||||
|
elif opt.augmentation == "delete":
|
||||||
|
return torch.tensor(deleteword(x, p=opt.prob_augmentation))
|
||||||
|
elif opt.augmentation == "shuffle":
|
||||||
|
return torch.tensor(shuffleword(x, p=opt.prob_augmentation))
|
||||||
|
else:
|
||||||
|
if not isinstance(x, torch.Tensor):
|
||||||
|
x = torch.Tensor(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def add_bos_eos(x, bos_token_id, eos_token_id):
|
||||||
|
if not isinstance(x, torch.Tensor):
|
||||||
|
x = torch.Tensor(x)
|
||||||
|
if bos_token_id is None and eos_token_id is not None:
|
||||||
|
x = torch.cat([x.clone().detach(), torch.tensor([eos_token_id])])
|
||||||
|
elif bos_token_id is not None and eos_token_id is None:
|
||||||
|
x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach()])
|
||||||
|
elif bos_token_id is None and eos_token_id is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach(), torch.tensor([eos_token_id])])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Used for passage retrieval
|
||||||
|
def load_passages(path):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
logger.info(f"{path} does not exist")
|
||||||
|
return
|
||||||
|
logger.info(f"Loading passages from: {path}")
|
||||||
|
passages = []
|
||||||
|
with open(path,encoding='UTF-8') as fin:
|
||||||
|
if path.endswith(".jsonl"):
|
||||||
|
for k, line in enumerate(fin):
|
||||||
|
ex = json.loads(line)
|
||||||
|
passages.append(ex)
|
||||||
|
else:
|
||||||
|
reader = csv.reader(fin, delimiter="\t")
|
||||||
|
for k, row in enumerate(reader):
|
||||||
|
if not row[0] == "id":
|
||||||
|
ex = {"id": row[0], "title": row[2], "text": row[1]}
|
||||||
|
passages.append(ex)
|
||||||
|
return passages
|
|
@ -0,0 +1,128 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
class Gather(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: torch.tensor):
|
||||||
|
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
|
||||||
|
dist.all_gather(output, x)
|
||||||
|
return tuple(output)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *grads):
|
||||||
|
all_gradients = torch.stack(grads)
|
||||||
|
dist.all_reduce(all_gradients)
|
||||||
|
return all_gradients[dist.get_rank()]
|
||||||
|
|
||||||
|
|
||||||
|
def gather(x: torch.tensor):
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return x
|
||||||
|
x_gather = Gather.apply(x)
|
||||||
|
x_gather = torch.cat(x_gather, dim=0)
|
||||||
|
return x_gather
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def gather_nograd(x: torch.tensor):
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return x
|
||||||
|
x_gather = [torch.ones_like(x) for _ in range(dist.get_world_size())]
|
||||||
|
dist.all_gather(x_gather, x, async_op=False)
|
||||||
|
|
||||||
|
x_gather = torch.cat(x_gather, dim=0)
|
||||||
|
return x_gather
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def varsize_gather_nograd(x: torch.Tensor):
|
||||||
|
"""gather tensors of different sizes along the first dimension"""
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return x
|
||||||
|
|
||||||
|
# determine max size
|
||||||
|
size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int)
|
||||||
|
allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
|
||||||
|
dist.all_gather(allsizes, size)
|
||||||
|
max_size = max([size.cpu().max() for size in allsizes])
|
||||||
|
|
||||||
|
padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device)
|
||||||
|
padded[: x.shape[0]] = x
|
||||||
|
output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())]
|
||||||
|
dist.all_gather(output, padded)
|
||||||
|
|
||||||
|
output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)]
|
||||||
|
output = torch.cat(output, dim=0)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_varsize(x: torch.Tensor):
|
||||||
|
"""gather tensors of different sizes along the first dimension"""
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return [x.shape[0]]
|
||||||
|
|
||||||
|
# determine max size
|
||||||
|
size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int)
|
||||||
|
allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
|
||||||
|
dist.all_gather(allsizes, size)
|
||||||
|
allsizes = torch.cat(allsizes)
|
||||||
|
return allsizes
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if not dist.is_available():
|
||||||
|
return 0
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def is_main():
|
||||||
|
return get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return dist.get_world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def barrier():
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
def average_main(x):
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return x
|
||||||
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
dist.reduce(x, 0, op=dist.ReduceOp.SUM)
|
||||||
|
if is_main():
|
||||||
|
x = x / dist.get_world_size()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def sum_main(x):
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return x
|
||||||
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
dist.reduce(x, 0, op=dist.ReduceOp.SUM)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def weighted_average(x, count):
|
||||||
|
if not dist.is_initialized():
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
x = x.item()
|
||||||
|
return x, count
|
||||||
|
t_loss = torch.tensor([x * count]).cuda()
|
||||||
|
t_total = torch.tensor([count]).cuda()
|
||||||
|
t_loss = sum_main(t_loss)
|
||||||
|
t_total = sum_main(t_total)
|
||||||
|
return (t_loss / t_total).item(), t_total.item()
|
|
@ -0,0 +1,190 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import logging
|
||||||
|
import regex
|
||||||
|
import string
|
||||||
|
import unicodedata
|
||||||
|
from functools import partial
|
||||||
|
from multiprocessing import Pool as ProcessPool
|
||||||
|
from typing import Tuple, List, Dict
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
"""
|
||||||
|
Evaluation code from DPR: https://github.com/facebookresearch/DPR
|
||||||
|
"""
|
||||||
|
|
||||||
|
class SimpleTokenizer(object):
|
||||||
|
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
|
||||||
|
NON_WS = r'[^\p{Z}\p{C}]'
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
annotators: None or empty set (only tokenizes).
|
||||||
|
"""
|
||||||
|
self._regexp = regex.compile(
|
||||||
|
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
|
||||||
|
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
|
||||||
|
)
|
||||||
|
|
||||||
|
def tokenize(self, text, uncased=False):
|
||||||
|
matches = [m for m in self._regexp.finditer(text)]
|
||||||
|
if uncased:
|
||||||
|
tokens = [m.group().lower() for m in matches]
|
||||||
|
else:
|
||||||
|
tokens = [m.group() for m in matches]
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits'])
|
||||||
|
|
||||||
|
def calculate_matches(data: List, workers_num: int):
|
||||||
|
"""
|
||||||
|
Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of
|
||||||
|
documents and results. It internally forks multiple sub-processes for evaluation and then merges results
|
||||||
|
:param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title)
|
||||||
|
:param answers: list of answers's list. One list per question
|
||||||
|
:param closest_docs: document ids of the top results along with their scores
|
||||||
|
:param workers_num: amount of parallel threads to process data
|
||||||
|
:param match_type: type of answer matching. Refer to has_answer code for available options
|
||||||
|
:return: matching information tuple.
|
||||||
|
top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of
|
||||||
|
valid matches across an entire dataset.
|
||||||
|
questions_doc_hits - more detailed info with answer matches for every question and every retrieved document
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info('Matching answers in top docs...')
|
||||||
|
|
||||||
|
tokenizer = SimpleTokenizer()
|
||||||
|
get_score_partial = partial(check_answer, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
processes = ProcessPool(processes=workers_num)
|
||||||
|
scores = processes.map(get_score_partial, data)
|
||||||
|
|
||||||
|
logger.info('Per question validation results len=%d', len(scores))
|
||||||
|
|
||||||
|
n_docs = len(data[0]['ctxs'])
|
||||||
|
top_k_hits = [0] * n_docs
|
||||||
|
for question_hits in scores:
|
||||||
|
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
|
||||||
|
if best_hit is not None:
|
||||||
|
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
|
||||||
|
|
||||||
|
return QAMatchStats(top_k_hits, scores)
|
||||||
|
|
||||||
|
def check_answer(example, tokenizer) -> List[bool]:
|
||||||
|
"""Search through all the top docs to see if they have any of the answers."""
|
||||||
|
answers = example['answers']
|
||||||
|
ctxs = example['ctxs']
|
||||||
|
|
||||||
|
hits = []
|
||||||
|
|
||||||
|
for i, doc in enumerate(ctxs):
|
||||||
|
text = doc['text']
|
||||||
|
|
||||||
|
if text is None: # cannot find the document for some reason
|
||||||
|
logger.warning("no doc in db")
|
||||||
|
hits.append(False)
|
||||||
|
continue
|
||||||
|
|
||||||
|
hits.append(has_answer(answers, text, tokenizer))
|
||||||
|
|
||||||
|
return hits
|
||||||
|
|
||||||
|
def has_answer(answers, text, tokenizer) -> bool:
|
||||||
|
"""Check if a document contains an answer string."""
|
||||||
|
text = _normalize(text)
|
||||||
|
text = tokenizer.tokenize(text, uncased=True)
|
||||||
|
|
||||||
|
for answer in answers:
|
||||||
|
answer = _normalize(answer)
|
||||||
|
answer = tokenizer.tokenize(answer, uncased=True)
|
||||||
|
for i in range(0, len(text) - len(answer) + 1):
|
||||||
|
if answer == text[i: i + len(answer)]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
######## READER EVALUATION ########
|
||||||
|
#################################################
|
||||||
|
|
||||||
|
def _normalize(text):
|
||||||
|
return unicodedata.normalize('NFD', text)
|
||||||
|
|
||||||
|
#Normalization and score functions from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
|
||||||
|
def normalize_answer(s):
|
||||||
|
def remove_articles(text):
|
||||||
|
return regex.sub(r'\b(a|an|the)\b', ' ', text)
|
||||||
|
|
||||||
|
def white_space_fix(text):
|
||||||
|
return ' '.join(text.split())
|
||||||
|
|
||||||
|
def remove_punc(text):
|
||||||
|
exclude = set(string.punctuation)
|
||||||
|
return ''.join(ch for ch in text if ch not in exclude)
|
||||||
|
|
||||||
|
def lower(text):
|
||||||
|
return text.lower()
|
||||||
|
|
||||||
|
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||||
|
|
||||||
|
def em(prediction, ground_truth):
|
||||||
|
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||||
|
|
||||||
|
def f1(prediction, ground_truth):
|
||||||
|
prediction_tokens = normalize_answer(prediction).split()
|
||||||
|
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||||
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||||
|
num_same = sum(common.values())
|
||||||
|
if num_same == 0:
|
||||||
|
return 0
|
||||||
|
precision = 1.0 * num_same / len(prediction_tokens)
|
||||||
|
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||||
|
f1 = (2 * precision * recall) / (precision + recall)
|
||||||
|
return f1
|
||||||
|
|
||||||
|
def f1_score(prediction, ground_truths):
|
||||||
|
return max([f1(prediction, gt) for gt in ground_truths])
|
||||||
|
|
||||||
|
def exact_match_score(prediction, ground_truths):
|
||||||
|
return max([em(prediction, gt) for gt in ground_truths])
|
||||||
|
|
||||||
|
####################################################
|
||||||
|
######## RETRIEVER EVALUATION ########
|
||||||
|
####################################################
|
||||||
|
|
||||||
|
def eval_batch(scores, inversions, avg_topk, idx_topk):
|
||||||
|
for k, s in enumerate(scores):
|
||||||
|
s = s.cpu().numpy()
|
||||||
|
sorted_idx = np.argsort(-s)
|
||||||
|
score(sorted_idx, inversions, avg_topk, idx_topk)
|
||||||
|
|
||||||
|
def count_inversions(arr):
|
||||||
|
inv_count = 0
|
||||||
|
lenarr = len(arr)
|
||||||
|
for i in range(lenarr):
|
||||||
|
for j in range(i + 1, lenarr):
|
||||||
|
if (arr[i] > arr[j]):
|
||||||
|
inv_count += 1
|
||||||
|
return inv_count
|
||||||
|
|
||||||
|
def score(x, inversions, avg_topk, idx_topk):
|
||||||
|
x = np.array(x)
|
||||||
|
inversions.append(count_inversions(x))
|
||||||
|
for k in avg_topk:
|
||||||
|
# ratio of passages in the predicted top-k that are
|
||||||
|
# also in the topk given by gold score
|
||||||
|
avg_pred_topk = (x[:k]<k).mean()
|
||||||
|
avg_topk[k].append(avg_pred_topk)
|
||||||
|
for k in idx_topk:
|
||||||
|
below_k = (x<k)
|
||||||
|
# number of passages required to obtain all passages from gold top-k
|
||||||
|
idx_gold_topk = len(x) - np.argmax(below_k[::-1])
|
||||||
|
idx_topk[k].append(idx_gold_topk)
|
|
@ -0,0 +1,171 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
from src import normalize_text
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
datapaths,
|
||||||
|
negative_ctxs=1,
|
||||||
|
negative_hard_ratio=0.0,
|
||||||
|
negative_hard_min_idx=0,
|
||||||
|
training=False,
|
||||||
|
global_rank=-1,
|
||||||
|
world_size=-1,
|
||||||
|
maxload=None,
|
||||||
|
normalize=False,
|
||||||
|
):
|
||||||
|
self.negative_ctxs = negative_ctxs
|
||||||
|
self.negative_hard_ratio = negative_hard_ratio
|
||||||
|
self.negative_hard_min_idx = negative_hard_min_idx
|
||||||
|
self.training = training
|
||||||
|
self.normalize_fn = normalize_text.normalize if normalize_text else lambda x: x
|
||||||
|
self._load_data(datapaths, global_rank, world_size, maxload)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
example = self.data[index]
|
||||||
|
question = example["question"]
|
||||||
|
if self.training:
|
||||||
|
gold = random.choice(example["positive_ctxs"])
|
||||||
|
|
||||||
|
n_hard_negatives, n_random_negatives = self.sample_n_hard_negatives(example)
|
||||||
|
negatives = []
|
||||||
|
if n_random_negatives > 0:
|
||||||
|
random_negatives = random.sample(example["negative_ctxs"], n_random_negatives)
|
||||||
|
negatives += random_negatives
|
||||||
|
if n_hard_negatives > 0:
|
||||||
|
hard_negatives = random.sample(
|
||||||
|
example["hard_negative_ctxs"][self.negative_hard_min_idx :], n_hard_negatives
|
||||||
|
)
|
||||||
|
negatives += hard_negatives
|
||||||
|
else:
|
||||||
|
gold = example["positive_ctxs"][0]
|
||||||
|
nidx = 0
|
||||||
|
if "negative_ctxs" in example:
|
||||||
|
negatives = [example["negative_ctxs"][nidx]]
|
||||||
|
else:
|
||||||
|
negatives = []
|
||||||
|
|
||||||
|
gold = gold["title"] + " " + gold["text"] if "title" in gold and len(gold["title"]) > 0 else gold["text"]
|
||||||
|
|
||||||
|
negatives = [
|
||||||
|
n["title"] + " " + n["text"] if ("title" in n and len(n["title"]) > 0) else n["text"] for n in negatives
|
||||||
|
]
|
||||||
|
|
||||||
|
example = {
|
||||||
|
"query": self.normalize_fn(question),
|
||||||
|
"gold": self.normalize_fn(gold),
|
||||||
|
"negatives": [self.normalize_fn(n) for n in negatives],
|
||||||
|
}
|
||||||
|
return example
|
||||||
|
|
||||||
|
def _load_data(self, datapaths, global_rank, world_size, maxload):
|
||||||
|
counter = 0
|
||||||
|
self.data = []
|
||||||
|
for path in datapaths:
|
||||||
|
path = str(path)
|
||||||
|
if path.endswith(".jsonl"):
|
||||||
|
file_data, counter = self._load_data_jsonl(path, global_rank, world_size, counter, maxload)
|
||||||
|
elif path.endswith(".json"):
|
||||||
|
file_data, counter = self._load_data_json(path, global_rank, world_size, counter, maxload)
|
||||||
|
self.data.extend(file_data)
|
||||||
|
if maxload is not None and maxload > 0 and counter >= maxload:
|
||||||
|
break
|
||||||
|
|
||||||
|
def _load_data_json(self, path, global_rank, world_size, counter, maxload=None):
|
||||||
|
examples = []
|
||||||
|
with open(path, "r") as fin:
|
||||||
|
data = json.load(fin)
|
||||||
|
for example in data:
|
||||||
|
counter += 1
|
||||||
|
if global_rank > -1 and not counter % world_size == global_rank:
|
||||||
|
continue
|
||||||
|
examples.append(example)
|
||||||
|
if maxload is not None and maxload > 0 and counter == maxload:
|
||||||
|
break
|
||||||
|
|
||||||
|
return examples, counter
|
||||||
|
|
||||||
|
def _load_data_jsonl(self, path, global_rank, world_size, counter, maxload=None):
|
||||||
|
examples = []
|
||||||
|
with open(path, "r") as fin:
|
||||||
|
for line in fin:
|
||||||
|
counter += 1
|
||||||
|
if global_rank > -1 and not counter % world_size == global_rank:
|
||||||
|
continue
|
||||||
|
example = json.loads(line)
|
||||||
|
examples.append(example)
|
||||||
|
if maxload is not None and maxload > 0 and counter == maxload:
|
||||||
|
break
|
||||||
|
|
||||||
|
return examples, counter
|
||||||
|
|
||||||
|
def sample_n_hard_negatives(self, ex):
|
||||||
|
|
||||||
|
if "hard_negative_ctxs" in ex:
|
||||||
|
n_hard_negatives = sum([random.random() < self.negative_hard_ratio for _ in range(self.negative_ctxs)])
|
||||||
|
n_hard_negatives = min(n_hard_negatives, len(ex["hard_negative_ctxs"][self.negative_hard_min_idx :]))
|
||||||
|
else:
|
||||||
|
n_hard_negatives = 0
|
||||||
|
n_random_negatives = self.negative_ctxs - n_hard_negatives
|
||||||
|
if "negative_ctxs" in ex:
|
||||||
|
n_random_negatives = min(n_random_negatives, len(ex["negative_ctxs"]))
|
||||||
|
else:
|
||||||
|
n_random_negatives = 0
|
||||||
|
return n_hard_negatives, n_random_negatives
|
||||||
|
|
||||||
|
|
||||||
|
class Collator(object):
|
||||||
|
def __init__(self, tokenizer, passage_maxlength=200):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.passage_maxlength = passage_maxlength
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
queries = [ex["query"] for ex in batch]
|
||||||
|
golds = [ex["gold"] for ex in batch]
|
||||||
|
negs = [item for ex in batch for item in ex["negatives"]]
|
||||||
|
allpassages = golds + negs
|
||||||
|
|
||||||
|
qout = self.tokenizer.batch_encode_plus(
|
||||||
|
queries,
|
||||||
|
max_length=self.passage_maxlength,
|
||||||
|
truncation=True,
|
||||||
|
padding=True,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
kout = self.tokenizer.batch_encode_plus(
|
||||||
|
allpassages,
|
||||||
|
max_length=self.passage_maxlength,
|
||||||
|
truncation=True,
|
||||||
|
padding=True,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
q_tokens, q_mask = qout["input_ids"], qout["attention_mask"].bool()
|
||||||
|
k_tokens, k_mask = kout["input_ids"], kout["attention_mask"].bool()
|
||||||
|
|
||||||
|
g_tokens, g_mask = k_tokens[: len(golds)], k_mask[: len(golds)]
|
||||||
|
n_tokens, n_mask = k_tokens[len(golds) :], k_mask[len(golds) :]
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"q_tokens": q_tokens,
|
||||||
|
"q_mask": q_mask,
|
||||||
|
"k_tokens": k_tokens,
|
||||||
|
"k_mask": k_mask,
|
||||||
|
"g_tokens": g_tokens,
|
||||||
|
"g_mask": g_mask,
|
||||||
|
"n_tokens": n_tokens,
|
||||||
|
"n_mask": n_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
return batch
|
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import transformers
|
||||||
|
import logging
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from src import contriever, dist_utils, utils
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InBatch(nn.Module):
|
||||||
|
def __init__(self, opt, retriever=None, tokenizer=None):
|
||||||
|
super(InBatch, self).__init__()
|
||||||
|
|
||||||
|
self.opt = opt
|
||||||
|
self.norm_doc = opt.norm_doc
|
||||||
|
self.norm_query = opt.norm_query
|
||||||
|
self.label_smoothing = opt.label_smoothing
|
||||||
|
if retriever is None or tokenizer is None:
|
||||||
|
retriever, tokenizer = self._load_retriever(
|
||||||
|
opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init
|
||||||
|
)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.encoder = retriever
|
||||||
|
|
||||||
|
def _load_retriever(self, model_id, pooling, random_init):
|
||||||
|
cfg = utils.load_hf(transformers.AutoConfig, model_id)
|
||||||
|
tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id)
|
||||||
|
|
||||||
|
if "xlm" in model_id:
|
||||||
|
model_class = contriever.XLMRetriever
|
||||||
|
else:
|
||||||
|
model_class = contriever.Contriever
|
||||||
|
|
||||||
|
if random_init:
|
||||||
|
retriever = model_class(cfg)
|
||||||
|
else:
|
||||||
|
retriever = utils.load_hf(model_class, model_id)
|
||||||
|
|
||||||
|
if "bert-" in model_id:
|
||||||
|
if tokenizer.bos_token_id is None:
|
||||||
|
tokenizer.bos_token = "[CLS]"
|
||||||
|
if tokenizer.eos_token_id is None:
|
||||||
|
tokenizer.eos_token = "[SEP]"
|
||||||
|
|
||||||
|
retriever.config.pooling = pooling
|
||||||
|
|
||||||
|
return retriever, tokenizer
|
||||||
|
|
||||||
|
def get_encoder(self):
|
||||||
|
return self.encoder
|
||||||
|
|
||||||
|
def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs):
|
||||||
|
|
||||||
|
bsz = len(q_tokens)
|
||||||
|
labels = torch.arange(0, bsz, dtype=torch.long, device=q_tokens.device)
|
||||||
|
|
||||||
|
qemb = self.encoder(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query)
|
||||||
|
kemb = self.encoder(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc)
|
||||||
|
|
||||||
|
gather_fn = dist_utils.gather
|
||||||
|
|
||||||
|
gather_kemb = gather_fn(kemb)
|
||||||
|
|
||||||
|
labels = labels + dist_utils.get_rank() * len(kemb)
|
||||||
|
|
||||||
|
scores = torch.einsum("id, jd->ij", qemb / self.opt.temperature, gather_kemb)
|
||||||
|
|
||||||
|
loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=self.label_smoothing)
|
||||||
|
|
||||||
|
# log stats
|
||||||
|
if len(stats_prefix) > 0:
|
||||||
|
stats_prefix = stats_prefix + "/"
|
||||||
|
iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz)
|
||||||
|
|
||||||
|
predicted_idx = torch.argmax(scores, dim=-1)
|
||||||
|
accuracy = 100 * (predicted_idx == labels).float().mean()
|
||||||
|
stdq = torch.std(qemb, dim=0).mean().item()
|
||||||
|
stdk = torch.std(kemb, dim=0).mean().item()
|
||||||
|
iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz)
|
||||||
|
iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz)
|
||||||
|
iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz)
|
||||||
|
|
||||||
|
return loss, iter_stats
|
|
@ -0,0 +1,73 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
class Indexer(object):
|
||||||
|
|
||||||
|
def __init__(self, vector_sz, n_subquantizers=0, n_bits=8):
|
||||||
|
if n_subquantizers > 0:
|
||||||
|
self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
|
||||||
|
else:
|
||||||
|
self.index = faiss.IndexFlatIP(vector_sz)
|
||||||
|
#self.index_id_to_db_id = np.empty((0), dtype=np.int64)
|
||||||
|
self.index_id_to_db_id = []
|
||||||
|
|
||||||
|
def index_data(self, ids, embeddings):
|
||||||
|
self._update_id_mapping(ids)
|
||||||
|
embeddings = embeddings.astype('float32')
|
||||||
|
if not self.index.is_trained:
|
||||||
|
self.index.train(embeddings)
|
||||||
|
self.index.add(embeddings)
|
||||||
|
|
||||||
|
print(f'Total data indexed {len(self.index_id_to_db_id)}')
|
||||||
|
|
||||||
|
def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 2048) -> List[Tuple[List[object], List[float]]]:
|
||||||
|
query_vectors = query_vectors.astype('float32')
|
||||||
|
result = []
|
||||||
|
nbatch = (len(query_vectors)-1) // index_batch_size + 1
|
||||||
|
for k in tqdm(range(nbatch)):
|
||||||
|
start_idx = k*index_batch_size
|
||||||
|
end_idx = min((k+1)*index_batch_size, len(query_vectors))
|
||||||
|
q = query_vectors[start_idx: end_idx]
|
||||||
|
scores, indexes = self.index.search(q, top_docs)
|
||||||
|
# convert to external ids
|
||||||
|
db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
|
||||||
|
result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))])
|
||||||
|
return result
|
||||||
|
|
||||||
|
def serialize(self, dir_path):
|
||||||
|
index_file = os.path.join(dir_path, 'index.faiss')
|
||||||
|
meta_file = os.path.join(dir_path, 'index_meta.faiss')
|
||||||
|
print(f'Serializing index to {index_file}, meta data to {meta_file}')
|
||||||
|
|
||||||
|
faiss.write_index(self.index, index_file)
|
||||||
|
with open(meta_file, mode='wb') as f:
|
||||||
|
pickle.dump(self.index_id_to_db_id, f)
|
||||||
|
|
||||||
|
def deserialize_from(self, dir_path):
|
||||||
|
index_file = os.path.join(dir_path, 'index.faiss')
|
||||||
|
meta_file = os.path.join(dir_path, 'index_meta.faiss')
|
||||||
|
print(f'Loading index from {index_file}, meta data from {meta_file}')
|
||||||
|
|
||||||
|
self.index = faiss.read_index(index_file)
|
||||||
|
print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)
|
||||||
|
|
||||||
|
with open(meta_file, "rb") as reader:
|
||||||
|
self.index_id_to_db_id = pickle.load(reader)
|
||||||
|
assert len(
|
||||||
|
self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
|
||||||
|
|
||||||
|
def _update_id_mapping(self, db_ids: List):
|
||||||
|
#new_ids = np.array(db_ids, dtype=np.int64)
|
||||||
|
#self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0)
|
||||||
|
self.index_id_to_db_id.extend(db_ids)
|
|
@ -0,0 +1,140 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import logging
|
||||||
|
import copy
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from src import contriever, dist_utils, utils
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MoCo(nn.Module):
|
||||||
|
def __init__(self, opt):
|
||||||
|
super(MoCo, self).__init__()
|
||||||
|
|
||||||
|
self.queue_size = opt.queue_size
|
||||||
|
self.momentum = opt.momentum
|
||||||
|
self.temperature = opt.temperature
|
||||||
|
self.label_smoothing = opt.label_smoothing
|
||||||
|
self.norm_doc = opt.norm_doc
|
||||||
|
self.norm_query = opt.norm_query
|
||||||
|
self.moco_train_mode_encoder_k = opt.moco_train_mode_encoder_k # apply the encoder on keys in train mode
|
||||||
|
|
||||||
|
retriever, tokenizer = self._load_retriever(
|
||||||
|
opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.encoder_q = retriever
|
||||||
|
self.encoder_k = copy.deepcopy(retriever)
|
||||||
|
|
||||||
|
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
|
||||||
|
param_k.data.copy_(param_q.data)
|
||||||
|
param_k.requires_grad = False
|
||||||
|
|
||||||
|
# create the queue
|
||||||
|
self.register_buffer("queue", torch.randn(opt.projection_size, self.queue_size))
|
||||||
|
self.queue = nn.functional.normalize(self.queue, dim=0)
|
||||||
|
|
||||||
|
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
||||||
|
|
||||||
|
def _load_retriever(self, model_id, pooling, random_init):
|
||||||
|
cfg = utils.load_hf(transformers.AutoConfig, model_id)
|
||||||
|
tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id)
|
||||||
|
|
||||||
|
if "xlm" in model_id:
|
||||||
|
model_class = contriever.XLMRetriever
|
||||||
|
else:
|
||||||
|
model_class = contriever.Contriever
|
||||||
|
|
||||||
|
if random_init:
|
||||||
|
retriever = model_class(cfg)
|
||||||
|
else:
|
||||||
|
retriever = utils.load_hf(model_class, model_id)
|
||||||
|
|
||||||
|
if "bert-" in model_id:
|
||||||
|
if tokenizer.bos_token_id is None:
|
||||||
|
tokenizer.bos_token = "[CLS]"
|
||||||
|
if tokenizer.eos_token_id is None:
|
||||||
|
tokenizer.eos_token = "[SEP]"
|
||||||
|
|
||||||
|
retriever.config.pooling = pooling
|
||||||
|
|
||||||
|
return retriever, tokenizer
|
||||||
|
|
||||||
|
def get_encoder(self, return_encoder_k=False):
|
||||||
|
if return_encoder_k:
|
||||||
|
return self.encoder_k
|
||||||
|
else:
|
||||||
|
return self.encoder_q
|
||||||
|
|
||||||
|
def _momentum_update_key_encoder(self):
|
||||||
|
"""
|
||||||
|
Update of the key encoder
|
||||||
|
"""
|
||||||
|
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
|
||||||
|
param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _dequeue_and_enqueue(self, keys):
|
||||||
|
# gather keys before updating queue
|
||||||
|
keys = dist_utils.gather_nograd(keys.contiguous())
|
||||||
|
|
||||||
|
batch_size = keys.shape[0]
|
||||||
|
|
||||||
|
ptr = int(self.queue_ptr)
|
||||||
|
assert self.queue_size % batch_size == 0, f"{batch_size}, {self.queue_size}" # for simplicity
|
||||||
|
|
||||||
|
# replace the keys at ptr (dequeue and enqueue)
|
||||||
|
self.queue[:, ptr : ptr + batch_size] = keys.T
|
||||||
|
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
||||||
|
|
||||||
|
self.queue_ptr[0] = ptr
|
||||||
|
|
||||||
|
def _compute_logits(self, q, k):
|
||||||
|
l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
|
||||||
|
l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])
|
||||||
|
|
||||||
|
logits = torch.cat([l_pos, l_neg], dim=1)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs):
|
||||||
|
bsz = q_tokens.size(0)
|
||||||
|
|
||||||
|
q = self.encoder_q(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query)
|
||||||
|
|
||||||
|
# compute key features
|
||||||
|
with torch.no_grad(): # no gradient to keys
|
||||||
|
self._momentum_update_key_encoder() # update the key encoder
|
||||||
|
|
||||||
|
if not self.encoder_k.training and not self.moco_train_mode_encoder_k:
|
||||||
|
self.encoder_k.eval()
|
||||||
|
|
||||||
|
k = self.encoder_k(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc)
|
||||||
|
|
||||||
|
logits = self._compute_logits(q, k) / self.temperature
|
||||||
|
|
||||||
|
# labels: positive key indicators
|
||||||
|
labels = torch.zeros(bsz, dtype=torch.long).cuda()
|
||||||
|
|
||||||
|
loss = torch.nn.functional.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
|
||||||
|
|
||||||
|
self._dequeue_and_enqueue(k)
|
||||||
|
|
||||||
|
# log stats
|
||||||
|
if len(stats_prefix) > 0:
|
||||||
|
stats_prefix = stats_prefix + "/"
|
||||||
|
iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz)
|
||||||
|
|
||||||
|
predicted_idx = torch.argmax(logits, dim=-1)
|
||||||
|
accuracy = 100 * (predicted_idx == labels).float().mean()
|
||||||
|
stdq = torch.std(q, dim=0).mean().item()
|
||||||
|
stdk = torch.std(k, dim=0).mean().item()
|
||||||
|
iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz)
|
||||||
|
iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz)
|
||||||
|
iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz)
|
||||||
|
|
||||||
|
return loss, iter_stats
|
|
@ -0,0 +1,162 @@
|
||||||
|
"""
|
||||||
|
adapted from chemdataextractor.text.normalize
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
Tools for normalizing text.
|
||||||
|
https://github.com/mcs07/ChemDataExtractor
|
||||||
|
:copyright: Copyright 2016 by Matt Swain.
|
||||||
|
:license: MIT
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
a copy of this software and associated documentation files (the
|
||||||
|
'Software'), to deal in the Software without restriction, including
|
||||||
|
without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be
|
||||||
|
included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
|
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||||
|
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||||
|
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||||
|
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
#: Control characters.
|
||||||
|
CONTROLS = {
|
||||||
|
'\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011',
|
||||||
|
'\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b',
|
||||||
|
}
|
||||||
|
# There are further control characters, but they are instead replaced with a space by unicode normalization
|
||||||
|
# '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f'
|
||||||
|
|
||||||
|
|
||||||
|
#: Hyphen and dash characters.
|
||||||
|
HYPHENS = {
|
||||||
|
'-', # \u002d Hyphen-minus
|
||||||
|
'‐', # \u2010 Hyphen
|
||||||
|
'‑', # \u2011 Non-breaking hyphen
|
||||||
|
'⁃', # \u2043 Hyphen bullet
|
||||||
|
'‒', # \u2012 figure dash
|
||||||
|
'–', # \u2013 en dash
|
||||||
|
'—', # \u2014 em dash
|
||||||
|
'―', # \u2015 horizontal bar
|
||||||
|
}
|
||||||
|
|
||||||
|
#: Minus characters.
|
||||||
|
MINUSES = {
|
||||||
|
'-', # \u002d Hyphen-minus
|
||||||
|
'−', # \u2212 Minus
|
||||||
|
'-', # \uff0d Full-width Hyphen-minus
|
||||||
|
'⁻', # \u207b Superscript minus
|
||||||
|
}
|
||||||
|
|
||||||
|
#: Plus characters.
|
||||||
|
PLUSES = {
|
||||||
|
'+', # \u002b Plus
|
||||||
|
'+', # \uff0b Full-width Plus
|
||||||
|
'⁺', # \u207a Superscript plus
|
||||||
|
}
|
||||||
|
|
||||||
|
#: Slash characters.
|
||||||
|
SLASHES = {
|
||||||
|
'/', # \u002f Solidus
|
||||||
|
'⁄', # \u2044 Fraction slash
|
||||||
|
'∕', # \u2215 Division slash
|
||||||
|
}
|
||||||
|
|
||||||
|
#: Tilde characters.
|
||||||
|
TILDES = {
|
||||||
|
'~', # \u007e Tilde
|
||||||
|
'˜', # \u02dc Small tilde
|
||||||
|
'⁓', # \u2053 Swung dash
|
||||||
|
'∼', # \u223c Tilde operator #in mbert vocab
|
||||||
|
'∽', # \u223d Reversed tilde
|
||||||
|
'∿', # \u223f Sine wave
|
||||||
|
'〜', # \u301c Wave dash #in mbert vocab
|
||||||
|
'~', # \uff5e Full-width tilde #in mbert vocab
|
||||||
|
}
|
||||||
|
|
||||||
|
#: Apostrophe characters.
|
||||||
|
APOSTROPHES = {
|
||||||
|
"'", # \u0027
|
||||||
|
'’', # \u2019
|
||||||
|
'՚', # \u055a
|
||||||
|
'Ꞌ', # \ua78b
|
||||||
|
'ꞌ', # \ua78c
|
||||||
|
''', # \uff07
|
||||||
|
}
|
||||||
|
|
||||||
|
#: Single quote characters.
|
||||||
|
SINGLE_QUOTES = {
|
||||||
|
"'", # \u0027
|
||||||
|
'‘', # \u2018
|
||||||
|
'’', # \u2019
|
||||||
|
'‚', # \u201a
|
||||||
|
'‛', # \u201b
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#: Double quote characters.
|
||||||
|
DOUBLE_QUOTES = {
|
||||||
|
'"', # \u0022
|
||||||
|
'“', # \u201c
|
||||||
|
'”', # \u201d
|
||||||
|
'„', # \u201e
|
||||||
|
'‟', # \u201f
|
||||||
|
}
|
||||||
|
|
||||||
|
#: Accent characters.
|
||||||
|
ACCENTS = {
|
||||||
|
'`', # \u0060
|
||||||
|
'´', # \u00b4
|
||||||
|
}
|
||||||
|
|
||||||
|
#: Prime characters.
|
||||||
|
PRIMES = {
|
||||||
|
'′', # \u2032
|
||||||
|
'″', # \u2033
|
||||||
|
'‴', # \u2034
|
||||||
|
'‵', # \u2035
|
||||||
|
'‶', # \u2036
|
||||||
|
'‷', # \u2037
|
||||||
|
'⁗', # \u2057
|
||||||
|
}
|
||||||
|
|
||||||
|
#: Quote characters, including apostrophes, single quotes, double quotes, accents and primes.
|
||||||
|
QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES
|
||||||
|
|
||||||
|
def normalize(text):
|
||||||
|
for control in CONTROLS:
|
||||||
|
text = text.replace(control, '')
|
||||||
|
text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ')
|
||||||
|
|
||||||
|
for hyphen in HYPHENS | MINUSES:
|
||||||
|
text = text.replace(hyphen, '-')
|
||||||
|
text = text.replace('\u00ad', '')
|
||||||
|
|
||||||
|
for double_quote in DOUBLE_QUOTES:
|
||||||
|
text = text.replace(double_quote, '"') # \u0022
|
||||||
|
for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS):
|
||||||
|
text = text.replace(single_quote, "'") # \u0027
|
||||||
|
text = text.replace('′', "'") # \u2032 prime
|
||||||
|
text = text.replace('‵', "'") # \u2035 reversed prime
|
||||||
|
text = text.replace('″', "''") # \u2033 double prime
|
||||||
|
text = text.replace('‶', "''") # \u2036 reversed double prime
|
||||||
|
text = text.replace('‴', "'''") # \u2034 triple prime
|
||||||
|
text = text.replace('‷', "'''") # \u2037 reversed triple prime
|
||||||
|
text = text.replace('⁗', "''''") # \u2057 quadruple prime
|
||||||
|
|
||||||
|
text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026
|
||||||
|
|
||||||
|
for slash in SLASHES:
|
||||||
|
text = text.replace(slash, '/')
|
||||||
|
|
||||||
|
#for tilde in TILDES:
|
||||||
|
# text = text.replace(tilde, '~')
|
||||||
|
|
||||||
|
return text
|
|
@ -0,0 +1,132 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class Options:
|
||||||
|
def __init__(self):
|
||||||
|
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
# basic parameters
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--output_dir", type=str, default="./checkpoint/my_experiments", help="models are saved here"
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--train_data",
|
||||||
|
nargs="+",
|
||||||
|
default=[],
|
||||||
|
help="Data used for training, passed as a list of directories splitted into tensor files.",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--eval_data",
|
||||||
|
nargs="+",
|
||||||
|
default=[],
|
||||||
|
help="Data used for evaluation during finetuning, this option is not used during contrastive pre-training.",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--eval_datasets", nargs="+", default=[], help="List of datasets used for evaluation, in BEIR format"
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--eval_datasets_dir", type=str, default="./", help="Directory where eval datasets are stored"
|
||||||
|
)
|
||||||
|
self.parser.add_argument("--model_path", type=str, default="none", help="path for retraining")
|
||||||
|
self.parser.add_argument("--continue_training", action="store_true")
|
||||||
|
self.parser.add_argument("--num_workers", type=int, default=5)
|
||||||
|
|
||||||
|
self.parser.add_argument("--chunk_length", type=int, default=256)
|
||||||
|
self.parser.add_argument("--loading_mode", type=str, default="split")
|
||||||
|
self.parser.add_argument("--lower_case", action="store_true", help="perform evaluation after lowercasing")
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--sampling_coefficient",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="coefficient used for sampling between different datasets during training, \
|
||||||
|
by default sampling is uniform over datasets",
|
||||||
|
)
|
||||||
|
self.parser.add_argument("--augmentation", type=str, default="none")
|
||||||
|
self.parser.add_argument("--prob_augmentation", type=float, default=0.0)
|
||||||
|
|
||||||
|
self.parser.add_argument("--dropout", type=float, default=0.1)
|
||||||
|
self.parser.add_argument("--rho", type=float, default=0.05)
|
||||||
|
|
||||||
|
self.parser.add_argument("--contrastive_mode", type=str, default="moco")
|
||||||
|
self.parser.add_argument("--queue_size", type=int, default=65536)
|
||||||
|
self.parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
|
self.parser.add_argument("--momentum", type=float, default=0.999)
|
||||||
|
self.parser.add_argument("--moco_train_mode_encoder_k", action="store_true")
|
||||||
|
self.parser.add_argument("--eval_normalize_text", action="store_true")
|
||||||
|
self.parser.add_argument("--norm_query", action="store_true")
|
||||||
|
self.parser.add_argument("--norm_doc", action="store_true")
|
||||||
|
self.parser.add_argument("--projection_size", type=int, default=768)
|
||||||
|
|
||||||
|
self.parser.add_argument("--ratio_min", type=float, default=0.1)
|
||||||
|
self.parser.add_argument("--ratio_max", type=float, default=0.5)
|
||||||
|
self.parser.add_argument("--score_function", type=str, default="dot")
|
||||||
|
self.parser.add_argument("--retriever_model_id", type=str, default="bert-base-uncased")
|
||||||
|
self.parser.add_argument("--pooling", type=str, default="average")
|
||||||
|
self.parser.add_argument("--random_init", action="store_true", help="init model with random weights")
|
||||||
|
|
||||||
|
# dataset parameters
|
||||||
|
self.parser.add_argument("--per_gpu_batch_size", default=64, type=int, help="Batch size per GPU for training.")
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--per_gpu_eval_batch_size", default=256, type=int, help="Batch size per GPU for evaluation."
|
||||||
|
)
|
||||||
|
self.parser.add_argument("--total_steps", type=int, default=1000)
|
||||||
|
self.parser.add_argument("--warmup_steps", type=int, default=-1)
|
||||||
|
|
||||||
|
self.parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
|
self.parser.add_argument("--main_port", type=int, default=10001, help="Master port (for multi-node SLURM jobs)")
|
||||||
|
self.parser.add_argument("--seed", type=int, default=0, help="random seed for initialization")
|
||||||
|
# training parameters
|
||||||
|
self.parser.add_argument("--optim", type=str, default="adamw")
|
||||||
|
self.parser.add_argument("--scheduler", type=str, default="linear")
|
||||||
|
self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--lr_min_ratio",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="minimum learning rate at the end of the optimization schedule as a ratio of the learning rate",
|
||||||
|
)
|
||||||
|
self.parser.add_argument("--weight_decay", type=float, default=0.01, help="learning rate")
|
||||||
|
self.parser.add_argument("--beta1", type=float, default=0.9, help="beta1")
|
||||||
|
self.parser.add_argument("--beta2", type=float, default=0.98, help="beta2")
|
||||||
|
self.parser.add_argument("--eps", type=float, default=1e-6, help="eps")
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--log_freq", type=int, default=100, help="log train stats every <log_freq> steps during training"
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--eval_freq", type=int, default=500, help="evaluate model every <eval_freq> steps during training"
|
||||||
|
)
|
||||||
|
self.parser.add_argument("--save_freq", type=int, default=50000)
|
||||||
|
self.parser.add_argument("--maxload", type=int, default=None)
|
||||||
|
self.parser.add_argument("--label_smoothing", type=float, default=0.0)
|
||||||
|
|
||||||
|
# finetuning options
|
||||||
|
self.parser.add_argument("--negative_ctxs", type=int, default=1)
|
||||||
|
self.parser.add_argument("--negative_hard_min_idx", type=int, default=0)
|
||||||
|
self.parser.add_argument("--negative_hard_ratio", type=float, default=0.0)
|
||||||
|
|
||||||
|
def print_options(self, opt):
|
||||||
|
message = ""
|
||||||
|
for k, v in sorted(vars(opt).items()):
|
||||||
|
comment = ""
|
||||||
|
default = self.parser.get_default(k)
|
||||||
|
if v != default:
|
||||||
|
comment = f"\t[default: %s]" % str(default)
|
||||||
|
message += f"{str(k):>40}: {str(v):<40}{comment}\n"
|
||||||
|
print(message, flush=True)
|
||||||
|
model_dir = os.path.join(opt.output_dir, "models")
|
||||||
|
if not os.path.exists(model_dir):
|
||||||
|
os.makedirs(os.path.join(opt.output_dir, "models"))
|
||||||
|
file_name = os.path.join(opt.output_dir, "opt.txt")
|
||||||
|
with open(file_name, "wt") as opt_file:
|
||||||
|
opt_file.write(message)
|
||||||
|
opt_file.write("\n")
|
||||||
|
|
||||||
|
def parse(self):
|
||||||
|
opt, _ = self.parser.parse_known_args()
|
||||||
|
# opt = self.parser.parse_args()
|
||||||
|
return opt
|
|
@ -0,0 +1,114 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from logging import getLogger
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import socket
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
|
logger = getLogger()
|
||||||
|
|
||||||
|
def sig_handler(signum, frame):
|
||||||
|
logger.warning("Signal handler called with signal " + str(signum))
|
||||||
|
prod_id = int(os.environ['SLURM_PROCID'])
|
||||||
|
logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id))
|
||||||
|
if prod_id == 0:
|
||||||
|
logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID'])
|
||||||
|
os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID'])
|
||||||
|
else:
|
||||||
|
logger.warning("Not the main process, no need to requeue.")
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def term_handler(signum, frame):
|
||||||
|
logger.warning("Signal handler called with signal " + str(signum))
|
||||||
|
logger.warning("Bypassing SIGTERM.")
|
||||||
|
|
||||||
|
|
||||||
|
def init_signal_handler():
|
||||||
|
"""
|
||||||
|
Handle signals sent by SLURM for time limit / pre-emption.
|
||||||
|
"""
|
||||||
|
signal.signal(signal.SIGUSR1, sig_handler)
|
||||||
|
signal.signal(signal.SIGTERM, term_handler)
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed_mode(params):
|
||||||
|
"""
|
||||||
|
Handle single and multi-GPU / multi-node / SLURM jobs.
|
||||||
|
Initialize the following variables:
|
||||||
|
- local_rank
|
||||||
|
- global_rank
|
||||||
|
- world_size
|
||||||
|
"""
|
||||||
|
is_slurm_job = 'SLURM_JOB_ID' in os.environ and not 'WORLD_SIZE' in os.environ
|
||||||
|
has_local_rank = hasattr(params, 'local_rank')
|
||||||
|
|
||||||
|
# SLURM job without torch.distributed.launch
|
||||||
|
if is_slurm_job and has_local_rank:
|
||||||
|
|
||||||
|
assert params.local_rank == -1 # on the cluster, this is handled by SLURM
|
||||||
|
|
||||||
|
# local rank on the current node / global rank
|
||||||
|
params.local_rank = int(os.environ['SLURM_LOCALID'])
|
||||||
|
params.global_rank = int(os.environ['SLURM_PROCID'])
|
||||||
|
params.world_size = int(os.environ['SLURM_NTASKS'])
|
||||||
|
|
||||||
|
# define master address and master port
|
||||||
|
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']])
|
||||||
|
params.main_addr = hostnames.split()[0].decode('utf-8')
|
||||||
|
assert 10001 <= params.main_port <= 20000 or params.world_size == 1
|
||||||
|
|
||||||
|
# set environment variables for 'env://'
|
||||||
|
os.environ['MASTER_ADDR'] = params.main_addr
|
||||||
|
os.environ['MASTER_PORT'] = str(params.main_port)
|
||||||
|
os.environ['WORLD_SIZE'] = str(params.world_size)
|
||||||
|
os.environ['RANK'] = str(params.global_rank)
|
||||||
|
is_distributed = True
|
||||||
|
|
||||||
|
|
||||||
|
# multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch
|
||||||
|
elif has_local_rank and params.local_rank != -1:
|
||||||
|
|
||||||
|
assert params.main_port == -1
|
||||||
|
|
||||||
|
# read environment variables
|
||||||
|
params.global_rank = int(os.environ['RANK'])
|
||||||
|
params.world_size = int(os.environ['WORLD_SIZE'])
|
||||||
|
|
||||||
|
is_distributed = True
|
||||||
|
|
||||||
|
# local job (single GPU)
|
||||||
|
else:
|
||||||
|
params.local_rank = 0
|
||||||
|
params.global_rank = 0
|
||||||
|
params.world_size = 1
|
||||||
|
is_distributed = False
|
||||||
|
|
||||||
|
# set GPU device
|
||||||
|
torch.cuda.set_device(params.local_rank)
|
||||||
|
|
||||||
|
# initialize multi-GPU
|
||||||
|
if is_distributed:
|
||||||
|
|
||||||
|
# http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization
|
||||||
|
# 'env://' will read these environment variables:
|
||||||
|
# MASTER_PORT - required; has to be a free port on machine with rank 0
|
||||||
|
# MASTER_ADDR - required (except for rank 0); address of rank 0 node
|
||||||
|
# WORLD_SIZE - required; can be set either here, or in a call to init function
|
||||||
|
# RANK - required; can be set either here, or in a call to init function
|
||||||
|
|
||||||
|
#print("Initializing PyTorch distributed ...")
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
init_method='env://',
|
||||||
|
backend='nccl',
|
||||||
|
#world_size=params.world_size,
|
||||||
|
#rank=params.global_rank,
|
||||||
|
)
|
|
@ -0,0 +1,213 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import errno
|
||||||
|
from typing import Union, Tuple, Dict
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from robowaiter.algos.retrieval.retrieval_lm.src import dist_utils
|
||||||
|
|
||||||
|
Number = Union[float, int]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def init_logger(args, stdout_only=False):
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
torch.distributed.barrier()
|
||||||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handlers = [stdout_handler]
|
||||||
|
if not stdout_only:
|
||||||
|
file_handler = logging.FileHandler(filename=os.path.join(args.output_dir, "run.log"))
|
||||||
|
handlers.append(file_handler)
|
||||||
|
logging.basicConfig(
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO if dist_utils.is_main() else logging.WARN,
|
||||||
|
format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s",
|
||||||
|
handlers=handlers,
|
||||||
|
)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def symlink_force(target, link_name):
|
||||||
|
try:
|
||||||
|
os.symlink(target, link_name)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno == errno.EEXIST:
|
||||||
|
os.remove(link_name)
|
||||||
|
os.symlink(target, link_name)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def save(model, optimizer, scheduler, step, opt, dir_path, name):
|
||||||
|
model_to_save = model.module if hasattr(model, "module") else model
|
||||||
|
path = os.path.join(dir_path, "checkpoint")
|
||||||
|
epoch_path = os.path.join(path, name) # "step-%s" % step)
|
||||||
|
os.makedirs(epoch_path, exist_ok=True)
|
||||||
|
cp = os.path.join(path, "latest")
|
||||||
|
fp = os.path.join(epoch_path, "checkpoint.pth")
|
||||||
|
checkpoint = {
|
||||||
|
"step": step,
|
||||||
|
"model": model_to_save.state_dict(),
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
"scheduler": scheduler.state_dict(),
|
||||||
|
"opt": opt,
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, fp)
|
||||||
|
symlink_force(epoch_path, cp)
|
||||||
|
if not name == "lastlog":
|
||||||
|
logger.info(f"Saving model to {epoch_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def load(model_class, dir_path, opt, reset_params=False):
|
||||||
|
epoch_path = os.path.realpath(dir_path)
|
||||||
|
checkpoint_path = os.path.join(epoch_path, "checkpoint.pth")
|
||||||
|
logger.info(f"loading checkpoint {checkpoint_path}")
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||||
|
opt_checkpoint = checkpoint["opt"]
|
||||||
|
state_dict = checkpoint["model"]
|
||||||
|
|
||||||
|
model = model_class(opt_checkpoint)
|
||||||
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
model = model.cuda()
|
||||||
|
step = checkpoint["step"]
|
||||||
|
if not reset_params:
|
||||||
|
optimizer, scheduler = set_optim(opt_checkpoint, model)
|
||||||
|
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||||
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||||
|
else:
|
||||||
|
optimizer, scheduler = set_optim(opt, model)
|
||||||
|
|
||||||
|
return model, optimizer, scheduler, opt_checkpoint, step
|
||||||
|
|
||||||
|
|
||||||
|
############ OPTIM
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR):
|
||||||
|
def __init__(self, optimizer, warmup, total, ratio, last_epoch=-1):
|
||||||
|
self.warmup = warmup
|
||||||
|
self.total = total
|
||||||
|
self.ratio = ratio
|
||||||
|
super(WarmupLinearScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
def lr_lambda(self, step):
|
||||||
|
if step < self.warmup:
|
||||||
|
return (1 - self.ratio) * step / float(max(1, self.warmup))
|
||||||
|
|
||||||
|
return max(
|
||||||
|
0.0,
|
||||||
|
1.0 + (self.ratio - 1) * (step - self.warmup) / float(max(1.0, self.total - self.warmup)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CosineScheduler(torch.optim.lr_scheduler.LambdaLR):
|
||||||
|
def __init__(self, optimizer, warmup, total, ratio=0.1, last_epoch=-1):
|
||||||
|
self.warmup = warmup
|
||||||
|
self.total = total
|
||||||
|
self.ratio = ratio
|
||||||
|
super(CosineScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
def lr_lambda(self, step):
|
||||||
|
if step < self.warmup:
|
||||||
|
return float(step) / self.warmup
|
||||||
|
s = float(step - self.warmup) / (self.total - self.warmup)
|
||||||
|
return self.ratio + (1.0 - self.ratio) * math.cos(0.5 * math.pi * s)
|
||||||
|
|
||||||
|
|
||||||
|
def set_optim(opt, model):
|
||||||
|
if opt.optim == "adamw":
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
model.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), eps=opt.eps, weight_decay=opt.weight_decay
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("optimizer class not implemented")
|
||||||
|
|
||||||
|
scheduler_args = {
|
||||||
|
"warmup": opt.warmup_steps,
|
||||||
|
"total": opt.total_steps,
|
||||||
|
"ratio": opt.lr_min_ratio,
|
||||||
|
}
|
||||||
|
if opt.scheduler == "linear":
|
||||||
|
scheduler_class = WarmupLinearScheduler
|
||||||
|
elif opt.scheduler == "cosine":
|
||||||
|
scheduler_class = CosineScheduler
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
scheduler = scheduler_class(optimizer, **scheduler_args)
|
||||||
|
return optimizer, scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameters(net, verbose=False):
|
||||||
|
num_params = 0
|
||||||
|
for param in net.parameters():
|
||||||
|
num_params += param.numel()
|
||||||
|
message = "[Network] Total number of parameters : %.6f M" % (num_params / 1e6)
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedAvgStats:
|
||||||
|
"""provides an average over a bunch of stats"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.raw_stats: Dict[str, float] = defaultdict(float)
|
||||||
|
self.total_weights: Dict[str, float] = defaultdict(float)
|
||||||
|
|
||||||
|
def update(self, vals: Dict[str, Tuple[Number, Number]]) -> None:
|
||||||
|
for key, (value, weight) in vals.items():
|
||||||
|
self.raw_stats[key] += value * weight
|
||||||
|
self.total_weights[key] += weight
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stats(self) -> Dict[str, float]:
|
||||||
|
return {x: self.raw_stats[x] / self.total_weights[x] for x in self.raw_stats.keys()}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tuple_stats(self) -> Dict[str, Tuple[float, float]]:
|
||||||
|
return {x: (self.raw_stats[x] / self.total_weights[x], self.total_weights[x]) for x in self.raw_stats.keys()}
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.raw_stats = defaultdict(float)
|
||||||
|
self.total_weights = defaultdict(float)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def average_stats(self) -> Dict[str, float]:
|
||||||
|
keys = sorted(self.raw_stats.keys())
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
torch.distributed.broadcast_object_list(keys, src=0)
|
||||||
|
global_dict = {}
|
||||||
|
for k in keys:
|
||||||
|
if not k in self.total_weights:
|
||||||
|
v = 0.0
|
||||||
|
else:
|
||||||
|
v = self.raw_stats[k] / self.total_weights[k]
|
||||||
|
v, _ = dist_utils.weighted_average(v, self.total_weights[k])
|
||||||
|
global_dict[k] = v
|
||||||
|
return global_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf(object_class, model_name):
|
||||||
|
try:
|
||||||
|
obj = object_class.from_pretrained(model_name, local_files_only=True)
|
||||||
|
except:
|
||||||
|
obj = object_class.from_pretrained(model_name, local_files_only=False)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def init_tb_logger(output_dir):
|
||||||
|
try:
|
||||||
|
from torch.utils import tensorboard
|
||||||
|
|
||||||
|
if dist_utils.is_main():
|
||||||
|
tb_logger = tensorboard.SummaryWriter(output_dir)
|
||||||
|
else:
|
||||||
|
tb_logger = None
|
||||||
|
except:
|
||||||
|
logger.warning("Tensorboard is not available.")
|
||||||
|
tb_logger = None
|
||||||
|
|
||||||
|
return tb_logger
|
|
@ -0,0 +1,23 @@
|
||||||
|
{
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 1e9,
|
||||||
|
"reduce_bucket_size": "auto",
|
||||||
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
|
"stage3_param_persistence_threshold": "auto",
|
||||||
|
"stage3_max_live_parameters": 1e9,
|
||||||
|
"stage3_max_reuse_distance": 1e9,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"steps_per_print": 1e5,
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
|
@ -0,0 +1,194 @@
|
||||||
|
import jsonlines
|
||||||
|
import json
|
||||||
|
import copy
|
||||||
|
import re
|
||||||
|
|
||||||
|
PROMPT_DICT = {
|
||||||
|
"prompt_input": (
|
||||||
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||||
|
),
|
||||||
|
"prompt_no_input": (
|
||||||
|
"### Instruction:\n{instruction}\n\n### Response:\n"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
TASK_INST = {"wow": "Given a chat history separated by new lines, generates an informative, knowledgeable and engaging response. ",
|
||||||
|
"fever": "Is the following statement correct or not? Say true if it's correct; otherwise say false.",
|
||||||
|
"eli5": "Provide a paragraph-length response using simple words to answer the following question.",
|
||||||
|
"obqa": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
|
||||||
|
"arc_easy": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
|
||||||
|
"arc_c": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
|
||||||
|
"trex": "Given the input format 'Subject Entity [SEP] Relationship Type,' predict the target entity.",
|
||||||
|
"asqa": "Answer the following question. The question may be ambiguous and have multiple correct answers, and in that case, you have to provide a long-form answer including all correct answers."}
|
||||||
|
|
||||||
|
rel_tokens_names = ["[Irrelevant]", "[Relevant]"]
|
||||||
|
retrieval_tokens_names = ["[No Retrieval]",
|
||||||
|
"[Retrieval]", "[Continue to Use Evidence]"]
|
||||||
|
utility_tokens_names = ["[Utility:1]", "[Utility:2]",
|
||||||
|
"[Utility:3]", "[Utility:4]", "[Utility:5]"]
|
||||||
|
ground_tokens_names = ["[Fully supported]",
|
||||||
|
"[Partially supported]", "[No support / Contradictory]"]
|
||||||
|
other_special_tokens = ["<s>", "</s>", "[PAD]",
|
||||||
|
"<unk>", "<paragraph>", "</paragraph>"]
|
||||||
|
control_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]",
|
||||||
|
"[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_special_tokens(tokenizer, use_grounding=False, use_utility=False):
|
||||||
|
ret_tokens = {token: tokenizer.convert_tokens_to_ids(
|
||||||
|
token) for token in retrieval_tokens_names}
|
||||||
|
rel_tokens = {}
|
||||||
|
for token in ["[Irrelevant]", "[Relevant]"]:
|
||||||
|
rel_tokens[token] = tokenizer.convert_tokens_to_ids(token)
|
||||||
|
|
||||||
|
grd_tokens = None
|
||||||
|
if use_grounding is True:
|
||||||
|
grd_tokens = {}
|
||||||
|
for token in ground_tokens_names:
|
||||||
|
grd_tokens[token] = tokenizer.convert_tokens_to_ids(token)
|
||||||
|
|
||||||
|
ut_tokens = None
|
||||||
|
if use_utility is True:
|
||||||
|
ut_tokens = {}
|
||||||
|
for token in utility_tokens_names:
|
||||||
|
ut_tokens[token] = tokenizer.convert_tokens_to_ids(token)
|
||||||
|
|
||||||
|
return ret_tokens, rel_tokens, grd_tokens, ut_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def fix_spacing(input_text):
|
||||||
|
# Add a space after periods that lack whitespace
|
||||||
|
output_text = re.sub(r'(?<=\w)([.!?])(?=\w)', r'\1 ', input_text)
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess(pred):
|
||||||
|
special_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]",
|
||||||
|
"[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"]
|
||||||
|
for item in special_tokens:
|
||||||
|
pred = pred.replace(item, "")
|
||||||
|
pred = pred.replace("</s>", "")
|
||||||
|
|
||||||
|
if len(pred) == 0:
|
||||||
|
return ""
|
||||||
|
if pred[0] == " ":
|
||||||
|
pred = pred[1:]
|
||||||
|
return pred
|
||||||
|
|
||||||
|
|
||||||
|
def load_jsonlines(file):
|
||||||
|
with jsonlines.open(file, 'r') as jsonl_f:
|
||||||
|
lst = [obj for obj in jsonl_f]
|
||||||
|
return lst
|
||||||
|
|
||||||
|
|
||||||
|
def load_file(input_fp):
|
||||||
|
if input_fp.endswith(".json"):
|
||||||
|
input_data = json.load(open(input_fp))
|
||||||
|
else:
|
||||||
|
input_data = load_jsonlines(input_fp)
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
|
||||||
|
def save_file_jsonl(data, fp):
|
||||||
|
with jsonlines.open(fp, mode='w') as writer:
|
||||||
|
writer.write_all(data)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_input(input_data, task):
|
||||||
|
if task == "factscore":
|
||||||
|
for item in input_data:
|
||||||
|
item["instruction"] = item["input"]
|
||||||
|
item["output"] = [item["output"]
|
||||||
|
] if "output" in item else [item["topic"]]
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
elif task == "qa":
|
||||||
|
for item in input_data:
|
||||||
|
if "instruction" not in item:
|
||||||
|
item["instruction"] = item["question"]
|
||||||
|
if "answers" not in item and "output" in item:
|
||||||
|
item["answers"] = "output"
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
elif task in ["asqa", "eli5"]:
|
||||||
|
processed_input_data = []
|
||||||
|
for instance_idx, item in enumerate(input_data["data"]):
|
||||||
|
prompt = item["question"]
|
||||||
|
instructions = TASK_INST[task]
|
||||||
|
prompt = instructions + "## Input:\n\n" + prompt
|
||||||
|
entry = copy.deepcopy(item)
|
||||||
|
entry["instruction"] = prompt
|
||||||
|
processed_input_data.append(entry)
|
||||||
|
return processed_input_data
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_output(input_instance, prediction, task, intermediate_results=None):
|
||||||
|
if task == "factscore":
|
||||||
|
return {"input": input_instance["input"], "output": prediction, "topic": input_instance["topic"], "cat": input_instance["cat"]}
|
||||||
|
|
||||||
|
elif task == "qa":
|
||||||
|
input_instance["pred"] = prediction
|
||||||
|
return input_instance
|
||||||
|
|
||||||
|
elif task in ["asqa", "eli5"]:
|
||||||
|
# ALCE datasets require additional postprocessing to compute citation accuracy.
|
||||||
|
final_output = ""
|
||||||
|
docs = []
|
||||||
|
if "splitted_sentences" not in intermediate_results:
|
||||||
|
input_instance["output"] = postprocess(prediction)
|
||||||
|
|
||||||
|
else:
|
||||||
|
for idx, (sent, doc) in enumerate(zip(intermediate_results["splitted_sentences"][0], intermediate_results["ctxs"][0])):
|
||||||
|
if len(sent) == 0:
|
||||||
|
continue
|
||||||
|
postprocessed_result = postprocess(sent)
|
||||||
|
final_output += postprocessed_result[:-
|
||||||
|
1] + " [{}]".format(idx) + ". "
|
||||||
|
docs.append(doc)
|
||||||
|
if final_output[-1] == " ":
|
||||||
|
final_output = final_output[:-1]
|
||||||
|
input_instance["output"] = final_output
|
||||||
|
input_instance["docs"] = docs
|
||||||
|
return input_instance
|
||||||
|
|
||||||
|
def process_arc_instruction(item, instruction):
|
||||||
|
choices = item["choices"]
|
||||||
|
answer_labels = {}
|
||||||
|
for i in range(len(choices["label"])):
|
||||||
|
answer_key = choices["label"][i]
|
||||||
|
text = choices["text"][i]
|
||||||
|
if answer_key == "1":
|
||||||
|
answer_labels["A"] = text
|
||||||
|
if answer_key == "2":
|
||||||
|
answer_labels["B"] = text
|
||||||
|
if answer_key == "3":
|
||||||
|
answer_labels["C"] = text
|
||||||
|
if answer_key == "4":
|
||||||
|
answer_labels["D"] = text
|
||||||
|
if answer_key in ["A", "B", "C", "D"]:
|
||||||
|
answer_labels[answer_key] = text
|
||||||
|
|
||||||
|
if "D" not in answer_labels:
|
||||||
|
answer_labels["D"] = ""
|
||||||
|
choices = "\nA: {0}\nB: {1}\nC: {2}\nD: {3}".format(answer_labels["A"], answer_labels["B"], answer_labels["C"], answer_labels["D"])
|
||||||
|
if "E" in answer_labels:
|
||||||
|
choices += "\nE: {}".format(answer_labels["E"])
|
||||||
|
processed_instruction = instruction + "\n\n### Input:\n" + item["instruction"] + choices
|
||||||
|
return processed_instruction
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_answers_closed(output, task, choices=None):
|
||||||
|
final_output = None
|
||||||
|
if choices is not None:
|
||||||
|
for c in choices.split(" "):
|
||||||
|
if c in output:
|
||||||
|
final_output = c
|
||||||
|
if task == "fever" and output in ["REFUTES", "SUPPORTS"]:
|
||||||
|
final_output = "true" if output == "SUPPORTS" else "REFUTES"
|
||||||
|
if task == "fever" and output.lower() in ["true", "false"]:
|
||||||
|
final_output = output.lower()
|
||||||
|
if final_output is None:
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
return final_output
|
Binary file not shown.
|
@ -0,0 +1,10 @@
|
||||||
|
python passage_retrieval3.py \
|
||||||
|
--model_name_or_path ../model/contriever-msmarco \
|
||||||
|
--passages train_robot.jsonl \
|
||||||
|
--passages_embeddings "robot_embeddings/*" \
|
||||||
|
--data test_robot.jsonl \
|
||||||
|
--output_dir robot_result \
|
||||||
|
--n_docs 2
|
||||||
|
|
||||||
|
|
||||||
|
#python passage_retrieval3.py --model_name_or_path contriever-msmarco --passages train_robot.jsonl --passages_embeddings "robot_embeddings/*" --data test_robot.jsonl --output_dir robot_result --n_docs 2
|
|
@ -195,6 +195,7 @@ get_object_info
|
||||||
我带着孩子呢,想要宽敞亮堂的地方。
|
我带着孩子呢,想要宽敞亮堂的地方。
|
||||||
好的,我明白了,那么我们推荐您到大厅的桌子,那里的空间比较宽敞,环境也比较明亮,适合带着孩子一起用餐。
|
好的,我明白了,那么我们推荐您到大厅的桌子,那里的空间比较宽敞,环境也比较明亮,适合带着孩子一起用餐。
|
||||||
|
|
||||||
|
|
||||||
冰红茶
|
冰红茶
|
||||||
好的
|
好的
|
||||||
create_sub_task
|
create_sub_task
|
||||||
|
|
|
@ -0,0 +1,312 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
import glob
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from robowaiter.algos.retrieval.retrieval_lm.src.slurm import init_distributed_mode
|
||||||
|
from robowaiter.algos.retrieval.retrieval_lm.src.normalize_text import normalize
|
||||||
|
from robowaiter.algos.retrieval.retrieval_lm.src.contriever import load_retriever
|
||||||
|
from robowaiter.algos.retrieval.retrieval_lm.src.index import Indexer
|
||||||
|
from robowaiter.algos.retrieval.retrieval_lm.src.data import load_passages
|
||||||
|
|
||||||
|
from robowaiter.algos.retrieval.retrieval_lm.src.evaluation import calculate_matches
|
||||||
|
import warnings
|
||||||
|
from robowaiter.utils.basic import get_root_path
|
||||||
|
root_path = get_root_path()
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
def embed_queries(args, queries, model, tokenizer):
|
||||||
|
model.eval()
|
||||||
|
embeddings, batch_question = [], []
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
for k, q in enumerate(queries):
|
||||||
|
if args.lowercase:
|
||||||
|
q = q.lower()
|
||||||
|
if args.normalize_text:
|
||||||
|
q = normalize(q)
|
||||||
|
batch_question.append(q)
|
||||||
|
|
||||||
|
if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
|
||||||
|
|
||||||
|
encoded_batch = tokenizer.batch_encode_plus(
|
||||||
|
batch_question,
|
||||||
|
return_tensors="pt",
|
||||||
|
max_length=args.question_maxlength,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
||||||
|
output = model(**encoded_batch)
|
||||||
|
embeddings.append(output.cpu())
|
||||||
|
|
||||||
|
batch_question = []
|
||||||
|
|
||||||
|
embeddings = torch.cat(embeddings, dim=0)
|
||||||
|
#print(f"Questions embeddings shape: {embeddings.size()}")
|
||||||
|
|
||||||
|
return embeddings.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
def index_encoded_data(index, embedding_files, indexing_batch_size):
|
||||||
|
allids = []
|
||||||
|
allembeddings = np.array([])
|
||||||
|
for i, file_path in enumerate(embedding_files):
|
||||||
|
#print(f"Loading file {file_path}")
|
||||||
|
with open(file_path, "rb") as fin:
|
||||||
|
ids, embeddings = pickle.load(fin)
|
||||||
|
|
||||||
|
allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings
|
||||||
|
allids.extend(ids)
|
||||||
|
while allembeddings.shape[0] > indexing_batch_size:
|
||||||
|
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
||||||
|
|
||||||
|
while allembeddings.shape[0] > 0:
|
||||||
|
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
||||||
|
|
||||||
|
#print("Data indexing completed.")
|
||||||
|
|
||||||
|
|
||||||
|
def add_embeddings(index, embeddings, ids, indexing_batch_size):
|
||||||
|
end_idx = min(indexing_batch_size, embeddings.shape[0])
|
||||||
|
ids_toadd = ids[:end_idx]
|
||||||
|
embeddings_toadd = embeddings[:end_idx]
|
||||||
|
ids = ids[end_idx:]
|
||||||
|
embeddings = embeddings[end_idx:]
|
||||||
|
index.index_data(ids_toadd, embeddings_toadd)
|
||||||
|
return embeddings, ids
|
||||||
|
|
||||||
|
|
||||||
|
def validate(data, workers_num):
|
||||||
|
match_stats = calculate_matches(data, workers_num)
|
||||||
|
top_k_hits = match_stats.top_k_hits
|
||||||
|
|
||||||
|
# print("Validation results: top k documents hits %s", top_k_hits)
|
||||||
|
top_k_hits = [v / len(data) for v in top_k_hits]
|
||||||
|
message = ""
|
||||||
|
for k in [5, 10, 20, 100]:
|
||||||
|
if k <= len(top_k_hits):
|
||||||
|
message += f"R@{k}: {top_k_hits[k-1]} "
|
||||||
|
#print(message)
|
||||||
|
return match_stats.questions_doc_hits
|
||||||
|
|
||||||
|
|
||||||
|
def add_passages(data, passages, top_passages_and_scores):
|
||||||
|
# add passages to original data
|
||||||
|
merged_data = []
|
||||||
|
assert len(data) == len(top_passages_and_scores)
|
||||||
|
for i, d in enumerate(data):
|
||||||
|
results_and_scores = top_passages_and_scores[i]
|
||||||
|
#print(passages[2393])
|
||||||
|
docs = [passages[int(doc_id)] for doc_id in results_and_scores[0]]
|
||||||
|
scores = [str(score) for score in results_and_scores[1]]
|
||||||
|
ctxs_num = len(docs)
|
||||||
|
d["ctxs"] = [
|
||||||
|
{
|
||||||
|
"id": results_and_scores[0][c],
|
||||||
|
"title": docs[c]["title"],
|
||||||
|
"text": docs[c]["text"],
|
||||||
|
"score": scores[c],
|
||||||
|
}
|
||||||
|
for c in range(ctxs_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def add_hasanswer(data, hasanswer):
|
||||||
|
# add hasanswer to data
|
||||||
|
for i, ex in enumerate(data):
|
||||||
|
for k, d in enumerate(ex["ctxs"]):
|
||||||
|
d["hasanswer"] = hasanswer[i][k]
|
||||||
|
|
||||||
|
|
||||||
|
# def load_data(data_path):
|
||||||
|
# if data_path.endswith(".json"):
|
||||||
|
# with open(data_path, "r",encoding='utf-8') as fin:
|
||||||
|
# data = json.load(fin)
|
||||||
|
# elif data_path.endswith(".jsonl"):
|
||||||
|
# data = []
|
||||||
|
# with open(data_path, "r",encoding='utf-8') as fin:
|
||||||
|
# for k, example in enumerate(fin):
|
||||||
|
# example = json.loads(example)
|
||||||
|
# data.append(example)
|
||||||
|
# print("data:",data)
|
||||||
|
# return data
|
||||||
|
def load_data(data_path):
|
||||||
|
if data_path.endswith(".json"):
|
||||||
|
with open(data_path, "r",encoding='utf-8') as fin:
|
||||||
|
data = json.load(fin)
|
||||||
|
elif data_path.endswith(".jsonl"):
|
||||||
|
data = []
|
||||||
|
with open(data_path, "r",encoding='utf-8') as fin:
|
||||||
|
for k, example in enumerate(fin):
|
||||||
|
example = json.loads(example)
|
||||||
|
#print("example:",example)
|
||||||
|
data.append(example)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def test(args):#path为query
|
||||||
|
# args = {"model_name_or_path":"contriever-msmarco","passages":"train_robot.jsonl"\
|
||||||
|
# passages_embeddings = "robot_embeddings/*"
|
||||||
|
# data = "test_robot.jsonl"
|
||||||
|
# output_dir = "robot_result"
|
||||||
|
# n_docs = 1
|
||||||
|
|
||||||
|
#print(f"Loading model from: {args.model_name_or_path}")
|
||||||
|
model, tokenizer, _ = load_retriever(args.model_name_or_path)
|
||||||
|
model.eval()
|
||||||
|
model = model.cuda()
|
||||||
|
if not args.no_fp16:
|
||||||
|
model = model.half()
|
||||||
|
|
||||||
|
index = Indexer(args.projection_size, args.n_subquantizers, args.n_bits)
|
||||||
|
|
||||||
|
# index all passages
|
||||||
|
input_paths = glob.glob(args.passages_embeddings)
|
||||||
|
input_paths = sorted(input_paths)
|
||||||
|
embeddings_dir = os.path.dirname(input_paths[0])
|
||||||
|
index_path = os.path.join(embeddings_dir, "index.faiss")
|
||||||
|
if args.save_or_load_index and os.path.exists(index_path):
|
||||||
|
index.deserialize_from(embeddings_dir)
|
||||||
|
else:
|
||||||
|
#print(f"Indexing passages from files {input_paths}")
|
||||||
|
start_time_indexing = time.time()
|
||||||
|
index_encoded_data(index, input_paths, args.indexing_batch_size)
|
||||||
|
#print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
|
||||||
|
if args.save_or_load_index:
|
||||||
|
index.serialize(embeddings_dir)
|
||||||
|
|
||||||
|
# load passages
|
||||||
|
passages = load_passages(args.passages)
|
||||||
|
passage_id_map = {x["id"]: x for x in passages}
|
||||||
|
|
||||||
|
data_paths = glob.glob(args.data)
|
||||||
|
alldata = []
|
||||||
|
for path in data_paths:
|
||||||
|
data = load_data(path)
|
||||||
|
#print("data:",data)
|
||||||
|
output_path = os.path.join(args.output_dir, os.path.basename(path))
|
||||||
|
|
||||||
|
queries = [ex["question"] for ex in data]
|
||||||
|
questions_embedding = embed_queries(args, queries, model, tokenizer)
|
||||||
|
|
||||||
|
# get top k results
|
||||||
|
start_time_retrieval = time.time()
|
||||||
|
top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
|
||||||
|
#print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
|
||||||
|
|
||||||
|
add_passages(data, passage_id_map, top_ids_and_scores)
|
||||||
|
#hasanswer = validate(data, args.validation_workers)
|
||||||
|
#add_hasanswer(data, hasanswer)
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
ret_list = []
|
||||||
|
with open(output_path, "w",encoding='utf-8') as fout:
|
||||||
|
for ex in data:
|
||||||
|
json.dump(ex, fout, ensure_ascii=False)
|
||||||
|
ret_list.append(ex)
|
||||||
|
fout.write("\n")
|
||||||
|
return ret_list
|
||||||
|
#print(f"Saved results to {output_path}")
|
||||||
|
|
||||||
|
#将query写到test_robot.jsonl
|
||||||
|
def get_json(query):
|
||||||
|
dic = {"id": 1, "question": query}
|
||||||
|
with open('test_robot.jsonl', "w", encoding='utf-8') as fout:
|
||||||
|
json.dump(dic, fout, ensure_ascii=False)
|
||||||
|
|
||||||
|
def get_answer():
|
||||||
|
with open('robot_result\\test_robot.jsonl', "w", encoding='utf-8') as fin:
|
||||||
|
for k, example in enumerate(fin):
|
||||||
|
example = json.loads(example)
|
||||||
|
answer = example["ctxs"][0]["text"]
|
||||||
|
score = example["ctxs"][0]["score"]
|
||||||
|
return score, answer
|
||||||
|
def retri(query):
|
||||||
|
get_json(query)
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--data",
|
||||||
|
#required=True,
|
||||||
|
type=str,
|
||||||
|
default='test_robot.jsonl',
|
||||||
|
help=".json file containing question and answers, similar format to reader data",
|
||||||
|
)
|
||||||
|
# parser.add_argument("--passages", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/train_robot.jsonl', help="Path to passages (.tsv file)")
|
||||||
|
# parser.add_argument("--passages_embeddings", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/robot_embeddings/*', help="Glob path to encoded passages")
|
||||||
|
parser.add_argument("--passages", type=str, default=f'{root_path}/robowaiter/llm_client/train_robot.jsonl', help="Path to passages (.tsv file)")
|
||||||
|
parser.add_argument("--passages_embeddings", type=str, default=f'{root_path}/robowaiter/algos/retrieval/robot_embeddings/*', help="Glob path to encoded passages")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir", type=str, default='robot_result', help="Results are written to outputdir with data suffix"
|
||||||
|
)
|
||||||
|
parser.add_argument("--n_docs", type=int, default=5, help="Number of documents to retrieve per questions") #可以改这个参数,返回前n_docs个检索结果
|
||||||
|
parser.add_argument(
|
||||||
|
"--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
|
||||||
|
)
|
||||||
|
parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
|
||||||
|
)
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--model_name_or_path", type=str, default='C:\\Users\\huangyu\\Desktop\\RoboWaiter-main\\RoboWaiter-main\\contriever-msmarco',help="path to directory containing model weights and config file"
|
||||||
|
# )
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path", type=str, default=f'{root_path}/robowaiter/algos/retrieval/contriever-msmarco',help="path to directory containing model weights and config file"
|
||||||
|
)
|
||||||
|
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
|
||||||
|
parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
|
||||||
|
parser.add_argument(
|
||||||
|
"--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
|
||||||
|
)
|
||||||
|
parser.add_argument("--projection_size", type=int, default=768)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_subquantizers",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Number of subquantizer used for vector quantization, if 0 flat index is used",
|
||||||
|
)
|
||||||
|
parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer")
|
||||||
|
parser.add_argument("--lang", nargs="+")
|
||||||
|
parser.add_argument("--dataset", type=str, default="none")
|
||||||
|
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
|
||||||
|
parser.add_argument("--normalize_text", action="store_true", help="normalize text")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
init_distributed_mode(args)
|
||||||
|
#print(args)
|
||||||
|
ret = test(args)
|
||||||
|
#print(ret)
|
||||||
|
|
||||||
|
return ret[0]
|
||||||
|
|
||||||
|
# example = ret[0]
|
||||||
|
# answer = example["ctxs"][0]["text"]
|
||||||
|
# score = example["ctxs"][0]["score"]
|
||||||
|
# return score, answer
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# query = "请你拿一下软饮料到第三张桌子位置。"
|
||||||
|
# score,answer = retri(query)
|
||||||
|
# print(score,answer)
|
||||||
|
|
||||||
|
query = "你能把空调打开一下吗?"
|
||||||
|
all_ret = retri(query)
|
||||||
|
for i,example in enumerate(all_ret["ctxs"]):
|
||||||
|
answer = example["text"]
|
||||||
|
score = example["score"]
|
||||||
|
id = example["id"]
|
||||||
|
print(i,answer,score," id=",id)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
{"id": 1, "question": "你能把空调打开一下吗?", "ctxs": [{"id": "505", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}, {"id": "313", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8567487"}, {"id": "312", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}, {"id": "120", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8567487"}, {"id": "119", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}]}
|
|
@ -0,0 +1,43 @@
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import urllib3
|
||||||
|
########################################
|
||||||
|
# 该文件实现了与大模型的简单通信
|
||||||
|
########################################
|
||||||
|
|
||||||
|
# 忽略https的安全性警告
|
||||||
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
|
||||||
|
def single_round(question,prefix=""):
|
||||||
|
url = "https://45.125.46.134:25344/v1/chat/completions"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
data = {
|
||||||
|
"model": "RoboWaiter",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "你是一个机器人服务员:RoboWaiter. 你的职责是为顾客提供对话及具身服务。"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prefix + question
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, json=data, verify=False)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
return result['choices'][0]['message']['content'].strip()
|
||||||
|
else:
|
||||||
|
return "大模型请求失败:", response.status_code
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
question = '''
|
||||||
|
给我一杯拿铁
|
||||||
|
'''
|
||||||
|
|
||||||
|
print(single_round(question))
|
|
@ -0,0 +1 @@
|
||||||
|
{"id": 1, "question": "你能把空调打开一下吗?"}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue