From 2e0e2f8e5b9589b34f862967655116ba0d5bf13b Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 2 Oct 2024 10:41:27 +1000 Subject: [PATCH] throw error for unknown providers or moderator --- src/exchange/load_exchange_attribute_error.py | 6 ++++++ src/exchange/moderators/__init__.py | 6 +++++- src/exchange/providers/__init__.py | 6 +++++- tests/providers/test_provider.py | 16 ++++++++++++++++ tests/test_load_exchange_attribute_error.py | 11 +++++++++++ tests/test_moderators.py | 15 +++++++++++++++ 6 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 src/exchange/load_exchange_attribute_error.py create mode 100644 tests/providers/test_provider.py create mode 100644 tests/test_load_exchange_attribute_error.py create mode 100644 tests/test_moderators.py diff --git a/src/exchange/load_exchange_attribute_error.py b/src/exchange/load_exchange_attribute_error.py new file mode 100644 index 0000000..131f6f9 --- /dev/null +++ b/src/exchange/load_exchange_attribute_error.py @@ -0,0 +1,6 @@ +class LoadExchangeAttributeError(Exception): + def __init__(self, attribute_name: str, attribute_value: str) -> None: + self.attribute_name = attribute_name + self.attribute_value = attribute_value + self.message = f"Unknown {attribute_name}: {attribute_value}" + super().__init__(self.message) diff --git a/src/exchange/moderators/__init__.py b/src/exchange/moderators/__init__.py index 56b198a..5835f88 100644 --- a/src/exchange/moderators/__init__.py +++ b/src/exchange/moderators/__init__.py @@ -1,6 +1,7 @@ from functools import cache from typing import Type +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError from exchange.moderators.base import Moderator from exchange.utils import load_plugins from exchange.moderators.passive import PassiveModerator # noqa @@ -10,4 +11,7 @@ @cache def get_moderator(name: str) -> Type[Moderator]: - return load_plugins(group="exchange.moderator")[name] + moderators = load_plugins(group="exchange.moderator") + if name not in moderators: + raise LoadExchangeAttributeError("moderator", name) + return moderators[name] diff --git a/src/exchange/providers/__init__.py b/src/exchange/providers/__init__.py index 177ea63..ca83c5b 100644 --- a/src/exchange/providers/__init__.py +++ b/src/exchange/providers/__init__.py @@ -1,6 +1,7 @@ from functools import cache from typing import Type +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError from exchange.providers.anthropic import AnthropicProvider # noqa from exchange.providers.base import Provider, Usage # noqa from exchange.providers.databricks import DatabricksProvider # noqa @@ -13,4 +14,7 @@ @cache def get_provider(name: str) -> Type[Provider]: - return load_plugins(group="exchange.provider")[name] + providers = load_plugins(group="exchange.provider") + if name not in providers: + raise LoadExchangeAttributeError("provider", name) + return providers[name] diff --git a/tests/providers/test_provider.py b/tests/providers/test_provider.py new file mode 100644 index 0000000..39b0992 --- /dev/null +++ b/tests/providers/test_provider.py @@ -0,0 +1,16 @@ +import pytest +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.providers import get_provider + + +def test_get_provider_valid(): + provider_name = "openai" + provider = get_provider(provider_name) + assert provider.__name__ == "OpenAiProvider" + + +def test_get_provider_throw_error_for_unknown_provider(): + with pytest.raises(LoadExchangeAttributeError) as error: + get_provider("nonexistent") + assert error.value.attribute_name == "provider" + assert error.value.attribute_value == "nonexistent" diff --git a/tests/test_load_exchange_attribute_error.py b/tests/test_load_exchange_attribute_error.py new file mode 100644 index 0000000..aa9e106 --- /dev/null +++ b/tests/test_load_exchange_attribute_error.py @@ -0,0 +1,11 @@ +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError + + +def test_load_exchange_attribute_error(): + attribute_name = "provider" + attribute_value = "not_exist" + error = LoadExchangeAttributeError(attribute_name, attribute_value) + + assert error.attribute_name == attribute_name + assert error.attribute_value == attribute_value + assert error.message == "Unknown provider: not_exist" diff --git a/tests/test_moderators.py b/tests/test_moderators.py new file mode 100644 index 0000000..f9ef0b9 --- /dev/null +++ b/tests/test_moderators.py @@ -0,0 +1,15 @@ +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.moderators import get_moderator +import pytest + + +def test_get_moderator(): + moderator = get_moderator("truncate") + assert moderator.__name__ == "ContextTruncate" + + +def test_get_moderator_raise_error_for_unknown_moderator(): + with pytest.raises(LoadExchangeAttributeError) as error: + get_moderator("nonexistent") + assert error.value.attribute_name == "moderator" + assert error.value.attribute_value == "nonexistent"