bbm25_haystack.bbm25_store

  1# SPDX-FileCopyrightText: 2024-present Yuxuan Wang <wangy49@seas.upenn.edu>
  2#
  3# SPDX-License-Identifier: Apache-2.0
  4import heapq
  5import math
  6import os
  7from collections import Counter, deque
  8from collections.abc import Iterable
  9from itertools import chain
 10from typing import Any, Final, Optional, Union
 11
 12import pandas as pd
 13from haystack import Document, default_from_dict, default_to_dict, logging
 14from haystack.document_stores.errors import (
 15    DuplicateDocumentError,
 16    MissingDocumentError,
 17)
 18from haystack.document_stores.types import DuplicatePolicy
 19from haystack.utils.filters import document_matches_filter
 20from sentencepiece import SentencePieceProcessor  # type: ignore
 21
 22from bbm25_haystack.filters import apply_filters_to_document
 23
 24logger = logging.getLogger(__name__)
 25
 26
 27def _n_grams(seq: Iterable[str], n: int):
 28    """
 29    Returns a sliding window (of width n) over data from the iterable.
 30
 31    :param seq: the input token sequence.
 32    :type seq: Iterable[str]
 33    :param n: the window size.
 34    :type n: int
 35
 36    :return: the n-gram window generator.
 37    :rtype: Generator[tuple[str], None, None]
 38    """
 39    it = iter(seq)
 40    wd = deque((next(it, None) for _ in range(n)), maxlen=n)
 41
 42    yield tuple(wd)
 43    for el in it:
 44        wd.append(el)
 45        yield tuple(wd)
 46
 47
 48class BetterBM25DocumentStore:
 49    """
 50    An in-memory document store intended to improve the default BM25 document
 51    store shipped with Haystack.
 52    """
 53
 54    default_sp_file: Final = os.path.join(
 55        os.path.dirname(os.path.abspath(__file__)), "default.model"
 56    )
 57
 58    def __init__(
 59        self,
 60        *,
 61        k: float = 1.5,
 62        b: float = 0.75,
 63        delta: float = 1.0,
 64        sp_file: Optional[str] = None,
 65        n_grams: Union[int, tuple[int, int]] = 1,
 66        haystack_filter_logic: bool = True,
 67    ) -> None:
 68        """
 69        Creates a new BetterBM25DocumentStore instance.
 70
 71        An in-memory document store intended to improve the default
 72        BM25 document store shipped with Haystack. The default store
 73        recompute the index for the entire document store for every
 74        in-coming query, which is significantly inefficient. This
 75        store aims to improve the efficiency by pre-computing the
 76        index for all documents in the store and only do incremental
 77        updates when new documents are added or removed. Further, it
 78        leverages a SentencePiece model to tokenize the input text
 79        to allow more flexible and dynamic tokenization adapted to
 80        domain-specific text.
 81
 82        :param k: the k1 parameter in BM25+ formula.
 83        :type k: float, optional
 84        :param b: the b parameter in BM25+ formula.
 85        :type b: float, optional
 86        :param delta: the delta parameter in BM25+ formula.
 87        :type delta: float, optional
 88        :param sp_file: the SentencePiece model file to use for
 89            tokenization.
 90        :type sp_file: Optional[str], optional
 91        :param n_grams: the n-gram window size.
 92        :type n_grams: Optional[Union[int, tuple[int, int]]], optional
 93        :param haystack_filter_logic: Whether to use the Haystack
 94            filter logic or the one implemented in this store,
 95            which is more conservative.
 96        :type haystack_filter_logic: bool, optional
 97        """
 98        self.k = k
 99        self.b = b
