Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: consolidate azure with openai provider #60

Merged
merged 5 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 10 additions & 80 deletions src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,16 @@
import os
from typing import Any, Dict, List, Tuple, Type
from typing import Type

import httpx

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
openai_single_message_context_length_exceeded,
raise_for_status,
tools_to_openai_spec,
)
from exchange.tool import Tool
from exchange.providers import OpenAiProvider

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)

class AzureProvider(OpenAiProvider):
"""Provides chat completions for models hosted by the Azure OpenAI Service"""

class AzureProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI"""

def __init__(self, client: httpx.Client, deployment_name: str, api_version: str) -> None:
super().__init__()
self.client = client
self.deployment_name = deployment_name
self.api_version = api_version
def __init__(self, client: httpx.Client) -> None:
super().__init__(client)

@classmethod
def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
Expand All @@ -55,61 +34,12 @@ def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.")

# format the url host/"openai/deployments/" + deployment_name + "/chat/completions?api-version=" + api_version
url = f"{url}/openai/deployments/{deployment_name}"
# format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version
url = f"{url}/openai/deployments/{deployment_name}/"
client = httpx.Client(
base_url=url,
headers={"api-key": key, "Content-Type": "application/json"},
params={"api-version": api_version},
timeout=httpx.Timeout(60 * 10),
)
return cls(client, deployment_name, api_version)

@staticmethod
def get_usage(data: dict) -> Usage:
usage = data.pop("usage")
input_tokens = usage.get("prompt_tokens")
output_tokens = usage.get("completion_tokens")
total_tokens = usage.get("total_tokens")

if total_tokens is None and input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens

return Usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)

def complete(
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
payload = dict(
messages=[
{"role": "system", "content": system},
*messages_to_openai_spec(messages),
],
tools=tools_to_openai_spec(tools) if tools else [],
**kwargs,
)

payload = {k: v for k, v in payload.items() if v}
request_url = f"{self.client.base_url}/chat/completions?api-version={self.api_version}"
response = self._post(payload, request_url)

# Check for context_length_exceeded error for single, long input message
if "error" in response and len(messages) == 1:
openai_single_message_context_length_exceeded(response["error"])

message = openai_response_to_message(response)
usage = self.get_usage(response)
return message, usage

@retry_procedure
def _post(self, payload: dict, request_url: str) -> dict:
response = self.client.post(request_url, json=payload)
return raise_for_status(response).json()
return cls(client)
17 changes: 8 additions & 9 deletions src/exchange/providers/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@ def __init__(self, client: httpx.Client) -> None:

@classmethod
def from_env(cls: Type["OllamaProvider"]) -> "OllamaProvider":
url = os.environ.get("OLLAMA_HOST", OLLAMA_HOST)
client = httpx.Client(
base_url=url,
timeout=httpx.Timeout(60 * 10),
)
ollama_url = os.environ.get("OLLAMA_HOST", OLLAMA_HOST)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@baxen lemme know if you think this is more sensible than before

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this looks good!

timeout = httpx.Timeout(60 * 10)

# from_env is expected to fail if required ENV variables are not
# available. Since this provider can run with defaults, we substitute
# a health check to verify the endpoint is running.
client.get("")
# The OpenAI API is defined after "v1/", so we need to join it here.
client.base_url = client.base_url.join("v1/")
# an Ollama health check (GET /) to determine if the service is ok.
httpx.get(ollama_url, timeout=timeout)

# When served by Ollama, the OpenAI API is available at the path "v1/".
client = httpx.Client(base_url=ollama_url + "v1/", timeout=timeout)
return cls(client)
68 changes: 68 additions & 0 deletions tests/providers/cassettes/test_azure_complete.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"}], "model": "gpt-4o-mini"}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
api-key:
- test_azure_api_key
connection:
- keep-alive
content-length:
- '139'
content-type:
- application/json
host:
- test.openai.azure.com
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://test.openai.azure.com/openai/deployments/test-azure-deployment/chat/completions?api-version=2024-05-01-preview
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michaelneale this is a scrubbed real request/response

response:
body:
string: '{"choices":[{"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}},"finish_reason":"stop","index":0,"logprobs":null,"message":{"content":"Hello!
How can I assist you today?","role":"assistant"}}],"created":1727230065,"id":"chatcmpl-ABBjN3AoYlxkP7Vg2lBvUhYeA6j5K","model":"gpt-4-32k","object":"chat.completion","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"system_fingerprint":null,"usage":{"completion_tokens":9,"prompt_tokens":18,"total_tokens":27}}

'
headers:
Cache-Control:
- no-cache, must-revalidate
Content-Length:
- '825'
Content-Type:
- application/json
Date:
- Wed, 25 Sep 2024 02:07:45 GMT
Set-Cookie: test_set_cookie
Strict-Transport-Security:
- max-age=31536000; includeSubDomains; preload
access-control-allow-origin:
- '*'
apim-request-id:
- 82e66ef8-ac07-4a43-b60f-9aecec1d8c81
azureml-model-session:
- d145-20240919052126
openai-organization: test_openai_org_key
x-accel-buffering:
- 'no'
x-content-type-options:
- nosniff
x-ms-client-request-id:
- 82e66ef8-ac07-4a43-b60f-9aecec1d8c81
x-ms-rai-invoked:
- 'true'
x-ms-region:
- Switzerland North
x-ratelimit-remaining-requests:
- '79'
x-ratelimit-remaining-tokens:
- '79984'
x-request-id:
- 38db9001-8b16-4efe-84c9-620e10f18c3c
status:
code: 200
message: OK
version: 1
74 changes: 74 additions & 0 deletions tests/providers/cassettes/test_azure_tools.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant.
Expect to need to read a file using read_file."}, {"role": "user", "content":
"What are the contents of this file? test.txt"}], "model": "gpt-4o-mini", "tools":
[{"type": "function", "function": {"name": "read_file", "description": "Read
the contents of the file.", "parameters": {"type": "object", "properties": {"filename":
{"type": "string", "description": "The path to the file, which can be relative
or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent
working directory."}}, "required": ["filename"]}}}]}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
api-key:
- test_azure_api_key
connection:
- keep-alive
content-length:
- '608'
content-type:
- application/json
host:
- test.openai.azure.com
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://test.openai.azure.com/openai/deployments/test-azure-deployment/chat/completions?api-version=2024-05-01-preview
response:
body:
string: '{"choices":[{"content_filter_results":{},"finish_reason":"tool_calls","index":0,"logprobs":null,"message":{"content":null,"role":"assistant","tool_calls":[{"function":{"arguments":"{\n \"filename\":
\"test.txt\"\n}","name":"read_file"},"id":"call_a47abadDxlGKIWjvYYvGVAHa","type":"function"}]}}],"created":1727256650,"id":"chatcmpl-ABIeABbq5WVCq0e0AriGFaYDSih3P","model":"gpt-4-32k","object":"chat.completion","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"system_fingerprint":null,"usage":{"completion_tokens":16,"prompt_tokens":109,"total_tokens":125}}

'
headers:
Cache-Control:
- no-cache, must-revalidate
Content-Length:
- '769'
Content-Type:
- application/json
Date:
- Wed, 25 Sep 2024 09:30:50 GMT
Set-Cookie: test_set_cookie
Strict-Transport-Security:
- max-age=31536000; includeSubDomains; preload
access-control-allow-origin:
- '*'
apim-request-id:
- 8c0e3372-8ffd-4ff5-a5d1-0b962c4ea339
azureml-model-session:
- d145-20240919052126
openai-organization: test_openai_org_key
x-accel-buffering:
- 'no'
x-content-type-options:
- nosniff
x-ms-client-request-id:
- 8c0e3372-8ffd-4ff5-a5d1-0b962c4ea339
x-ms-rai-invoked:
- 'true'
x-ms-region:
- Switzerland North
x-ratelimit-remaining-requests:
- '79'
x-ratelimit-remaining-tokens:
- '79824'
x-request-id:
- 401bd803-b790-47b7-b098-98708d44f060
status:
code: 200
message: OK
version: 1
42 changes: 42 additions & 0 deletions tests/providers/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
from typing import Type, Tuple

import pytest
Expand Down Expand Up @@ -26,6 +27,32 @@ def default_openai_env(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", OPENAI_API_KEY)


AZURE_ENDPOINT = "https://test.openai.azure.com"
AZURE_DEPLOYMENT_NAME = "test-azure-deployment"
AZURE_API_VERSION = "2024-05-01-preview"
AZURE_API_KEY = "test_azure_api_key"


@pytest.fixture
def default_azure_env(monkeypatch):
"""
This fixture prevents AzureProvider.from_env() from erring on missing
environment variables.

