Skip to content

Commit

Permalink
throw error for unknown providers or moderator
Browse files Browse the repository at this point in the history
  • Loading branch information
lifeizhou-ap committed Oct 2, 2024
1 parent 78d4913 commit 2e0e2f8
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 2 deletions.
6 changes: 6 additions & 0 deletions src/exchange/load_exchange_attribute_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class LoadExchangeAttributeError(Exception):
def __init__(self, attribute_name: str, attribute_value: str) -> None:
self.attribute_name = attribute_name
self.attribute_value = attribute_value
self.message = f"Unknown {attribute_name}: {attribute_value}"
super().__init__(self.message)
6 changes: 5 additions & 1 deletion src/exchange/moderators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import cache
from typing import Type

from exchange.load_exchange_attribute_error import LoadExchangeAttributeError
from exchange.moderators.base import Moderator
from exchange.utils import load_plugins
from exchange.moderators.passive import PassiveModerator # noqa
Expand All @@ -10,4 +11,7 @@

@cache
def get_moderator(name: str) -> Type[Moderator]:
return load_plugins(group="exchange.moderator")[name]
moderators = load_plugins(group="exchange.moderator")
if name not in moderators:
raise LoadExchangeAttributeError("moderator", name)
return moderators[name]
6 changes: 5 additions & 1 deletion src/exchange/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import cache
from typing import Type

from exchange.load_exchange_attribute_error import LoadExchangeAttributeError
from exchange.providers.anthropic import AnthropicProvider # noqa
from exchange.providers.base import Provider, Usage # noqa
from exchange.providers.databricks import DatabricksProvider # noqa
Expand All @@ -13,4 +14,7 @@

@cache
def get_provider(name: str) -> Type[Provider]:
return load_plugins(group="exchange.provider")[name]
providers = load_plugins(group="exchange.provider")
if name not in providers:
raise LoadExchangeAttributeError("provider", name)
return providers[name]
16 changes: 16 additions & 0 deletions tests/providers/test_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
from exchange.load_exchange_attribute_error import LoadExchangeAttributeError
from exchange.providers import get_provider


def test_get_provider_valid():
provider_name = "openai"
provider = get_provider(provider_name)
assert provider.__name__ == "OpenAiProvider"


def test_get_provider_throw_error_for_unknown_provider():
with pytest.raises(LoadExchangeAttributeError) as error:
get_provider("nonexistent")
assert error.value.attribute_name == "provider"
assert error.value.attribute_value == "nonexistent"
11 changes: 11 additions & 0 deletions tests/test_load_exchange_attribute_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from exchange.load_exchange_attribute_error import LoadExchangeAttributeError


def test_load_exchange_attribute_error():
attribute_name = "provider"
attribute_value = "not_exist"
error = LoadExchangeAttributeError(attribute_name, attribute_value)

assert error.attribute_name == attribute_name
assert error.attribute_value == attribute_value
assert error.message == "Unknown provider: not_exist"
15 changes: 15 additions & 0 deletions tests/test_moderators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from exchange.load_exchange_attribute_error import LoadExchangeAttributeError
from exchange.moderators import get_moderator
import pytest


def test_get_moderator():
moderator = get_moderator("truncate")
assert moderator.__name__ == "ContextTruncate"


def test_get_moderator_raise_error_for_unknown_moderator():
with pytest.raises(LoadExchangeAttributeError) as error:
get_moderator("nonexistent")
assert error.value.attribute_name == "moderator"
assert error.value.attribute_value == "nonexistent"

0 comments on commit 2e0e2f8

Please sign in to comment.