Skip to content

Datasets

gfmrag.datasets

KGDataset

Bases: InMemoryDataset

A dataset class for processing and managing Knowledge Graph (KG) data.

This class extends InMemoryDataset to handle knowledge graph data, including entity-relation-entity triplets, and supports processing of both direct and inverse relations.

Parameters:

Name Type Description Default
root str

Root directory where the dataset should be saved.

required
data_name str

Name of the dataset.

required
text_emb_model_cfgs DictConfig

Configuration for the text embedding model.

required
force_rebuild bool

Whether to force rebuilding the processed data. Defaults to False.

False
**kwargs str

Additional keyword arguments.

{}

Attributes:

Name Type Description
name str

Name of the dataset.

fingerprint str

MD5 hash of the text embedding model configuration.

delimiter str

Delimiter used in the KG text file.

data Data

Processed graph data object.

slices dict

Data slices for batching.

Note
  • The class expects a 'kg.txt' file in the raw directory containing triplets.
  • Processes both direct and inverse relations.
  • Generates and stores relation embeddings using the specified text embedding model.
  • Saves processed data along with entity and relation mappings.
Source code in gfmrag/datasets/kg_dataset.py
Python
class KGDataset(InMemoryDataset):
    """A dataset class for processing and managing Knowledge Graph (KG) data.

    This class extends InMemoryDataset to handle knowledge graph data, including entity-relation-entity triplets,
    and supports processing of both direct and inverse relations.

    Args:
        root (str): Root directory where the dataset should be saved.
        data_name (str): Name of the dataset.
        text_emb_model_cfgs (DictConfig): Configuration for the text embedding model.
        force_rebuild (bool, optional): Whether to force rebuilding the processed data. Defaults to False.
        **kwargs (str): Additional keyword arguments.

    Attributes:
        name (str): Name of the dataset.
        fingerprint (str): MD5 hash of the text embedding model configuration.
        delimiter (str): Delimiter used in the KG text file.
        data (Data): Processed graph data object.
        slices (dict): Data slices for batching.

    Note:
        - The class expects a 'kg.txt' file in the raw directory containing triplets.
        - Processes both direct and inverse relations.
        - Generates and stores relation embeddings using the specified text embedding model.
        - Saves processed data along with entity and relation mappings.
    """

    delimiter = KG_DELIMITER

    def __init__(
        self,
        root: str,
        data_name: str,
        text_emb_model_cfgs: DictConfig,
        force_rebuild: bool = False,
        **kwargs: str,
    ) -> None:
        self.name = data_name
        self.force_rebuild = force_rebuild
        # Get fingerprint of the model configuration
        self.fingerprint = hashlib.md5(
            json.dumps(
                OmegaConf.to_container(text_emb_model_cfgs, resolve=True)
            ).encode()
        ).hexdigest()
        self.text_emb_model_cfgs = text_emb_model_cfgs
        super().__init__(root, None, None)
        self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)

    @property
    def raw_file_names(self) -> list:
        return ["kg.txt"]

    def load_file(
        self, triplet_file: str, inv_entity_vocab: dict, inv_rel_vocab: dict
    ) -> dict:
        """Load a knowledge graph file and return the processed data."""

        triplets = []  # Triples with inverse relations
        entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)

        with open(triplet_file, encoding="utf-8") as fin:
            for line in fin:
                try:
                    u, r, v = (
                        line.split()
                        if self.delimiter is None
                        else line.strip().split(self.delimiter)
                    )
                except Exception as e:
                    logger.error(f"Error in line: {line}, {e}, Skipping")
                    continue
                if u not in inv_entity_vocab:
                    inv_entity_vocab[u] = entity_cnt
                    entity_cnt += 1
                if v not in inv_entity_vocab:
                    inv_entity_vocab[v] = entity_cnt
                    entity_cnt += 1
                if r not in inv_rel_vocab:
                    inv_rel_vocab[r] = rel_cnt
                    rel_cnt += 1
                u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]

                triplets.append((u, v, r))

        return {
            "triplets": triplets,
            "num_node": len(inv_entity_vocab),  # entity_cnt,
            "num_relation": rel_cnt,
            "inv_entity_vocab": inv_entity_vocab,
            "inv_rel_vocab": inv_rel_vocab,
        }

    def _process(self) -> None:
        if is_main_process():
            logger.info(f"Processing KG dataset {self.name} at rank {get_rank()}")
            f = osp.join(self.processed_dir, "pre_transform.pt")
            if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
                self.pre_transform
            ):
                warnings.warn(  # noqa:B028
                    f"The `pre_transform` argument differs from the one used in "
                    f"the pre-processed version of this dataset. If you want to "
                    f"make use of another pre-processing technique, make sure to "
                    f"delete '{self.processed_dir}' first",
                    stacklevel=1,
                )

            f = osp.join(self.processed_dir, "pre_filter.pt")
            if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
                self.pre_filter
            ):
                warnings.warn(
                    f"The `pre_filter` argument differs from the one used in "
                    f"the pre-processed version of this dataset. If you want to "
                    f"make use of another pre-fitering technique, make sure to "
                    f"delete '{self.processed_dir}' first",
                    stacklevel=1,
                )

            if self.force_rebuild or not files_exist(self.processed_paths):
                if self.log and "pytest" not in sys.modules:
                    print("Processing...", file=sys.stderr)

                makedirs(self.processed_dir)
                self.process()

                path = osp.join(self.processed_dir, "pre_transform.pt")
                torch.save(_repr(self.pre_transform), path)
                path = osp.join(self.processed_dir, "pre_filter.pt")
                torch.save(_repr(self.pre_filter), path)

                if self.log and "pytest" not in sys.modules:
                    print("Done!", file=sys.stderr)
        else:
            logger.info(
                f"Rank [{get_rank()}]: Waiting for main process to finish processing KG dataset {self.name}"
            )
        synchronize()

    def process(self) -> None:
        """Process the knowledge graph dataset.

        This method processes the raw knowledge graph file and creates the following:

        1. Loads the KG triplets and vocabulary
        2. Creates edge indices and types for both original and inverse relations
        3. Saves entity and relation mappings to JSON files
        4. Generates relation embeddings using a text embedding model
        5. Builds relation graphs
        6. Saves the processed data and model configurations

        The processed data includes:

        - Edge indices and types for both original and inverse edges
        - Target edge indices and types (original edges only)
        - Number of nodes and relations
        - Relation embeddings
        - Relation graphs

        Files created:

        - ent2id.json: Entity to ID mapping
        - rel2id.json: Relation to ID mapping (including inverse relations)
        - text_emb_model_cfgs.json: Text embedding model configuration
        - Processed graph data file at self.processed_paths[0]
        """
        kg_file = self.raw_paths[0]

        kg_result = self.load_file(kg_file, inv_entity_vocab={}, inv_rel_vocab={})

        # in some datasets, there are several new nodes in the test set, eg 123,143 YAGO train and 123,182 in YAGO test
        # for consistency with other experimental results, we'll include those in the full vocab and num nodes
        num_node = kg_result["num_node"]
        # the same for rels: in most cases train == test for transductive
        # for AristoV4 train rels 1593, test 1604
        num_relations = kg_result["num_relation"]

        kg_triplets = kg_result["triplets"]

        train_target_edges = torch.tensor(
            [[t[0], t[1]] for t in kg_triplets], dtype=torch.long
        ).t()
        train_target_etypes = torch.tensor([t[2] for t in kg_triplets])

        # Add inverse edges
        train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
        train_etypes = torch.cat(
            [train_target_etypes, train_target_etypes + num_relations]
        )

        with open(self.processed_dir + "/ent2id.json", "w") as f:
            json.dump(kg_result["inv_entity_vocab"], f)
        rel2id = kg_result["inv_rel_vocab"]
        id2rel = {v: k for k, v in rel2id.items()}
        for etype in train_etypes:
            if etype.item() >= num_relations:
                raw_etype = etype - num_relations
                raw_rel = id2rel[raw_etype.item()]
                rel2id["inverse_" + raw_rel] = etype.item()
        with open(self.processed_dir + "/rel2id.json", "w") as f:
            json.dump(rel2id, f)

        # Generate relation embeddings
        logger.info("Generating relation embeddings")
        text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
        rel_emb = text_emb_model.encode(list(rel2id.keys()), is_query=False).cpu()

        kg_data = Data(
            edge_index=train_edges,
            edge_type=train_etypes,
            num_nodes=num_node,
            target_edge_index=train_target_edges,
            target_edge_type=train_target_etypes,
            num_relations=num_relations * 2,
            rel_emb=rel_emb,
        )

        # build graphs of relations
        kg_data = build_relation_graph(kg_data)

        torch.save((self.collate([kg_data])), self.processed_paths[0])

        # Save text embeddings model configuration
        with open(self.processed_dir + "/text_emb_model_cfgs.json", "w") as f:
            json.dump(OmegaConf.to_container(self.text_emb_model_cfgs), f, indent=4)

    def __repr__(self) -> str:
        return f"{self.name}()"

    @property
    def num_relations(self) -> int:
        return int(self.data.edge_type.max()) + 1

    @property
    def raw_dir(self) -> str:
        return os.path.join(str(self.root), str(self.name), "processed", "stage1")

    @property
    def processed_dir(self) -> str:
        return os.path.join(
            str(self.root),
            str(self.name),
            "processed",
            "stage2",
            self.fingerprint,
        )

    @property
    def processed_file_names(self) -> str:
        return "data.pt"

