Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve ollama workflow from CI #53

Merged
merged 3 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/exchange/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from copy import deepcopy
from typing import Any, Dict, List, Mapping, Tuple

from attrs import define, evolve, field
from attrs import define, evolve, field, Factory
from tiktoken import get_encoding

from exchange.checkpoint import Checkpoint, CheckpointData
Expand Down Expand Up @@ -44,6 +44,7 @@ class Exchange:
tools: Tuple[Tool] = field(factory=tuple, converter=tuple)
messages: List[Message] = field(factory=list)
checkpoint_data: CheckpointData = field(factory=CheckpointData)
generation_args: dict = field(default=Factory(dict))

@property
def _toolmap(self) -> Mapping[str, Tool]:
Expand Down Expand Up @@ -77,6 +78,7 @@ def generate(self) -> Message:
self.system,
messages=self.messages,
tools=self.tools,
**self.generation_args,
)
self.add(message)
self.add_checkpoints_from_usage(usage) # this has to come after adding the response
Expand Down
29 changes: 17 additions & 12 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,25 @@
too_long_chars = "x" * (2**20 + 1)

cases = [
(get_provider("ollama"), os.getenv("OLLAMA_MODEL", OLLAMA_MODEL)),
(get_provider("openai"), "gpt-4o-mini"),
(get_provider("databricks"), "databricks-meta-llama-3-70b-instruct"),
(get_provider("bedrock"), "anthropic.claude-3-5-sonnet-20240620-v1:0"),
# Set seed and temperature for more determinism, to avoid flakes
(get_provider("ollama"), os.getenv("OLLAMA_MODEL", OLLAMA_MODEL), dict(seed=3, temperature=0.1)),
michaelneale marked this conversation as resolved.
Show resolved Hide resolved
(get_provider("openai"), "gpt-4o-mini", dict()),
(get_provider("databricks"), "databricks-meta-llama-3-70b-instruct", dict()),
(get_provider("bedrock"), "anthropic.claude-3-5-sonnet-20240620-v1:0", dict()),
]


@pytest.mark.integration # skipped in CI/CD
@pytest.mark.parametrize("provider,model", cases)
def test_simple(provider, model):
@pytest.mark.parametrize("provider,model,kwargs", cases)
def test_simple(provider, model, kwargs):
provider = provider.from_env()

ex = Exchange(
provider=provider,
model=model,
moderator=ContextTruncate(model),
system="You are a helpful assistant.",
generation_args=kwargs,
)

ex.add(Message.user("Who is the most famous wizard from the lord of the rings"))
Expand All @@ -38,8 +40,8 @@ def test_simple(provider, model):


@pytest.mark.integration # skipped in CI/CD
@pytest.mark.parametrize("provider,model", cases)
def test_tools(provider, model, tmp_path):
@pytest.mark.parametrize("provider,model,kwargs", cases)
def test_tools(provider, model, kwargs, tmp_path):
provider = provider.from_env()

def read_file(filename: str) -> str:
Expand All @@ -48,8 +50,8 @@ def read_file(filename: str) -> str:

Args:
filename (str): The path to the file, which can be relative or
absolute. If it is a plain filename, it is assumed to be in the
current working directory.
absolute. If it is a plain filename, it is assumed to be in the
current working directory.

Returns:
str: The contents of the file.
Expand All @@ -60,8 +62,10 @@ def read_file(filename: str) -> str:
ex = Exchange(
provider=provider,
model=model,
moderator=ContextTruncate(model),
system="You are a helpful assistant. Expect to need to read a file using read_file.",
tools=(Tool.from_function(read_file),),
generation_args=kwargs,
)

ex.add(Message.user("What are the contents of this file? test.txt"))
Expand All @@ -72,8 +76,8 @@ def read_file(filename: str) -> str:


@pytest.mark.integration
@pytest.mark.parametrize("provider,model", cases)
def test_tool_use_output_chars(provider, model):
@pytest.mark.parametrize("provider,model,kwargs", cases)
def test_tool_use_output_chars(provider, model, kwargs):
provider = provider.from_env()

def get_password() -> str:
Expand All @@ -86,6 +90,7 @@ def get_password() -> str:
moderator=ContextTruncate(model),
system="You are a helpful assistant. Expect to need to authenticate using get_password.",
tools=(Tool.from_function(get_password),),
generation_args=kwargs,
)

ex.add(Message.user("Can you authenticate this session by responding with the password"))
Expand Down