Skip to content

GFM-RAG KGC Training Config

This page documents gfmrag/workflow/config/gfm_rag/kgc_training.yaml.

Purpose

This preset is used by python -m gfmrag.workflow.kgc_training to train the original GFM-RAG query-side graph model on graph construction data.

gfmrag/workflow/config/gfm_rag/kgc_training.yaml

gfmrag/workflow/config/gfm_rag/kgc_training.yaml
hydra:
  run:
    dir: outputs/kg_pretrain/${now:%Y-%m-%d}/${now:%H-%M-%S} # Output directory
  searchpath:
    - pkg://gfmrag.workflow.config

defaults:
  - _self_
  - text_emb_model: mpnet # The text embedding model to use
  - wandb: default # Weights & Biases configuration

# Misc
seed: 1024
timeout: 60 # timeout minutes for multi-gpu training
save_pretrained: no # Save the model in pre-trained format
load_model_from_pretrained: null # Load model from pre-trained format, which would overwrite the model configuration

datasets:
  _target_: gfmrag.graph_index_datasets.GraphIndexDatasetV1 # The KG dataset class
  cfgs:
    root: ./data # data root directory
    force_reload: False # Whether to force rebuild the dataset
    text_emb_model_cfgs: ${text_emb_model} # The text embedding model configuration
    target_type: entity # Only target type nodes are used for graph index construction, need to be one of the node type names in the graph
    use_node_feat: False # GFM-RAG v1 does not use node features
    use_edge_feat: False # GFM-RAG v1 does not use edge features
    use_relation_feat: True
    inverse_relation_feat: text # The GFM-RAG v1 add "inverse" to the original relation names to generate inverse relation features.
  train_names: # List of training dataset names
    - hotpotqa_train_example
  valid_names: []
  init_datasets: True # Whether to pre-process datasets at the beginning, if true, it will pre-process all datasets in the train_names and valid_names at the beginning
  feat_dim: 768 # Feature dimension for the embeddings, must be given if init_datasets is False
  max_datasets_in_memory: 10 # Number of datasets to load into memory at once
  data_loading_workers: 4 # Number of workers for data loading

# GFM model configuration
model:
  _target_: gfmrag.models.gfm_rag_v1.QueryGNN
  entity_model:
    _target_: gfmrag.models.ultra.models.QueryNBFNet
    input_dim: 512
    hidden_dims: [512, 512, 512, 512, 512, 512]
    message_func: distmult
    aggregate_func: sum
    short_cut: yes
    layer_norm: yes

optimizer:
  _target_: torch.optim.AdamW
  lr: 5.0e-4

# Training configuration
trainer:
  _target_: gfmrag.trainers.KGCTrainer
  args:
    _target_: gfmrag.trainers.TrainingArguments
    train_batch_size: 8
    num_epoch: 10
    logging_steps: 100
    max_steps_per_epoch: null
    resume_from_checkpoint: null
    save_best_only: no
    metric_for_best_model: mrr
  num_negative: 256
  strict_negative: yes
  adversarial_temperature: 1
  metrics: [mr, mrr, hits@1, hits@3, hits@10]
  fast_test: 500

Top-level Fields

Parameter Options Note
hydra.run.dir outputs/kg_pretrain/<date>/<time>/ Directory used by Hydra for runtime logs and outputs.
defaults List of config groups Pulls in text_emb_model and wandb.
datasets Mapping Defines the training and validation dataset loader.
model Mapping Configures gfmrag.models.gfm_rag_v1.QueryGNN.
optimizer Mapping Sets optimizer type and learning rate.
trainer Mapping Configures KGCTrainer and negative-sampling behavior.

datasets Fields

The default dataset class is GraphIndexDatasetV1.

Parameter Options Note
datasets.cfgs.root Any valid data root Root directory containing datasets.
datasets.cfgs.text_emb_model_cfgs ${text_emb_model} Shared text embedding model config.
datasets.cfgs.target_type Node type string Node type used as the graph target.
datasets.train_names List of dataset names Training dataset list.
datasets.valid_names List of dataset names Validation dataset list.
datasets.init_datasets True, False Whether to preprocess all listed datasets before training.
datasets.feat_dim Positive integer Required when init_datasets=false.

model Fields

The default model target is gfmrag.models.gfm_rag_v1.QueryGNN.

The nested entity_model block configures QueryNBFNet, including:

Parameter Options Note
entity_model.input_dim Positive integer Input embedding dimension.
entity_model.hidden_dims List of integers Hidden dimensions of each layer.
entity_model.message_func distmult and others Message function used by the graph encoder.
entity_model.aggregate_func sum and others Aggregation function used by the graph encoder.
entity_model.short_cut yes, no Whether to enable shortcut connections.
entity_model.layer_norm yes, no Whether to enable layer normalization.

trainer Fields

Parameter Options Note
trainer.args.train_batch_size Positive integer Training batch size.
trainer.args.num_epoch Positive integer Number of training epochs.
trainer.args.logging_steps Positive integer Logging interval in steps.
trainer.num_negative Positive integer Number of negative samples per query.
trainer.strict_negative True, False Whether to sample strict negatives.
trainer.adversarial_temperature Float Negative-sampling temperature.
trainer.metrics List of metric names Evaluation metrics such as mr, mrr, hits@k.
trainer.fast_test Positive integer Number of samples used in fast evaluation mode.