load_file(triplet_file, inv_entity_vocab, inv_rel_vocab)

Load a knowledge graph file and return the processed data.

Source code in gfmrag/datasets/kg_dataset.py
Python
def load_file(
    self, triplet_file: str, inv_entity_vocab: dict, inv_rel_vocab: dict
) -> dict:
    """Load a knowledge graph file and return the processed data."""

    triplets = []  # Triples with inverse relations
    entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)

    with open(triplet_file, encoding="utf-8") as fin:
        for line in fin:
            try:
                u, r, v = (
                    line.split()
                    if self.delimiter is None
                    else line.strip().split(self.delimiter)
                )
            except Exception as e:
                logger.error(f"Error in line: {line}, {e}, Skipping")
                continue
            if u not in inv_entity_vocab:
                inv_entity_vocab[u] = entity_cnt
                entity_cnt += 1
            if v not in inv_entity_vocab:
                inv_entity_vocab[v] = entity_cnt
                entity_cnt += 1
            if r not in inv_rel_vocab:
                inv_rel_vocab[r] = rel_cnt
                rel_cnt += 1
            u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]

            triplets.append((u, v, r))

    return {
        "triplets": triplets,
        "num_node": len(inv_entity_vocab),  # entity_cnt,
        "num_relation": rel_cnt,
        "inv_entity_vocab": inv_entity_vocab,
        "inv_rel_vocab": inv_rel_vocab,
    }

