Skip to content

GFM-RAG Training

You can further fine-tune the pre-trained GFM-RAG model on your own dataset to improve the performance of the model on your specific domain.

Data Preparation

Please follow the instructions in the Data Preparation to prepare your dataset in the following structure: Make sure to have the train.json to perform the fine-tuning.

Text Only
data_name/
├── raw/
│   ├── dataset_corpus.json
│   ├── train.json
│   └── test.json # (optional)
└── processed/
    └── stage1/
        ├── kg.txt
        ├── document2entities.json
        ├── train.json
        └── test.json # (optional)

GFM Fine-tuning

During fine-tuning, the GFM model will be trained on the query-documents pairs train.json from the labeled dataset to learn complex relationships for retrieval.

It can be conducted on your own dataset to improve the performance of the model on your specific domain.

An example of the training data:

JSON
[
    {
        "id": "5abc553a554299700f9d7871",
        "question": "Kyle Ezell is a professor at what School of Architecture building at Ohio State?",
        "answer": "Knowlton Hall",
        "supporting_facts": [
            "Knowlton Hall",
            "Kyle Ezell"
        ],
        "question_entities": [
            "kyle ezell",
            "architectural association school of architecture",
            "ohio state"
        ],
        "supporting_entities": [
            "10 million donation",
            "2004",
            "architecture",
            "austin e  knowlton",
            "austin e  knowlton school of architecture",
            "bachelor s in architectural engineering",
            "city and regional planning",
            "columbus  ohio  united states",
            "ives hall",
            "july 2002",
            "knowlton hall",
            "ksa",
        ]
    },
    ...
]

Note

We have already released the pre-trained model checkpoint, which can be used for further finetuning. The model will be automatically downloaded by specifying it in the configuration.

YAML
checkpoint: rmanluo/GFM-RAG-8M

You need to create a configuration file for fine-tuning.

gfmrag/workflow/config/stage2_qa_finetune.yaml
gfmrag/workflow/config/stage2_qa_finetune.yaml
hydra:
  run:
    dir: outputs/qa_finetune/${now:%Y-%m-%d}/${now:%H-%M-%S} # Output directory

defaults:
  - _self_
  - doc_ranker: idf_topk_ranker # The document ranker to use
  - text_emb_model: mpnet # The text embedding model to use

seed: 1024

datasets:
  _target_: gfmrag.datasets.QADataset # The QA dataset class
  cfgs:
    root: ./data # data root directory
    force_rebuild: False # Whether to force rebuild the dataset
    text_emb_model_cfgs: ${text_emb_model} # The text embedding model configuration
  train_names: # List of training dataset names
    - hotpotqa_train_example
  valid_names: # List of validation dataset names
    - hotpotqa_test
    - musique_test
    - 2wikimultihopqa_test
  init_datasets: True # Whether to pre-process datasets at the beginning, if true, it will pre-process all datasets in the train_names and valid_names at the beginning
  feat_dim: 768 # Feature dimension for the embeddings, must be given if init_datasets is False
  max_datasets_in_memory: 10 # Number of datasets to load into memory at once
  data_loading_workers: 4 # Number of workers for data loading

# GFM model configuration
model:
  _target_: gfmrag.models.GNNRetriever
  entity_model:
    _target_: gfmrag.ultra.models.QueryNBFNet
    input_dim: 512
    hidden_dims: [512, 512, 512, 512, 512, 512]
    message_func: distmult
    aggregate_func: sum
    short_cut: yes
    layer_norm: yes

# Loss configuration
task:
  strict_negative: yes
  metric:
    [mrr, hits@1, hits@2, hits@3, hits@5, hits@10, hits@20, hits@50, hits@100]
  losses:
    - name: ent_bce_loss
      loss:
        _target_: gfmrag.losses.BCELoss
        adversarial_temperature: 0.2
      cfg:
        weight: 0.3
        is_doc_loss: False
    - name: ent_pcr_loss
      loss:
        _target_: gfmrag.losses.ListCELoss
      cfg:
        weight: 0.7
        is_doc_loss: False