When running VCR tests for the first time or after deleting a cassette
recording, set required environment variables, so that real requests don't
fail. Subsequent runs use the recorded data, so don't need them.
"""
if "AZURE_CHAT_COMPLETIONS_HOST_NAME" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_HOST_NAME", AZURE_ENDPOINT)
if "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", AZURE_DEPLOYMENT_NAME)
if "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", AZURE_API_VERSION)
if "AZURE_CHAT_COMPLETIONS_KEY" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", AZURE_API_KEY)


@pytest.fixture(scope="module")
def vcr_config():
"""
Expand All @@ -43,10 +70,25 @@ def vcr_config():
("openai-project", OPENAI_PROJECT_ID),
("cookie", None),
],
"before_record_request": scrub_request_url,
"before_record_response": scrub_response_headers,
}


def scrub_request_url(request):
"""
This scrubs sensitive request data in provider-specific way. Note that headers
are case-sensitive!
"""
if "openai.azure.com" in request.uri:
request.uri = re.sub(r"https://[^/]+", AZURE_ENDPOINT, request.uri)
request.uri = re.sub(r"/deployments/[^/]+", f"/deployments/{AZURE_DEPLOYMENT_NAME}", request.uri)
request.headers["host"] = AZURE_ENDPOINT.replace("https://", "")
request.headers["api-key"] = AZURE_API_KEY

return request


def scrub_response_headers(response):
"""
This scrubs sensitive response headers. Note they are case-sensitive!
Expand Down
Loading