GFM-RAG Retrieval
GFM-RAG can be directly used for retrieval on a given dataset without fine-tuning. We provide an easy-to-use GFMRetriever interface for inference.
Config¶
You need to create a configuration file for inference.
gfmrag/workflow/config/stage3_qa_ircot_inference.yaml
hydra:
run:
dir: outputs/qa_agent_inference/${dataset.data_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} # Output directory
defaults:
- _self_
- doc_ranker: idf_topk_ranker # The document ranker to use
- agent_prompt: hotpotqa_ircot # The agent prompt to use
- qa_prompt: hotpotqa # The QA prompt to use
- ner_model: llm_ner_model # The NER model to use
- el_model: colbert_el_model # The EL model to use
- qa_evaluator: hotpotqa # The QA evaluator to use
seed: 1024
dataset:
root: ./data # data root directory
data_name: hotpotqa_test # data name
llm:
_target_: gfmrag.llms.ChatGPT # The language model to use
model_name_or_path: gpt-3.5-turbo # The model name or path
retry: 5 # Number of retries
graph_retriever:
model_path: rmanluo/GFM-RAG-8M # Checkpoint path of the pre-trained GFM-RAG model
doc_ranker: ${doc_ranker} # The document ranker to use
ner_model: ${ner_model} # The NER model to usek
el_model: ${el_model} # The EL model to use
qa_evaluator: ${qa_evaluator} # The QA evaluator to use
init_entities_weight: True # Whether to initialize the entities weight
test:
top_k: 10 # Number of documents to retrieve
max_steps: 2 # Maximum number of steps
max_test_samples: -1 # -1 for all samples
resume: null # Resume from previous prediction
Details of the configuration parameters are explained in the GFM-RAG Configuration page.
Initialize GFMRetriever¶
You can initialize the GFMRetriever with the following code. It will load the pre-trained GFM-RAG model and the KG-index for retrieval.
import logging
import os
import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from gfmrag import GFMRetriever
logger = logging.getLogger(__name__)
@hydra.main(
config_path="config", config_name="stage3_qa_ircot_inference", version_base=None
)
def main(cfg: DictConfig) -> None:
output_dir = HydraConfig.get().runtime.output_dir
logger.info(f"Config:\n {OmegaConf.to_yaml(cfg)}")
logger.info(f"Current working directory: {os.getcwd()}")
logger.info(f"Output directory: {output_dir}")
gfmrag_retriever = GFMRetriever.from_config(cfg)
Document Retrieval¶
You can use GFM-RAG retriever to reason over the KG-index and obtain documents for a given query.
Question Answering¶
from hydra.utils import instantiate
from gfmrag.llms import BaseLanguageModel
from gfmrag.prompt_builder import QAPromptBuilder
llm = instantiate(cfg.llm)
qa_prompt_builder = QAPromptBuilder(cfg.qa_prompt)
message = qa_prompt_builder.build_input_prompt(current_query, retrieved_docs)
answer = llm.generate_sentence(message) # Answer: "Emmanuel Macron"
GFM-RAG + Agent for Multi-step Retrieval¶
You can also integrate the GFM-RAG with arbitrary reasoning agents to perform multi-step RAG. Here is an example of IRCOT + GFM-RAG:
You can run the following command to perform multi-step reasoning:
gfmrag/workflow/stage3_qa_ircot_inference.py
import json
import logging
import os
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm
from gfmrag import GFMRetriever
from gfmrag.evaluation import RetrievalEvaluator
from gfmrag.llms import BaseLanguageModel
from gfmrag.prompt_builder import QAPromptBuilder
from gfmrag.ultra import query_utils
# A logger for this file
logger = logging.getLogger(__name__)
def agent_reasoning(
cfg: DictConfig,
gfmrag_retriever: GFMRetriever,
llm: BaseLanguageModel,
qa_prompt_builder: QAPromptBuilder,
query: str,
) -> dict:
step = 1
current_query = query
thoughts: list[str] = []
retrieved_docs = gfmrag_retriever.retrieve(current_query, top_k=cfg.test.top_k)
logs = []
while step <= cfg.test.max_steps:
message = qa_prompt_builder.build_input_prompt(
current_query, retrieved_docs, thoughts
)
response = llm.generate_sentence(message)
if isinstance(response, Exception):
raise response from None
thoughts.append(response)
logs.append(
{
"step": step,
"query": current_query,
"retrieved_docs": retrieved_docs,
"response": response,
"thoughts": thoughts,
}
)
if "So the answer is:" in response:
break
step += 1
new_ret_docs = gfmrag_retriever.retrieve(response, top_k=cfg.test.top_k)
retrieved_docs_dict = {doc["title"]: doc for doc in retrieved_docs}
for doc in new_ret_docs:
if doc["title"] in retrieved_docs_dict:
if doc["norm_score"] > retrieved_docs_dict[doc["title"]]["norm_score"]:
retrieved_docs_dict[doc["title"]]["score"] = doc["score"]
retrieved_docs_dict[doc["title"]]["norm_score"] = doc["norm_score"]
else:
retrieved_docs_dict[doc["title"]] = doc
# Sort the retrieved docs by score
retrieved_docs = sorted(
retrieved_docs_dict.values(), key=lambda x: x["norm_score"], reverse=True
)
# Only keep the top k
retrieved_docs = retrieved_docs[: cfg.test.top_k]
final_response = " ".join(thoughts)
return {"response": final_response, "retrieved_docs": retrieved_docs, "logs": logs}
@hydra.main(
config_path="config", config_name="stage3_qa_ircot_inference", version_base=None
)
def main(cfg: DictConfig) -> None:
output_dir = HydraConfig.get().runtime.output_dir
logger.info(f"Config:\n {OmegaConf.to_yaml(cfg)}")
logger.info(f"Current working directory: {os.getcwd()}")
logger.info(f"Output directory: {output_dir}")
gfmrag_retriever = GFMRetriever.from_config(cfg)
llm = instantiate(cfg.llm)
agent_prompt_builder = QAPromptBuilder(cfg.agent_prompt)
qa_prompt_builder = QAPromptBuilder(cfg.qa_prompt)
test_data = gfmrag_retriever.qa_data.raw_test_data
max_samples = (
cfg.test.max_test_samples if cfg.test.max_test_samples > 0 else len(test_data)
)
processed_data = {}
if cfg.test.resume:
logger.info(f"Resuming from previous prediction {cfg.test.resume}")
try:
with open(cfg.test.resume) as f:
for line in f:
result = json.loads(line)
processed_data[result["id"]] = result
except Exception as e:
logger.error(f"Could not resume from previous prediction {e}")
with open(os.path.join(output_dir, "prediction.jsonl"), "w") as f:
for i in tqdm(range(max_samples)):
sample = test_data[i]
if i >= max_samples:
break
query = sample["question"]
if sample["id"] in processed_data:
result = processed_data[sample["id"]]
else:
result = agent_reasoning(
cfg, gfmrag_retriever, llm, agent_prompt_builder, query
)
# Generate QA response
retrieved_docs = result["retrieved_docs"]
message = qa_prompt_builder.build_input_prompt(query, retrieved_docs)
qa_response = llm.generate_sentence(message)
result = {
"id": sample["id"],
"question": sample["question"],
"answer": sample["answer"],
"answer_aliases": sample.get(
"answer_aliases", []
), # Some datasets have answer aliases
"supporting_facts": sample["supporting_facts"],
"response": qa_response,
"retrieved_docs": retrieved_docs,
"logs": result["logs"],
}
f.write(json.dumps(result) + "\n")
f.flush()
result_path = os.path.join(output_dir, "prediction.jsonl")
# Evaluation
evaluator = instantiate(cfg.qa_evaluator, prediction_file=result_path)
metrics = evaluator.evaluate()
query_utils.print_metrics(metrics, logger)
# Eval retrieval results
retrieval_evaluator = RetrievalEvaluator(prediction_file=result_path)
retrieval_metrics = retrieval_evaluator.evaluate()
query_utils.print_metrics(retrieval_metrics, logger)
if __name__ == "__main__":
main()
You can overwrite the configuration like this:
Batch Retrieval¶
You can also perform batch retrieval with GFM-RAG with multi GPUs supports by running the following command:
gfmrag/workflow/config/stage3_qa_inference.yaml
hydra:
run:
dir: outputs/qa_inference/${dataset.data_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
defaults:
- _self_
- doc_ranker: idf_topk_ranker
- qa_prompt: hotpotqa
- qa_evaluator: hotpotqa
seed: 1024
dataset:
root: ./data
data_name: hotpotqa_test
llm:
_target_: gfmrag.llms.ChatGPT
model_name_or_path: gpt-3.5-turbo
retry: 5
graph_retriever:
model_path: rmanluo/GFM-RAG-8M
test:
retrieval_batch_size: 8
top_k: 5
save_retrieval: False
save_top_k_entity: 10
n_threads: 5
retrieved_result_path: null
prediction_result_path: null
init_entities_weight: True
gfmrag/workflow/stage3_qa_inference.py
import json
import logging
import os
from multiprocessing.dummy import Pool as ThreadPool
import hydra
import torch
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torch import nn
from torch.utils import data as torch_data
from torch.utils.data import Dataset
from tqdm import tqdm
from gfmrag import utils
from gfmrag.datasets import QADataset
from gfmrag.prompt_builder import QAPromptBuilder
from gfmrag.ultra import query_utils
# A logger for this file
logger = logging.getLogger(__name__)
@torch.no_grad()
def doc_retrieval(
cfg: DictConfig,
model: nn.Module,
qa_data: Dataset,
device: torch.device,
) -> list[dict]:
world_size = utils.get_world_size()
rank = utils.get_rank()
_, test_data = qa_data._data
graph = qa_data.kg
ent2docs = qa_data.ent2docs
# Retrieve the supporting documents for each query
sampler = torch_data.DistributedSampler(test_data, world_size, rank, shuffle=False)
test_loader = torch_data.DataLoader(
test_data, cfg.test.retrieval_batch_size, sampler=sampler
)
# Create doc retriever
doc_ranker = instantiate(cfg.doc_ranker, ent2doc=ent2docs)
if cfg.test.init_entities_weight:
entities_weight = utils.get_entities_weight(ent2docs)
else:
entities_weight = None
model.eval()
all_predictions: list[dict] = []
for batch in tqdm(test_loader):
batch = query_utils.cuda(batch, device=device)
ent_pred = model(graph, batch, entities_weight=entities_weight)
doc_pred = doc_ranker(ent_pred) # Ent2docs mapping
idx = batch["sample_id"]
all_predictions.extend(
{"id": i, "ent_pred": e, "doc_pred": d}
for i, e, d in zip(idx.cpu(), ent_pred.cpu(), doc_pred.cpu())
)
# Gather the predictions across all processes
if utils.get_world_size() > 1:
gathered_predictions = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(gathered_predictions, all_predictions)
else:
gathered_predictions = [all_predictions] # type: ignore
sorted_predictions = sorted(
[item for sublist in gathered_predictions for item in sublist], # type: ignore
key=lambda x: x["id"],
)
utils.synchronize()
return sorted_predictions
def ans_prediction(
cfg: DictConfig, output_dir: str, qa_data: Dataset, retrieval_result: list[dict]
) -> str:
llm = instantiate(cfg.llm)
doc_retriever = utils.DocumentRetriever(qa_data.doc, qa_data.id2doc)
test_data = qa_data.raw_test_data
id2ent = {v: k for k, v in qa_data.ent2id.items()}
prompt_builder = QAPromptBuilder(cfg.qa_prompt)
def predict(qa_input: tuple[dict, torch.Tensor]) -> dict | Exception:
data, retrieval_doc = qa_input
retrieved_ent_idx = torch.topk(
retrieval_doc["ent_pred"], cfg.test.save_top_k_entity, dim=-1
).indices
retrieved_ent = [id2ent[i.item()] for i in retrieved_ent_idx]
retrieved_docs = doc_retriever(retrieval_doc["doc_pred"], top_k=cfg.test.top_k)
message = prompt_builder.build_input_prompt(data["question"], retrieved_docs)
response = llm.generate_sentence(message)
if isinstance(response, Exception):
return response
else:
return {
"id": data["id"],
"question": data["question"],
"answer": data["answer"],
"answer_aliases": data.get(
"answer_aliases", []
), # Some datasets have answer aliases
"response": response,
"retrieved_ent": retrieved_ent,
"retrieved_docs": retrieved_docs,
}
with open(os.path.join(output_dir, "prediction.jsonl"), "w") as f:
with ThreadPool(cfg.test.n_threads) as pool:
for results in tqdm(
pool.imap(predict, zip(test_data, retrieval_result)),
total=len(test_data),
):
if isinstance(results, Exception):
logger.error(f"Error: {results}")
continue
f.write(json.dumps(results) + "\n")
f.flush()
return os.path.join(output_dir, "prediction.jsonl")
@hydra.main(config_path="config", config_name="stage3_qa_inference", version_base=None)
def main(cfg: DictConfig) -> None:
output_dir = HydraConfig.get().runtime.output_dir
utils.init_distributed_mode()
torch.manual_seed(cfg.seed + utils.get_rank())
if utils.get_rank() == 0:
logger.info(f"Config:\n {OmegaConf.to_yaml(cfg)}")
logger.info(f"Current working directory: {os.getcwd()}")
logger.info(f"Output directory: {output_dir}")
model, model_config = utils.load_model_from_pretrained(
cfg.graph_retriever.model_path
)
qa_data = QADataset(
**cfg.dataset,
text_emb_model_cfgs=OmegaConf.create(model_config["text_emb_model_config"]),
)
device = utils.get_device()
model = model.to(device)
qa_data.kg = qa_data.kg.to(device)
qa_data.ent2docs = qa_data.ent2docs.to(device)
if cfg.test.retrieved_result_path:
retrieval_result = torch.load(cfg.test.retrieved_result_path, weights_only=True)
else:
if cfg.test.prediction_result_path:
retrieval_result = None
else:
retrieval_result = doc_retrieval(cfg, model, qa_data, device=device)
if utils.is_main_process():
if cfg.test.save_retrieval and retrieval_result is not None:
logger.info(
f"Ranking saved to disk: {os.path.join(output_dir, 'retrieval_result.pt')}"
)
torch.save(
retrieval_result, os.path.join(output_dir, "retrieval_result.pt")
)
if cfg.test.prediction_result_path:
output_path = cfg.test.prediction_result_path
else:
output_path = ans_prediction(cfg, output_dir, qa_data, retrieval_result)
# Evaluation
evaluator = instantiate(cfg.qa_evaluator, prediction_file=output_path)
metrics = evaluator.evaluate()
query_utils.print_metrics(metrics, logger)
return metrics
if __name__ == "__main__":
main()
python -m gfmrag.workflow.stage3_qa_inference
# Multi-GPU retrieval
torchrun --nproc_per_node=4 -m gfmrag.workflow.stage3_qa_inference
You can overwrite the configuration like this: