Skip to content

Commit

Permalink
Apply ruff and add to CI
Browse files Browse the repository at this point in the history
  • Loading branch information
baxen committed Sep 3, 2024
1 parent 119b8d7 commit 128925a
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 108 deletions.
33 changes: 17 additions & 16 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,27 @@ jobs:

strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version:
- "3.10"
- "3.11"
- "3.12"

steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install UV
run: curl -LsSf https://astral.sh/uv/install.sh | sh

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install uv pytest
- name: Source Cargo Environment
run: source $HOME/.cargo/env

- name: Install UV
run: curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}

- name: Source Cargo Environment
run: source $HOME/.cargo/env
- name: Ruff
run: |
uvx ruff check
uvx ruff format --check
- name: Run tests
run: uv run pytest tests -m 'not integration'
- name: Run tests
run: uvx pytest tests -m 'not integration'
8 changes: 6 additions & 2 deletions src/exchange/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ class Checkpoint:
end_index: int = field(default=0) # inclusive
token_count: int = field(default=0)

def __deepcopy__(self, _) -> "Checkpoint":
def __deepcopy__(self, _) -> "Checkpoint": # noqa: ANN001
"""
Returns a deep copy of the Checkpoint object.
"""
return Checkpoint(start_index=self.start_index, end_index=self.end_index, token_count=self.token_count)
return Checkpoint(
start_index=self.start_index,
end_index=self.end_index,
token_count=self.token_count,
)


@define
Expand Down
17 changes: 11 additions & 6 deletions src/exchange/moderators/truncate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import List, Optional, Type
from __future__ import annotations

from typing import TYPE_CHECKING, List

from exchange.checkpoint import CheckpointData
from exchange.message import Message
from exchange.moderators import PassiveModerator
from exchange.moderators.base import Moderator

if TYPE_CHECKING:
from exchange.exchange import Exchange

# currently this is the point at which we start to truncate, so
# so once we get to this token size the token count will exceed this
# by a little bit.
Expand All @@ -15,15 +20,15 @@
class ContextTruncate(Moderator):
def __init__(
self,
model: Optional[str] = "gpt-4o-mini",
max_tokens: Optional[int] = MAX_TOKENS,
model: str = "gpt-4o-mini",
max_tokens: int = MAX_TOKENS,
) -> None:
self.model = model
self.system_prompt_token_count = 0
self.max_tokens = max_tokens
self.last_system_prompt = None

def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None:
def rewrite(self, exchange: Exchange) -> None:
"""Truncate the exchange messages with a FIFO strategy."""
self._update_system_prompt_token_count(exchange)

Expand All @@ -34,7 +39,7 @@ def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None:
for _ in range(len(messages_to_remove)):
exchange.pop_first_message()

