Skip to content

GFM-RAG Training

You can further fine-tune the pre-trained GFM-RAG model on your own dataset to improve the performance of the model on your specific domain.

Data Preparation

Please follow the instructions in the Data Preparation to prepare your dataset in the following structure: Make sure to have the train.json to perform the fine-tuning.

Text Only
data_name/
├── raw/
│   ├── dataset_corpus.json
│   ├── train.json
│   └── test.json # (optional)
└── processed/
    └── stage1/
        ├── kg.txt
        ├── document2entities.json
        ├── train.json
        └── test.json # (optional)

GFM Fine-tuning

During fine-tuning, the GFM model will be trained on the query-documents pairs train.json from the labeled dataset to learn complex relationships for retrieval.

It can be conducted on your own dataset to improve the performance of the model on your specific domain.

An example of the training data:

JSON
[
    {
        "id": "5abc553a554299700f9d7871",
        "question": "Kyle Ezell is a professor at what School of Architecture building at Ohio State?",
        "answer": "Knowlton Hall",
        "supporting_facts": [
            "Knowlton Hall",
            "Kyle Ezell"
        ],
        "question_entities": [
            "kyle ezell",
            "architectural association school of architecture",
            "ohio state"
        ],
        "supporting_entities": [
            "10 million donation",
            "2004",
            "architecture",
            "austin e  knowlton",
            "austin e  knowlton school of architecture",
            "bachelor s in architectural engineering",
            "city and regional planning",
            "columbus  ohio  united states",
            "ives hall",
            "july 2002",
            "knowlton hall",
            "ksa",
        ]
    },
    ...
]

You need to create a configuration file for fine-tuning.

gfmrag/workflow/config/stage2_qa_finetune.yaml
gfmrag/workflow/config/stage2_qa_finetune.yaml
hydra:
  run:
    dir: outputs/qa_finetune/${now:%Y-%m-%d}/${now:%H-%M-%S} # Output directory

defaults:
  - _self_
  - doc_ranker: idf_topk_ranker # The document ranker to use
  - text_emb_model: mpnet # The text embedding model to use

seed: 1024

datasets:
  _target_: gfmrag.datasets.QADataset # The QA dataset class
  cfgs:
    root: ./data # data root directory
    force_rebuild: False # Whether to force rebuild the dataset
    text_emb_model_cfgs: ${text_emb_model} # The text embedding model configuration
  train_names: # List of training dataset names
    - hotpotqa_train_example
  valid_names: # List of validation dataset names
    - hotpotqa_test
    - musique_test
    - 2wikimultihopqa_test

# GFM model configuration
model:
  _target_: gfmrag.models.GNNRetriever
  entity_model:
    _target_: gfmrag.ultra.models.QueryNBFNet
    input_dim: 512
    hidden_dims: [512, 512, 512, 512, 512, 512]
    message_func: distmult
    aggregate_func: sum
    short_cut: yes
    layer_norm: yes

# Loss configuration
task:
  strict_negative: yes
  metric:
    [mrr, hits@1, hits@2, hits@3, hits@5, hits@10, hits@20, hits@50, hits@100]
  losses:
    - name: ent_bce_loss
      loss:
        _target_: gfmrag.losses.BCELoss
        adversarial_temperature: 0.2
      cfg:
        weight: 0.3
        is_doc_loss: False
    - name: ent_pcr_loss
      loss:
        _target_: gfmrag.losses.ListCELoss
      cfg:
        weight: 0.7
        is_doc_loss: False

# Optimizer configuration
optimizer:
  _target_: torch.optim.AdamW
  lr: 5.0e-4

# Training configuration
train:
  batch_size: 8
  num_epoch: 20
  log_interval: 100
  batch_per_epoch: null
  save_best_only: yes
  save_pretrained: yes # Save the model for QA inference
  do_eval: yes
  timeout: 60 # timeout minutes for multi-gpu training
  init_entities_weight: True

  checkpoint: null

Details of the configuration parameters are explained in the GFM-RAG Fine-tuning Configuration page.

You can fine-tune the pre-trained GFM-RAG model on your dataset using the following command:

gfmrag/workflow/stage2_qa_finetune.py