process()

Process the knowledge graph dataset.

This method processes the raw knowledge graph file and creates the following:

  1. Loads the KG triplets and vocabulary
  2. Creates edge indices and types for both original and inverse relations
  3. Saves entity and relation mappings to JSON files
  4. Generates relation embeddings using a text embedding model
  5. Builds relation graphs
  6. Saves the processed data and model configurations

The processed data includes:

  • Edge indices and types for both original and inverse edges
  • Target edge indices and types (original edges only)
  • Number of nodes and relations
  • Relation embeddings
  • Relation graphs

Files created:

  • ent2id.json: Entity to ID mapping
  • rel2id.json: Relation to ID mapping (including inverse relations)
  • text_emb_model_cfgs.json: Text embedding model configuration
  • Processed graph data file at self.processed_paths[0]
Source code in gfmrag/datasets/kg_dataset.py
Python
def process(self) -> None:
    """Process the knowledge graph dataset.

    This method processes the raw knowledge graph file and creates the following:

    1. Loads the KG triplets and vocabulary
    2. Creates edge indices and types for both original and inverse relations
    3. Saves entity and relation mappings to JSON files
    4. Generates relation embeddings using a text embedding model
    5. Builds relation graphs
    6. Saves the processed data and model configurations

    The processed data includes:

    - Edge indices and types for both original and inverse edges
    - Target edge indices and types (original edges only)
    - Number of nodes and relations
    - Relation embeddings
    - Relation graphs

    Files created:

    - ent2id.json: Entity to ID mapping
    - rel2id.json: Relation to ID mapping (including inverse relations)
    - text_emb_model_cfgs.json: Text embedding model configuration
    - Processed graph data file at self.processed_paths[0]
    """
    kg_file = self.raw_paths[0]

    kg_result = self.load_file(kg_file, inv_entity_vocab={}, inv_rel_vocab={})

    # in some datasets, there are several new nodes in the test set, eg 123,143 YAGO train and 123,182 in YAGO test
    # for consistency with other experimental results, we'll include those in the full vocab and num nodes
    num_node = kg_result["num_node"]
    # the same for rels: in most cases train == test for transductive
    # for AristoV4 train rels 1593, test 1604
    num_relations = kg_result["num_relation"]

    kg_triplets = kg_result["triplets"]

    train_target_edges = torch.tensor(
        [[t[0], t[1]] for t in kg_triplets], dtype=torch.long
    ).t()
    train_target_etypes = torch.tensor([t[2] for t in kg_triplets])

    # Add inverse edges
    train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
    train_etypes = torch.cat(
        [train_target_etypes, train_target_etypes + num_relations]
    )

    with open(self.processed_dir + "/ent2id.json", "w") as f:
        json.dump(kg_result["inv_entity_vocab"], f)
    rel2id = kg_result["inv_rel_vocab"]
    id2rel = {v: k for k, v in rel2id.items()}
    for etype in train_etypes:
        if etype.item() >= num_relations:
            raw_etype = etype - num_relations
            raw_rel = id2rel[raw_etype.item()]
            rel2id["inverse_" + raw_rel] = etype.item()
    with open(self.processed_dir + "/rel2id.json", "w") as f:
        json.dump(rel2id, f)

    # Generate relation embeddings
    logger.info("Generating relation embeddings")
    text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
    rel_emb = text_emb_model.encode(list(rel2id.keys()), is_query=False).cpu()

    kg_data = Data(
        edge_index=train_edges,
        edge_type=train_etypes,
        num_nodes=num_node,
        target_edge_index=train_target_edges,
        target_edge_type=train_target_etypes,
        num_relations=num_relations * 2,
        rel_emb=rel_emb,
    )

    # build graphs of relations
    kg_data = build_relation_graph(kg_data)

    torch.save((self.collate([kg_data])), self.processed_paths[0])

    # Save text embeddings model configuration
    with open(self.processed_dir + "/text_emb_model_cfgs.json", "w") as f:
        json.dump(OmegaConf.to_container(self.text_emb_model_cfgs), f, indent=4)

