Entity Linking Model
gfmrag.kg_construction.entity_linking_model
¶
BaseELModel
¶
Bases: ABC
Source code in gfmrag/kg_construction/entity_linking_model/base_model.py
class BaseELModel(ABC):
@abstractmethod
def __init__(self, **kwargs: Any) -> None:
pass
@abstractmethod
def index(self, entity_list: list) -> None:
"""
This method creates an index for the provided list of entities to enable efficient entity linking and searching capabilities.
Args:
entity_list (list): A list of entities to be indexed. Each entity should be a string or dictionary containing
the entity text and other relevant metadata.
None: This method modifies the internal index structure but does not return anything.
Raises:
ValueError: If entity_list is empty or contains invalid entity formats.
TypeError: If entity_list is not a list type.
Examples:
>>> model = EntityLinkingModel()
>>> entities = ["Paris", "France", "Eiffel Tower"]
>>> model.index(entities)
"""
pass
@abstractmethod
def __call__(self, ner_entity_list: list, topk: int = 1) -> dict:
"""
Link entities in the given text to the knowledge graph.
Args:
ner_entity_list (list): list of named entities
topk (int): number of linked entities to return
Returns:
dict: dict of linked entities in the knowledge graph
- key (str): named entity
- value (list[dict]): list of linked entities
- entity: linked entity
- score: score of the entity
- norm_score: normalized score of the entity
"""
pass
__call__(ner_entity_list, topk=1)
abstractmethod
¶
Link entities in the given text to the knowledge graph.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ner_entity_list
|
list
|
list of named entities |
required |
topk
|
int
|
number of linked entities to return |
1
|
Returns:
Name | Type | Description |
---|---|---|
dict |
dict
|
dict of linked entities in the knowledge graph
|
Source code in gfmrag/kg_construction/entity_linking_model/base_model.py
@abstractmethod
def __call__(self, ner_entity_list: list, topk: int = 1) -> dict:
"""
Link entities in the given text to the knowledge graph.
Args:
ner_entity_list (list): list of named entities
topk (int): number of linked entities to return
Returns:
dict: dict of linked entities in the knowledge graph
- key (str): named entity
- value (list[dict]): list of linked entities
- entity: linked entity
- score: score of the entity
- norm_score: normalized score of the entity
"""
pass
index(entity_list)
abstractmethod
¶
This method creates an index for the provided list of entities to enable efficient entity linking and searching capabilities.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
entity_list
|
list
|
A list of entities to be indexed. Each entity should be a string or dictionary containing the entity text and other relevant metadata. |
required |
None
|
This method modifies the internal index structure but does not return anything. |
required |
Raises:
Type | Description |
---|---|
ValueError
|
If entity_list is empty or contains invalid entity formats. |
TypeError
|
If entity_list is not a list type. |
Examples:
>>> model = EntityLinkingModel()
>>> entities = ["Paris", "France", "Eiffel Tower"]
>>> model.index(entities)
Source code in gfmrag/kg_construction/entity_linking_model/base_model.py
@abstractmethod
def index(self, entity_list: list) -> None:
"""
This method creates an index for the provided list of entities to enable efficient entity linking and searching capabilities.
Args:
entity_list (list): A list of entities to be indexed. Each entity should be a string or dictionary containing
the entity text and other relevant metadata.
None: This method modifies the internal index structure but does not return anything.
Raises:
ValueError: If entity_list is empty or contains invalid entity formats.
TypeError: If entity_list is not a list type.
Examples:
>>> model = EntityLinkingModel()
>>> entities = ["Paris", "France", "Eiffel Tower"]
>>> model.index(entities)
"""
pass
ColbertELModel
¶
Bases: BaseELModel
ColBERT-based Entity Linking Model.
This class implements an entity linking model using ColBERT, a neural information retrieval framework. It indexes a list of entities and performs entity linking by finding the most similar entities in the index for given named entities.
Attributes:
Name | Type | Description |
---|---|---|
checkpoint_path |
str
|
Path to the ColBERT checkpoint file |
root |
str
|
Root directory for storing indices |
doc_index_name |
str
|
Name of document index |
phrase_index_name |
str
|
Name of phrase index |
force |
bool
|
Whether to force reindex if index exists |
entity_list |
list
|
List of entities to be indexed |
phrase_searcher |
list
|
ColBERT phrase searcher object |
Raises:
Type | Description |
---|---|
FileNotFoundError
|
If the checkpoint file is not found at the specified path. |
AttributeError
|
If entity linking is attempted before indexing. |
Notes
You need to download the checkpoint by running the following command:
wget https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz && tar -zxvf colbertv2.0.tar.gz -C tmp/
Examples:
>>> model = ColbertELModel("path/to/checkpoint")
>>> model.index(["entity1", "entity2", "entity3"])
>>> results = model(["query1", "query2"], topk=3)
>>> print(results)
{'paris city': [{'entity': 'entity1', 'score': 0.82, 'norm_score': 1.0},
{'entity': 'entity2', 'score': 0.35, 'norm_score': 0.43}]}
Source code in gfmrag/kg_construction/entity_linking_model/colbert_el_model.py
class ColbertELModel(BaseELModel):
"""ColBERT-based Entity Linking Model.
This class implements an entity linking model using ColBERT, a neural information retrieval
framework. It indexes a list of entities and performs entity linking by finding the most
similar entities in the index for given named entities.
Attributes:
checkpoint_path (str): Path to the ColBERT checkpoint file
root (str): Root directory for storing indices
doc_index_name (str): Name of document index
phrase_index_name (str): Name of phrase index
force (bool): Whether to force reindex if index exists
entity_list (list): List of entities to be indexed
phrase_searcher: ColBERT phrase searcher object
Raises:
FileNotFoundError: If the checkpoint file is not found at the specified path.
AttributeError: If entity linking is attempted before indexing.
Notes:
You need to download the checkpoint by running the following command:
`wget https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz && tar -zxvf colbertv2.0.tar.gz -C tmp/`
Examples:
>>> model = ColbertELModel("path/to/checkpoint")
>>> model.index(["entity1", "entity2", "entity3"])
>>> results = model(["query1", "query2"], topk=3)
>>> print(results)
{'paris city': [{'entity': 'entity1', 'score': 0.82, 'norm_score': 1.0},
{'entity': 'entity2', 'score': 0.35, 'norm_score': 0.43}]}
"""
def __init__(
self,
checkpoint_path: str,
root: str = "tmp",
doc_index_name: str = "nbits_2",
phrase_index_name: str = "nbits_2",
force: bool = False,
) -> None:
"""
Initialize the ColBERT entity linking model.
This initializes a ColBERT model for entity linking using pre-trained checkpoints and indices.
Args:
checkpoint_path (str): Path to the ColBERT checkpoint file. Model weights will be loaded from this path. Can be downloaded [here](https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz)
root (str, optional): Root directory for storing indices. Defaults to "tmp".
doc_index_name (str, optional): Name of the document index. Defaults to "nbits_2".
phrase_index_name (str, optional): Name of the phrase index. Defaults to "nbits_2".
force (bool, optional): Whether to force recomputation of existing indices. Defaults to False.
Raises:
FileNotFoundError: If the checkpoint file does not exist at the specified path.
Returns:
None
"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
"Checkpoint not found, download the checkpoint with: 'wget https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz && tar -zxvf tmp/colbertv2.0.tar.gz -C tmp/'"
)
self.checkpoint_path = checkpoint_path
self.root = root
self.doc_index_name = doc_index_name
self.phrase_index_name = phrase_index_name
self.force = force
def index(self, entity_list: list) -> None:
"""
Index a list of entities using ColBERT for efficient similarity search.
This method processes and indexes a list of entities using the ColBERT model. It creates
a unique index based on the MD5 hash of the entity list and stores it in the specified
root directory.
Args:
entity_list (list): List of entity strings to be indexed.
Returns:
None
Notes:
- Creates a unique index directory based on MD5 hash of entities
- If force=True, will delete existing index with same fingerprint
- Processes entities into phrases before indexing
- Sets up ColBERT indexer and searcher with specified configuration
- Stores phrase_searcher as instance variable for later use
"""
self.entity_list = entity_list
# Get md5 fingerprint of the whole given entity list
fingerprint = hashlib.md5("".join(entity_list).encode()).hexdigest()
exp_name = f"Entity_index_{fingerprint}"
if os.path.exists(f"{self.root}/colbert/{fingerprint}") and self.force:
shutil.rmtree(f"{self.root}/colbert/{fingerprint}")
colbert_config = {
"root": f"{self.root}/colbert/{fingerprint}",
"doc_index_name": self.doc_index_name,
"phrase_index_name": self.phrase_index_name,
}
phrases = [processing_phrases(p) for p in entity_list]
with Run().context(
RunConfig(nranks=1, experiment=exp_name, root=colbert_config["root"])
):
config = ColBERTConfig(
nbits=2,
root=colbert_config["root"],
)
indexer = Indexer(checkpoint=self.checkpoint_path, config=config)
indexer.index(
name=self.phrase_index_name, collection=phrases, overwrite="reuse"
)
with Run().context(
RunConfig(nranks=1, experiment=exp_name, root=colbert_config["root"])
):
config = ColBERTConfig(
root=colbert_config["root"],
)
phrase_searcher = Searcher(
index=colbert_config["phrase_index_name"], config=config, verbose=1
)
self.phrase_searcher = phrase_searcher
def __call__(self, ner_entity_list: list, topk: int = 1) -> dict:
"""
Link entities in the given text to the knowledge graph.
Args:
ner_entity_list (list): list of named entities
topk (int): number of linked entities to return
Returns:
dict: dict of linked entities in the knowledge graph
- key (str): named entity
- value (list[dict]): list of linked entities
- entity: linked entity
- score: score of the entity
- norm_score: normalized score of the entity
"""
try:
self.__getattribute__("phrase_searcher")
except AttributeError as e:
raise AttributeError("Index the entities first using index method") from e
ner_entity_list = [processing_phrases(p) for p in ner_entity_list]
query_data: dict[int, str] = {
i: query for i, query in enumerate(ner_entity_list)
}
queries = Queries(path=None, data=query_data)
ranking = self.phrase_searcher.search_all(queries, k=topk)
linked_entity_dict: dict[str, list] = {}
for i in range(len(queries)):
query = queries[i]
rank = ranking.data[i]
linked_entity_dict[query] = []
max_score = rank[0][2]
for phrase_id, _rank, score in rank:
linked_entity_dict[query].append(
{
"entity": self.entity_list[phrase_id],
"score": score,
"norm_score": score / max_score,
}
)
return linked_entity_dict
__call__(ner_entity_list, topk=1)
¶
Link entities in the given text to the knowledge graph.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ner_entity_list
|
list
|
list of named entities |
required |
topk
|
int
|
number of linked entities to return |
1
|
Returns:
Name | Type | Description |
---|---|---|
dict |
dict
|
dict of linked entities in the knowledge graph
|
Source code in gfmrag/kg_construction/entity_linking_model/colbert_el_model.py
def __call__(self, ner_entity_list: list, topk: int = 1) -> dict:
"""
Link entities in the given text to the knowledge graph.
Args:
ner_entity_list (list): list of named entities
topk (int): number of linked entities to return
Returns:
dict: dict of linked entities in the knowledge graph
- key (str): named entity
- value (list[dict]): list of linked entities
- entity: linked entity
- score: score of the entity
- norm_score: normalized score of the entity
"""
try:
self.__getattribute__("phrase_searcher")
except AttributeError as e:
raise AttributeError("Index the entities first using index method") from e
ner_entity_list = [processing_phrases(p) for p in ner_entity_list]
query_data: dict[int, str] = {
i: query for i, query in enumerate(ner_entity_list)
}
queries = Queries(path=None, data=query_data)
ranking = self.phrase_searcher.search_all(queries, k=topk)
linked_entity_dict: dict[str, list] = {}
for i in range(len(queries)):
query = queries[i]
rank = ranking.data[i]
linked_entity_dict[query] = []
max_score = rank[0][2]
for phrase_id, _rank, score in rank:
linked_entity_dict[query].append(
{
"entity": self.entity_list[phrase_id],
"score": score,
"norm_score": score / max_score,
}
)
return linked_entity_dict
__init__(checkpoint_path, root='tmp', doc_index_name='nbits_2', phrase_index_name='nbits_2', force=False)
¶
Initialize the ColBERT entity linking model.
This initializes a ColBERT model for entity linking using pre-trained checkpoints and indices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
checkpoint_path
|
str
|
Path to the ColBERT checkpoint file. Model weights will be loaded from this path. Can be downloaded here |
required |
root
|
str
|
Root directory for storing indices. Defaults to "tmp". |
'tmp'
|
doc_index_name
|
str
|
Name of the document index. Defaults to "nbits_2". |
'nbits_2'
|
phrase_index_name
|
str
|
Name of the phrase index. Defaults to "nbits_2". |
'nbits_2'
|
force
|
bool
|
Whether to force recomputation of existing indices. Defaults to False. |
False
|
Raises:
Type | Description |
---|---|
FileNotFoundError
|
If the checkpoint file does not exist at the specified path. |
Returns:
Type | Description |
---|---|
None
|
None |
Source code in gfmrag/kg_construction/entity_linking_model/colbert_el_model.py
def __init__(
self,
checkpoint_path: str,
root: str = "tmp",
doc_index_name: str = "nbits_2",
phrase_index_name: str = "nbits_2",
force: bool = False,
) -> None:
"""
Initialize the ColBERT entity linking model.
This initializes a ColBERT model for entity linking using pre-trained checkpoints and indices.
Args:
checkpoint_path (str): Path to the ColBERT checkpoint file. Model weights will be loaded from this path. Can be downloaded [here](https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz)
root (str, optional): Root directory for storing indices. Defaults to "tmp".
doc_index_name (str, optional): Name of the document index. Defaults to "nbits_2".
phrase_index_name (str, optional): Name of the phrase index. Defaults to "nbits_2".
force (bool, optional): Whether to force recomputation of existing indices. Defaults to False.
Raises:
FileNotFoundError: If the checkpoint file does not exist at the specified path.
Returns:
None
"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
"Checkpoint not found, download the checkpoint with: 'wget https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz && tar -zxvf tmp/colbertv2.0.tar.gz -C tmp/'"
)
self.checkpoint_path = checkpoint_path
self.root = root
self.doc_index_name = doc_index_name
self.phrase_index_name = phrase_index_name
self.force = force
index(entity_list)
¶
Index a list of entities using ColBERT for efficient similarity search.
This method processes and indexes a list of entities using the ColBERT model. It creates a unique index based on the MD5 hash of the entity list and stores it in the specified root directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
entity_list
|
list
|
List of entity strings to be indexed. |
required |
Returns:
Type | Description |
---|---|
None
|
None |
Notes
- Creates a unique index directory based on MD5 hash of entities
- If force=True, will delete existing index with same fingerprint
- Processes entities into phrases before indexing
- Sets up ColBERT indexer and searcher with specified configuration
- Stores phrase_searcher as instance variable for later use
Source code in gfmrag/kg_construction/entity_linking_model/colbert_el_model.py
def index(self, entity_list: list) -> None:
"""
Index a list of entities using ColBERT for efficient similarity search.
This method processes and indexes a list of entities using the ColBERT model. It creates
a unique index based on the MD5 hash of the entity list and stores it in the specified
root directory.
Args:
entity_list (list): List of entity strings to be indexed.
Returns:
None
Notes:
- Creates a unique index directory based on MD5 hash of entities
- If force=True, will delete existing index with same fingerprint
- Processes entities into phrases before indexing
- Sets up ColBERT indexer and searcher with specified configuration
- Stores phrase_searcher as instance variable for later use
"""
self.entity_list = entity_list
# Get md5 fingerprint of the whole given entity list
fingerprint = hashlib.md5("".join(entity_list).encode()).hexdigest()
exp_name = f"Entity_index_{fingerprint}"
if os.path.exists(f"{self.root}/colbert/{fingerprint}") and self.force:
shutil.rmtree(f"{self.root}/colbert/{fingerprint}")
colbert_config = {
"root": f"{self.root}/colbert/{fingerprint}",
"doc_index_name": self.doc_index_name,
"phrase_index_name": self.phrase_index_name,
}
phrases = [processing_phrases(p) for p in entity_list]
with Run().context(
RunConfig(nranks=1, experiment=exp_name, root=colbert_config["root"])
):
config = ColBERTConfig(
nbits=2,
root=colbert_config["root"],
)
indexer = Indexer(checkpoint=self.checkpoint_path, config=config)
indexer.index(
name=self.phrase_index_name, collection=phrases, overwrite="reuse"
)
with Run().context(
RunConfig(nranks=1, experiment=exp_name, root=colbert_config["root"])
):
config = ColBERTConfig(
root=colbert_config["root"],
)
phrase_searcher = Searcher(
index=colbert_config["phrase_index_name"], config=config, verbose=1
)
self.phrase_searcher = phrase_searcher
DPRELModel
¶
Bases: BaseELModel
Entity Linking Model based on Dense Passage Retrieval (DPR).
This class implements an entity linking model using DPR architecture and SentenceTransformer for encoding entities and computing similarity scores between mentions and candidate entities.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name
|
str
|
Name or path of the SentenceTransformer model to use |
required |
root
|
str
|
Root directory for caching embeddings. Defaults to "tmp". |
'tmp'
|
use_cache
|
bool
|
Whether to cache and reuse entity embeddings. Defaults to True. |
True
|
normalize
|
bool
|
Whether to L2-normalize embeddings. Defaults to True. |
True
|
batch_size
|
int
|
Batch size for encoding. Defaults to 32. |
32
|
query_instruct
|
str
|
Instruction/prompt prefix for query encoding. Defaults to "". |
''
|
passage_instruct
|
str
|
Instruction/prompt prefix for passage encoding. Defaults to "". |
''
|
model_kwargs
|
dict
|
Additional kwargs to pass to SentenceTransformer. Defaults to None. |
None
|
Methods:
Name | Description |
---|---|
index |
Indexes a list of entities by computing and caching their embeddings |
__call__ |
Links named entities to indexed entities and returns top-k matches |
Examples:
>>> model = DPRELModel('sentence-transformers/all-mpnet-base-v2')
>>> model.index(['Paris', 'London', 'Berlin'])
>>> results = model(['paris city'], topk=2)
>>> print(results)
{'paris city': [{'entity': 'Paris', 'score': 0.82, 'norm_score': 1.0},
{'entity': 'London', 'score': 0.35, 'norm_score': 0.43}]}
Source code in gfmrag/kg_construction/entity_linking_model/dpr_el_model.py
class DPRELModel(BaseELModel):
"""
Entity Linking Model based on Dense Passage Retrieval (DPR).
This class implements an entity linking model using DPR architecture and SentenceTransformer
for encoding entities and computing similarity scores between mentions and candidate entities.
Args:
model_name (str): Name or path of the SentenceTransformer model to use
root (str, optional): Root directory for caching embeddings. Defaults to "tmp".
use_cache (bool, optional): Whether to cache and reuse entity embeddings. Defaults to True.
normalize (bool, optional): Whether to L2-normalize embeddings. Defaults to True.
batch_size (int, optional): Batch size for encoding. Defaults to 32.
query_instruct (str, optional): Instruction/prompt prefix for query encoding. Defaults to "".
passage_instruct (str, optional): Instruction/prompt prefix for passage encoding. Defaults to "".
model_kwargs (dict, optional): Additional kwargs to pass to SentenceTransformer. Defaults to None.
Methods:
index(entity_list): Indexes a list of entities by computing and caching their embeddings
__call__(ner_entity_list, topk): Links named entities to indexed entities and returns top-k matches
Examples:
>>> model = DPRELModel('sentence-transformers/all-mpnet-base-v2')
>>> model.index(['Paris', 'London', 'Berlin'])
>>> results = model(['paris city'], topk=2)
>>> print(results)
{'paris city': [{'entity': 'Paris', 'score': 0.82, 'norm_score': 1.0},
{'entity': 'London', 'score': 0.35, 'norm_score': 0.43}]}
"""
def __init__(
self,
model_name: str,
root: str = "tmp",
use_cache: bool = True,
normalize: bool = True,
batch_size: int = 32,
query_instruct: str = "",
passage_instruct: str = "",
model_kwargs: dict | None = None,
) -> None:
"""Initialize DPR Entity Linking Model.
Args:
model_name (str): Name or path of the pre-trained model to load.
root (str, optional): Root directory for cache storage. Defaults to "tmp".
use_cache (bool, optional): Whether to use cache for embeddings. Defaults to True.
normalize (bool, optional): Whether to normalize the embeddings. Defaults to True.
batch_size (int, optional): Batch size for encoding. Defaults to 32.
query_instruct (str, optional): Instruction prefix for query encoding. Defaults to "".
passage_instruct (str, optional): Instruction prefix for passage encoding. Defaults to "".
model_kwargs (dict | None, optional): Additional arguments to pass to the model. Defaults to None.
"""
self.model_name = model_name
self.use_cache = use_cache
self.normalize = normalize
self.batch_size = batch_size
self.root = os.path.join(root, f"{self.model_name.replace("/", "_")}_dpr_cache")
if self.use_cache and not os.path.exists(self.root):
os.makedirs(self.root)
self.model = SentenceTransformer(
model_name, trust_remote_code=True, model_kwargs=model_kwargs
)
self.query_instruct = query_instruct
self.passage_instruct = passage_instruct
def index(self, entity_list: list) -> None:
"""
Index a list of entities by encoding them into embeddings and optionally caching the results.
This method processes a list of entity strings, converting them into dense vector representations
using a pre-trained model. To avoid redundant computation, it implements a caching mechanism
based on the MD5 hash of the input entity list.
Args:
entity_list (list): A list of strings representing entities to be indexed.
Returns:
None
Notes:
- The method stores the embeddings in self.entity_embeddings
- If caching is enabled and a cache file exists for the given entity list,
embeddings are loaded from cache instead of being recomputed
- Cache files are stored using the MD5 hash of the concatenated entity list as filename
- Embeddings are computed on GPU if available, otherwise on CPU
"""
self.entity_list = entity_list
# Get md5 fingerprint of the whole given entity list
fingerprint = hashlib.md5("".join(entity_list).encode()).hexdigest()
cache_file = f"{self.root}/{fingerprint}.pt"
if os.path.exists(cache_file):
self.entity_embeddings = torch.load(
cache_file, map_location="cuda" if torch.cuda.is_available() else "cpu"
)
else:
self.entity_embeddings = self.model.encode(
entity_list,
device="cuda" if torch.cuda.is_available() else "cpu",
convert_to_tensor=True,
show_progress_bar=True,
prompt=self.passage_instruct,
normalize_embeddings=self.normalize,
batch_size=self.batch_size,
)
if self.use_cache:
torch.save(self.entity_embeddings, cache_file)
def __call__(self, ner_entity_list: list, topk: int = 1) -> dict:
"""
Performs entity linking by matching input entities with pre-encoded entity embeddings.
This method takes a list of named entities (e.g., from NER), computes their embeddings,
and finds the closest matching entities from the pre-encoded knowledge base using
cosine similarity.
Args:
ner_entity_list (list): List of named entities to link
topk (int, optional): Number of top matches to return for each entity. Defaults to 1.
Returns:
dict: Dictionary mapping each input entity to its linked candidates. For each candidate:
- entity (str): The matched entity name from the knowledge base
- score (float): Raw similarity score
- norm_score (float): Normalized similarity score (relative to top match)
"""
ner_entity_embeddings = self.model.encode(
ner_entity_list,
device="cuda" if torch.cuda.is_available() else "cpu",
convert_to_tensor=True,
prompt=self.query_instruct,
normalize_embeddings=self.normalize,
batch_size=self.batch_size,
)
scores = ner_entity_embeddings @ self.entity_embeddings.T
top_k_scores, top_k_values = torch.topk(scores, topk, dim=-1)
linked_entity_dict: dict[str, list] = {}
for i in range(len(ner_entity_list)):
linked_entity_dict[ner_entity_list[i]] = []
sorted_score = top_k_scores[i]
sorted_indices = top_k_values[i]
max_score = sorted_score[0].item()
for score, top_k_index in zip(sorted_score, sorted_indices):
linked_entity_dict[ner_entity_list[i]].append(
{
"entity": self.entity_list[top_k_index],
"score": score.item(),
"norm_score": score.item() / max_score,
}
)
return linked_entity_dict
__call__(ner_entity_list, topk=1)
¶
Performs entity linking by matching input entities with pre-encoded entity embeddings.
This method takes a list of named entities (e.g., from NER), computes their embeddings, and finds the closest matching entities from the pre-encoded knowledge base using cosine similarity.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ner_entity_list
|
list
|
List of named entities to link |
required |
topk
|
int
|
Number of top matches to return for each entity. Defaults to 1. |
1
|
Returns:
Name | Type | Description |
---|---|---|
dict |
dict
|
Dictionary mapping each input entity to its linked candidates. For each candidate: - entity (str): The matched entity name from the knowledge base - score (float): Raw similarity score - norm_score (float): Normalized similarity score (relative to top match) |
Source code in gfmrag/kg_construction/entity_linking_model/dpr_el_model.py
def __call__(self, ner_entity_list: list, topk: int = 1) -> dict:
"""
Performs entity linking by matching input entities with pre-encoded entity embeddings.
This method takes a list of named entities (e.g., from NER), computes their embeddings,
and finds the closest matching entities from the pre-encoded knowledge base using
cosine similarity.
Args:
ner_entity_list (list): List of named entities to link
topk (int, optional): Number of top matches to return for each entity. Defaults to 1.
Returns:
dict: Dictionary mapping each input entity to its linked candidates. For each candidate:
- entity (str): The matched entity name from the knowledge base
- score (float): Raw similarity score
- norm_score (float): Normalized similarity score (relative to top match)
"""
ner_entity_embeddings = self.model.encode(
ner_entity_list,
device="cuda" if torch.cuda.is_available() else "cpu",
convert_to_tensor=True,
prompt=self.query_instruct,
normalize_embeddings=self.normalize,
batch_size=self.batch_size,
)
scores = ner_entity_embeddings @ self.entity_embeddings.T
top_k_scores, top_k_values = torch.topk(scores, topk, dim=-1)
linked_entity_dict: dict[str, list] = {}
for i in range(len(ner_entity_list)):
linked_entity_dict[ner_entity_list[i]] = []
sorted_score = top_k_scores[i]
sorted_indices = top_k_values[i]
max_score = sorted_score[0].item()
for score, top_k_index in zip(sorted_score, sorted_indices):
linked_entity_dict[ner_entity_list[i]].append(
{
"entity": self.entity_list[top_k_index],
"score": score.item(),
"norm_score": score.item() / max_score,
}
)
return linked_entity_dict
__init__(model_name, root='tmp', use_cache=True, normalize=True, batch_size=32, query_instruct='', passage_instruct='', model_kwargs=None)
¶
Initialize DPR Entity Linking Model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name
|
str
|
Name or path of the pre-trained model to load. |
required |
root
|
str
|
Root directory for cache storage. Defaults to "tmp". |
'tmp'
|
use_cache
|
bool
|
Whether to use cache for embeddings. Defaults to True. |
True
|
normalize
|
bool
|
Whether to normalize the embeddings. Defaults to True. |
True
|
batch_size
|
int
|
Batch size for encoding. Defaults to 32. |
32
|
query_instruct
|
str
|
Instruction prefix for query encoding. Defaults to "". |
''
|
passage_instruct
|
str
|
Instruction prefix for passage encoding. Defaults to "". |
''
|
model_kwargs
|
dict | None
|
Additional arguments to pass to the model. Defaults to None. |
None
|
Source code in gfmrag/kg_construction/entity_linking_model/dpr_el_model.py
def __init__(
self,
model_name: str,
root: str = "tmp",
use_cache: bool = True,
normalize: bool = True,
batch_size: int = 32,
query_instruct: str = "",
passage_instruct: str = "",
model_kwargs: dict | None = None,
) -> None:
"""Initialize DPR Entity Linking Model.
Args:
model_name (str): Name or path of the pre-trained model to load.
root (str, optional): Root directory for cache storage. Defaults to "tmp".
use_cache (bool, optional): Whether to use cache for embeddings. Defaults to True.
normalize (bool, optional): Whether to normalize the embeddings. Defaults to True.
batch_size (int, optional): Batch size for encoding. Defaults to 32.
query_instruct (str, optional): Instruction prefix for query encoding. Defaults to "".
passage_instruct (str, optional): Instruction prefix for passage encoding. Defaults to "".
model_kwargs (dict | None, optional): Additional arguments to pass to the model. Defaults to None.
"""
self.model_name = model_name
self.use_cache = use_cache
self.normalize = normalize
self.batch_size = batch_size
self.root = os.path.join(root, f"{self.model_name.replace("/", "_")}_dpr_cache")
if self.use_cache and not os.path.exists(self.root):
os.makedirs(self.root)
self.model = SentenceTransformer(
model_name, trust_remote_code=True, model_kwargs=model_kwargs
)
self.query_instruct = query_instruct
self.passage_instruct = passage_instruct
index(entity_list)
¶
Index a list of entities by encoding them into embeddings and optionally caching the results.
This method processes a list of entity strings, converting them into dense vector representations using a pre-trained model. To avoid redundant computation, it implements a caching mechanism based on the MD5 hash of the input entity list.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
entity_list
|
list
|
A list of strings representing entities to be indexed. |
required |
Returns:
Type | Description |
---|---|
None
|
None |
Notes
- The method stores the embeddings in self.entity_embeddings
- If caching is enabled and a cache file exists for the given entity list, embeddings are loaded from cache instead of being recomputed
- Cache files are stored using the MD5 hash of the concatenated entity list as filename
- Embeddings are computed on GPU if available, otherwise on CPU
Source code in gfmrag/kg_construction/entity_linking_model/dpr_el_model.py
def index(self, entity_list: list) -> None:
"""
Index a list of entities by encoding them into embeddings and optionally caching the results.
This method processes a list of entity strings, converting them into dense vector representations
using a pre-trained model. To avoid redundant computation, it implements a caching mechanism
based on the MD5 hash of the input entity list.
Args:
entity_list (list): A list of strings representing entities to be indexed.
Returns:
None
Notes:
- The method stores the embeddings in self.entity_embeddings
- If caching is enabled and a cache file exists for the given entity list,
embeddings are loaded from cache instead of being recomputed
- Cache files are stored using the MD5 hash of the concatenated entity list as filename
- Embeddings are computed on GPU if available, otherwise on CPU
"""
self.entity_list = entity_list
# Get md5 fingerprint of the whole given entity list
fingerprint = hashlib.md5("".join(entity_list).encode()).hexdigest()
cache_file = f"{self.root}/{fingerprint}.pt"
if os.path.exists(cache_file):
self.entity_embeddings = torch.load(
cache_file, map_location="cuda" if torch.cuda.is_available() else "cpu"
)
else:
self.entity_embeddings = self.model.encode(
entity_list,
device="cuda" if torch.cuda.is_available() else "cpu",
convert_to_tensor=True,
show_progress_bar=True,
prompt=self.passage_instruct,
normalize_embeddings=self.normalize,
batch_size=self.batch_size,
)
if self.use_cache:
torch.save(self.entity_embeddings, cache_file)
NVEmbedV2ELModel
¶
Bases: DPRELModel
A DPR-based Entity Linking model specialized for NVEmbed V2 embeddings.
This class extends DPRELModel with specific adaptations for handling NVEmbed V2 models, including increased sequence length and right-side padding.
Attributes:
Name | Type | Description |
---|---|---|
model |
The underlying model with max_seq_length of 32768 and right-side padding. |
Methods:
Name | Description |
---|---|
add_eos |
Adds EOS token to input examples. |
__call__ |
Processes entity list with EOS tokens before linking. |
Examples:
>>> model = NVEmbedV2ELModel('nvidia/NV-Embed-v2', query_instruct="Instruct: Given a entity, retrieve entities that are semantically equivalent to the given entity\nQuery: ")
>>> model.index(['Paris', 'London', 'Berlin'])
>>> results = model(['paris city'], topk=2)
>>> print(results)
{'paris city': [{'entity': 'Paris', 'score': 0.82, 'norm_score': 1.0},
{'entity': 'London', 'score': 0.35, 'norm_score': 0.43}]}
Source code in gfmrag/kg_construction/entity_linking_model/dpr_el_model.py
class NVEmbedV2ELModel(DPRELModel):
"""
A DPR-based Entity Linking model specialized for NVEmbed V2 embeddings.
This class extends DPRELModel with specific adaptations for handling NVEmbed V2 models,
including increased sequence length and right-side padding.
Attributes:
model: The underlying model with max_seq_length of 32768 and right-side padding.
Methods:
add_eos(input_examples): Adds EOS token to input examples.
__call__(ner_entity_list): Processes entity list with EOS tokens before linking.
Examples:
>>> model = NVEmbedV2ELModel('nvidia/NV-Embed-v2', query_instruct=\"Instruct: Given a entity, retrieve entities that are semantically equivalent to the given entity\\nQuery: \")
>>> model.index(['Paris', 'London', 'Berlin'])
>>> results = model(['paris city'], topk=2)
>>> print(results)
{'paris city': [{'entity': 'Paris', 'score': 0.82, 'norm_score': 1.0},
{'entity': 'London', 'score': 0.35, 'norm_score': 0.43}]}
"""
def __init__(
self,
*args: Any,
**kwargs: Any,
) -> None:
"""
Initialize the DPR Entity Linking model.
This initialization extends the base class initialization and sets specific model parameters
for entity linking tasks. It configures the maximum sequence length to 32768 and sets
the tokenizer padding side to "right".
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
None
"""
super().__init__(
*args,
**kwargs,
)
self.model.max_seq_length = 32768
self.model.tokenizer.padding_side = "right"
def add_eos(self, input_examples: list[str]) -> list[str]:
"""
Appends EOS (End of Sequence) token to each input example in the list.
Args:
input_examples (list[str]): List of input text strings.
Returns:
list[str]: List of input texts with EOS token appended to each example.
"""
input_examples = [
input_example + self.model.tokenizer.eos_token
for input_example in input_examples
]
return input_examples
def __call__(self, ner_entity_list: list, *args: Any, **kwargs: Any) -> dict:
"""
Execute entity linking for a list of named entities.
Args:
ner_entity_list (list): List of named entities to be linked.
*args (Any): Variable length argument list.
**kwargs (Any): Arbitrary keyword arguments.
Returns:
dict: Entity linking results mapping entities to their linked entries.
"""
ner_entity_list = self.add_eos(ner_entity_list)
return super().__call__(ner_entity_list, *args, **kwargs)
__call__(ner_entity_list, *args, **kwargs)
¶
Execute entity linking for a list of named entities.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ner_entity_list
|
list
|
List of named entities to be linked. |
required |
*args
|
Any
|
Variable length argument list. |
()
|
**kwargs
|
Any
|
Arbitrary keyword arguments. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
dict |
dict
|
Entity linking results mapping entities to their linked entries. |
Source code in gfmrag/kg_construction/entity_linking_model/dpr_el_model.py
def __call__(self, ner_entity_list: list, *args: Any, **kwargs: Any) -> dict:
"""
Execute entity linking for a list of named entities.
Args:
ner_entity_list (list): List of named entities to be linked.
*args (Any): Variable length argument list.
**kwargs (Any): Arbitrary keyword arguments.
Returns:
dict: Entity linking results mapping entities to their linked entries.
"""
ner_entity_list = self.add_eos(ner_entity_list)
return super().__call__(ner_entity_list, *args, **kwargs)
__init__(*args, **kwargs)
¶
Initialize the DPR Entity Linking model.
This initialization extends the base class initialization and sets specific model parameters for entity linking tasks. It configures the maximum sequence length to 32768 and sets the tokenizer padding side to "right".
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args
|
Any
|
Variable length argument list. |
()
|
**kwargs
|
Any
|
Arbitrary keyword arguments. |
{}
|
Returns:
Type | Description |
---|---|
None
|
None |
Source code in gfmrag/kg_construction/entity_linking_model/dpr_el_model.py
def __init__(
self,
*args: Any,
**kwargs: Any,
) -> None:
"""
Initialize the DPR Entity Linking model.
This initialization extends the base class initialization and sets specific model parameters
for entity linking tasks. It configures the maximum sequence length to 32768 and sets
the tokenizer padding side to "right".
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
None
"""
super().__init__(
*args,
**kwargs,
)
self.model.max_seq_length = 32768
self.model.tokenizer.padding_side = "right"
add_eos(input_examples)
¶
Appends EOS (End of Sequence) token to each input example in the list.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_examples
|
list[str]
|
List of input text strings. |
required |
Returns:
Type | Description |
---|---|
list[str]
|
list[str]: List of input texts with EOS token appended to each example. |
Source code in gfmrag/kg_construction/entity_linking_model/dpr_el_model.py
def add_eos(self, input_examples: list[str]) -> list[str]:
"""
Appends EOS (End of Sequence) token to each input example in the list.
Args:
input_examples (list[str]): List of input text strings.
Returns:
list[str]: List of input texts with EOS token appended to each example.
"""
input_examples = [
input_example + self.model.tokenizer.eos_token
for input_example in input_examples
]
return input_examples