Skip to content

GFM-RAG Retrieval

GFM-RAG can be directly used for retrieval on a given dataset without fine-tuning. We provide an easy-to-use GFMRetriever interface for inference.

Config

You need to create a configuration file for inference.

gfmrag/workflow/config/stage3_qa_ircot_inference.yaml
gfmrag/workflow/config/stage3_qa_ircot_inference.yaml
hydra:
  run:
    dir: outputs/qa_agent_inference/${dataset.data_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} # Output directory

defaults:
  - _self_
  - doc_ranker: idf_topk_ranker # The document ranker to use
  - agent_prompt: hotpotqa_ircot # The agent prompt to use
  - qa_prompt: hotpotqa # The QA prompt to use
  - ner_model: llm_ner_model # The NER model to use
  - el_model: colbert_el_model # The EL model to use
  - qa_evaluator: hotpotqa # The QA evaluator to use

seed: 1024

dataset:
  root: ./data # data root directory
  data_name: hotpotqa_test # data name

llm:
  _target_: gfmrag.llms.ChatGPT # The language model to use
  model_name_or_path: gpt-3.5-turbo # The model name or path
  retry: 5 # Number of retries

graph_retriever:
  model_path: rmanluo/GFM-RAG-8M # Checkpoint path of the pre-trained GFM-RAG model
  doc_ranker: ${doc_ranker} # The document ranker to use
  ner_model: ${ner_model} # The NER model to usek
  el_model: ${el_model} # The EL model to use
  qa_evaluator: ${qa_evaluator} # The QA evaluator to use
  init_entities_weight: True # Whether to initialize the entities weight

test:
  top_k: 10 # Number of documents to retrieve
  max_steps: 2 # Maximum number of steps
  max_test_samples: -1 # -1 for all samples
  resume: null # Resume from previous prediction

Details of the configuration parameters are explained in the GFM-RAG Configuration page.

Initialize GFMRetriever

You can initialize the GFMRetriever with the following code. It will load the pre-trained GFM-RAG model and the KG-index for retrieval.

Python
import logging
import os

import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf

from gfmrag import GFMRetriever

logger = logging.getLogger(__name__)


@hydra.main(
    config_path="config", config_name="stage3_qa_ircot_inference", version_base=None
)
def main(cfg: DictConfig) -> None:
    output_dir = HydraConfig.get().runtime.output_dir
    logger.info(f"Config:\n {OmegaConf.to_yaml(cfg)}")
    logger.info(f"Current working directory: {os.getcwd()}")
    logger.info(f"Output directory: {output_dir}")

    gfmrag_retriever = GFMRetriever.from_config(cfg)

Document Retrieval

You can use GFM-RAG retriever to reason over the KG-index and obtain documents for a given query.

Python
docs = retriever.retrieve("Who is the president of France?", top_k=5)

Question Answering

Python
from hydra.utils import instantiate
from gfmrag.llms import BaseLanguageModel
from gfmrag.prompt_builder import QAPromptBuilder

llm = instantiate(cfg.llm)
qa_prompt_builder = QAPromptBuilder(cfg.qa_prompt)

message = qa_prompt_builder.build_input_prompt(current_query, retrieved_docs)
answer = llm.generate_sentence(message)  # Answer: "Emmanuel Macron"

GFM-RAG + Agent for Multi-step Retrieval

You can also integrate the GFM-RAG with arbitrary reasoning agents to perform multi-step RAG. Here is an example of IRCOT + GFM-RAG:

You can run the following command to perform multi-step reasoning:

gfmrag/workflow/stage3_qa_ircot_inference.py

gfmrag/workflow/stage3_qa_ircot_inference.py
import json
import logging
import os

import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm

from gfmrag import GFMRetriever
from gfmrag.evaluation import RetrievalEvaluator
from gfmrag.llms import BaseLanguageModel
from gfmrag.prompt_builder import QAPromptBuilder
from gfmrag.ultra import query_utils

# A logger for this file
logger = logging.getLogger(__name__)