QADataset

Bases: InMemoryDataset

A dataset class for Question-Answering tasks built on top of a Knowledge Graph.

This dataset inherits from torch_geometric's InMemoryDataset and processes raw QA data into a format suitable for graph-based QA models. It handles both training and test splits.

Parameters:

Name Type Description Default
root str

Root directory where the dataset should be saved.

required
data_name str

Name of the dataset.

required
text_emb_model_cfgs DictConfig

Configuration for the text embedding model used to encode questions.

required
force_rebuild bool

If True, forces the dataset to be reprocessed even if it exists. Defaults to False.

False

Attributes:

Name Type Description
name str

Name of the dataset.

kg KGDataset

The underlying knowledge graph dataset.

rel_emb_dim int

Dimension of relation embeddings.

ent2id dict

Mapping from entity names to IDs.

rel2id dict

Mapping from relation names to IDs.

doc dict

Corpus of documents.

doc2entities dict

Mapping from documents to contained entities.

raw_train_data list

Raw training data samples.

raw_test_data list

Raw test data samples.

ent2docs Tensor

Sparse tensor mapping entities to documents.

id2doc dict

Mapping from document IDs to document names.

Notes

The processed dataset contains: - Question embeddings - Question entity masks - Supporting entity masks - Supporting document masks - Sample IDs

The dataset processes raw JSON files and creates PyTorch tensors for efficient training.