# Optimizer configuration
optimizer:
  _target_: torch.optim.AdamW
  lr: 5.0e-4

# Training configuration
train:
  batch_size: 8
  num_epoch: 20
  log_interval: 100
  batch_per_epoch: null
  save_best_only: yes
  save_pretrained: yes # Save the model for QA inference
  do_eval: yes
  timeout: 60 # timeout minutes for multi-gpu training
  init_entities_weight: True

  checkpoint: null

Details of the configuration parameters are explained in the GFM-RAG Fine-tuning Configuration page.

You can fine-tune the pre-trained GFM-RAG model on your dataset using the following command:

gfmrag/workflow/stage2_qa_finetune.py

gfmrag/workflow/stage2_qa_finetune.py
import logging
import os
from itertools import islice

import hydra
import numpy as np
import torch
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torch import distributed as dist
from torch import nn
from torch.nn import functional as F  # noqa:N812
from torch.utils import data as torch_data
from tqdm import tqdm

from gfmrag import utils
from gfmrag.datasets import QADataset
from gfmrag.ultra import query_utils
from gfmrag.utils import GraphDatasetLoader

# A logger for this file
logger = logging.getLogger(__name__)

separator = ">" * 30
line = "-" * 30


def create_qa_dataloader(
    dataset: dict[str, QADataset],
    batch_size: int,
    world_size: int,
    rank: int,
    is_train: bool = True,
    shuffle: bool = True,
) -> dict:
    """
    Create a dataloader for the QA dataset.
    """
    data_name = dataset["data_name"]
    qa_data = dataset["data"]
    train_data, valid_data = qa_data._data
    data = train_data if is_train else valid_data

    sampler = torch_data.DistributedSampler(
        data,
        num_replicas=world_size,
        rank=rank,
        shuffle=shuffle,
    )
    data_loader = torch_data.DataLoader(
        data,
        batch_size=batch_size,
        sampler=sampler,
    )

    # Return data
    return {
        "data_name": data_name,
        "data_loader": data_loader,
        "graph": qa_data.kg,
        "ent2docs": qa_data.ent2docs,
    }