def agent_reasoning(
    cfg: DictConfig,
    gfmrag_retriever: GFMRetriever,
    llm: BaseLanguageModel,
    qa_prompt_builder: QAPromptBuilder,
    query: str,
) -> dict:
    step = 1
    current_query = query
    thoughts: list[str] = []
    retrieved_docs = gfmrag_retriever.retrieve(current_query, top_k=cfg.test.top_k)
    logs = []
    while step <= cfg.test.max_steps:
        message = qa_prompt_builder.build_input_prompt(
            current_query, retrieved_docs, thoughts
        )
        response = llm.generate_sentence(message)

        if isinstance(response, Exception):
            raise response from None

        thoughts.append(response)

        logs.append(
            {
                "step": step,
                "query": current_query,
                "retrieved_docs": retrieved_docs,
                "response": response,
                "thoughts": thoughts,
            }
        )

        if "So the answer is:" in response:
            break

        step += 1

        new_ret_docs = gfmrag_retriever.retrieve(response, top_k=cfg.test.top_k)

        retrieved_docs_dict = {doc["title"]: doc for doc in retrieved_docs}
        for doc in new_ret_docs:
            if doc["title"] in retrieved_docs_dict:
                if doc["norm_score"] > retrieved_docs_dict[doc["title"]]["norm_score"]:
                    retrieved_docs_dict[doc["title"]]["score"] = doc["score"]
                    retrieved_docs_dict[doc["title"]]["norm_score"] = doc["norm_score"]
            else:
                retrieved_docs_dict[doc["title"]] = doc
        # Sort the retrieved docs by score
        retrieved_docs = sorted(
            retrieved_docs_dict.values(), key=lambda x: x["norm_score"], reverse=True
        )
        # Only keep the top k
        retrieved_docs = retrieved_docs[: cfg.test.top_k]

    final_response = " ".join(thoughts)
    return {"response": final_response, "retrieved_docs": retrieved_docs, "logs": logs}


@hydra.main(
    config_path="config", config_name="stage3_qa_ircot_inference", version_base=None
)
def main(cfg: DictConfig) -> None:
    output_dir = HydraConfig.get().runtime.output_dir
    logger.info(f"Config:\n {OmegaConf.to_yaml(cfg)}")
    logger.info(f"Current working directory: {os.getcwd()}")
    logger.info(f"Output directory: {output_dir}")

    gfmrag_retriever = GFMRetriever.from_config(cfg)
    llm = instantiate(cfg.llm)
    agent_prompt_builder = QAPromptBuilder(cfg.agent_prompt)
    qa_prompt_builder = QAPromptBuilder(cfg.qa_prompt)
    test_data = gfmrag_retriever.qa_data.raw_test_data
    max_samples = (
        cfg.test.max_test_samples if cfg.test.max_test_samples > 0 else len(test_data)
    )
    processed_data = {}
    if cfg.test.resume:
        logger.info(f"Resuming from previous prediction {cfg.test.resume}")
        try:
            with open(cfg.test.resume) as f:
                for line in f:
                    result = json.loads(line)
                    processed_data[result["id"]] = result
        except Exception as e:
            logger.error(f"Could not resume from previous prediction {e}")
    with open(os.path.join(output_dir, "prediction.jsonl"), "w") as f:
        for i in tqdm(range(max_samples)):
            sample = test_data[i]
            if i >= max_samples:
                break
            query = sample["question"]
            if sample["id"] in processed_data:
                result = processed_data[sample["id"]]
            else:
                result = agent_reasoning(
                    cfg, gfmrag_retriever, llm, agent_prompt_builder, query
                )

                # Generate QA response
                retrieved_docs = result["retrieved_docs"]
                message = qa_prompt_builder.build_input_prompt(query, retrieved_docs)
                qa_response = llm.generate_sentence(message)

                result = {
                    "id": sample["id"],
                    "question": sample["question"],
                    "answer": sample["answer"],
                    "answer_aliases": sample.get(
                        "answer_aliases", []
                    ),  # Some datasets have answer aliases
                    "supporting_facts": sample["supporting_facts"],
                    "response": qa_response,
                    "retrieved_docs": retrieved_docs,
                    "logs": result["logs"],
                }
            f.write(json.dumps(result) + "\n")
            f.flush()

    result_path = os.path.join(output_dir, "prediction.jsonl")
    # Evaluation
    evaluator = instantiate(cfg.qa_evaluator, prediction_file=result_path)
    metrics = evaluator.evaluate()
    query_utils.print_metrics(metrics, logger)

    # Eval retrieval results
    retrieval_evaluator = RetrievalEvaluator(prediction_file=result_path)
    retrieval_metrics = retrieval_evaluator.evaluate()
    query_utils.print_metrics(retrieval_metrics, logger)