Source code in gfmrag/datasets/qa_dataset.py
Python
class QADataset(InMemoryDataset):
    """A dataset class for Question-Answering tasks built on top of a Knowledge Graph.

    This dataset inherits from torch_geometric's InMemoryDataset and processes raw QA data
    into a format suitable for graph-based QA models. It handles both training and test splits.

    Args:
        root (str): Root directory where the dataset should be saved.
        data_name (str): Name of the dataset.
        text_emb_model_cfgs (DictConfig): Configuration for the text embedding model used to encode questions.
        force_rebuild (bool, optional): If True, forces the dataset to be reprocessed even if it exists. Defaults to False.

    Attributes:
        name (str): Name of the dataset.
        kg (KGDataset): The underlying knowledge graph dataset.
        rel_emb_dim (int): Dimension of relation embeddings.
        ent2id (dict): Mapping from entity names to IDs.
        rel2id (dict): Mapping from relation names to IDs.
        doc (dict): Corpus of documents.
        doc2entities (dict): Mapping from documents to contained entities.
        raw_train_data (list): Raw training data samples.
        raw_test_data (list): Raw test data samples.
        ent2docs (torch.Tensor): Sparse tensor mapping entities to documents.
        id2doc (dict): Mapping from document IDs to document names.

    Notes:
        The processed dataset contains:
        - Question embeddings
        - Question entity masks
        - Supporting entity masks
        - Supporting document masks
        - Sample IDs

        The dataset processes raw JSON files and creates PyTorch tensors for efficient training.
    """

    def __init__(
        self,
        root: str,
        data_name: str,
        text_emb_model_cfgs: DictConfig,
        force_rebuild: bool = False,
    ):
        self.name = data_name
        self.force_rebuild = force_rebuild
        self.text_emb_model_cfgs = text_emb_model_cfgs
        # Get fingerprint of the model configuration
        self.fingerprint = hashlib.md5(
            json.dumps(
                OmegaConf.to_container(text_emb_model_cfgs, resolve=True)
            ).encode()
        ).hexdigest()
        self.kg = KGDataset(root, data_name, text_emb_model_cfgs, force_rebuild)[0]
        self.rel_emb_dim = self.kg.rel_emb.shape[-1]
        super().__init__(root, None, None)
        self.data = torch.load(self.processed_paths[0], weights_only=False)
        self.load_property()

    def __repr__(self) -> str:
        return f"{self.name}()"

    @property
    def raw_file_names(self) -> list:
        return ["train.json", "test.json"]

    @property
    def raw_dir(self) -> str:
        return os.path.join(str(self.root), str(self.name), "processed", "stage1")

    @property
    def processed_dir(self) -> str:
        return os.path.join(
            str(self.root),
            str(self.name),
            "processed",
            "stage2",
            self.fingerprint,
        )

    @property
    def processed_file_names(self) -> str:
        return "qa_data.pt"

    def load_property(self) -> None:
        """
        Load necessary properties from the KG dataset.
        """
        with open(os.path.join(self.processed_dir, "ent2id.json")) as fin:
            self.ent2id = json.load(fin)
        with open(os.path.join(self.processed_dir, "rel2id.json")) as fin:
            self.rel2id = json.load(fin)
        with open(
            os.path.join(str(self.root), str(self.name), "raw", "dataset_corpus.json")
        ) as fin:
            self.doc = json.load(fin)
        with open(os.path.join(self.raw_dir, "document2entities.json")) as fin:
            self.doc2entities = json.load(fin)
        if os.path.exists(os.path.join(self.raw_dir, "train.json")):
            with open(os.path.join(self.raw_dir, "train.json")) as fin:
                self.raw_train_data = json.load(fin)
        else:
            self.raw_train_data = []
        if os.path.exists(os.path.join(self.raw_dir, "test.json")):
            with open(os.path.join(self.raw_dir, "test.json")) as fin:
                self.raw_test_data = json.load(fin)
        else:
            self.raw_test_data = []

        self.ent2docs = torch.load(
            os.path.join(self.processed_dir, "ent2doc.pt"), weights_only=False
        )  # (n_nodes, n_docs)
        self.id2doc = {i: doc for i, doc in enumerate(self.doc2entities)}

    def _process(self) -> None:
        if is_main_process():
            logger.info(f"Processing QA dataset {self.name} at rank {get_rank()}")
            f = osp.join(self.processed_dir, "pre_transform.pt")
            if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
                self.pre_transform
            ):
                warnings.warn(
                    f"The `pre_transform` argument differs from the one used in "
                    f"the pre-processed version of this dataset. If you want to "
                    f"make use of another pre-processing technique, make sure to "
                    f"delete '{self.processed_dir}' first",
                    stacklevel=1,
                )

            f = osp.join(self.processed_dir, "pre_filter.pt")
            if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
                self.pre_filter
            ):
                warnings.warn(
                    f"The `pre_filter` argument differs from the one used in "
                    f"the pre-processed version of this dataset. If you want to "
                    f"make use of another pre-fitering technique, make sure to "
                    f"delete '{self.processed_dir}' first",
                    stacklevel=1,
                )

            if self.force_rebuild or not files_exist(self.processed_paths):
                if self.log and "pytest" not in sys.modules:
                    print("Processing...", file=sys.stderr)

                makedirs(self.processed_dir)
                self.process()

                path = osp.join(self.processed_dir, "pre_transform.pt")
                torch.save(_repr(self.pre_transform), path)
                path = osp.join(self.processed_dir, "pre_filter.pt")
                torch.save(_repr(self.pre_filter), path)

                if self.log and "pytest" not in sys.modules:
                    print("Done!", file=sys.stderr)
        else:
            logger.info(
                f"Rank [{get_rank()}]: Waiting for main process to finish processing QA dataset {self.name}"
            )
        synchronize()

    def process(self) -> None:
        """Process and prepare the question-answering dataset.

        This method processes raw data files to create a structured dataset for question answering
        tasks. It performs the following main operations:

        1. Loads entity and relation mappings from processed files
        2. Creates entity-document mapping tensors
        3. Processes question samples to generate:
            - Question embeddings
            - Question entity masks
            - Supporting entity masks
            - Supporting document masks

        The processed dataset is saved as torch splits containing:

        - Question embeddings
        - Various mask tensors for entities and documents
        - Sample IDs

        Files created:

        - ent2doc.pt: Sparse tensor mapping entities to documents
        - qa_data.pt: Processed QA dataset
        - text_emb_model_cfgs.json: Text embedding model configuration

        The method also saves the text embedding model configuration.

        Returns:
            None
        """
        with open(os.path.join(self.processed_dir, "ent2id.json")) as fin:
            self.ent2id = json.load(fin)
        with open(os.path.join(self.processed_dir, "rel2id.json")) as fin:
            self.rel2id = json.load(fin)
        with open(os.path.join(self.raw_dir, "document2entities.json")) as fin:
            self.doc2entities = json.load(fin)

        num_nodes = self.kg.num_nodes
        doc2id = {doc: i for i, doc in enumerate(self.doc2entities)}
        # Convert document to entities to entity to document
        n_docs = len(self.doc2entities)
        # Create a sparse tensor for entity to document
        doc2ent = torch.zeros((n_docs, num_nodes))
        for doc, entities in self.doc2entities.items():
            entity_ids = [self.ent2id[ent] for ent in entities if ent in self.ent2id]
            doc2ent[doc2id[doc], entity_ids] = 1
        ent2doc = doc2ent.T.to_sparse()  # (n_nodes, n_docs)
        torch.save(ent2doc, os.path.join(self.processed_dir, "ent2doc.pt"))

        sample_id = []
        questions = []
        question_entities_masks = []  # Convert question entities to mask with number of nodes
        supporting_entities_masks = []
        supporting_docs_masks = []
        num_samples = []

        for path in self.raw_paths:
            if not os.path.exists(path):
                num_samples.append(0)
                continue  # Skip if the file does not exist
            num_sample = 0
            with open(path) as fin:
                data = json.load(fin)
                for index, item in enumerate(data):
                    question_entities = [
                        self.ent2id[x]
                        for x in item["question_entities"]
                        if x in self.ent2id
                    ]

                    supporting_entities = [
                        self.ent2id[x]
                        for x in item["supporting_entities"]
                        if x in self.ent2id
                    ]

                    supporting_docs = [
                        doc2id[doc] for doc in item["supporting_facts"] if doc in doc2id
                    ]

                    # Skip samples if any of the entities or documens are empty
                    if any(
                        len(x) == 0
                        for x in [
                            question_entities,
                            supporting_entities,
                            supporting_docs,
                        ]
                    ):
                        continue
                    num_sample += 1
                    sample_id.append(index)
                    question = item["question"]
                    questions.append(question)

                    question_entities_masks.append(
                        entities_to_mask(question_entities, num_nodes)
                    )

                    supporting_entities_masks.append(
                        entities_to_mask(supporting_entities, num_nodes)
                    )

                    supporting_docs_masks.append(
                        entities_to_mask(supporting_docs, n_docs)
                    )
                num_samples.append(num_sample)

        # Generate question embeddings
        logger.info("Generating question embeddings")
        text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
        question_embeddings = text_emb_model.encode(
            questions,
            is_query=True,
        ).cpu()
        question_entities_masks = torch.stack(question_entities_masks)
        supporting_entities_masks = torch.stack(supporting_entities_masks)
        supporting_docs_masks = torch.stack(supporting_docs_masks)
        sample_id = torch.tensor(sample_id, dtype=torch.long)

        dataset = datasets.Dataset.from_dict(
            {
                "question_embeddings": question_embeddings,
                "question_entities_masks": question_entities_masks,
                "supporting_entities_masks": supporting_entities_masks,
                "supporting_docs_masks": supporting_docs_masks,
                "sample_id": sample_id,
            }
        ).with_format("torch")
        offset = 0
        splits = []
        for num_sample in num_samples:
            split = torch_data.Subset(dataset, range(offset, offset + num_sample))
            splits.append(split)
            offset += num_sample
        torch.save(splits, self.processed_paths[0])

        # Save text embeddings model configuration
        with open(self.processed_dir + "/text_emb_model_cfgs.json", "w") as f:
            json.dump(OmegaConf.to_container(self.text_emb_model_cfgs), f, indent=4)

