Skip to content

gfmrag.models

GNNRetriever

Bases: QueryGNN

A Query-dependent Graph Neural Network-based retriever that processes questions and entities for information retrieval.

This class extends QueryGNN to implement a GNN-based retrieval system that processes question embeddings and entity information to retrieve relevant information from a graph.

Attributes:

Name Type Description
question_mlp Linear

Linear layer for transforming question embeddings.

Parameters:

Name Type Description Default
entity_model QueryNBFNet

The underlying query-dependent GNN for reasoning on graph.

required
rel_emb_dim int

Dimension of relation embeddings.

required
*args Any

Variable length argument list.

()
**kwargs Any

Arbitrary keyword arguments.

{}

Methods:

Name Description
forward

Processes the input graph and question embeddings to generate retrieval scores.

Args: graph (Data): The input graph structure. batch (dict[str, torch.Tensor]): Batch of input data containing question embeddings and masks. entities_weight (torch.Tensor, optional): Optional weights for entities.

Returns: torch.Tensor: Output scores for retrieval.

visualize

Generates visualization data for the model's reasoning process.

Args: graph (Data): The input graph structure. sample (dict[str, torch.Tensor]): Single sample data containing question embeddings and masks. entities_weight (torch.Tensor, optional): Optional weights for entities.

Returns: dict[int, torch.Tensor]: Visualization data for each reasoning step.

Note

The visualization method currently only supports batch size of 1.