if __name__ == "__main__":
    main()

Bash
python -m gfmrag.workflow.stage3_qa_ircot_inference

You can overwrite the configuration like this:

Bash
python -m gfmrag.workflow.stage3_qa_ircot_inference test.max_steps=3

Batch Retrieval

You can also perform batch retrieval with GFM-RAG with multi GPUs supports by running the following command:

gfmrag/workflow/config/stage3_qa_inference.yaml
gfmrag/workflow/config/stage3_qa_inference.yaml
hydra:
  run:
    dir: outputs/qa_inference/${dataset.data_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}

defaults:
  - _self_
  - doc_ranker: idf_topk_ranker
  - qa_prompt: hotpotqa
  - qa_evaluator: hotpotqa

seed: 1024

dataset:
  root: ./data
  data_name: hotpotqa_test

llm:
  _target_: gfmrag.llms.ChatGPT
  model_name_or_path: gpt-3.5-turbo
  retry: 5

graph_retriever:
  model_path: rmanluo/GFM-RAG-8M

test:
  retrieval_batch_size: 8
  top_k: 5
  save_retrieval: False
  save_top_k_entity: 10
  n_threads: 5
  retrieved_result_path: null
  prediction_result_path: null
  init_entities_weight: True
gfmrag/workflow/stage3_qa_inference.py

gfmrag/workflow/stage3_qa_inference.py
import json
import logging
import os
from multiprocessing.dummy import Pool as ThreadPool

import hydra
import torch
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torch import nn
from torch.utils import data as torch_data
from torch.utils.data import Dataset
from tqdm import tqdm

from gfmrag import utils
from gfmrag.datasets import QADataset
from gfmrag.prompt_builder import QAPromptBuilder
from gfmrag.ultra import query_utils

# A logger for this file
logger = logging.getLogger(__name__)


@torch.no_grad()
def doc_retrieval(
    cfg: DictConfig,
    model: nn.Module,
    qa_data: Dataset,
    device: torch.device,
) -> list[dict]:
    world_size = utils.get_world_size()
    rank = utils.get_rank()

    _, test_data = qa_data._data
    graph = qa_data.kg
    ent2docs = qa_data.ent2docs

    # Retrieve the supporting documents for each query
    sampler = torch_data.DistributedSampler(test_data, world_size, rank, shuffle=False)
    test_loader = torch_data.DataLoader(
        test_data, cfg.test.retrieval_batch_size, sampler=sampler
    )

    # Create doc retriever
    doc_ranker = instantiate(cfg.doc_ranker, ent2doc=ent2docs)

    if cfg.test.init_entities_weight:
        entities_weight = utils.get_entities_weight(ent2docs)
    else:
        entities_weight = None

    model.eval()
    all_predictions: list[dict] = []
    for batch in tqdm(test_loader):
        batch = query_utils.cuda(batch, device=device)
        ent_pred = model(graph, batch, entities_weight=entities_weight)
        doc_pred = doc_ranker(ent_pred)  # Ent2docs mapping
        idx = batch["sample_id"]
        all_predictions.extend(
            {"id": i, "ent_pred": e, "doc_pred": d}
            for i, e, d in zip(idx.cpu(), ent_pred.cpu(), doc_pred.cpu())
        )

    # Gather the predictions across all processes
    if utils.get_world_size() > 1:
        gathered_predictions = [None] * torch.distributed.get_world_size()
        torch.distributed.all_gather_object(gathered_predictions, all_predictions)
    else:
        gathered_predictions = [all_predictions]  # type: ignore

    sorted_predictions = sorted(
        [item for sublist in gathered_predictions for item in sublist],  # type: ignore
        key=lambda x: x["id"],
    )
    utils.synchronize()
    return sorted_predictions


