Datasets
gfmrag.datasets
¶
KGDataset
¶
Bases: InMemoryDataset
A dataset class for processing and managing Knowledge Graph (KG) data.
This class extends InMemoryDataset to handle knowledge graph data, including entity-relation-entity triplets, and supports processing of both direct and inverse relations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
root
|
str
|
Root directory where the dataset should be saved. |
required |
data_name
|
str
|
Name of the dataset. |
required |
text_emb_model_cfgs
|
DictConfig
|
Configuration for the text embedding model. |
required |
force_rebuild
|
bool
|
Whether to force rebuilding the processed data. Defaults to False. |
False
|
**kwargs
|
str
|
Additional keyword arguments. |
{}
|
Attributes:
Name | Type | Description |
---|---|---|
name |
str
|
Name of the dataset. |
fingerprint |
str
|
MD5 hash of the text embedding model configuration. |
delimiter |
str
|
Delimiter used in the KG text file. |
data |
Data
|
Processed graph data object. |
slices |
dict
|
Data slices for batching. |
Note
- The class expects a 'kg.txt' file in the raw directory containing triplets.
- Processes both direct and inverse relations.
- Generates and stores relation embeddings using the specified text embedding model.
- Saves processed data along with entity and relation mappings.
Source code in gfmrag/datasets/kg_dataset.py
class KGDataset(InMemoryDataset):
"""A dataset class for processing and managing Knowledge Graph (KG) data.
This class extends InMemoryDataset to handle knowledge graph data, including entity-relation-entity triplets,
and supports processing of both direct and inverse relations.
Args:
root (str): Root directory where the dataset should be saved.
data_name (str): Name of the dataset.
text_emb_model_cfgs (DictConfig): Configuration for the text embedding model.
force_rebuild (bool, optional): Whether to force rebuilding the processed data. Defaults to False.
**kwargs (str): Additional keyword arguments.
Attributes:
name (str): Name of the dataset.
fingerprint (str): MD5 hash of the text embedding model configuration.
delimiter (str): Delimiter used in the KG text file.
data (Data): Processed graph data object.
slices (dict): Data slices for batching.
Note:
- The class expects a 'kg.txt' file in the raw directory containing triplets.
- Processes both direct and inverse relations.
- Generates and stores relation embeddings using the specified text embedding model.
- Saves processed data along with entity and relation mappings.
"""
delimiter = KG_DELIMITER
def __init__(
self,
root: str,
data_name: str,
text_emb_model_cfgs: DictConfig,
force_rebuild: bool = False,
**kwargs: str,
) -> None:
self.name = data_name
self.force_rebuild = force_rebuild
# Get fingerprint of the model configuration
self.fingerprint = hashlib.md5(
json.dumps(
OmegaConf.to_container(text_emb_model_cfgs, resolve=True)
).encode()
).hexdigest()
self.text_emb_model_cfgs = text_emb_model_cfgs
super().__init__(root, None, None)
self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)
@property
def raw_file_names(self) -> list:
return ["kg.txt"]
def load_file(
self, triplet_file: str, inv_entity_vocab: dict, inv_rel_vocab: dict
) -> dict:
"""Load a knowledge graph file and return the processed data."""
triplets = [] # Triples with inverse relations
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
with open(triplet_file, encoding="utf-8") as fin:
for line in fin:
try:
u, r, v = (
line.split()
if self.delimiter is None
else line.strip().split(self.delimiter)
)
except Exception as e:
logger.error(f"Error in line: {line}, {e}, Skipping")
continue
if u not in inv_entity_vocab:
inv_entity_vocab[u] = entity_cnt
entity_cnt += 1
if v not in inv_entity_vocab:
inv_entity_vocab[v] = entity_cnt
entity_cnt += 1
if r not in inv_rel_vocab:
inv_rel_vocab[r] = rel_cnt
rel_cnt += 1
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
triplets.append((u, v, r))
return {
"triplets": triplets,
"num_node": len(inv_entity_vocab), # entity_cnt,
"num_relation": rel_cnt,
"inv_entity_vocab": inv_entity_vocab,
"inv_rel_vocab": inv_rel_vocab,
}
def _process(self) -> None:
if is_main_process():
logger.info(f"Processing KG dataset {self.name} at rank {get_rank()}")
f = osp.join(self.processed_dir, "pre_transform.pt")
if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
self.pre_transform
):
warnings.warn( # noqa:B028
f"The `pre_transform` argument differs from the one used in "
f"the pre-processed version of this dataset. If you want to "
f"make use of another pre-processing technique, make sure to "
f"delete '{self.processed_dir}' first",
stacklevel=1,
)
f = osp.join(self.processed_dir, "pre_filter.pt")
if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
self.pre_filter
):
warnings.warn(
f"The `pre_filter` argument differs from the one used in "
f"the pre-processed version of this dataset. If you want to "
f"make use of another pre-fitering technique, make sure to "
f"delete '{self.processed_dir}' first",
stacklevel=1,
)
if self.force_rebuild or not files_exist(self.processed_paths):
if self.log and "pytest" not in sys.modules:
print("Processing...", file=sys.stderr)
makedirs(self.processed_dir)
self.process()
path = osp.join(self.processed_dir, "pre_transform.pt")
torch.save(_repr(self.pre_transform), path)
path = osp.join(self.processed_dir, "pre_filter.pt")
torch.save(_repr(self.pre_filter), path)
if self.log and "pytest" not in sys.modules:
print("Done!", file=sys.stderr)
else:
logger.info(
f"Rank [{get_rank()}]: Waiting for main process to finish processing KG dataset {self.name}"
)
synchronize()
def process(self) -> None:
"""Process the knowledge graph dataset.
This method processes the raw knowledge graph file and creates the following:
1. Loads the KG triplets and vocabulary
2. Creates edge indices and types for both original and inverse relations
3. Saves entity and relation mappings to JSON files
4. Generates relation embeddings using a text embedding model
5. Builds relation graphs
6. Saves the processed data and model configurations
The processed data includes:
- Edge indices and types for both original and inverse edges
- Target edge indices and types (original edges only)
- Number of nodes and relations
- Relation embeddings
- Relation graphs
Files created:
- ent2id.json: Entity to ID mapping
- rel2id.json: Relation to ID mapping (including inverse relations)
- text_emb_model_cfgs.json: Text embedding model configuration
- Processed graph data file at self.processed_paths[0]
"""
kg_file = self.raw_paths[0]
kg_result = self.load_file(kg_file, inv_entity_vocab={}, inv_rel_vocab={})
# in some datasets, there are several new nodes in the test set, eg 123,143 YAGO train and 123,182 in YAGO test
# for consistency with other experimental results, we'll include those in the full vocab and num nodes
num_node = kg_result["num_node"]
# the same for rels: in most cases train == test for transductive
# for AristoV4 train rels 1593, test 1604
num_relations = kg_result["num_relation"]
kg_triplets = kg_result["triplets"]
train_target_edges = torch.tensor(
[[t[0], t[1]] for t in kg_triplets], dtype=torch.long
).t()
train_target_etypes = torch.tensor([t[2] for t in kg_triplets])
# Add inverse edges
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
train_etypes = torch.cat(
[train_target_etypes, train_target_etypes + num_relations]
)
with open(self.processed_dir + "/ent2id.json", "w") as f:
json.dump(kg_result["inv_entity_vocab"], f)
rel2id = kg_result["inv_rel_vocab"]
id2rel = {v: k for k, v in rel2id.items()}
for etype in train_etypes:
if etype.item() >= num_relations:
raw_etype = etype - num_relations
raw_rel = id2rel[raw_etype.item()]
rel2id["inverse_" + raw_rel] = etype.item()
with open(self.processed_dir + "/rel2id.json", "w") as f:
json.dump(rel2id, f)
# Generate relation embeddings
logger.info("Generating relation embeddings")
text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
rel_emb = text_emb_model.encode(list(rel2id.keys()), is_query=False).cpu()
kg_data = Data(
edge_index=train_edges,
edge_type=train_etypes,
num_nodes=num_node,
target_edge_index=train_target_edges,
target_edge_type=train_target_etypes,
num_relations=num_relations * 2,
rel_emb=rel_emb,
)
# build graphs of relations
kg_data = build_relation_graph(kg_data)
torch.save((self.collate([kg_data])), self.processed_paths[0])
# Save text embeddings model configuration
with open(self.processed_dir + "/text_emb_model_cfgs.json", "w") as f:
json.dump(OmegaConf.to_container(self.text_emb_model_cfgs), f, indent=4)
def __repr__(self) -> str:
return f"{self.name}()"
@property
def num_relations(self) -> int:
return int(self.data.edge_type.max()) + 1
@property
def raw_dir(self) -> str:
return os.path.join(str(self.root), str(self.name), "processed", "stage1")
@property
def processed_dir(self) -> str:
return os.path.join(
str(self.root),
str(self.name),
"processed",
"stage2",
self.fingerprint,
)
@property
def processed_file_names(self) -> str:
return "data.pt"
load_file(triplet_file, inv_entity_vocab, inv_rel_vocab)
¶
Load a knowledge graph file and return the processed data.
Source code in gfmrag/datasets/kg_dataset.py
def load_file(
self, triplet_file: str, inv_entity_vocab: dict, inv_rel_vocab: dict
) -> dict:
"""Load a knowledge graph file and return the processed data."""
triplets = [] # Triples with inverse relations
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
with open(triplet_file, encoding="utf-8") as fin:
for line in fin:
try:
u, r, v = (
line.split()
if self.delimiter is None
else line.strip().split(self.delimiter)
)
except Exception as e:
logger.error(f"Error in line: {line}, {e}, Skipping")
continue
if u not in inv_entity_vocab:
inv_entity_vocab[u] = entity_cnt
entity_cnt += 1
if v not in inv_entity_vocab:
inv_entity_vocab[v] = entity_cnt
entity_cnt += 1
if r not in inv_rel_vocab:
inv_rel_vocab[r] = rel_cnt
rel_cnt += 1
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
triplets.append((u, v, r))
return {
"triplets": triplets,
"num_node": len(inv_entity_vocab), # entity_cnt,
"num_relation": rel_cnt,
"inv_entity_vocab": inv_entity_vocab,
"inv_rel_vocab": inv_rel_vocab,
}
process()
¶
Process the knowledge graph dataset.
This method processes the raw knowledge graph file and creates the following:
- Loads the KG triplets and vocabulary
- Creates edge indices and types for both original and inverse relations
- Saves entity and relation mappings to JSON files
- Generates relation embeddings using a text embedding model
- Builds relation graphs
- Saves the processed data and model configurations
The processed data includes:
- Edge indices and types for both original and inverse edges
- Target edge indices and types (original edges only)
- Number of nodes and relations
- Relation embeddings
- Relation graphs
Files created:
- ent2id.json: Entity to ID mapping
- rel2id.json: Relation to ID mapping (including inverse relations)
- text_emb_model_cfgs.json: Text embedding model configuration
- Processed graph data file at self.processed_paths[0]
Source code in gfmrag/datasets/kg_dataset.py
def process(self) -> None:
"""Process the knowledge graph dataset.
This method processes the raw knowledge graph file and creates the following:
1. Loads the KG triplets and vocabulary
2. Creates edge indices and types for both original and inverse relations
3. Saves entity and relation mappings to JSON files
4. Generates relation embeddings using a text embedding model
5. Builds relation graphs
6. Saves the processed data and model configurations
The processed data includes:
- Edge indices and types for both original and inverse edges
- Target edge indices and types (original edges only)
- Number of nodes and relations
- Relation embeddings
- Relation graphs
Files created:
- ent2id.json: Entity to ID mapping
- rel2id.json: Relation to ID mapping (including inverse relations)
- text_emb_model_cfgs.json: Text embedding model configuration
- Processed graph data file at self.processed_paths[0]
"""
kg_file = self.raw_paths[0]
kg_result = self.load_file(kg_file, inv_entity_vocab={}, inv_rel_vocab={})
# in some datasets, there are several new nodes in the test set, eg 123,143 YAGO train and 123,182 in YAGO test
# for consistency with other experimental results, we'll include those in the full vocab and num nodes
num_node = kg_result["num_node"]
# the same for rels: in most cases train == test for transductive
# for AristoV4 train rels 1593, test 1604
num_relations = kg_result["num_relation"]
kg_triplets = kg_result["triplets"]
train_target_edges = torch.tensor(
[[t[0], t[1]] for t in kg_triplets], dtype=torch.long
).t()
train_target_etypes = torch.tensor([t[2] for t in kg_triplets])
# Add inverse edges
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
train_etypes = torch.cat(
[train_target_etypes, train_target_etypes + num_relations]
)
with open(self.processed_dir + "/ent2id.json", "w") as f:
json.dump(kg_result["inv_entity_vocab"], f)
rel2id = kg_result["inv_rel_vocab"]
id2rel = {v: k for k, v in rel2id.items()}
for etype in train_etypes:
if etype.item() >= num_relations:
raw_etype = etype - num_relations
raw_rel = id2rel[raw_etype.item()]
rel2id["inverse_" + raw_rel] = etype.item()
with open(self.processed_dir + "/rel2id.json", "w") as f:
json.dump(rel2id, f)
# Generate relation embeddings
logger.info("Generating relation embeddings")
text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
rel_emb = text_emb_model.encode(list(rel2id.keys()), is_query=False).cpu()
kg_data = Data(
edge_index=train_edges,
edge_type=train_etypes,
num_nodes=num_node,
target_edge_index=train_target_edges,
target_edge_type=train_target_etypes,
num_relations=num_relations * 2,
rel_emb=rel_emb,
)
# build graphs of relations
kg_data = build_relation_graph(kg_data)
torch.save((self.collate([kg_data])), self.processed_paths[0])
# Save text embeddings model configuration
with open(self.processed_dir + "/text_emb_model_cfgs.json", "w") as f:
json.dump(OmegaConf.to_container(self.text_emb_model_cfgs), f, indent=4)
QADataset
¶
Bases: InMemoryDataset
A dataset class for Question-Answering tasks built on top of a Knowledge Graph.
This dataset inherits from torch_geometric's InMemoryDataset and processes raw QA data into a format suitable for graph-based QA models. It handles both training and test splits.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
root
|
str
|
Root directory where the dataset should be saved. |
required |
data_name
|
str
|
Name of the dataset. |
required |
text_emb_model_cfgs
|
DictConfig
|
Configuration for the text embedding model used to encode questions. |
required |
force_rebuild
|
bool
|
If True, forces the dataset to be reprocessed even if it exists. Defaults to False. |
False
|
Attributes:
Name | Type | Description |
---|---|---|
name |
str
|
Name of the dataset. |
kg |
KGDataset
|
The underlying knowledge graph dataset. |
rel_emb_dim |
int
|
Dimension of relation embeddings. |
ent2id |
dict
|
Mapping from entity names to IDs. |
rel2id |
dict
|
Mapping from relation names to IDs. |
doc |
dict
|
Corpus of documents. |
doc2entities |
dict
|
Mapping from documents to contained entities. |
raw_train_data |
list
|
Raw training data samples. |
raw_test_data |
list
|
Raw test data samples. |
ent2docs |
Tensor
|
Sparse tensor mapping entities to documents. |
id2doc |
dict
|
Mapping from document IDs to document names. |
Notes
The processed dataset contains: - Question embeddings - Question entity masks - Supporting entity masks - Supporting document masks - Sample IDs
The dataset processes raw JSON files and creates PyTorch tensors for efficient training.
Source code in gfmrag/datasets/qa_dataset.py
class QADataset(InMemoryDataset):
"""A dataset class for Question-Answering tasks built on top of a Knowledge Graph.
This dataset inherits from torch_geometric's InMemoryDataset and processes raw QA data
into a format suitable for graph-based QA models. It handles both training and test splits.
Args:
root (str): Root directory where the dataset should be saved.
data_name (str): Name of the dataset.
text_emb_model_cfgs (DictConfig): Configuration for the text embedding model used to encode questions.
force_rebuild (bool, optional): If True, forces the dataset to be reprocessed even if it exists. Defaults to False.
Attributes:
name (str): Name of the dataset.
kg (KGDataset): The underlying knowledge graph dataset.
rel_emb_dim (int): Dimension of relation embeddings.
ent2id (dict): Mapping from entity names to IDs.
rel2id (dict): Mapping from relation names to IDs.
doc (dict): Corpus of documents.
doc2entities (dict): Mapping from documents to contained entities.
raw_train_data (list): Raw training data samples.
raw_test_data (list): Raw test data samples.
ent2docs (torch.Tensor): Sparse tensor mapping entities to documents.
id2doc (dict): Mapping from document IDs to document names.
Notes:
The processed dataset contains:
- Question embeddings
- Question entity masks
- Supporting entity masks
- Supporting document masks
- Sample IDs
The dataset processes raw JSON files and creates PyTorch tensors for efficient training.
"""
def __init__(
self,
root: str,
data_name: str,
text_emb_model_cfgs: DictConfig,
force_rebuild: bool = False,
):
self.name = data_name
self.force_rebuild = force_rebuild
self.text_emb_model_cfgs = text_emb_model_cfgs
# Get fingerprint of the model configuration
self.fingerprint = hashlib.md5(
json.dumps(
OmegaConf.to_container(text_emb_model_cfgs, resolve=True)
).encode()
).hexdigest()
self.kg = KGDataset(root, data_name, text_emb_model_cfgs, force_rebuild)[0]
self.rel_emb_dim = self.kg.rel_emb.shape[-1]
super().__init__(root, None, None)
self.data = torch.load(self.processed_paths[0], weights_only=False)
self.load_property()
def __repr__(self) -> str:
return f"{self.name}()"
@property
def raw_file_names(self) -> list:
return ["train.json", "test.json"]
@property
def raw_dir(self) -> str:
return os.path.join(str(self.root), str(self.name), "processed", "stage1")
@property
def processed_dir(self) -> str:
return os.path.join(
str(self.root),
str(self.name),
"processed",
"stage2",
self.fingerprint,
)
@property
def processed_file_names(self) -> str:
return "qa_data.pt"
def load_property(self) -> None:
"""
Load necessary properties from the KG dataset.
"""
with open(os.path.join(self.processed_dir, "ent2id.json")) as fin:
self.ent2id = json.load(fin)
with open(os.path.join(self.processed_dir, "rel2id.json")) as fin:
self.rel2id = json.load(fin)
with open(
os.path.join(str(self.root), str(self.name), "raw", "dataset_corpus.json")
) as fin:
self.doc = json.load(fin)
with open(os.path.join(self.raw_dir, "document2entities.json")) as fin:
self.doc2entities = json.load(fin)
if os.path.exists(os.path.join(self.raw_dir, "train.json")):
with open(os.path.join(self.raw_dir, "train.json")) as fin:
self.raw_train_data = json.load(fin)
else:
self.raw_train_data = []
if os.path.exists(os.path.join(self.raw_dir, "test.json")):
with open(os.path.join(self.raw_dir, "test.json")) as fin:
self.raw_test_data = json.load(fin)
else:
self.raw_test_data = []
self.ent2docs = torch.load(
os.path.join(self.processed_dir, "ent2doc.pt"), weights_only=False
) # (n_nodes, n_docs)
self.id2doc = {i: doc for i, doc in enumerate(self.doc2entities)}
def _process(self) -> None:
if is_main_process():
logger.info(f"Processing QA dataset {self.name} at rank {get_rank()}")
f = osp.join(self.processed_dir, "pre_transform.pt")
if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
self.pre_transform
):
warnings.warn(
f"The `pre_transform` argument differs from the one used in "
f"the pre-processed version of this dataset. If you want to "
f"make use of another pre-processing technique, make sure to "
f"delete '{self.processed_dir}' first",
stacklevel=1,
)
f = osp.join(self.processed_dir, "pre_filter.pt")
if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
self.pre_filter
):
warnings.warn(
f"The `pre_filter` argument differs from the one used in "
f"the pre-processed version of this dataset. If you want to "
f"make use of another pre-fitering technique, make sure to "
f"delete '{self.processed_dir}' first",
stacklevel=1,
)
if self.force_rebuild or not files_exist(self.processed_paths):
if self.log and "pytest" not in sys.modules:
print("Processing...", file=sys.stderr)
makedirs(self.processed_dir)
self.process()
path = osp.join(self.processed_dir, "pre_transform.pt")
torch.save(_repr(self.pre_transform), path)
path = osp.join(self.processed_dir, "pre_filter.pt")
torch.save(_repr(self.pre_filter), path)
if self.log and "pytest" not in sys.modules:
print("Done!", file=sys.stderr)
else:
logger.info(
f"Rank [{get_rank()}]: Waiting for main process to finish processing QA dataset {self.name}"
)
synchronize()
def process(self) -> None:
"""Process and prepare the question-answering dataset.
This method processes raw data files to create a structured dataset for question answering
tasks. It performs the following main operations:
1. Loads entity and relation mappings from processed files
2. Creates entity-document mapping tensors
3. Processes question samples to generate:
- Question embeddings
- Question entity masks
- Supporting entity masks
- Supporting document masks
The processed dataset is saved as torch splits containing:
- Question embeddings
- Various mask tensors for entities and documents
- Sample IDs
Files created:
- ent2doc.pt: Sparse tensor mapping entities to documents
- qa_data.pt: Processed QA dataset
- text_emb_model_cfgs.json: Text embedding model configuration
The method also saves the text embedding model configuration.
Returns:
None
"""
with open(os.path.join(self.processed_dir, "ent2id.json")) as fin:
self.ent2id = json.load(fin)
with open(os.path.join(self.processed_dir, "rel2id.json")) as fin:
self.rel2id = json.load(fin)
with open(os.path.join(self.raw_dir, "document2entities.json")) as fin:
self.doc2entities = json.load(fin)
num_nodes = self.kg.num_nodes
doc2id = {doc: i for i, doc in enumerate(self.doc2entities)}
# Convert document to entities to entity to document
n_docs = len(self.doc2entities)
# Create a sparse tensor for entity to document
doc2ent = torch.zeros((n_docs, num_nodes))
for doc, entities in self.doc2entities.items():
entity_ids = [self.ent2id[ent] for ent in entities if ent in self.ent2id]
doc2ent[doc2id[doc], entity_ids] = 1
ent2doc = doc2ent.T.to_sparse() # (n_nodes, n_docs)
torch.save(ent2doc, os.path.join(self.processed_dir, "ent2doc.pt"))
sample_id = []
questions = []
question_entities_masks = [] # Convert question entities to mask with number of nodes
supporting_entities_masks = []
supporting_docs_masks = []
num_samples = []
for path in self.raw_paths:
if not os.path.exists(path):
num_samples.append(0)
continue # Skip if the file does not exist
num_sample = 0
with open(path) as fin:
data = json.load(fin)
for index, item in enumerate(data):
question_entities = [
self.ent2id[x]
for x in item["question_entities"]
if x in self.ent2id
]
supporting_entities = [
self.ent2id[x]
for x in item["supporting_entities"]
if x in self.ent2id
]
supporting_docs = [
doc2id[doc] for doc in item["supporting_facts"] if doc in doc2id
]
# Skip samples if any of the entities or documens are empty
if any(
len(x) == 0
for x in [
question_entities,
supporting_entities,
supporting_docs,
]
):
continue
num_sample += 1
sample_id.append(index)
question = item["question"]
questions.append(question)
question_entities_masks.append(
entities_to_mask(question_entities, num_nodes)
)
supporting_entities_masks.append(
entities_to_mask(supporting_entities, num_nodes)
)
supporting_docs_masks.append(
entities_to_mask(supporting_docs, n_docs)
)
num_samples.append(num_sample)
# Generate question embeddings
logger.info("Generating question embeddings")
text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
question_embeddings = text_emb_model.encode(
questions,
is_query=True,
).cpu()
question_entities_masks = torch.stack(question_entities_masks)
supporting_entities_masks = torch.stack(supporting_entities_masks)
supporting_docs_masks = torch.stack(supporting_docs_masks)
sample_id = torch.tensor(sample_id, dtype=torch.long)
dataset = datasets.Dataset.from_dict(
{
"question_embeddings": question_embeddings,
"question_entities_masks": question_entities_masks,
"supporting_entities_masks": supporting_entities_masks,
"supporting_docs_masks": supporting_docs_masks,
"sample_id": sample_id,
}
).with_format("torch")
offset = 0
splits = []
for num_sample in num_samples:
split = torch_data.Subset(dataset, range(offset, offset + num_sample))
splits.append(split)
offset += num_sample
torch.save(splits, self.processed_paths[0])
# Save text embeddings model configuration
with open(self.processed_dir + "/text_emb_model_cfgs.json", "w") as f:
json.dump(OmegaConf.to_container(self.text_emb_model_cfgs), f, indent=4)
load_property()
¶
Load necessary properties from the KG dataset.
Source code in gfmrag/datasets/qa_dataset.py
def load_property(self) -> None:
"""
Load necessary properties from the KG dataset.
"""
with open(os.path.join(self.processed_dir, "ent2id.json")) as fin:
self.ent2id = json.load(fin)
with open(os.path.join(self.processed_dir, "rel2id.json")) as fin:
self.rel2id = json.load(fin)
with open(
os.path.join(str(self.root), str(self.name), "raw", "dataset_corpus.json")
) as fin:
self.doc = json.load(fin)
with open(os.path.join(self.raw_dir, "document2entities.json")) as fin:
self.doc2entities = json.load(fin)
if os.path.exists(os.path.join(self.raw_dir, "train.json")):
with open(os.path.join(self.raw_dir, "train.json")) as fin:
self.raw_train_data = json.load(fin)
else:
self.raw_train_data = []
if os.path.exists(os.path.join(self.raw_dir, "test.json")):
with open(os.path.join(self.raw_dir, "test.json")) as fin:
self.raw_test_data = json.load(fin)
else:
self.raw_test_data = []
self.ent2docs = torch.load(
os.path.join(self.processed_dir, "ent2doc.pt"), weights_only=False
) # (n_nodes, n_docs)
self.id2doc = {i: doc for i, doc in enumerate(self.doc2entities)}
process()
¶
Process and prepare the question-answering dataset.
This method processes raw data files to create a structured dataset for question answering tasks. It performs the following main operations:
- Loads entity and relation mappings from processed files
- Creates entity-document mapping tensors
- Processes question samples to generate:
- Question embeddings
- Question entity masks
- Supporting entity masks
- Supporting document masks
The processed dataset is saved as torch splits containing:
- Question embeddings
- Various mask tensors for entities and documents
- Sample IDs
Files created:
- ent2doc.pt: Sparse tensor mapping entities to documents
- qa_data.pt: Processed QA dataset
- text_emb_model_cfgs.json: Text embedding model configuration
The method also saves the text embedding model configuration.
Returns:
Type | Description |
---|---|
None
|
None |
Source code in gfmrag/datasets/qa_dataset.py
def process(self) -> None:
"""Process and prepare the question-answering dataset.
This method processes raw data files to create a structured dataset for question answering
tasks. It performs the following main operations:
1. Loads entity and relation mappings from processed files
2. Creates entity-document mapping tensors
3. Processes question samples to generate:
- Question embeddings
- Question entity masks
- Supporting entity masks
- Supporting document masks
The processed dataset is saved as torch splits containing:
- Question embeddings
- Various mask tensors for entities and documents
- Sample IDs
Files created:
- ent2doc.pt: Sparse tensor mapping entities to documents
- qa_data.pt: Processed QA dataset
- text_emb_model_cfgs.json: Text embedding model configuration
The method also saves the text embedding model configuration.
Returns:
None
"""
with open(os.path.join(self.processed_dir, "ent2id.json")) as fin:
self.ent2id = json.load(fin)
with open(os.path.join(self.processed_dir, "rel2id.json")) as fin:
self.rel2id = json.load(fin)
with open(os.path.join(self.raw_dir, "document2entities.json")) as fin:
self.doc2entities = json.load(fin)
num_nodes = self.kg.num_nodes
doc2id = {doc: i for i, doc in enumerate(self.doc2entities)}
# Convert document to entities to entity to document
n_docs = len(self.doc2entities)
# Create a sparse tensor for entity to document
doc2ent = torch.zeros((n_docs, num_nodes))
for doc, entities in self.doc2entities.items():
entity_ids = [self.ent2id[ent] for ent in entities if ent in self.ent2id]
doc2ent[doc2id[doc], entity_ids] = 1
ent2doc = doc2ent.T.to_sparse() # (n_nodes, n_docs)
torch.save(ent2doc, os.path.join(self.processed_dir, "ent2doc.pt"))
sample_id = []
questions = []
question_entities_masks = [] # Convert question entities to mask with number of nodes
supporting_entities_masks = []
supporting_docs_masks = []
num_samples = []
for path in self.raw_paths:
if not os.path.exists(path):
num_samples.append(0)
continue # Skip if the file does not exist
num_sample = 0
with open(path) as fin:
data = json.load(fin)
for index, item in enumerate(data):
question_entities = [
self.ent2id[x]
for x in item["question_entities"]
if x in self.ent2id
]
supporting_entities = [
self.ent2id[x]
for x in item["supporting_entities"]
if x in self.ent2id
]
supporting_docs = [
doc2id[doc] for doc in item["supporting_facts"] if doc in doc2id
]
# Skip samples if any of the entities or documens are empty
if any(
len(x) == 0
for x in [
question_entities,
supporting_entities,
supporting_docs,
]
):
continue
num_sample += 1
sample_id.append(index)
question = item["question"]
questions.append(question)
question_entities_masks.append(
entities_to_mask(question_entities, num_nodes)
)
supporting_entities_masks.append(
entities_to_mask(supporting_entities, num_nodes)
)
supporting_docs_masks.append(
entities_to_mask(supporting_docs, n_docs)
)
num_samples.append(num_sample)
# Generate question embeddings
logger.info("Generating question embeddings")
text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
question_embeddings = text_emb_model.encode(
questions,
is_query=True,
).cpu()
question_entities_masks = torch.stack(question_entities_masks)
supporting_entities_masks = torch.stack(supporting_entities_masks)
supporting_docs_masks = torch.stack(supporting_docs_masks)
sample_id = torch.tensor(sample_id, dtype=torch.long)
dataset = datasets.Dataset.from_dict(
{
"question_embeddings": question_embeddings,
"question_entities_masks": question_entities_masks,
"supporting_entities_masks": supporting_entities_masks,
"supporting_docs_masks": supporting_docs_masks,
"sample_id": sample_id,
}
).with_format("torch")
offset = 0
splits = []
for num_sample in num_samples:
split = torch_data.Subset(dataset, range(offset, offset + num_sample))
splits.append(split)
offset += num_sample
torch.save(splits, self.processed_paths[0])
# Save text embeddings model configuration
with open(self.processed_dir + "/text_emb_model_cfgs.json", "w") as f:
json.dump(OmegaConf.to_container(self.text_emb_model_cfgs), f, indent=4)