100
101        # Adjust the delta value so that we can bring the `(k1 + 1)`
102        # term out of the 'term frequency' term in BM25+ formula and
103        # delete it; this will not affect the ranking
104        self.delta = delta / (self.k + 1.0)
105
106        self._parse_sp_file(sp_file=sp_file)
107        self._parse_n_grams(n_grams=n_grams)
108
109        self._haystack_filter_logic = haystack_filter_logic
110        self._filter_func = (
111            document_matches_filter
112            if self._haystack_filter_logic
113            else apply_filters_to_document
114        )
115
116        self._avg_doc_len: float = 0.0
117        self._freq_doc: Counter = Counter()
118        self._index: dict[str, tuple[Document, dict[tuple[str], int], int]] = {}
119
120    def _parse_sp_file(self, sp_file: Optional[str]) -> None:
121        self._sp_file = sp_file
122
123        if sp_file is None:
124            self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file)
125            return
126
127        if not os.path.exists(sp_file) or not os.path.isfile(sp_file):
128            msg = (
129                f"Tokenizer model file '{sp_file}' not accessible; "
130                f"fallback to default {self.default_sp_file}."
131            )
132            logger.warn(msg)
133            self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file)
134            return
135
136        try:
137            self._sp_inst = SentencePieceProcessor(model_file=sp_file)
138        except Exception as exc:
139            msg = (
140                f"Failed to load tokenizer model file '{sp_file}': {exc}; "
141                f"fallback to default {self.default_sp_file}."
142            )
143            logger.error(msg)
144            self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file)
145
146    def _parse_n_grams(self, n_grams: Optional[Union[int, tuple[int, int]]]) -> None:
147        self._n_grams = n_grams
148
149        if isinstance(n_grams, int):
150            self._n_grams_min = 1
151            self._n_grams_max = n_grams
152            return
153
154        if isinstance(n_grams, tuple):
155            self._n_grams_min, self._n_grams_max = n_grams
156            if not all(isinstance(n, int) for n in n_grams):
157                msg = f"Invalid n-gram window size: {n_grams}."
158                raise ValueError(msg)
159            return
160
161        msg = f"Invalid n-gram window size: {n_grams}; expected int or tuple."
162        raise ValueError(msg)
163
164    def _tokenize(self, texts: Union[str, list[str]]) -> list[list[tuple[str]]]:
165        """
166        Tokenize input text using SentencePiece model.
167
168        The input text can either be a single string or a list of strings,
169        such as a single user query or a group of raw document. The tokenized
170        text will be augmented into set of n-grams based.
171
172        :param texts: the input text to tokenize.
173        :type texts: Union[str, list[str]]
174
175        :return: the tokenized text, with n-grams augmented.
176        :rtype: list[list[tuple[str]]]
177        """
178
179        def _augment_to_n_grams(tokens: list[str]) -> list[tuple[str]]:
180            it = (
181                _n_grams(tokens, n)
182                for n in range(self._n_grams_min, self._n_grams_max + 1)
183            )
184            return list(chain(*it))
185
186        if isinstance(texts, str):
187            texts = [texts]
188        return [
189            _augment_to_n_grams(tokens)
190            for tokens in self._sp_inst.encode(texts, out_type=str)
191        ]
192
193    def _compute_bm25plus(
194        self,
195        query: str,
196        documents: list[Document],
197    ) -> list[tuple[Document, float]]:
198        """
199        Calculate the BM25+ score for all documents in this index.
200
201        :param query: the query to calculate the BM25+ score for.
202        :type query: str
203        :param documents: the pool of documents to calculate the BM25+ score for.
204        :type documents: list[Document]
205
206        :return: the BM25+ scores for all documents.
207        :rtype: list[tuple[Document, float]]
208        """
209        cnt = lambda ng: self._freq_doc.get(ng, 0)
210        idf = {
211            ng: math.log(1 + (len(self._index) - cnt(ng) + 0.5) / (cnt(ng) + 0.5))
212            for ng in self._tokenize(query)[0]
213        }
214
215        sim = []
216        for doc in documents:
217            _, freq, doc_len = self._index[doc.id]
218            doc_len_scaled = doc_len / self._avg_doc_len
219
220            scr = 0.0
221            for token, idf_val in idf.items():
222                freq_term = freq.get(token, 0.0)
223                freq_damp = self.k * (1 + self.b * (doc_len_scaled - 1))
224
225                tf_val = freq_term / (freq_term + freq_damp) + self.delta
226                scr += idf_val * tf_val
227
228            sim.append((doc, scr))
229
230        return sim
231
232    def _retrieval(
233        self,
234        query: str,
235        *,
236        filters: Optional[dict[str, Any]] = None,
237        top_k: Optional[int] = None,
238    ) -> list[tuple[Document, float]]:
239        """
240        Retrieve documents from the store using the given query.
241
242        :param query: the query to search for.
243        :type query: str
244        :param filters: the filters to apply to the document list.
245        :type filters: Optional[dict[str, Any]]
246        :param top_k: the number of documents to return.
247        :type top_k: int
248
249        :return: the top-k documents and corresponding sim score.
250        :rtype: list[tuple[Document, float]]
251        """
252        documents = self.filter_documents(filters)
253        if not documents:
254            return []
255
256        sim = self._compute_bm25plus(query, documents)
257        if top_k is None:
258            return sorted(sim, key=lambda x: x[1], reverse=True)
259        return heapq.nlargest(top_k, sim, key=lambda x: x[1])
260
261    def count_documents(self) -> int:
262        """
263        Returns how many documents are present in the document store.
264
265        :return: the number of documents in the store.
266        :rtype: int
267        """
268        return len(self._index)
269
270    def filter_documents(
271        self, filters: Optional[dict[str, Any]] = None
272    ) -> list[Document]:
273        """
274        Filter documents in the store using the given filters.
275
276        :param filters: the filters to apply to the document list.
277        :type filters: Optional[dict[str, Any]]
278
279        :return: the list of documents that match the given filters.
280        :rtype: list[Document]
281        """
282        if filters is None or not filters:
283            return [doc for doc, _, _ in self._index.values()]
284        return [
285            doc
286            for doc, _, _ in self._index.values()
287            if self._filter_func(filters, doc)
288        ]
289
290    def write_documents(
291        self,
292        documents: list[Document],
293        policy: DuplicatePolicy = DuplicatePolicy.FAIL,
294    ) -> int:
295        """
296        Writes (or overwrites) documents into the store.
297
298        :param documents: a list of documents.
299        :type documents: list[Document]
300        :param policy: documents with the same ID count as duplicates.
301            When duplicates are met, the store can:
302             - skip: keep the existing document and ignore the new one.
303             - overwrite: remove the old document and write the new one.
304             - fail: an error is raised
305        :type policy: DuplicatePolicy, optional
306
307        :raises DuplicateDocumentError: Exception trigger on duplicate
308            document if `policy=DuplicatePolicy.FAIL`
309
310        :return: Number of documents written.
311        :rtype: int
312        """
313        n_written = 0
314        for doc in documents:
315            if not isinstance(doc, Document):
316                msg = f"Expected document type, got '{doc}' of type '{type(doc)}'."
317                raise ValueError(msg)
318
319            if doc.id in self._index.keys():
320                if policy == DuplicatePolicy.SKIP:
321                    continue
322                elif policy == DuplicatePolicy.FAIL:
323                    msg = f"Document with ID '{doc.id}' already exists in the store."
324                    raise DuplicateDocumentError(msg)
325
326                # Overwrite if exists; delete first to keep the statistics consistent
327                logger.debug(
328                    f"Document '{doc.id}' already exists in the store, overwriting."
329                )
330                self.delete_documents([doc.id])
331
332            content = doc.content or ""
333            if content == "" and isinstance(doc.dataframe, pd.DataFrame):
334                content = doc.dataframe.astype(str).to_csv(index=False)
335
336            tokens = self._tokenize(content)[0]
337
338            self._index[doc.id] = (doc, Counter(tokens), len(tokens))
339            self._freq_doc.update(set(tokens))
340            self._avg_doc_len = (
341                len(tokens) + self._avg_doc_len * len(self._index)
342            ) / (len(self._index) + 1)
343
344            logger.debug(f"Document '{doc.id}' written to store.")
345            n_written += 1
346
347        return n_written
348
349    def delete_documents(self, document_ids: list[str]) -> int:
350        """
351        Deletes all documents with a matching document_ids.
352
353        Fails with `MissingDocumentError` if no document with
354        this id is present in the store.
355
356        :param object_ids: the object_ids to delete
357        :type object_ids: list[str]
358
359        :raises MissingDocumentError: trigger on missing document.
360
361        :return: Number of documents deleted.
362        :rtype: int
363        """
364        n_removal = 0
365        for doc_id in document_ids:
366            try:
367                _, freq, doc_len = self._index.pop(doc_id)
368                self._freq_doc.subtract(Counter(freq.keys()))
369                try:
370                    self._avg_doc_len = (
371                        self._avg_doc_len * (len(self._index) + 1) - doc_len
372                    ) / len(self._index)
373                except ZeroDivisionError:
374                    self._avg_doc_len = 0
375
376                logger.debug(f"Document '{doc_id}' deleted from store.")
377                n_removal += 1
378            except KeyError as exc:
379                msg = f"Document with ID '{doc_id}' not found, cannot delete it."
380                raise MissingDocumentError(msg) from exc
381
382        return n_removal
383
384    def to_dict(self) -> dict[str, Any]:
385        """Serializes this store to a dictionary."""
386        return default_to_dict(
387            self,
388            k=self.k,
389            b=self.b,
390            delta=self.delta * (self.k + 1.0),  # Because we scaled it on init
391            sp_file=self._sp_file,
392            n_grams=self._n_grams,
393            haystack_filter_logic=self._haystack_filter_logic,
394        )
395
396    @classmethod
397    def from_dict(cls, data: dict[str, Any]) -> "BetterBM25DocumentStore":
398        """Deserializes the store from a dictionary."""
399        return default_from_dict(cls, data)
logger = <Logger bbm25_haystack.bbm25_store (WARNING)>
class BetterBM25DocumentStore:
 49class BetterBM25DocumentStore:
 50    """
 51    An in-memory document store intended to improve the default BM25 document
 52    store shipped with Haystack.
 53    """
 54
 55    default_sp_file: Final = os.path.join(
 56        os.path.dirname(os.path.abspath(__file__)), "default.model"
 57    )
 58
 59    def __init__(
 60        self,
 61        *,
 62        k: float = 1.5,
 63        b: float = 0.75,
 64        delta: float = 1.0,
 65        sp_file: Optional[str] = None,
 66        n_grams: Union[int, tuple[int, int]] = 1,
 67        haystack_filter_logic: bool = True,
 68    ) -> None:
 69        """
 70        Creates a new BetterBM25DocumentStore instance.
 71
 72        An in-memory document store intended to improve the default
 73        BM25 document store shipped with Haystack. The default store
 74        recompute the index for the entire document store for every
 75        in-coming query, which is significantly inefficient. This
 76        store aims to improve the efficiency by pre-computing the
 77        index for all documents in the store and only do incremental
 78        updates when new documents are added or removed. Further, it
 79        leverages a SentencePiece model to tokenize the input text
 80        to allow more flexible and dynamic tokenization adapted to
 81        domain-specific text.
 82
 83        :param k: the k1 parameter in BM25+ formula.
 84        :type k: float, optional
 85        :param b: the b parameter in BM25+ formula.
 86        :type b: float, optional
 87        :param delta: the delta parameter in BM25+ formula.
 88        :type delta: float, optional
 89        :param sp_file: the SentencePiece model file to use for
 90            tokenization.
 91        :type sp_file: Optional[str], optional
 92        :param n_grams: the n-gram window size.
 93        :type n_grams: Optional[Union[int, tuple[int, int]]], optional
 94        :param haystack_filter_logic: Whether to use the Haystack
 95            filter logic or the one implemented in this store,
 96            which is more conservative.
 97        :type haystack_filter_logic: bool, optional
 98        """
 99        self.k = k
