diff --git a/.mypy.ini b/.mypy.ini index a799cdd..cc3ed4a 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -38,3 +38,6 @@ disallow_untyped_decorators = False [mypy-pwnlib] ignore_missing_imports = True +[mypy-bson] +ignore_missing_imports = True + diff --git a/setup.py b/setup.py index dd948de..e630eaa 100755 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setuptools.setup( name="enochecker", - version="0.4.0", + version="0.4.1", author="domenukk", author_email="dmaier@sect.tu-berlin.de", description="Library to build checker scripts for EnoEngine A/D CTF Framework in Python", diff --git a/src/enochecker/storeddict.py b/src/enochecker/storeddict.py index d05fee6..e7eddbf 100644 --- a/src/enochecker/storeddict.py +++ b/src/enochecker/storeddict.py @@ -1,6 +1,5 @@ """Backend for team_db based on a local filesystem directory.""" -import json import logging import os import threading @@ -10,6 +9,8 @@ from pathlib import Path from typing import Any, Callable, Dict, Iterator, Optional, Set +import bson + from .utils import base64ify, debase64ify, ensure_valid_filename logging.basicConfig(level=logging.DEBUG) @@ -22,7 +23,7 @@ 6 # 2**6 / 10 seconds are 6.4 secs. -> That's how long the db will wait for a log ) DB_PREFIX = "_store_" # Prefix all db files will get -DB_EXTENSION = ".json" # Extension all db files will get +DB_EXTENSION = ".bson" # Extension all db files will get DB_LOCK_EXTENSION = ".lock" # Extension all lock folders will get DB_GLOBAL_CACHE_SETTING = True @@ -129,9 +130,9 @@ def _dir(self, key: str) -> str: """ return os.path.join(self.path, DB_PREFIX + base64ify(key, b"+-")) - def _dir_jsonname(self, key: str) -> str: + def _dir_bsonname(self, key: str) -> str: """ - Return the path for the json db file for this key. + Return the path for the bson db file for this key. See :func:`_dir` """ @@ -240,7 +241,7 @@ def persist(self) -> None: locked = self.is_locked(key) or self.ignore_locks if not locked: self.lock(key) - os.remove(self._dir_jsonname(key)) + os.remove(self._dir_bsonname(key)) if not locked: self.release(key) self.logger.debug(f"Deleted {key} from db {self.name}") @@ -251,8 +252,8 @@ def persist(self) -> None: if not locked: self.lock(key) try: - with open(self._dir_jsonname(key), "wb") as f: - f.write(json.dumps(self._cache[key]).encode("utf-8")) + with open(self._dir_bsonname(key), "wb") as f: + f.write(bson.BSON.encode({"value": self._cache[key]})) finally: if not locked: self.release(key) @@ -272,9 +273,9 @@ def __getitem__(self, key: str) -> Any: if not locked: self.lock(key) try: - with open(self._dir_jsonname(key), "rb") as f: - val = json.loads(f.read().decode("utf-8")) - except (OSError, json.decoder.JSONDecodeError) as ex: + with open(self._dir_bsonname(key), "rb") as f: + val = bson.BSON(f.read()).decode()["value"] + except (OSError, bson.errors.BSONError) as ex: raise KeyError("Key {} not found - {}".format(key, ex)) finally: if not locked: diff --git a/tests/test_enochecker.py b/tests/test_enochecker.py index 91b1b32..524c0cb 100644 --- a/tests/test_enochecker.py +++ b/tests/test_enochecker.py @@ -171,6 +171,20 @@ def test_dict(): assert len(db) == 0 +@pytest.mark.parametrize( + "value", [b"binarydata", {"test": b"test", "test2": 123}, ["asd", 123, b"xyz"]] +) +@temp_storage_dir +def test_storeddict_complex_types(value): + db = enochecker.storeddict.StoredDict(name="test", base_path=STORAGE_DIR) + + db["test"] = value + db.persist() + + db.reload() + assert db["test"] == value + + @temp_storage_dir def test_checker(): flag = "ENOFLAG"