gfmrag/workflow/stage2_qa_finetune.py
import logging
import os
from itertools import islice

import hydra
import numpy as np
import torch
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torch import nn
from torch.nn import functional as F  # noqa:N812
from torch.utils import data as torch_data
from tqdm import tqdm

from gfmrag import utils
from gfmrag.ultra import query_utils

# A logger for this file
logger = logging.getLogger(__name__)

separator = ">" * 30
line = "-" * 30


def train_and_validate(
    cfg: DictConfig,
    output_dir: str,
    model: nn.Module,
    train_datasets: dict,
    valid_datasets: dict,
    device: torch.device,
    batch_per_epoch: int | None = None,
) -> None:
    if cfg.train.num_epoch == 0:
        return

    world_size = utils.get_world_size()
    rank = utils.get_rank()

    # Create dataloader for each dataset
    train_dataloader_dict = {}
    for data_name, dataset in train_datasets.items():
        train_data = dataset["data"]
        sampler = torch_data.DistributedSampler(train_data, world_size, rank)
        train_loader = torch_data.DataLoader(
            train_data, cfg.train.batch_size, sampler=sampler
        )
        train_dataloader_dict[data_name] = train_loader
    data_name_list = list(train_dataloader_dict.keys())

    batch_per_epoch = batch_per_epoch or len(train_loader)

    optimizer = instantiate(cfg.optimizer, model.parameters())

    num_params = sum(p.numel() for p in model.parameters())
    logger.warning(line)
    logger.warning(f"Number of parameters: {num_params}")

    # Initialize Losses
    loss_fn_list = []
    has_doc_loss = False
    for loss_cfg in cfg.task.losses:
        loss_fn = instantiate(loss_cfg.loss)
        if loss_cfg.cfg.is_doc_loss:
            has_doc_loss = True
        loss_fn_list.append(
            {
                "name": loss_cfg.name,
                "loss_fn": loss_fn,
                **loss_cfg.cfg,
            }
        )

    if world_size > 1:
        parallel_model = nn.parallel.DistributedDataParallel(model, device_ids=[device])
    else:
        parallel_model = model

    best_result = float("-inf")
    best_epoch = -1

    batch_id = 0
    for i in range(0, cfg.train.num_epoch):
        epoch = i + 1
        parallel_model.train()

        if utils.get_rank() == 0:
            logger.warning(separator)
            logger.warning("Epoch %d begin" % epoch)

        losses: dict[str, list] = {loss_dict["name"]: [] for loss_dict in loss_fn_list}
        losses["loss"] = []

        for dataloader in train_dataloader_dict.values():
            dataloader.sampler.set_epoch(epoch)
        # np.random.seed(epoch)  # TODO: should we use the same dataloader for all processes?
        shuffled_data_name_list = np.random.permutation(
            data_name_list
        )  # Shuffle the dataloaders
        for data_name in shuffled_data_name_list:
            train_loader = train_dataloader_dict[data_name]
            graph = train_datasets[data_name]["graph"]
            ent2docs = train_datasets[data_name]["ent2docs"]
            entities_weight = None
            if cfg.train.init_entities_weight:
                entities_weight = utils.get_entities_weight(ent2docs)
            for batch in tqdm(
                islice(train_loader, batch_per_epoch),
                desc=f"Training Batches: {epoch}",
                total=batch_per_epoch,
                disable=not utils.is_main_process(),
            ):
                batch = query_utils.cuda(batch, device=device)
                pred = parallel_model(graph, batch, entities_weight=entities_weight)
                target = batch["supporting_entities_masks"]  # supporting_entities_mask

                if has_doc_loss:
                    doc_pred = torch.sparse.mm(pred, ent2docs)
                    doc_target = batch["supporting_docs_masks"]  # supporting_docs_mask

                loss = 0
                tmp_losses = {}
                for loss_dict in loss_fn_list:
                    loss_fn = loss_dict["loss_fn"]
                    weight = loss_dict["weight"]
                    if loss_dict["is_doc_loss"]:
                        single_loss = loss_fn(doc_pred, doc_target)
                    else:
                        single_loss = loss_fn(pred, target)
                    tmp_losses[loss_dict["name"]] = single_loss.item()
                    loss += weight * single_loss
                tmp_losses["loss"] = loss.item()  # type: ignore

                loss.backward()  # type: ignore
                optimizer.step()
                optimizer.zero_grad()

                for loss_log in tmp_losses:
                    losses[loss_log].append(tmp_losses[loss_log])

                if utils.get_rank() == 0 and batch_id % cfg.train.log_interval == 0:
                    logger.warning(separator)
                    for loss_log in tmp_losses:
                        logger.warning(f"{loss_log}: {tmp_losses[loss_log]:g}")
                batch_id += 1

        if utils.get_rank() == 0:
            logger.warning(separator)
            logger.warning("Epoch %d end" % epoch)
            logger.warning(line)
            for loss_log in losses:
                logger.warning(
                    f"Avg: {loss_log}: {sum(losses[loss_log]) / len(losses[loss_log]):g}"
                )

        utils.synchronize()

        if cfg.train.do_eval:
            if rank == 0:
                logger.warning(separator)
                logger.warning("Evaluate on valid")
            result = test(cfg, model, valid_datasets, device=device)
        else:
            result = float("inf")
            best_result = float("-inf")
        if rank == 0:
            if result > best_result:
                best_result = result
                best_epoch = epoch
                logger.warning("Save checkpoint to model_best.pth")
                state = {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                }
                torch.save(state, os.path.join(output_dir, "model_best.pth"))
            if not cfg.train.save_best_only:
                logger.warning("Save checkpoint to model_epoch_%d.pth" % epoch)
                state = {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                }
                torch.save(
                    state, os.path.join(output_dir, "model_epoch_%d.pth" % epoch)
                )
            logger.warning(f"Best mrr: {best_result:g} at epoch {best_epoch}")

    if rank == 0:
        logger.warning("Load checkpoint from model_best.pth")
    state = torch.load(
        os.path.join(output_dir, "model_best.pth"),
        map_location=device,
        weights_only=False,
    )
    model.load_state_dict(state["model"])
    utils.synchronize()