100        self.b = b
101
102        # Adjust the delta value so that we can bring the `(k1 + 1)`
103        # term out of the 'term frequency' term in BM25+ formula and
104        # delete it; this will not affect the ranking
105        self.delta = delta / (self.k + 1.0)
106
107        self._parse_sp_file(sp_file=sp_file)
108        self._parse_n_grams(n_grams=n_grams)
109
110        self._haystack_filter_logic = haystack_filter_logic
111        self._filter_func = (
112            document_matches_filter
113            if self._haystack_filter_logic
114            else apply_filters_to_document
115        )
116
117        self._avg_doc_len: float = 0.0
118        self._freq_doc: Counter = Counter()
119        self._index: dict[str, tuple[Document, dict[tuple[str], int], int]] = {}
120
121    def _parse_sp_file(self, sp_file: Optional[str]) -> None:
122        self._sp_file = sp_file
123
124        if sp_file is None:
125            self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file)
126            return
127
128        if not os.path.exists(sp_file) or not os.path.isfile(sp_file):
129            msg = (
130                f"Tokenizer model file '{sp_file}' not accessible; "
131                f"fallback to default {self.default_sp_file}."
132            )
133            logger.warn(msg)
134            self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file)
135            return
136
137        try:
138            self._sp_inst = SentencePieceProcessor(model_file=sp_file)
139        except Exception as exc:
140            msg = (
141                f"Failed to load tokenizer model file '{sp_file}': {exc}; "
142                f"fallback to default {self.default_sp_file}."
143            )
144            logger.error(msg)
145            self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file)
146
147    def _parse_n_grams(self, n_grams: Optional[Union[int, tuple[int, int]]]) -> None:
148        self._n_grams = n_grams
149
150        if isinstance(n_grams, int):
151            self._n_grams_min = 1
152            self._n_grams_max = n_grams
153            return
154
155        if isinstance(n_grams, tuple):
156            self._n_grams_min, self._n_grams_max = n_grams
157            if not all(isinstance(n, int) for n in n_grams):
158                msg = f"Invalid n-gram window size: {n_grams}."
159                raise ValueError(msg)
160            return
161
162        msg = f"Invalid n-gram window size: {n_grams}; expected int or tuple."
163        raise ValueError(msg)
164
165    def _tokenize(self, texts: Union[str, list[str]]) -> list[list[tuple[str]]]:
166        """
167        Tokenize input text using SentencePiece model.
168
169        The input text can either be a single string or a list of strings,
170        such as a single user query or a group of raw document. The tokenized
171        text will be augmented into set of n-grams based.
172
173        :param texts: the input text to tokenize.
174        :type texts: Union[str, list[str]]
175
176        :return: the tokenized text, with n-grams augmented.
177        :rtype: list[list[tuple[str]]]
178        """
179
180        def _augment_to_n_grams(tokens: list[str]) -> list[tuple[str]]:
181            it = (
182                _n_grams(tokens, n)
183                for n in range(self._n_grams_min, self._n_grams_max + 1)
184            )
185            return list(chain(*it))
186
187        if isinstance(texts, str):
188            texts = [texts]
189        return [
190            _augment_to_n_grams(tokens)
191            for tokens in self._sp_inst.encode(texts, out_type=str)
192        ]
193
194    def _compute_bm25plus(
195        self,
196        query: str,
197        documents: list[Document],
198    ) -> list[tuple[Document, float]]:
199        """
200        Calculate the BM25+ score for all documents in this index.
201
202        :param query: the query to calculate the BM25+ score for.
203        :type query: str
204        :param documents: the pool of documents to calculate the BM25+ score for.
205        :type documents: list[Document]
206
207        :return: the BM25+ scores for all documents.
208        :rtype: list[tuple[Document, float]]
209        """
210        cnt = lambda ng: self._freq_doc.get(ng, 0)
211        idf = {
212            ng: math.log(1 + (len(self._index) - cnt(ng) + 0.5) / (cnt(ng) + 0.5))
213            for ng in self._tokenize(query)[0]
214        }
215
216        sim = []
217        for doc in documents:
218            _, freq, doc_len = self._index[doc.id]
219            doc_len_scaled = doc_len / self._avg_doc_len
220
221            scr = 0.0
222            for token, idf_val in idf.items():
223                freq_term = freq.get(token, 0.0)
224                freq_damp = self.k * (1 + self.b * (doc_len_scaled - 1))
225
226                tf_val = freq_term / (freq_term + freq_damp) + self.delta
227                scr += idf_val * tf_val
228
229            sim.append((doc, scr))
230
231        return sim
232
233    def _retrieval(
234        self,
235        query: str,
236        *,
237        filters: Optional[dict[str, Any]] = None,
238        top_k: Optional[int] = None,
239    ) -> list[tuple[Document, float]]:
240        """
241        Retrieve documents from the store using the given query.
242
243        :param query: the query to search for.
244        :type query: str
245        :param filters: the filters to apply to the document list.
246        :type filters: Optional[dict[str, Any]]
247        :param top_k: the number of documents to return.
248        :type top_k: int
249
250        :return: the top-k documents and corresponding sim score.
251        :rtype: list[tuple[Document, float]]
252        """
253        documents = self.filter_documents(filters)
254        if not documents:
255            return []
256
257        sim = self._compute_bm25plus(query, documents)
258        if top_k is None:
259            return sorted(sim, key=lambda x: x[1], reverse=True)
260        return heapq.nlargest(top_k, sim, key=lambda x: x[1])
261
262    def count_documents(self) -> int:
263        """
264        Returns how many documents are present in the document store.
265
266        :return: the number of documents in the store.
267        :rtype: int
268        """
269        return len(self._index)
270
271    def filter_documents(
272        self, filters: Optional[dict[str, Any]] = None
273    ) -> list[Document]:
274        """
275        Filter documents in the store using the given filters.
276
277        :param filters: the filters to apply to the document list.
278        :type filters: Optional[dict[str, Any]]
279
280        :return: the list of documents that match the given filters.
281        :rtype: list[Document]
282        """
283        if filters is None or not filters:
284            return [doc for doc, _, _ in self._index.values()]
285        return [
286            doc
287            for doc, _, _ in self._index.values()
288            if self._filter_func(filters, doc)
289        ]
290
291    def write_documents(
292        self,
293        documents: list[Document],
294        policy: DuplicatePolicy = DuplicatePolicy.FAIL,
295    ) -> int:
296        """
297        Writes (or overwrites) documents into the store.
298
299        :param documents: a list of documents.
300        :type documents: list[Document]
301        :param policy: documents with the same ID count as duplicates.
302            When duplicates are met, the store can:
303             - skip: keep the existing document and ignore the new one.
304             - overwrite: remove the old document and write the new one.
305             - fail: an error is raised
306        :type policy: DuplicatePolicy, optional
307
308        :raises DuplicateDocumentError: Exception trigger on duplicate
309            document if `policy=DuplicatePolicy.FAIL`
310
311        :return: Number of documents written.
312        :rtype: int
313        """
314        n_written = 0
315        for doc in documents:
316            if not isinstance(doc, Document):
317                msg = f"Expected document type, got '{doc}' of type '{type(doc)}'."
318                raise ValueError(msg)
319
320            if doc.id in self._index.keys():
321                if policy == DuplicatePolicy.SKIP:
322                    continue
323                elif policy == DuplicatePolicy.FAIL:
324                    msg = f"Document with ID '{doc.id}' already exists in the store."
325                    raise DuplicateDocumentError(msg)
326
327                # Overwrite if exists; delete first to keep the statistics consistent
328                logger.debug(
329                    f"Document '{doc.id}' already exists in the store, overwriting."
330                )
331                self.delete_documents([doc.id])
332
333            content = doc.content or ""
334            if content == "" and isinstance(doc.dataframe, pd.DataFrame):
335                content = doc.dataframe.astype(str).to_csv(index=False)
336
337            tokens = self._tokenize(content)[0]
338
339            self._index[doc.id] = (doc, Counter(tokens), len(tokens))
340            self._freq_doc.update(set(tokens))
341            self._avg_doc_len = (
342                len(tokens) + self._avg_doc_len * len(self._index)
343            ) / (len(self._index) + 1)
344
345            logger.debug(f"Document '{doc.id}' written to store.")
346            n_written += 1
347
348        return n_written
349
350    def delete_documents(self, document_ids: list[str]) -> int:
351        """
352        Deletes all documents with a matching document_ids.
353
354        Fails with `MissingDocumentError` if no document with
355        this id is present in the store.
356
357        :param object_ids: the object_ids to delete
358        :type object_ids: list[str]
359
360        :raises MissingDocumentError: trigger on missing document.
361
362        :return: Number of documents deleted.
363        :rtype: int
364        """
365        n_removal = 0
366        for doc_id in document_ids:
367            try:
368                _, freq, doc_len = self._index.pop(doc_id)
369                self._freq_doc.subtract(Counter(freq.keys()))
370                try:
371                    self._avg_doc_len = (
372                        self._avg_doc_len * (len(self._index) + 1) - doc_len
373                    ) / len(self._index)
374                except ZeroDivisionError:
375                    self._avg_doc_len = 0
376
377                logger.debug(f"Document '{doc_id}' deleted from store.")
378                n_removal += 1
379            except KeyError as exc:
380                msg = f"Document with ID '{doc_id}' not found, cannot delete it."
381                raise MissingDocumentError(msg) from exc
382
383        return n_removal
384
385    def to_dict(self) -> dict[str, Any]:
386        """Serializes this store to a dictionary."""
387        return default_to_dict(
388            self,
389            k=self.k,
390            b=self.b,
391            delta=self.delta * (self.k + 1.0),  # Because we scaled it on init
392            sp_file=self._sp_file,
393            n_grams=self._n_grams,
394            haystack_filter_logic=self._haystack_filter_logic,
395        )
396
397    @classmethod
398    def from_dict(cls, data: dict[str, Any]) -> "BetterBM25DocumentStore":
399        """Deserializes the store from a dictionary."""
400        return default_from_dict(cls, data)

