This commit is contained in:
liwang_zhang 2023-11-24 11:28:04 +08:00
commit 089d987c28
49 changed files with 54961 additions and 0 deletions

1
.gitignore vendored
View File

@ -19,6 +19,7 @@ share/python-wheels/
MANIFEST
MO-VLN/
GLIP/
pytorch_model.bin
sub_task.ptml

View File

@ -15,6 +15,8 @@ pip install -e .
### 安装UI
1. 安装 [graphviz-9.0.0](https://gitlab.com/api/v4/projects/4207231/packages/generic/graphviz-releases/9.0.0/windows_10_cmake_Release_graphviz-install-9.0.0-win64.exe) (详见[官网](https://www.graphviz.org/download/#windows))
2. 将软件安装目录的bin文件添加到系统环境中。如电脑是Windows系统Graphviz安装在D:\Program Files (x86)\Graphviz2.38该目录下有bin文件将该路径添加到电脑系统环境变量path中即D:\Program Files (x86)\Graphviz2.38\bin。
3. 安装向量数据库
conda install -c conda-forge faiss
### 快速入门
1. 安装UE及Harix插件打开默认项目并运行

View File

View File

@ -0,0 +1,25 @@
{
"architectures": [
"Contriever"
],
"attention_probs_dropout_prob": 0.1,
"classifier_dropout": null,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"torch_dtype": "float32",
"transformers_version": "4.15.0",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 30522
}

View File

@ -0,0 +1 @@
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1 @@
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "bert-base-uncased", "tokenizer_class": "BertTokenizer"}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,4 @@
pip install gdown
gdown 1IYNAkwawfCDiBL27BlBqGssxFQH9vOux
unzip enwiki_2020_intro_only.zip
rm enwiki_2020_intro_only.zip

View File