Source code in gfmrag/models.py
Python
class GNNRetriever(QueryGNN):
    """A Query-dependent Graph Neural Network-based retriever that processes questions and entities for information retrieval.

    This class extends QueryGNN to implement a GNN-based retrieval system that processes question
    embeddings and entity information to retrieve relevant information from a graph.

    Attributes:
        question_mlp (nn.Linear): Linear layer for transforming question embeddings.

    Args:
        entity_model (QueryNBFNet): The underlying query-dependent GNN for reasoning on graph.
        rel_emb_dim (int): Dimension of relation embeddings.
        *args (Any): Variable length argument list.
        **kwargs (Any): Arbitrary keyword arguments.

    Methods:
        forward(graph, batch, entities_weight=None):
            Processes the input graph and question embeddings to generate retrieval scores.

            Args:
                graph (Data): The input graph structure.
                batch (dict[str, torch.Tensor]): Batch of input data containing question embeddings and masks.
                entities_weight (torch.Tensor, optional): Optional weights for entities.

            Returns:
                torch.Tensor: Output scores for retrieval.

        visualize(graph, sample, entities_weight=None):
            Generates visualization data for the model's reasoning process.

            Args:
                graph (Data): The input graph structure.
                sample (dict[str, torch.Tensor]): Single sample data containing question embeddings and masks.
                entities_weight (torch.Tensor, optional): Optional weights for entities.

            Returns:
                dict[int, torch.Tensor]: Visualization data for each reasoning step.

    Note:
        The visualization method currently only supports batch size of 1.
    """

    """Wrap the GNN model for retrieval."""

    def __init__(
        self, entity_model: QueryNBFNet, rel_emb_dim: int, *args: Any, **kwargs: Any
    ) -> None:
        """
        Initialize the RelGFM model.

        Args:
            entity_model (QueryNBFNet): Model for entity embedding and message passing
            rel_emb_dim (int): Dimension of relation embeddings
            *args: Variable length argument list
            **kwargs: Arbitrary keyword arguments

        Returns:
            None

        Note:
            This constructor initializes the base class with entity_model and rel_emb_dim,
            and creates a linear layer to project question embeddings to entity dimension.
        """

        super().__init__(entity_model, rel_emb_dim)
        self.question_mlp = nn.Linear(self.rel_emb_dim, self.entity_model.dims[0])

    def forward(
        self,
        graph: Data,
        batch: dict[str, torch.Tensor],
        entities_weight: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Forward pass of the model.

        This method processes a graph and question embeddings to produce entity-level reasoning output.

        Args:
            graph (Data): A PyTorch Geometric Data object containing the graph structure and features.
            batch (dict[str, torch.Tensor]): A dictionary containing:
                - question_embeddings: Tensor of question embeddings
                - question_entities_masks: Tensor of masks for question entities
            entities_weight (torch.Tensor | None, optional): Optional weight tensor for entities. Defaults to None.

        Returns:
            torch.Tensor: The output tensor representing entity-level reasoning results.

        Notes:
            The forward pass includes:
            1. Processing question embeddings through MLP
            2. Expanding relation representations
            3. Applying optional entity weights
            4. Computing entity-question interaction
            5. Running entity-level reasoning model
        """

        question_emb = batch["question_embeddings"]
        question_entities_mask = batch["question_entities_masks"]

        question_embedding = self.question_mlp(question_emb)  # shape: (bs, emb_dim)
        batch_size = question_embedding.size(0)
        relation_representations = (
            self.rel_mlp(graph.rel_emb).unsqueeze(0).expand(batch_size, -1, -1)
        )

        # initialize the input with the fuzzy set and question embs
        if entities_weight is not None:
            question_entities_mask = question_entities_mask * entities_weight.unsqueeze(
                0
            )

        input = torch.einsum(
            "bn, bd -> bnd", question_entities_mask, question_embedding
        )

        # GNN model: run the entity-level reasoner to get a scalar distribution over nodes
        output = self.entity_model(
            graph, input, relation_representations, question_embedding
        )

        return output

    def visualize(
        self,
        graph: Data,
        sample: dict[str, torch.Tensor],
        entities_weight: torch.Tensor | None = None,
    ) -> dict[int, torch.Tensor]:
        """Visualizes attention weights and intermediate states for the model.

        This function generates visualization data for understanding how the model processes
        inputs and generates entity predictions. It is designed for debugging and analysis purposes.

        Args:
            graph (Data): The input knowledge graph structure containing entity and relation information
            sample (dict[str, torch.Tensor]): Dictionary containing:
                - question_embeddings: Tensor of question text embeddings
                - question_entities_masks: Binary mask tensor indicating question entities
            entities_weight (torch.Tensor | None, optional): Optional tensor of entity weights to apply.
                Defaults to None.

        Returns:
            dict[int, torch.Tensor]: Dictionary mapping layer indices to attention weight tensors,
                allowing visualization of attention patterns at different model depths.

        Note:
            Currently only supports batch size of 1 for visualization purposes.

        Raises:
            AssertionError: If batch size is not 1
        """

        question_emb = sample["question_embeddings"]
        question_entities_mask = sample["question_entities_masks"]
        question_embedding = self.question_mlp(question_emb)  # shape: (bs, emb_dim)
        batch_size = question_embedding.size(0)

        assert batch_size == 1, "Currently only supports batch size 1 for visualization"

        relation_representations = (
            self.rel_mlp(graph.rel_emb).unsqueeze(0).expand(batch_size, -1, -1)
        )

        # initialize the input with the fuzzy set and question embs
        if entities_weight is not None:
            question_entities_mask = question_entities_mask * entities_weight.unsqueeze(
                0
            )

        input = torch.einsum(
            "bn, bd -> bnd", question_entities_mask, question_embedding
        )
        return self.entity_model.visualize(
            graph,
            sample,
            input,
            relation_representations,
            question_embedding,  # type: ignore
        )

__init__(entity_model, rel_emb_dim, *args, **kwargs)

Initialize the RelGFM model.

Parameters:

Name Type Description Default
entity_model QueryNBFNet

Model for entity embedding and message passing

required
rel_emb_dim int

Dimension of relation embeddings

required
*args Any

Variable length argument list

()
**kwargs Any

Arbitrary keyword arguments

{}

Returns:

Type Description
None

None

Note

This constructor initializes the base class with entity_model and rel_emb_dim, and creates a linear layer to project question embeddings to entity dimension.

Source code in gfmrag/models.py
Python
def __init__(
    self, entity_model: QueryNBFNet, rel_emb_dim: int, *args: Any, **kwargs: Any
) -> None:
    """
    Initialize the RelGFM model.

    Args:
        entity_model (QueryNBFNet): Model for entity embedding and message passing
        rel_emb_dim (int): Dimension of relation embeddings
        *args: Variable length argument list
        **kwargs: Arbitrary keyword arguments

    Returns:
        None

    Note:
        This constructor initializes the base class with entity_model and rel_emb_dim,
        and creates a linear layer to project question embeddings to entity dimension.
    """

    super().__init__(entity_model, rel_emb_dim)
    self.question_mlp = nn.Linear(self.rel_emb_dim, self.entity_model.dims[0])

forward(graph, batch, entities_weight=None)

Forward pass of the model.

This method processes a graph and question embeddings to produce entity-level reasoning output.

Parameters:

Name Type Description Default
graph Data

A PyTorch Geometric Data object containing the graph structure and features.

required
batch dict[str, Tensor]

A dictionary containing: - question_embeddings: Tensor of question embeddings - question_entities_masks: Tensor of masks for question entities

required
entities_weight Tensor | None

Optional weight tensor for entities. Defaults to None.

None

Returns:

Type Description
Tensor

torch.Tensor: The output tensor representing entity-level reasoning results.

Notes

The forward pass includes: 1. Processing question embeddings through MLP 2. Expanding relation representations 3. Applying optional entity weights 4. Computing entity-question interaction 5. Running entity-level reasoning model

Source code in gfmrag/models.py
Python
def forward(
    self,
    graph: Data,
    batch: dict[str, torch.Tensor],
    entities_weight: torch.Tensor | None = None,
) -> torch.Tensor:
    """Forward pass of the model.

    This method processes a graph and question embeddings to produce entity-level reasoning output.

    Args:
        graph (Data): A PyTorch Geometric Data object containing the graph structure and features.
        batch (dict[str, torch.Tensor]): A dictionary containing:
            - question_embeddings: Tensor of question embeddings
            - question_entities_masks: Tensor of masks for question entities
        entities_weight (torch.Tensor | None, optional): Optional weight tensor for entities. Defaults to None.

    Returns:
        torch.Tensor: The output tensor representing entity-level reasoning results.

    Notes:
        The forward pass includes:
        1. Processing question embeddings through MLP
        2. Expanding relation representations
        3. Applying optional entity weights
        4. Computing entity-question interaction
        5. Running entity-level reasoning model
    """

    question_emb = batch["question_embeddings"]
    question_entities_mask = batch["question_entities_masks"]

    question_embedding = self.question_mlp(question_emb)  # shape: (bs, emb_dim)
    batch_size = question_embedding.size(0)
    relation_representations = (
        self.rel_mlp(graph.rel_emb).unsqueeze(0).expand(batch_size, -1, -1)
    )

    # initialize the input with the fuzzy set and question embs
    if entities_weight is not None:
        question_entities_mask = question_entities_mask * entities_weight.unsqueeze(
            0
        )

    input = torch.einsum(
        "bn, bd -> bnd", question_entities_mask, question_embedding
    )

    # GNN model: run the entity-level reasoner to get a scalar distribution over nodes
    output = self.entity_model(
        graph, input, relation_representations, question_embedding
    )

    return output

visualize(graph, sample, entities_weight=None)

Visualizes attention weights and intermediate states for the model.

This function generates visualization data for understanding how the model processes inputs and generates entity predictions. It is designed for debugging and analysis purposes.

Parameters:

Name Type Description Default
graph Data

The input knowledge graph structure containing entity and relation information

required
sample dict[str, Tensor]

Dictionary containing: - question_embeddings: Tensor of question text embeddings - question_entities_masks: Binary mask tensor indicating question entities

required
entities_weight Tensor | None

Optional tensor of entity weights to apply. Defaults to None.

None

Returns:

Type Description
dict[int, Tensor]

dict[int, torch.Tensor]: Dictionary mapping layer indices to attention weight tensors, allowing visualization of attention patterns at different model depths.

Note

Currently only supports batch size of 1 for visualization purposes.

Raises:

Type Description
AssertionError

If batch size is not 1

Source code in gfmrag/models.py
Python
def visualize(
    self,
    graph: Data,
    sample: dict[str, torch.Tensor],
    entities_weight: torch.Tensor | None = None,
) -> dict[int, torch.Tensor]:
    """Visualizes attention weights and intermediate states for the model.

    This function generates visualization data for understanding how the model processes
    inputs and generates entity predictions. It is designed for debugging and analysis purposes.

    Args:
        graph (Data): The input knowledge graph structure containing entity and relation information
        sample (dict[str, torch.Tensor]): Dictionary containing:
            - question_embeddings: Tensor of question text embeddings
            - question_entities_masks: Binary mask tensor indicating question entities
        entities_weight (torch.Tensor | None, optional): Optional tensor of entity weights to apply.
            Defaults to None.

    Returns:
        dict[int, torch.Tensor]: Dictionary mapping layer indices to attention weight tensors,
            allowing visualization of attention patterns at different model depths.

    Note:
        Currently only supports batch size of 1 for visualization purposes.

    Raises:
        AssertionError: If batch size is not 1
    """

    question_emb = sample["question_embeddings"]
    question_entities_mask = sample["question_entities_masks"]
    question_embedding = self.question_mlp(question_emb)  # shape: (bs, emb_dim)
    batch_size = question_embedding.size(0)

    assert batch_size == 1, "Currently only supports batch size 1 for visualization"

    relation_representations = (
        self.rel_mlp(graph.rel_emb).unsqueeze(0).expand(batch_size, -1, -1)
    )

    # initialize the input with the fuzzy set and question embs
    if entities_weight is not None:
        question_entities_mask = question_entities_mask * entities_weight.unsqueeze(
            0
        )

    input = torch.einsum(
        "bn, bd -> bnd", question_entities_mask, question_embedding
    )
    return self.entity_model.visualize(
        graph,
        sample,
        input,
        relation_representations,
        question_embedding,  # type: ignore
    )

QueryGNN

Bases: Module

A neural network module for query embedding in graph neural networks.

This class implements a query embedding model that combines relation embeddings with an entity-based graph neural network for knowledge graph completion tasks.

Parameters:

Name Type Description Default
entity_model EntityNBFNet

The entity-based neural network model for reasoning on graph structure.

required
rel_emb_dim int

Dimension of the relation embeddings.

required
*args Any

Variable length argument list.

()
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
rel_emb_dim int

Dimension of relation embeddings.

entity_model EntityNBFNet

The entity model instance.

rel_mlp Linear

Linear transformation layer for relation embeddings.

Methods:

Name Description
forward

Data, batch: torch.Tensor) -> torch.Tensor: Forward pass of the query GNN model.

Args: data (Data): Graph data object containing the knowledge graph structure and features. batch (torch.Tensor): Batch of triples with shape (batch_size, 1+num_negatives, 3), where each triple contains (head, tail, relation) indices.

Returns: torch.Tensor: Scoring tensor for the input triples.

Source code in gfmrag/models.py
Python
class QueryGNN(nn.Module):
    """A neural network module for query embedding in graph neural networks.

    This class implements a query embedding model that combines relation embeddings with an entity-based graph neural network
    for knowledge graph completion tasks.

    Args:
        entity_model (EntityNBFNet): The entity-based neural network model for reasoning on graph structure.
        rel_emb_dim (int): Dimension of the relation embeddings.
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Attributes:
        rel_emb_dim (int): Dimension of relation embeddings.
        entity_model (EntityNBFNet): The entity model instance.
        rel_mlp (nn.Linear): Linear transformation layer for relation embeddings.

    Methods:
        forward(data: Data, batch: torch.Tensor) -> torch.Tensor:
            Forward pass of the query GNN model.

            Args:
                data (Data): Graph data object containing the knowledge graph structure and features.
                batch (torch.Tensor): Batch of triples with shape (batch_size, 1+num_negatives, 3),
                                    where each triple contains (head, tail, relation) indices.

            Returns:
                torch.Tensor: Scoring tensor for the input triples.
    """

    def __init__(
        self, entity_model: EntityNBFNet, rel_emb_dim: int, *args: Any, **kwargs: Any
    ) -> None:
        """Initialize the model.

        Args:
            entity_model (EntityNBFNet): The entity model component
            rel_emb_dim (int): Dimension of relation embeddings
            *args (Any): Variable length argument list
            **kwargs (Any): Arbitrary keyword arguments

        """

        super().__init__()
        self.rel_emb_dim = rel_emb_dim
        self.entity_model = entity_model
        self.rel_mlp = nn.Linear(rel_emb_dim, self.entity_model.dims[0])

    def forward(self, data: Data, batch: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.

        Args:
            data (Data): Graph data object containing entity embeddings and graph structure.
            batch (torch.Tensor): Batch of triple indices with shape (batch_size, 1+num_negatives, 3),
                                where each triple contains (head_idx, tail_idx, relation_idx).

        Returns:
            torch.Tensor: Scores for the triples in the batch.

        Notes:
            - Relations are assumed to be the same across all positive and negative triples
            - Easy edges are removed before processing to encourage learning of non-trivial paths
            - The batch tensor contains both positive and negative samples where the first sample
              is positive and the rest are negative samples
        """
        # batch shape: (bs, 1+num_negs, 3)
        # relations are the same all positive and negative triples, so we can extract only one from the first triple among 1+nug_negs
        batch_size = len(batch)
        relation_representations = (
            self.rel_mlp(data.rel_emb).unsqueeze(0).expand(batch_size, -1, -1)
        )
        h_index, t_index, r_index = batch.unbind(-1)
        # to make NBFNet iteration learn non-trivial paths
        data = self.entity_model.remove_easy_edges(data, h_index, t_index, r_index)
        score = self.entity_model(data, relation_representations, batch)

        return score

__init__(entity_model, rel_emb_dim, *args, **kwargs)

Initialize the model.

Parameters:

Name Type Description Default
entity_model EntityNBFNet

The entity model component

required
rel_emb_dim int

Dimension of relation embeddings

required
*args Any

Variable length argument list

()
**kwargs Any

Arbitrary keyword arguments

{}
Source code in gfmrag/models.py
Python
def __init__(
    self, entity_model: EntityNBFNet, rel_emb_dim: int, *args: Any, **kwargs: Any
) -> None:
    """Initialize the model.

    Args:
        entity_model (EntityNBFNet): The entity model component
        rel_emb_dim (int): Dimension of relation embeddings
        *args (Any): Variable length argument list
        **kwargs (Any): Arbitrary keyword arguments

    """

    super().__init__()
    self.rel_emb_dim = rel_emb_dim
    self.entity_model = entity_model
    self.rel_mlp = nn.Linear(rel_emb_dim, self.entity_model.dims[0])

forward(data, batch)

Forward pass of the model.

Parameters:

Name Type Description Default
data Data

Graph data object containing entity embeddings and graph structure.

required
batch Tensor

Batch of triple indices with shape (batch_size, 1+num_negatives, 3), where each triple contains (head_idx, tail_idx, relation_idx).

required

Returns:

Type Description
Tensor

torch.Tensor: Scores for the triples in the batch.

Notes
  • Relations are assumed to be the same across all positive and negative triples
  • Easy edges are removed before processing to encourage learning of non-trivial paths
  • The batch tensor contains both positive and negative samples where the first sample is positive and the rest are negative samples
Source code in gfmrag/models.py
Python
def forward(self, data: Data, batch: torch.Tensor) -> torch.Tensor:
    """
    Forward pass of the model.

    Args:
        data (Data): Graph data object containing entity embeddings and graph structure.
        batch (torch.Tensor): Batch of triple indices with shape (batch_size, 1+num_negatives, 3),
                            where each triple contains (head_idx, tail_idx, relation_idx).

    Returns:
        torch.Tensor: Scores for the triples in the batch.

    Notes:
        - Relations are assumed to be the same across all positive and negative triples
        - Easy edges are removed before processing to encourage learning of non-trivial paths
        - The batch tensor contains both positive and negative samples where the first sample
          is positive and the rest are negative samples
    """
    # batch shape: (bs, 1+num_negs, 3)
    # relations are the same all positive and negative triples, so we can extract only one from the first triple among 1+nug_negs
    batch_size = len(batch)
    relation_representations = (
        self.rel_mlp(data.rel_emb).unsqueeze(0).expand(batch_size, -1, -1)
    )
    h_index, t_index, r_index = batch.unbind(-1)
    # to make NBFNet iteration learn non-trivial paths
    data = self.entity_model.remove_easy_edges(data, h_index, t_index, r_index)
    score = self.entity_model(data, relation_representations, batch)

    return score

gfmrag.ultra.models.EntityNBFNet

Bases: BaseNBFNet

Neural Bellman-Ford Network for Entity Prediction.

This class extends BaseNBFNet to perform entity prediction in knowledge graphs using a neural version of the Bellman-Ford algorithm. It learns entity representations through message passing over the graph structure.

Parameters:

Name Type Description Default
input_dim int

Dimension of input node/relation features

required
hidden_dims list

List of hidden dimensions for each layer

required
num_relation int

Number of relation types. Defaults to 1 (dummy value)

1
**kwargs

Additional arguments passed to BaseNBFNet

{}

Attributes:

Name Type Description
layers ModuleList

List of GeneralizedRelationalConv layers

mlp Sequential

Multi-layer perceptron for final prediction

query Tensor

Relation type embeddings used as queries

Methods:

Name Description
bellmanford

Performs neural Bellman-Ford message passing iterations.

Args: data: Graph data object containing edge information h_index (torch.Tensor): Indices of head entities r_index (torch.Tensor): Indices of relations separate_grad (bool): Whether to use separate gradients for visualization

Returns: dict: Contains node features and edge weights after message passing

forward

Forward pass for entity prediction.

Args: data: Graph data object relation_representations (torch.Tensor): Embeddings of relations batch: Batch of (head, tail, relation) triples

Returns: torch.Tensor: Prediction scores for tail entities

Source code in gfmrag/ultra/models.py
Python
class EntityNBFNet(BaseNBFNet):
    """Neural Bellman-Ford Network for Entity Prediction.

    This class extends BaseNBFNet to perform entity prediction in knowledge graphs using a neural
    version of the Bellman-Ford algorithm. It learns entity representations through message passing
    over the graph structure.

    Args:
        input_dim (int): Dimension of input node/relation features
        hidden_dims (list): List of hidden dimensions for each layer
        num_relation (int, optional): Number of relation types. Defaults to 1 (dummy value)
        **kwargs: Additional arguments passed to BaseNBFNet

    Attributes:
        layers (nn.ModuleList): List of GeneralizedRelationalConv layers
        mlp (nn.Sequential): Multi-layer perceptron for final prediction
        query (torch.Tensor): Relation type embeddings used as queries

    Methods:
        bellmanford(data, h_index, r_index, separate_grad=False):
            Performs neural Bellman-Ford message passing iterations.

            Args:
                data: Graph data object containing edge information
                h_index (torch.Tensor): Indices of head entities
                r_index (torch.Tensor): Indices of relations
                separate_grad (bool): Whether to use separate gradients for visualization

            Returns:
                dict: Contains node features and edge weights after message passing

        forward(data, relation_representations, batch):
            Forward pass for entity prediction.

            Args:
                data: Graph data object
                relation_representations (torch.Tensor): Embeddings of relations
                batch: Batch of (head, tail, relation) triples

            Returns:
                torch.Tensor: Prediction scores for tail entities
    """
    def __init__(self, input_dim, hidden_dims, num_relation=1, **kwargs):
        # dummy num_relation = 1 as we won't use it in the NBFNet layer
        super().__init__(input_dim, hidden_dims, num_relation, **kwargs)

        self.layers = nn.ModuleList()
        for i in range(len(self.dims) - 1):
            self.layers.append(
                layers.GeneralizedRelationalConv(
                    self.dims[i],
                    self.dims[i + 1],
                    num_relation,
                    self.dims[0],
                    self.message_func,
                    self.aggregate_func,
                    self.layer_norm,
                    self.activation,
                    dependent=False,
                    project_relations=True,
                )
            )

        feature_dim = (
            sum(hidden_dims) if self.concat_hidden else hidden_dims[-1]
        ) + input_dim
        self.mlp = nn.Sequential()
        mlp = []
        for i in range(self.num_mlp_layers - 1):
            mlp.append(nn.Linear(feature_dim, feature_dim))
            mlp.append(nn.ReLU())
        mlp.append(nn.Linear(feature_dim, 1))
        self.mlp = nn.Sequential(*mlp)

    def bellmanford(self, data, h_index, r_index, separate_grad=False):
        batch_size = len(r_index)

        # initialize queries (relation types of the given triples)
        query = self.query[torch.arange(batch_size, device=r_index.device), r_index]
        index = h_index.unsqueeze(-1).expand_as(query)

        # initial (boundary) condition - initialize all node states as zeros
        boundary = torch.zeros(
            batch_size, data.num_nodes, self.dims[0], device=h_index.device
        )
        # by the scatter operation we put query (relation) embeddings as init features of source (index) nodes
        boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))

        size = (data.num_nodes, data.num_nodes)
        edge_weight = torch.ones(data.num_edges, device=h_index.device)

        hiddens = []
        edge_weights = []
        layer_input = boundary

        for layer in self.layers:
            # for visualization
            if separate_grad:
                edge_weight = edge_weight.clone().requires_grad_()

            # Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
            hidden = layer(
                layer_input,
                query,
                boundary,
                data.edge_index,
                data.edge_type,
                size,
                edge_weight,
            )
            if self.short_cut and hidden.shape == layer_input.shape:
                # residual connection here
                hidden = hidden + layer_input
            hiddens.append(hidden)
            edge_weights.append(edge_weight)
            layer_input = hidden

        # original query (relation type) embeddings
        node_query = query.unsqueeze(1).expand(
            -1, data.num_nodes, -1
        )  # (batch_size, num_nodes, input_dim)
        if self.concat_hidden:
            output = torch.cat(hiddens + [node_query], dim=-1)
        else:
            output = torch.cat([hiddens[-1], node_query], dim=-1)

        return {
            "node_feature": output,
            "edge_weights": edge_weights,
        }

    def forward(self, data, relation_representations, batch):
        h_index, t_index, r_index = batch.unbind(-1)

        # initial query representations are those from the relation graph
        self.query = relation_representations

        # initialize relations in each NBFNet layer (with uinque projection internally)
        for layer in self.layers:
            layer.relation = relation_representations

        # if self.training:
            # Edge dropout in the training mode
            # here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
            # to make NBFNet iteration learn non-trivial paths
            # data = self.remove_easy_edges(data, h_index, t_index, r_index)

        shape = h_index.shape
        # turn all triples in a batch into a tail prediction mode
        h_index, t_index, r_index = self.negative_sample_to_tail(
            h_index, t_index, r_index, num_direct_rel=data.num_relations // 2
        )
        assert (h_index[:, [0]] == h_index).all()
        assert (r_index[:, [0]] == r_index).all()

        # message passing and updated node representations
        output = self.bellmanford(
            data, h_index[:, 0], r_index[:, 0]
        )  # (num_nodes, batch_size, feature_dim)
        feature = output["node_feature"]
        index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1])
        # extract representations of tail entities from the updated node states
        feature = feature.gather(
            1, index
        )  # (batch_size, num_negative + 1, feature_dim)

        # probability logit for each tail node in the batch
        # (batch_size, num_negative + 1, dim) -> (batch_size, num_negative + 1)
        score = self.mlp(feature).squeeze(-1)
        return score.view(shape)