def train_and_validate(
    cfg: DictConfig,
    output_dir: str,
    model: nn.Module,
    train_dataset_loader: GraphDatasetLoader,
    valid_dataset_loader: GraphDatasetLoader,
    device: torch.device,
    batch_per_epoch: int | None = None,
) -> None:
    if cfg.train.num_epoch == 0:
        return

    world_size = utils.get_world_size()
    rank = utils.get_rank()

    optimizer = instantiate(cfg.optimizer, model.parameters())
    start_epoch = 0
    # Load optimizer state and epoch if exists
    if "checkpoint" in cfg.train and cfg.train.checkpoint is not None:
        if os.path.exists(cfg.train.checkpoint):
            state = torch.load(
                cfg.train.checkpoint, map_location="cpu", weights_only=True
            )
            if "optimizer" in state:
                optimizer.load_state_dict(state["optimizer"])
            else:
                logger.warning(
                    f"Optimizer state not found in {cfg.train.checkpoint}, using default optimizer."
                )
            if "epoch" in state:
                start_epoch = state["epoch"]
                logger.warning(f"Resuming training from epoch {start_epoch}.")
        else:
            logger.warning(
                f"Checkpoint {cfg.train.checkpoint} does not exist, using default optimizer."
            )

    # Initialize Losses
    loss_fn_list = []
    has_doc_loss = False
    for loss_cfg in cfg.task.losses:
        loss_fn = instantiate(loss_cfg.loss)
        if loss_cfg.cfg.is_doc_loss:
            has_doc_loss = True
        loss_fn_list.append(
            {
                "name": loss_cfg.name,
                "loss_fn": loss_fn,
                **loss_cfg.cfg,
            }
        )

    if world_size > 1:
        parallel_model = nn.parallel.DistributedDataParallel(model, device_ids=[device])
    else:
        parallel_model = model

    best_result = float("-inf")
    best_epoch = -1

    batch_id = 0
    for i in range(start_epoch, cfg.train.num_epoch):
        epoch = i + 1
        parallel_model.train()

        if utils.get_rank() == 0:
            logger.info(separator)
            logger.info(f"Epoch {epoch} begin")

        losses: dict[str, list] = {loss_dict["name"]: [] for loss_dict in loss_fn_list}
        losses["loss"] = []
        train_dataset_loader.set_epoch(
            epoch
        )  # Make sure the datasets order is the same across all processes
        for train_dataset in train_dataset_loader:
            train_dataset = create_qa_dataloader(
                train_dataset,
                cfg.train.batch_size,
                world_size,
                rank,
                is_train=True,
                shuffle=True,
            )
            train_loader = train_dataset["data_loader"]
            train_loader.sampler.set_epoch(epoch)
            data_name = train_dataset["data_name"]
            graph = train_dataset["graph"].to(device)
            ent2docs = train_dataset["ent2docs"].to(device)
            entities_weight = None
            if cfg.train.init_entities_weight:
                entities_weight = utils.get_entities_weight(ent2docs)
            batch_per_epoch = batch_per_epoch or len(train_loader)
            for batch in tqdm(
                islice(train_loader, batch_per_epoch),
                desc=f"Training Batches: {data_name}: {epoch}",
                total=batch_per_epoch,
                disable=not utils.is_main_process(),
            ):
                batch = query_utils.cuda(batch, device=device)
                pred = parallel_model(graph, batch, entities_weight=entities_weight)
                target = batch["supporting_entities_masks"]  # supporting_entities_mask

                if has_doc_loss:
                    doc_pred = torch.sparse.mm(pred, ent2docs)
                    doc_target = batch["supporting_docs_masks"]  # supporting_docs_mask

                loss = 0
                tmp_losses = {}
                for loss_dict in loss_fn_list:
                    loss_fn = loss_dict["loss_fn"]
                    weight = loss_dict["weight"]
                    if loss_dict["is_doc_loss"]:
                        single_loss = loss_fn(doc_pred, doc_target)
                    else:
                        single_loss = loss_fn(pred, target)
                    tmp_losses[loss_dict["name"]] = single_loss.item()
                    loss += weight * single_loss
                tmp_losses["loss"] = loss.item()  # type: ignore

                loss.backward()  # type: ignore
                optimizer.step()
                optimizer.zero_grad()

                for loss_log in tmp_losses:
                    losses[loss_log].append(tmp_losses[loss_log])

                if utils.get_rank() == 0 and batch_id % cfg.train.log_interval == 0:
                    logger.info(separator)
                    for loss_log in tmp_losses:
                        logger.info(f"{loss_log}: {tmp_losses[loss_log]:g}")
                batch_id += 1

        if utils.get_rank() == 0:
            logger.info(separator)
            logger.info(f"Epoch {epoch} end")
            logger.info(line)
            for loss_log in losses:
                logger.info(
                    f"Avg: {loss_log}: {sum(losses[loss_log]) / len(losses[loss_log]):g}"
                )

        utils.synchronize()

        if cfg.train.do_eval:
            if rank == 0:
                logger.info(separator)
                logger.info("Evaluate on valid")
            result = test(cfg, model, valid_dataset_loader, device=device)
        else:
            result = float("inf")
            best_result = float("-inf")
        if rank == 0:
            if result > best_result:
                best_result = result
                best_epoch = epoch
                logger.info("Save checkpoint to model_best.pth")
                state = {
                    "epoch": epoch,
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                }
                torch.save(state, os.path.join(output_dir, "model_best.pth"))
            if not cfg.train.save_best_only:
                logger.info(f"Save checkpoint to model_epoch_{epoch}.pth")
                state = {
                    "epoch": epoch,
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                }
                torch.save(state, os.path.join(output_dir, f"model_epoch_{epoch}.pth"))
            logger.info(f"Best mrr: {best_result:g} at epoch {best_epoch}")

    if rank == 0:
        logger.info("Load checkpoint from model_best.pth")
    utils.synchronize()
    state = torch.load(
        os.path.join(output_dir, "model_best.pth"),
        map_location=device,
        weights_only=True,
    )
    model.load_state_dict(state["model"])


