Skip to content

Commit

Permalink
Add weight mask and update tests (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhluca authored Sep 3, 2024
1 parent 2b97cc5 commit 8b36bdd
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 5 deletions.
34 changes: 29 additions & 5 deletions bm25s/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def get_tokens_ids(self, query_tokens: List[str]) -> List[int]:
self.vocab_dict[token] for token in query_tokens if token in self.vocab_dict
]

def get_scores_from_ids(self, query_tokens_ids: List[int]) -> np.ndarray:
def get_scores_from_ids(self, query_tokens_ids: List[int], weight_mask=None) -> np.ndarray:
data = self.scores["data"]
indices = self.scores["indices"]
indptr = self.scores["indptr"]
Expand All @@ -454,6 +454,10 @@ def get_scores_from_ids(self, query_tokens_ids: List[int]) -> np.ndarray:
dtype=dtype,
)

if weight_mask is not None:
# multiply the scores by the weight mask
scores *= weight_mask

# if there's a non-occurrence array, we need to add the non-occurrence score
# back to the scores
if self.nonoccurrence_array is not None:
Expand All @@ -462,23 +466,24 @@ def get_scores_from_ids(self, query_tokens_ids: List[int]) -> np.ndarray:

return scores

def get_scores(self, query_tokens_single: List[str]) -> np.ndarray:
def get_scores(self, query_tokens_single: List[str], weight_mask=None) -> np.ndarray:
query_tokens_ids = self.get_tokens_ids(query_tokens_single)
return self.get_scores_from_ids(query_tokens_ids)
return self.get_scores_from_ids(query_tokens_ids, weight_mask=weight_mask)

def _get_top_k_results(
self,
query_tokens_single: List[str],
k: int = 1000,
backend="auto",
sorted: bool = False,
weight_mask: np.ndarray = None,
):
"""
This function is used to retrieve the top-k results for a single query.
Since it's a hidden function, the user should not call it directly and
may change in the future. Please use the `retrieve` function instead.
"""
scores_q = self.get_scores(query_tokens_single)
scores_q = self.get_scores(query_tokens_single, weight_mask=weight_mask)
if backend.startswith('numba'):
if selection_jit is None:
raise ImportError("Numba is not installed. Please install numba to use the numba backend.")
Expand All @@ -505,6 +510,7 @@ def retrieve(
n_threads: int = 0,
chunksize: int = 50,
backend_selection: str = "auto",
weight_mask: np.ndarray = None,
):
"""
Retrieve the top-k documents for each query (tokenized).
Expand Down Expand Up @@ -554,6 +560,10 @@ def retrieve(
backend_selection : str
The backend to use for the top-k retrieval. Choose from "auto", "numpy", "jax".
If "auto", it will use JAX if it is available, otherwise it will use numpy.
weight_mask : np.ndarray
A weight mask to filter the documents. If provided, the scores for the masked
documents will be set to 0 to avoid returning them in the results.
"""
allowed_return_as = ["tuple", "documents"]