An in-memory document store intended to improve the default BM25 document store shipped with Haystack.

BetterBM25DocumentStore( *, k: float = 1.5, b: float = 0.75, delta: float = 1.0, sp_file: Optional[str] = None, n_grams: Union[int, tuple[int, int]] = 1, haystack_filter_logic: bool = True)
 59    def __init__(
 60        self,
 61        *,
 62        k: float = 1.5,
 63        b: float = 0.75,
 64        delta: float = 1.0,
 65        sp_file: Optional[str] = None,
 66        n_grams: Union[int, tuple[int, int]] = 1,
 67        haystack_filter_logic: bool = True,
 68    ) -> None:
 69        """
 70        Creates a new BetterBM25DocumentStore instance.
 71
 72        An in-memory document store intended to improve the default
 73        BM25 document store shipped with Haystack. The default store
 74        recompute the index for the entire document store for every
 75        in-coming query, which is significantly inefficient. This
 76        store aims to improve the efficiency by pre-computing the
 77        index for all documents in the store and only do incremental
 78        updates when new documents are added or removed. Further, it
 79        leverages a SentencePiece model to tokenize the input text
 80        to allow more flexible and dynamic tokenization adapted to
 81        domain-specific text.
 82
 83        :param k: the k1 parameter in BM25+ formula.
 84        :type k: float, optional
 85        :param b: the b parameter in BM25+ formula.
 86        :type b: float, optional
 87        :param delta: the delta parameter in BM25+ formula.
 88        :type delta: float, optional
 89        :param sp_file: the SentencePiece model file to use for
 90            tokenization.
 91        :type sp_file: Optional[str], optional
 92        :param n_grams: the n-gram window size.
 93        :type n_grams: Optional[Union[int, tuple[int, int]]], optional
 94        :param haystack_filter_logic: Whether to use the Haystack
 95            filter logic or the one implemented in this store,
 96            which is more conservative.
 97        :type haystack_filter_logic: bool, optional
 98        """
 99        self.k = k