@ -0,0 +1,731 @@
#!/usr/bin/env python
# coding=utf-8
import argparse
import logging
import math
import os
import random
import datasets
import torch
import copy
from functools import partial
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from typing import Optional, Dict, Sequence
import json
import transformers
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
LlamaTokenizer,
LlamaTokenizerFast,
SchedulerType,
DataCollatorForSeq2Seq,
get_scheduler,
GPTNeoXTokenizerFast,
GPT2Tokenizer,
OPTForCausalLM
)
from peft import LoraConfig, TaskType, get_peft_model
logger = get_logger(__name__)
PROMPT_DICT = {
"prompt_input": (
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
),
"prompt_no_input": (
"### Instruction:\n{instruction}\n\n### Response:\n"
),
}
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help="The name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The configuration name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
)
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=False,
)
parser.add_argument(
"--config_name",
type=str,
default=None,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--use_lora",
action="store_true",
help="If passed, will use LORA (low-rank parameter-efficient training) to train the model.",
)
parser.add_argument(
"--lora_rank",
type=int,
default=64,
help="The rank of lora.",
)
parser.add_argument(
"--lora_alpha",
type=float,
default=16,
help="The alpha parameter of lora.",
)
parser.add_argument(
"--lora_dropout",
type=float,
default=0.1,
help="The dropout rate of lora modules.",
)
parser.add_argument(
"--save_merged_lora_model",
action="store_true",
help="If passed, will merge the lora modules and save the entire model.",
)
parser.add_argument(
"--use_flash_attn",
action="store_true",
help="If passed, will use flash attention to train the model.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--max_seq_length",
type=int,
default=512,
help="The maximum total sequence length (prompt+completion) of each training example.",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument(
"--warmup_ratio", type=float, default=0, help="Ratio of total training steps used for warmup."
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--preprocessing_num_workers",
type=int,
default=None,
help="The number of processes to use for the preprocessing.",
)
parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
)
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--logging_steps",
type=int,
default=None,
help="Log the training loss and learning rate every logging_steps steps.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument(
"--low_cpu_mem_usage",
action="store_true",
help=(
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
"If passed, LLM loading time and RAM consumption will be benefited."
),
)
parser.add_argument(
"--use_special_tokens",
action="store_true",
help=(
"Use special tokens."
),
)
args = parser.parse_args()
# Sanity checks
if args.dataset_name is None and args.train_file is None:
raise ValueError("Need either a dataset name or a training file.")
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
assert extension in ["json", "jsonl"], "`train_file` should be a json/jsonl file."
return args
def _tokenize_fn(text: str, tokenizer: transformers.PreTrainedTokenizer, max_seq_length: int) -> Dict:
"""Tokenize a list of strings."""
input_ids = labels = tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=max_seq_length,
truncation=True,
).input_ids
input_ids_lens = labels_lens = input_ids.ne(tokenizer.pad_token_id).sum().item()
print(input_ids_lens)
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def encode_with_prompt_completion_format(example, tokenizer, max_seq_length, context_markups=None):
'''
Here we assume each example has 'prompt' and 'completion' fields.
We concatenate prompt and completion and tokenize them together because otherwise prompt will be padded/trancated
and it doesn't make sense to follow directly with the completion.
'''
# if prompt doesn't end with space and completion doesn't start with space, add space
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
source_text = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
target_text = example['output'] + tokenizer.eos_token
examples_tokenized = _tokenize_fn(source_text + target_text, tokenizer, max_seq_length)
sources_tokenized = _tokenize_fn(source_text, tokenizer, max_seq_length)
input_ids = examples_tokenized["input_ids"].flatten()
source_len = sources_tokenized["input_ids_lens"]
labels = copy.deepcopy(input_ids)
labels[ :source_len-1] = -100
if context_markups is not None:
context_start = False
for j, orig_token in enumerate(labels[source_len:]):
if context_start is False and orig_token == context_markups[0]:
context_start = True
assert labels[source_len+j] == context_markups[0]
start_idx = j+source_len
end_idx = None
for k, orig_token_2 in enumerate(labels[start_idx:]):
if orig_token_2 == context_markups[1]:
end_idx = start_idx + k
if end_idx is None:
end_idx = start_idx + k
else:
assert labels[end_idx] == context_markups[1]
labels[start_idx+1:end_idx] = -100
context_start = False
attention_mask = torch.ones_like(input_ids)
return {
'input_ids': input_ids.flatten(),
'labels': labels.flatten(),
'attention_mask': attention_mask.flatten()
}
def encode_with_messages_format(example, tokenizer, max_seq_length):
'''
Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields.
We concatenate all messages with the roles as delimiters and tokenize them together.
'''
messages = example['messages']
if len(messages) == 0:
raise ValueError('messages field is empty.')
def _concat_messages(messages):
message_text = ""
for message in messages:
if message["role"] == "system":
message_text += "<|system|>\n" + message["content"].strip() + "\n"
elif message["role"] == "user":
message_text += "<|user|>\n" + message["content"].strip() + "\n"
elif message["role"] == "assistant":
message_text += "<|assistant|>\n" + message["content"].strip() + tokenizer.eos_token + "\n"
else:
raise ValueError("Invalid role: {}".format(message["role"]))
return message_text
example_text = _concat_messages(messages).strip()
tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
input_ids = tokenized_example.input_ids
labels = input_ids.clone()
# mask the non-assistant part for avoiding loss
for message_idx, message in enumerate(messages):
if message["role"] != "assistant":
if message_idx == 0:
message_start_idx = 0
else:
message_start_idx = tokenizer(
_concat_messages(messages[:message_idx]), return_tensors='pt', max_length=max_seq_length, truncation=True
).input_ids.shape[1]
if message_idx < len(messages) - 1 and messages[message_idx+1]["role"] == "assistant":
# here we also ignore the role of the assistant
messages_so_far = _concat_messages(messages[:message_idx+1]) + "<|assistant|>\n"
else:
messages_so_far = _concat_messages(messages[:message_idx+1])
message_end_idx = tokenizer(
messages_so_far,
return_tensors='pt',
max_length=max_seq_length,
truncation=True
).input_ids.shape[1]
labels[:, message_start_idx:message_end_idx] = -100
if message_end_idx >= max_seq_length:
break
attention_mask = torch.ones_like(input_ids)
return {
'input_ids': input_ids.flatten(),
'labels': labels.flatten(),
'attention_mask': attention_mask.flatten(),
}
def main():
args = parse_args()
# A hacky way to make llama work with flash attention
if args.use_flash_attn:
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator_log_kwargs = {}
if args.with_tracking:
accelerator_log_kwargs["log_with"] = args.report_to
accelerator_log_kwargs["project_dir"] = args.output_dir
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(
args.dataset_name,
args.dataset_config_name,
)
else:
data_files = {}
dataset_args = {}
if args.train_file is not None:
data_files["train"] = args.train_file
raw_datasets = load_dataset(
"json",
data_files=data_files,
**dataset_args,
)
# Load pretrained model and tokenizer
if args.config_name:
config = AutoConfig.from_pretrained(args.config_name)
elif args.model_name_or_path:
config = AutoConfig.from_pretrained(args.model_name_or_path)
else:
raise ValueError(
"You are instantiating a new config instance from scratch. This is not supported by this script."
)
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
elif args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
if args.model_name_or_path:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
else:
logger.info("Training new model from scratch")
model = AutoModelForCausalLM.from_config(config)
# no default pad token for llama!
# here we add all special tokens again, because the default ones are not in the special_tokens_map
if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast):
if args.use_special_tokens is True:
special_token_dict = {"additional_special_tokens": ["[No Retrieval]", "[Retrieval]", "[Continue to Use Evidence]", "[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]", "[Fully supported]", "[Partially supported]", "[No support / Contradictory]"]}
special_token_dict["bos_token"] = "<s>"
special_token_dict["eos_token"] = "</s>"
special_token_dict["unk_token"] = "<unk>"
special_token_dict["pad_token"] = "<pad>"
num_added_tokens = tokenizer.add_special_tokens(special_token_dict)
context_markups = []
for token in ["<paragraph>", "</paragraph>"]:
context_markups.append(tokenizer.convert_tokens_to_ids(token))
if args.use_special_tokens is False:
assert num_added_tokens in [0, 1], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present."
else:
assert num_added_tokens > 10, "special tokens must be added to the original tokenizers."
elif isinstance(tokenizer, GPTNeoXTokenizerFast):
num_added_tokens = tokenizer.add_special_tokens({
"pad_token": "<pad>",
})
assert num_added_tokens == 1, "GPTNeoXTokenizer should only add one special token - the pad_token."
elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM):
num_added_tokens = tokenizer.add_special_tokens({'unk_token': '<unk>'})
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
if args.use_lora:
logger.info("Initializing LORA model...")
modules_to_save = ["embed_tokens"]
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=args.lora_rank,
#modules_to_save=modules_to_save,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
encode_function = partial(
encode_with_prompt_completion_format,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
context_markups=context_markups if args.use_special_tokens is True else None
)
# elif "messages" in raw_datasets["train"].column_names:
# encode_function = partial(
# encode_with_messages_format,
# tokenizer=tokenizer,
# max_seq_length=args.max_seq_length,
# )
with accelerator.main_process_first():
lm_datasets = raw_datasets.map(
encode_function,
batched=False,
num_proc=args.preprocessing_num_workers,
load_from_cache_file=not args.overwrite_cache,
remove_columns=[name for name in raw_datasets["train"].column_names if name not in ["input_ids", "labels", "attention_mask"]],
desc="Tokenizing and reformatting instruction data",
)
lm_datasets.set_format(type="pt")
lm_datasets = lm_datasets.filter(lambda example: (example['labels'] != -100).any())
train_dataset = lm_datasets["train"]
#print(train_dataset[0])
#print(train_dataset[1000])
#print(train_dataset[500])
#print(train_dataset[2000])
#print(train_dataset[10000])
with open("processed.json", "w") as outfile:
new_data = []
for item in train_dataset:
print(item)
labels = [int(i) for i in item["labels"]]
input_ids = [int(i) for i in item["input_ids"]]
new_data.append({"labels": labels, "input_ids": input_ids})
json.dump(new_data, outfile)
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# DataLoaders creation:
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"),
batch_size=args.per_device_train_batch_size
)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
# Create the learning rate scheduler.
# Note: the current accelerator.step() calls the .step() of the real scheduler for the `num_processes` times. This is because they assume
# the user initialize the scheduler with the entire training set. In the case of data parallel training, each process only
# sees a subset (1/num_processes) of the training set. So each time the process needs to update the lr multiple times so that the total
# number of updates in the end matches the num_training_steps here.
# Here we need to set the num_training_steps to either using the entire training set (when epochs is specified) or we need to multiply the
# num_training_steps by num_processes so that the total number of updates matches the num_training_steps.
num_training_steps_for_scheduler = args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio),
)
# Prepare everything with `accelerator`.
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
checkpointing_steps = args.checkpointing_steps
if checkpointing_steps is not None and checkpointing_steps.isdigit():
checkpointing_steps = int(checkpointing_steps)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("open_instruct", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
# Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
# need to multiply `gradient_accumulation_steps` to reflect real steps
resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
# update the progress_bar if load from checkpoint
progress_bar.update(starting_epoch * num_update_steps_per_epoch)
completed_steps = starting_epoch * num_update_steps_per_epoch
for epoch in range(starting_epoch, args.num_train_epochs):
model.train()
total_loss = 0
for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and completed_steps < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
completed_steps += 1
continue
with accelerator.accumulate(model):
outputs = model(**batch, use_cache=False)
loss = outputs.loss
# We keep track of the loss at each logged step
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
# # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
if args.logging_steps and completed_steps % args.logging_steps == 0:
avg_loss = accelerator.gather(total_loss).mean().item() / args.gradient_accumulation_steps / args.logging_steps
logger.info(f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}")
if args.with_tracking:
accelerator.log(
{
"learning_rate": lr_scheduler.get_last_lr()[0],
"train_loss": avg_loss,
},
step=completed_steps,
)
total_loss = 0
if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= args.max_train_steps:
break
if args.checkpointing_steps == "epoch":
output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.with_tracking:
accelerator.end_training()
if args.output_dir is not None:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir)
unwrapped_model = accelerator.unwrap_model(model)
# When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict.
# Otherwise, sometimes the model will be saved with only part of the parameters.
# Also, accelerator needs to use the wrapped model to get the state_dict.
state_dict = accelerator.get_state_dict(model)
if args.use_lora:
# When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process
# and has its own save_pretrained function for only saving lora modules.
# We have to mannually specify the is_main_process outside the save_pretrained function.
if accelerator.is_main_process:
unwrapped_model.save_pretrained(args.output_dir, state_dict=state_dict)
else:
unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=state_dict
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,115 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import argparse
import pickle
import torch
def embed_passages(args, passages, model, tokenizer):
total = 0
allids, allembeddings = [], []
batch_ids, batch_text = [], []
with torch.no_grad():
for k, p in enumerate(passages):
batch_ids.append(p["id"])
"""if args.no_title or not "title" in p:
text = p["text"]
else:
text = p["title"] + " " + p["text"]"""
text = p["title"]
if args.lowercase:
text = text.lower()
if args.normalize_text:
text = robowaiter.llm_client.retrieval_lm.src.normalize_text.normalize(text)
batch_text.append(text)
if len(batch_text) == args.per_gpu_batch_size or k == len(passages) - 1:
encoded_batch = tokenizer.batch_encode_plus(
batch_text,
return_tensors="pt",
max_length=args.passage_maxlength,
padding=True,
truncation=True,
)
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
embeddings = model(**encoded_batch)
embeddings = embeddings.cpu()
total += len(batch_ids)
allids.extend(batch_ids)
allembeddings.append(embeddings)
batch_text = []
batch_ids = []
if k % 100000 == 0 and k > 0:
print(f"Encoded passages {total}")
allembeddings = torch.cat(allembeddings, dim=0).numpy()
return allids, allembeddings
def main(args):
model, tokenizer, _ = robowaiter.llm_client.retrieval_lm.src.contriever.load_retriever(args.model_name_or_path)
print(f"Model loaded from {args.model_name_or_path}.", flush=True)
model.eval()
model = model.cuda()
if not args.no_fp16:
model = model.half()
passages = robowaiter.llm_client.retrieval_lm.src.data.load_passages(args.passages)
shard_size = len(passages) // args.num_shards
start_idx = args.shard_id * shard_size
end_idx = start_idx + shard_size
if args.shard_id == args.num_shards - 1:
end_idx = len(passages)
passages = passages[start_idx:end_idx]
print(f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.")
allids, allembeddings = embed_passages(args, passages, model, tokenizer)
save_file = os.path.join(args.output_dir, args.prefix + f"_{args.shard_id:02d}")
os.makedirs(args.output_dir, exist_ok=True)
print(f"Saving {len(allids)} passage embeddings to {save_file}.")
with open(save_file, mode="wb") as f:
pickle.dump((allids, allembeddings), f)
print(f"Total passages processed {len(allids)}. Written to {save_file}.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
parser.add_argument("--output_dir", type=str, default="wikipedia_embeddings", help="dir path to save embeddings")
parser.add_argument("--prefix", type=str, default="passages", help="prefix path to save embeddings")
parser.add_argument("--shard_id", type=int, default=0, help="Id of the current shard")
parser.add_argument("--num_shards", type=int, default=1, help="Total number of shards")
parser.add_argument(
"--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass"
)
parser.add_argument("--passage_maxlength", type=int, default=512, help="Maximum number of tokens in a passage")
parser.add_argument(
"--model_name_or_path", type=str, help="path to directory containing model weights and config file"
)
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
parser.add_argument("--no_title", action="store_true", help="title not added to the passage body")
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
parser.add_argument("--normalize_text", action="store_true", help="lowercase text before encoding")
args = parser.parse_args()
robowaiter.llm_client.retrieval_lm.src.slurm.init_distributed_mode(args)
main(args)

View File

@ -0,0 +1,119 @@
from typing import List, Optional, Tuple
import torch
from torch import nn
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from einops import rearrange
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
except ImportError:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(
pad_input(
rearrange(output_unpad,
"nnz h d -> nnz (h d)"), indices, bsz, q_len
),
"b s (h d) -> b s h d",
h=nheads,
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

View File

@ -0,0 +1,87 @@
import numpy as np
import string
import re
from collections import Counter
import re
def exact_match_score(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def accuracy(preds, labels):
match_count = 0
for pred, label in zip(preds, labels):
target = label[0]
if pred == target:
match_count += 1
return 100 * (match_count / len(preds))
def f1(decoded_preds, decoded_labels):
f1_all = []
for prediction, answers in zip(decoded_preds, decoded_labels):
if type(answers) == list:
if len(answers) == 0:
return 0
f1_all.append(np.max([qa_f1_score(prediction, gt)
for gt in answers]))
else:
f1_all.append(qa_f1_score(prediction, answers))
return 100 * np.mean(f1_all)
def qa_f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def find_entity_tags(sentence):
entity_regex = r'(.+?)(?=\s<|$)'
tag_regex = r'<(.+?)>'
entity_names = re.findall(entity_regex, sentence)
tags = re.findall(tag_regex, sentence)
results = {}
for entity, tag in zip(entity_names, tags):
if "<" in entity:
results[entity.split("> ")[1]] = tag
else:
results[entity] = tag
return results
def match(prediction, ground_truth):
for gt in ground_truth:
if gt in prediction:
return 1
return 0

View File

@ -0,0 +1,8 @@
export CUDA_VISIBLE_DEVICES=0
python3 ../generate_passage_embeddings.py \
--model_name_or_path ../../model/contriever-msmarco \
--passages train_robot.jsonl \
--output_dir robot_embeddings \
--shard_id 0 \
--num_shards 1 \
--per_gpu_batch_size 500

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,250 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import argparse
import csv
import json
import logging
import pickle
import time
import glob
from pathlib import Path
import numpy as np
import torch
import transformers
import src.index
import src.contriever
import src.utils
import src.slurm
import src.data
from src.evaluation import calculate_matches
import src.normalize_text
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def embed_queries(args, queries, model, tokenizer):
model.eval()
embeddings, batch_question = [], []
with torch.no_grad():
for k, q in enumerate(queries):
if args.lowercase:
q = q.lower()
if args.normalize_text:
q = src.normalize_text.normalize(q)
batch_question.append(q)
if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
encoded_batch = tokenizer.batch_encode_plus(
batch_question,
return_tensors="pt",
max_length=args.question_maxlength,
padding=True,
truncation=True,
)
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
output = model(**encoded_batch)
embeddings.append(output.cpu())
batch_question = []
embeddings = torch.cat(embeddings, dim=0)
print(f"Questions embeddings shape: {embeddings.size()}")
return embeddings.numpy()
def index_encoded_data(index, embedding_files, indexing_batch_size):
allids = []
allembeddings = np.array([])
for i, file_path in enumerate(embedding_files):
print(f"Loading file {file_path}")
with open(file_path, "rb") as fin:
ids, embeddings = pickle.load(fin)
allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings
allids.extend(ids)
while allembeddings.shape[0] > indexing_batch_size:
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
while allembeddings.shape[0] > 0:
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
print("Data indexing completed.")
def add_embeddings(index, embeddings, ids, indexing_batch_size):
end_idx = min(indexing_batch_size, embeddings.shape[0])
ids_toadd = ids[:end_idx]
embeddings_toadd = embeddings[:end_idx]
ids = ids[end_idx:]
embeddings = embeddings[end_idx:]
index.index_data(ids_toadd, embeddings_toadd)
return embeddings, ids
def validate(data, workers_num):
match_stats = calculate_matches(data, workers_num)
top_k_hits = match_stats.top_k_hits
print("Validation results: top k documents hits %s", top_k_hits)
top_k_hits = [v / len(data) for v in top_k_hits]
message = ""
for k in [5, 10, 20, 100]:
if k <= len(top_k_hits):
message += f"R@{k}: {top_k_hits[k-1]} "
print(message)
return match_stats.questions_doc_hits
def add_passages(data, passages, top_passages_and_scores):
# add passages to original data
merged_data = []
assert len(data) == len(top_passages_and_scores)
for i, d in enumerate(data):
results_and_scores = top_passages_and_scores[i]
#print(passages[2393])
docs = [passages[int(doc_id)] for doc_id in results_and_scores[0]]
scores = [str(score) for score in results_and_scores[1]]
ctxs_num = len(docs)
d["ctxs"] = [
{
"id": results_and_scores[0][c],
"title": docs[c]["title"],
"text": docs[c]["text"],
"score": scores[c],
}
for c in range(ctxs_num)
]
def add_hasanswer(data, hasanswer):
# add hasanswer to data
for i, ex in enumerate(data):
for k, d in enumerate(ex["ctxs"]):
d["hasanswer"] = hasanswer[i][k]
def load_data(data_path):
if data_path.endswith(".json"):
with open(data_path, "r") as fin:
data = json.load(fin)
elif data_path.endswith(".jsonl"):
data = []
with open(data_path, "r") as fin:
for k, example in enumerate(fin):
example = json.loads(example)
data.append(example)
return data
def main(args):
print(f"Loading model from: {args.model_name_or_path}")
model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path)
model.eval()
model = model.cuda()
if not args.no_fp16:
model = model.half()
index = src.index.Indexer(args.projection_size, args.n_subquantizers, args.n_bits)
# index all passages
input_paths = glob.glob(args.passages_embeddings)
input_paths = sorted(input_paths)
embeddings_dir = os.path.dirname(input_paths[0])
index_path = os.path.join(embeddings_dir, "index.faiss")
if args.save_or_load_index and os.path.exists(index_path):
index.deserialize_from(embeddings_dir)
else:
print(f"Indexing passages from files {input_paths}")
start_time_indexing = time.time()
index_encoded_data(index, input_paths, args.indexing_batch_size)
print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
if args.save_or_load_index:
index.serialize(embeddings_dir)
# load passages
passages = src.data.load_passages(args.passages)
passage_id_map = {x["id"]: x for x in passages}
data_paths = glob.glob(args.data)
alldata = []
for path in data_paths:
data = load_data(path)
output_path = os.path.join(args.output_dir, os.path.basename(path))
queries = [ex["question"] for ex in data]
questions_embedding = embed_queries(args, queries, model, tokenizer)
# get top k results
start_time_retrieval = time.time()
top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
add_passages(data, passage_id_map, top_ids_and_scores)
#hasanswer = validate(data, args.validation_workers)
#add_hasanswer(data, hasanswer)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w") as fout:
for ex in data:
json.dump(ex, fout, ensure_ascii=False)
fout.write("\n")
print(f"Saved results to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data",
required=True,
type=str,
default=None,
help=".json file containing question and answers, similar format to reader data",
)
parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
parser.add_argument("--passages_embeddings", type=str, default=None, help="Glob path to encoded passages")
parser.add_argument(
"--output_dir", type=str, default=None, help="Results are written to outputdir with data suffix"
)
parser.add_argument("--n_docs", type=int, default=100, help="Number of documents to retrieve per questions")
parser.add_argument(
"--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
)
parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding")
parser.add_argument(
"--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
)
parser.add_argument(
"--model_name_or_path", type=str, help="path to directory containing model weights and config file"
)
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
parser.add_argument(
"--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
)
parser.add_argument("--projection_size", type=int, default=768)
parser.add_argument(
"--n_subquantizers",
type=int,
default=0,
help="Number of subquantizer used for vector quantization, if 0 flat index is used",
)
parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer")
parser.add_argument("--lang", nargs="+")
parser.add_argument("--dataset", type=str, default="none")
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
parser.add_argument("--normalize_text", action="store_true", help="normalize text")
args = parser.parse_args()
src.slurm.init_distributed_mode(args)
main(args)

View File

@ -0,0 +1,41 @@
import json
import jsonlines
import argparse
def train(args):
filename=args.passages
with open(filename, 'r', encoding="utf-8") as f:
k=0
for line in f:
data = json.loads(line)
dict={"id":k,'title':data['title'],'text':data['text']}
k+=1
with jsonlines.open("train_robot.jsonl", "a") as file_jsonl:
file_jsonl.write(dict)
def test(args):
filename = args.passages
with open(filename, 'r', encoding="utf-8") as f:
k=0
for line in f:
if k<1000:
data = json.loads(line)
dict={"id":data['id'],'question':data['title'],'answers':data['text']}
k+=1
with jsonlines.open("test_robot.jsonl", "a") as file_jsonl:
file_jsonl.write(dict)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--passages", type=str, default=None, help="Path to passages")
parser.add_argument("--mode", type=str, default=None, help="train or test")
args = parser.parse_args()
if args.mode=='train':
train(args)
elif args.mode=='test':
test(args)
else:
print("error mode!")

View File

@ -0,0 +1,2 @@
{"id": 0, "question": "请把酸奶放在咖啡台上,并打开窗帘。", "ctxs": [{"id": "0", "title": "请把酸奶放在咖啡台上,并打开窗帘。", "text": "On(Yogurt,CoffeeTable),Is(Curtain,Open)", "score": "1.9694625"}, {"id": "1", "title": "可以把牛奶饮料放在2号桌子上吗还有关掉灯光。", "text": "On(MilkDrink,Table2),Is(TubeLight,Off)", "score": "1.8284101"}, {"id": "2", "title": "你好,可以给我上一份甜点吗?", "text": "On(Dessert,Table1)", "score": "1.4835652"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "1.4412252"}, {"id": "4", "title": "可以送一瓶牛奶饮料到1号桌吗", "text": "On(MilkDrink,Table1)", "score": "1.2867957"}, {"id": "3", "title": "你能到另一个吧台这边来吗?空调可以关掉吗?", "text": "At(Robot,Bar2),Is(AC,On)", "score": "1.2599907"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}]}
{"id": 1, "question": "可以把牛奶饮料放在2号桌子上吗还有关掉灯光。", "ctxs": [{"id": "1", "title": "可以把牛奶饮料放在2号桌子上吗还有关掉灯光。", "text": "On(MilkDrink,Table2),Is(TubeLight,Off)", "score": "2.138029"}, {"id": "0", "title": "请把酸奶放在咖啡台上,并打开窗帘。", "text": "On(Yogurt,CoffeeTable),Is(Curtain,Open)", "score": "1.8282425"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "1.6972268"}, {"id": "2", "title": "你好,可以给我上一份甜点吗?", "text": "On(Dessert,Table1)", "score": "1.4741647"}, {"id": "4", "title": "可以送一瓶牛奶饮料到1号桌吗", "text": "On(MilkDrink,Table1)", "score": "1.4532053"}, {"id": "3", "title": "你能到另一个吧台这边来吗?空调可以关掉吗?", "text": "At(Robot,Bar2),Is(AC,On)", "score": "1.3438905"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}, {"id": "5", "title": "可以把酸奶放在2号桌上吗还有能关掉筒灯吗", "text": "On(Yogurt,Table2),Is(TubeLight,Off)", "score": "-3.4028235e+38"}]}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,208 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
from collections import defaultdict
from typing import List, Dict
import numpy as np
import torch
import torch.distributed as dist
import beir.util
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch
from beir.reranking.models import CrossEncoder
from beir.reranking import Rerank
import src.dist_utils as dist_utils
from src import normalize_text
class DenseEncoderModel:
def __init__(
self,
query_encoder,
doc_encoder=None,
tokenizer=None,
max_length=512,
add_special_tokens=True,
norm_query=False,
norm_doc=False,
lower_case=False,
normalize_text=False,
**kwargs,
):
self.query_encoder = query_encoder
self.doc_encoder = doc_encoder
self.tokenizer = tokenizer
self.max_length = max_length
self.add_special_tokens = add_special_tokens
self.norm_query = norm_query
self.norm_doc = norm_doc
self.lower_case = lower_case
self.normalize_text = normalize_text
def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray:
if dist.is_initialized():
idx = np.array_split(range(len(queries)), dist.get_world_size())[dist.get_rank()]
else:
idx = range(len(queries))
queries = [queries[i] for i in idx]
if self.normalize_text:
queries = [normalize_text.normalize(q) for q in queries]
if self.lower_case:
queries = [q.lower() for q in queries]
allemb = []
nbatch = (len(queries) - 1) // batch_size + 1
with torch.no_grad():
for k in range(nbatch):
start_idx = k * batch_size
end_idx = min((k + 1) * batch_size, len(queries))
qencode = self.tokenizer.batch_encode_plus(
queries[start_idx:end_idx],
max_length=self.max_length,
padding=True,
truncation=True,
add_special_tokens=self.add_special_tokens,
return_tensors="pt",
)
qencode = {key: value.cuda() for key, value in qencode.items()}
emb = self.query_encoder(**qencode, normalize=self.norm_query)
allemb.append(emb.cpu())
allemb = torch.cat(allemb, dim=0)
allemb = allemb.cuda()
if dist.is_initialized():
allemb = dist_utils.varsize_gather_nograd(allemb)
allemb = allemb.cpu().numpy()
return allemb
def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs):
if dist.is_initialized():
idx = np.array_split(range(len(corpus)), dist.get_world_size())[dist.get_rank()]
else:
idx = range(len(corpus))
corpus = [corpus[i] for i in idx]
corpus = [c["title"] + " " + c["text"] if len(c["title"]) > 0 else c["text"] for c in corpus]
if self.normalize_text:
corpus = [normalize_text.normalize(c) for c in corpus]
if self.lower_case:
corpus = [c.lower() for c in corpus]
allemb = []
nbatch = (len(corpus) - 1) // batch_size + 1
with torch.no_grad():
for k in range(nbatch):
start_idx = k * batch_size
end_idx = min((k + 1) * batch_size, len(corpus))
cencode = self.tokenizer.batch_encode_plus(
corpus[start_idx:end_idx],
max_length=self.max_length,
padding=True,
truncation=True,
add_special_tokens=self.add_special_tokens,
return_tensors="pt",
)
cencode = {key: value.cuda() for key, value in cencode.items()}
emb = self.doc_encoder(**cencode, normalize=self.norm_doc)
allemb.append(emb.cpu())
allemb = torch.cat(allemb, dim=0)
allemb = allemb.cuda()
if dist.is_initialized():
allemb = dist_utils.varsize_gather_nograd(allemb)
allemb = allemb.cpu().numpy()
return allemb
def evaluate_model(
query_encoder,
doc_encoder,
tokenizer,
dataset,
batch_size=128,
add_special_tokens=True,
norm_query=False,
norm_doc=False,
is_main=True,
split="test",
score_function="dot",
beir_dir="BEIR/datasets",
save_results_path=None,
lower_case=False,
normalize_text=False,
):
metrics = defaultdict(list) # store final results
if hasattr(query_encoder, "module"):
query_encoder = query_encoder.module
query_encoder.eval()
if doc_encoder is not None:
if hasattr(doc_encoder, "module"):
doc_encoder = doc_encoder.module
doc_encoder.eval()
else:
doc_encoder = query_encoder
dmodel = DenseRetrievalExactSearch(
DenseEncoderModel(
query_encoder=query_encoder,
doc_encoder=doc_encoder,
tokenizer=tokenizer,
add_special_tokens=add_special_tokens,
norm_query=norm_query,
norm_doc=norm_doc,
lower_case=lower_case,
normalize_text=normalize_text,
),
batch_size=batch_size,
)
retriever = EvaluateRetrieval(dmodel, score_function=score_function)
data_path = os.path.join(beir_dir, dataset)
if not os.path.isdir(data_path) and is_main:
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
data_path = beir.util.download_and_unzip(url, beir_dir)
dist_utils.barrier()
if not dataset == "cqadupstack":
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
results = retriever.retrieve(corpus, queries)
if is_main:
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"):
if isinstance(metric, str):
metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric)
for key, value in metric.items():
metrics[key].append(value)
if save_results_path is not None:
torch.save(results, f"{save_results_path}")
elif dataset == "cqadupstack": # compute macroaverage over datasets
paths = glob.glob(data_path)
for path in paths:
corpus, queries, qrels = GenericDataLoader(data_folder=data_folder).load(split=split)
results = retriever.retrieve(corpus, queries)
if is_main:
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"):
if isinstance(metric, str):
metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric)
for key, value in metric.items():
metrics[key].append(value)
for key, value in metrics.items():
assert (
len(value) == 12
), f"cqadupstack includes 12 datasets, only {len(value)} values were compute for the {key} metric"
metrics = {key: 100 * np.mean(value) for key, value in metrics.items()}
return metrics