def _update_system_prompt_token_count(self, exchange: Type["exchange.exchange.Exchange"]) -> None:
def _update_system_prompt_token_count(self, exchange: Exchange) -> None:
is_different_system_prompt = False
if self.last_system_prompt != exchange.system:
is_different_system_prompt = True
Expand All @@ -55,7 +60,7 @@ def _update_system_prompt_token_count(self, exchange: Type["exchange.exchange.Ex
exchange.checkpoint_data.total_token_count -= last_system_prompt_token_count
exchange.checkpoint_data.total_token_count += self.system_prompt_token_count

def _get_messages_to_remove(self, exchange: Type["exchange.exchange.Exchange"]) -> List[Message]:
def _get_messages_to_remove(self, exchange: Exchange) -> List[Message]:
# this keeps all the messages/checkpoints
throwaway_exchange = exchange.replace(
moderator=PassiveModerator(),
Expand Down
4 changes: 2 additions & 2 deletions src/exchange/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def complete(
usage = self.get_usage(response_data)

return message, usage

@retry_httpx_request()
def _send_request(self, payload: Dict[str, Any]) -> httpx.Response:
return self.client.post(ANTHROPIC_HOST, json=payload)
return self.client.post(ANTHROPIC_HOST, json=payload)
28 changes: 10 additions & 18 deletions src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,30 @@ def __init__(self, client: httpx.Client, deployment_name: str, api_version: str)
super().__init__()
self.client = client
self.deployment_name = deployment_name
self.api_version = api_version
self.api_version = api_version

@classmethod
def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
try:
url = os.environ["AZURE_CHAT_COMPLETIONS_HOST_NAME"]
except KeyError:
raise RuntimeError(
"Failed to get AZURE_CHAT_COMPLETIONS_HOST_NAME from the environment."
)

raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_HOST_NAME from the environment.")

try:
deployment_name = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"]
except KeyError:
raise RuntimeError(
"Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME from the environment."
)

raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME from the environment.")

try:
api_version = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"]
except KeyError:
raise RuntimeError(
"Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION from the environment."
)

raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION from the environment.")

try:
key = os.environ["AZURE_CHAT_COMPLETIONS_KEY"]
except KeyError:
raise RuntimeError(
"Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment."
)

raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.")

# format the url host/"openai/deployments/" + deployment_name + "/chat/completions?api-version=" + api_version
url = f"{url}/openai/deployments/{deployment_name}"
client = httpx.Client(
Expand Down Expand Up @@ -111,6 +103,6 @@ def complete(
usage = self.get_usage(data)
return message, usage

@retry_httpx_request()
@retry_httpx_request()
def _send_request(self, payload: Any, request_url: str) -> httpx.Response: # noqa: ANN401
return self.client.post(request_url, json=payload)
2 changes: 1 addition & 1 deletion src/exchange/providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def complete(
return self.response_to_message(response_message), usage

@retry_httpx_request()
def _send_request(self, payload: Any, path:str) -> httpx.Response: # noqa: ANN401
def _send_request(self, payload: Any, path: str) -> httpx.Response: # noqa: ANN401
return self.client.post(path, json=payload)

@staticmethod
Expand Down
7 changes: 4 additions & 3 deletions src/exchange/providers/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

OLLAMA_HOST = "http://localhost:11434/"


#
# NOTE: this is experimental, best used with 70B model or larger if you can.
# NOTE: this is experimental, best used with 70B model or larger if you can.
# Example profile config to try:
class OllamaProvider(Provider):
"""Provides chat completions for models hosted by Ollama"""
Expand All @@ -33,10 +34,10 @@ class OllamaProvider(Provider):
toolkits:
- name: developer
requires: {}
"""
"""

def __init__(self, client: httpx.Client) -> None:
print('PLEASE NOTE: the ollama provider is experimental, use with care')
print("PLEASE NOTE: the ollama provider is experimental, use with care")
super().__init__()
self.client = client

Expand Down
21 changes: 13 additions & 8 deletions src/exchange/providers/retry_with_back_off_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@


def retry_with_backoff(
should_retry: Callable,
max_retries: Optional[int] = 5,
initial_wait: Optional[float] = 10,
backoff_factor: Optional[float] = 1,
handle_retry_exhausted: Optional[Callable] = None) -> Callable:
should_retry: Callable,
max_retries: Optional[int] = 5,
initial_wait: Optional[float] = 10,
backoff_factor: Optional[float] = 1,
handle_retry_exhausted: Optional[Callable] = None,
) -> Callable:
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args: List, **kwargs: Dict) -> Any: # noqa: ANN401
Expand All @@ -21,14 +22,17 @@ def wrapper(*args: List, **kwargs: Dict) -> Any: # noqa: ANN401
return result
if (retry + 1) == max_retries:
break
sleep_time = initial_wait + (backoff_factor * (2 ** retry))
sleep_time = initial_wait + (backoff_factor * (2**retry))
time.sleep(sleep_time)
if handle_retry_exhausted:
handle_retry_exhausted(result, max_retries)
return result

return wrapper

return decorator


def retry_httpx_request(
retry_on_status_code: Optional[Iterable[int]] = None,
max_retries: Optional[int] = 5,
Expand All @@ -37,6 +41,7 @@ def retry_httpx_request(
) -> Callable:
if retry_on_status_code is None:
retry_on_status_code = set(range(401, 999))

def should_retry(response: Response) -> bool:
return response.status_code in retry_on_status_code

Expand All @@ -52,5 +57,5 @@ def handle_retry_exhausted(response: Response, max_retries: int) -> None:
initial_wait=initial_wait,
backoff_factor=backoff_factor,
should_retry=should_retry,
handle_retry_exhausted=handle_retry_exhausted
)
handle_retry_exhausted=handle_retry_exhausted,
)
15 changes: 9 additions & 6 deletions tests/providers/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@


@pytest.fixture
@patch.dict(os.environ, {
"AZURE_CHAT_COMPLETIONS_HOST_NAME": "https://test.openai.azure.com/",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test-deployment",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "2024-02-15-preview",
"AZURE_CHAT_COMPLETIONS_KEY": "test_api_key"
})
@patch.dict(
os.environ,
{
"AZURE_CHAT_COMPLETIONS_HOST_NAME": "https://test.openai.azure.com/",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test-deployment",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "2024-02-15-preview",
"AZURE_CHAT_COMPLETIONS_KEY": "test_api_key",
},
)
def azure_provider():
return AzureProvider.from_env()

Expand Down
Loading

0 comments on commit 128925a

Please sign in to comment.