100        self.b = b
101
102        # Adjust the delta value so that we can bring the `(k1 + 1)`
103        # term out of the 'term frequency' term in BM25+ formula and
104        # delete it; this will not affect the ranking
105        self.delta = delta / (self.k + 1.0)
106
107        self._parse_sp_file(sp_file=sp_file)
108        self._parse_n_grams(n_grams=n_grams)
109
110        self._haystack_filter_logic = haystack_filter_logic
111        self._filter_func = (
112            document_matches_filter
113            if self._haystack_filter_logic
114            else apply_filters_to_document
115        )
116
117        self._avg_doc_len: float = 0.0
118        self._freq_doc: Counter = Counter()
119        self._index: dict[str, tuple[Document, dict[tuple[str], int], int]] = {}

Creates a new BetterBM25DocumentStore instance.

An in-memory document store intended to improve the default BM25 document store shipped with Haystack. The default store recompute the index for the entire document store for every in-coming query, which is significantly inefficient. This store aims to improve the efficiency by pre-computing the index for all documents in the store and only do incremental updates when new documents are added or removed. Further, it leverages a SentencePiece model to tokenize the input text to allow more flexible and dynamic tokenization adapted to domain-specific text.

Parameters
  • k: the k1 parameter in BM25+ formula.
  • b: the b parameter in BM25+ formula.
  • delta: the delta parameter in BM25+ formula.
  • sp_file: the SentencePiece model file to use for tokenization.
  • n_grams: the n-gram window size.
  • haystack_filter_logic: Whether to use the Haystack filter logic or the one implemented in this store, which is more conservative.