View File

@ -0,0 +1,139 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import torch
import transformers
from transformers import BertModel, XLMRobertaModel
from robowaiter.algos.retrieval.retrieval_lm.src import utils
class Contriever(BertModel):
def __init__(self, config, pooling="average", **kwargs):
super().__init__(config, add_pooling_layer=False)
if not hasattr(config, "pooling"):
self.config.pooling = pooling
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
normalize=False,
):
model_output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden = model_output["last_hidden_state"]
last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
if self.config.pooling == "average":
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif self.config.pooling == "cls":
emb = last_hidden[:, 0]
if normalize:
emb = torch.nn.functional.normalize(emb, dim=-1)
return emb
class XLMRetriever(XLMRobertaModel):
def __init__(self, config, pooling="average", **kwargs):
super().__init__(config, add_pooling_layer=False)
if not hasattr(config, "pooling"):
self.config.pooling = pooling
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
normalize=False,
):
model_output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden = model_output["last_hidden_state"]
last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
if self.config.pooling == "average":
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif self.config.pooling == "cls":
emb = last_hidden[:, 0]
if normalize:
emb = torch.nn.functional.normalize(emb, dim=-1)
return emb
def load_retriever(model_path, pooling="average", random_init=False):
# try: check if model exists locally
path = os.path.join(model_path, "checkpoint.pth")
if os.path.exists(path):
pretrained_dict = torch.load(path, map_location="cpu")
opt = pretrained_dict["opt"]
if hasattr(opt, "retriever_model_id"):
retriever_model_id = opt.retriever_model_id
else:
# retriever_model_id = "bert-base-uncased"
retriever_model_id = "bert-base-multilingual-cased"
tokenizer = utils.load_hf(transformers.AutoTokenizer, retriever_model_id)
cfg = utils.load_hf(transformers.AutoConfig, retriever_model_id)
if "xlm" in retriever_model_id:
model_class = XLMRetriever
else:
model_class = Contriever
retriever = model_class(cfg)
pretrained_dict = pretrained_dict["model"]
if any("encoder_q." in key for key in pretrained_dict.keys()): # test if model is defined with moco class
pretrained_dict = {k.replace("encoder_q.", ""): v for k, v in pretrained_dict.items() if "encoder_q." in k}
elif any("encoder." in key for key in pretrained_dict.keys()): # test if model is defined with inbatch class
pretrained_dict = {k.replace("encoder.", ""): v for k, v in pretrained_dict.items() if "encoder." in k}
retriever.load_state_dict(pretrained_dict, strict=False)
else:
retriever_model_id = model_path
if "xlm" in retriever_model_id:
model_class = XLMRetriever
else:
model_class = Contriever
cfg = utils.load_hf(transformers.AutoConfig, model_path)
tokenizer = utils.load_hf(transformers.AutoTokenizer, model_path)
retriever = utils.load_hf(model_class, model_path)
return retriever, tokenizer, retriever_model_id

