GFM-RAG Fine-tuning Configuration
An example configuration file for GFM fine-tuning is shown below:
Example
gfmrag/workflow/config/stage2_qa_finetune.yamlhydra:
run:
dir: outputs/qa_finetune/${now:%Y-%m-%d}/${now:%H-%M-%S} # Output directory
defaults:
- _self_
- doc_ranker: idf_topk_ranker # The document ranker to use
- text_emb_model: mpnet # The text embedding model to use
seed: 1024
datasets:
_target_: gfmrag.datasets.QADataset # The QA dataset class
cfgs:
root: ./data # data root directory
force_rebuild: False # Whether to force rebuild the dataset
text_emb_model_cfgs: ${text_emb_model} # The text embedding model configuration
train_names: # List of training dataset names
- hotpotqa_train_example
valid_names: # List of validation dataset names
- hotpotqa_test
- musique_test
- 2wikimultihopqa_test
# GFM model configuration
model:
_target_: gfmrag.models.GNNRetriever
entity_model:
_target_: gfmrag.ultra.models.QueryNBFNet
input_dim: 512
hidden_dims: [512, 512, 512, 512, 512, 512]
message_func: distmult
aggregate_func: sum
short_cut: yes
layer_norm: yes
# Loss configuration
task:
strict_negative: yes
metric:
[mrr, hits@1, hits@2, hits@3, hits@5, hits@10, hits@20, hits@50, hits@100]
losses:
- name: ent_bce_loss
loss:
_target_: gfmrag.losses.BCELoss
adversarial_temperature: 0.2
cfg:
weight: 0.3
is_doc_loss: False
- name: ent_pcr_loss
loss:
_target_: gfmrag.losses.ListCELoss
cfg:
weight: 0.7
is_doc_loss: False
# Optimizer configuration
optimizer:
_target_: torch.optim.AdamW
lr: 5.0e-4
# Training configuration
train:
batch_size: 8
num_epoch: 20
log_interval: 100
batch_per_epoch: null
save_best_only: yes
save_pretrained: yes # Save the model for QA inference
do_eval: yes
timeout: 60 # timeout minutes for multi-gpu training
init_entities_weight: True
checkpoint: null
General Configuration
Parameter |
Options |
Note |
run.dir |
None |
The output directory of the log |
Defaults
Training datasets
Parameter |
Options |
Note |
_target_ |
None |
QADataset |
cfgs.root |
None |
root dictionary of the datasets saving path |
cfgs.force_rebuild |
None |
whether to force rebuild the dataset |
cfgs.text_emb_model_cfgs |
None |
text embedding modelconfiguration |
train_names |
[] |
List of training dataset names |
valid_names |
[] |
List of validation dataset names |
GFM model configuration
Parameter |
Options |
Note |
_target_ |
None |
QueryGNN model |
entity_model |
None |
EntityNBFNet model |
input_dim |
None |
input dimension of the model |
hidden_dims |
[] |
hidden dimensions of the model |
message_func |
transe ,rotate ,distmult |
message function of the model |
aggregate_func |
pna ,min ,max ,mean ,sum |
aggregate function of the model |
short_cut |
True , False |
whether to use short cut |
layer_norm |
True , False |
whether to use layer norm |
Loss configuration
Parameter |
Options |
Note |
``` |
|
|
strict_negative |
None |
whether to use strict negative sampling |
metric |
None |
evaluation metrics to use |
losses |
None |
list of losses to use |
losses[].name |
None |
name of the loss |
losses[]._target_ |
None |
loss function to use |
losses[].cfg |
None |
configuration of the loss |
losses[].cfg.weight |
None |
weight of the loss |
losses[].cfg.is_doc_loss |
None |
whether the loss is for document |
Optimizer configuration
Parameter |
Options |
Note |
optimizer._target_ |
None |
torch optimizer for the model |
optimizer.lr |
None |
learning rate for the optimizer |
Training configuration
Parameter |
Options |
Note |
batch_size |
None |
batch size for the training |
num_epoch |
None |
number of epochs for training |
log_interval |
None |
logging interval for the training |
fast_test |
None |
number of samples for fast test |
save_best_only |
None |
whether to save the best model based on the metric |
save_pretrained |
None |
whether to save the model for QA inference |
batch_per_epoch |
None |
number of batches per epoch for training |
timeout |
None |
timeout minutes for multi-gpu training |
checkpoint |
None |
checkpoint path for the training |