@torch.no_grad()
def test(
    cfg: DictConfig,
    model: nn.Module,
    test_datasets: dict,
    device: torch.device,
    return_metrics: bool = False,
) -> float | dict:
    world_size = utils.get_world_size()
    rank = utils.get_rank()

    # process sequentially of test datasets
    all_metrics = {}
    all_mrr = []
    for data_name, q_data in test_datasets.items():
        test_data = q_data["data"]
        graph = q_data["graph"]
        ent2docs = q_data["ent2docs"]
        sampler = torch_data.DistributedSampler(test_data, world_size, rank)
        test_loader = torch_data.DataLoader(
            test_data, cfg.train.batch_size, sampler=sampler
        )

        model.eval()
        ent_preds = []
        ent_targets = []
        doc_preds = []
        doc_targets = []

        # Create doc retriever
        doc_ranker = instantiate(
            cfg.doc_ranker,
            ent2doc=ent2docs,
        )

        entities_weight = None
        if cfg.train.init_entities_weight:
            entities_weight = utils.get_entities_weight(ent2docs)

        for batch in tqdm(
            test_loader,
            desc=f"Testing {data_name}",
            disable=not utils.is_main_process(),
        ):
            batch = query_utils.cuda(batch, device=device)
            ent_pred = model(graph, batch, entities_weight=entities_weight)
            doc_pred = doc_ranker(ent_pred)  # Ent2docs mapping
            target_entities_mask = batch[
                "supporting_entities_masks"
            ]  # supporting_entities_mask
            target_docs_mask = batch["supporting_docs_masks"]  # supporting_docs_mask
            target_entities = target_entities_mask.bool()
            target_docs = target_docs_mask.bool()
            ent_ranking, target_ent_ranking = utils.batch_evaluate(
                ent_pred, target_entities
            )
            doc_ranking, target_doc_ranking = utils.batch_evaluate(
                doc_pred, target_docs
            )

            # answer set cardinality prediction
            ent_prob = F.sigmoid(ent_pred)
            num_pred = (ent_prob * (ent_prob > 0.5)).sum(dim=-1)
            num_target = target_entities_mask.sum(dim=-1)
            ent_preds.append((ent_ranking, num_pred))
            ent_targets.append((target_ent_ranking, num_target))

            # document set cardinality prediction
            doc_prob = F.sigmoid(doc_pred)
            num_pred = (doc_prob * (doc_prob > 0.5)).sum(dim=-1)
            num_target = target_docs_mask.sum(dim=-1)
            doc_preds.append((doc_ranking, num_pred))
            doc_targets.append((target_doc_ranking, num_target))

        ent_pred = query_utils.cat(ent_preds)
        ent_target = query_utils.cat(ent_targets)
        doc_pred = query_utils.cat(doc_preds)
        doc_target = query_utils.cat(doc_targets)

        ent_pred, ent_target = utils.gather_results(
            ent_pred, ent_target, rank, world_size, device
        )
        doc_pred, doc_target = utils.gather_results(
            doc_pred, doc_target, rank, world_size, device
        )
        ent_metrics = utils.evaluate(ent_pred, ent_target, cfg.task.metric)
        metrics = {}
        if rank == 0:
            doc_metrics = utils.evaluate(doc_pred, doc_target, cfg.task.metric)
            for key, value in ent_metrics.items():
                metrics[f"ent_{key}"] = value
            for key, value in doc_metrics.items():
                metrics[f"doc_{key}"] = value
            metrics["mrr"] = ent_metrics["mrr"]
            logger.warning(f"{'-' * 15} Test on {data_name} {'-' * 15}")
            query_utils.print_metrics(metrics, logger)
        else:
            metrics["mrr"] = ent_metrics["mrr"]
        all_metrics[data_name] = metrics
        all_mrr.append(metrics["mrr"])
    utils.synchronize()
    all_avg_mrr = np.mean(all_mrr)
    return all_avg_mrr if not return_metrics else metrics


