GFM-RAG KGC Training Config¶
This page documents gfmrag/workflow/config/gfm_rag/kgc_training.yaml.
Purpose¶
This preset is used by python -m gfmrag.workflow.kgc_training to train the original GFM-RAG query-side graph model on graph construction data.
gfmrag/workflow/config/gfm_rag/kgc_training.yaml
gfmrag/workflow/config/gfm_rag/kgc_training.yaml
hydra:
run:
dir: outputs/kg_pretrain/${now:%Y-%m-%d}/${now:%H-%M-%S} # Output directory
searchpath:
- pkg://gfmrag.workflow.config
defaults:
- _self_
- text_emb_model: mpnet # The text embedding model to use
- wandb: default # Weights & Biases configuration
# Misc
seed: 1024
timeout: 60 # timeout minutes for multi-gpu training
save_pretrained: no # Save the model in pre-trained format
load_model_from_pretrained: null # Load model from pre-trained format, which would overwrite the model configuration
datasets:
_target_: gfmrag.graph_index_datasets.GraphIndexDatasetV1 # The KG dataset class
cfgs:
root: ./data # data root directory
force_reload: False # Whether to force rebuild the dataset
text_emb_model_cfgs: ${text_emb_model} # The text embedding model configuration
target_type: entity # Only target type nodes are used for graph index construction, need to be one of the node type names in the graph
use_node_feat: False # GFM-RAG v1 does not use node features
use_edge_feat: False # GFM-RAG v1 does not use edge features
use_relation_feat: True
inverse_relation_feat: text # The GFM-RAG v1 add "inverse" to the original relation names to generate inverse relation features.
train_names: # List of training dataset names
- hotpotqa_train_example
valid_names: []
init_datasets: True # Whether to pre-process datasets at the beginning, if true, it will pre-process all datasets in the train_names and valid_names at the beginning
feat_dim: 768 # Feature dimension for the embeddings, must be given if init_datasets is False
max_datasets_in_memory: 10 # Number of datasets to load into memory at once
data_loading_workers: 4 # Number of workers for data loading
# GFM model configuration
model:
_target_: gfmrag.models.gfm_rag_v1.QueryGNN
entity_model:
_target_: gfmrag.models.ultra.models.QueryNBFNet
input_dim: 512
hidden_dims: [512, 512, 512, 512, 512, 512]
message_func: distmult
aggregate_func: sum
short_cut: yes
layer_norm: yes
optimizer:
_target_: torch.optim.AdamW
lr: 5.0e-4
# Training configuration
trainer:
_target_: gfmrag.trainers.KGCTrainer
args:
_target_: gfmrag.trainers.TrainingArguments
train_batch_size: 8
num_epoch: 10
logging_steps: 100
max_steps_per_epoch: null
resume_from_checkpoint: null
save_best_only: no
metric_for_best_model: mrr
num_negative: 256
strict_negative: yes
adversarial_temperature: 1
metrics: [mr, mrr, hits@1, hits@3, hits@10]
fast_test: 500
Top-level Fields¶
| Parameter | Options | Note |
|---|---|---|
hydra.run.dir |
outputs/kg_pretrain/<date>/<time>/ |
Directory used by Hydra for runtime logs and outputs. |
defaults |
List of config groups | Pulls in text_emb_model and wandb. |
datasets |
Mapping | Defines the training and validation dataset loader. |
model |
Mapping | Configures gfmrag.models.gfm_rag_v1.QueryGNN. |
optimizer |
Mapping | Sets optimizer type and learning rate. |
trainer |
Mapping | Configures KGCTrainer and negative-sampling behavior. |
datasets Fields¶
The default dataset class is GraphIndexDatasetV1.
| Parameter | Options | Note |
|---|---|---|
datasets.cfgs.root |
Any valid data root | Root directory containing datasets. |
datasets.cfgs.text_emb_model_cfgs |
${text_emb_model} |
Shared text embedding model config. |
datasets.cfgs.target_type |
Node type string | Node type used as the graph target. |
datasets.train_names |
List of dataset names | Training dataset list. |
datasets.valid_names |
List of dataset names | Validation dataset list. |
datasets.init_datasets |
True, False |
Whether to preprocess all listed datasets before training. |
datasets.feat_dim |
Positive integer | Required when init_datasets=false. |
model Fields¶
The default model target is gfmrag.models.gfm_rag_v1.QueryGNN.
The nested entity_model block configures QueryNBFNet, including:
| Parameter | Options | Note |
|---|---|---|
entity_model.input_dim |
Positive integer | Input embedding dimension. |
entity_model.hidden_dims |
List of integers | Hidden dimensions of each layer. |
entity_model.message_func |
distmult and others |
Message function used by the graph encoder. |
entity_model.aggregate_func |
sum and others |
Aggregation function used by the graph encoder. |
entity_model.short_cut |
yes, no |
Whether to enable shortcut connections. |
entity_model.layer_norm |
yes, no |
Whether to enable layer normalization. |
trainer Fields¶
| Parameter | Options | Note |
|---|---|---|
trainer.args.train_batch_size |
Positive integer | Training batch size. |
trainer.args.num_epoch |
Positive integer | Number of training epochs. |
trainer.args.logging_steps |
Positive integer | Logging interval in steps. |
trainer.num_negative |
Positive integer | Number of negative samples per query. |
trainer.strict_negative |
True, False |
Whether to sample strict negatives. |
trainer.adversarial_temperature |
Float | Negative-sampling temperature. |
trainer.metrics |
List of metric names | Evaluation metrics such as mr, mrr, hits@k. |
trainer.fast_test |
Positive integer | Number of samples used in fast evaluation mode. |