View File

@ -0,0 +1,243 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import glob
import torch
import random
import json
import csv
import numpy as np
import numpy.random
import logging
from collections import defaultdict
from robowaiter.algos.retrieval.retrieval_lm.src import dist_utils
logger = logging.getLogger(__name__)
def load_data(opt, tokenizer):
datasets = {}
for path in opt.train_data:
data = load_dataset(path, opt.loading_mode)
if data is not None:
datasets[path] = Dataset(data, opt.chunk_length, tokenizer, opt)
dataset = MultiDataset(datasets)
dataset.set_prob(coeff=opt.sampling_coefficient)
return dataset
def load_dataset(data_path, loading_mode):
files = glob.glob(os.path.join(data_path, "*.p*"))
files.sort()
tensors = []
if loading_mode == "split":
files_split = list(np.array_split(files, dist_utils.get_world_size()))[dist_utils.get_rank()]
for filepath in files_split:
try:
tensors.append(torch.load(filepath, map_location="cpu"))
except:
logger.warning(f"Unable to load file {filepath}")
elif loading_mode == "full":
for fin in files:
tensors.append(torch.load(fin, map_location="cpu"))
elif loading_mode == "single":
tensors.append(torch.load(files[0], map_location="cpu"))
if len(tensors) == 0:
return None
tensor = torch.cat(tensors)
return tensor
class MultiDataset(torch.utils.data.Dataset):
def __init__(self, datasets):
self.datasets = datasets
self.prob = [1 / len(self.datasets) for _ in self.datasets]
self.dataset_ids = list(self.datasets.keys())
def __len__(self):
return sum([len(dataset) for dataset in self.datasets.values()])
def __getitem__(self, index):
dataset_idx = numpy.random.choice(range(len(self.prob)), 1, p=self.prob)[0]
did = self.dataset_ids[dataset_idx]
index = random.randint(0, len(self.datasets[did]) - 1)
sample = self.datasets[did][index]
sample["dataset_id"] = did
return sample
def generate_offset(self):
for dataset in self.datasets.values():
dataset.generate_offset()
def set_prob(self, coeff=0.0):
prob = np.array([float(len(dataset)) for _, dataset in self.datasets.items()])
prob /= prob.sum()
prob = np.array([p**coeff for p in prob])
prob /= prob.sum()
self.prob = prob
class Dataset(torch.utils.data.Dataset):
"""Monolingual dataset based on a list of paths"""
def __init__(self, data, chunk_length, tokenizer, opt):
self.data = data
self.chunk_length = chunk_length
self.tokenizer = tokenizer
self.opt = opt
self.generate_offset()
def __len__(self):
return (self.data.size(0) - self.offset) // self.chunk_length
def __getitem__(self, index):
start_idx = self.offset + index * self.chunk_length
end_idx = start_idx + self.chunk_length
tokens = self.data[start_idx:end_idx]
q_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max)
k_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max)
q_tokens = apply_augmentation(q_tokens, self.opt)
q_tokens = add_bos_eos(q_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id)
k_tokens = apply_augmentation(k_tokens, self.opt)
k_tokens = add_bos_eos(k_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id)
return {"q_tokens": q_tokens, "k_tokens": k_tokens}
def generate_offset(self):
self.offset = random.randint(0, self.chunk_length - 1)
class Collator(object):
def __init__(self, opt):
self.opt = opt
def __call__(self, batch_examples):
batch = defaultdict(list)
for example in batch_examples:
for k, v in example.items():
batch[k].append(v)
q_tokens, q_mask = build_mask(batch["q_tokens"])
k_tokens, k_mask = build_mask(batch["k_tokens"])
batch["q_tokens"] = q_tokens
batch["q_mask"] = q_mask
batch["k_tokens"] = k_tokens
batch["k_mask"] = k_mask
return batch
def randomcrop(x, ratio_min, ratio_max):
ratio = random.uniform(ratio_min, ratio_max)
length = int(len(x) * ratio)
start = random.randint(0, len(x) - length)
end = start + length
crop = x[start:end].clone()
return crop
def build_mask(tensors):
shapes = [x.shape for x in tensors]
maxlength = max([len(x) for x in tensors])
returnmasks = []
ids = []
for k, x in enumerate(tensors):
returnmasks.append(torch.tensor([1] * len(x) + [0] * (maxlength - len(x))))
ids.append(torch.cat((x, torch.tensor([0] * (maxlength - len(x))))))
ids = torch.stack(ids, dim=0).long()
returnmasks = torch.stack(returnmasks, dim=0).bool()
return ids, returnmasks
def add_token(x, token):
x = torch.cat((torch.tensor([token]), x))
return x
def deleteword(x, p=0.1):
mask = np.random.rand(len(x))
x = [e for e, m in zip(x, mask) if m > p]
return x
def replaceword(x, min_random, max_random, p=0.1):
mask = np.random.rand(len(x))
x = [e if m > p else random.randint(min_random, max_random) for e, m in zip(x, mask)]
return x
def maskword(x, mask_id, p=0.1):
mask = np.random.rand(len(x))
x = [e if m > p else mask_id for e, m in zip(x, mask)]
return x
def shuffleword(x, p=0.1):
count = (np.random.rand(len(x)) < p).sum()
"""Shuffles any n number of values in a list"""
indices_to_shuffle = random.sample(range(len(x)), k=count)
to_shuffle = [x[i] for i in indices_to_shuffle]
random.shuffle(to_shuffle)
for index, value in enumerate(to_shuffle):
old_index = indices_to_shuffle[index]
x[old_index] = value
return x
def apply_augmentation(x, opt):
if opt.augmentation == "mask":
return torch.tensor(maskword(x, mask_id=opt.mask_id, p=opt.prob_augmentation))
elif opt.augmentation == "replace":
return torch.tensor(
replaceword(x, min_random=opt.start_id, max_random=opt.vocab_size - 1, p=opt.prob_augmentation)
)
elif opt.augmentation == "delete":
return torch.tensor(deleteword(x, p=opt.prob_augmentation))
elif opt.augmentation == "shuffle":
return torch.tensor(shuffleword(x, p=opt.prob_augmentation))
else:
if not isinstance(x, torch.Tensor):
x = torch.Tensor(x)
return x
def add_bos_eos(x, bos_token_id, eos_token_id):
if not isinstance(x, torch.Tensor):
x = torch.Tensor(x)
if bos_token_id is None and eos_token_id is not None:
x = torch.cat([x.clone().detach(), torch.tensor([eos_token_id])])
elif bos_token_id is not None and eos_token_id is None:
x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach()])
elif bos_token_id is None and eos_token_id is None:
pass
else:
x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach(), torch.tensor([eos_token_id])])
return x
# Used for passage retrieval
def load_passages(path):
if not os.path.exists(path):
logger.info(f"{path} does not exist")
return
logger.info(f"Loading passages from: {path}")
passages = []
with open(path,encoding='UTF-8') as fin:
if path.endswith(".jsonl"):
for k, line in enumerate(fin):
ex = json.loads(line)
passages.append(ex)
else:
reader = csv.reader(fin, delimiter="\t")
for k, row in enumerate(reader):
if not row[0] == "id":
ex = {"id": row[0], "title": row[2], "text": row[1]}
passages.append(ex)
return passages

View File

@ -0,0 +1,128 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import torch.distributed as dist
class Gather(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.tensor):
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
all_gradients = torch.stack(grads)
dist.all_reduce(all_gradients)
return all_gradients[dist.get_rank()]
def gather(x: torch.tensor):
if not dist.is_initialized():
return x
x_gather = Gather.apply(x)
x_gather = torch.cat(x_gather, dim=0)
return x_gather
@torch.no_grad()
def gather_nograd(x: torch.tensor):
if not dist.is_initialized():
return x
x_gather = [torch.ones_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(x_gather, x, async_op=False)
x_gather = torch.cat(x_gather, dim=0)
return x_gather
@torch.no_grad()
def varsize_gather_nograd(x: torch.Tensor):
"""gather tensors of different sizes along the first dimension"""
if not dist.is_initialized():
return x
# determine max size
size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int)
allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
dist.all_gather(allsizes, size)
max_size = max([size.cpu().max() for size in allsizes])
padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device)
padded[: x.shape[0]] = x
output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())]
dist.all_gather(output, padded)
output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)]
output = torch.cat(output, dim=0)
return output
@torch.no_grad()
def get_varsize(x: torch.Tensor):
"""gather tensors of different sizes along the first dimension"""
if not dist.is_initialized():
return [x.shape[0]]
# determine max size
size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int)
allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
dist.all_gather(allsizes, size)
allsizes = torch.cat(allsizes)
return allsizes
def get_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def is_main():
return get_rank() == 0
def get_world_size():
if not dist.is_initialized():
return 1
else:
return dist.get_world_size()
def barrier():
if dist.is_initialized():
dist.barrier()
def average_main(x):
if not dist.is_initialized():
return x
if dist.is_initialized() and dist.get_world_size() > 1:
dist.reduce(x, 0, op=dist.ReduceOp.SUM)
if is_main():
x = x / dist.get_world_size()
return x
def sum_main(x):
if not dist.is_initialized():
return x
if dist.is_initialized() and dist.get_world_size() > 1:
dist.reduce(x, 0, op=dist.ReduceOp.SUM)
return x
def weighted_average(x, count):
if not dist.is_initialized():
if isinstance(x, torch.Tensor):
x = x.item()
return x, count
t_loss = torch.tensor([x * count]).cuda()
t_total = torch.tensor([count]).cuda()
t_loss = sum_main(t_loss)
t_total = sum_main(t_total)
return (t_loss / t_total).item(), t_total.item()

View File

