Doc ranker
gfmrag.models.gfm_rag_v1.rankers
¶
BaseDocRanker
¶
IDFWeightedRanker
¶
Bases: BaseDocRanker
Rank documents based on entity prediction with IDF weighting
Source code in gfmrag/models/gfm_rag_v1/rankers.py
class IDFWeightedRanker(BaseDocRanker):
"""
Rank documents based on entity prediction with IDF weighting
"""
def __call__(self, ent_pred: torch.Tensor, ent2doc: torch.Tensor) -> torch.Tensor:
"""
Rank documents based on entity prediction with IDF weighting
Args:
ent_pred (torch.Tensor): Entity prediction, shape (batch_size, n_entities)
Returns:
torch.Tensor: Document ranks, shape (batch_size, n_docs)
"""
frequency = torch.sparse.sum(ent2doc, dim=-1).to_dense()
idf_weight = 1 / frequency
idf_weight[frequency == 0] = 0
doc_pred = torch.sparse.mm(ent_pred * idf_weight.unsqueeze(0), ent2doc)
return doc_pred
__call__(ent_pred, ent2doc)
¶
Rank documents based on entity prediction with IDF weighting
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ent_pred
|
Tensor
|
Entity prediction, shape (batch_size, n_entities) |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Document ranks, shape (batch_size, n_docs) |
Source code in gfmrag/models/gfm_rag_v1/rankers.py
def __call__(self, ent_pred: torch.Tensor, ent2doc: torch.Tensor) -> torch.Tensor:
"""
Rank documents based on entity prediction with IDF weighting
Args:
ent_pred (torch.Tensor): Entity prediction, shape (batch_size, n_entities)
Returns:
torch.Tensor: Document ranks, shape (batch_size, n_docs)
"""
frequency = torch.sparse.sum(ent2doc, dim=-1).to_dense()
idf_weight = 1 / frequency
idf_weight[frequency == 0] = 0
doc_pred = torch.sparse.mm(ent_pred * idf_weight.unsqueeze(0), ent2doc)
return doc_pred
IDFWeightedTopKRanker
¶
Bases: BaseDocRanker
Source code in gfmrag/models/gfm_rag_v1/rankers.py
class IDFWeightedTopKRanker(BaseDocRanker):
def __init__(self, top_k: int) -> None:
self.top_k = top_k
def __call__(self, ent_pred: torch.Tensor, ent2doc: torch.Tensor) -> torch.Tensor:
"""
Rank documents based on top-k entity prediction
Args:
ent_pred (torch.Tensor): Entity prediction, shape (batch_size, n_entities)
ent2doc (torch.Tensor): Sparse tensor mapping entities to documents, shape (n_entities, n_docs)
Returns:
torch.Tensor: Document ranks, shape (batch_size, n_docs)
"""
frequency = torch.sparse.sum(ent2doc, dim=-1).to_dense()
idf_weight = 1 / frequency
idf_weight[frequency == 0] = 0
top_k_ent_pred = torch.topk(ent_pred, self.top_k, dim=-1)
idf_weight = torch.gather(
idf_weight.expand(ent_pred.shape[0], -1), 1, top_k_ent_pred.indices
)
masked_ent_pred = torch.zeros_like(
ent_pred, device=ent_pred.device, dtype=idf_weight.dtype
)
masked_ent_pred.scatter_(1, top_k_ent_pred.indices, idf_weight)
doc_pred = torch.sparse.mm(masked_ent_pred, ent2doc)
return doc_pred
__call__(ent_pred, ent2doc)
¶
Rank documents based on top-k entity prediction
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ent_pred
|
Tensor
|
Entity prediction, shape (batch_size, n_entities) |
required |
ent2doc
|
Tensor
|
Sparse tensor mapping entities to documents, shape (n_entities, n_docs) |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Document ranks, shape (batch_size, n_docs) |
Source code in gfmrag/models/gfm_rag_v1/rankers.py
def __call__(self, ent_pred: torch.Tensor, ent2doc: torch.Tensor) -> torch.Tensor:
"""
Rank documents based on top-k entity prediction
Args:
ent_pred (torch.Tensor): Entity prediction, shape (batch_size, n_entities)
ent2doc (torch.Tensor): Sparse tensor mapping entities to documents, shape (n_entities, n_docs)
Returns:
torch.Tensor: Document ranks, shape (batch_size, n_docs)
"""
frequency = torch.sparse.sum(ent2doc, dim=-1).to_dense()
idf_weight = 1 / frequency
idf_weight[frequency == 0] = 0
top_k_ent_pred = torch.topk(ent_pred, self.top_k, dim=-1)
idf_weight = torch.gather(
idf_weight.expand(ent_pred.shape[0], -1), 1, top_k_ent_pred.indices
)
masked_ent_pred = torch.zeros_like(
ent_pred, device=ent_pred.device, dtype=idf_weight.dtype
)
masked_ent_pred.scatter_(1, top_k_ent_pred.indices, idf_weight)
doc_pred = torch.sparse.mm(masked_ent_pred, ent2doc)
return doc_pred
SimpleRanker
¶
Bases: BaseDocRanker
Rank documents based on entity prediction without any weighting
Source code in gfmrag/models/gfm_rag_v1/rankers.py
class SimpleRanker(BaseDocRanker):
"""
Rank documents based on entity prediction without any weighting
"""
def __call__(self, ent_pred: torch.Tensor, ent2doc: torch.Tensor) -> torch.Tensor:
"""
Rank documents based on entity prediction
Args:
ent_pred (torch.Tensor): Entity prediction, shape (batch_size, n_entities)
ent2doc (torch.Tensor): Sparse tensor mapping entities to documents, shape (n_entities, n_docs)
Returns:
torch.Tensor: Document ranks, shape (batch_size, n_docs)
"""
doc_pred = torch.sparse.mm(ent_pred, ent2doc)
return doc_pred
__call__(ent_pred, ent2doc)
¶
Rank documents based on entity prediction
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ent_pred
|
Tensor
|
Entity prediction, shape (batch_size, n_entities) |
required |
ent2doc
|
Tensor
|
Sparse tensor mapping entities to documents, shape (n_entities, n_docs) |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Document ranks, shape (batch_size, n_docs) |
Source code in gfmrag/models/gfm_rag_v1/rankers.py
def __call__(self, ent_pred: torch.Tensor, ent2doc: torch.Tensor) -> torch.Tensor:
"""
Rank documents based on entity prediction
Args:
ent_pred (torch.Tensor): Entity prediction, shape (batch_size, n_entities)
ent2doc (torch.Tensor): Sparse tensor mapping entities to documents, shape (n_entities, n_docs)
Returns:
torch.Tensor: Document ranks, shape (batch_size, n_docs)
"""
doc_pred = torch.sparse.mm(ent_pred, ent2doc)
return doc_pred
TopKRanker
¶
Bases: BaseDocRanker
Source code in gfmrag/models/gfm_rag_v1/rankers.py
class TopKRanker(BaseDocRanker):
def __init__(self, top_k: int) -> None:
self.top_k = top_k
def __call__(self, ent_pred: torch.Tensor, ent2doc: torch.Tensor) -> torch.Tensor:
"""
Rank documents based on top-k entity prediction
Args:
ent_pred (torch.Tensor): Entity prediction, shape (batch_size, n_entities)
ent2doc (torch.Tensor): Sparse tensor mapping entities to documents, shape (n_entities, n_docs)
Returns:
torch.Tensor: Document ranks, shape (batch_size, n_docs)
"""
top_k_ent_pred = torch.topk(ent_pred, self.top_k, dim=-1)
masked_ent_pred = torch.zeros_like(ent_pred, device=ent_pred.device)
masked_ent_pred.scatter_(1, top_k_ent_pred.indices, 1)
doc_pred = torch.sparse.mm(masked_ent_pred, ent2doc)
return doc_pred
__call__(ent_pred, ent2doc)
¶
Rank documents based on top-k entity prediction
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ent_pred
|
Tensor
|
Entity prediction, shape (batch_size, n_entities) |
required |
ent2doc
|
Tensor
|
Sparse tensor mapping entities to documents, shape (n_entities, n_docs) |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Document ranks, shape (batch_size, n_docs) |
Source code in gfmrag/models/gfm_rag_v1/rankers.py
def __call__(self, ent_pred: torch.Tensor, ent2doc: torch.Tensor) -> torch.Tensor:
"""
Rank documents based on top-k entity prediction
Args:
ent_pred (torch.Tensor): Entity prediction, shape (batch_size, n_entities)
ent2doc (torch.Tensor): Sparse tensor mapping entities to documents, shape (n_entities, n_docs)
Returns:
torch.Tensor: Document ranks, shape (batch_size, n_docs)
"""
top_k_ent_pred = torch.topk(ent_pred, self.top_k, dim=-1)
masked_ent_pred = torch.zeros_like(ent_pred, device=ent_pred.device)
masked_ent_pred.scatter_(1, top_k_ent_pred.indices, 1)
doc_pred = torch.sparse.mm(masked_ent_pred, ent2doc)
return doc_pred