@torch.no_grad()
def test(
    cfg: DictConfig,
    model: nn.Module,
    test_dataset_loader: GraphDatasetLoader,
    device: torch.device,
    return_metrics: bool = False,
) -> float | dict:
    world_size = utils.get_world_size()
    rank = utils.get_rank()

    # process sequentially of test datasets
    all_metrics = {}
    all_mrr = []
    for dataset in test_dataset_loader:
        dataset = create_qa_dataloader(
            dataset,
            cfg.train.batch_size,
            world_size,
            rank,
            is_train=False,
            shuffle=False,
        )
        test_loader = dataset["data_loader"]
        test_loader.sampler.set_epoch(0)
        data_name = dataset["data_name"]
        graph = dataset["graph"].to(device)
        ent2docs = dataset["ent2docs"].to(device)

        model.eval()
        ent_preds = []
        ent_targets = []
        doc_preds = []
        doc_targets = []

        # Create doc retriever
        doc_ranker = instantiate(
            cfg.doc_ranker,
            ent2doc=ent2docs,
        )

        entities_weight = None
        if cfg.train.init_entities_weight:
            entities_weight = utils.get_entities_weight(ent2docs)

        for batch in tqdm(
            test_loader,
            desc=f"Testing {data_name}",
            disable=not utils.is_main_process(),
        ):
            batch = query_utils.cuda(batch, device=device)
            ent_pred = model(graph, batch, entities_weight=entities_weight)
            doc_pred = doc_ranker(ent_pred)  # Ent2docs mapping
            target_entities_mask = batch[
                "supporting_entities_masks"
            ]  # supporting_entities_mask
            target_docs_mask = batch["supporting_docs_masks"]  # supporting_docs_mask
            target_entities = target_entities_mask.bool()
            target_docs = target_docs_mask.bool()
            ent_ranking, target_ent_ranking = utils.batch_evaluate(
                ent_pred, target_entities
            )
            doc_ranking, target_doc_ranking = utils.batch_evaluate(
                doc_pred, target_docs
            )

            # answer set cardinality prediction
            ent_prob = F.sigmoid(ent_pred)
            num_pred = (ent_prob * (ent_prob > 0.5)).sum(dim=-1)
            num_target = target_entities_mask.sum(dim=-1)
            ent_preds.append((ent_ranking, num_pred))
            ent_targets.append((target_ent_ranking, num_target))

            # document set cardinality prediction
            doc_prob = F.sigmoid(doc_pred)
            num_pred = (doc_prob * (doc_prob > 0.5)).sum(dim=-1)
            num_target = target_docs_mask.sum(dim=-1)
            doc_preds.append((doc_ranking, num_pred))
            doc_targets.append((target_doc_ranking, num_target))

        ent_pred = query_utils.cat(ent_preds)
        ent_target = query_utils.cat(ent_targets)
        doc_pred = query_utils.cat(doc_preds)
        doc_target = query_utils.cat(doc_targets)

        ent_pred, ent_target = utils.gather_results(
            ent_pred, ent_target, rank, world_size, device
        )
        doc_pred, doc_target = utils.gather_results(
            doc_pred, doc_target, rank, world_size, device
        )
        ent_metrics = utils.evaluate(ent_pred, ent_target, cfg.task.metric)
        metrics = {}
        if rank == 0:
            doc_metrics = utils.evaluate(doc_pred, doc_target, cfg.task.metric)
            for key, value in ent_metrics.items():
                metrics[f"ent_{key}"] = value
            for key, value in doc_metrics.items():
                metrics[f"doc_{key}"] = value
            metrics["mrr"] = ent_metrics["mrr"]
            logger.info(f"{'-' * 15} Test on {data_name} {'-' * 15}")
            query_utils.print_metrics(metrics, logger)
        else:
            metrics["mrr"] = ent_metrics["mrr"]
        all_metrics[data_name] = metrics
        all_mrr.append(metrics["mrr"])
    utils.synchronize()
    all_avg_mrr = np.mean(all_mrr)
    return all_avg_mrr if not return_metrics else metrics


