Skip to content

G-reasoner SFT Training Configuration

This page documents the supervised fine-tuning presets in gfmrag/workflow/config/gfm_reasoner/.

sft_training.yaml

This preset is used by python -m gfmrag.workflow.sft_training --config-name gfm_reasoner/sft_training.

gfmrag/workflow/config/gfm_reasoner/sft_training.yaml

gfmrag/workflow/config/gfm_reasoner/sft_training.yaml
hydra:
  run:
    dir: outputs/qa_finetune/${now:%Y-%m-%d}/${now:%H-%M-%S} # Output directory
  searchpath:
    - pkg://gfmrag.workflow.config

defaults:
  - _self_
  - text_emb_model: qwen3 # The text embedding model to use
  - wandb: default # Weights & Biases configuration

seed: 1024
timeout: 60 # timeout minutes for multi-gpu training
save_pretrained: no # Save the model in pre-trained format
load_model_from_pretrained: null # Load model from pre-trained format, which would overwrite the model configuration

datasets:
  _target_: gfmrag.graph_index_datasets.GraphIndexDataset # The QA dataset class
  cfgs:
    root: ./data # data root directory
    force_reload: False # Whether to force rebuild the dataset
    text_emb_model_cfgs: ${text_emb_model} # The text embedding model configuration
  train_names: # List of training dataset names
    - hotpotqa_train_example
  valid_names: # List of validation dataset names
    - hotpotqa_test
    - hotpotqa_test_v2
    - musique_test
    - 2wikimultihopqa_test
  init_datasets: True # Whether to pre-process datasets at the beginning, if true, it will pre-process all datasets in the train_names and valid_names at the beginning
  feat_dim: 1024 # Feature dimension for the embeddings, must be given if init_datasets is False
  max_datasets_in_memory: 10 # Number of datasets to load into memory at once
  data_loading_workers: 4 # Number of workers for data loading

# GFM model configuration
model:
  _target_: gfmrag.models.gfm_reasoner.GraphReasoner
  use_ent_emb: early-late-fusion
  # Mixed precision training configuration
  dtype: bfloat16 # Precision type: 'float32', 'float16', 'bfloat16', or 'auto'
  entity_model:
    _target_: gfmrag.models.ultra.models.QueryNBFNet
    input_dim: 1024
    hidden_dims: [1024, 1024, 1024, 1024, 1024, 1024]
    message_func: distmult
    aggregate_func: sum
    short_cut: yes
    layer_norm: yes
    return_hidden: True # Return the hidden states of the entity model

# Loss configuration
losses:
  - name: bce_loss
    loss:
      _target_: gfmrag.losses.BCELoss
      adversarial_temperature: 0.2
    weight: 0.3
    target_node_type: document
  - name: pcr_loss
    loss:
      _target_: gfmrag.losses.ListCELoss
    weight: 0.7
    target_node_type: document
  - name: mse_loss
    loss:
      _target_: gfmrag.losses.MSELoss
    weight: 1
    target_node_type: document
    is_distillation_loss: True
  # - name: kl_loss
  #   loss:
  #     _target_: gfmrag.losses.KLDivLoss
  #   weight: 0.01
  #   target_node_type: document
  #   is_distillation_loss: True

# Optimizer configuration
optimizer:
  _target_: torch.optim.AdamW
  lr: 5.0e-4

# Training configuration
trainer:
  _target_: gfmrag.trainers.SFTTrainer
  args:
    _target_: gfmrag.trainers.TrainingArguments
    train_batch_size: 8
    num_epoch: 20
    logging_steps: 100
    max_steps_per_epoch: null
    resume_from_checkpoint: null
    do_train: true
    do_eval: true
    save_best_only: yes
    metric_for_best_model: document_mrr
    dtype: ${model.dtype} # Use the same precision type
    split_graph_inference: false
    split_graph_training: false
    split_graph_partition: contiguous  # contiguous or metis
  metrics: [mrr, hits@1, hits@2, hits@3, hits@5, hits@10, hits@20, recall@2, recall@3, recall@5, recall@10, recall@20] # List of metrics to calculate
  target_types: [document] # The target node types used for metric calculation

Top-level Fields