default_sp_file: Final = '/home/catcat/projects/bbm25-haystack/src/bbm25_haystack/default.model'
k
b
delta
def count_documents(self) -> int:
262    def count_documents(self) -> int:
263        """
264        Returns how many documents are present in the document store.
265
266        :return: the number of documents in the store.
267        :rtype: int
268        """
269        return len(self._index)

Returns how many documents are present in the document store.

Returns

the number of documents in the store.

def filter_documents( self, filters: Optional[dict[str, Any]] = None) -> list[haystack.dataclasses.document.Document]:
271    def filter_documents(
272        self, filters: Optional[dict[str, Any]] = None
273    ) -> list[Document]:
274        """
275        Filter documents in the store using the given filters.
276
277        :param filters: the filters to apply to the document list.
278        :type filters: Optional[dict[str, Any]]
279
280        :return: the list of documents that match the given filters.
281        :rtype: list[Document]
282        """
283        if filters is None or not filters:
284            return [doc for doc, _, _ in self._index.values()]
285        return [
286            doc
287            for doc, _, _ in self._index.values()
288            if self._filter_func(filters, doc)
289        ]

Filter documents in the store using the given filters.

Parameters
  • filters: the filters to apply to the document list.
Returns

the list of documents that match the given filters.

def write_documents( self, documents: list[haystack.dataclasses.document.Document], policy: haystack.document_stores.types.policy.DuplicatePolicy = <DuplicatePolicy.FAIL: 'fail'>) -> int:
291    def write_documents(
292        self,
293        documents: list[Document],
294        policy: DuplicatePolicy = DuplicatePolicy.FAIL,
295    ) -> int:
296        """
297        Writes (or overwrites) documents into the store.
298
299        :param documents: a list of documents.
300        :type documents: list[Document]
301        :param policy: documents with the same ID count as duplicates.
302            When duplicates are met, the store can:
303             - skip: keep the existing document and ignore the new one.
304             - overwrite: remove the old document and write the new one.
305             - fail: an error is raised
306        :type policy: DuplicatePolicy, optional
307
308        :raises DuplicateDocumentError: Exception trigger on duplicate
309            document if `policy=DuplicatePolicy.FAIL`
310
311        :return: Number of documents written.
312        :rtype: int
313        """
314        n_written = 0
315        for doc in documents:
316            if not isinstance(doc, Document):
317                msg = f"Expected document type, got '{doc}' of type '{type(doc)}'."
318                raise ValueError(msg)
319
320            if doc.id in self._index.keys():
321                if policy == DuplicatePolicy.SKIP:
322                    continue
323                elif policy == DuplicatePolicy.FAIL:
324                    msg = f"Document with ID '{doc.id}' already exists in the store."
325                    raise DuplicateDocumentError(msg)
326
327                # Overwrite if exists; delete first to keep the statistics consistent
328                logger.debug(
329                    f"Document '{doc.id}' already exists in the store, overwriting."
330                )
331                self.delete_documents([doc.id])
332
333            content = doc.content or ""
334            if content == "" and isinstance(doc.dataframe, pd.DataFrame):
335                content = doc.dataframe.astype(str).to_csv(index=False)
336
337            tokens = self._tokenize(content)[0]
338
339            self._index[doc.id] = (doc, Counter(tokens), len(tokens))
340            self._freq_doc.update(set(tokens))
341            self._avg_doc_len = (
342                len(tokens) + self._avg_doc_len * len(self._index)
343            ) / (len(self._index) + 1)
344
345            logger.debug(f"Document '{doc.id}' written to store.")
346            n_written += 1
347
348        return n_written