@hydra.main(config_path="config", config_name="stage2_qa_finetune", version_base=None)
def main(cfg: DictConfig) -> None:
    utils.init_distributed_mode(cfg.train.timeout)
    torch.manual_seed(cfg.seed + utils.get_rank())
    if utils.get_rank() == 0:
        output_dir = HydraConfig.get().runtime.output_dir
        logger.info(f"Config:\n {OmegaConf.to_yaml(cfg)}")
        logger.info(f"Current working directory: {os.getcwd()}")
        logger.info(f"Output directory: {output_dir}")
        output_dir_list = [output_dir]
    else:
        output_dir_list = [None]
    if utils.get_world_size() > 1:
        dist.broadcast_object_list(
            output_dir_list, src=0
        )  # Use the output dir from rank 0
    output_dir = output_dir_list[0]

    # Initialize the datasets in the each process, make sure they are processed
    if cfg.datasets.init_datasets:
        rel_emb_dim_list = utils.init_multi_dataset(
            cfg, utils.get_world_size(), utils.get_rank()
        )
        rel_emb_dim = set(rel_emb_dim_list)
        assert len(rel_emb_dim) == 1, (
            "All datasets should have the same relation embedding dimension"
        )
    else:
        assert cfg.datasets.feat_dim is not None, (
            "If datasets.init_datasets is False, cfg.datasets.feat_dim must be set"
        )
        rel_emb_dim = {cfg.datasets.feat_dim}
    if utils.get_rank() == 0:
        logger.info(
            f"Datasets {cfg.datasets.train_names} and {cfg.datasets.valid_names} initialized"
        )

    device = utils.get_device()
    model = instantiate(cfg.model, rel_emb_dim=rel_emb_dim.pop())

    if "checkpoint" in cfg.train and cfg.train.checkpoint is not None:
        if os.path.exists(cfg.train.checkpoint):
            state = torch.load(
                cfg.train.checkpoint, map_location="cpu", weights_only=True
            )
            model.load_state_dict(state["model"])
        # Try to load the model from the remote dictionary
        else:
            model, _ = utils.load_model_from_pretrained(cfg.train.checkpoint)

    model = model.to(device)
    if utils.get_rank() == 0:
        num_params = sum(p.numel() for p in model.parameters())
        logger.info(line)
        logger.info(f"Number of parameters: {num_params}")

    train_dataset_loader = GraphDatasetLoader(
        cfg.datasets,
        cfg.datasets.train_names,
        max_datasets_in_memory=cfg.datasets.max_datasets_in_memory,
        data_loading_workers=cfg.datasets.data_loading_workers,
    )
    valid_dataset_loader = GraphDatasetLoader(
        cfg.datasets,
        cfg.datasets.valid_names,
        shuffle=False,
        max_datasets_in_memory=cfg.datasets.max_datasets_in_memory,
        data_loading_workers=cfg.datasets.data_loading_workers,
    )

    train_and_validate(
        cfg,
        output_dir,
        model,
        train_dataset_loader,
        valid_dataset_loader,
        device=device,
        batch_per_epoch=cfg.train.batch_per_epoch,
    )

    if cfg.train.do_eval:
        if utils.get_rank() == 0:
            logger.info(separator)
            logger.info("Evaluate on valid")
        test(cfg, model, valid_dataset_loader, device=device)

    # Save the model into the format for QA inference
    if (
        utils.is_main_process()
        and cfg.train.save_pretrained
        and cfg.train.num_epoch > 0
    ):
        pre_trained_dir = os.path.join(output_dir, "pretrained")
        utils.save_model_to_pretrained(model, cfg, pre_trained_dir)

    # Shutdown the dataset loaders
    train_dataset_loader.shutdown()
    valid_dataset_loader.shutdown()

    utils.synchronize()
    utils.cleanup()