@ -0,0 +1,190 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import collections
import logging
import regex
import string
import unicodedata
from functools import partial
from multiprocessing import Pool as ProcessPool
from typing import Tuple, List, Dict
import numpy as np
"""
Evaluation code from DPR: https://github.com/facebookresearch/DPR
"""
class SimpleTokenizer(object):
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
NON_WS = r'[^\p{Z}\p{C}]'
def __init__(self):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self._regexp = regex.compile(
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
)
def tokenize(self, text, uncased=False):
matches = [m for m in self._regexp.finditer(text)]
if uncased:
tokens = [m.group().lower() for m in matches]
else:
tokens = [m.group() for m in matches]
return tokens
logger = logging.getLogger(__name__)
QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits'])
def calculate_matches(data: List, workers_num: int):
"""
Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of
documents and results. It internally forks multiple sub-processes for evaluation and then merges results
:param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title)
:param answers: list of answers's list. One list per question
:param closest_docs: document ids of the top results along with their scores
:param workers_num: amount of parallel threads to process data
:param match_type: type of answer matching. Refer to has_answer code for available options
:return: matching information tuple.
top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of
valid matches across an entire dataset.
questions_doc_hits - more detailed info with answer matches for every question and every retrieved document
"""
logger.info('Matching answers in top docs...')
tokenizer = SimpleTokenizer()
get_score_partial = partial(check_answer, tokenizer=tokenizer)
processes = ProcessPool(processes=workers_num)
scores = processes.map(get_score_partial, data)
logger.info('Per question validation results len=%d', len(scores))
n_docs = len(data[0]['ctxs'])
top_k_hits = [0] * n_docs
for question_hits in scores:
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
if best_hit is not None:
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
return QAMatchStats(top_k_hits, scores)
def check_answer(example, tokenizer) -> List[bool]:
"""Search through all the top docs to see if they have any of the answers."""
answers = example['answers']
ctxs = example['ctxs']
hits = []
for i, doc in enumerate(ctxs):
text = doc['text']
if text is None: # cannot find the document for some reason
logger.warning("no doc in db")
hits.append(False)
continue
hits.append(has_answer(answers, text, tokenizer))
return hits
def has_answer(answers, text, tokenizer) -> bool:
"""Check if a document contains an answer string."""
text = _normalize(text)
text = tokenizer.tokenize(text, uncased=True)
for answer in answers:
answer = _normalize(answer)
answer = tokenizer.tokenize(answer, uncased=True)
for i in range(0, len(text) - len(answer) + 1):
if answer == text[i: i + len(answer)]:
return True
return False
#################################################
######## READER EVALUATION ########
#################################################
def _normalize(text):
return unicodedata.normalize('NFD', text)
#Normalization and score functions from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
def normalize_answer(s):
def remove_articles(text):
return regex.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def em(prediction, ground_truth):
return normalize_answer(prediction) == normalize_answer(ground_truth)
def f1(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def f1_score(prediction, ground_truths):
return max([f1(prediction, gt) for gt in ground_truths])
def exact_match_score(prediction, ground_truths):
return max([em(prediction, gt) for gt in ground_truths])
####################################################
######## RETRIEVER EVALUATION ########
####################################################
def eval_batch(scores, inversions, avg_topk, idx_topk):
for k, s in enumerate(scores):
s = s.cpu().numpy()
sorted_idx = np.argsort(-s)
score(sorted_idx, inversions, avg_topk, idx_topk)
def count_inversions(arr):
inv_count = 0
lenarr = len(arr)
for i in range(lenarr):
for j in range(i + 1, lenarr):
if (arr[i] > arr[j]):
inv_count += 1
return inv_count
def score(x, inversions, avg_topk, idx_topk):
x = np.array(x)
inversions.append(count_inversions(x))
for k in avg_topk:
# ratio of passages in the predicted top-k that are
# also in the topk given by gold score
avg_pred_topk = (x[:k]<k).mean()
avg_topk[k].append(avg_pred_topk)
for k in idx_topk:
below_k = (x<k)
# number of passages required to obtain all passages from gold top-k
idx_gold_topk = len(x) - np.argmax(below_k[::-1])
idx_topk[k].append(idx_gold_topk)

View File

@ -0,0 +1,171 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import random
import json
import sys
import numpy as np
from src import normalize_text
class Dataset(torch.utils.data.Dataset):
def __init__(
self,
datapaths,
negative_ctxs=1,
negative_hard_ratio=0.0,
negative_hard_min_idx=0,
training=False,
global_rank=-1,
world_size=-1,
maxload=None,
normalize=False,
):
self.negative_ctxs = negative_ctxs
self.negative_hard_ratio = negative_hard_ratio
self.negative_hard_min_idx = negative_hard_min_idx
self.training = training
self.normalize_fn = normalize_text.normalize if normalize_text else lambda x: x
self._load_data(datapaths, global_rank, world_size, maxload)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
example = self.data[index]
question = example["question"]
if self.training:
gold = random.choice(example["positive_ctxs"])
n_hard_negatives, n_random_negatives = self.sample_n_hard_negatives(example)
negatives = []
if n_random_negatives > 0:
random_negatives = random.sample(example["negative_ctxs"], n_random_negatives)
negatives += random_negatives
if n_hard_negatives > 0:
hard_negatives = random.sample(
example["hard_negative_ctxs"][self.negative_hard_min_idx :], n_hard_negatives
)
negatives += hard_negatives
else:
gold = example["positive_ctxs"][0]
nidx = 0
if "negative_ctxs" in example:
negatives = [example["negative_ctxs"][nidx]]
else:
negatives = []
gold = gold["title"] + " " + gold["text"] if "title" in gold and len(gold["title"]) > 0 else gold["text"]
negatives = [
n["title"] + " " + n["text"] if ("title" in n and len(n["title"]) > 0) else n["text"] for n in negatives
]
example = {
"query": self.normalize_fn(question),
"gold": self.normalize_fn(gold),
"negatives": [self.normalize_fn(n) for n in negatives],
}
return example
def _load_data(self, datapaths, global_rank, world_size, maxload):
counter = 0
self.data = []
for path in datapaths:
path = str(path)
if path.endswith(".jsonl"):
file_data, counter = self._load_data_jsonl(path, global_rank, world_size, counter, maxload)
elif path.endswith(".json"):
file_data, counter = self._load_data_json(path, global_rank, world_size, counter, maxload)
self.data.extend(file_data)
if maxload is not None and maxload > 0 and counter >= maxload:
break
def _load_data_json(self, path, global_rank, world_size, counter, maxload=None):
examples = []
with open(path, "r") as fin:
data = json.load(fin)
for example in data:
counter += 1
if global_rank > -1 and not counter % world_size == global_rank:
continue
examples.append(example)
if maxload is not None and maxload > 0 and counter == maxload:
break
return examples, counter
def _load_data_jsonl(self, path, global_rank, world_size, counter, maxload=None):
examples = []
with open(path, "r") as fin:
for line in fin:
counter += 1
if global_rank > -1 and not counter % world_size == global_rank:
continue
example = json.loads(line)
examples.append(example)
if maxload is not None and maxload > 0 and counter == maxload:
break
return examples, counter
def sample_n_hard_negatives(self, ex):
if "hard_negative_ctxs" in ex:
n_hard_negatives = sum([random.random() < self.negative_hard_ratio for _ in range(self.negative_ctxs)])
n_hard_negatives = min(n_hard_negatives, len(ex["hard_negative_ctxs"][self.negative_hard_min_idx :]))
else:
n_hard_negatives = 0
n_random_negatives = self.negative_ctxs - n_hard_negatives
if "negative_ctxs" in ex:
n_random_negatives = min(n_random_negatives, len(ex["negative_ctxs"]))
else:
n_random_negatives = 0
return n_hard_negatives, n_random_negatives
class Collator(object):
def __init__(self, tokenizer, passage_maxlength=200):
self.tokenizer = tokenizer
self.passage_maxlength = passage_maxlength
def __call__(self, batch):
queries = [ex["query"] for ex in batch]
golds = [ex["gold"] for ex in batch]
negs = [item for ex in batch for item in ex["negatives"]]
allpassages = golds + negs
qout = self.tokenizer.batch_encode_plus(
queries,
max_length=self.passage_maxlength,
truncation=True,
padding=True,
add_special_tokens=True,
return_tensors="pt",
)
kout = self.tokenizer.batch_encode_plus(
allpassages,
max_length=self.passage_maxlength,
truncation=True,
padding=True,
add_special_tokens=True,
return_tensors="pt",
)
q_tokens, q_mask = qout["input_ids"], qout["attention_mask"].bool()
k_tokens, k_mask = kout["input_ids"], kout["attention_mask"].bool()
g_tokens, g_mask = k_tokens[: len(golds)], k_mask[: len(golds)]
n_tokens, n_mask = k_tokens[len(golds) :], k_mask[len(golds) :]
batch = {
"q_tokens": q_tokens,
"q_mask": q_mask,
"k_tokens": k_tokens,
"k_mask": k_mask,
"g_tokens": g_tokens,
"g_mask": g_mask,
"n_tokens": n_tokens,
"n_mask": n_mask,
}
return batch

View File

@ -0,0 +1,90 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import torch.nn as nn
import numpy as np
import math
import random
import transformers
import logging
import torch.distributed as dist
from src import contriever, dist_utils, utils
logger = logging.getLogger(__name__)
class InBatch(nn.Module):
def __init__(self, opt, retriever=None, tokenizer=None):
super(InBatch, self).__init__()
self.opt = opt
self.norm_doc = opt.norm_doc
self.norm_query = opt.norm_query
self.label_smoothing = opt.label_smoothing
if retriever is None or tokenizer is None:
retriever, tokenizer = self._load_retriever(
opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init
)
self.tokenizer = tokenizer
self.encoder = retriever
def _load_retriever(self, model_id, pooling, random_init):
cfg = utils.load_hf(transformers.AutoConfig, model_id)
tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id)
if "xlm" in model_id:
model_class = contriever.XLMRetriever
else:
model_class = contriever.Contriever
if random_init:
retriever = model_class(cfg)
else:
retriever = utils.load_hf(model_class, model_id)
if "bert-" in model_id:
if tokenizer.bos_token_id is None:
tokenizer.bos_token = "[CLS]"
if tokenizer.eos_token_id is None:
tokenizer.eos_token = "[SEP]"
retriever.config.pooling = pooling
return retriever, tokenizer
def get_encoder(self):
return self.encoder
def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs):
bsz = len(q_tokens)
labels = torch.arange(0, bsz, dtype=torch.long, device=q_tokens.device)
qemb = self.encoder(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query)
kemb = self.encoder(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc)
gather_fn = dist_utils.gather
gather_kemb = gather_fn(kemb)
labels = labels + dist_utils.get_rank() * len(kemb)
scores = torch.einsum("id, jd->ij", qemb / self.opt.temperature, gather_kemb)
loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=self.label_smoothing)
# log stats
if len(stats_prefix) > 0:
stats_prefix = stats_prefix + "/"
iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz)
predicted_idx = torch.argmax(scores, dim=-1)
accuracy = 100 * (predicted_idx == labels).float().mean()
stdq = torch.std(qemb, dim=0).mean().item()
stdk = torch.std(kemb, dim=0).mean().item()
iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz)
iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz)
iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz)
return loss, iter_stats

View File

@ -0,0 +1,73 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import pickle
from typing import List, Tuple
import faiss
import numpy as np
from tqdm import tqdm
class Indexer(object):
def __init__(self, vector_sz, n_subquantizers=0, n_bits=8):
if n_subquantizers > 0:
self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
else:
self.index = faiss.IndexFlatIP(vector_sz)
#self.index_id_to_db_id = np.empty((0), dtype=np.int64)
self.index_id_to_db_id = []
def index_data(self, ids, embeddings):
self._update_id_mapping(ids)
embeddings = embeddings.astype('float32')
if not self.index.is_trained:
self.index.train(embeddings)
self.index.add(embeddings)
print(f'Total data indexed {len(self.index_id_to_db_id)}')
def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 2048) -> List[Tuple[List[object], List[float]]]:
query_vectors = query_vectors.astype('float32')
result = []
nbatch = (len(query_vectors)-1) // index_batch_size + 1
for k in tqdm(range(nbatch)):
start_idx = k*index_batch_size
end_idx = min((k+1)*index_batch_size, len(query_vectors))
q = query_vectors[start_idx: end_idx]
scores, indexes = self.index.search(q, top_docs)
# convert to external ids
db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))])
return result
def serialize(self, dir_path):
index_file = os.path.join(dir_path, 'index.faiss')
meta_file = os.path.join(dir_path, 'index_meta.faiss')
print(f'Serializing index to {index_file}, meta data to {meta_file}')
faiss.write_index(self.index, index_file)
with open(meta_file, mode='wb') as f:
pickle.dump(self.index_id_to_db_id, f)
def deserialize_from(self, dir_path):
index_file = os.path.join(dir_path, 'index.faiss')
meta_file = os.path.join(dir_path, 'index_meta.faiss')
print(f'Loading index from {index_file}, meta data from {meta_file}')
self.index = faiss.read_index(index_file)
print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)
with open(meta_file, "rb") as reader:
self.index_id_to_db_id = pickle.load(reader)
assert len(
self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
def _update_id_mapping(self, db_ids: List):
#new_ids = np.array(db_ids, dtype=np.int64)
#self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0)
self.index_id_to_db_id.extend(db_ids)

View File

@ -0,0 +1,140 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import torch.nn as nn
import logging
import copy
import transformers
from src import contriever, dist_utils, utils
logger = logging.getLogger(__name__)
class MoCo(nn.Module):
def __init__(self, opt):
super(MoCo, self).__init__()
self.queue_size = opt.queue_size
self.momentum = opt.momentum
self.temperature = opt.temperature
self.label_smoothing = opt.label_smoothing
self.norm_doc = opt.norm_doc
self.norm_query = opt.norm_query
self.moco_train_mode_encoder_k = opt.moco_train_mode_encoder_k # apply the encoder on keys in train mode
retriever, tokenizer = self._load_retriever(
opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init
)
self.tokenizer = tokenizer
self.encoder_q = retriever
self.encoder_k = copy.deepcopy(retriever)
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
# create the queue
self.register_buffer("queue", torch.randn(opt.projection_size, self.queue_size))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
def _load_retriever(self, model_id, pooling, random_init):
cfg = utils.load_hf(transformers.AutoConfig, model_id)
tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id)
if "xlm" in model_id:
model_class = contriever.XLMRetriever
else:
model_class = contriever.Contriever
if random_init:
retriever = model_class(cfg)
else:
retriever = utils.load_hf(model_class, model_id)
if "bert-" in model_id:
if tokenizer.bos_token_id is None:
tokenizer.bos_token = "[CLS]"
if tokenizer.eos_token_id is None:
tokenizer.eos_token = "[SEP]"
retriever.config.pooling = pooling
return retriever, tokenizer
def get_encoder(self, return_encoder_k=False):
if return_encoder_k:
return self.encoder_k
else:
return self.encoder_q
def _momentum_update_key_encoder(self):
"""
Update of the key encoder
"""
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
# gather keys before updating queue
keys = dist_utils.gather_nograd(keys.contiguous())
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.queue_size % batch_size == 0, f"{batch_size}, {self.queue_size}" # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr : ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.queue_size # move pointer
self.queue_ptr[0] = ptr
def _compute_logits(self, q, k):
l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])
logits = torch.cat([l_pos, l_neg], dim=1)
return logits
def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs):
bsz = q_tokens.size(0)
q = self.encoder_q(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query)
# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder
if not self.encoder_k.training and not self.moco_train_mode_encoder_k:
self.encoder_k.eval()
k = self.encoder_k(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc)
logits = self._compute_logits(q, k) / self.temperature
# labels: positive key indicators
labels = torch.zeros(bsz, dtype=torch.long).cuda()
loss = torch.nn.functional.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
self._dequeue_and_enqueue(k)
# log stats
if len(stats_prefix) > 0:
stats_prefix = stats_prefix + "/"
iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz)
predicted_idx = torch.argmax(logits, dim=-1)
accuracy = 100 * (predicted_idx == labels).float().mean()
stdq = torch.std(q, dim=0).mean().item()
stdk = torch.std(k, dim=0).mean().item()
iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz)
iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz)
iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz)
return loss, iter_stats

