GFM-RAG SFT Training Config¶
This page documents gfmrag/workflow/config/gfm_rag/sft_training.yaml.
Purpose¶
This preset is used by python -m gfmrag.workflow.sft_training for supervised fine-tuning and retrieval evaluation in the original GFM-RAG model family.
gfmrag/workflow/config/gfm_rag/sft_training.yaml
gfmrag/workflow/config/gfm_rag/sft_training.yaml
hydra:
run:
dir: outputs/qa_finetune/${now:%Y-%m-%d}/${now:%H-%M-%S} # Output directory
searchpath:
- pkg://gfmrag.workflow.config
defaults:
- _self_
- doc_ranker: idf_topk_ranker # The document ranker to use
- text_emb_model: mpnet # The text embedding model to use
- wandb: default # Weights & Biases configuration
# Misc
seed: 1024
timeout: 60 # timeout minutes for multi-gpu training
save_pretrained: no # Save the model in pre-trained format
load_model_from_pretrained: null # Load model from pre-trained format, which would overwrite the model configuration
datasets:
_target_: gfmrag.graph_index_datasets.GraphIndexDatasetV1 # The QA dataset class
cfgs:
root: ./data # data root directory
force_reload: False # Whether to force rebuild the dataset
text_emb_model_cfgs: ${text_emb_model} # The text embedding model configuration
target_type: entity # Only target type nodes are used for graph index construction, need to be one of the node type names in the graph
use_node_feat: False # GFM-RAG v1 does not use node features
use_edge_feat: False # GFM-RAG v1 does not use edge features
use_relation_feat: True
inverse_relation_feat: text # The GFM-RAG v1 add "inverse" to the original relation names to generate inverse relation features.
train_names: # List of training dataset names
- hotpotqa_train_example
valid_names: # List of validation dataset names
- hotpotqa_test
- musique_test
- 2wikimultihopqa_test
init_datasets: True # Whether to pre-process datasets at the beginning, if true, it will pre-process all datasets in the train_names and valid_names at the beginning
feat_dim: 768 # Feature dimension for the embeddings, must be given if init_datasets is False
max_datasets_in_memory: 10 # Number of datasets to load into memory at once
data_loading_workers: 4 # Number of workers for data loading
# GFM model configuration
model:
_target_: gfmrag.models.gfm_rag_v1.GNNRetriever
init_nodes_weight: True # Whether to initialize node weights at the input
init_nodes_type: document # The type of nodes used for weight calculation.
ranker: ${doc_ranker} # The document ranker to use
entity_model:
_target_: gfmrag.models.ultra.models.QueryNBFNet
input_dim: 512
hidden_dims: [512, 512, 512, 512, 512, 512]
message_func: distmult
aggregate_func: sum
short_cut: yes
layer_norm: yes
# Loss configuration
losses:
- name: bce_loss
loss:
_target_: gfmrag.losses.BCELoss
adversarial_temperature: 0.2
weight: 0.3
target_node_type: entity # The type of nodes to apply the BCE loss
- name: pcr_loss
loss:
_target_: gfmrag.losses.ListCELoss
weight: 0.7
target_node_type: entity
# Optimizer configuration
optimizer:
_target_: torch.optim.AdamW
lr: 5.0e-4
# Training configuration
trainer:
_target_: gfmrag.trainers.SFTTrainer
args:
_target_: gfmrag.trainers.TrainingArguments
train_batch_size: 8
num_epoch: 20
logging_steps: 100
max_steps_per_epoch: null
resume_from_checkpoint: null
do_train: true
do_eval: true
save_best_only: yes
metric_for_best_model: document_mrr
metrics: [mrr, hits@1, hits@2, hits@3, hits@5, hits@10, hits@20, hits@50, hits@100]
target_types: [entity, document] # The target node types used for metric calculation
Top-level Fields¶
| Parameter | Options | Note |
|---|---|---|
hydra.run.dir |
outputs/qa_finetune/<date>/<time>/ |
Directory used by Hydra for runtime logs and outputs. |
defaults |
List of config groups | Pulls in doc_ranker, text_emb_model, and wandb. |
datasets |
Mapping | Selects indexed datasets and feature settings. |
model |
Mapping | Configures gfmrag.models.gfm_rag_v1.GNNRetriever. |
losses |
List | Defines one or more supervised losses. |
optimizer |
Mapping | Sets optimizer type and learning rate. |
trainer |
Mapping | Configures SFTTrainer, evaluation metrics, and prediction behavior. |
datasets Fields¶
The default dataset class is GraphIndexDatasetV1.
| Parameter | Options | Note |
|---|---|---|
datasets.cfgs.root |
Any valid data root | Root directory containing indexed datasets. |
datasets.cfgs.target_type |
Node type string | Graph target node type. |
datasets.cfgs.use_node_feat |
True, False |
Whether to use node features. |
datasets.cfgs.use_edge_feat |
True, False |
Whether to use edge features. |
datasets.cfgs.use_relation_feat |
True, False |
Whether to use relation features. |
datasets.train_names |
List of dataset names | Training split list. |
datasets.valid_names |
List of dataset names | Validation split list. |
datasets.max_datasets_in_memory |
Positive integer | Maximum number of datasets kept in memory. |
datasets.data_loading_workers |
Positive integer | Number of background loading workers. |
model Fields¶
The default model target is gfmrag.models.gfm_rag_v1.GNNRetriever.
| Parameter | Options | Note |
|---|---|---|
model.init_nodes_weight |
True, False |
Whether to initialize node weights at input. |
model.init_nodes_type |
Node type string | Node type used for initialization. |
model.ranker |
Mapping | Shared document ranker config. |
model.entity_model |
Mapping | Nested QueryNBFNet settings. |
losses Fields¶
The default preset includes two losses: bce_loss and pcr_loss. Each loss block defines:
| Parameter | Options | Note |
|---|---|---|
losses[].name |
Any string | Human-readable loss name. |
losses[].loss._target_ |
Loss class path | Concrete loss implementation. |
losses[].weight |
Float | Loss weight in the total objective. |
losses[].target_node_type |
Node type string | Target node type for that loss. |
losses[].is_distillation_loss |
True, False |
Optional distillation flag. |
trainer Fields¶
| Parameter | Options | Note |
|---|---|---|
trainer.args.train_batch_size |
Positive integer | Training batch size. |
trainer.args.num_epoch |
Positive integer | Number of epochs. |
trainer.args.do_train |
True, False |
Whether to run training. |
trainer.args.do_eval |
True, False |
Whether to run evaluation. |
trainer.args.save_best_only |
True, False |
Whether to save only the best checkpoint. |
trainer.args.metric_for_best_model |
Metric name | Metric used to select the best checkpoint. |
trainer.metrics |
List of metric names | Evaluation metrics. |
trainer.target_types |
List of node types | Node types used in evaluation. |