@hydra.main(config_path="config", config_name="stage2_qa_finetune", version_base=None)
def main(cfg: DictConfig) -> None:
    output_dir = HydraConfig.get().runtime.output_dir
    utils.init_distributed_mode(cfg.train.timeout)
    torch.manual_seed(cfg.seed + utils.get_rank())
    if utils.get_rank() == 0:
        logger.info(f"Config:\n {OmegaConf.to_yaml(cfg)}")
        logger.info(f"Current working directory: {os.getcwd()}")
        logger.info(f"Output directory: {output_dir}")

    qa_datasets = utils.get_multi_dataset(cfg)
    device = utils.get_device()

    rel_emb_dim = {qa_data.rel_emb_dim for qa_data in qa_datasets.values()}
    assert (
        len(rel_emb_dim) == 1
    ), "All datasets should have the same relation embedding dimension"

    model = instantiate(cfg.model, rel_emb_dim=rel_emb_dim.pop())

    train_datasets = {}
    valid_datasets = {}

    for data_name, qa_data in qa_datasets.items():
        if data_name not in cfg.datasets.train_names + cfg.datasets.valid_names:
            raise ValueError(f"Unknown data name: {data_name}")

        train_data, valid_data = qa_data._data
        graph = qa_data.kg.to(device)
        ent2docs = qa_data.ent2docs.to(device)
        if data_name in cfg.datasets.train_names:
            train_datasets[data_name] = {
                "data": train_data,
                "graph": graph,
                "ent2docs": ent2docs,
            }
        if data_name in cfg.datasets.valid_names:
            valid_datasets[data_name] = {
                "data": valid_data,
                "graph": graph,
                "ent2docs": ent2docs,
            }

    if "checkpoint" in cfg.train and cfg.train.checkpoint is not None:
        if os.path.exists(cfg.train.checkpoint):
            state = torch.load(cfg.train.checkpoint, map_location="cpu")
            model.load_state_dict(state["model"])
        # Try to load the model from the remote dictionary
        else:
            model, _ = utils.load_model_from_pretrained(cfg.train.checkpoint)

    model = model.to(device)
    train_and_validate(
        cfg,
        output_dir,
        model,
        train_datasets,
        valid_datasets,
        device=device,
        batch_per_epoch=cfg.train.batch_per_epoch,
    )

    if cfg.train.do_eval:
        if utils.get_rank() == 0:
            logger.warning(separator)
            logger.warning("Evaluate on valid")
        test(cfg, model, valid_datasets, device=device)

    # Save the model into the format for QA inference
    if (
        utils.is_main_process()
        and cfg.train.save_pretrained
        and cfg.train.num_epoch > 0
    ):
        pre_trained_dir = os.path.join(output_dir, "pretrained")
        utils.save_model_to_pretrained(model, cfg, pre_trained_dir)

    utils.synchronize()
    utils.cleanup()