load_property()

Load necessary properties from the KG dataset.

Source code in gfmrag/datasets/qa_dataset.py
Python
def load_property(self) -> None:
    """
    Load necessary properties from the KG dataset.
    """
    with open(os.path.join(self.processed_dir, "ent2id.json")) as fin:
        self.ent2id = json.load(fin)
    with open(os.path.join(self.processed_dir, "rel2id.json")) as fin:
        self.rel2id = json.load(fin)
    with open(
        os.path.join(str(self.root), str(self.name), "raw", "dataset_corpus.json")
    ) as fin:
        self.doc = json.load(fin)
    with open(os.path.join(self.raw_dir, "document2entities.json")) as fin:
        self.doc2entities = json.load(fin)
    if os.path.exists(os.path.join(self.raw_dir, "train.json")):
        with open(os.path.join(self.raw_dir, "train.json")) as fin:
            self.raw_train_data = json.load(fin)
    else:
        self.raw_train_data = []
    if os.path.exists(os.path.join(self.raw_dir, "test.json")):
        with open(os.path.join(self.raw_dir, "test.json")) as fin:
            self.raw_test_data = json.load(fin)
    else:
        self.raw_test_data = []

    self.ent2docs = torch.load(
        os.path.join(self.processed_dir, "ent2doc.pt"), weights_only=False
    )  # (n_nodes, n_docs)
    self.id2doc = {i: doc for i, doc in enumerate(self.doc2entities)}