View File

@ -0,0 +1,162 @@
"""
adapted from chemdataextractor.text.normalize
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tools for normalizing text.
https://github.com/mcs07/ChemDataExtractor
:copyright: Copyright 2016 by Matt Swain.
:license: MIT
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
'Software'), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
#: Control characters.
CONTROLS = {
'\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011',
'\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b',
}
# There are further control characters, but they are instead replaced with a space by unicode normalization
# '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f'
#: Hyphen and dash characters.
HYPHENS = {
'-', # \u002d Hyphen-minus
'', # \u2010 Hyphen
'', # \u2011 Non-breaking hyphen
'', # \u2043 Hyphen bullet
'', # \u2012 figure dash
'', # \u2013 en dash
'', # \u2014 em dash
'', # \u2015 horizontal bar
}
#: Minus characters.
MINUSES = {
'-', # \u002d Hyphen-minus
'', # \u2212 Minus
'', # \uff0d Full-width Hyphen-minus
'', # \u207b Superscript minus
}
#: Plus characters.
PLUSES = {
'+', # \u002b Plus
'', # \uff0b Full-width Plus
'', # \u207a Superscript plus
}
#: Slash characters.
SLASHES = {
'/', # \u002f Solidus
'', # \u2044 Fraction slash
'', # \u2215 Division slash
}
#: Tilde characters.
TILDES = {
'~', # \u007e Tilde
'˜', # \u02dc Small tilde
'', # \u2053 Swung dash
'', # \u223c Tilde operator #in mbert vocab
'', # \u223d Reversed tilde
'', # \u223f Sine wave
'', # \u301c Wave dash #in mbert vocab
'', # \uff5e Full-width tilde #in mbert vocab
}
#: Apostrophe characters.
APOSTROPHES = {
"'", # \u0027
'', # \u2019
'՚', # \u055a
'', # \ua78b
'', # \ua78c
'', # \uff07
}
#: Single quote characters.
SINGLE_QUOTES = {
"'", # \u0027
'', # \u2018
'', # \u2019
'', # \u201a
'', # \u201b
}
#: Double quote characters.
DOUBLE_QUOTES = {
'"', # \u0022
'', # \u201c
'', # \u201d
'', # \u201e
'', # \u201f
}
#: Accent characters.
ACCENTS = {
'`', # \u0060
'´', # \u00b4
}
#: Prime characters.
PRIMES = {
'', # \u2032
'', # \u2033
'', # \u2034
'', # \u2035
'', # \u2036
'', # \u2037
'', # \u2057
}
#: Quote characters, including apostrophes, single quotes, double quotes, accents and primes.
QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES
def normalize(text):
for control in CONTROLS:
text = text.replace(control, '')
text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ')
for hyphen in HYPHENS | MINUSES:
text = text.replace(hyphen, '-')
text = text.replace('\u00ad', '')
for double_quote in DOUBLE_QUOTES:
text = text.replace(double_quote, '"') # \u0022
for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS):
text = text.replace(single_quote, "'") # \u0027
text = text.replace('', "'") # \u2032 prime
text = text.replace('', "'") # \u2035 reversed prime
text = text.replace('', "''") # \u2033 double prime
text = text.replace('', "''") # \u2036 reversed double prime
text = text.replace('', "'''") # \u2034 triple prime
text = text.replace('', "'''") # \u2037 reversed triple prime
text = text.replace('', "''''") # \u2057 quadruple prime
text = text.replace('', '...').replace(' . . . ', ' ... ') # \u2026
for slash in SLASHES:
text = text.replace(slash, '/')
#for tilde in TILDES:
# text = text.replace(tilde, '~')
return text

View File

@ -0,0 +1,132 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import argparse
import os
class Options:
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.initialize()
def initialize(self):
# basic parameters
self.parser.add_argument(
"--output_dir", type=str, default="./checkpoint/my_experiments", help="models are saved here"
)
self.parser.add_argument(
"--train_data",
nargs="+",
default=[],
help="Data used for training, passed as a list of directories splitted into tensor files.",
)
self.parser.add_argument(
"--eval_data",
nargs="+",
default=[],
help="Data used for evaluation during finetuning, this option is not used during contrastive pre-training.",
)
self.parser.add_argument(
"--eval_datasets", nargs="+", default=[], help="List of datasets used for evaluation, in BEIR format"
)
self.parser.add_argument(
"--eval_datasets_dir", type=str, default="./", help="Directory where eval datasets are stored"
)
self.parser.add_argument("--model_path", type=str, default="none", help="path for retraining")
self.parser.add_argument("--continue_training", action="store_true")
self.parser.add_argument("--num_workers", type=int, default=5)
self.parser.add_argument("--chunk_length", type=int, default=256)
self.parser.add_argument("--loading_mode", type=str, default="split")
self.parser.add_argument("--lower_case", action="store_true", help="perform evaluation after lowercasing")
self.parser.add_argument(
"--sampling_coefficient",
type=float,
default=0.0,
help="coefficient used for sampling between different datasets during training, \
by default sampling is uniform over datasets",
)
self.parser.add_argument("--augmentation", type=str, default="none")
self.parser.add_argument("--prob_augmentation", type=float, default=0.0)
self.parser.add_argument("--dropout", type=float, default=0.1)
self.parser.add_argument("--rho", type=float, default=0.05)
self.parser.add_argument("--contrastive_mode", type=str, default="moco")
self.parser.add_argument("--queue_size", type=int, default=65536)
self.parser.add_argument("--temperature", type=float, default=1.0)
self.parser.add_argument("--momentum", type=float, default=0.999)
self.parser.add_argument("--moco_train_mode_encoder_k", action="store_true")
self.parser.add_argument("--eval_normalize_text", action="store_true")
self.parser.add_argument("--norm_query", action="store_true")
self.parser.add_argument("--norm_doc", action="store_true")
self.parser.add_argument("--projection_size", type=int, default=768)
self.parser.add_argument("--ratio_min", type=float, default=0.1)
self.parser.add_argument("--ratio_max", type=float, default=0.5)
self.parser.add_argument("--score_function", type=str, default="dot")
self.parser.add_argument("--retriever_model_id", type=str, default="bert-base-uncased")
self.parser.add_argument("--pooling", type=str, default="average")
self.parser.add_argument("--random_init", action="store_true", help="init model with random weights")
# dataset parameters
self.parser.add_argument("--per_gpu_batch_size", default=64, type=int, help="Batch size per GPU for training.")
self.parser.add_argument(
"--per_gpu_eval_batch_size", default=256, type=int, help="Batch size per GPU for evaluation."
)
self.parser.add_argument("--total_steps", type=int, default=1000)
self.parser.add_argument("--warmup_steps", type=int, default=-1)
self.parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
self.parser.add_argument("--main_port", type=int, default=10001, help="Master port (for multi-node SLURM jobs)")
self.parser.add_argument("--seed", type=int, default=0, help="random seed for initialization")
# training parameters
self.parser.add_argument("--optim", type=str, default="adamw")
self.parser.add_argument("--scheduler", type=str, default="linear")
self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
self.parser.add_argument(
"--lr_min_ratio",
type=float,
default=0.0,
help="minimum learning rate at the end of the optimization schedule as a ratio of the learning rate",
)
self.parser.add_argument("--weight_decay", type=float, default=0.01, help="learning rate")
self.parser.add_argument("--beta1", type=float, default=0.9, help="beta1")
self.parser.add_argument("--beta2", type=float, default=0.98, help="beta2")
self.parser.add_argument("--eps", type=float, default=1e-6, help="eps")
self.parser.add_argument(
"--log_freq", type=int, default=100, help="log train stats every <log_freq> steps during training"
)
self.parser.add_argument(
"--eval_freq", type=int, default=500, help="evaluate model every <eval_freq> steps during training"
)
self.parser.add_argument("--save_freq", type=int, default=50000)
self.parser.add_argument("--maxload", type=int, default=None)
self.parser.add_argument("--label_smoothing", type=float, default=0.0)
# finetuning options
self.parser.add_argument("--negative_ctxs", type=int, default=1)
self.parser.add_argument("--negative_hard_min_idx", type=int, default=0)
self.parser.add_argument("--negative_hard_ratio", type=float, default=0.0)
def print_options(self, opt):
message = ""
for k, v in sorted(vars(opt).items()):
comment = ""
default = self.parser.get_default(k)
if v != default:
comment = f"\t[default: %s]" % str(default)
message += f"{str(k):>40}: {str(v):<40}{comment}\n"
print(message, flush=True)
model_dir = os.path.join(opt.output_dir, "models")
if not os.path.exists(model_dir):
os.makedirs(os.path.join(opt.output_dir, "models"))
file_name = os.path.join(opt.output_dir, "opt.txt")
with open(file_name, "wt") as opt_file:
opt_file.write(message)
opt_file.write("\n")
def parse(self):
opt, _ = self.parser.parse_known_args()
# opt = self.parser.parse_args()
return opt

View File

@ -0,0 +1,114 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from logging import getLogger
import os
import sys
import torch
import socket
import signal
import subprocess
logger = getLogger()
def sig_handler(signum, frame):
logger.warning("Signal handler called with signal " + str(signum))
prod_id = int(os.environ['SLURM_PROCID'])
logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id))
if prod_id == 0:
logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID'])
os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID'])
else:
logger.warning("Not the main process, no need to requeue.")
sys.exit(-1)
def term_handler(signum, frame):
logger.warning("Signal handler called with signal " + str(signum))
logger.warning("Bypassing SIGTERM.")
def init_signal_handler():
"""
Handle signals sent by SLURM for time limit / pre-emption.
"""
signal.signal(signal.SIGUSR1, sig_handler)
signal.signal(signal.SIGTERM, term_handler)
def init_distributed_mode(params):
"""
Handle single and multi-GPU / multi-node / SLURM jobs.
Initialize the following variables:
- local_rank
- global_rank
- world_size
"""
is_slurm_job = 'SLURM_JOB_ID' in os.environ and not 'WORLD_SIZE' in os.environ
has_local_rank = hasattr(params, 'local_rank')
# SLURM job without torch.distributed.launch
if is_slurm_job and has_local_rank:
assert params.local_rank == -1 # on the cluster, this is handled by SLURM
# local rank on the current node / global rank
params.local_rank = int(os.environ['SLURM_LOCALID'])
params.global_rank = int(os.environ['SLURM_PROCID'])
params.world_size = int(os.environ['SLURM_NTASKS'])
# define master address and master port
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']])
params.main_addr = hostnames.split()[0].decode('utf-8')
assert 10001 <= params.main_port <= 20000 or params.world_size == 1
# set environment variables for 'env://'
os.environ['MASTER_ADDR'] = params.main_addr
os.environ['MASTER_PORT'] = str(params.main_port)
os.environ['WORLD_SIZE'] = str(params.world_size)
os.environ['RANK'] = str(params.global_rank)
is_distributed = True
# multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch
elif has_local_rank and params.local_rank != -1:
assert params.main_port == -1
# read environment variables
params.global_rank = int(os.environ['RANK'])
params.world_size = int(os.environ['WORLD_SIZE'])
is_distributed = True
# local job (single GPU)
else:
params.local_rank = 0
params.global_rank = 0
params.world_size = 1
is_distributed = False
# set GPU device
torch.cuda.set_device(params.local_rank)
# initialize multi-GPU
if is_distributed:
# http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization
# 'env://' will read these environment variables:
# MASTER_PORT - required; has to be a free port on machine with rank 0
# MASTER_ADDR - required (except for rank 0); address of rank 0 node
# WORLD_SIZE - required; can be set either here, or in a call to init function
# RANK - required; can be set either here, or in a call to init function
#print("Initializing PyTorch distributed ...")
torch.distributed.init_process_group(
init_method='env://',
backend='nccl',
#world_size=params.world_size,
#rank=params.global_rank,
)

View File

@ -0,0 +1,213 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import sys
import logging
import torch
import errno
from typing import Union, Tuple, Dict
from collections import defaultdict
from robowaiter.algos.retrieval.retrieval_lm.src import dist_utils
Number = Union[float, int]
logger = logging.getLogger(__name__)
def init_logger(args, stdout_only=False):
if torch.distributed.is_initialized():
torch.distributed.barrier()
stdout_handler = logging.StreamHandler(sys.stdout)
handlers = [stdout_handler]
if not stdout_only:
file_handler = logging.FileHandler(filename=os.path.join(args.output_dir, "run.log"))
handlers.append(file_handler)
logging.basicConfig(
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if dist_utils.is_main() else logging.WARN,
format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s",
handlers=handlers,
)
return logger
def symlink_force(target, link_name):
try:
os.symlink(target, link_name)
except OSError as e:
if e.errno == errno.EEXIST:
os.remove(link_name)
os.symlink(target, link_name)
else:
raise e
def save(model, optimizer, scheduler, step, opt, dir_path, name):
model_to_save = model.module if hasattr(model, "module") else model
path = os.path.join(dir_path, "checkpoint")
epoch_path = os.path.join(path, name) # "step-%s" % step)
os.makedirs(epoch_path, exist_ok=True)
cp = os.path.join(path, "latest")
fp = os.path.join(epoch_path, "checkpoint.pth")
checkpoint = {
"step": step,
"model": model_to_save.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"opt": opt,
}
torch.save(checkpoint, fp)
symlink_force(epoch_path, cp)
if not name == "lastlog":
logger.info(f"Saving model to {epoch_path}")
def load(model_class, dir_path, opt, reset_params=False):
epoch_path = os.path.realpath(dir_path)
checkpoint_path = os.path.join(epoch_path, "checkpoint.pth")
logger.info(f"loading checkpoint {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
opt_checkpoint = checkpoint["opt"]
state_dict = checkpoint["model"]
model = model_class(opt_checkpoint)
model.load_state_dict(state_dict, strict=True)
model = model.cuda()
step = checkpoint["step"]
if not reset_params:
optimizer, scheduler = set_optim(opt_checkpoint, model)
scheduler.load_state_dict(checkpoint["scheduler"])
optimizer.load_state_dict(checkpoint["optimizer"])
else:
optimizer, scheduler = set_optim(opt, model)
return model, optimizer, scheduler, opt_checkpoint, step
############ OPTIM
class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR):
def __init__(self, optimizer, warmup, total, ratio, last_epoch=-1):
self.warmup = warmup
self.total = total
self.ratio = ratio
super(WarmupLinearScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup:
return (1 - self.ratio) * step / float(max(1, self.warmup))
return max(
0.0,
1.0 + (self.ratio - 1) * (step - self.warmup) / float(max(1.0, self.total - self.warmup)),
)
class CosineScheduler(torch.optim.lr_scheduler.LambdaLR):
def __init__(self, optimizer, warmup, total, ratio=0.1, last_epoch=-1):
self.warmup = warmup
self.total = total
self.ratio = ratio
super(CosineScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup:
return float(step) / self.warmup
s = float(step - self.warmup) / (self.total - self.warmup)
return self.ratio + (1.0 - self.ratio) * math.cos(0.5 * math.pi * s)
def set_optim(opt, model):
if opt.optim == "adamw":
optimizer = torch.optim.AdamW(
model.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), eps=opt.eps, weight_decay=opt.weight_decay
)
else:
raise NotImplementedError("optimizer class not implemented")
scheduler_args = {
"warmup": opt.warmup_steps,
"total": opt.total_steps,
"ratio": opt.lr_min_ratio,
}
if opt.scheduler == "linear":
scheduler_class = WarmupLinearScheduler
elif opt.scheduler == "cosine":
scheduler_class = CosineScheduler
else:
raise ValueError
scheduler = scheduler_class(optimizer, **scheduler_args)
return optimizer, scheduler
def get_parameters(net, verbose=False):
num_params = 0
for param in net.parameters():
num_params += param.numel()
message = "[Network] Total number of parameters : %.6f M" % (num_params / 1e6)
return message
class WeightedAvgStats:
"""provides an average over a bunch of stats"""
def __init__(self):
self.raw_stats: Dict[str, float] = defaultdict(float)
self.total_weights: Dict[str, float] = defaultdict(float)
def update(self, vals: Dict[str, Tuple[Number, Number]]) -> None:
for key, (value, weight) in vals.items():
self.raw_stats[key] += value * weight
self.total_weights[key] += weight
@property
def stats(self) -> Dict[str, float]:
return {x: self.raw_stats[x] / self.total_weights[x] for x in self.raw_stats.keys()}
@property
def tuple_stats(self) -> Dict[str, Tuple[float, float]]:
return {x: (self.raw_stats[x] / self.total_weights[x], self.total_weights[x]) for x in self.raw_stats.keys()}
def reset(self) -> None:
self.raw_stats = defaultdict(float)
self.total_weights = defaultdict(float)
@property
def average_stats(self) -> Dict[str, float]:
keys = sorted(self.raw_stats.keys())
if torch.distributed.is_initialized():
torch.distributed.broadcast_object_list(keys, src=0)
global_dict = {}
for k in keys:
if not k in self.total_weights:
v = 0.0
else:
v = self.raw_stats[k] / self.total_weights[k]
v, _ = dist_utils.weighted_average(v, self.total_weights[k])
global_dict[k] = v
return global_dict
def load_hf(object_class, model_name):
try:
obj = object_class.from_pretrained(model_name, local_files_only=True)
except:
obj = object_class.from_pretrained(model_name, local_files_only=False)
return obj
def init_tb_logger(output_dir):
try:
from torch.utils import tensorboard
if dist_utils.is_main():
tb_logger = tensorboard.SummaryWriter(output_dir)
else:
tb_logger = None
except:
logger.warning("Tensorboard is not available.")
tb_logger = None
return tb_logger

View File

@ -0,0 +1,23 @@
{
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 1e5,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,194 @@
import jsonlines
import json
import copy
import re
PROMPT_DICT = {
"prompt_input": (
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
),
"prompt_no_input": (
"### Instruction:\n{instruction}\n\n### Response:\n"
),
}
TASK_INST = {"wow": "Given a chat history separated by new lines, generates an informative, knowledgeable and engaging response. ",
"fever": "Is the following statement correct or not? Say true if it's correct; otherwise say false.",
"eli5": "Provide a paragraph-length response using simple words to answer the following question.",
"obqa": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
"arc_easy": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
"arc_c": "Given four answer candidates, A, B, C and D, choose the best answer choice.",
"trex": "Given the input format 'Subject Entity [SEP] Relationship Type,' predict the target entity.",
"asqa": "Answer the following question. The question may be ambiguous and have multiple correct answers, and in that case, you have to provide a long-form answer including all correct answers."}
rel_tokens_names = ["[Irrelevant]", "[Relevant]"]
retrieval_tokens_names = ["[No Retrieval]",
"[Retrieval]", "[Continue to Use Evidence]"]
utility_tokens_names = ["[Utility:1]", "[Utility:2]",
"[Utility:3]", "[Utility:4]", "[Utility:5]"]
ground_tokens_names = ["[Fully supported]",
"[Partially supported]", "[No support / Contradictory]"]
other_special_tokens = ["<s>", "</s>", "[PAD]",
"<unk>", "<paragraph>", "</paragraph>"]
control_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]",
"[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"]
def load_special_tokens(tokenizer, use_grounding=False, use_utility=False):
ret_tokens = {token: tokenizer.convert_tokens_to_ids(
token) for token in retrieval_tokens_names}
rel_tokens = {}
for token in ["[Irrelevant]", "[Relevant]"]:
rel_tokens[token] = tokenizer.convert_tokens_to_ids(token)
grd_tokens = None
if use_grounding is True:
grd_tokens = {}
for token in ground_tokens_names:
grd_tokens[token] = tokenizer.convert_tokens_to_ids(token)
ut_tokens = None
if use_utility is True:
ut_tokens = {}
for token in utility_tokens_names:
ut_tokens[token] = tokenizer.convert_tokens_to_ids(token)
return ret_tokens, rel_tokens, grd_tokens, ut_tokens
def fix_spacing(input_text):
# Add a space after periods that lack whitespace
output_text = re.sub(r'(?<=\w)([.!?])(?=\w)', r'\1 ', input_text)
return output_text
def postprocess(pred):
special_tokens = ["[Fully supported]", "[Partially supported]", "[No support / Contradictory]", "[No Retrieval]", "[Retrieval]",
"[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]"]
for item in special_tokens:
pred = pred.replace(item, "")
pred = pred.replace("</s>", "")
if len(pred) == 0:
return ""
if pred[0] == " ":
pred = pred[1:]
return pred
def load_jsonlines(file):
with jsonlines.open(file, 'r') as jsonl_f:
lst = [obj for obj in jsonl_f]
return lst
def load_file(input_fp):
if input_fp.endswith(".json"):
input_data = json.load(open(input_fp))
else:
input_data = load_jsonlines(input_fp)
return input_data
def save_file_jsonl(data, fp):
with jsonlines.open(fp, mode='w') as writer:
writer.write_all(data)
def preprocess_input(input_data, task):
if task == "factscore":
for item in input_data:
item["instruction"] = item["input"]
item["output"] = [item["output"]
] if "output" in item else [item["topic"]]
return input_data
elif task == "qa":
for item in input_data:
if "instruction" not in item:
item["instruction"] = item["question"]
if "answers" not in item and "output" in item:
item["answers"] = "output"
return input_data
elif task in ["asqa", "eli5"]:
processed_input_data = []
for instance_idx, item in enumerate(input_data["data"]):
prompt = item["question"]
instructions = TASK_INST[task]
prompt = instructions + "## Input:\n\n" + prompt
entry = copy.deepcopy(item)
entry["instruction"] = prompt
processed_input_data.append(entry)
return processed_input_data
def postprocess_output(input_instance, prediction, task, intermediate_results=None):
if task == "factscore":
return {"input": input_instance["input"], "output": prediction, "topic": input_instance["topic"], "cat": input_instance["cat"]}
elif task == "qa":
input_instance["pred"] = prediction
return input_instance
elif task in ["asqa", "eli5"]:
# ALCE datasets require additional postprocessing to compute citation accuracy.
final_output = ""
docs = []
if "splitted_sentences" not in intermediate_results:
input_instance["output"] = postprocess(prediction)
else:
for idx, (sent, doc) in enumerate(zip(intermediate_results["splitted_sentences"][0], intermediate_results["ctxs"][0])):
if len(sent) == 0:
continue
postprocessed_result = postprocess(sent)
final_output += postprocessed_result[:-
1] + " [{}]".format(idx) + ". "
docs.append(doc)
if final_output[-1] == " ":
final_output = final_output[:-1]
input_instance["output"] = final_output
input_instance["docs"] = docs
return input_instance
def process_arc_instruction(item, instruction):
choices = item["choices"]
answer_labels = {}
for i in range(len(choices["label"])):
answer_key = choices["label"][i]
text = choices["text"][i]
if answer_key == "1":
answer_labels["A"] = text
if answer_key == "2":
answer_labels["B"] = text
if answer_key == "3":
answer_labels["C"] = text
if answer_key == "4":
answer_labels["D"] = text
if answer_key in ["A", "B", "C", "D"]:
answer_labels[answer_key] = text
if "D" not in answer_labels:
answer_labels["D"] = ""
choices = "\nA: {0}\nB: {1}\nC: {2}\nD: {3}".format(answer_labels["A"], answer_labels["B"], answer_labels["C"], answer_labels["D"])
if "E" in answer_labels:
choices += "\nE: {}".format(answer_labels["E"])
processed_instruction = instruction + "\n\n### Input:\n" + item["instruction"] + choices
return processed_instruction
def postprocess_answers_closed(output, task, choices=None):
final_output = None
if choices is not None:
for c in choices.split(" "):
if c in output:
final_output = c
if task == "fever" and output in ["REFUTES", "SUPPORTS"]:
final_output = "true" if output == "SUPPORTS" else "REFUTES"
if task == "fever" and output.lower() in ["true", "false"]:
final_output = output.lower()
if final_output is None:
return output
else:
return final_output

View File

@ -0,0 +1,10 @@
python passage_retrieval3.py \
--model_name_or_path ../model/contriever-msmarco \
--passages train_robot.jsonl \
--passages_embeddings "robot_embeddings/*" \
--data test_robot.jsonl \
--output_dir robot_result \
--n_docs 2
#python passage_retrieval3.py --model_name_or_path contriever-msmarco --passages train_robot.jsonl --passages_embeddings "robot_embeddings/*" --data test_robot.jsonl --output_dir robot_result --n_docs 2

View File

@ -195,6 +195,7 @@ get_object_info
我带着孩子呢,想要宽敞亮堂的地方。
好的,我明白了,那么我们推荐您到大厅的桌子,那里的空间比较宽敞,环境也比较明亮,适合带着孩子一起用餐。
冰红茶
好的
create_sub_task

View File

@ -0,0 +1,312 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import argparse
import json
import pickle
import time
import glob
import numpy as np
import torch
from robowaiter.algos.retrieval.retrieval_lm.src.slurm import init_distributed_mode
from robowaiter.algos.retrieval.retrieval_lm.src.normalize_text import normalize
from robowaiter.algos.retrieval.retrieval_lm.src.contriever import load_retriever
from robowaiter.algos.retrieval.retrieval_lm.src.index import Indexer
from robowaiter.algos.retrieval.retrieval_lm.src.data import load_passages
from robowaiter.algos.retrieval.retrieval_lm.src.evaluation import calculate_matches
import warnings
from robowaiter.utils.basic import get_root_path
root_path = get_root_path()
warnings.filterwarnings('ignore')
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def embed_queries(args, queries, model, tokenizer):
model.eval()
embeddings, batch_question = [], []
with torch.no_grad():
for k, q in enumerate(queries):
if args.lowercase:
q = q.lower()
if args.normalize_text:
q = normalize(q)
batch_question.append(q)
if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
encoded_batch = tokenizer.batch_encode_plus(
batch_question,
return_tensors="pt",
max_length=args.question_maxlength,
padding=True,
truncation=True,
)
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
output = model(**encoded_batch)
embeddings.append(output.cpu())
batch_question = []
embeddings = torch.cat(embeddings, dim=0)
#print(f"Questions embeddings shape: {embeddings.size()}")
return embeddings.numpy()
def index_encoded_data(index, embedding_files, indexing_batch_size):
allids = []
allembeddings = np.array([])
for i, file_path in enumerate(embedding_files):
#print(f"Loading file {file_path}")
with open(file_path, "rb") as fin:
ids, embeddings = pickle.load(fin)
allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings
allids.extend(ids)
while allembeddings.shape[0] > indexing_batch_size:
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
while allembeddings.shape[0] > 0:
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
#print("Data indexing completed.")
def add_embeddings(index, embeddings, ids, indexing_batch_size):
end_idx = min(indexing_batch_size, embeddings.shape[0])
ids_toadd = ids[:end_idx]
embeddings_toadd = embeddings[:end_idx]
ids = ids[end_idx:]
embeddings = embeddings[end_idx:]
index.index_data(ids_toadd, embeddings_toadd)
return embeddings, ids
def validate(data, workers_num):
match_stats = calculate_matches(data, workers_num)
top_k_hits = match_stats.top_k_hits
# print("Validation results: top k documents hits %s", top_k_hits)
top_k_hits = [v / len(data) for v in top_k_hits]
message = ""
for k in [5, 10, 20, 100]:
if k <= len(top_k_hits):
message += f"R@{k}: {top_k_hits[k-1]} "
#print(message)
return match_stats.questions_doc_hits
def add_passages(data, passages, top_passages_and_scores):
# add passages to original data
merged_data = []
assert len(data) == len(top_passages_and_scores)
for i, d in enumerate(data):
results_and_scores = top_passages_and_scores[i]
#print(passages[2393])
docs = [passages[int(doc_id)] for doc_id in results_and_scores[0]]
scores = [str(score) for score in results_and_scores[1]]
ctxs_num = len(docs)
d["ctxs"] = [
{
"id": results_and_scores[0][c],
"title": docs[c]["title"],
"text": docs[c]["text"],
"score": scores[c],
}
for c in range(ctxs_num)
]
def add_hasanswer(data, hasanswer):
# add hasanswer to data
for i, ex in enumerate(data):
for k, d in enumerate(ex["ctxs"]):
d["hasanswer"] = hasanswer[i][k]
# def load_data(data_path):
# if data_path.endswith(".json"):
# with open(data_path, "r",encoding='utf-8') as fin:
# data = json.load(fin)
# elif data_path.endswith(".jsonl"):
# data = []
# with open(data_path, "r",encoding='utf-8') as fin:
# for k, example in enumerate(fin):
# example = json.loads(example)
# data.append(example)
# print("data:",data)
# return data
def load_data(data_path):
if data_path.endswith(".json"):
with open(data_path, "r",encoding='utf-8') as fin:
data = json.load(fin)
elif data_path.endswith(".jsonl"):
data = []
with open(data_path, "r",encoding='utf-8') as fin:
for k, example in enumerate(fin):
example = json.loads(example)
#print("example:",example)
data.append(example)
return data
def test(args):#path为query
# args = {"model_name_or_path":"contriever-msmarco","passages":"train_robot.jsonl"\
# passages_embeddings = "robot_embeddings/*"
# data = "test_robot.jsonl"
# output_dir = "robot_result"
# n_docs = 1
#print(f"Loading model from: {args.model_name_or_path}")
model, tokenizer, _ = load_retriever(args.model_name_or_path)
model.eval()
model = model.cuda()
if not args.no_fp16:
model = model.half()
index = Indexer(args.projection_size, args.n_subquantizers, args.n_bits)
# index all passages
input_paths = glob.glob(args.passages_embeddings)
input_paths = sorted(input_paths)
embeddings_dir = os.path.dirname(input_paths[0])
index_path = os.path.join(embeddings_dir, "index.faiss")
if args.save_or_load_index and os.path.exists(index_path):
index.deserialize_from(embeddings_dir)
else:
#print(f"Indexing passages from files {input_paths}")
start_time_indexing = time.time()
index_encoded_data(index, input_paths, args.indexing_batch_size)
#print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
if args.save_or_load_index:
index.serialize(embeddings_dir)
# load passages
passages = load_passages(args.passages)
passage_id_map = {x["id"]: x for x in passages}
data_paths = glob.glob(args.data)
alldata = []
for path in data_paths:
data = load_data(path)
#print("data:",data)
output_path = os.path.join(args.output_dir, os.path.basename(path))
queries = [ex["question"] for ex in data]
questions_embedding = embed_queries(args, queries, model, tokenizer)
# get top k results
start_time_retrieval = time.time()
top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
#print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
add_passages(data, passage_id_map, top_ids_and_scores)
#hasanswer = validate(data, args.validation_workers)
#add_hasanswer(data, hasanswer)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
ret_list = []
with open(output_path, "w",encoding='utf-8') as fout:
for ex in data:
json.dump(ex, fout, ensure_ascii=False)
ret_list.append(ex)
fout.write("\n")
return ret_list
#print(f"Saved results to {output_path}")
#将query写到test_robot.jsonl
def get_json(query):
dic = {"id": 1, "question": query}
with open('test_robot.jsonl', "w", encoding='utf-8') as fout:
json.dump(dic, fout, ensure_ascii=False)
def get_answer():
with open('robot_result\\test_robot.jsonl', "w", encoding='utf-8') as fin:
for k, example in enumerate(fin):
example = json.loads(example)
answer = example["ctxs"][0]["text"]
score = example["ctxs"][0]["score"]
return score, answer
def retri(query):
get_json(query)
parser = argparse.ArgumentParser()
parser.add_argument(
"--data",
#required=True,
type=str,
default='test_robot.jsonl',
help=".json file containing question and answers, similar format to reader data",
)
# parser.add_argument("--passages", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/train_robot.jsonl', help="Path to passages (.tsv file)")
# parser.add_argument("--passages_embeddings", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/robot_embeddings/*', help="Glob path to encoded passages")
parser.add_argument("--passages", type=str, default=f'{root_path}/robowaiter/llm_client/train_robot.jsonl', help="Path to passages (.tsv file)")
parser.add_argument("--passages_embeddings", type=str, default=f'{root_path}/robowaiter/algos/retrieval/robot_embeddings/*', help="Glob path to encoded passages")
parser.add_argument(
"--output_dir", type=str, default='robot_result', help="Results are written to outputdir with data suffix"
)
parser.add_argument("--n_docs", type=int, default=5, help="Number of documents to retrieve per questions") #可以改这个参数返回前n_docs个检索结果
parser.add_argument(
"--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
)
parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding")
parser.add_argument(
"--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
)
# parser.add_argument(
# "--model_name_or_path", type=str, default='C:\\Users\\huangyu\\Desktop\\RoboWaiter-main\\RoboWaiter-main\\contriever-msmarco',help="path to directory containing model weights and config file"
# )
parser.add_argument(
"--model_name_or_path", type=str, default=f'{root_path}/robowaiter/algos/retrieval/contriever-msmarco',help="path to directory containing model weights and config file"
)
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
parser.add_argument(
"--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
)
parser.add_argument("--projection_size", type=int, default=768)
parser.add_argument(
"--n_subquantizers",
type=int,
default=0,
help="Number of subquantizer used for vector quantization, if 0 flat index is used",
)
parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer")
parser.add_argument("--lang", nargs="+")
parser.add_argument("--dataset", type=str, default="none")
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
parser.add_argument("--normalize_text", action="store_true", help="normalize text")
args = parser.parse_args()
init_distributed_mode(args)
#print(args)
ret = test(args)
#print(ret)
return ret[0]
# example = ret[0]
# answer = example["ctxs"][0]["text"]
# score = example["ctxs"][0]["score"]
# return score, answer
if __name__ == "__main__":
# query = "请你拿一下软饮料到第三张桌子位置。"
# score,answer = retri(query)
# print(score,answer)
query = "你能把空调打开一下吗?"
all_ret = retri(query)
for i,example in enumerate(all_ret["ctxs"]):
answer = example["text"]
score = example["score"]
id = example["id"]
print(i,answer,score," id=",id)

View File

@ -0,0 +1 @@
{"id": 1, "question": "你能把空调打开一下吗?", "ctxs": [{"id": "505", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}, {"id": "313", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8567487"}, {"id": "312", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}, {"id": "120", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8567487"}, {"id": "119", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}]}

View File

@ -0,0 +1,43 @@
import requests
import urllib3
########################################
# 该文件实现了与大模型的简单通信
########################################
# 忽略https的安全性警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
def single_round(question,prefix=""):
url = "https://45.125.46.134:25344/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"model": "RoboWaiter",
"messages": [
{
"role": "system",
"content": "你是一个机器人服务员RoboWaiter. 你的职责是为顾客提供对话及具身服务。"
},
{
"role": "user",
"content": prefix + question
}
]
}
response = requests.post(url, headers=headers, json=data, verify=False)
if response.status_code == 200:
result = response.json()
return result['choices'][0]['message']['content'].strip()
else:
return "大模型请求失败:", response.status_code
if __name__ == '__main__':
question = '''
给我一杯拿铁
'''
print(single_round(question))

View File

@ -0,0 +1 @@
{"id": 1, "question": "你能把空调打开一下吗?"}

File diff suppressed because it is too large Load Diff