Parameter Options Note
hydra.run.dir outputs/qa_finetune/${now:%Y-%m-%d}/${now:%H-%M-%S} Directory used by Hydra for SFT outputs.
hydra.searchpath pkg://gfmrag.workflow.config Adds the packaged workflow config directory to Hydra's search path.
defaults List of config groups Selects the text embedding and wandb presets.
seed Integer Random seed used during training.
timeout Positive integer Timeout in minutes for multi-GPU training.
save_pretrained yes, no Whether to save the trained model in pretrained format.
load_model_from_pretrained File path or null Optional pretrained checkpoint that overrides the model definition.
datasets Mapping Dataset construction and loading options.
model Mapping GraphReasoner model configuration.
losses List Loss definitions used during fine-tuning.
optimizer Mapping Optimizer type and hyperparameters.
trainer Mapping Trainer arguments, evaluation metrics, and target types.

defaults Fields

Parameter Options Note
_self_ Current file Loads the local values in this preset.
text_emb_model qwen3 by default Text embedding preset used by the dataset loader.
wandb default by default Weights and Biases logging preset.

datasets Fields

Parameter Options Note
_target_ gfmrag.graph_index_datasets.GraphIndexDataset Dataset class used for graph-index supervision data.
cfgs.root Any valid data root Root directory that contains the indexed datasets.
cfgs.force_reload True, False Whether to rebuild the dataset cache before loading.
cfgs.text_emb_model_cfgs ${text_emb_model} by default Text embedding config passed to the dataset loader.
train_names List of dataset names Training dataset splits.
valid_names List of dataset names Validation dataset splits.
init_datasets True, False Whether to preprocess all listed datasets before training starts.
feat_dim Positive integer Embedding feature dimension used when datasets are not initialized up front.
max_datasets_in_memory Positive integer Maximum number of datasets kept in memory at once.
data_loading_workers Positive integer Number of worker processes for data loading.

model Fields

Parameter Options Note
_target_ gfmrag.models.gfm_reasoner.GraphReasoner Model class used for G-reasoner SFT.
use_ent_emb early-late-fusion and other supported modes Entity embedding integration mode.
dtype float32, float16, bfloat16, auto Precision mode for the model.
entity_model._target_ gfmrag.models.ultra.models.QueryNBFNet Graph encoder used inside GraphReasoner.
entity_model.input_dim Positive integer Input embedding dimension of the entity model.
entity_model.hidden_dims List of integers Hidden dimensions of each entity-model layer.
entity_model.message_func Supported message functions such as distmult Message function used by the graph encoder.
entity_model.aggregate_func Supported aggregation functions such as sum Aggregation function used by the graph encoder.
entity_model.short_cut yes, no Whether to enable shortcut connections.
entity_model.layer_norm yes, no Whether to enable layer normalization.
entity_model.return_hidden True, False Whether to return hidden states for downstream losses or distillation.

losses Fields

Each entry in losses follows the same schema:

Parameter Options Note
name Any loss name Logical name used to identify the loss block.
loss._target_ Loss class path Concrete loss implementation, such as BCELoss, ListCELoss, or MSELoss.
loss.adversarial_temperature Float Optional temperature used by adversarial BCE loss.
weight Float Weight assigned to this loss in the total objective.
target_node_type document, entity, or another node type Node type supervised by this loss.
is_distillation_loss True, False Marks a loss as distillation-style supervision.

optimizer Fields

Parameter Options Note
_target_ torch.optim.AdamW by default Optimizer class used for training.
lr Positive float Learning rate.

trainer Fields

Parameter Options Note
_target_ gfmrag.trainers.SFTTrainer Trainer class used for fine-tuning.
args._target_ gfmrag.trainers.TrainingArguments Training-argument class used by the trainer.
args.train_batch_size Positive integer Training batch size.
args.num_epoch Positive integer Number of epochs.
args.logging_steps Positive integer Logging interval in steps.
args.max_steps_per_epoch Positive integer or null Optional cap on steps per epoch.
args.resume_from_checkpoint File path or null Resume training from a saved checkpoint.
args.do_train true, false Whether to run training.
args.do_eval true, false Whether to run evaluation.
args.save_best_only yes, no Whether to only keep the best checkpoint.
args.metric_for_best_model Metric name Metric used to select the best checkpoint.
args.dtype ${model.dtype} by default Trainer-side precision mode.
args.split_graph_inference true, false Whether to enable split-graph inference.
args.split_graph_training true, false Whether to enable split-graph training.
args.split_graph_partition contiguous, metis, or supported methods Partition strategy used for split-graph execution.
metrics List of metric names Ranking metrics computed during evaluation.
target_types List of node types Node types included in metric computation.

sft_training_w_answer.yaml

This preset extends the base SFT configuration with additional answer supervision.

