gfmrag.models
¶
GNNRetriever
¶
Bases: QueryGNN
A Query-dependent Graph Neural Network-based retriever that processes questions and entities for information retrieval.
This class extends QueryGNN to implement a GNN-based retrieval system that processes question embeddings and entity information to retrieve relevant information from a graph.
Attributes:
Name | Type | Description |
---|---|---|
question_mlp |
Linear
|
Linear layer for transforming question embeddings. |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
entity_model
|
QueryNBFNet
|
The underlying query-dependent GNN for reasoning on graph. |
required |
rel_emb_dim
|
int
|
Dimension of relation embeddings. |
required |
*args
|
Any
|
Variable length argument list. |
()
|
**kwargs
|
Any
|
Arbitrary keyword arguments. |
{}
|
Methods:
Name | Description |
---|---|
forward |
Processes the input graph and question embeddings to generate retrieval scores. Args: graph (Data): The input graph structure. batch (dict[str, torch.Tensor]): Batch of input data containing question embeddings and masks. entities_weight (torch.Tensor, optional): Optional weights for entities. Returns: torch.Tensor: Output scores for retrieval. |
visualize |
Generates visualization data for the model's reasoning process. Args: graph (Data): The input graph structure. sample (dict[str, torch.Tensor]): Single sample data containing question embeddings and masks. entities_weight (torch.Tensor, optional): Optional weights for entities. Returns: dict[int, torch.Tensor]: Visualization data for each reasoning step. |
Note
The visualization method currently only supports batch size of 1.
Source code in gfmrag/models.py
class GNNRetriever(QueryGNN):
"""A Query-dependent Graph Neural Network-based retriever that processes questions and entities for information retrieval.
This class extends QueryGNN to implement a GNN-based retrieval system that processes question
embeddings and entity information to retrieve relevant information from a graph.
Attributes:
question_mlp (nn.Linear): Linear layer for transforming question embeddings.
Args:
entity_model (QueryNBFNet): The underlying query-dependent GNN for reasoning on graph.
rel_emb_dim (int): Dimension of relation embeddings.
*args (Any): Variable length argument list.
**kwargs (Any): Arbitrary keyword arguments.
Methods:
forward(graph, batch, entities_weight=None):
Processes the input graph and question embeddings to generate retrieval scores.
Args:
graph (Data): The input graph structure.
batch (dict[str, torch.Tensor]): Batch of input data containing question embeddings and masks.
entities_weight (torch.Tensor, optional): Optional weights for entities.
Returns:
torch.Tensor: Output scores for retrieval.
visualize(graph, sample, entities_weight=None):
Generates visualization data for the model's reasoning process.
Args:
graph (Data): The input graph structure.
sample (dict[str, torch.Tensor]): Single sample data containing question embeddings and masks.
entities_weight (torch.Tensor, optional): Optional weights for entities.
Returns:
dict[int, torch.Tensor]: Visualization data for each reasoning step.
Note:
The visualization method currently only supports batch size of 1.
"""
"""Wrap the GNN model for retrieval."""
def __init__(
self, entity_model: QueryNBFNet, rel_emb_dim: int, *args: Any, **kwargs: Any
) -> None:
"""
Initialize the RelGFM model.
Args:
entity_model (QueryNBFNet): Model for entity embedding and message passing
rel_emb_dim (int): Dimension of relation embeddings
*args: Variable length argument list
**kwargs: Arbitrary keyword arguments
Returns:
None
Note:
This constructor initializes the base class with entity_model and rel_emb_dim,
and creates a linear layer to project question embeddings to entity dimension.
"""
super().__init__(entity_model, rel_emb_dim)
self.question_mlp = nn.Linear(self.rel_emb_dim, self.entity_model.dims[0])
def forward(
self,
graph: Data,
batch: dict[str, torch.Tensor],
entities_weight: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass of the model.
This method processes a graph and question embeddings to produce entity-level reasoning output.
Args:
graph (Data): A PyTorch Geometric Data object containing the graph structure and features.
batch (dict[str, torch.Tensor]): A dictionary containing:
- question_embeddings: Tensor of question embeddings
- question_entities_masks: Tensor of masks for question entities
entities_weight (torch.Tensor | None, optional): Optional weight tensor for entities. Defaults to None.
Returns:
torch.Tensor: The output tensor representing entity-level reasoning results.
Notes:
The forward pass includes:
1. Processing question embeddings through MLP
2. Expanding relation representations
3. Applying optional entity weights
4. Computing entity-question interaction
5. Running entity-level reasoning model
"""
question_emb = batch["question_embeddings"]
question_entities_mask = batch["question_entities_masks"]
question_embedding = self.question_mlp(question_emb) # shape: (bs, emb_dim)
batch_size = question_embedding.size(0)
relation_representations = (
self.rel_mlp(graph.rel_emb).unsqueeze(0).expand(batch_size, -1, -1)
)
# initialize the input with the fuzzy set and question embs
if entities_weight is not None:
question_entities_mask = question_entities_mask * entities_weight.unsqueeze(
0
)
input = torch.einsum(
"bn, bd -> bnd", question_entities_mask, question_embedding
)
# GNN model: run the entity-level reasoner to get a scalar distribution over nodes
output = self.entity_model(
graph, input, relation_representations, question_embedding
)
return output
def visualize(
self,
graph: Data,
sample: dict[str, torch.Tensor],
entities_weight: torch.Tensor | None = None,
) -> dict[int, torch.Tensor]:
"""Visualizes attention weights and intermediate states for the model.
This function generates visualization data for understanding how the model processes
inputs and generates entity predictions. It is designed for debugging and analysis purposes.
Args:
graph (Data): The input knowledge graph structure containing entity and relation information
sample (dict[str, torch.Tensor]): Dictionary containing:
- question_embeddings: Tensor of question text embeddings
- question_entities_masks: Binary mask tensor indicating question entities
entities_weight (torch.Tensor | None, optional): Optional tensor of entity weights to apply.
Defaults to None.
Returns:
dict[int, torch.Tensor]: Dictionary mapping layer indices to attention weight tensors,
allowing visualization of attention patterns at different model depths.
Note:
Currently only supports batch size of 1 for visualization purposes.
Raises:
AssertionError: If batch size is not 1
"""
question_emb = sample["question_embeddings"]
question_entities_mask = sample["question_entities_masks"]
question_embedding = self.question_mlp(question_emb) # shape: (bs, emb_dim)
batch_size = question_embedding.size(0)
assert batch_size == 1, "Currently only supports batch size 1 for visualization"
relation_representations = (
self.rel_mlp(graph.rel_emb).unsqueeze(0).expand(batch_size, -1, -1)
)
# initialize the input with the fuzzy set and question embs
if entities_weight is not None:
question_entities_mask = question_entities_mask * entities_weight.unsqueeze(
0
)
input = torch.einsum(
"bn, bd -> bnd", question_entities_mask, question_embedding
)
return self.entity_model.visualize(
graph,
sample,
input,
relation_representations,
question_embedding, # type: ignore
)
__init__(entity_model, rel_emb_dim, *args, **kwargs)
¶
Initialize the RelGFM model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
entity_model
|
QueryNBFNet
|
Model for entity embedding and message passing |
required |
rel_emb_dim
|
int
|
Dimension of relation embeddings |
required |
*args
|
Any
|
Variable length argument list |
()
|
**kwargs
|
Any
|
Arbitrary keyword arguments |
{}
|
Returns:
Type | Description |
---|---|
None
|
None |
Note
This constructor initializes the base class with entity_model and rel_emb_dim, and creates a linear layer to project question embeddings to entity dimension.
Source code in gfmrag/models.py
def __init__(
self, entity_model: QueryNBFNet, rel_emb_dim: int, *args: Any, **kwargs: Any
) -> None:
"""
Initialize the RelGFM model.
Args:
entity_model (QueryNBFNet): Model for entity embedding and message passing
rel_emb_dim (int): Dimension of relation embeddings
*args: Variable length argument list
**kwargs: Arbitrary keyword arguments
Returns:
None
Note:
This constructor initializes the base class with entity_model and rel_emb_dim,
and creates a linear layer to project question embeddings to entity dimension.
"""
super().__init__(entity_model, rel_emb_dim)
self.question_mlp = nn.Linear(self.rel_emb_dim, self.entity_model.dims[0])
forward(graph, batch, entities_weight=None)
¶
Forward pass of the model.
This method processes a graph and question embeddings to produce entity-level reasoning output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
graph
|
Data
|
A PyTorch Geometric Data object containing the graph structure and features. |
required |
batch
|
dict[str, Tensor]
|
A dictionary containing: - question_embeddings: Tensor of question embeddings - question_entities_masks: Tensor of masks for question entities |
required |
entities_weight
|
Tensor | None
|
Optional weight tensor for entities. Defaults to None. |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: The output tensor representing entity-level reasoning results. |
Notes
The forward pass includes: 1. Processing question embeddings through MLP 2. Expanding relation representations 3. Applying optional entity weights 4. Computing entity-question interaction 5. Running entity-level reasoning model
Source code in gfmrag/models.py
def forward(
self,
graph: Data,
batch: dict[str, torch.Tensor],
entities_weight: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass of the model.
This method processes a graph and question embeddings to produce entity-level reasoning output.
Args:
graph (Data): A PyTorch Geometric Data object containing the graph structure and features.
batch (dict[str, torch.Tensor]): A dictionary containing:
- question_embeddings: Tensor of question embeddings
- question_entities_masks: Tensor of masks for question entities
entities_weight (torch.Tensor | None, optional): Optional weight tensor for entities. Defaults to None.
Returns:
torch.Tensor: The output tensor representing entity-level reasoning results.
Notes:
The forward pass includes:
1. Processing question embeddings through MLP
2. Expanding relation representations
3. Applying optional entity weights
4. Computing entity-question interaction
5. Running entity-level reasoning model
"""
question_emb = batch["question_embeddings"]
question_entities_mask = batch["question_entities_masks"]
question_embedding = self.question_mlp(question_emb) # shape: (bs, emb_dim)
batch_size = question_embedding.size(0)
relation_representations = (
self.rel_mlp(graph.rel_emb).unsqueeze(0).expand(batch_size, -1, -1)
)
# initialize the input with the fuzzy set and question embs
if entities_weight is not None:
question_entities_mask = question_entities_mask * entities_weight.unsqueeze(
0
)
input = torch.einsum(
"bn, bd -> bnd", question_entities_mask, question_embedding
)
# GNN model: run the entity-level reasoner to get a scalar distribution over nodes
output = self.entity_model(
graph, input, relation_representations, question_embedding
)
return output
visualize(graph, sample, entities_weight=None)
¶
Visualizes attention weights and intermediate states for the model.
This function generates visualization data for understanding how the model processes inputs and generates entity predictions. It is designed for debugging and analysis purposes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
graph
|
Data
|
The input knowledge graph structure containing entity and relation information |
required |
sample
|
dict[str, Tensor]
|
Dictionary containing: - question_embeddings: Tensor of question text embeddings - question_entities_masks: Binary mask tensor indicating question entities |
required |
entities_weight
|
Tensor | None
|
Optional tensor of entity weights to apply. Defaults to None. |
None
|
Returns:
Type | Description |
---|---|
dict[int, Tensor]
|
dict[int, torch.Tensor]: Dictionary mapping layer indices to attention weight tensors, allowing visualization of attention patterns at different model depths. |
Note
Currently only supports batch size of 1 for visualization purposes.
Raises:
Type | Description |
---|---|
AssertionError
|
If batch size is not 1 |
Source code in gfmrag/models.py
def visualize(
self,
graph: Data,
sample: dict[str, torch.Tensor],
entities_weight: torch.Tensor | None = None,
) -> dict[int, torch.Tensor]:
"""Visualizes attention weights and intermediate states for the model.
This function generates visualization data for understanding how the model processes
inputs and generates entity predictions. It is designed for debugging and analysis purposes.
Args:
graph (Data): The input knowledge graph structure containing entity and relation information
sample (dict[str, torch.Tensor]): Dictionary containing:
- question_embeddings: Tensor of question text embeddings
- question_entities_masks: Binary mask tensor indicating question entities
entities_weight (torch.Tensor | None, optional): Optional tensor of entity weights to apply.
Defaults to None.
Returns:
dict[int, torch.Tensor]: Dictionary mapping layer indices to attention weight tensors,
allowing visualization of attention patterns at different model depths.
Note:
Currently only supports batch size of 1 for visualization purposes.
Raises:
AssertionError: If batch size is not 1
"""
question_emb = sample["question_embeddings"]
question_entities_mask = sample["question_entities_masks"]
question_embedding = self.question_mlp(question_emb) # shape: (bs, emb_dim)
batch_size = question_embedding.size(0)
assert batch_size == 1, "Currently only supports batch size 1 for visualization"
relation_representations = (
self.rel_mlp(graph.rel_emb).unsqueeze(0).expand(batch_size, -1, -1)
)
# initialize the input with the fuzzy set and question embs
if entities_weight is not None:
question_entities_mask = question_entities_mask * entities_weight.unsqueeze(
0
)
input = torch.einsum(
"bn, bd -> bnd", question_entities_mask, question_embedding
)
return self.entity_model.visualize(
graph,
sample,
input,
relation_representations,
question_embedding, # type: ignore
)
QueryGNN
¶
Bases: Module
A neural network module for query embedding in graph neural networks.
This class implements a query embedding model that combines relation embeddings with an entity-based graph neural network for knowledge graph completion tasks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
entity_model
|
EntityNBFNet
|
The entity-based neural network model for reasoning on graph structure. |
required |
rel_emb_dim
|
int
|
Dimension of the relation embeddings. |
required |
*args
|
Any
|
Variable length argument list. |
()
|
**kwargs
|
Any
|
Arbitrary keyword arguments. |
{}
|
Attributes:
Name | Type | Description |
---|---|---|
rel_emb_dim |
int
|
Dimension of relation embeddings. |
entity_model |
EntityNBFNet
|
The entity model instance. |
rel_mlp |
Linear
|
Linear transformation layer for relation embeddings. |
Methods:
Name | Description |
---|---|
forward |
Data, batch: torch.Tensor) -> torch.Tensor: Forward pass of the query GNN model. Args: data (Data): Graph data object containing the knowledge graph structure and features. batch (torch.Tensor): Batch of triples with shape (batch_size, 1+num_negatives, 3), where each triple contains (head, tail, relation) indices. Returns: torch.Tensor: Scoring tensor for the input triples. |
Source code in gfmrag/models.py
class QueryGNN(nn.Module):
"""A neural network module for query embedding in graph neural networks.
This class implements a query embedding model that combines relation embeddings with an entity-based graph neural network
for knowledge graph completion tasks.
Args:
entity_model (EntityNBFNet): The entity-based neural network model for reasoning on graph structure.
rel_emb_dim (int): Dimension of the relation embeddings.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Attributes:
rel_emb_dim (int): Dimension of relation embeddings.
entity_model (EntityNBFNet): The entity model instance.
rel_mlp (nn.Linear): Linear transformation layer for relation embeddings.
Methods:
forward(data: Data, batch: torch.Tensor) -> torch.Tensor:
Forward pass of the query GNN model.
Args:
data (Data): Graph data object containing the knowledge graph structure and features.
batch (torch.Tensor): Batch of triples with shape (batch_size, 1+num_negatives, 3),
where each triple contains (head, tail, relation) indices.
Returns:
torch.Tensor: Scoring tensor for the input triples.
"""
def __init__(
self, entity_model: EntityNBFNet, rel_emb_dim: int, *args: Any, **kwargs: Any
) -> None:
"""Initialize the model.
Args:
entity_model (EntityNBFNet): The entity model component
rel_emb_dim (int): Dimension of relation embeddings
*args (Any): Variable length argument list
**kwargs (Any): Arbitrary keyword arguments
"""
super().__init__()
self.rel_emb_dim = rel_emb_dim
self.entity_model = entity_model
self.rel_mlp = nn.Linear(rel_emb_dim, self.entity_model.dims[0])
def forward(self, data: Data, batch: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the model.
Args:
data (Data): Graph data object containing entity embeddings and graph structure.
batch (torch.Tensor): Batch of triple indices with shape (batch_size, 1+num_negatives, 3),
where each triple contains (head_idx, tail_idx, relation_idx).
Returns:
torch.Tensor: Scores for the triples in the batch.
Notes:
- Relations are assumed to be the same across all positive and negative triples
- Easy edges are removed before processing to encourage learning of non-trivial paths
- The batch tensor contains both positive and negative samples where the first sample
is positive and the rest are negative samples
"""
# batch shape: (bs, 1+num_negs, 3)
# relations are the same all positive and negative triples, so we can extract only one from the first triple among 1+nug_negs
batch_size = len(batch)
relation_representations = (
self.rel_mlp(data.rel_emb).unsqueeze(0).expand(batch_size, -1, -1)
)
h_index, t_index, r_index = batch.unbind(-1)
# to make NBFNet iteration learn non-trivial paths
data = self.entity_model.remove_easy_edges(data, h_index, t_index, r_index)
score = self.entity_model(data, relation_representations, batch)
return score
__init__(entity_model, rel_emb_dim, *args, **kwargs)
¶
Initialize the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
entity_model
|
EntityNBFNet
|
The entity model component |
required |
rel_emb_dim
|
int
|
Dimension of relation embeddings |
required |
*args
|
Any
|
Variable length argument list |
()
|
**kwargs
|
Any
|
Arbitrary keyword arguments |
{}
|
Source code in gfmrag/models.py
def __init__(
self, entity_model: EntityNBFNet, rel_emb_dim: int, *args: Any, **kwargs: Any
) -> None:
"""Initialize the model.
Args:
entity_model (EntityNBFNet): The entity model component
rel_emb_dim (int): Dimension of relation embeddings
*args (Any): Variable length argument list
**kwargs (Any): Arbitrary keyword arguments
"""
super().__init__()
self.rel_emb_dim = rel_emb_dim
self.entity_model = entity_model
self.rel_mlp = nn.Linear(rel_emb_dim, self.entity_model.dims[0])
forward(data, batch)
¶
Forward pass of the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data
|
Data
|
Graph data object containing entity embeddings and graph structure. |
required |
batch
|
Tensor
|
Batch of triple indices with shape (batch_size, 1+num_negatives, 3), where each triple contains (head_idx, tail_idx, relation_idx). |
required |
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: Scores for the triples in the batch. |
Notes
- Relations are assumed to be the same across all positive and negative triples
- Easy edges are removed before processing to encourage learning of non-trivial paths
- The batch tensor contains both positive and negative samples where the first sample is positive and the rest are negative samples
Source code in gfmrag/models.py
def forward(self, data: Data, batch: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the model.
Args:
data (Data): Graph data object containing entity embeddings and graph structure.
batch (torch.Tensor): Batch of triple indices with shape (batch_size, 1+num_negatives, 3),
where each triple contains (head_idx, tail_idx, relation_idx).
Returns:
torch.Tensor: Scores for the triples in the batch.
Notes:
- Relations are assumed to be the same across all positive and negative triples
- Easy edges are removed before processing to encourage learning of non-trivial paths
- The batch tensor contains both positive and negative samples where the first sample
is positive and the rest are negative samples
"""
# batch shape: (bs, 1+num_negs, 3)
# relations are the same all positive and negative triples, so we can extract only one from the first triple among 1+nug_negs
batch_size = len(batch)
relation_representations = (
self.rel_mlp(data.rel_emb).unsqueeze(0).expand(batch_size, -1, -1)
)
h_index, t_index, r_index = batch.unbind(-1)
# to make NBFNet iteration learn non-trivial paths
data = self.entity_model.remove_easy_edges(data, h_index, t_index, r_index)
score = self.entity_model(data, relation_representations, batch)
return score
gfmrag.ultra.models.EntityNBFNet
¶
Bases: BaseNBFNet
Neural Bellman-Ford Network for Entity Prediction.
This class extends BaseNBFNet to perform entity prediction in knowledge graphs using a neural version of the Bellman-Ford algorithm. It learns entity representations through message passing over the graph structure.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dim
|
int
|
Dimension of input node/relation features |
required |
hidden_dims
|
list
|
List of hidden dimensions for each layer |
required |
num_relation
|
int
|
Number of relation types. Defaults to 1 (dummy value) |
1
|
**kwargs
|
Additional arguments passed to BaseNBFNet |
{}
|
Attributes:
Name | Type | Description |
---|---|---|
layers |
ModuleList
|
List of GeneralizedRelationalConv layers |
mlp |
Sequential
|
Multi-layer perceptron for final prediction |
query |
Tensor
|
Relation type embeddings used as queries |
Methods:
Name | Description |
---|---|
bellmanford |
Performs neural Bellman-Ford message passing iterations. Args: data: Graph data object containing edge information h_index (torch.Tensor): Indices of head entities r_index (torch.Tensor): Indices of relations separate_grad (bool): Whether to use separate gradients for visualization Returns: dict: Contains node features and edge weights after message passing |
forward |
Forward pass for entity prediction. Args: data: Graph data object relation_representations (torch.Tensor): Embeddings of relations batch: Batch of (head, tail, relation) triples Returns: torch.Tensor: Prediction scores for tail entities |
Source code in gfmrag/ultra/models.py
class EntityNBFNet(BaseNBFNet):
"""Neural Bellman-Ford Network for Entity Prediction.
This class extends BaseNBFNet to perform entity prediction in knowledge graphs using a neural
version of the Bellman-Ford algorithm. It learns entity representations through message passing
over the graph structure.
Args:
input_dim (int): Dimension of input node/relation features
hidden_dims (list): List of hidden dimensions for each layer
num_relation (int, optional): Number of relation types. Defaults to 1 (dummy value)
**kwargs: Additional arguments passed to BaseNBFNet
Attributes:
layers (nn.ModuleList): List of GeneralizedRelationalConv layers
mlp (nn.Sequential): Multi-layer perceptron for final prediction
query (torch.Tensor): Relation type embeddings used as queries
Methods:
bellmanford(data, h_index, r_index, separate_grad=False):
Performs neural Bellman-Ford message passing iterations.
Args:
data: Graph data object containing edge information
h_index (torch.Tensor): Indices of head entities
r_index (torch.Tensor): Indices of relations
separate_grad (bool): Whether to use separate gradients for visualization
Returns:
dict: Contains node features and edge weights after message passing
forward(data, relation_representations, batch):
Forward pass for entity prediction.
Args:
data: Graph data object
relation_representations (torch.Tensor): Embeddings of relations
batch: Batch of (head, tail, relation) triples
Returns:
torch.Tensor: Prediction scores for tail entities
"""
def __init__(self, input_dim, hidden_dims, num_relation=1, **kwargs):
# dummy num_relation = 1 as we won't use it in the NBFNet layer
super().__init__(input_dim, hidden_dims, num_relation, **kwargs)
self.layers = nn.ModuleList()
for i in range(len(self.dims) - 1):
self.layers.append(
layers.GeneralizedRelationalConv(
self.dims[i],
self.dims[i + 1],
num_relation,
self.dims[0],
self.message_func,
self.aggregate_func,
self.layer_norm,
self.activation,
dependent=False,
project_relations=True,
)
)
feature_dim = (
sum(hidden_dims) if self.concat_hidden else hidden_dims[-1]
) + input_dim
self.mlp = nn.Sequential()
mlp = []
for i in range(self.num_mlp_layers - 1):
mlp.append(nn.Linear(feature_dim, feature_dim))
mlp.append(nn.ReLU())
mlp.append(nn.Linear(feature_dim, 1))
self.mlp = nn.Sequential(*mlp)
def bellmanford(self, data, h_index, r_index, separate_grad=False):
batch_size = len(r_index)
# initialize queries (relation types of the given triples)
query = self.query[torch.arange(batch_size, device=r_index.device), r_index]
index = h_index.unsqueeze(-1).expand_as(query)
# initial (boundary) condition - initialize all node states as zeros
boundary = torch.zeros(
batch_size, data.num_nodes, self.dims[0], device=h_index.device
)
# by the scatter operation we put query (relation) embeddings as init features of source (index) nodes
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
size = (data.num_nodes, data.num_nodes)
edge_weight = torch.ones(data.num_edges, device=h_index.device)
hiddens = []
edge_weights = []
layer_input = boundary
for layer in self.layers:
# for visualization
if separate_grad:
edge_weight = edge_weight.clone().requires_grad_()
# Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
hidden = layer(
layer_input,
query,
boundary,
data.edge_index,
data.edge_type,
size,
edge_weight,
)
if self.short_cut and hidden.shape == layer_input.shape:
# residual connection here
hidden = hidden + layer_input
hiddens.append(hidden)
edge_weights.append(edge_weight)
layer_input = hidden
# original query (relation type) embeddings
node_query = query.unsqueeze(1).expand(
-1, data.num_nodes, -1
) # (batch_size, num_nodes, input_dim)
if self.concat_hidden:
output = torch.cat(hiddens + [node_query], dim=-1)
else:
output = torch.cat([hiddens[-1], node_query], dim=-1)
return {
"node_feature": output,
"edge_weights": edge_weights,
}
def forward(self, data, relation_representations, batch):
h_index, t_index, r_index = batch.unbind(-1)
# initial query representations are those from the relation graph
self.query = relation_representations
# initialize relations in each NBFNet layer (with uinque projection internally)
for layer in self.layers:
layer.relation = relation_representations
# if self.training:
# Edge dropout in the training mode
# here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
# to make NBFNet iteration learn non-trivial paths
# data = self.remove_easy_edges(data, h_index, t_index, r_index)
shape = h_index.shape
# turn all triples in a batch into a tail prediction mode
h_index, t_index, r_index = self.negative_sample_to_tail(
h_index, t_index, r_index, num_direct_rel=data.num_relations // 2
)
assert (h_index[:, [0]] == h_index).all()
assert (r_index[:, [0]] == r_index).all()
# message passing and updated node representations
output = self.bellmanford(
data, h_index[:, 0], r_index[:, 0]
) # (num_nodes, batch_size, feature_dim)
feature = output["node_feature"]
index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1])
# extract representations of tail entities from the updated node states
feature = feature.gather(
1, index
) # (batch_size, num_negative + 1, feature_dim)
# probability logit for each tail node in the batch
# (batch_size, num_negative + 1, dim) -> (batch_size, num_negative + 1)
score = self.mlp(feature).squeeze(-1)
return score.view(shape)
gfmrag.ultra.models.QueryNBFNet
¶
Bases: EntityNBFNet
The entity-level reasoner for UltraQuery-like complex query answering pipelines.
This class extends EntityNBFNet to handle query-specific reasoning in knowledge graphs. Key differences from EntityNBFNet include:
- Initial node features are provided during forward pass rather than read from triples batch
- Query comes from outer loop
- Returns distribution over all nodes (assuming t_index covers all nodes)
Attributes:
Name | Type | Description |
---|---|---|
layers |
List of neural network layers for message passing |
|
short_cut |
Boolean flag for using residual connections |
|
concat_hidden |
Boolean flag for concatenating hidden states |
|
mlp |
Multi-layer perceptron for final scoring |
|
num_beam |
Beam size for path search |
|
path_topk |
Number of top paths to return |
Methods:
Name | Description |
---|---|
bellmanford |
Performs Bellman-Ford message passing iterations. Args: data: Graph data object containing edge information node_features: Initial node representations query: Query representation separate_grad: Whether to track gradients separately for edges Returns: dict: Contains node features and edge weights |
forward |
Main forward pass of the model. Args: data: Graph data object node_features: Initial node features relation_representations: Representations for relations query: Query representation Returns: torch.Tensor: Scores for each node |
visualize |
Visualizes reasoning paths for given entities. Args: data: Graph data object sample: Dictionary containing entity masks node_features: Initial node features relation_representations: Representations for relations query: Query representation Returns: dict: Contains paths and weights for target entities |
Source code in gfmrag/ultra/models.py
class QueryNBFNet(EntityNBFNet):
"""
The entity-level reasoner for UltraQuery-like complex query answering pipelines.
This class extends EntityNBFNet to handle query-specific reasoning in knowledge graphs.
Key differences from EntityNBFNet include:
1. Initial node features are provided during forward pass rather than read from triples batch
2. Query comes from outer loop
3. Returns distribution over all nodes (assuming t_index covers all nodes)
Attributes:
layers: List of neural network layers for message passing
short_cut: Boolean flag for using residual connections
concat_hidden: Boolean flag for concatenating hidden states
mlp: Multi-layer perceptron for final scoring
num_beam: Beam size for path search
path_topk: Number of top paths to return
Methods:
bellmanford(data, node_features, query, separate_grad=False):
Performs Bellman-Ford message passing iterations.
Args:
data: Graph data object containing edge information
node_features: Initial node representations
query: Query representation
separate_grad: Whether to track gradients separately for edges
Returns:
dict: Contains node features and edge weights
forward(data, node_features, relation_representations, query):
Main forward pass of the model.
Args:
data: Graph data object
node_features: Initial node features
relation_representations: Representations for relations
query: Query representation
Returns:
torch.Tensor: Scores for each node
visualize(data, sample, node_features, relation_representations, query):
Visualizes reasoning paths for given entities.
Args:
data: Graph data object
sample: Dictionary containing entity masks
node_features: Initial node features
relation_representations: Representations for relations
query: Query representation
Returns:
dict: Contains paths and weights for target entities
"""
def bellmanford(self, data, node_features, query, separate_grad=False):
size = (data.num_nodes, data.num_nodes)
edge_weight = torch.ones(data.num_edges, device=query.device)
hiddens = []
edge_weights = []
layer_input = node_features
for layer in self.layers:
# for visualization
if separate_grad:
edge_weight = edge_weight.clone().requires_grad_()
# Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
hidden = layer(
layer_input,
query,
node_features,
data.edge_index,
data.edge_type,
size,
edge_weight,
)
if self.short_cut and hidden.shape == layer_input.shape:
# residual connection here
hidden = hidden + layer_input
hiddens.append(hidden)
edge_weights.append(edge_weight)
layer_input = hidden
# original query (relation type) embeddings
node_query = query.unsqueeze(1).expand(
-1, data.num_nodes, -1
) # (batch_size, num_nodes, input_dim)
if self.concat_hidden:
output = torch.cat(hiddens + [node_query], dim=-1)
else:
output = torch.cat([hiddens[-1], node_query], dim=-1)
return {
"node_feature": output,
"edge_weights": edge_weights,
}
def forward(self, data, node_features, relation_representations, query):
# initialize relations in each NBFNet layer (with uinque projection internally)
for layer in self.layers:
layer.relation = relation_representations
# we already did traversal_dropout in the outer loop of UltraQuery
# if self.training:
# # Edge dropout in the training mode
# # here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
# # to make NBFNet iteration learn non-trivial paths
# data = self.remove_easy_edges(data, h_index, t_index, r_index)
# node features arrive in shape (bs, num_nodes, dim)
# NBFNet needs batch size on the first place
output = self.bellmanford(
data, node_features, query
) # (num_nodes, batch_size, feature_dim)
score = self.mlp(output["node_feature"]).squeeze(-1) # (bs, num_nodes)
return score
def visualize(self, data, sample, node_features, relation_representations, query):
for layer in self.layers:
layer.relation = relation_representations
output = self.bellmanford(
data, node_features, query, separate_grad=True
) # (num_nodes, batch_size, feature_dim)
node_feature = output["node_feature"]
edge_weights = output["edge_weights"]
question_entities_mask = sample["question_entities_masks"]
target_entities_mask = sample["supporting_entities_masks"]
query_entities_index = question_entities_mask.nonzero(as_tuple=True)[1]
target_entities_index = target_entities_mask.nonzero(as_tuple=True)[1]
paths_results = {}
for t_index in target_entities_index:
index = t_index.unsqueeze(0).unsqueeze(0).unsqueeze(-1).expand(-1, -1, node_feature.shape[-1])
feature = node_feature.gather(1, index).squeeze(0)
score = self.mlp(feature).squeeze(-1)
edge_grads = autograd.grad(score, edge_weights, retain_graph=True)
distances, back_edges = self.beam_search_distance(data, edge_grads, query_entities_index, t_index, self.num_beam)
paths, weights = self.topk_average_length(distances, back_edges, t_index, self.path_topk)
paths_results[t_index.item()] = (paths, weights)
return paths_results