def ans_prediction(
    cfg: DictConfig, output_dir: str, qa_data: Dataset, retrieval_result: list[dict]
) -> str:
    llm = instantiate(cfg.llm)
    doc_retriever = utils.DocumentRetriever(qa_data.doc, qa_data.id2doc)
    test_data = qa_data.raw_test_data
    id2ent = {v: k for k, v in qa_data.ent2id.items()}

    prompt_builder = QAPromptBuilder(cfg.qa_prompt)

    def predict(qa_input: tuple[dict, torch.Tensor]) -> dict | Exception:
        data, retrieval_doc = qa_input
        retrieved_ent_idx = torch.topk(
            retrieval_doc["ent_pred"], cfg.test.save_top_k_entity, dim=-1
        ).indices
        retrieved_ent = [id2ent[i.item()] for i in retrieved_ent_idx]
        retrieved_docs = doc_retriever(retrieval_doc["doc_pred"], top_k=cfg.test.top_k)

        message = prompt_builder.build_input_prompt(data["question"], retrieved_docs)

        response = llm.generate_sentence(message)
        if isinstance(response, Exception):
            return response
        else:
            return {
                "id": data["id"],
                "question": data["question"],
                "answer": data["answer"],
                "answer_aliases": data.get(
                    "answer_aliases", []
                ),  # Some datasets have answer aliases
                "response": response,
                "retrieved_ent": retrieved_ent,
                "retrieved_docs": retrieved_docs,
            }

    with open(os.path.join(output_dir, "prediction.jsonl"), "w") as f:
        with ThreadPool(cfg.test.n_threads) as pool:
            for results in tqdm(
                pool.imap(predict, zip(test_data, retrieval_result)),
                total=len(test_data),
            ):
                if isinstance(results, Exception):
                    logger.error(f"Error: {results}")
                    continue

                f.write(json.dumps(results) + "\n")
                f.flush()

    return os.path.join(output_dir, "prediction.jsonl")


@hydra.main(config_path="config", config_name="stage3_qa_inference", version_base=None)
def main(cfg: DictConfig) -> None:
    output_dir = HydraConfig.get().runtime.output_dir
    utils.init_distributed_mode()
    torch.manual_seed(cfg.seed + utils.get_rank())
    if utils.get_rank() == 0:
        logger.info(f"Config:\n {OmegaConf.to_yaml(cfg)}")
        logger.info(f"Current working directory: {os.getcwd()}")
        logger.info(f"Output directory: {output_dir}")

    model, model_config = utils.load_model_from_pretrained(
        cfg.graph_retriever.model_path
    )
    qa_data = QADataset(
        **cfg.dataset,
        text_emb_model_cfgs=OmegaConf.create(model_config["text_emb_model_config"]),
    )
    device = utils.get_device()
    model = model.to(device)

    qa_data.kg = qa_data.kg.to(device)
    qa_data.ent2docs = qa_data.ent2docs.to(device)

    if cfg.test.retrieved_result_path:
        retrieval_result = torch.load(cfg.test.retrieved_result_path, weights_only=True)
    else:
        if cfg.test.prediction_result_path:
            retrieval_result = None
        else:
            retrieval_result = doc_retrieval(cfg, model, qa_data, device=device)
    if utils.is_main_process():
        if cfg.test.save_retrieval and retrieval_result is not None:
            logger.info(
                f"Ranking saved to disk: {os.path.join(output_dir, 'retrieval_result.pt')}"
            )
            torch.save(
                retrieval_result, os.path.join(output_dir, "retrieval_result.pt")
            )
        if cfg.test.prediction_result_path:
            output_path = cfg.test.prediction_result_path
        else:
            output_path = ans_prediction(cfg, output_dir, qa_data, retrieval_result)

        # Evaluation
        evaluator = instantiate(cfg.qa_evaluator, prediction_file=output_path)
        metrics = evaluator.evaluate()
        query_utils.print_metrics(metrics, logger)
        return metrics


if __name__ == "__main__":
    main()

Bash
python -m gfmrag.workflow.stage3_qa_inference
# Multi-GPU retrieval
torchrun --nproc_per_node=4 -m gfmrag.workflow.stage3_qa_inference

You can overwrite the configuration like this:

Bash
python -m gfmrag.workflow.stage3_qa_inference test.retrieval_batch_size=4