gfmrag.ultra.models.QueryNBFNet

Bases: EntityNBFNet

The entity-level reasoner for UltraQuery-like complex query answering pipelines.

This class extends EntityNBFNet to handle query-specific reasoning in knowledge graphs. Key differences from EntityNBFNet include:

  1. Initial node features are provided during forward pass rather than read from triples batch
  2. Query comes from outer loop
  3. Returns distribution over all nodes (assuming t_index covers all nodes)

Attributes:

Name Type Description
layers

List of neural network layers for message passing

short_cut

Boolean flag for using residual connections

concat_hidden

Boolean flag for concatenating hidden states

mlp

Multi-layer perceptron for final scoring

num_beam

Beam size for path search

path_topk

Number of top paths to return

Methods:

Name Description
bellmanford

Performs Bellman-Ford message passing iterations. Args: data: Graph data object containing edge information node_features: Initial node representations query: Query representation separate_grad: Whether to track gradients separately for edges Returns: dict: Contains node features and edge weights

forward

Main forward pass of the model. Args: data: Graph data object node_features: Initial node features relation_representations: Representations for relations query: Query representation Returns: torch.Tensor: Scores for each node

visualize

Visualizes reasoning paths for given entities. Args: data: Graph data object sample: Dictionary containing entity masks node_features: Initial node features relation_representations: Representations for relations query: Query representation Returns: dict: Contains paths and weights for target entities