Expand Down Expand Up @@ -591,6 +601,20 @@ def retrieve(
query_tokens = tokenization.convert_tokenized_to_string_list(query_tokens)

corpus = corpus if corpus is not None else self.corpus

if weight_mask is not None:
if not isinstance(weight_mask, np.ndarray):
raise ValueError("weight_mask must be a numpy array.")

# check if weight_mask is a 1D array, if not raise an error
if weight_mask.ndim != 1:
raise ValueError("weight_mask must be a 1D array.")

# check if the length of the weight_mask is the same as the length of the corpus
if len(weight_mask) != self.scores["num_docs"]:
raise ValueError(
"The length of the weight_mask must be the same as the length of the corpus."
)

if self.backend == "numba":
if _retrieve_numba_functional is None:
Expand Down Expand Up @@ -628,7 +652,7 @@ def retrieve(
"disable": not show_progress,
}
topk_fn = partial(
self._get_top_k_results, k=k, sorted=sorted, backend=backend_selection
self._get_top_k_results, k=k, sorted=sorted, backend=backend_selection, weight_mask=weight_mask
)

if n_threads == 0:
Expand Down
6 changes: 6 additions & 0 deletions bm25s/numba/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _retrieve_internal_jitted_parallel(
indices: np.ndarray,
num_docs: int,
nonoccurrence_array: np.ndarray = None,
weight_mask: np.ndarray = None,
):
N = len(query_pointers) - 1

Expand All @@ -48,6 +49,9 @@ def _retrieve_internal_jitted_parallel(
nonoccurrence_scores = nonoccurrence_array[query_tokens_single].sum()
scores_single += nonoccurrence_scores

if weight_mask is not None:
scores_single = scores_single * weight_mask

topk_scores_sing, topk_indices_sing = _numba_sorted_top_k(
scores_single, k=k, sorted=sorted
)
Expand All @@ -72,6 +76,7 @@ def _retrieve_numba_functional(
backend_selection="numba",
dtype="float32",
int_dtype="int32",
weight_mask=None,
):
from numba import get_num_threads, set_num_threads, njit

Expand Down Expand Up @@ -121,6 +126,7 @@ def _retrieve_numba_functional(
indices=scores["indices"],
num_docs=scores["num_docs"],
nonoccurrence_array=nonoccurrence_array,
weight_mask=weight_mask,
)

# reset the number of threads
Expand Down
38 changes: 38 additions & 0 deletions tests/core/test_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,44 @@ def test_retrieve(self):
results = self.retriever.retrieve(queries_as_tuple, k=1).documents
self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}")


def test_retrieve_with_weight_mask(self):


# first, try with default mode
query = "cat feline dog bird fish" # weights should be [2, 1, 1, 1], but after masking should be [2, 0, 0, 1]

for dt in [np.float32, np.int32, np.bool_]:
weight_mask = np.array([1, 0, 0, 1], dtype=dt)
ground_truth = np.array([[0, 3]])

query_tokens_obj = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=True)

# retrieve the top 2 documents
results = self.retriever.retrieve(query_tokens_obj, k=2, weight_mask=weight_mask).documents

# assert that the retrieved indices are correct
self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}")

# now, try tokenizing with text tokens
query_tokens_texts = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=False)
results = self.retriever.retrieve(query_tokens_texts, k=2, weight_mask=weight_mask).documents
self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}")

# now, try to pass a tuple of tokens
ids, vocab = query_tokens_obj
query_tokens_tuple = (ids, vocab)
results = self.retriever.retrieve(query_tokens_tuple, k=2, weight_mask=weight_mask).documents
self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}")

# finally, try to pass a 2-tuple of tokens with text tokens to "try to trick the system"
queries_as_tuple = (query_tokens_texts[0], query_tokens_texts[0])
# only retrieve 1 document
ground_truth = np.array([[0], [0]])
results = self.retriever.retrieve(queries_as_tuple, k=1, weight_mask=weight_mask).documents
self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}")


def test_failure_of_bad_tuple(self):
# try to pass a tuple of tokens with different lengths
query = "a cat is a feline, it's sometimes beautiful but cannot fly"
Expand Down
26 changes: 26 additions & 0 deletions tests/numba/test_numba_backend_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,32 @@ def test_c_mmap_retrieve_with_numba(self):
self.assertTrue(np.all(retrieved.documents == retrieved_np.documents), "The retrieved documents should be the same")
self.assertTrue(np.all(retrieved_docs == retrieved_docs_np), "The results should be the same")

# test weight_mask in retrieve()
def test_d_retrieve_with_weight_mask(self):
for dt in [np.float32, np.int32, np.bool_]:
weight_mask = np.array([1, 1, 0, 1], dtype=dt)
# load the retriever from temp dir
retriever = bm25s.BM25.load(
self.tmpdirname,
data_name="data.index.csc.npy",
indices_name="indices.index.csc.npy",
indptr_name="indptr.index.csc.npy",
vocab_name="vocab.json",
nnoc_name="nonoccurrence_array.npy",
params_name="params.json",
load_corpus=True,
)

self.assertTrue(retriever.backend == "numba", "The backend should be 'numba'")

# now, let's retrieve the top-k results for a query
query = ["my cat loves to purr", "a fish likes swimming"]

query_tokens = bm25s.tokenize(query, stopwords="en", stemmer=self.stemmer)

# retrieve the top-k results
top_k = 2
retrieved = retriever.retrieve(query_tokens, k=top_k, return_as="tuple", weight_mask=weight_mask)

@classmethod
def tearDownClass(cls):
Expand Down

0 comments on commit 8b36bdd

Please sign in to comment.