bbm25_haystack.bbm25_store
1# SPDX-FileCopyrightText: 2024-present Yuxuan Wang <wangy49@seas.upenn.edu> 2# 3# SPDX-License-Identifier: Apache-2.0 4import heapq 5import math 6import os 7from collections import Counter, deque 8from collections.abc import Iterable 9from itertools import chain 10from typing import Any, Final, Optional, Union 11 12import pandas as pd 13from haystack import Document, default_from_dict, default_to_dict, logging 14from haystack.document_stores.errors import ( 15 DuplicateDocumentError, 16 MissingDocumentError, 17) 18from haystack.document_stores.types import DuplicatePolicy 19from haystack.utils.filters import document_matches_filter 20from sentencepiece import SentencePieceProcessor # type: ignore 21 22from bbm25_haystack.filters import apply_filters_to_document 23 24logger = logging.getLogger(__name__) 25 26 27def _n_grams(seq: Iterable[str], n: int): 28 """ 29 Returns a sliding window (of width n) over data from the iterable. 30 31 :param seq: the input token sequence. 32 :type seq: Iterable[str] 33 :param n: the window size. 34 :type n: int 35 36 :return: the n-gram window generator. 37 :rtype: Generator[tuple[str], None, None] 38 """ 39 it = iter(seq) 40 wd = deque((next(it, None) for _ in range(n)), maxlen=n) 41 42 yield tuple(wd) 43 for el in it: 44 wd.append(el) 45 yield tuple(wd) 46 47 48class BetterBM25DocumentStore: 49 """ 50 An in-memory document store intended to improve the default BM25 document 51 store shipped with Haystack. 52 """ 53 54 default_sp_file: Final = os.path.join( 55 os.path.dirname(os.path.abspath(__file__)), "default.model" 56 ) 57 58 def __init__( 59 self, 60 *, 61 k: float = 1.5, 62 b: float = 0.75, 63 delta: float = 1.0, 64 sp_file: Optional[str] = None, 65 n_grams: Union[int, tuple[int, int]] = 1, 66 haystack_filter_logic: bool = True, 67 ) -> None: 68 """ 69 Creates a new BetterBM25DocumentStore instance. 70 71 An in-memory document store intended to improve the default 72 BM25 document store shipped with Haystack. The default store 73 recompute the index for the entire document store for every 74 in-coming query, which is significantly inefficient. This 75 store aims to improve the efficiency by pre-computing the 76 index for all documents in the store and only do incremental 77 updates when new documents are added or removed. Further, it 78 leverages a SentencePiece model to tokenize the input text 79 to allow more flexible and dynamic tokenization adapted to 80 domain-specific text. 81 82 :param k: the k1 parameter in BM25+ formula. 83 :type k: float, optional 84 :param b: the b parameter in BM25+ formula. 85 :type b: float, optional 86 :param delta: the delta parameter in BM25+ formula. 87 :type delta: float, optional 88 :param sp_file: the SentencePiece model file to use for 89 tokenization. 90 :type sp_file: Optional[str], optional 91 :param n_grams: the n-gram window size. 92 :type n_grams: Optional[Union[int, tuple[int, int]]], optional 93 :param haystack_filter_logic: Whether to use the Haystack 94 filter logic or the one implemented in this store, 95 which is more conservative. 96 :type haystack_filter_logic: bool, optional 97 """ 98 self.k = k 99 self.b = b 100 101 # Adjust the delta value so that we can bring the `(k1 + 1)` 102 # term out of the 'term frequency' term in BM25+ formula and 103 # delete it; this will not affect the ranking 104 self.delta = delta / (self.k + 1.0) 105 106 self._parse_sp_file(sp_file=sp_file) 107 self._parse_n_grams(n_grams=n_grams) 108 109 self._haystack_filter_logic = haystack_filter_logic 110 self._filter_func = ( 111 document_matches_filter 112 if self._haystack_filter_logic 113 else apply_filters_to_document 114 ) 115 116 self._avg_doc_len: float = 0.0 117 self._freq_doc: Counter = Counter() 118 self._index: dict[str, tuple[Document, dict[tuple[str], int], int]] = {} 119 120 def _parse_sp_file(self, sp_file: Optional[str]) -> None: 121 self._sp_file = sp_file 122 123 if sp_file is None: 124 self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file) 125 return 126 127 if not os.path.exists(sp_file) or not os.path.isfile(sp_file): 128 msg = ( 129 f"Tokenizer model file '{sp_file}' not accessible; " 130 f"fallback to default {self.default_sp_file}." 131 ) 132 logger.warn(msg) 133 self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file) 134 return 135 136 try: 137 self._sp_inst = SentencePieceProcessor(model_file=sp_file) 138 except Exception as exc: 139 msg = ( 140 f"Failed to load tokenizer model file '{sp_file}': {exc}; " 141 f"fallback to default {self.default_sp_file}." 142 ) 143 logger.error(msg) 144 self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file) 145 146 def _parse_n_grams(self, n_grams: Optional[Union[int, tuple[int, int]]]) -> None: 147 self._n_grams = n_grams 148 149 if isinstance(n_grams, int): 150 self._n_grams_min = 1 151 self._n_grams_max = n_grams 152 return 153 154 if isinstance(n_grams, tuple): 155 self._n_grams_min, self._n_grams_max = n_grams 156 if not all(isinstance(n, int) for n in n_grams): 157 msg = f"Invalid n-gram window size: {n_grams}." 158 raise ValueError(msg) 159 return 160 161 msg = f"Invalid n-gram window size: {n_grams}; expected int or tuple." 162 raise ValueError(msg) 163 164 def _tokenize(self, texts: Union[str, list[str]]) -> list[list[tuple[str]]]: 165 """ 166 Tokenize input text using SentencePiece model. 167 168 The input text can either be a single string or a list of strings, 169 such as a single user query or a group of raw document. The tokenized 170 text will be augmented into set of n-grams based. 171 172 :param texts: the input text to tokenize. 173 :type texts: Union[str, list[str]] 174 175 :return: the tokenized text, with n-grams augmented. 176 :rtype: list[list[tuple[str]]] 177 """ 178 179 def _augment_to_n_grams(tokens: list[str]) -> list[tuple[str]]: 180 it = ( 181 _n_grams(tokens, n) 182 for n in range(self._n_grams_min, self._n_grams_max + 1) 183 ) 184 return list(chain(*it)) 185 186 if isinstance(texts, str): 187 texts = [texts] 188 return [ 189 _augment_to_n_grams(tokens) 190 for tokens in self._sp_inst.encode(texts, out_type=str) 191 ] 192 193 def _compute_bm25plus( 194 self, 195 query: str, 196 documents: list[Document], 197 ) -> list[tuple[Document, float]]: 198 """ 199 Calculate the BM25+ score for all documents in this index. 200 201 :param query: the query to calculate the BM25+ score for. 202 :type query: str 203 :param documents: the pool of documents to calculate the BM25+ score for. 204 :type documents: list[Document] 205 206 :return: the BM25+ scores for all documents. 207 :rtype: list[tuple[Document, float]] 208 """ 209 cnt = lambda ng: self._freq_doc.get(ng, 0) 210 idf = { 211 ng: math.log(1 + (len(self._index) - cnt(ng) + 0.5) / (cnt(ng) + 0.5)) 212 for ng in self._tokenize(query)[0] 213 } 214 215 sim = [] 216 for doc in documents: 217 _, freq, doc_len = self._index[doc.id] 218 doc_len_scaled = doc_len / self._avg_doc_len 219 220 scr = 0.0 221 for token, idf_val in idf.items(): 222 freq_term = freq.get(token, 0.0) 223 freq_damp = self.k * (1 + self.b * (doc_len_scaled - 1)) 224 225 tf_val = freq_term / (freq_term + freq_damp) + self.delta 226 scr += idf_val * tf_val 227 228 sim.append((doc, scr)) 229 230 return sim 231 232 def _retrieval( 233 self, 234 query: str, 235 *, 236 filters: Optional[dict[str, Any]] = None, 237 top_k: Optional[int] = None, 238 ) -> list[tuple[Document, float]]: 239 """ 240 Retrieve documents from the store using the given query. 241 242 :param query: the query to search for. 243 :type query: str 244 :param filters: the filters to apply to the document list. 245 :type filters: Optional[dict[str, Any]] 246 :param top_k: the number of documents to return. 247 :type top_k: int 248 249 :return: the top-k documents and corresponding sim score. 250 :rtype: list[tuple[Document, float]] 251 """ 252 documents = self.filter_documents(filters) 253 if not documents: 254 return [] 255 256 sim = self._compute_bm25plus(query, documents) 257 if top_k is None: 258 return sorted(sim, key=lambda x: x[1], reverse=True) 259 return heapq.nlargest(top_k, sim, key=lambda x: x[1]) 260 261 def count_documents(self) -> int: 262 """ 263 Returns how many documents are present in the document store. 264 265 :return: the number of documents in the store. 266 :rtype: int 267 """ 268 return len(self._index) 269 270 def filter_documents( 271 self, filters: Optional[dict[str, Any]] = None 272 ) -> list[Document]: 273 """ 274 Filter documents in the store using the given filters. 275 276 :param filters: the filters to apply to the document list. 277 :type filters: Optional[dict[str, Any]] 278 279 :return: the list of documents that match the given filters. 280 :rtype: list[Document] 281 """ 282 if filters is None or not filters: 283 return [doc for doc, _, _ in self._index.values()] 284 return [ 285 doc 286 for doc, _, _ in self._index.values() 287 if self._filter_func(filters, doc) 288 ] 289 290 def write_documents( 291 self, 292 documents: list[Document], 293 policy: DuplicatePolicy = DuplicatePolicy.FAIL, 294 ) -> int: 295 """ 296 Writes (or overwrites) documents into the store. 297 298 :param documents: a list of documents. 299 :type documents: list[Document] 300 :param policy: documents with the same ID count as duplicates. 301 When duplicates are met, the store can: 302 - skip: keep the existing document and ignore the new one. 303 - overwrite: remove the old document and write the new one. 304 - fail: an error is raised 305 :type policy: DuplicatePolicy, optional 306 307 :raises DuplicateDocumentError: Exception trigger on duplicate 308 document if `policy=DuplicatePolicy.FAIL` 309 310 :return: Number of documents written. 311 :rtype: int 312 """ 313 n_written = 0 314 for doc in documents: 315 if not isinstance(doc, Document): 316 msg = f"Expected document type, got '{doc}' of type '{type(doc)}'." 317 raise ValueError(msg) 318 319 if doc.id in self._index.keys(): 320 if policy == DuplicatePolicy.SKIP: 321 continue 322 elif policy == DuplicatePolicy.FAIL: 323 msg = f"Document with ID '{doc.id}' already exists in the store." 324 raise DuplicateDocumentError(msg) 325 326 # Overwrite if exists; delete first to keep the statistics consistent 327 logger.debug( 328 f"Document '{doc.id}' already exists in the store, overwriting." 329 ) 330 self.delete_documents([doc.id]) 331 332 content = doc.content or "" 333 if content == "" and isinstance(doc.dataframe, pd.DataFrame): 334 content = doc.dataframe.astype(str).to_csv(index=False) 335 336 tokens = self._tokenize(content)[0] 337 338 self._index[doc.id] = (doc, Counter(tokens), len(tokens)) 339 self._freq_doc.update(set(tokens)) 340 self._avg_doc_len = ( 341 len(tokens) + self._avg_doc_len * len(self._index) 342 ) / (len(self._index) + 1) 343 344 logger.debug(f"Document '{doc.id}' written to store.") 345 n_written += 1 346 347 return n_written 348 349 def delete_documents(self, document_ids: list[str]) -> int: 350 """ 351 Deletes all documents with a matching document_ids. 352 353 Fails with `MissingDocumentError` if no document with 354 this id is present in the store. 355 356 :param object_ids: the object_ids to delete 357 :type object_ids: list[str] 358 359 :raises MissingDocumentError: trigger on missing document. 360 361 :return: Number of documents deleted. 362 :rtype: int 363 """ 364 n_removal = 0 365 for doc_id in document_ids: 366 try: 367 _, freq, doc_len = self._index.pop(doc_id) 368 self._freq_doc.subtract(Counter(freq.keys())) 369 try: 370 self._avg_doc_len = ( 371 self._avg_doc_len * (len(self._index) + 1) - doc_len 372 ) / len(self._index) 373 except ZeroDivisionError: 374 self._avg_doc_len = 0 375 376 logger.debug(f"Document '{doc_id}' deleted from store.") 377 n_removal += 1 378 except KeyError as exc: 379 msg = f"Document with ID '{doc_id}' not found, cannot delete it." 380 raise MissingDocumentError(msg) from exc 381 382 return n_removal 383 384 def to_dict(self) -> dict[str, Any]: 385 """Serializes this store to a dictionary.""" 386 return default_to_dict( 387 self, 388 k=self.k, 389 b=self.b, 390 delta=self.delta * (self.k + 1.0), # Because we scaled it on init 391 sp_file=self._sp_file, 392 n_grams=self._n_grams, 393 haystack_filter_logic=self._haystack_filter_logic, 394 ) 395 396 @classmethod 397 def from_dict(cls, data: dict[str, Any]) -> "BetterBM25DocumentStore": 398 """Deserializes the store from a dictionary.""" 399 return default_from_dict(cls, data)
49class BetterBM25DocumentStore: 50 """ 51 An in-memory document store intended to improve the default BM25 document 52 store shipped with Haystack. 53 """ 54 55 default_sp_file: Final = os.path.join( 56 os.path.dirname(os.path.abspath(__file__)), "default.model" 57 ) 58 59 def __init__( 60 self, 61 *, 62 k: float = 1.5, 63 b: float = 0.75, 64 delta: float = 1.0, 65 sp_file: Optional[str] = None, 66 n_grams: Union[int, tuple[int, int]] = 1, 67 haystack_filter_logic: bool = True, 68 ) -> None: 69 """ 70 Creates a new BetterBM25DocumentStore instance. 71 72 An in-memory document store intended to improve the default 73 BM25 document store shipped with Haystack. The default store 74 recompute the index for the entire document store for every 75 in-coming query, which is significantly inefficient. This 76 store aims to improve the efficiency by pre-computing the 77 index for all documents in the store and only do incremental 78 updates when new documents are added or removed. Further, it 79 leverages a SentencePiece model to tokenize the input text 80 to allow more flexible and dynamic tokenization adapted to 81 domain-specific text. 82 83 :param k: the k1 parameter in BM25+ formula. 84 :type k: float, optional 85 :param b: the b parameter in BM25+ formula. 86 :type b: float, optional 87 :param delta: the delta parameter in BM25+ formula. 88 :type delta: float, optional 89 :param sp_file: the SentencePiece model file to use for 90 tokenization. 91 :type sp_file: Optional[str], optional 92 :param n_grams: the n-gram window size. 93 :type n_grams: Optional[Union[int, tuple[int, int]]], optional 94 :param haystack_filter_logic: Whether to use the Haystack 95 filter logic or the one implemented in this store, 96 which is more conservative. 97 :type haystack_filter_logic: bool, optional 98 """ 99 self.k = k 100 self.b = b 101 102 # Adjust the delta value so that we can bring the `(k1 + 1)` 103 # term out of the 'term frequency' term in BM25+ formula and 104 # delete it; this will not affect the ranking 105 self.delta = delta / (self.k + 1.0) 106 107 self._parse_sp_file(sp_file=sp_file) 108 self._parse_n_grams(n_grams=n_grams) 109 110 self._haystack_filter_logic = haystack_filter_logic 111 self._filter_func = ( 112 document_matches_filter 113 if self._haystack_filter_logic 114 else apply_filters_to_document 115 ) 116 117 self._avg_doc_len: float = 0.0 118 self._freq_doc: Counter = Counter() 119 self._index: dict[str, tuple[Document, dict[tuple[str], int], int]] = {} 120 121 def _parse_sp_file(self, sp_file: Optional[str]) -> None: 122 self._sp_file = sp_file 123 124 if sp_file is None: 125 self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file) 126 return 127 128 if not os.path.exists(sp_file) or not os.path.isfile(sp_file): 129 msg = ( 130 f"Tokenizer model file '{sp_file}' not accessible; " 131 f"fallback to default {self.default_sp_file}." 132 ) 133 logger.warn(msg) 134 self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file) 135 return 136 137 try: 138 self._sp_inst = SentencePieceProcessor(model_file=sp_file) 139 except Exception as exc: 140 msg = ( 141 f"Failed to load tokenizer model file '{sp_file}': {exc}; " 142 f"fallback to default {self.default_sp_file}." 143 ) 144 logger.error(msg) 145 self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file) 146 147 def _parse_n_grams(self, n_grams: Optional[Union[int, tuple[int, int]]]) -> None: 148 self._n_grams = n_grams 149 150 if isinstance(n_grams, int): 151 self._n_grams_min = 1 152 self._n_grams_max = n_grams 153 return 154 155 if isinstance(n_grams, tuple): 156 self._n_grams_min, self._n_grams_max = n_grams 157 if not all(isinstance(n, int) for n in n_grams): 158 msg = f"Invalid n-gram window size: {n_grams}." 159 raise ValueError(msg) 160 return 161 162 msg = f"Invalid n-gram window size: {n_grams}; expected int or tuple." 163 raise ValueError(msg) 164 165 def _tokenize(self, texts: Union[str, list[str]]) -> list[list[tuple[str]]]: 166 """ 167 Tokenize input text using SentencePiece model. 168 169 The input text can either be a single string or a list of strings, 170 such as a single user query or a group of raw document. The tokenized 171 text will be augmented into set of n-grams based. 172 173 :param texts: the input text to tokenize. 174 :type texts: Union[str, list[str]] 175 176 :return: the tokenized text, with n-grams augmented. 177 :rtype: list[list[tuple[str]]] 178 """ 179 180 def _augment_to_n_grams(tokens: list[str]) -> list[tuple[str]]: 181 it = ( 182 _n_grams(tokens, n) 183 for n in range(self._n_grams_min, self._n_grams_max + 1) 184 ) 185 return list(chain(*it)) 186 187 if isinstance(texts, str): 188 texts = [texts] 189 return [ 190 _augment_to_n_grams(tokens) 191 for tokens in self._sp_inst.encode(texts, out_type=str) 192 ] 193 194 def _compute_bm25plus( 195 self, 196 query: str, 197 documents: list[Document], 198 ) -> list[tuple[Document, float]]: 199 """ 200 Calculate the BM25+ score for all documents in this index. 201 202 :param query: the query to calculate the BM25+ score for. 203 :type query: str 204 :param documents: the pool of documents to calculate the BM25+ score for. 205 :type documents: list[Document] 206 207 :return: the BM25+ scores for all documents. 208 :rtype: list[tuple[Document, float]] 209 """ 210 cnt = lambda ng: self._freq_doc.get(ng, 0) 211 idf = { 212 ng: math.log(1 + (len(self._index) - cnt(ng) + 0.5) / (cnt(ng) + 0.5)) 213 for ng in self._tokenize(query)[0] 214 } 215 216 sim = [] 217 for doc in documents: 218 _, freq, doc_len = self._index[doc.id] 219 doc_len_scaled = doc_len / self._avg_doc_len 220 221 scr = 0.0 222 for token, idf_val in idf.items(): 223 freq_term = freq.get(token, 0.0) 224 freq_damp = self.k * (1 + self.b * (doc_len_scaled - 1)) 225 226 tf_val = freq_term / (freq_term + freq_damp) + self.delta 227 scr += idf_val * tf_val 228 229 sim.append((doc, scr)) 230 231 return sim 232 233 def _retrieval( 234 self, 235 query: str, 236 *, 237 filters: Optional[dict[str, Any]] = None, 238 top_k: Optional[int] = None, 239 ) -> list[tuple[Document, float]]: 240 """ 241 Retrieve documents from the store using the given query. 242 243 :param query: the query to search for. 244 :type query: str 245 :param filters: the filters to apply to the document list. 246 :type filters: Optional[dict[str, Any]] 247 :param top_k: the number of documents to return. 248 :type top_k: int 249 250 :return: the top-k documents and corresponding sim score. 251 :rtype: list[tuple[Document, float]] 252 """ 253 documents = self.filter_documents(filters) 254 if not documents: 255 return [] 256 257 sim = self._compute_bm25plus(query, documents) 258 if top_k is None: 259 return sorted(sim, key=lambda x: x[1], reverse=True) 260 return heapq.nlargest(top_k, sim, key=lambda x: x[1]) 261 262 def count_documents(self) -> int: 263 """ 264 Returns how many documents are present in the document store. 265 266 :return: the number of documents in the store. 267 :rtype: int 268 """ 269 return len(self._index) 270 271 def filter_documents( 272 self, filters: Optional[dict[str, Any]] = None 273 ) -> list[Document]: 274 """ 275 Filter documents in the store using the given filters. 276 277 :param filters: the filters to apply to the document list. 278 :type filters: Optional[dict[str, Any]] 279 280 :return: the list of documents that match the given filters. 281 :rtype: list[Document] 282 """ 283 if filters is None or not filters: 284 return [doc for doc, _, _ in self._index.values()] 285 return [ 286 doc 287 for doc, _, _ in self._index.values() 288 if self._filter_func(filters, doc) 289 ] 290 291 def write_documents( 292 self, 293 documents: list[Document], 294 policy: DuplicatePolicy = DuplicatePolicy.FAIL, 295 ) -> int: 296 """ 297 Writes (or overwrites) documents into the store. 298 299 :param documents: a list of documents. 300 :type documents: list[Document] 301 :param policy: documents with the same ID count as duplicates. 302 When duplicates are met, the store can: 303 - skip: keep the existing document and ignore the new one. 304 - overwrite: remove the old document and write the new one. 305 - fail: an error is raised 306 :type policy: DuplicatePolicy, optional 307 308 :raises DuplicateDocumentError: Exception trigger on duplicate 309 document if `policy=DuplicatePolicy.FAIL` 310 311 :return: Number of documents written. 312 :rtype: int 313 """ 314 n_written = 0 315 for doc in documents: 316 if not isinstance(doc, Document): 317 msg = f"Expected document type, got '{doc}' of type '{type(doc)}'." 318 raise ValueError(msg) 319 320 if doc.id in self._index.keys(): 321 if policy == DuplicatePolicy.SKIP: 322 continue 323 elif policy == DuplicatePolicy.FAIL: 324 msg = f"Document with ID '{doc.id}' already exists in the store." 325 raise DuplicateDocumentError(msg) 326 327 # Overwrite if exists; delete first to keep the statistics consistent 328 logger.debug( 329 f"Document '{doc.id}' already exists in the store, overwriting." 330 ) 331 self.delete_documents([doc.id]) 332 333 content = doc.content or "" 334 if content == "" and isinstance(doc.dataframe, pd.DataFrame): 335 content = doc.dataframe.astype(str).to_csv(index=False) 336 337 tokens = self._tokenize(content)[0] 338 339 self._index[doc.id] = (doc, Counter(tokens), len(tokens)) 340 self._freq_doc.update(set(tokens)) 341 self._avg_doc_len = ( 342 len(tokens) + self._avg_doc_len * len(self._index) 343 ) / (len(self._index) + 1) 344 345 logger.debug(f"Document '{doc.id}' written to store.") 346 n_written += 1 347 348 return n_written 349 350 def delete_documents(self, document_ids: list[str]) -> int: 351 """ 352 Deletes all documents with a matching document_ids. 353 354 Fails with `MissingDocumentError` if no document with 355 this id is present in the store. 356 357 :param object_ids: the object_ids to delete 358 :type object_ids: list[str] 359 360 :raises MissingDocumentError: trigger on missing document. 361 362 :return: Number of documents deleted. 363 :rtype: int 364 """ 365 n_removal = 0 366 for doc_id in document_ids: 367 try: 368 _, freq, doc_len = self._index.pop(doc_id) 369 self._freq_doc.subtract(Counter(freq.keys())) 370 try: 371 self._avg_doc_len = ( 372 self._avg_doc_len * (len(self._index) + 1) - doc_len 373 ) / len(self._index) 374 except ZeroDivisionError: 375 self._avg_doc_len = 0 376 377 logger.debug(f"Document '{doc_id}' deleted from store.") 378 n_removal += 1 379 except KeyError as exc: 380 msg = f"Document with ID '{doc_id}' not found, cannot delete it." 381 raise MissingDocumentError(msg) from exc 382 383 return n_removal 384 385 def to_dict(self) -> dict[str, Any]: 386 """Serializes this store to a dictionary.""" 387 return default_to_dict( 388 self, 389 k=self.k, 390 b=self.b, 391 delta=self.delta * (self.k + 1.0), # Because we scaled it on init 392 sp_file=self._sp_file, 393 n_grams=self._n_grams, 394 haystack_filter_logic=self._haystack_filter_logic, 395 ) 396 397 @classmethod 398 def from_dict(cls, data: dict[str, Any]) -> "BetterBM25DocumentStore": 399 """Deserializes the store from a dictionary.""" 400 return default_from_dict(cls, data)
An in-memory document store intended to improve the default BM25 document store shipped with Haystack.
59 def __init__( 60 self, 61 *, 62 k: float = 1.5, 63 b: float = 0.75, 64 delta: float = 1.0, 65 sp_file: Optional[str] = None, 66 n_grams: Union[int, tuple[int, int]] = 1, 67 haystack_filter_logic: bool = True, 68 ) -> None: 69 """ 70 Creates a new BetterBM25DocumentStore instance. 71 72 An in-memory document store intended to improve the default 73 BM25 document store shipped with Haystack. The default store 74 recompute the index for the entire document store for every 75 in-coming query, which is significantly inefficient. This 76 store aims to improve the efficiency by pre-computing the 77 index for all documents in the store and only do incremental 78 updates when new documents are added or removed. Further, it 79 leverages a SentencePiece model to tokenize the input text 80 to allow more flexible and dynamic tokenization adapted to 81 domain-specific text. 82 83 :param k: the k1 parameter in BM25+ formula. 84 :type k: float, optional 85 :param b: the b parameter in BM25+ formula. 86 :type b: float, optional 87 :param delta: the delta parameter in BM25+ formula. 88 :type delta: float, optional 89 :param sp_file: the SentencePiece model file to use for 90 tokenization. 91 :type sp_file: Optional[str], optional 92 :param n_grams: the n-gram window size. 93 :type n_grams: Optional[Union[int, tuple[int, int]]], optional 94 :param haystack_filter_logic: Whether to use the Haystack 95 filter logic or the one implemented in this store, 96 which is more conservative. 97 :type haystack_filter_logic: bool, optional 98 """ 99 self.k = k 100 self.b = b 101 102 # Adjust the delta value so that we can bring the `(k1 + 1)` 103 # term out of the 'term frequency' term in BM25+ formula and 104 # delete it; this will not affect the ranking 105 self.delta = delta / (self.k + 1.0) 106 107 self._parse_sp_file(sp_file=sp_file) 108 self._parse_n_grams(n_grams=n_grams) 109 110 self._haystack_filter_logic = haystack_filter_logic 111 self._filter_func = ( 112 document_matches_filter 113 if self._haystack_filter_logic 114 else apply_filters_to_document 115 ) 116 117 self._avg_doc_len: float = 0.0 118 self._freq_doc: Counter = Counter() 119 self._index: dict[str, tuple[Document, dict[tuple[str], int], int]] = {}
Creates a new BetterBM25DocumentStore instance.
An in-memory document store intended to improve the default BM25 document store shipped with Haystack. The default store recompute the index for the entire document store for every in-coming query, which is significantly inefficient. This store aims to improve the efficiency by pre-computing the index for all documents in the store and only do incremental updates when new documents are added or removed. Further, it leverages a SentencePiece model to tokenize the input text to allow more flexible and dynamic tokenization adapted to domain-specific text.
Parameters
- k: the k1 parameter in BM25+ formula.
- b: the b parameter in BM25+ formula.
- delta: the delta parameter in BM25+ formula.
- sp_file: the SentencePiece model file to use for tokenization.
- n_grams: the n-gram window size.
- haystack_filter_logic: Whether to use the Haystack filter logic or the one implemented in this store, which is more conservative.
262 def count_documents(self) -> int: 263 """ 264 Returns how many documents are present in the document store. 265 266 :return: the number of documents in the store. 267 :rtype: int 268 """ 269 return len(self._index)
Returns how many documents are present in the document store.
Returns
the number of documents in the store.
271 def filter_documents( 272 self, filters: Optional[dict[str, Any]] = None 273 ) -> list[Document]: 274 """ 275 Filter documents in the store using the given filters. 276 277 :param filters: the filters to apply to the document list. 278 :type filters: Optional[dict[str, Any]] 279 280 :return: the list of documents that match the given filters. 281 :rtype: list[Document] 282 """ 283 if filters is None or not filters: 284 return [doc for doc, _, _ in self._index.values()] 285 return [ 286 doc 287 for doc, _, _ in self._index.values() 288 if self._filter_func(filters, doc) 289 ]
Filter documents in the store using the given filters.
Parameters
- filters: the filters to apply to the document list.
Returns
the list of documents that match the given filters.
291 def write_documents( 292 self, 293 documents: list[Document], 294 policy: DuplicatePolicy = DuplicatePolicy.FAIL, 295 ) -> int: 296 """ 297 Writes (or overwrites) documents into the store. 298 299 :param documents: a list of documents. 300 :type documents: list[Document] 301 :param policy: documents with the same ID count as duplicates. 302 When duplicates are met, the store can: 303 - skip: keep the existing document and ignore the new one. 304 - overwrite: remove the old document and write the new one. 305 - fail: an error is raised 306 :type policy: DuplicatePolicy, optional 307 308 :raises DuplicateDocumentError: Exception trigger on duplicate 309 document if `policy=DuplicatePolicy.FAIL` 310 311 :return: Number of documents written. 312 :rtype: int 313 """ 314 n_written = 0 315 for doc in documents: 316 if not isinstance(doc, Document): 317 msg = f"Expected document type, got '{doc}' of type '{type(doc)}'." 318 raise ValueError(msg) 319 320 if doc.id in self._index.keys(): 321 if policy == DuplicatePolicy.SKIP: 322 continue 323 elif policy == DuplicatePolicy.FAIL: 324 msg = f"Document with ID '{doc.id}' already exists in the store." 325 raise DuplicateDocumentError(msg) 326 327 # Overwrite if exists; delete first to keep the statistics consistent 328 logger.debug( 329 f"Document '{doc.id}' already exists in the store, overwriting." 330 ) 331 self.delete_documents([doc.id]) 332 333 content = doc.content or "" 334 if content == "" and isinstance(doc.dataframe, pd.DataFrame): 335 content = doc.dataframe.astype(str).to_csv(index=False) 336 337 tokens = self._tokenize(content)[0] 338 339 self._index[doc.id] = (doc, Counter(tokens), len(tokens)) 340 self._freq_doc.update(set(tokens)) 341 self._avg_doc_len = ( 342 len(tokens) + self._avg_doc_len * len(self._index) 343 ) / (len(self._index) + 1) 344 345 logger.debug(f"Document '{doc.id}' written to store.") 346 n_written += 1 347 348 return n_written
Writes (or overwrites) documents into the store.
Parameters
- documents: a list of documents.
- policy: documents with the same ID count as duplicates.
When duplicates are met, the store can:
- skip: keep the existing document and ignore the new one.
- overwrite: remove the old document and write the new one.
- fail: an error is raised
Raises
- DuplicateDocumentError: Exception trigger on duplicate
document if
policy=DuplicatePolicy.FAIL
Returns
Number of documents written.
350 def delete_documents(self, document_ids: list[str]) -> int: 351 """ 352 Deletes all documents with a matching document_ids. 353 354 Fails with `MissingDocumentError` if no document with 355 this id is present in the store. 356 357 :param object_ids: the object_ids to delete 358 :type object_ids: list[str] 359 360 :raises MissingDocumentError: trigger on missing document. 361 362 :return: Number of documents deleted. 363 :rtype: int 364 """ 365 n_removal = 0 366 for doc_id in document_ids: 367 try: 368 _, freq, doc_len = self._index.pop(doc_id) 369 self._freq_doc.subtract(Counter(freq.keys())) 370 try: 371 self._avg_doc_len = ( 372 self._avg_doc_len * (len(self._index) + 1) - doc_len 373 ) / len(self._index) 374 except ZeroDivisionError: 375 self._avg_doc_len = 0 376 377 logger.debug(f"Document '{doc_id}' deleted from store.") 378 n_removal += 1 379 except KeyError as exc: 380 msg = f"Document with ID '{doc_id}' not found, cannot delete it." 381 raise MissingDocumentError(msg) from exc 382 383 return n_removal
Deletes all documents with a matching document_ids.
Fails with MissingDocumentError if no document with
this id is present in the store.
Parameters
- object_ids: the object_ids to delete
Raises
- MissingDocumentError: trigger on missing document.
Returns
Number of documents deleted.
385 def to_dict(self) -> dict[str, Any]: 386 """Serializes this store to a dictionary.""" 387 return default_to_dict( 388 self, 389 k=self.k, 390 b=self.b, 391 delta=self.delta * (self.k + 1.0), # Because we scaled it on init 392 sp_file=self._sp_file, 393 n_grams=self._n_grams, 394 haystack_filter_logic=self._haystack_filter_logic, 395 )
Serializes this store to a dictionary.