gfmrag/workflow/config/gfm_reasoner/sft_training_w_answer.yaml

gfmrag/workflow/config/gfm_reasoner/sft_training_w_answer.yaml
hydra:
  run:
    dir: outputs/qa_finetune/${now:%Y-%m-%d}/${now:%H-%M-%S} # Output directory
  searchpath:
    - pkg://gfmrag.workflow.config

defaults:
  - _self_
  - text_emb_model: qwen3 # The text embedding model to use
  - wandb: default # Weights & Biases configuration

seed: 1024
timeout: 60 # timeout minutes for multi-gpu training
save_pretrained: no # Save the model in pre-trained format
load_model_from_pretrained: null # Load model from pre-trained format, which would overwrite the model configuration

datasets:
  _target_: gfmrag.graph_index_datasets.GraphIndexDataset # The QA dataset class
  cfgs:
    root: ./data # data root directory
    force_reload: False # Whether to force rebuild the dataset
    text_emb_model_cfgs: ${text_emb_model} # The text embedding model configuration
  train_names: # List of training dataset names
    - hotpotqa_train_example
  valid_names: # List of validation dataset names
    - hotpotqa_test
    - hotpotqa_test_v2
    - musique_test
    - 2wikimultihopqa_test
  init_datasets: True # Whether to pre-process datasets at the beginning, if true, it will pre-process all datasets in the train_names and valid_names at the beginning
  feat_dim: 1024 # Feature dimension for the embeddings, must be given if init_datasets is False
  max_datasets_in_memory: 10 # Number of datasets to load into memory at once
  data_loading_workers: 4 # Number of workers for data loading

# GFM model configuration
model:
  _target_: gfmrag.models.gfm_reasoner.GraphReasoner
  use_ent_emb: early-late-fusion
  dtype: bfloat16 # Precision type: 'float32', 'float16', 'bfloat16', or 'auto'
  entity_model:
    _target_: gfmrag.models.ultra.models.QueryNBFNet
    input_dim: 1024
    hidden_dims: [1024, 1024, 1024, 1024, 1024, 1024]
    message_func: distmult
    aggregate_func: sum
    short_cut: yes
    layer_norm: yes
    return_hidden: True # Return the hidden states of the entity model

# Loss configuration
losses:
  - name: bce_loss
    loss:
      _target_: gfmrag.losses.BCELoss
      adversarial_temperature: 0.2
    weight: 0.3
    target_node_type: document
  - name: pcr_loss
    loss:
      _target_: gfmrag.losses.ListCELoss
    weight: 0.7
    target_node_type: document
  - name: mse_loss
    loss:
      _target_: gfmrag.losses.MSELoss
    weight: 1
    target_node_type: document
    is_distillation_loss: True
  - name: bce_loss
    loss:
      _target_: gfmrag.losses.BCELoss
      adversarial_temperature: 0.2
    weight: 0.3
    target_node_type: entity
  - name: pcr_loss
    loss:
      _target_: gfmrag.losses.ListCELoss
    weight: 0.7
    target_node_type: entity
  - name: mse_loss
    loss:
      _target_: gfmrag.losses.MSELoss
    weight: 1
    target_node_type: entity
    is_distillation_loss: True
  # - name: kl_loss
  #   loss:
  #     _target_: gfmrag.losses.KLDivLoss
  #   weight: 0.01
  #   target_node_type: document
  #   is_distillation_loss: True

# Optimizer configuration
optimizer:
  _target_: torch.optim.AdamW
  lr: 5.0e-4

# Training configuration
trainer:
  _target_: gfmrag.trainers.SFTTrainer
  args:
    _target_: gfmrag.trainers.TrainingArguments
    train_batch_size: 8
    num_epoch: 20
    logging_steps: 100
    max_steps_per_epoch: null
    resume_from_checkpoint: null
    do_train: true
    do_eval: true
    save_best_only: yes
    metric_for_best_model: document_mrr
    dtype: ${model.dtype} # Use the same precision type
  metrics: [mrr, hits@1, hits@2, hits@3, hits@5, hits@10, hits@20, recall@2, recall@3, recall@5, recall@10, recall@20] # List of metrics to calculate
  target_types: [document,entity] # The target node types used for metric calculation

Additional Differences

Compared with sft_training.yaml, this variant:

  • adds entity-targeted bce_loss, pcr_loss, and mse_loss blocks
  • changes trainer.target_types from [document] to [document, entity]
  • keeps the same dataset, model, optimizer, and trainer structure