if __name__ == "__main__":
    main()

Bash
python -m gfmrag.workflow.stage2_qa_finetune
# Multi-GPU training
torchrun --nproc_per_node=4 gfmrag.workflow.stage2_qa_finetune
# Multi-node Multi-GPU training
torchrun --nproc_per_node=4 --nnodes=2 gfmrag.workflow.stage2_qa_finetune

You can overwrite the configuration like this:

Bash
python -m gfmrag.workflow.stage2_qa_finetune train.batch_size=4

GFM Pre-training

During pre-training, the GFM model will sample triples from the KG-index kg.txt to construct synthetic queries and target entities for training.

Tip

It is only recommended to conduct pre-training when you want to train the model from scratch or when you have a large amount of unlabeled data.

Tip

It is recommended to conduct fine-tuning after the pre-training to empower the model with the ability to understand user queries and retrieve relevant documents.

An example of the KG-index:

Text Only
fred gehrke,was,american football player
fred gehrke,was,executive
fred gehrke,played for,cleveland   los angeles rams

You need to create a configuration file for pre-training.

gfmrag/workflow/config/stage2_kg_pretrain.yaml
gfmrag/workflow/config/stage2_kg_pretrain.yaml
hydra:
  run:
    dir: outputs/kg_pretrain/${now:%Y-%m-%d}/${now:%H-%M-%S} # Output directory

defaults:
  - _self_
  - text_emb_model: mpnet # The text embedding model to use

seed: 1024

datasets:
  _target_: gfmrag.datasets.KGDataset # The KG dataset class
  cfgs:
    root: ./data # data root directory
    force_rebuild: False # Whether to force rebuild the dataset
    text_emb_model_cfgs: ${text_emb_model} # The text embedding model configuration
  train_names: # List of training dataset names
    - hotpotqa_train_example
  valid_names: []

# GFM model configuration
model:
  _target_: gfmrag.models.QueryGNN
  entity_model:
    _target_: gfmrag.ultra.models.EntityNBFNet
    input_dim: 512
    hidden_dims: [512, 512, 512, 512, 512, 512]
    message_func: distmult
    aggregate_func: sum
    short_cut: yes
    layer_norm: yes

# Loss configuration
task:
  num_negative: 256
  strict_negative: yes
  adversarial_temperature: 1
  metric: [mr, mrr, hits@1, hits@3, hits@10]

optimizer:
  _target_: torch.optim.AdamW
  lr: 5.0e-4

# Training configuration
train:
  batch_size: 8
  num_epoch: 10
  log_interval: 100
  fast_test: 500
  save_best_only: no
  save_pretrained: no # Save the model for QA inference
  batch_per_epoch: null
  timeout: 60 # timeout minutes for multi-gpu training
  # Checkpoint configuration
  checkpoint: null

Details of the configuration parameters are explained in the GFM-RAG Pre-training Config page.

You can pre-train the GFM-RAG model on your dataset using the following command:

gfmrag/workflow/stage2_kg_pretrain.py

gfmrag/workflow/stage2_kg_pretrain.py
--8<--"gfmrag/workflow/stage2_kg_pretrain.py"

Bash
python -m gfmrag.workflow.stage2_kg_pretrain
# Multi-GPU training
torchrun --nproc_per_node=4 gfmrag.workflow.stage2_kg_pretrain
# Multi-node Multi-GPU training
torchrun --nproc_per_node=4 --nnodes=2 gfmrag.workflow.stage2_kg_pretrain

You can overwrite the configuration like this:

Bash
python -m gfmrag.workflow.stage2_kg_pretrain train.batch_size=4