Graph Index Datasets
gfmrag.graph_index_datasets
¶
GraphDatasetLoader
¶
On-demand data loader for multiple datasets with LRU caching and async loading.
Source code in gfmrag/graph_index_datasets/graph_dataset_loader.py
class GraphDatasetLoader:
"""
On-demand data loader for multiple datasets with LRU caching and async loading.
"""
def __init__(
self,
datasets_cfg: DictConfig,
data_names: list[str],
shuffle: bool = True,
max_datasets_in_memory: int = 1,
data_loading_workers: int = 2,
) -> None:
"""
Initialize the data loader.
Args:
datasets_cfg (DictConfig): Configuration for datasets.
data_names (list[str]): List of dataset names to load.
shuffle (bool): Whether to shuffle the datasets.
max_datasets_in_memory (int): Maximum number of datasets to keep in memory.
data_loading_workers (int): Number of workers for async loading.
"""
self.datasets_cfg = datasets_cfg
self.data_names = data_names
self.shuffle = shuffle
self.max_datasets_in_memory = max_datasets_in_memory
self.data_loading_workers = data_loading_workers
# Use OrderedDict to maintain LRU cache
self.loaded_datasets: OrderedDict[str, Any] = OrderedDict()
# Multiprocessing components
self.executor: None | ProcessPoolExecutor = None
self.loading_futures: dict[str, Future] = {} # Track ongoing loading tasks
self.loading_lock = (
threading.Lock()
) # Protect concurrent access to loading_futures
# Initialize process pool only if max_workers > 0
if self.data_loading_workers > 0:
self._init_process_pool()
else:
self.executor = None
def _init_process_pool(self) -> None:
"""Initialize the process pool executor"""
if self.data_loading_workers <= 0:
self.executor = None
return
# Use spawn method to avoid issues with CUDA contexts
mp_context = mp.get_context("spawn")
self.executor = ProcessPoolExecutor(
max_workers=self.data_loading_workers, mp_context=mp_context
)
def __del__(self) -> None:
"""Cleanup when object is destroyed"""
self.shutdown()
def shutdown(self) -> None:
"""Shutdown the process pool executor"""
if hasattr(self, "executor") and self.executor:
# Cancel all pending futures
with self.loading_lock:
for future in self.loading_futures.values():
future.cancel()
self.loading_futures.clear()
self.executor.shutdown(wait=False)
self.executor = None
def set_epoch(self, epoch: int) -> None:
np.random.seed(epoch)
def _manage_memory(self) -> None:
"""Manage memory to ensure not exceeding maximum dataset count"""
while len(self.loaded_datasets) >= self.max_datasets_in_memory:
# Remove the oldest dataset
oldest_name, oldest_dataset = self.loaded_datasets.popitem(last=False)
del oldest_dataset
gc.collect()
def _start_async_loading(self, data_names: list[str]) -> None:
"""Start async loading for multiple datasets (up to max_workers)"""
# Skip async loading if max_workers is 0
if self.data_loading_workers <= 0 or not self.executor:
return
with self.loading_lock:
# Calculate how many we can start loading
available_workers = self.data_loading_workers - len(self.loading_futures)
datasets_to_load = []
for data_name in data_names[:available_workers]:
# Skip if already loaded or currently being loaded
if (
data_name in self.loaded_datasets
or data_name in self.loading_futures
):
continue
datasets_to_load.append(data_name)
# Start async loading for each dataset
for data_name in datasets_to_load:
try:
# Convert DictConfig to dict for serialization
datasets_cfg_dict = OmegaConf.to_container(
self.datasets_cfg, resolve=True
)
future = self.executor.submit(
_load_dataset_worker,
datasets_cfg_dict,
data_name,
)
self.loading_futures[data_name] = future
except Exception as e:
print(f"Failed to start async loading for {data_name}: {e}")
def _get_next_datasets_to_prefetch(
self, current_index: int, data_name_list: list[str]
) -> list[str]:
"""Get the next max_workers datasets to prefetch"""
start_idx = current_index + 1
end_idx = min(start_idx + self.data_loading_workers, len(data_name_list))
return data_name_list[start_idx:end_idx]
def _wait_for_dataset(self, data_name: str, timeout: float = 30.0) -> dict | None:
"""Wait for async loading to complete and return dataset"""
with self.loading_lock:
if data_name not in self.loading_futures:
return None
future = self.loading_futures.pop(data_name)
try:
# Wait for the loading to complete
dataset = future.result(timeout=timeout)
return dataset
except Exception as e:
print(f"Error loading dataset {data_name}: {e}")
return None
def _cleanup_completed_futures(self) -> None:
"""Clean up completed futures and store results"""
if not self.executor:
return
with self.loading_lock:
completed_names = []
for name, future in self.loading_futures.items():
if future.done():
try:
# Try to get the result to handle any exceptions
dataset = future.result()
if dataset is not None:
# Only store if we have space and don't already have it
if name not in self.loaded_datasets:
self._manage_memory()
self.loaded_datasets[name] = dataset
completed_names.append(name)
except Exception as e:
print(f"Error in background loading of {name}: {e}")
completed_names.append(name)
# Remove completed futures
for name in completed_names:
self.loading_futures.pop(name, None)
def _preload_datasets(self, data_name_list: list[str]) -> None:
"""Preload datasets up to memory limit"""
# Preload first few datasets synchronously
sync_preload_count = min(self.max_datasets_in_memory, len(data_name_list))
for i in range(sync_preload_count):
data_name = data_name_list[i]
if data_name not in self.loaded_datasets:
self._manage_memory()
dataset = _load_dataset_worker(
OmegaConf.to_container(self.datasets_cfg, resolve=True), data_name
)
self.loaded_datasets[data_name] = dataset
# Start async loading for next max_workers datasets only if max_workers > 0
if self.data_loading_workers > 0 and sync_preload_count < len(data_name_list):
async_start_idx = sync_preload_count
async_datasets = self._get_next_datasets_to_prefetch(
async_start_idx - 1, data_name_list
)
if async_datasets:
self._start_async_loading(async_datasets)
def _get_dataset(self, data_name: str) -> Any:
"""Get dataset, load if not in memory"""
# Clean up any completed background loading first
self._cleanup_completed_futures()
if data_name in self.loaded_datasets:
# Move to most recently used position (LRU update)
dataset = self.loaded_datasets.pop(data_name)
self.loaded_datasets[data_name] = dataset
return dataset
else:
# Check if it's being loaded asynchronously
if data_name in self.loading_futures:
dataset = self._wait_for_dataset(data_name)
if dataset is not None:
self._manage_memory()
self.loaded_datasets[data_name] = dataset
return dataset
# Dataset not in memory and not being loaded, load synchronously
self._manage_memory()
dataset = _load_dataset_worker(
OmegaConf.to_container(self.datasets_cfg, resolve=True),
data_name,
)
self.loaded_datasets[data_name] = dataset
return dataset
def __iter__(self) -> Generator[GraphDataset, None, None]:
data_name_list = self.data_names.copy()
if self.shuffle:
np.random.shuffle(data_name_list)
# Preload datasets
self._preload_datasets(data_name_list)
for i, data_name in enumerate(data_name_list):
# Start async loading for next max_workers datasets only if max_workers > 0
if self.data_loading_workers > 0:
next_datasets = self._get_next_datasets_to_prefetch(i, data_name_list)
if next_datasets:
self._start_async_loading(next_datasets)
# Get current dataset
dataset = self._get_dataset(data_name)
yield GraphDataset(name=data_name, data=dataset)
def clear_cache(self) -> None:
"""Clear all cached datasets and cancel pending loads"""
# Cancel all pending async loads
with self.loading_lock:
for future in self.loading_futures.values():
future.cancel()
self.loading_futures.clear()
self.loaded_datasets.clear()
gc.collect()
def get_memory_info(self) -> dict:
"""Get current memory usage information"""
with self.loading_lock:
loading_count = len(self.loading_futures)
loading_names = list(self.loading_futures.keys())
return {
"loaded_datasets_count": len(self.loaded_datasets),
"max_datasets_in_memory": self.max_datasets_in_memory,
"loaded_dataset_names": list(self.loaded_datasets.keys()),
"async_loading_count": loading_count,
"async_loading_names": loading_names,
"max_workers": self.data_loading_workers,
}
def wait_for_all_loading(self, timeout: float = 60.0) -> None:
"""Wait for all async loading to complete"""
if not self.executor:
return
start_time = time.time()
while True:
with self.loading_lock:
if not self.loading_futures:
break
# Check for completed futures
self._cleanup_completed_futures()
# Check timeout
if time.time() - start_time > timeout:
print("Warning: Timeout waiting for async loading to complete")
break
time.sleep(0.1)
__del__()
¶
__init__(datasets_cfg, data_names, shuffle=True, max_datasets_in_memory=1, data_loading_workers=2)
¶
Initialize the data loader. Args: datasets_cfg (DictConfig): Configuration for datasets. data_names (list[str]): List of dataset names to load. shuffle (bool): Whether to shuffle the datasets. max_datasets_in_memory (int): Maximum number of datasets to keep in memory. data_loading_workers (int): Number of workers for async loading.
Source code in gfmrag/graph_index_datasets/graph_dataset_loader.py
def __init__(
self,
datasets_cfg: DictConfig,
data_names: list[str],
shuffle: bool = True,
max_datasets_in_memory: int = 1,
data_loading_workers: int = 2,
) -> None:
"""
Initialize the data loader.
Args:
datasets_cfg (DictConfig): Configuration for datasets.
data_names (list[str]): List of dataset names to load.
shuffle (bool): Whether to shuffle the datasets.
max_datasets_in_memory (int): Maximum number of datasets to keep in memory.
data_loading_workers (int): Number of workers for async loading.
"""
self.datasets_cfg = datasets_cfg
self.data_names = data_names
self.shuffle = shuffle
self.max_datasets_in_memory = max_datasets_in_memory
self.data_loading_workers = data_loading_workers
# Use OrderedDict to maintain LRU cache
self.loaded_datasets: OrderedDict[str, Any] = OrderedDict()
# Multiprocessing components
self.executor: None | ProcessPoolExecutor = None
self.loading_futures: dict[str, Future] = {} # Track ongoing loading tasks
self.loading_lock = (
threading.Lock()
) # Protect concurrent access to loading_futures
# Initialize process pool only if max_workers > 0
if self.data_loading_workers > 0:
self._init_process_pool()
else:
self.executor = None
clear_cache()
¶
Clear all cached datasets and cancel pending loads
Source code in gfmrag/graph_index_datasets/graph_dataset_loader.py
get_memory_info()
¶
Get current memory usage information
Source code in gfmrag/graph_index_datasets/graph_dataset_loader.py
def get_memory_info(self) -> dict:
"""Get current memory usage information"""
with self.loading_lock:
loading_count = len(self.loading_futures)
loading_names = list(self.loading_futures.keys())
return {
"loaded_datasets_count": len(self.loaded_datasets),
"max_datasets_in_memory": self.max_datasets_in_memory,
"loaded_dataset_names": list(self.loaded_datasets.keys()),
"async_loading_count": loading_count,
"async_loading_names": loading_names,
"max_workers": self.data_loading_workers,
}
shutdown()
¶
Shutdown the process pool executor
Source code in gfmrag/graph_index_datasets/graph_dataset_loader.py
def shutdown(self) -> None:
"""Shutdown the process pool executor"""
if hasattr(self, "executor") and self.executor:
# Cancel all pending futures
with self.loading_lock:
for future in self.loading_futures.values():
future.cancel()
self.loading_futures.clear()
self.executor.shutdown(wait=False)
self.executor = None
wait_for_all_loading(timeout=60.0)
¶
Wait for all async loading to complete
Source code in gfmrag/graph_index_datasets/graph_dataset_loader.py
def wait_for_all_loading(self, timeout: float = 60.0) -> None:
"""Wait for all async loading to complete"""
if not self.executor:
return
start_time = time.time()
while True:
with self.loading_lock:
if not self.loading_futures:
break
# Check for completed futures
self._cleanup_completed_futures()
# Check timeout
if time.time() - start_time > timeout:
print("Warning: Timeout waiting for async loading to complete")
break
time.sleep(0.1)
GraphIndexDataset
¶
A dataset class for processing and managing graph index data.
GraphDataset provides an unified interface for loading, processing, and managing graph data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
root
|
str
|
Root directory where the dataset should be saved. |
required |
data_name
|
str
|
Name of the dataset. |
required |
text_emb_model_cfgs
|
DictConfig
|
Configuration for the text embedding model. |
required |
force_reload
|
bool
|
Whether to force rebuilding the processed data. Defaults to False. |
False
|
use_node_feat
|
bool
|
Whether to use node features. Defaults to True. |
True
|
use_relation_feat
|
bool
|
Whether to use relation features. Defaults to True. |
True
|
use_edge_feat
|
bool
|
Whether to use edge features. Defaults to False. |
False
|
inverse_relation_feat
|
Literal['text', 'inverse']
|
How to handle inverse relations. - 'text': Generate text embeddings for inverse relations by adding "inverse_" prefix to the relation name. - 'inverse': Use the negative of the relation embeddings for inverse relations following: http://arxiv.org/abs/2505.20422 |
'text'
|
skip_empty_target
|
bool
|
Whether to skip samples with empty target nodes. Defaults to True. Can be set to False for QA tasks where some samples may not have target nodes. |
True
|
**kwargs
|
str
|
Additional keyword arguments. |
{}
|
Attributes:
| Name | Type | Description |
|---|---|---|
name |
str
|
Name of the dataset. |
fingerprint |
str
|
MD5 hash of the text embedding model configuration. |
graph |
Data
|
Processed graph data object. |
train_data |
Dataset | None
|
Training data. |
test_data |
Dataset | None
|
Testing data. |
feat_dim |
int
|
Dimension of the entity and relation embeddings. |
node2id |
dict
|
Mapping from node name or uid to continuous IDs. |
rel2id |
dict
|
Mapping from relation name or uid to continuous IDs. |
id2node |
dict
|
Dict[str, int]: Mapping from continuous IDs to node uid. |
doc |
dict
|
Dict[str, Any]: The original document data |
raw_train_data |
Dict[str, Any]
|
The raw training data. |
raw_test_data |
Dict[str, Any]
|
The raw testing data. |
Note
- The class expects 'edges.csv', 'nodes.csv', 'relations.csv' files in the stage1 directory.
- Processes both direct and inverse relations.
- Generates and stores node and relation embeddings using the specified text embedding model.
- Saves processed data along with entity and relation mappings.
Files created
- graph.pt: Contains the processed graph data.
- train.pt: Contains the processed training data, if available.
- test.pt: Contains the processed testing data, if available.
- node2id.json: Maps node name or uid to continuous IDs.
- rel2id.json: Maps relation name or uid to continuous IDs (including inverse relations).
- config.json: Contains the configuration of the text embedding model and dataset attributes.
Source code in gfmrag/graph_index_datasets/graph_index_dataset.py
class GraphIndexDataset:
"""A dataset class for processing and managing graph index data.
GraphDataset provides an unified interface for loading, processing, and managing graph data.
Args:
root (str): Root directory where the dataset should be saved.
data_name (str): Name of the dataset.
text_emb_model_cfgs (DictConfig): Configuration for the text embedding model.
force_reload (bool, optional): Whether to force rebuilding the processed data. Defaults to False.
use_node_feat (bool, optional): Whether to use node features. Defaults to True.
use_relation_feat (bool, optional): Whether to use relation features. Defaults to True.
use_edge_feat (bool, optional): Whether to use edge features. Defaults to False.
inverse_relation_feat (Literal['text', 'inverse'], optional): How to handle inverse relations.
- 'text': Generate text embeddings for inverse relations by adding "inverse_" prefix to the relation name.
- 'inverse': Use the negative of the relation embeddings for inverse relations following: http://arxiv.org/abs/2505.20422
skip_empty_target (bool, optional): Whether to skip samples with empty target nodes. Defaults to True. Can be set to False for QA tasks where some samples may not have target nodes.
**kwargs (str): Additional keyword arguments.
Attributes:
name (str): Name of the dataset.
fingerprint (str): MD5 hash of the text embedding model configuration.
graph (Data): Processed graph data object.
train_data (torch.utils.data.Dataset | None): Training data.
test_data (torch.utils.data.Dataset | None): Testing data.
feat_dim (int): Dimension of the entity and relation embeddings.
node2id (dict): Mapping from node name or uid to continuous IDs.
rel2id (dict): Mapping from relation name or uid to continuous IDs.
id2node: Dict[str, int]: Mapping from continuous IDs to node uid.
doc: Dict[str, Any]: The original document data
raw_train_data (Dict[str, Any]): The raw training data.
raw_test_data (Dict[str, Any]): The raw testing data.
Note:
- The class expects 'edges.csv', 'nodes.csv', 'relations.csv' files in the stage1 directory.
- Processes both direct and inverse relations.
- Generates and stores node and relation embeddings using the specified text embedding model.
- Saves processed data along with entity and relation mappings.
Files created:
- graph.pt: Contains the processed graph data.
- train.pt: Contains the processed training data, if available.
- test.pt: Contains the processed testing data, if available.
- node2id.json: Maps node name or uid to continuous IDs.
- rel2id.json: Maps relation name or uid to continuous IDs (including inverse relations).
- config.json: Contains the configuration of the text embedding model and dataset attributes.
"""
FINGER_PRINT_ATTRS: ClassVar[list[str]] = [
"use_node_feat",
"use_relation_feat",
"use_edge_feat",
"inverse_relation_feat",
]
RAW_GRAPH_NAMES = ["nodes.csv", "relations.csv", "edges.csv"]
RAW_QA_DATA_NAMES = ["train.json", "test.json"]
RAW_DOCUMENT_NAME = "documents.json"
PROCESSED_GRAPH_NAMES = ["graph.pt", "node2id.json", "rel2id.json"]
PROCESSED_QA_DATA_NAMES = ["train.pt", "test.pt"]
@classmethod
def export_config_dict(
cls, dataset_cfgs: DictConfig | dict[str, Any]
) -> dict[str, Any]:
"""Build the persisted dataset config dict.
The returned structure intentionally matches the JSON produced by
``save_config()`` so it can also be embedded into model checkpoints.
"""
if isinstance(dataset_cfgs, DictConfig):
cfgs = OmegaConf.to_container(dataset_cfgs, resolve=True)
elif isinstance(dataset_cfgs, dict):
cfgs = dataset_cfgs
assert isinstance(cfgs, dict)
config = {
"class_name": cls.__name__,
"text_emb_model_cfgs": cfgs["text_emb_model_cfgs"],
}
for key in cls.FINGER_PRINT_ATTRS:
config[key] = cfgs.get(key)
return config
def __init__(
self,
root: str,
data_name: str,
text_emb_model_cfgs: DictConfig,
force_reload: bool = False,
use_node_feat: bool = True,
use_relation_feat: bool = True,
use_edge_feat: bool = False,
inverse_relation_feat: Literal["text", "inverse"] = "text",
skip_empty_target: bool = True,
**kwargs: str,
) -> None:
self.root = root
self.name = data_name
self.text_emb_model_cfgs = text_emb_model_cfgs
self.use_node_feat = use_node_feat
self.use_relation_feat = use_relation_feat
self.use_edge_feat = use_edge_feat
self.inverse_relation_feat = inverse_relation_feat
self.skip_empty_target = skip_empty_target
# Get fingerprint of the model configuration
cfgs = OmegaConf.to_container(text_emb_model_cfgs, resolve=True)
cfgs.pop("batch_size", None) # Remove batch_size for fingerprinting
for key in self.FINGER_PRINT_ATTRS:
cfgs[key] = getattr(self, key, None)
self.fingerprint = hashlib.md5(
(self.__class__.__name__ + json.dumps(cfgs)).encode()
).hexdigest()
graph_rebuild = self.load_graph(force_reload)
qa_rebuild = self.load_qa_data(graph_rebuild, force_reload)
if any([graph_rebuild, qa_rebuild]):
# Save the dataset configuration if the graph or QA data was rebuilt
self.save_config()
def load_graph(self, force_reload: bool = False) -> bool:
"""Load the processed graph data.
Setting attributes:
- self.graph: The processed graph data as a torch_geometric Data object.
- self.node2id: A dictionary mapping node name or uid to continuous IDs.
- self.rel2id: A dictionary mapping relation name or uid to continuous IDs.
- self.id2node: A dictionary mapping continuous IDs back to node uid or name.
- self.feat_dim: The dimension of the entity and relation embeddings.
Args:
force_reload (bool): Whether to force reload the graph data. If True, it will process the graph even if the processed files exist.
Returns:
bool: Whether the graph was rebuilt. If the graph was rebuilt, it returns True, otherwise False.
"""
rebuild_graph = False
if force_reload or not files_exist(self.processed_graph):
os.makedirs(self.processed_dir, exist_ok=True)
logger.warning(f"Processing graph for {self.name} at rank {get_rank()}")
self.process_graph()
self.graph = torch.load(self.processed_graph[0], weights_only=False)
with open(self.processed_graph[1]) as fin:
self.node2id = json.load(fin)
with open(self.processed_graph[2]) as fin:
self.rel2id = json.load(fin)
self.id2node = {v: k for k, v in self.node2id.items()}
self.feat_dim = self.graph.feat_dim
return rebuild_graph
def load_qa_data(self, graph_rebuild: bool, force_reload: bool = False) -> bool:
"""
Load the QA data.
Setting attributes:
- self.train_data: The processed training data as a torch.utils.data.Dataset.
- self.test_data: The processed testing data as a torch.utils.data.Dataset.
- self.raw_train_data: The raw training data as a dictionary.
- self.raw_test_data: The raw testing data as a dictionary.
- self.doc: The original document data as a dictionary.
Args:
graph_rebuild (bool): Whether the graph was rebuilt.
force_reload (bool): Whether to force reload the QA data. If True, it will process the QA data even if the processed files exist.
Returns:
bool: Whether the QA data was rebuilt.
"""
rebuild_qa_data = False
exist_raw_qa_data = []
# Check if any raw QA data files exist
for raw_data_name in self.raw_qa_data:
if os.path.exists(raw_data_name):
exist_raw_qa_data.append(raw_data_name)
if len(exist_raw_qa_data) > 0:
# Process the QA data if it does not exist or if force_reload is True or if the graph is rebuilt
need_to_process_qa_data = [
raw_data_name
for raw_data_name in exist_raw_qa_data
if not osp.exists(
osp.join(
self.processed_dir,
f"{osp.basename(raw_data_name).split('.')[0]}.pt",
)
)
]
if force_reload or graph_rebuild:
need_to_process_qa_data = exist_raw_qa_data
if len(need_to_process_qa_data) > 0:
logger.warning(
f"Processing QA data for {self.name} at rank {get_rank()}"
)
self.process_qa_data(need_to_process_qa_data)
rebuild_qa_data = True
# Load the processed QA data
if osp.exists(osp.join(self.processed_dir, "train.pt")):
self.train_data = torch.load(
osp.join(self.processed_dir, "train.pt"), weights_only=False
)
with open(os.path.join(self.raw_dir, "train.json")) as fin:
self.raw_train_data = json.load(fin)
else:
self.train_data = None
self.raw_train_data = None
if osp.exists(osp.join(self.processed_dir, "test.pt")):
self.test_data = torch.load(
osp.join(self.processed_dir, "test.pt"), weights_only=False
)
with open(os.path.join(self.raw_dir, "test.json")) as fin:
self.raw_test_data = json.load(fin)
else:
self.test_data = None
self.raw_test_data = None
with open(
os.path.join(str(self.root), str(self.name), "raw", self.RAW_DOCUMENT_NAME)
) as fin:
self.doc = json.load(fin)
return rebuild_qa_data
def attributes_to_text(self, attributes: dict | None = None, **kwargs: dict) -> str:
"""Return a string representation of the attributes.
Args:
attributes (dict | None): A dictionary of attributes.
**kwargs (dict): Additional keyword arguments to include in the string. The keys of the dictionary will be used as attribute names.
Returns:
str: A formatted string representation of the attributes.
Examples:
>>> attributes = {"description": "A node in the graph"}
>>> name = "Node1"
>>> print(attributes_to_text(attributes, name=name))
name: Node1
description: A node in the graph
"""
if attributes is None:
attributes = {}
if len(attributes) == 0 and len(kwargs) == 0:
raise ValueError(
"At least 'attributes' or other keyword arguments must be provided."
)
if len(attributes) > 0:
attr_str = "\n".join(
f"{key}: {value}"
for key, value in attributes.items()
if value is not None
)
else:
attr_str = ""
if len(kwargs) > 0:
additional_attrs = "\n".join(
f"{key}: {value}" for key, value in kwargs.items() if value is not None
)
else:
additional_attrs = ""
if attr_str and additional_attrs:
return f"{additional_attrs}\n{attr_str}".strip()
elif attr_str:
return attr_str.strip()
elif additional_attrs:
return additional_attrs.strip()
else:
return ""
def _read_csv_file(self, file_path: str) -> pd.DataFrame:
"""Read a CSV file and return a dict for nodes and relations.
Args:
file_path (str): Path to the CSV file.
Returns:
pd.DataFrame: A DataFrame
"""
if not osp.exists(file_path):
raise FileNotFoundError(f"File {file_path} does not exist.")
df = pd.read_csv(file_path, keep_default_na=False)
df["id"] = df.index # Add an ID column based on the index
# Change index to 'uid' or 'name' for nodes and 'relation' for relations
if "uid" in df.columns:
if df["uid"].nunique() != len(df):
raise ValueError(
f"The 'uid' column must contain unique values. Unique values found: {df['uid'].nunique()}, total rows: {len(df)}"
)
df = df.set_index("uid")
elif "name" in df.columns:
if df["name"].nunique() != len(df):
raise ValueError(
f"The 'name' column must contain unique values. Unique values found: {df['name'].nunique()}, total rows: {len(df)}"
)
df = df.set_index("name")
else:
raise ValueError(
"CSV file must contain either 'uid' or 'name' column as unique identifiers."
)
# Handle attributes
df["attributes"] = df["attributes"].apply(
lambda x: {} if pd.isna(x) else ast.literal_eval(x)
)
return df
def process_graph(self) -> None:
"""Process the graph index dataset.
This method processes the raw graph index file and creates the following:
1. Loads the nodes, edges, and relations from the raw files
2. Creates edge indices and types for both original and inverse relations
3. Saves entity and relation mappings to JSON files
4. Generates relation, entity, edges features using a text embedding model
5. Saves the processed data and model configurations
The processed data includes:
- Edge indices and types for both original and inverse edges
- Target edge indices and types (original edges only)
- Number of nodes and relations
- Relation embeddings
- Entity embeddings
Files created:
- graph.pt: Contains the processed graph data
- node2id.json: Node to ID mapping
- rel2id.json: Relation to ID mapping (including inverse relations)
"""
node_file, relation_file, edge_file = self.raw_graph
if (
not osp.exists(node_file)
or not osp.exists(relation_file)
or not osp.exists(edge_file)
):
raise FileNotFoundError(
f"Required files not found in {self.raw_dir}. "
"Please ensure 'nodes.csv', 'relations.csv', and 'edges.csv' exist."
)
# Load nodes
nodes_df = self._read_csv_file(node_file)
node2id = nodes_df["id"].to_dict() # Map name or uids to continuous IDs
nodes_type_id, node_type_names = pd.factorize(nodes_df["type"])
nodes_df["type_id"] = nodes_type_id # Add type ID column
# Create a tensor for node types
node_types = torch.LongTensor(nodes_type_id)
# Save node ids under each type for fast access
nodes_by_type = {}
# Group node id by type
node_types_group = nodes_df.groupby("type")["id"].apply(list).to_dict()
for node_type, node_ids in node_types_group.items():
nodes_by_type[node_type] = torch.LongTensor(node_ids)
# Load relations
relations_df = self._read_csv_file(relation_file)
rel2id = relations_df["id"].to_dict()
# Load triplets from edges.csv
edges_df = pd.read_csv(edge_file, keep_default_na=False)
edges_df["attributes"] = edges_df["attributes"].apply(
lambda x: {} if pd.isna(x) else ast.literal_eval(x)
)
# Vectorized mapping of source, target, and relation to IDs
edges_df["u"] = edges_df["source"].map(node2id)
edges_df["v"] = edges_df["target"].map(node2id)
edges_df["r"] = edges_df["relation"].map(rel2id)
# Filter out rows with missing node or relation IDs
valid_edges_df = edges_df.dropna(subset=["u", "v", "r"]).copy()
# Log skipped edges
skipped_edges = edges_df[edges_df[["u", "v", "r"]].isnull().any(axis=1)]
for _, row in skipped_edges.iterrows():
logger.warning(
f"Skipping edge with missing node or relation: {row['source']}, {row['relation']}, {row['target']}"
)
# Convert IDs to int and build edge tuples
edges = list(
zip(
valid_edges_df["u"].astype(int),
valid_edges_df["v"].astype(int),
valid_edges_df["r"].astype(int),
)
)
# # Sort the edges by source and target for consistency
# edges.sort(key=lambda x: (x[0], x[1]))
num_nodes = len(node2id)
num_relations = len(rel2id)
train_target_edges = torch.tensor(
[[t[0], t[1]] for t in edges], dtype=torch.long
).t()
train_target_etypes = torch.tensor([t[2] for t in edges])
# Add inverse edges
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
train_etypes = torch.cat(
[train_target_etypes, train_target_etypes + num_relations]
)
with open(self.processed_dir + "/node2id.json", "w") as f:
json.dump(node2id, f)
id2rel = {v: k for k, v in rel2id.items()}
for etype in train_etypes:
if etype.item() >= num_relations:
raw_etype = etype - num_relations
raw_rel = id2rel[raw_etype.item()]
rel2id["inverse_" + raw_rel] = etype.item()
with open(self.processed_dir + "/rel2id.json", "w") as f:
json.dump(rel2id, f)
# Instantiate the text embedding model if attributes are used
if self.use_node_feat or self.use_edge_feat or self.use_relation_feat:
text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
# Generate relation embeddings
if self.use_relation_feat:
logger.info("Generating relation embeddings")
relation_text_attributes = relations_df.apply(
lambda row: self.attributes_to_text(row["attributes"], name=row.name),
axis=1,
).to_list()
rel_emb = text_emb_model.encode(
relation_text_attributes, is_query=False
).cpu()
if self.inverse_relation_feat == "inverse":
# Inverse relations by adding the negative sign to the relation embeddings http://arxiv.org/abs/2505.20422
rel_emb = torch.cat([rel_emb, -rel_emb], dim=0)
elif self.inverse_relation_feat == "text":
inverse_relation_text_attributes = relations_df.apply(
lambda row: self.attributes_to_text(
row["attributes"], name="inverse_" + row.name
),
axis=1,
).to_list()
inverse_rel_emb = text_emb_model.encode(
inverse_relation_text_attributes, is_query=False
).cpu()
rel_emb = torch.cat([rel_emb, inverse_rel_emb], dim=0)
else:
rel_emb = None
# Generate entity embeddings
if self.use_node_feat:
node_text_attributes = nodes_df.apply(
lambda row: self.attributes_to_text(
row["attributes"], name=row.name, type=row["type"]
),
axis=1,
).to_list()
logger.info("Generating entity embeddings")
node_emb = text_emb_model.encode(node_text_attributes, is_query=False).cpu()
else:
node_emb = None
if self.use_edge_feat:
logger.info("Generating edge embeddings")
edge_text_attributes = edges_df.apply(
lambda row: self.attributes_to_text(
row["attributes"],
),
axis=1,
).to_list()
edge_emb = text_emb_model.encode(edge_text_attributes, is_query=False).cpu()
else:
edge_emb = None
# Get feature dimension
for emb in [node_emb, rel_emb, edge_emb]:
if emb is not None:
if emb.ndim != 2:
raise ValueError(
f"Expected 2D tensor for embeddings, got {emb.ndim}D tensor."
)
feat_dim = emb.size(1)
break
else:
feat_dim = 0 # No embeddings available
graph = Data(
node_type=node_types,
node_type_names=node_type_names,
nodes_by_type=nodes_by_type,
edge_index=train_edges,
edge_type=train_etypes,
num_nodes=num_nodes,
target_edge_index=train_target_edges,
target_edge_type=train_target_etypes,
num_relations=num_relations * 2,
x=node_emb,
rel_attr=rel_emb,
edge_attr=edge_emb,
feat_dim=feat_dim,
)
torch.save(graph, self.processed_graph[0])
def process_qa_data(self, qa_data_names: list) -> None:
"""Process and prepare the question-answering dataset.
This method processes raw data files to create a structured dataset for question answering
tasks. It performs the following main operations:
1. Loads entity and relation mappings from processed files
2. Creates entity-document mapping tensors
3. Processes question samples to generate:
- Question embeddings
- Combine masks for start nodes of each type
- Combine masks for target nodes of each type
The processed dataset is saved as torch splits containing:
- Question embeddings
- Start node masks
- Target node masks
- Sample IDs
Files created:
- train.pt: Contains the processed training data, if available
- test.pt: Contains the processed testing data, if available
Args:
qa_data_names (list): List of raw QA data file names to process.
Returns:
None
"""
num_nodes = self.graph.num_nodes
start_nodes_mask = []
target_nodes_mask = []
sample_id = []
questions = []
num_samples = []
for data_name in qa_data_names:
num_sample = 0
with open(data_name) as fin:
data = json.load(fin)
for item in data:
# Get start nodes and target nodes for each node type
start_nodes_ids = []
target_nodes_ids = []
for node in item["start_nodes"].values():
start_nodes_ids.extend(
[self.node2id[x] for x in node if x in self.node2id]
)
for node in item["target_nodes"].values():
target_nodes_ids.extend(
[self.node2id[x] for x in node if x in self.node2id]
)
# Skip samples if any of the entities or documens are empty
if len(start_nodes_ids) == 0:
logger.warning(
f"Skipping sample {item['id']} in {data_name} due to empty start nodes."
)
continue
if self.skip_empty_target and len(target_nodes_ids) == 0:
logger.warning(
f"Skipping sample {item['id']} in {data_name} due to empty target nodes."
)
continue
num_sample += 1
sample_id.append(item["id"])
question = item["question"]
questions.append(question)
# Create masks for start nodes and target nodes
start_nodes_mask.append(
entities_to_mask(start_nodes_ids, num_nodes)
)
target_nodes_mask.append(
entities_to_mask(target_nodes_ids, num_nodes)
)
num_samples.append(num_sample)
# Generate question embeddings
logger.info("Generating question embeddings")
text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
question_embeddings = text_emb_model.encode(
questions,
is_query=True,
).cpu()
start_nodes_mask = torch.stack(start_nodes_mask)
target_nodes_mask = torch.stack(target_nodes_mask)
dataset = datasets.Dataset.from_dict(
{
"question_embeddings": question_embeddings,
"start_nodes_mask": start_nodes_mask,
"target_nodes_mask": target_nodes_mask,
"id": sample_id,
}
).with_format("torch")
offset = 0
for raw_data_name, num_sample in zip(qa_data_names, num_samples):
split = torch_data.Subset(dataset, range(offset, offset + num_sample))
split_name = osp.basename(raw_data_name).split(".")[0]
processed_split_path = osp.join(self.processed_dir, f"{split_name}.pt")
torch.save(split, processed_split_path)
offset += num_sample
def __repr__(self) -> str:
return f"{self.name}()"
@property
def num_relations(self) -> int:
return self.graph.num_edge_types
@property
def raw_dir(self) -> str:
return os.path.join(str(self.root), str(self.name), "processed", "stage1")
@property
def raw_graph(self) -> list:
return [osp.join(self.raw_dir, name) for name in self.RAW_GRAPH_NAMES]
@property
def raw_qa_data(self) -> list:
return [osp.join(self.raw_dir, name) for name in self.RAW_QA_DATA_NAMES]
@property
def processed_dir(self) -> str:
return os.path.join(
str(self.root),
str(self.name),
"processed",
"stage2",
self.fingerprint,
)
@property
def processed_graph(self) -> list[str]:
r"""The names of the processed files in the dataset."""
return [
osp.join(self.processed_dir, name) for name in self.PROCESSED_GRAPH_NAMES
]
@property
def processed_qa_data(self) -> list[str]:
return [
osp.join(self.processed_dir, name) for name in self.PROCESSED_QA_DATA_NAMES
]
def save_config(self) -> None:
"""Save the configuration of the dataset to a JSON file."""
text_emb_model_cfgs = OmegaConf.to_container(
self.text_emb_model_cfgs, resolve=True
)
config = self.__class__.export_config_dict(
{
"text_emb_model_cfgs": text_emb_model_cfgs,
**{
key: getattr(self, key, None)
for key in self.__class__.FINGER_PRINT_ATTRS
},
}
)
with open(osp.join(self.processed_dir, "config.json"), "w") as f:
json.dump(config, f, indent=4)
processed_graph
property
¶
The names of the processed files in the dataset.
attributes_to_text(attributes=None, **kwargs)
¶
Return a string representation of the attributes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
attributes
|
dict | None
|
A dictionary of attributes. |
None
|
**kwargs
|
dict
|
Additional keyword arguments to include in the string. The keys of the dictionary will be used as attribute names. |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
A formatted string representation of the attributes. |
Examples:
>>> attributes = {"description": "A node in the graph"}
>>> name = "Node1"
>>> print(attributes_to_text(attributes, name=name))
name: Node1
description: A node in the graph
Source code in gfmrag/graph_index_datasets/graph_index_dataset.py
def attributes_to_text(self, attributes: dict | None = None, **kwargs: dict) -> str:
"""Return a string representation of the attributes.
Args:
attributes (dict | None): A dictionary of attributes.
**kwargs (dict): Additional keyword arguments to include in the string. The keys of the dictionary will be used as attribute names.
Returns:
str: A formatted string representation of the attributes.
Examples:
>>> attributes = {"description": "A node in the graph"}
>>> name = "Node1"
>>> print(attributes_to_text(attributes, name=name))
name: Node1
description: A node in the graph
"""
if attributes is None:
attributes = {}
if len(attributes) == 0 and len(kwargs) == 0:
raise ValueError(
"At least 'attributes' or other keyword arguments must be provided."
)
if len(attributes) > 0:
attr_str = "\n".join(
f"{key}: {value}"
for key, value in attributes.items()
if value is not None
)
else:
attr_str = ""
if len(kwargs) > 0:
additional_attrs = "\n".join(
f"{key}: {value}" for key, value in kwargs.items() if value is not None
)
else:
additional_attrs = ""
if attr_str and additional_attrs:
return f"{additional_attrs}\n{attr_str}".strip()
elif attr_str:
return attr_str.strip()
elif additional_attrs:
return additional_attrs.strip()
else:
return ""
export_config_dict(dataset_cfgs)
classmethod
¶
Build the persisted dataset config dict.
The returned structure intentionally matches the JSON produced by
save_config() so it can also be embedded into model checkpoints.
Source code in gfmrag/graph_index_datasets/graph_index_dataset.py
@classmethod
def export_config_dict(
cls, dataset_cfgs: DictConfig | dict[str, Any]
) -> dict[str, Any]:
"""Build the persisted dataset config dict.
The returned structure intentionally matches the JSON produced by
``save_config()`` so it can also be embedded into model checkpoints.
"""
if isinstance(dataset_cfgs, DictConfig):
cfgs = OmegaConf.to_container(dataset_cfgs, resolve=True)
elif isinstance(dataset_cfgs, dict):
cfgs = dataset_cfgs
assert isinstance(cfgs, dict)
config = {
"class_name": cls.__name__,
"text_emb_model_cfgs": cfgs["text_emb_model_cfgs"],
}
for key in cls.FINGER_PRINT_ATTRS:
config[key] = cfgs.get(key)
return config
load_graph(force_reload=False)
¶
Load the processed graph data.
Setting attributes
- self.graph: The processed graph data as a torch_geometric Data object.
- self.node2id: A dictionary mapping node name or uid to continuous IDs.
- self.rel2id: A dictionary mapping relation name or uid to continuous IDs.
- self.id2node: A dictionary mapping continuous IDs back to node uid or name.
- self.feat_dim: The dimension of the entity and relation embeddings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
force_reload
|
bool
|
Whether to force reload the graph data. If True, it will process the graph even if the processed files exist. |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
Whether the graph was rebuilt. If the graph was rebuilt, it returns True, otherwise False. |
Source code in gfmrag/graph_index_datasets/graph_index_dataset.py
def load_graph(self, force_reload: bool = False) -> bool:
"""Load the processed graph data.
Setting attributes:
- self.graph: The processed graph data as a torch_geometric Data object.
- self.node2id: A dictionary mapping node name or uid to continuous IDs.
- self.rel2id: A dictionary mapping relation name or uid to continuous IDs.
- self.id2node: A dictionary mapping continuous IDs back to node uid or name.
- self.feat_dim: The dimension of the entity and relation embeddings.
Args:
force_reload (bool): Whether to force reload the graph data. If True, it will process the graph even if the processed files exist.
Returns:
bool: Whether the graph was rebuilt. If the graph was rebuilt, it returns True, otherwise False.
"""
rebuild_graph = False
if force_reload or not files_exist(self.processed_graph):
os.makedirs(self.processed_dir, exist_ok=True)
logger.warning(f"Processing graph for {self.name} at rank {get_rank()}")
self.process_graph()
self.graph = torch.load(self.processed_graph[0], weights_only=False)
with open(self.processed_graph[1]) as fin:
self.node2id = json.load(fin)
with open(self.processed_graph[2]) as fin:
self.rel2id = json.load(fin)
self.id2node = {v: k for k, v in self.node2id.items()}
self.feat_dim = self.graph.feat_dim
return rebuild_graph
load_qa_data(graph_rebuild, force_reload=False)
¶
Load the QA data.
Setting attributes
- self.train_data: The processed training data as a torch.utils.data.Dataset.
- self.test_data: The processed testing data as a torch.utils.data.Dataset.
- self.raw_train_data: The raw training data as a dictionary.
- self.raw_test_data: The raw testing data as a dictionary.
- self.doc: The original document data as a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
graph_rebuild
|
bool
|
Whether the graph was rebuilt. |
required |
force_reload
|
bool
|
Whether to force reload the QA data. If True, it will process the QA data even if the processed files exist. |
False
|
Source code in gfmrag/graph_index_datasets/graph_index_dataset.py
def load_qa_data(self, graph_rebuild: bool, force_reload: bool = False) -> bool:
"""
Load the QA data.
Setting attributes:
- self.train_data: The processed training data as a torch.utils.data.Dataset.
- self.test_data: The processed testing data as a torch.utils.data.Dataset.
- self.raw_train_data: The raw training data as a dictionary.
- self.raw_test_data: The raw testing data as a dictionary.
- self.doc: The original document data as a dictionary.
Args:
graph_rebuild (bool): Whether the graph was rebuilt.
force_reload (bool): Whether to force reload the QA data. If True, it will process the QA data even if the processed files exist.
Returns:
bool: Whether the QA data was rebuilt.
"""
rebuild_qa_data = False
exist_raw_qa_data = []
# Check if any raw QA data files exist
for raw_data_name in self.raw_qa_data:
if os.path.exists(raw_data_name):
exist_raw_qa_data.append(raw_data_name)
if len(exist_raw_qa_data) > 0:
# Process the QA data if it does not exist or if force_reload is True or if the graph is rebuilt
need_to_process_qa_data = [
raw_data_name
for raw_data_name in exist_raw_qa_data
if not osp.exists(
osp.join(
self.processed_dir,
f"{osp.basename(raw_data_name).split('.')[0]}.pt",
)
)
]
if force_reload or graph_rebuild:
need_to_process_qa_data = exist_raw_qa_data
if len(need_to_process_qa_data) > 0:
logger.warning(
f"Processing QA data for {self.name} at rank {get_rank()}"
)
self.process_qa_data(need_to_process_qa_data)
rebuild_qa_data = True
# Load the processed QA data
if osp.exists(osp.join(self.processed_dir, "train.pt")):
self.train_data = torch.load(
osp.join(self.processed_dir, "train.pt"), weights_only=False
)
with open(os.path.join(self.raw_dir, "train.json")) as fin:
self.raw_train_data = json.load(fin)
else:
self.train_data = None
self.raw_train_data = None
if osp.exists(osp.join(self.processed_dir, "test.pt")):
self.test_data = torch.load(
osp.join(self.processed_dir, "test.pt"), weights_only=False
)
with open(os.path.join(self.raw_dir, "test.json")) as fin:
self.raw_test_data = json.load(fin)
else:
self.test_data = None
self.raw_test_data = None
with open(
os.path.join(str(self.root), str(self.name), "raw", self.RAW_DOCUMENT_NAME)
) as fin:
self.doc = json.load(fin)
return rebuild_qa_data
process_graph()
¶
Process the graph index dataset.
This method processes the raw graph index file and creates the following:
- Loads the nodes, edges, and relations from the raw files
- Creates edge indices and types for both original and inverse relations
- Saves entity and relation mappings to JSON files
- Generates relation, entity, edges features using a text embedding model
- Saves the processed data and model configurations
The processed data includes:
- Edge indices and types for both original and inverse edges
- Target edge indices and types (original edges only)
- Number of nodes and relations
- Relation embeddings
- Entity embeddings
Files created:
- graph.pt: Contains the processed graph data
- node2id.json: Node to ID mapping
- rel2id.json: Relation to ID mapping (including inverse relations)
Source code in gfmrag/graph_index_datasets/graph_index_dataset.py
def process_graph(self) -> None:
"""Process the graph index dataset.
This method processes the raw graph index file and creates the following:
1. Loads the nodes, edges, and relations from the raw files
2. Creates edge indices and types for both original and inverse relations
3. Saves entity and relation mappings to JSON files
4. Generates relation, entity, edges features using a text embedding model
5. Saves the processed data and model configurations
The processed data includes:
- Edge indices and types for both original and inverse edges
- Target edge indices and types (original edges only)
- Number of nodes and relations
- Relation embeddings
- Entity embeddings
Files created:
- graph.pt: Contains the processed graph data
- node2id.json: Node to ID mapping
- rel2id.json: Relation to ID mapping (including inverse relations)
"""
node_file, relation_file, edge_file = self.raw_graph
if (
not osp.exists(node_file)
or not osp.exists(relation_file)
or not osp.exists(edge_file)
):
raise FileNotFoundError(
f"Required files not found in {self.raw_dir}. "
"Please ensure 'nodes.csv', 'relations.csv', and 'edges.csv' exist."
)
# Load nodes
nodes_df = self._read_csv_file(node_file)
node2id = nodes_df["id"].to_dict() # Map name or uids to continuous IDs
nodes_type_id, node_type_names = pd.factorize(nodes_df["type"])
nodes_df["type_id"] = nodes_type_id # Add type ID column
# Create a tensor for node types
node_types = torch.LongTensor(nodes_type_id)
# Save node ids under each type for fast access
nodes_by_type = {}
# Group node id by type
node_types_group = nodes_df.groupby("type")["id"].apply(list).to_dict()
for node_type, node_ids in node_types_group.items():
nodes_by_type[node_type] = torch.LongTensor(node_ids)
# Load relations
relations_df = self._read_csv_file(relation_file)
rel2id = relations_df["id"].to_dict()
# Load triplets from edges.csv
edges_df = pd.read_csv(edge_file, keep_default_na=False)
edges_df["attributes"] = edges_df["attributes"].apply(
lambda x: {} if pd.isna(x) else ast.literal_eval(x)
)
# Vectorized mapping of source, target, and relation to IDs
edges_df["u"] = edges_df["source"].map(node2id)
edges_df["v"] = edges_df["target"].map(node2id)
edges_df["r"] = edges_df["relation"].map(rel2id)
# Filter out rows with missing node or relation IDs
valid_edges_df = edges_df.dropna(subset=["u", "v", "r"]).copy()
# Log skipped edges
skipped_edges = edges_df[edges_df[["u", "v", "r"]].isnull().any(axis=1)]
for _, row in skipped_edges.iterrows():
logger.warning(
f"Skipping edge with missing node or relation: {row['source']}, {row['relation']}, {row['target']}"
)
# Convert IDs to int and build edge tuples
edges = list(
zip(
valid_edges_df["u"].astype(int),
valid_edges_df["v"].astype(int),
valid_edges_df["r"].astype(int),
)
)
# # Sort the edges by source and target for consistency
# edges.sort(key=lambda x: (x[0], x[1]))
num_nodes = len(node2id)
num_relations = len(rel2id)
train_target_edges = torch.tensor(
[[t[0], t[1]] for t in edges], dtype=torch.long
).t()
train_target_etypes = torch.tensor([t[2] for t in edges])
# Add inverse edges
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
train_etypes = torch.cat(
[train_target_etypes, train_target_etypes + num_relations]
)
with open(self.processed_dir + "/node2id.json", "w") as f:
json.dump(node2id, f)
id2rel = {v: k for k, v in rel2id.items()}
for etype in train_etypes:
if etype.item() >= num_relations:
raw_etype = etype - num_relations
raw_rel = id2rel[raw_etype.item()]
rel2id["inverse_" + raw_rel] = etype.item()
with open(self.processed_dir + "/rel2id.json", "w") as f:
json.dump(rel2id, f)
# Instantiate the text embedding model if attributes are used
if self.use_node_feat or self.use_edge_feat or self.use_relation_feat:
text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
# Generate relation embeddings
if self.use_relation_feat:
logger.info("Generating relation embeddings")
relation_text_attributes = relations_df.apply(
lambda row: self.attributes_to_text(row["attributes"], name=row.name),
axis=1,
).to_list()
rel_emb = text_emb_model.encode(
relation_text_attributes, is_query=False
).cpu()
if self.inverse_relation_feat == "inverse":
# Inverse relations by adding the negative sign to the relation embeddings http://arxiv.org/abs/2505.20422
rel_emb = torch.cat([rel_emb, -rel_emb], dim=0)
elif self.inverse_relation_feat == "text":
inverse_relation_text_attributes = relations_df.apply(
lambda row: self.attributes_to_text(
row["attributes"], name="inverse_" + row.name
),
axis=1,
).to_list()
inverse_rel_emb = text_emb_model.encode(
inverse_relation_text_attributes, is_query=False
).cpu()
rel_emb = torch.cat([rel_emb, inverse_rel_emb], dim=0)
else:
rel_emb = None
# Generate entity embeddings
if self.use_node_feat:
node_text_attributes = nodes_df.apply(
lambda row: self.attributes_to_text(
row["attributes"], name=row.name, type=row["type"]
),
axis=1,
).to_list()
logger.info("Generating entity embeddings")
node_emb = text_emb_model.encode(node_text_attributes, is_query=False).cpu()
else:
node_emb = None
if self.use_edge_feat:
logger.info("Generating edge embeddings")
edge_text_attributes = edges_df.apply(
lambda row: self.attributes_to_text(
row["attributes"],
),
axis=1,
).to_list()
edge_emb = text_emb_model.encode(edge_text_attributes, is_query=False).cpu()
else:
edge_emb = None
# Get feature dimension
for emb in [node_emb, rel_emb, edge_emb]:
if emb is not None:
if emb.ndim != 2:
raise ValueError(
f"Expected 2D tensor for embeddings, got {emb.ndim}D tensor."
)
feat_dim = emb.size(1)
break
else:
feat_dim = 0 # No embeddings available
graph = Data(
node_type=node_types,
node_type_names=node_type_names,
nodes_by_type=nodes_by_type,
edge_index=train_edges,
edge_type=train_etypes,
num_nodes=num_nodes,
target_edge_index=train_target_edges,
target_edge_type=train_target_etypes,
num_relations=num_relations * 2,
x=node_emb,
rel_attr=rel_emb,
edge_attr=edge_emb,
feat_dim=feat_dim,
)
torch.save(graph, self.processed_graph[0])
process_qa_data(qa_data_names)
¶
Process and prepare the question-answering dataset.
This method processes raw data files to create a structured dataset for question answering tasks. It performs the following main operations:
- Loads entity and relation mappings from processed files
- Creates entity-document mapping tensors
- Processes question samples to generate:
- Question embeddings
- Combine masks for start nodes of each type
- Combine masks for target nodes of each type
The processed dataset is saved as torch splits containing:
- Question embeddings
- Start node masks
- Target node masks
- Sample IDs
Files created:
- train.pt: Contains the processed training data, if available
- test.pt: Contains the processed testing data, if available
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qa_data_names
|
list
|
List of raw QA data file names to process. |
required |
Returns:
| Type | Description |
|---|---|
None
|
None |
Source code in gfmrag/graph_index_datasets/graph_index_dataset.py
def process_qa_data(self, qa_data_names: list) -> None:
"""Process and prepare the question-answering dataset.
This method processes raw data files to create a structured dataset for question answering
tasks. It performs the following main operations:
1. Loads entity and relation mappings from processed files
2. Creates entity-document mapping tensors
3. Processes question samples to generate:
- Question embeddings
- Combine masks for start nodes of each type
- Combine masks for target nodes of each type
The processed dataset is saved as torch splits containing:
- Question embeddings
- Start node masks
- Target node masks
- Sample IDs
Files created:
- train.pt: Contains the processed training data, if available
- test.pt: Contains the processed testing data, if available
Args:
qa_data_names (list): List of raw QA data file names to process.
Returns:
None
"""
num_nodes = self.graph.num_nodes
start_nodes_mask = []
target_nodes_mask = []
sample_id = []
questions = []
num_samples = []
for data_name in qa_data_names:
num_sample = 0
with open(data_name) as fin:
data = json.load(fin)
for item in data:
# Get start nodes and target nodes for each node type
start_nodes_ids = []
target_nodes_ids = []
for node in item["start_nodes"].values():
start_nodes_ids.extend(
[self.node2id[x] for x in node if x in self.node2id]
)
for node in item["target_nodes"].values():
target_nodes_ids.extend(
[self.node2id[x] for x in node if x in self.node2id]
)
# Skip samples if any of the entities or documens are empty
if len(start_nodes_ids) == 0:
logger.warning(
f"Skipping sample {item['id']} in {data_name} due to empty start nodes."
)
continue
if self.skip_empty_target and len(target_nodes_ids) == 0:
logger.warning(
f"Skipping sample {item['id']} in {data_name} due to empty target nodes."
)
continue
num_sample += 1
sample_id.append(item["id"])
question = item["question"]
questions.append(question)
# Create masks for start nodes and target nodes
start_nodes_mask.append(
entities_to_mask(start_nodes_ids, num_nodes)
)
target_nodes_mask.append(
entities_to_mask(target_nodes_ids, num_nodes)
)
num_samples.append(num_sample)
# Generate question embeddings
logger.info("Generating question embeddings")
text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
question_embeddings = text_emb_model.encode(
questions,
is_query=True,
).cpu()
start_nodes_mask = torch.stack(start_nodes_mask)
target_nodes_mask = torch.stack(target_nodes_mask)
dataset = datasets.Dataset.from_dict(
{
"question_embeddings": question_embeddings,
"start_nodes_mask": start_nodes_mask,
"target_nodes_mask": target_nodes_mask,
"id": sample_id,
}
).with_format("torch")
offset = 0
for raw_data_name, num_sample in zip(qa_data_names, num_samples):
split = torch_data.Subset(dataset, range(offset, offset + num_sample))
split_name = osp.basename(raw_data_name).split(".")[0]
processed_split_path = osp.join(self.processed_dir, f"{split_name}.pt")
torch.save(split, processed_split_path)
offset += num_sample
save_config()
¶
Save the configuration of the dataset to a JSON file.
Source code in gfmrag/graph_index_datasets/graph_index_dataset.py
def save_config(self) -> None:
"""Save the configuration of the dataset to a JSON file."""
text_emb_model_cfgs = OmegaConf.to_container(
self.text_emb_model_cfgs, resolve=True
)
config = self.__class__.export_config_dict(
{
"text_emb_model_cfgs": text_emb_model_cfgs,
**{
key: getattr(self, key, None)
for key in self.__class__.FINGER_PRINT_ATTRS
},
}
)
with open(osp.join(self.processed_dir, "config.json"), "w") as f:
json.dump(config, f, indent=4)
GraphIndexDatasetV1
¶
Bases: GraphIndexDataset
Version 1 of the Graph Index Dataset for GFM-RAG.
This is a specialized version of the GraphIndexDataset tailored for the first iteration of the GFM-RAG framework, which predict other types of nodes based on entity predictions. It inherits from GraphIndexDataset and can include additional features or modifications specific to version 1
Source code in gfmrag/graph_index_datasets/graph_index_dataset_v1.py
class GraphIndexDatasetV1(GraphIndexDataset):
"""
Version 1 of the Graph Index Dataset for GFM-RAG.
This is a specialized version of the GraphIndexDataset tailored for the first iteration of the GFM-RAG framework, which predict other types of nodes based on entity predictions.
It inherits from GraphIndexDataset and can include additional features or modifications specific to version 1
"""
FINGER_PRINT_ATTRS: ClassVar[list[str]] = GraphIndexDataset.FINGER_PRINT_ATTRS + [
"target_type"
]
def __init__(self, target_type: str, **kwargs: Any):
"""
Initialize the GraphIndexDatasetV1 with the specified prediction type and other parameters.
Args:
target_type (str): The type of node to be used to construct graph index (e.g., 'entity').
**kwargs: Additional keyword arguments passed to the parent class.
"""
self.target_type = target_type
super().__init__(**kwargs)
# Additional initialization or modifications for version 1 can be added here.
def attributes_to_text( # type: ignore[override]
self, name: str | None = None, attributes: dict | None = None, **kwargs: dict
) -> str:
"""Return a string representation of the attributes. V1 version only encodes the name
Args:
name (str): The name of the node or relation to be used as the string representation
**kwargs (dict): Additional keyword arguments to include in the string. The keys of the dictionary will be used as attribute names.
Returns:
str: A formatted string representation of the attributes.
Examples:
>>> name = "Node1"
>>> print(attributes_to_text(name=name))
Node1
"""
if name is not None:
return name
if attributes is None:
attributes = {}
if len(attributes) > 0:
attr_str = "\n".join(
f"{key}: {value}"
for key, value in attributes.items()
if value is not None
)
else:
attr_str = ""
if len(kwargs) > 0:
additional_attrs = "\n".join(
f"{key}: {value}" for key, value in kwargs.items() if value is not None
)
else:
additional_attrs = ""
if attr_str and additional_attrs:
return f"{additional_attrs}\n{attr_str}".strip()
elif attr_str:
return attr_str.strip()
elif additional_attrs:
return additional_attrs.strip()
else:
return ""
def process_graph(self) -> None:
"""Process the graph index dataset.
This method processes the raw graph index file and creates the following:
1. Loads the nodes, edges, and relations from the raw files
2. Only use edges with the specified target type nodes for graph index construction
3. Saves entity and relation mappings to JSON files
4. Generates relation, entity, edges features using a text embedding model
5. Saves the processed data and model configurations
The processed data includes:
- Edge indices and types for both original and inverse edges
- Target edge indices and types (original edges only)
- Number of nodes and relations
- Relation embeddings
- Entity embeddings
- Target to other types mapping as a sparse tensor, used for getting the prediction of other types of nodes based on the target type nodes.
Files created:
- graph.pt: Contains the processed graph data
- node2id.json: Node to ID mapping
- rel2id.json: Relation to ID mapping (including inverse relations)
"""
node_file, relation_file, edge_file = self.raw_graph
if (
not osp.exists(node_file)
or not osp.exists(relation_file)
or not osp.exists(edge_file)
):
raise FileNotFoundError(
f"Required files not found in {self.raw_dir}. "
"Please ensure 'nodes.csv', 'relations.csv', and 'edges.csv' exist."
)
# Load nodes
nodes_df = self._read_csv_file(node_file)
node2id = nodes_df["id"].to_dict() # Map names or uid to continuous IDs
nodes_type_id, node_type_names = pd.factorize(nodes_df["type"])
nodes_df["type_id"] = nodes_type_id # Add type ID column
# Create a tensor for node types
node_types = torch.LongTensor(nodes_type_id)
# Save node ids under each type for fast access
nodes_by_type = {}
# Group node id by type
node_types_group = nodes_df.groupby("type")["id"].apply(list).to_dict()
for node_type, node_ids in node_types_group.items():
nodes_by_type[node_type] = torch.LongTensor(node_ids)
if len(nodes_by_type.get(self.target_type, [])) == 0:
raise ValueError(
f"No nodes found for target type '{self.target_type}'. "
"Please ensure the target type exists in the graph."
)
# Load relations
relations_df = self._read_csv_file(relation_file)
rel2id = relations_df["id"].to_dict()
# Load triplets from edges.csv
edges_df = pd.read_csv(edge_file)
edges_df["attributes"] = edges_df["attributes"].apply(
lambda x: {} if pd.isna(x) else ast.literal_eval(x)
)
# Vectorized mapping of source, target, and relation to IDs
edges_df["u"] = edges_df["source"].map(node2id)
edges_df["v"] = edges_df["target"].map(node2id)
edges_df["r"] = edges_df["relation"].map(rel2id)
# Filter out rows with missing node or relation IDs
valid_edges_df = edges_df.dropna(subset=["u", "v", "r"]).copy()
# Log skipped edges
skipped_edges = edges_df[edges_df[["u", "v", "r"]].isnull().any(axis=1)]
for _, row in skipped_edges.iterrows():
logger.warning(
f"Skipping edge with missing node or relation: {row['source']}, {row['relation']}, {row['target']}"
)
# Apply node type for edges
valid_edges_df["source_type"] = valid_edges_df["source"].apply(
lambda x: nodes_df.loc[x, "type"]
)
valid_edges_df["target_type"] = valid_edges_df["target"].apply(
lambda x: nodes_df.loc[x, "type"]
)
# Only select edges that the node type of both source and target is the target type
target_edges_df = valid_edges_df[
(valid_edges_df["source_type"] == self.target_type)
& (valid_edges_df["target_type"] == self.target_type)
]
num_nodes = len(node2id)
num_relations = len(rel2id)
# Create target type to other types mapping and store as sparse tensor, size: (n nodes, total number of nodes)
target_to_other_types: dict[str, torch.Tensor] = dict()
for other_type, group in valid_edges_df[
valid_edges_df["source_type"] == self.target_type
].groupby("target_type"):
# Skip if the other type is the same as the target type
if other_type == self.target_type:
continue
indices = torch.tensor(group[["u", "v"]].astype(int).values.T)
target_to_other_mapping = torch.sparse_coo_tensor(
indices,
torch.ones(indices.size(1), dtype=torch.float),
size=(num_nodes, num_nodes),
)
target_to_other_types[other_type] = target_to_other_mapping
# Convert IDs to int and build edge tuples
edges = list(
zip(
target_edges_df["u"].astype(int),
target_edges_df["v"].astype(int),
target_edges_df["r"].astype(int),
)
)
train_target_edges = torch.tensor(
[[t[0], t[1]] for t in edges], dtype=torch.long
).t()
train_target_etypes = torch.tensor([t[2] for t in edges])
# Add inverse edges
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
train_etypes = torch.cat(
[train_target_etypes, train_target_etypes + num_relations]
)
with open(self.processed_dir + "/node2id.json", "w") as f:
json.dump(node2id, f)
id2rel = {v: k for k, v in rel2id.items()}
for etype in train_etypes:
if etype.item() >= num_relations:
raw_etype = etype - num_relations
raw_rel = id2rel[raw_etype.item()]
rel2id["inverse_" + raw_rel] = etype.item()
with open(self.processed_dir + "/rel2id.json", "w") as f:
json.dump(rel2id, f)
# Instantiate the text embedding model if attributes are used
if self.use_node_feat or self.use_edge_feat or self.use_relation_feat:
text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
# Generate relation embeddings
if self.use_relation_feat:
logger.info("Generating relation embeddings")
relation_text_attributes = relations_df.apply(
lambda row: self.attributes_to_text(
attributes=row["attributes"], name=row.name
),
axis=1,
).to_list()
rel_emb = text_emb_model.encode(
relation_text_attributes, is_query=False
).cpu()
if self.inverse_relation_feat == "inverse":
# Inverse relations by adding the negative sign to the relation embeddings http://arxiv.org/abs/2505.20422
rel_emb = torch.cat([rel_emb, -rel_emb], dim=0)
elif self.inverse_relation_feat == "text":
inverse_relation_text_attributes = relations_df.apply(
lambda row: self.attributes_to_text(
attributes=row["attributes"], name="inverse_" + row.name
),
axis=1,
).to_list()
inverse_rel_emb = text_emb_model.encode(
inverse_relation_text_attributes, is_query=False
).cpu()
rel_emb = torch.cat([rel_emb, inverse_rel_emb], dim=0)
else:
rel_emb = None
# Generate entity embeddings
if self.use_node_feat:
node_text_attributes = nodes_df.apply(
lambda row: self.attributes_to_text(
attributes=row["attributes"], name=row.name, type=row["type"]
),
axis=1,
).to_list()
logger.info("Generating entity embeddings")
node_emb = text_emb_model.encode(node_text_attributes, is_query=False).cpu()
else:
node_emb = None
if self.use_edge_feat:
logger.info("Generating edge embeddings")
edge_text_attributes = edges_df.apply(
lambda row: self.attributes_to_text(
attributes=row["attributes"],
),
axis=1,
).to_list()
edge_emb = text_emb_model.encode(edge_text_attributes, is_query=False).cpu()
else:
edge_emb = None
# Get feature dimension
for emb in [node_emb, rel_emb, edge_emb]:
if emb is not None:
if emb.ndim != 2:
raise ValueError(
f"Expected 2D tensor for embeddings, got {emb.ndim}D tensor."
)
feat_dim = emb.size(1)
break
else:
feat_dim = 0 # No embeddings available
graph = Data(
node_type=node_types,
node_type_names=node_type_names,
nodes_by_type=nodes_by_type,
target_to_other_types=target_to_other_types,
edge_index=train_edges,
edge_type=train_etypes,
num_nodes=num_nodes,
target_edge_index=train_target_edges,
target_edge_type=train_target_etypes,
num_relations=num_relations * 2,
x=node_emb,
rel_attr=rel_emb,
edge_attr=edge_emb,
feat_dim=feat_dim,
)
torch.save(graph, self.processed_graph[0])
__init__(target_type, **kwargs)
¶
Initialize the GraphIndexDatasetV1 with the specified prediction type and other parameters. Args: target_type (str): The type of node to be used to construct graph index (e.g., 'entity'). **kwargs: Additional keyword arguments passed to the parent class.
Source code in gfmrag/graph_index_datasets/graph_index_dataset_v1.py
def __init__(self, target_type: str, **kwargs: Any):
"""
Initialize the GraphIndexDatasetV1 with the specified prediction type and other parameters.
Args:
target_type (str): The type of node to be used to construct graph index (e.g., 'entity').
**kwargs: Additional keyword arguments passed to the parent class.
"""
self.target_type = target_type
super().__init__(**kwargs)
attributes_to_text(name=None, attributes=None, **kwargs)
¶
Return a string representation of the attributes. V1 version only encodes the name
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
The name of the node or relation to be used as the string representation |
None
|
**kwargs
|
dict
|
Additional keyword arguments to include in the string. The keys of the dictionary will be used as attribute names. |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
A formatted string representation of the attributes. |
Examples:
Source code in gfmrag/graph_index_datasets/graph_index_dataset_v1.py
def attributes_to_text( # type: ignore[override]
self, name: str | None = None, attributes: dict | None = None, **kwargs: dict
) -> str:
"""Return a string representation of the attributes. V1 version only encodes the name
Args:
name (str): The name of the node or relation to be used as the string representation
**kwargs (dict): Additional keyword arguments to include in the string. The keys of the dictionary will be used as attribute names.
Returns:
str: A formatted string representation of the attributes.
Examples:
>>> name = "Node1"
>>> print(attributes_to_text(name=name))
Node1
"""
if name is not None:
return name
if attributes is None:
attributes = {}
if len(attributes) > 0:
attr_str = "\n".join(
f"{key}: {value}"
for key, value in attributes.items()
if value is not None
)
else:
attr_str = ""
if len(kwargs) > 0:
additional_attrs = "\n".join(
f"{key}: {value}" for key, value in kwargs.items() if value is not None
)
else:
additional_attrs = ""
if attr_str and additional_attrs:
return f"{additional_attrs}\n{attr_str}".strip()
elif attr_str:
return attr_str.strip()
elif additional_attrs:
return additional_attrs.strip()
else:
return ""
process_graph()
¶
Process the graph index dataset.
This method processes the raw graph index file and creates the following:
- Loads the nodes, edges, and relations from the raw files
- Only use edges with the specified target type nodes for graph index construction
- Saves entity and relation mappings to JSON files
- Generates relation, entity, edges features using a text embedding model
- Saves the processed data and model configurations
The processed data includes:
- Edge indices and types for both original and inverse edges
- Target edge indices and types (original edges only)
- Number of nodes and relations
- Relation embeddings
- Entity embeddings
- Target to other types mapping as a sparse tensor, used for getting the prediction of other types of nodes based on the target type nodes.
Files created:
- graph.pt: Contains the processed graph data
- node2id.json: Node to ID mapping
- rel2id.json: Relation to ID mapping (including inverse relations)
Source code in gfmrag/graph_index_datasets/graph_index_dataset_v1.py
def process_graph(self) -> None:
"""Process the graph index dataset.
This method processes the raw graph index file and creates the following:
1. Loads the nodes, edges, and relations from the raw files
2. Only use edges with the specified target type nodes for graph index construction
3. Saves entity and relation mappings to JSON files
4. Generates relation, entity, edges features using a text embedding model
5. Saves the processed data and model configurations
The processed data includes:
- Edge indices and types for both original and inverse edges
- Target edge indices and types (original edges only)
- Number of nodes and relations
- Relation embeddings
- Entity embeddings
- Target to other types mapping as a sparse tensor, used for getting the prediction of other types of nodes based on the target type nodes.
Files created:
- graph.pt: Contains the processed graph data
- node2id.json: Node to ID mapping
- rel2id.json: Relation to ID mapping (including inverse relations)
"""
node_file, relation_file, edge_file = self.raw_graph
if (
not osp.exists(node_file)
or not osp.exists(relation_file)
or not osp.exists(edge_file)
):
raise FileNotFoundError(
f"Required files not found in {self.raw_dir}. "
"Please ensure 'nodes.csv', 'relations.csv', and 'edges.csv' exist."
)
# Load nodes
nodes_df = self._read_csv_file(node_file)
node2id = nodes_df["id"].to_dict() # Map names or uid to continuous IDs
nodes_type_id, node_type_names = pd.factorize(nodes_df["type"])
nodes_df["type_id"] = nodes_type_id # Add type ID column
# Create a tensor for node types
node_types = torch.LongTensor(nodes_type_id)
# Save node ids under each type for fast access
nodes_by_type = {}
# Group node id by type
node_types_group = nodes_df.groupby("type")["id"].apply(list).to_dict()
for node_type, node_ids in node_types_group.items():
nodes_by_type[node_type] = torch.LongTensor(node_ids)
if len(nodes_by_type.get(self.target_type, [])) == 0:
raise ValueError(
f"No nodes found for target type '{self.target_type}'. "
"Please ensure the target type exists in the graph."
)
# Load relations
relations_df = self._read_csv_file(relation_file)
rel2id = relations_df["id"].to_dict()
# Load triplets from edges.csv
edges_df = pd.read_csv(edge_file)
edges_df["attributes"] = edges_df["attributes"].apply(
lambda x: {} if pd.isna(x) else ast.literal_eval(x)
)
# Vectorized mapping of source, target, and relation to IDs
edges_df["u"] = edges_df["source"].map(node2id)
edges_df["v"] = edges_df["target"].map(node2id)
edges_df["r"] = edges_df["relation"].map(rel2id)
# Filter out rows with missing node or relation IDs
valid_edges_df = edges_df.dropna(subset=["u", "v", "r"]).copy()
# Log skipped edges
skipped_edges = edges_df[edges_df[["u", "v", "r"]].isnull().any(axis=1)]
for _, row in skipped_edges.iterrows():
logger.warning(
f"Skipping edge with missing node or relation: {row['source']}, {row['relation']}, {row['target']}"
)
# Apply node type for edges
valid_edges_df["source_type"] = valid_edges_df["source"].apply(
lambda x: nodes_df.loc[x, "type"]
)
valid_edges_df["target_type"] = valid_edges_df["target"].apply(
lambda x: nodes_df.loc[x, "type"]
)
# Only select edges that the node type of both source and target is the target type
target_edges_df = valid_edges_df[
(valid_edges_df["source_type"] == self.target_type)
& (valid_edges_df["target_type"] == self.target_type)
]
num_nodes = len(node2id)
num_relations = len(rel2id)
# Create target type to other types mapping and store as sparse tensor, size: (n nodes, total number of nodes)
target_to_other_types: dict[str, torch.Tensor] = dict()
for other_type, group in valid_edges_df[
valid_edges_df["source_type"] == self.target_type
].groupby("target_type"):
# Skip if the other type is the same as the target type
if other_type == self.target_type:
continue
indices = torch.tensor(group[["u", "v"]].astype(int).values.T)
target_to_other_mapping = torch.sparse_coo_tensor(
indices,
torch.ones(indices.size(1), dtype=torch.float),
size=(num_nodes, num_nodes),
)
target_to_other_types[other_type] = target_to_other_mapping
# Convert IDs to int and build edge tuples
edges = list(
zip(
target_edges_df["u"].astype(int),
target_edges_df["v"].astype(int),
target_edges_df["r"].astype(int),
)
)
train_target_edges = torch.tensor(
[[t[0], t[1]] for t in edges], dtype=torch.long
).t()
train_target_etypes = torch.tensor([t[2] for t in edges])
# Add inverse edges
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
train_etypes = torch.cat(
[train_target_etypes, train_target_etypes + num_relations]
)
with open(self.processed_dir + "/node2id.json", "w") as f:
json.dump(node2id, f)
id2rel = {v: k for k, v in rel2id.items()}
for etype in train_etypes:
if etype.item() >= num_relations:
raw_etype = etype - num_relations
raw_rel = id2rel[raw_etype.item()]
rel2id["inverse_" + raw_rel] = etype.item()
with open(self.processed_dir + "/rel2id.json", "w") as f:
json.dump(rel2id, f)
# Instantiate the text embedding model if attributes are used
if self.use_node_feat or self.use_edge_feat or self.use_relation_feat:
text_emb_model: BaseTextEmbModel = instantiate(self.text_emb_model_cfgs)
# Generate relation embeddings
if self.use_relation_feat:
logger.info("Generating relation embeddings")
relation_text_attributes = relations_df.apply(
lambda row: self.attributes_to_text(
attributes=row["attributes"], name=row.name
),
axis=1,
).to_list()
rel_emb = text_emb_model.encode(
relation_text_attributes, is_query=False
).cpu()
if self.inverse_relation_feat == "inverse":
# Inverse relations by adding the negative sign to the relation embeddings http://arxiv.org/abs/2505.20422
rel_emb = torch.cat([rel_emb, -rel_emb], dim=0)
elif self.inverse_relation_feat == "text":
inverse_relation_text_attributes = relations_df.apply(
lambda row: self.attributes_to_text(
attributes=row["attributes"], name="inverse_" + row.name
),
axis=1,
).to_list()
inverse_rel_emb = text_emb_model.encode(
inverse_relation_text_attributes, is_query=False
).cpu()
rel_emb = torch.cat([rel_emb, inverse_rel_emb], dim=0)
else:
rel_emb = None
# Generate entity embeddings
if self.use_node_feat:
node_text_attributes = nodes_df.apply(
lambda row: self.attributes_to_text(
attributes=row["attributes"], name=row.name, type=row["type"]
),
axis=1,
).to_list()
logger.info("Generating entity embeddings")
node_emb = text_emb_model.encode(node_text_attributes, is_query=False).cpu()
else:
node_emb = None
if self.use_edge_feat:
logger.info("Generating edge embeddings")
edge_text_attributes = edges_df.apply(
lambda row: self.attributes_to_text(
attributes=row["attributes"],
),
axis=1,
).to_list()
edge_emb = text_emb_model.encode(edge_text_attributes, is_query=False).cpu()
else:
edge_emb = None
# Get feature dimension
for emb in [node_emb, rel_emb, edge_emb]:
if emb is not None:
if emb.ndim != 2:
raise ValueError(
f"Expected 2D tensor for embeddings, got {emb.ndim}D tensor."
)
feat_dim = emb.size(1)
break
else:
feat_dim = 0 # No embeddings available
graph = Data(
node_type=node_types,
node_type_names=node_type_names,
nodes_by_type=nodes_by_type,
target_to_other_types=target_to_other_types,
edge_index=train_edges,
edge_type=train_etypes,
num_nodes=num_nodes,
target_edge_index=train_target_edges,
target_edge_type=train_target_etypes,
num_relations=num_relations * 2,
x=node_emb,
rel_attr=rel_emb,
edge_attr=edge_emb,
feat_dim=feat_dim,
)
torch.save(graph, self.processed_graph[0])