Skip to content

Commit

Permalink
Add saving and loading corpus/stopwords to Tokenizer and add integr…
Browse files Browse the repository at this point in the history
…ation to HF Hub via `bm25s.hf.TokenizerHF` (save/load) (#59)

* Add save_vocab, load_vocab, save_stopwords, load_stopwords

* Add support to saving/loading vocabulary and stopwords to hub

* Improve auto-generated readme with section on tokenizer, fix error in example
  • Loading branch information
xhluca authored Sep 22, 2024
1 parent fd142da commit 1e636a9
Show file tree
Hide file tree
Showing 6 changed files with 490 additions and 22 deletions.
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ corpus = [

# Pick your favorite stemmer, and pass
stemmer = None
stopwords = []
stopwords = ["is"]
splitter = lambda x: x.split() # function or regex pattern
# Create a tokenizer
tokenizer = Tokenizer(
Expand All @@ -211,6 +211,19 @@ print("tokens:", corpus_tokens)
print("vocab:", tokenizer.get_vocab_dict())

# note: the vocab dict will either be a dict of `word -> id` if you don't have a stemmer, and a dict of `stemmed word -> stem id` if you do.
# You can save the vocab. it's fine to use the same dir as your index if filename doesn't conflict
tokenizer.save_vocab(save_dir="bm25s_very_big_index")

# loading:
new_tokenizer = Tokenizer(stemmer=stemmer, stopwords=[], splitter=splitter)
new_tokenizer.load_vocab("bm25s_very_big_index")
print("vocab reloaded:", new_tokenizer.get_vocab_dict())

# the same can be done for stopwords
print("stopwords before reload:", new_tokenizer.stopwords)
tokenizer.save_stopwords(save_dir="bm25s_very_big_index")
new_tokenizer.load_stopwords("bm25s_very_big_index")
print("stopwords reloaded:", new_tokenizer.stopwords)
```

You can find advanced examples in [examples/tokenizer_class.py](examples/tokenizer_class.py), including how to:
Expand Down
269 changes: 259 additions & 10 deletions bm25s/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tempfile
from typing import Iterable, Union
from . import BM25, __version__
from .tokenization import Tokenizer

try:
from huggingface_hub import HfApi
Expand Down Expand Up @@ -106,6 +107,32 @@
retriever = BM25HF.load_from_hub("{username}/{repo_name}", token=token)
```
## Tokenizer
If you have saved a `Tokenizer` object with the index using the following approach:
```python
from bm25s.hf import TokenizerHF
token = "your_hugging_face_token"
tokenizer = TokenizerHF(corpus=corpus, stopwords="english")
tokenizer.save_to_hub("{username}/{repo_name}", token=token)
# and stopwords too
tokenizer.save_stopwords_to_hub("{username}/{repo_name}", token=token)
```
Then, you can load the tokenizer using the following code:
```python
from bm25s.hf import TokenizerHF
tokenizer = TokenizerHF(corpus=corpus, stopwords=[])
tokenizer.load_vocab_from_hub("{username}/{repo_name}", token=token)
tokenizer.load_stopwords_from_hub("{username}/{repo_name}", token=token)
```
## Stats
This dataset was created using the following data:
Expand Down Expand Up @@ -133,15 +160,15 @@
To cite `bm25s`, please use the following bibtex:
```
@misc{lu_2024_bm25s,
title={BM25S: Orders of magnitude faster lexical search via eager sparse scoring},
author={Xing Han Lù},
year={2024},
eprint={2407.03618},
archivePrefix={arXiv},
primaryClass={cs.IR},
url={https://arxiv.org/abs/2407.03618},
}
@misc{{lu_2024_bm25s,
title={{BM25S: Orders of magnitude faster lexical search via eager sparse scoring}},
author={{Xing Han Lù}},
year={{2024}},
eprint={{2407.03618}},
archivePrefix={{arXiv}},
primaryClass={{cs.IR}},
url={{https://arxiv.org/abs/2407.03618}},
}}
```
"""
Expand Down Expand Up @@ -216,6 +243,228 @@ def can_save_locally(local_save_dir, overwrite_local: bool) -> bool:
return True


class TokenizerHF(Tokenizer):
def save_vocab_to_hub(
self,
repo_id: str,
token: str = None,
local_dir: str = None,
commit_message: str = "Update tokenizer",
overwrite_local: bool = False,
private=True,
**kwargs,
):
"""
This function saves the tokenizer's vocab to the Hugging Face Hub.
Parameters
----------
repo_id: str
The unique identifier of the repository to save the model to.
The `repo_id` should be in the form of "username/repo_name".
token: str
The Hugging Face API token to use.
local_dir: str
The directory to save the model to before pushing to the Hub.
If it is not empty and `overwrite_local` is False, it will fall
back to saving to a temporary directory.
commit_message: str
The commit message to use when saving the model.
overwrite_local: bool
Whether to overwrite the existing local directory if it exists.
kwargs: dict
Additional keyword arguments to pass to `HfApi.upload_folder` call.
"""
api = HfApi(token=token)
repo_url = api.create_repo(
repo_id=repo_id,
token=api.token,
private=private,
repo_type="model",
exist_ok=True,
)
repo_id = repo_url.repo_id

saving_locally = can_save_locally(local_dir, overwrite_local)
if saving_locally:
os.makedirs(local_dir, exist_ok=True)
save_dir = local_dir
else:
# save to a temporary directory otherwise
save_dir = tempfile.mkdtemp()

self.save_vocab(save_dir)
# push content of the temporary directory to the repo
api.upload_folder(
repo_id=repo_id,
commit_message=commit_message,
token=api.token,
folder_path=save_dir,
repo_type=repo_url.repo_type,
**kwargs,
)
# delete the temporary directory if it was created
if not saving_locally:
shutil.rmtree(save_dir)

return repo_url

def load_vocab_from_hub(
cls,
repo_id: str,
revision=None,
token=None,
local_dir=None,
):
"""
This function loads the tokenizer's vocab from the Hugging Face Hub.
Parameters
----------
repo_id: str
The unique identifier of the repository to load the model from.
The `repo_id` should be in the form of "username/repo_name".
revision: str
The revision of the model to load.
token: str
The Hugging Face API token to use.
local_dir: str
The local dir where the model will be stored after downloading.
allow_pickle: bool
Whether to allow pickling the model. Default is False.
"""
api = HfApi(token=token)
# check if the model exists
repo_url = api.repo_info(repo_id)
if repo_url is None:
raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.")

snapshot = api.snapshot_download(
repo_id=repo_id, revision=revision, token=token, local_dir=local_dir
)
if snapshot is None:
raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.")

return cls.load_vocab(save_dir=snapshot)

def save_stopwords_to_hub(
self,
repo_id: str,
token: str = None,
local_dir: str = None,
commit_message: str = "Update stopwords",
overwrite_local: bool = False,
private=True,
**kwargs,
):
"""
This function saves the tokenizer's stopwords to the Hugging Face Hub.
Parameters
----------
repo_id: str
The unique identifier of the repository to save the model to.
The `repo_id` should be in the form of "username/repo_name".
token: str
The Hugging Face API token to use.
local_dir: str
The directory to save the model to before pushing to the Hub.
If it is not empty and `overwrite_local` is False, it will fall
back to saving to a temporary directory.
commit_message: str
The commit message to use when saving the model.
overwrite_local: bool
Whether to overwrite the existing local directory if it exists.
kwargs: dict
Additional keyword arguments to pass to `HfApi.upload_folder` call.
"""
api = HfApi(token=token)
repo_url = api.create_repo(
repo_id=repo_id,
token=api.token,
private=private,
repo_type="model",
exist_ok=True,
)
repo_id = repo_url.repo_id

saving_locally = can_save_locally(local_dir, overwrite_local)
if saving_locally:
os.makedirs(local_dir, exist_ok=True)
save_dir = local_dir
else:
# save to a temporary directory otherwise
save_dir = tempfile.mkdtemp()

self.save_stopwords(save_dir)
# push content of the temporary directory to the repo
api.upload_folder(
repo_id=repo_id,
commit_message=commit_message,
token=api.token,
folder_path=save_dir,
repo_type=repo_url.repo_type,
**kwargs,
)
# delete the temporary directory if it was created
if not saving_locally:
shutil.rmtree(save_dir)

return repo_url

def load_stopwords_from_hub(
self,
repo_id: str,
revision=None,
token=None,
local_dir=None,
):
"""
This function loads the tokenizer's stopwords from the Hugging Face Hub.
Parameters
----------
repo_id: str
The unique identifier of the repository to load the model from.
The `repo_id` should be in the form of "username/repo_name".
revision: str
The revision of the model to load.
token: str
The Hugging Face API token to use.
local_dir: str
The local dir where the model will be stored after downloading.
"""
api = HfApi(token=token)
# check if the model exists
repo_url = api.repo_info(repo_id)
if repo_url is None:
raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.")

snapshot = api.snapshot_download(
repo_id=repo_id, revision=revision, token=token, local_dir=local_dir
)
if snapshot is None:
raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.")

return self.load_stopwords(save_dir=snapshot)

class BM25HF(BM25):
def save_to_hub(
self,
Expand All @@ -238,7 +487,7 @@ def save_to_hub(
repo_id: str
The name of the repository to save the model to.
It should be username/repo_name.
the `repo_id` should be in the form of "username/repo_name".
token: str
The Hugging Face API token to use.
Expand Down
Loading

0 comments on commit 1e636a9

Please sign in to comment.