Skip to content

GFM-RAG SFT Training Config

This page documents gfmrag/workflow/config/gfm_rag/sft_training.yaml.

Purpose

This preset is used by python -m gfmrag.workflow.sft_training for supervised fine-tuning and retrieval evaluation in the original GFM-RAG model family.

gfmrag/workflow/config/gfm_rag/sft_training.yaml

gfmrag/workflow/config/gfm_rag/sft_training.yaml
hydra:
  run:
    dir: outputs/qa_finetune/${now:%Y-%m-%d}/${now:%H-%M-%S} # Output directory
  searchpath:
    - pkg://gfmrag.workflow.config


defaults:
  - _self_
  - doc_ranker: idf_topk_ranker # The document ranker to use
  - text_emb_model: mpnet # The text embedding model to use
  - wandb: default # Weights & Biases configuration

# Misc
seed: 1024
timeout: 60 # timeout minutes for multi-gpu training
save_pretrained: no # Save the model in pre-trained format
load_model_from_pretrained: null # Load model from pre-trained format, which would overwrite the model configuration

datasets:
  _target_: gfmrag.graph_index_datasets.GraphIndexDatasetV1 # The QA dataset class
  cfgs:
    root: ./data # data root directory
    force_reload: False # Whether to force rebuild the dataset
    text_emb_model_cfgs: ${text_emb_model} # The text embedding model configuration
    target_type: entity # Only target type nodes are used for graph index construction, need to be one of the node type names in the graph
    use_node_feat: False # GFM-RAG v1 does not use node features
    use_edge_feat: False # GFM-RAG v1 does not use edge features
    use_relation_feat: True
    inverse_relation_feat: text # The GFM-RAG v1 add "inverse" to the original relation names to generate inverse relation features.
  train_names: # List of training dataset names
    - hotpotqa_train_example
  valid_names: # List of validation dataset names
    - hotpotqa_test
    - musique_test
    - 2wikimultihopqa_test
  init_datasets: True # Whether to pre-process datasets at the beginning, if true, it will pre-process all datasets in the train_names and valid_names at the beginning
  feat_dim: 768 # Feature dimension for the embeddings, must be given if init_datasets is False
  max_datasets_in_memory: 10 # Number of datasets to load into memory at once
  data_loading_workers: 4 # Number of workers for data loading

# GFM model configuration
model:
  _target_: gfmrag.models.gfm_rag_v1.GNNRetriever
  init_nodes_weight: True # Whether to initialize node weights at the input
  init_nodes_type: document # The type of nodes used for weight calculation.
  ranker: ${doc_ranker} # The document ranker to use
  entity_model:
    _target_: gfmrag.models.ultra.models.QueryNBFNet
    input_dim: 512
    hidden_dims: [512, 512, 512, 512, 512, 512]
    message_func: distmult
    aggregate_func: sum
    short_cut: yes
    layer_norm: yes

# Loss configuration
losses:
  - name: bce_loss
    loss:
      _target_: gfmrag.losses.BCELoss
      adversarial_temperature: 0.2
    weight: 0.3
    target_node_type: entity # The type of nodes to apply the BCE loss
  - name: pcr_loss
    loss:
      _target_: gfmrag.losses.ListCELoss
    weight: 0.7
    target_node_type: entity

# Optimizer configuration
optimizer:
  _target_: torch.optim.AdamW
  lr: 5.0e-4


# Training configuration
trainer:
  _target_: gfmrag.trainers.SFTTrainer
  args:
    _target_: gfmrag.trainers.TrainingArguments
    train_batch_size: 8
    num_epoch: 20
    logging_steps: 100
    max_steps_per_epoch: null
    resume_from_checkpoint: null
    do_train: true
    do_eval: true
    save_best_only: yes
    metric_for_best_model: document_mrr
  metrics: [mrr, hits@1, hits@2, hits@3, hits@5, hits@10, hits@20, hits@50, hits@100]
  target_types: [entity, document] # The target node types used for metric calculation

Top-level Fields

Parameter Options Note
hydra.run.dir outputs/qa_finetune/<date>/<time>/ Directory used by Hydra for runtime logs and outputs.
defaults List of config groups Pulls in doc_ranker, text_emb_model, and wandb.
datasets Mapping Selects indexed datasets and feature settings.
model Mapping Configures gfmrag.models.gfm_rag_v1.GNNRetriever.
losses List Defines one or more supervised losses.
optimizer Mapping Sets optimizer type and learning rate.
trainer Mapping Configures SFTTrainer, evaluation metrics, and prediction behavior.

datasets Fields

The default dataset class is GraphIndexDatasetV1.

Parameter Options Note
datasets.cfgs.root Any valid data root Root directory containing indexed datasets.
datasets.cfgs.target_type Node type string Graph target node type.
datasets.cfgs.use_node_feat True, False Whether to use node features.
datasets.cfgs.use_edge_feat True, False Whether to use edge features.
datasets.cfgs.use_relation_feat True, False Whether to use relation features.
datasets.train_names List of dataset names Training split list.
datasets.valid_names List of dataset names Validation split list.
datasets.max_datasets_in_memory Positive integer Maximum number of datasets kept in memory.
datasets.data_loading_workers Positive integer Number of background loading workers.

model Fields

The default model target is gfmrag.models.gfm_rag_v1.GNNRetriever.

Parameter Options Note
model.init_nodes_weight True, False Whether to initialize node weights at input.
model.init_nodes_type Node type string Node type used for initialization.
model.ranker Mapping Shared document ranker config.
model.entity_model Mapping Nested QueryNBFNet settings.

losses Fields

The default preset includes two losses: bce_loss and pcr_loss. Each loss block defines:

Parameter Options Note
losses[].name Any string Human-readable loss name.
losses[].loss._target_ Loss class path Concrete loss implementation.
losses[].weight Float Loss weight in the total objective.
losses[].target_node_type Node type string Target node type for that loss.
losses[].is_distillation_loss True, False Optional distillation flag.

trainer Fields

Parameter Options Note
trainer.args.train_batch_size Positive integer Training batch size.
trainer.args.num_epoch Positive integer Number of epochs.
trainer.args.do_train True, False Whether to run training.
trainer.args.do_eval True, False Whether to run evaluation.
trainer.args.save_best_only True, False Whether to save only the best checkpoint.
trainer.args.metric_for_best_model Metric name Metric used to select the best checkpoint.
trainer.metrics List of metric names Evaluation metrics.
trainer.target_types List of node types Node types used in evaluation.