Writes (or overwrites) documents into the store.

Parameters
  • documents: a list of documents.
  • policy: documents with the same ID count as duplicates. When duplicates are met, the store can:
    • skip: keep the existing document and ignore the new one.
    • overwrite: remove the old document and write the new one.
    • fail: an error is raised
Raises
  • DuplicateDocumentError: Exception trigger on duplicate document if policy=DuplicatePolicy.FAIL
Returns

Number of documents written.

def delete_documents(self, document_ids: list[str]) -> int:
350    def delete_documents(self, document_ids: list[str]) -> int:
351        """
352        Deletes all documents with a matching document_ids.
353
354        Fails with `MissingDocumentError` if no document with
355        this id is present in the store.
356
357        :param object_ids: the object_ids to delete
358        :type object_ids: list[str]
359
360        :raises MissingDocumentError: trigger on missing document.
361
362        :return: Number of documents deleted.
363        :rtype: int
364        """
365        n_removal = 0
366        for doc_id in document_ids:
367            try:
368                _, freq, doc_len = self._index.pop(doc_id)
369                self._freq_doc.subtract(Counter(freq.keys()))
370                try:
371                    self._avg_doc_len = (
372                        self._avg_doc_len * (len(self._index) + 1) - doc_len
373                    ) / len(self._index)
374                except ZeroDivisionError:
375                    self._avg_doc_len = 0
376
377                logger.debug(f"Document '{doc_id}' deleted from store.")
378                n_removal += 1
379            except KeyError as exc:
380                msg = f"Document with ID '{doc_id}' not found, cannot delete it."
381                raise MissingDocumentError(msg) from exc
382
383        return n_removal

Deletes all documents with a matching document_ids.

Fails with MissingDocumentError if no document with this id is present in the store.

Parameters
  • object_ids: the object_ids to delete
Raises
  • MissingDocumentError: trigger on missing document.
Returns

Number of documents deleted.

def to_dict(self) -> dict[str, typing.Any]:
385    def to_dict(self) -> dict[str, Any]:
386        """Serializes this store to a dictionary."""
387        return default_to_dict(
388            self,
389            k=self.k,
390            b=self.b,
391            delta=self.delta * (self.k + 1.0),  # Because we scaled it on init
392            sp_file=self._sp_file,
393            n_grams=self._n_grams,
394            haystack_filter_logic=self._haystack_filter_logic,
395        )

Serializes this store to a dictionary.

@classmethod
def from_dict( cls, data: dict[str, typing.Any]) -> BetterBM25DocumentStore:
397    @classmethod
398    def from_dict(cls, data: dict[str, Any]) -> "BetterBM25DocumentStore":
399        """Deserializes the store from a dictionary."""
400        return default_from_dict(cls, data)

Deserializes the store from a dictionary.