if __name__ == "__main__":
    main()

Bash
python -m gfmrag.workflow.stage2_qa_finetune
# Multi-GPU training
torchrun --nproc_per_node=4 gfmrag.workflow.stage2_qa_finetune
# Multi-node Multi-GPU training
torchrun --nproc_per_node=4 --nnodes=2 gfmrag.workflow.stage2_qa_finetune

You can overwrite the configuration like this:

Bash
python -m gfmrag.workflow.stage2_qa_finetune train.batch_size=4

GFM Pre-training

During pre-training, the GFM model will sample triples from the KG-index kg.txt to construct synthetic queries and target entities for training.

Tip

It is only recommended to conduct pre-training when you want to train the model from scratch or when you have a large amount of unlabeled data.

Tip

It is recommended to conduct fine-tuning after the pre-training to empower the model with the ability to understand user queries and retrieve relevant documents.

An example of the KG-index:

Text Only
fred gehrke,was,american football player
fred gehrke,was,executive
fred gehrke,played for,cleveland   los angeles rams

You need to create a configuration file for pre-training.

gfmrag/workflow/config/stage2_kg_pretrain.yaml
gfmrag/workflow/config/stage2_kg_pretrain.yaml
hydra:
  run:
    dir: outputs/kg_pretrain/${now:%Y-%m-%d}/${now:%H-%M-%S} # Output directory

defaults:
  - _self_
  - text_emb_model: mpnet # The text embedding model to use

seed: 1024

datasets:
  _target_: gfmrag.datasets.KGDataset # The KG dataset class
  cfgs:
    root: ./data # data root directory
    force_rebuild: False # Whether to force rebuild the dataset
    text_emb_model_cfgs: ${text_emb_model} # The text embedding model configuration
  train_names: # List of training dataset names
    - hotpotqa_train_example
  valid_names: []
  init_datasets: True # Whether to pre-process datasets at the beginning, if true, it will pre-process all datasets in the train_names and valid_names at the beginning
  feat_dim: 768 # Feature dimension for the embeddings, must be given if init_datasets is False
  max_datasets_in_memory: 10 # Number of datasets to load into memory at once
  data_loading_workers: 4 # Number of workers for data loading

# GFM model configuration
model:
  _target_: gfmrag.models.QueryGNN
  entity_model:
    _target_: gfmrag.ultra.models.EntityNBFNet
    input_dim: 512
    hidden_dims: [512, 512, 512, 512, 512, 512]
    message_func: distmult
    aggregate_func: sum
    short_cut: yes
    layer_norm: yes

# Loss configuration
task:
  num_negative: 256
  strict_negative: yes
  adversarial_temperature: 1
  metric: [mr, mrr, hits@1, hits@3, hits@10]

optimizer:
  _target_: torch.optim.AdamW
  lr: 5.0e-4

# Training configuration
train:
  batch_size: 8
  num_epoch: 10
  log_interval: 100
  fast_test: 500
  save_best_only: no
  save_pretrained: no # Save the model for QA inference
  batch_per_epoch: null
  timeout: 60 # timeout minutes for multi-gpu training
  # Checkpoint configuration
  checkpoint: null

Details of the configuration parameters are explained in the GFM-RAG Pre-training Config page.

You can pre-train the GFM-RAG model on your dataset using the following command:

gfmrag/workflow/stage2_kg_pretrain.py

gfmrag/workflow/stage2_kg_pretrain.py
--8<--"gfmrag/workflow/stage2_kg_pretrain.py"

Bash
python -m gfmrag.workflow.stage2_kg_pretrain
# Multi-GPU training
torchrun --nproc_per_node=4 gfmrag.workflow.stage2_kg_pretrain
# Multi-node Multi-GPU training
torchrun --nproc_per_node=4 --nnodes=2 gfmrag.workflow.stage2_kg_pretrain

You can overwrite the configuration like this:

Bash
python -m gfmrag.workflow.stage2_kg_pretrain train.batch_size=4