process()

Process and prepare the question-answering dataset.

This method processes raw data files to create a structured dataset for question answering tasks. It performs the following main operations:

  1. Loads entity and relation mappings from processed files
  2. Creates entity-document mapping tensors
  3. Processes question samples to generate:
    • Question embeddings
    • Question entity masks
    • Supporting entity masks
    • Supporting document masks

The processed dataset is saved as torch splits containing:

  • Question embeddings
  • Various mask tensors for entities and documents
  • Sample IDs

Files created:

  • ent2doc.pt: Sparse tensor mapping entities to documents
  • qa_data.pt: Processed QA dataset
  • text_emb_model_cfgs.json: Text embedding model configuration

The method also saves the text embedding model configuration.

Returns:

Type Description
None

None

Source code in gfmrag/datasets/qa_dataset.py
Python
def process(self) -> None:
    """Process and prepare the question-answering dataset.

    This method processes raw data files to create a structured dataset for question answering
    tasks. It performs the following main operations:

    1. Loads entity and relation mappings from processed files
    2. Creates entity-document mapping tensors
    3. Processes question samples to generate:
        - Question embeddings
        - Question entity masks
        - Supporting entity masks
        - Supporting document masks

    The processed dataset is saved as torch splits containing:

    - Question embeddings
    - Various mask tensors for entities and documents
    - Sample IDs

    Files created:

    - ent2doc.pt: Sparse tensor mapping entities to documents
    - qa_data.pt: Processed QA dataset
    - text_emb_model_cfgs.json: Text embedding model configuration

    The method also saves the text embedding model configuration.

    Returns:
        None
    """
    with open(os.path.join(self.processed_dir, "ent2id.json")) as fin:
        self.ent2id = json.load(fin)
    with open(os.path.join(self.processed_dir, "rel2id.json")) as fin:
        self.rel2id = json.load(fin)
    with open(os.path.join(self.raw_dir, "document2entities.json")) as fin:
        self.doc2entities = json.load(fin)

    num_nodes = self.kg.num_nodes
    doc2id = {doc: i for i, doc in enumerate(self.doc2entities)}
    # Convert document to entities to entity to document
    n_docs = len(self.doc2entities)
    # Create a sparse tensor for entity to document
    doc2ent = torch.zeros((n_docs, num_nodes))
    for doc, entities in self.doc2entities.items():
        entity_ids = [self.ent2id[ent] for ent in entities if ent in self.ent2id]
        doc2ent[doc2id[doc], entity_ids] = 1
    ent2doc = doc2ent.T.to_sparse()  # (n_nodes, n_docs)
    torch.save(ent2doc, os.path.join(self.processed_dir, "ent2doc.pt"))

    sample_id = []
    questions = []
    question_entities_masks = []  # Convert question entities to mask with number of nodes
    supporting_entities_masks = []
    supporting_docs_masks = []
    num_samples = []

    for path in self.raw_paths:
        if not os.path.exists(path):
            num_samples.append(0)
            continue  # Skip if the file does not exist
        num_sample = 0
        with open(path) as fin:
            data = json.load(fin)
            for index, item in enumerate(data):
                question_entities = [
                    self.ent2id[x]
                    for x in item["question_entities"]
                    if x in self.ent2id
                ]

                supporting_entities = [
                    self.ent2id[x]
                    for x in item["supporting_entities"]
                    if x in self.ent2id
                ]

                supporting_docs = [
                    doc2id[doc] for doc in item["supporting_facts"] if doc in doc2id
                ]

                # Skip samples if any of the entities or documens are empty
                if any(
                    len(x) == 0
                    for x in [
                        question_entities,
                        supporting_entities,
                        supporting_docs,
                    ]
                ):
                    continue
                num_sample += 1
                sample_id.append(index)
                question = item["question"]
                questions.append(question)

                question_entities_masks.append(
                    entities_to_mask(question_entities, num_nodes)
                )

                supporting_entities_masks.append(
                    entities_to_mask(supporting_entities, num_nodes)
                )

                supporting_docs_masks.append(
                    entities_to_mask(supporting_docs, n_docs)
                )
            num_samples.append(num_sample)

    # Generate question embeddings
    logger.info("Generating question embeddings")
    text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
    question_embeddings = text_emb_model.encode(
        questions,
        is_query=True,
    ).cpu()
    question_entities_masks = torch.stack(question_entities_masks)
    supporting_entities_masks = torch.stack(supporting_entities_masks)
    supporting_docs_masks = torch.stack(supporting_docs_masks)
    sample_id = torch.tensor(sample_id, dtype=torch.long)

    dataset = datasets.Dataset.from_dict(
        {
            "question_embeddings": question_embeddings,
            "question_entities_masks": question_entities_masks,
            "supporting_entities_masks": supporting_entities_masks,
            "supporting_docs_masks": supporting_docs_masks,
            "sample_id": sample_id,
        }
    ).with_format("torch")
    offset = 0
    splits = []
    for num_sample in num_samples:
        split = torch_data.Subset(dataset, range(offset, offset + num_sample))
        splits.append(split)
        offset += num_sample
    torch.save(splits, self.processed_paths[0])

    # Save text embeddings model configuration
    with open(self.processed_dir + "/text_emb_model_cfgs.json", "w") as f:
        json.dump(OmegaConf.to_container(self.text_emb_model_cfgs), f, indent=4)