bbm25_haystack.bbm25_retriever
1# SPDX-FileCopyrightText: 2024-present Yuxuan Wang <wangy49@seas.upenn.edu> 2# 3# SPDX-License-Identifier: Apache-2.0 4from typing import Any, Optional 5 6from haystack import ( 7 DeserializationError, 8 Document, 9 component, 10 default_from_dict, 11 default_to_dict, 12) 13 14from bbm25_haystack.bbm25_store import BetterBM25DocumentStore 15 16 17def _validate_search_params(filters: Optional[dict[str, Any]], top_k: int) -> None: 18 """ 19 Validate the search parameters. 20 21 :param filters: A dictionary with filters to narrow down the search space 22 (default is None). 23 :type filters: Optional[dict[str, Any]] 24 :param top_k: The maximum number of documents to retrieve (default is 10). 25 :type top_k: int 26 27 :raises ValueError: If the specified top_k is not > 0. 28 :raises TypeError: If filters is not a dictionary. 29 """ 30 if not isinstance(top_k, int): 31 msg = f"top_k must be an integer; got {type(top_k)} instead" 32 raise TypeError(msg) 33 34 if top_k <= 0: 35 msg = f"top_k must be > 0; got {top_k} instead" 36 raise ValueError(msg) 37 38 if filters is not None and (not isinstance(filters, dict)): 39 msg = f"filters must be a dictionary; got {type(filters)} instead" 40 raise TypeError(msg) 41 42 43@component 44class BetterBM25Retriever: 45 """ 46 A component for retrieving documents from an BetterBM25DocumentStore. 47 """ 48 49 def __init__( 50 self, 51 document_store: BetterBM25DocumentStore, 52 *, 53 filters: Optional[dict[str, Any]] = None, 54 top_k: int = 10, 55 set_score: bool = True, 56 ) -> None: 57 """ 58 Create an BetterBM25Retriever component. 59 60 :param document_store: A Document Store object used to 61 retrieve documents 62 :type document_store: BetterBM25DocumentStore 63 :param filters: A dictionary with filters to narrow down the 64 search space (default is None). 65 :type filters: Optional[dict[str, Any]] 66 :param top_k: The maximum number of documents to retrieve 67 (default is 10). 68 :type top_k: int 69 :param set_score: Whether to set the similarity scores 70 to retrieved documents (default is True). 71 :type set_score: bool 72 73 :raises ValueError: If the specified top_k is not > 0. 74 """ 75 _validate_search_params(filters, top_k) 76 77 self.filters = filters 78 self.top_k = top_k 79 self.set_score = set_score 80 81 if not isinstance(document_store, BetterBM25DocumentStore): 82 msg = "document_store must be an instance of BetterBM25DocumentStore" 83 raise TypeError(msg) 84 self.document_store = document_store 85 86 @component.output_types(documents=list[Document]) 87 def run( 88 self, 89 query: str, 90 *, 91 filters: Optional[dict[str, Any]] = None, 92 top_k: Optional[int] = None, 93 ) -> dict[str, list[Document]]: 94 """ 95 Run the Retriever on the given query. 96 97 This method always return copies of the documents 98 retrieved from the document store. 99 100 :param query: The query to run the Retriever on. 101 :type query: str 102 :param filters: A dictionary with filters to narrow 103 down the search space (default is None). 104 :type filters: Optional[dict[str, Any]] 105 :param top_k: The maximum number of documents to 106 retrieve (default is None). 107 108 :return: The retrieved documents. 109 """ 110 filters = filters or self.filters 111 top_k = top_k or self.top_k 112 113 _validate_search_params(filters, top_k) 114 115 sim = self.document_store._retrieval(query, filters=filters, top_k=top_k) 116 117 ret = [] 118 for doc, scr in sim: 119 data = doc.to_dict() 120 if self.set_score: 121 data["score"] = scr 122 ret.append(Document.from_dict(data)) 123 124 return {"documents": ret} 125 126 def to_dict(self) -> dict[str, Any]: 127 """ 128 Serializes the component to a dictionary. 129 130 :return: dictionary with serialized data. 131 """ 132 return default_to_dict( 133 self, 134 filters=self.filters, 135 top_k=self.top_k, 136 document_store=self.document_store.to_dict(), 137 set_score=self.set_score, 138 ) 139 140 @classmethod 141 def from_dict(cls, data: dict[str, Any]) -> "BetterBM25Retriever": 142 """ 143 Deserializes the component from a dictionary. 144 145 :param data: dictionary to deserialize from. 146 :returns: deserialized component. 147 """ 148 doc_store_params = data["init_parameters"].get("document_store") 149 if doc_store_params is None: 150 msg = "Missing 'document_store' in serialization data" 151 raise DeserializationError(msg) 152 153 if doc_store_params.get("type") is None: 154 msg = "Missing 'type' in document store's serialization data" 155 raise DeserializationError(msg) 156 157 data["init_parameters"]["document_store"] = ( 158 BetterBM25DocumentStore.from_dict(doc_store_params) 159 ) 160 return default_from_dict(cls, data)
@component
class
BetterBM25Retriever:
44@component 45class BetterBM25Retriever: 46 """ 47 A component for retrieving documents from an BetterBM25DocumentStore. 48 """ 49 50 def __init__( 51 self, 52 document_store: BetterBM25DocumentStore, 53 *, 54 filters: Optional[dict[str, Any]] = None, 55 top_k: int = 10, 56 set_score: bool = True, 57 ) -> None: 58 """ 59 Create an BetterBM25Retriever component. 60 61 :param document_store: A Document Store object used to 62 retrieve documents 63 :type document_store: BetterBM25DocumentStore 64 :param filters: A dictionary with filters to narrow down the 65 search space (default is None). 66 :type filters: Optional[dict[str, Any]] 67 :param top_k: The maximum number of documents to retrieve 68 (default is 10). 69 :type top_k: int 70 :param set_score: Whether to set the similarity scores 71 to retrieved documents (default is True). 72 :type set_score: bool 73 74 :raises ValueError: If the specified top_k is not > 0. 75 """ 76 _validate_search_params(filters, top_k) 77 78 self.filters = filters 79 self.top_k = top_k 80 self.set_score = set_score 81 82 if not isinstance(document_store, BetterBM25DocumentStore): 83 msg = "document_store must be an instance of BetterBM25DocumentStore" 84 raise TypeError(msg) 85 self.document_store = document_store 86 87 @component.output_types(documents=list[Document]) 88 def run( 89 self, 90 query: str, 91 *, 92 filters: Optional[dict[str, Any]] = None, 93 top_k: Optional[int] = None, 94 ) -> dict[str, list[Document]]: 95 """ 96 Run the Retriever on the given query. 97 98 This method always return copies of the documents 99 retrieved from the document store. 100 101 :param query: The query to run the Retriever on. 102 :type query: str 103 :param filters: A dictionary with filters to narrow 104 down the search space (default is None). 105 :type filters: Optional[dict[str, Any]] 106 :param top_k: The maximum number of documents to 107 retrieve (default is None). 108 109 :return: The retrieved documents. 110 """ 111 filters = filters or self.filters 112 top_k = top_k or self.top_k 113 114 _validate_search_params(filters, top_k) 115 116 sim = self.document_store._retrieval(query, filters=filters, top_k=top_k) 117 118 ret = [] 119 for doc, scr in sim: 120 data = doc.to_dict() 121 if self.set_score: 122 data["score"] = scr 123 ret.append(Document.from_dict(data)) 124 125 return {"documents": ret} 126 127 def to_dict(self) -> dict[str, Any]: 128 """ 129 Serializes the component to a dictionary. 130 131 :return: dictionary with serialized data. 132 """ 133 return default_to_dict( 134 self, 135 filters=self.filters, 136 top_k=self.top_k, 137 document_store=self.document_store.to_dict(), 138 set_score=self.set_score, 139 ) 140 141 @classmethod 142 def from_dict(cls, data: dict[str, Any]) -> "BetterBM25Retriever": 143 """ 144 Deserializes the component from a dictionary. 145 146 :param data: dictionary to deserialize from. 147 :returns: deserialized component. 148 """ 149 doc_store_params = data["init_parameters"].get("document_store") 150 if doc_store_params is None: 151 msg = "Missing 'document_store' in serialization data" 152 raise DeserializationError(msg) 153 154 if doc_store_params.get("type") is None: 155 msg = "Missing 'type' in document store's serialization data" 156 raise DeserializationError(msg) 157 158 data["init_parameters"]["document_store"] = ( 159 BetterBM25DocumentStore.from_dict(doc_store_params) 160 ) 161 return default_from_dict(cls, data)
A component for retrieving documents from an BetterBM25DocumentStore.
BetterBM25Retriever( document_store: bbm25_haystack.bbm25_store.BetterBM25DocumentStore, *, filters: Optional[dict[str, Any]] = None, top_k: int = 10, set_score: bool = True)
50 def __init__( 51 self, 52 document_store: BetterBM25DocumentStore, 53 *, 54 filters: Optional[dict[str, Any]] = None, 55 top_k: int = 10, 56 set_score: bool = True, 57 ) -> None: 58 """ 59 Create an BetterBM25Retriever component. 60 61 :param document_store: A Document Store object used to 62 retrieve documents 63 :type document_store: BetterBM25DocumentStore 64 :param filters: A dictionary with filters to narrow down the 65 search space (default is None). 66 :type filters: Optional[dict[str, Any]] 67 :param top_k: The maximum number of documents to retrieve 68 (default is 10). 69 :type top_k: int 70 :param set_score: Whether to set the similarity scores 71 to retrieved documents (default is True). 72 :type set_score: bool 73 74 :raises ValueError: If the specified top_k is not > 0. 75 """ 76 _validate_search_params(filters, top_k) 77 78 self.filters = filters 79 self.top_k = top_k 80 self.set_score = set_score 81 82 if not isinstance(document_store, BetterBM25DocumentStore): 83 msg = "document_store must be an instance of BetterBM25DocumentStore" 84 raise TypeError(msg) 85 self.document_store = document_store
Create an BetterBM25Retriever component.
Parameters
- document_store: A Document Store object used to retrieve documents
- filters: A dictionary with filters to narrow down the search space (default is None).
- top_k: The maximum number of documents to retrieve (default is 10).
- set_score: Whether to set the similarity scores to retrieved documents (default is True).
Raises
- ValueError: If the specified top_k is not > 0.
@component.output_types(documents=list[Document])
def
run( self, query: str, *, filters: Optional[dict[str, Any]] = None, top_k: Optional[int] = None) -> dict[str, list[haystack.dataclasses.document.Document]]:
87 @component.output_types(documents=list[Document]) 88 def run( 89 self, 90 query: str, 91 *, 92 filters: Optional[dict[str, Any]] = None, 93 top_k: Optional[int] = None, 94 ) -> dict[str, list[Document]]: 95 """ 96 Run the Retriever on the given query. 97 98 This method always return copies of the documents 99 retrieved from the document store. 100 101 :param query: The query to run the Retriever on. 102 :type query: str 103 :param filters: A dictionary with filters to narrow 104 down the search space (default is None). 105 :type filters: Optional[dict[str, Any]] 106 :param top_k: The maximum number of documents to 107 retrieve (default is None). 108 109 :return: The retrieved documents. 110 """ 111 filters = filters or self.filters 112 top_k = top_k or self.top_k 113 114 _validate_search_params(filters, top_k) 115 116 sim = self.document_store._retrieval(query, filters=filters, top_k=top_k) 117 118 ret = [] 119 for doc, scr in sim: 120 data = doc.to_dict() 121 if self.set_score: 122 data["score"] = scr 123 ret.append(Document.from_dict(data)) 124 125 return {"documents": ret}
Run the Retriever on the given query.
This method always return copies of the documents retrieved from the document store.
Parameters
- query: The query to run the Retriever on.
- filters: A dictionary with filters to narrow down the search space (default is None).
- top_k: The maximum number of documents to retrieve (default is None).
Returns
The retrieved documents.
def
to_dict(self) -> dict[str, typing.Any]:
127 def to_dict(self) -> dict[str, Any]: 128 """ 129 Serializes the component to a dictionary. 130 131 :return: dictionary with serialized data. 132 """ 133 return default_to_dict( 134 self, 135 filters=self.filters, 136 top_k=self.top_k, 137 document_store=self.document_store.to_dict(), 138 set_score=self.set_score, 139 )
Serializes the component to a dictionary.
Returns
dictionary with serialized data.
141 @classmethod 142 def from_dict(cls, data: dict[str, Any]) -> "BetterBM25Retriever": 143 """ 144 Deserializes the component from a dictionary. 145 146 :param data: dictionary to deserialize from. 147 :returns: deserialized component. 148 """ 149 doc_store_params = data["init_parameters"].get("document_store") 150 if doc_store_params is None: 151 msg = "Missing 'document_store' in serialization data" 152 raise DeserializationError(msg) 153 154 if doc_store_params.get("type") is None: 155 msg = "Missing 'type' in document store's serialization data" 156 raise DeserializationError(msg) 157 158 data["init_parameters"]["document_store"] = ( 159 BetterBM25DocumentStore.from_dict(doc_store_params) 160 ) 161 return default_from_dict(cls, data)
Deserializes the component from a dictionary.
Parameters
- data: dictionary to deserialize from. :returns: deserialized component.