Source code in gfmrag/ultra/models.py
Python
class QueryNBFNet(EntityNBFNet):
    """
    The entity-level reasoner for UltraQuery-like complex query answering pipelines.

    This class extends EntityNBFNet to handle query-specific reasoning in knowledge graphs.
    Key differences from EntityNBFNet include:

    1. Initial node features are provided during forward pass rather than read from triples batch
    2. Query comes from outer loop
    3. Returns distribution over all nodes (assuming t_index covers all nodes)

    Attributes:
        layers: List of neural network layers for message passing
        short_cut: Boolean flag for using residual connections
        concat_hidden: Boolean flag for concatenating hidden states
        mlp: Multi-layer perceptron for final scoring
        num_beam: Beam size for path search
        path_topk: Number of top paths to return

    Methods:
        bellmanford(data, node_features, query, separate_grad=False):
            Performs Bellman-Ford message passing iterations.
            Args:
                data: Graph data object containing edge information
                node_features: Initial node representations
                query: Query representation
                separate_grad: Whether to track gradients separately for edges
            Returns:
                dict: Contains node features and edge weights

        forward(data, node_features, relation_representations, query):
            Main forward pass of the model.
            Args:
                data: Graph data object
                node_features: Initial node features
                relation_representations: Representations for relations
                query: Query representation
            Returns:
                torch.Tensor: Scores for each node

        visualize(data, sample, node_features, relation_representations, query):
            Visualizes reasoning paths for given entities.
            Args:
                data: Graph data object
                sample: Dictionary containing entity masks
                node_features: Initial node features
                relation_representations: Representations for relations
                query: Query representation
            Returns:
                dict: Contains paths and weights for target entities
    """

    def bellmanford(self, data, node_features, query, separate_grad=False):
        size = (data.num_nodes, data.num_nodes)
        edge_weight = torch.ones(data.num_edges, device=query.device)

        hiddens = []
        edge_weights = []
        layer_input = node_features

        for layer in self.layers:
            # for visualization
            if separate_grad:
                edge_weight = edge_weight.clone().requires_grad_()

            # Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
            hidden = layer(
                layer_input,
                query,
                node_features,
                data.edge_index,
                data.edge_type,
                size,
                edge_weight,
            )
            if self.short_cut and hidden.shape == layer_input.shape:
                # residual connection here
                hidden = hidden + layer_input
            hiddens.append(hidden)
            edge_weights.append(edge_weight)
            layer_input = hidden

        # original query (relation type) embeddings
        node_query = query.unsqueeze(1).expand(
            -1, data.num_nodes, -1
        )  # (batch_size, num_nodes, input_dim)
        if self.concat_hidden:
            output = torch.cat(hiddens + [node_query], dim=-1)
        else:
            output = torch.cat([hiddens[-1], node_query], dim=-1)

        return {
            "node_feature": output,
            "edge_weights": edge_weights,
        }

    def forward(self, data, node_features, relation_representations, query):
        # initialize relations in each NBFNet layer (with uinque projection internally)
        for layer in self.layers:
            layer.relation = relation_representations

        # we already did traversal_dropout in the outer loop of UltraQuery
        # if self.training:
        #     # Edge dropout in the training mode
        #     # here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
        #     # to make NBFNet iteration learn non-trivial paths
        #     data = self.remove_easy_edges(data, h_index, t_index, r_index)

        # node features arrive in shape (bs, num_nodes, dim)
        # NBFNet needs batch size on the first place
        output = self.bellmanford(
            data, node_features, query
        )  # (num_nodes, batch_size, feature_dim)
        score = self.mlp(output["node_feature"]).squeeze(-1)  # (bs, num_nodes)
        return score

    def visualize(self, data, sample, node_features, relation_representations, query):
        for layer in self.layers:
            layer.relation = relation_representations

        output = self.bellmanford(
            data, node_features, query, separate_grad=True
        )  # (num_nodes, batch_size, feature_dim)
        node_feature = output["node_feature"]
        edge_weights = output["edge_weights"]
        question_entities_mask = sample["question_entities_masks"]
        target_entities_mask = sample["supporting_entities_masks"]
        query_entities_index = question_entities_mask.nonzero(as_tuple=True)[1]
        target_entities_index = target_entities_mask.nonzero(as_tuple=True)[1]

        paths_results = {}
        for t_index in target_entities_index:
            index = t_index.unsqueeze(0).unsqueeze(0).unsqueeze(-1).expand(-1, -1, node_feature.shape[-1])
            feature = node_feature.gather(1, index).squeeze(0)
            score = self.mlp(feature).squeeze(-1)

            edge_grads = autograd.grad(score, edge_weights, retain_graph=True)
            distances, back_edges = self.beam_search_distance(data, edge_grads, query_entities_index, t_index, self.num_beam)
            paths, weights = self.topk_average_length(distances, back_edges, t_index, self.path_topk)
            paths_results[t_index.item()] = (paths, weights)
        return paths_results