From 8dcf2be8d79aa91eb22441ffecf389ae5eaf60e0 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Wed, 25 Sep 2024 11:36:53 -0400 Subject: [PATCH 1/3] stash - kinda working with messages_dict but its messy --- src/exchange/providers/openai.py | 274 ++++++++++++++++++++++++++++++- 1 file changed, 266 insertions(+), 8 deletions(-) diff --git a/src/exchange/providers/openai.py b/src/exchange/providers/openai.py index cd0504f..f6b3142 100644 --- a/src/exchange/providers/openai.py +++ b/src/exchange/providers/openai.py @@ -1,6 +1,7 @@ import os +import sys from typing import Any, Dict, List, Tuple, Type - +import json import httpx from exchange.message import Message @@ -15,6 +16,7 @@ from exchange.tool import Tool from tenacity import retry, wait_fixed, stop_after_attempt from exchange.providers.utils import retry_if_status +from exchange.content import Text OPENAI_HOST = "https://api.openai.com/" @@ -25,6 +27,217 @@ reraise=True, ) +USER_PROMPT_TEMPLATE = """ +## Task + +{system} + +## Available Tools & Response Guidelines +You can either respond to the user with a message or compose tool calls. Your task is to translate user queries into appropriate tool calls or response messages in JSON format. + +Follow these guidelines: +- Always respond with a valid JSON object containing the function to call and its parameters. +- Do not include any additional text or explanations. +- If you are responding with a message, include the key "message" with the response text. +- If you are composing a tool call, include the key "tool_calls" with a list of tool calls. + +Here are some examples: + +Example 1: +--- +User Query: What's the weather like in New York today? +Available Functions: +1. get_current_weather(location) +2. get_forecast(location, days) + +Response: +{"tool_calls": [{ + "function": "get_current_weather", + "parameters": { + "location": "New York" + } +}]} +--- + +Example 2: +--- +User Query: Find me Italian restaurants nearby. +Available Functions: +1. search_restaurants(cuisine, location) +2. get_restaurant_details(restaurant_id) + +Response: +{"tool_calls": [{ + "function": "search_restaurants", + "parameters": { + "cuisine": "Italian", + "location": "current location" + } +}]} +--- + +Example 3: +--- +User Query: Schedule a meeting with John tomorrow at 10 AM and show me the calendar. +Available Functions: +1. create_event(title, datetime, participants) +2. get_calendar() + +Response: +{"tool_calls": [ + { + "function": "create_event", + "parameters": { + "title": "Meeting with John", + "datetime": "tomorrow at 10 AM", + "participants": ["John"] + } + }, + { + "function": "get_calendar", + "parameters": {{}} + } +]} +--- + +Example 4: +--- +User Query: Hi there! +Available Functions: +1. create_event(title, datetime, participants) +2. get_calendar() + +Response: +{ + "message": Hey! How can I help you today? +} + +Now, given the following user query and available functions, respond with the appropriate function call in JSON format. + +User Query: {user_query} +Available Functions: +{available_functions} + +Response: +""" + + +def is_o1(model: str) -> bool: + return model.startswith("o1") + + +def update_system_message(system: str, tools: Tuple[Tool]) -> str: + if not tools: + return system + + tool_names_str = "" + for i, tool in enumerate(tools, start=1): + tool_names_str += f"{i}. {tool.name}({', '.join(tool.parameters.keys())})\n" + + return USER_PROMPT_TEMPLATE.format(system=system, available_functions=tool_names_str) + + +def merge_consecutive_roles(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Merges consecutive messages with the same role into a single message. + + Args: + messages (List[Dict[str, Any]]): The list of messages to merge. + + Returns: + List[Dict[str, Any]]: The list of messages with consecutive messages of the same role merged. + """ + merged_messages = [] + current_role = None + current_content = "" + + for msg in messages: + role = msg.get("role") + content = msg.get("content", "") + + if role == current_role: + current_content += "\n" + content + else: + if current_role: + merged_messages.append({"role": current_role, "content": current_content.strip()}) + current_role = role + current_content = content + + if current_role: + merged_messages.append({"role": current_role, "content": current_content.strip()}) + + return merged_messages + + +def convert_messages(original_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Converts original messages with 'system', 'user', 'assistant', and 'tool' roles + into messages containing only 'user' and 'assistant' roles. + + Args: + original_messages (List[Dict[str, Any]]): The original list of messages. + + Returns: + List[Dict[str, Any]]: The converted list of messages with only 'user' and 'assistant' roles. + """ + converted_messages = [] + tool_call_map = {} # Maps tool_call_id to function details + + for msg in original_messages: + role = msg.get("role") + content = msg.get("content", "") + + if role == "system": + # Convert 'system' messages to 'user' messages, optionally prefixing for clarity + content = f"[System]: {content}" + converted_messages.append({"role": "user", "content": content}) + + elif role == "user": + converted_messages.append({"role": "user", "content": content}) + + elif role == "assistant": + # Check for 'tool_calls' in the 'assistant' message + tool_calls = msg.get("tool_calls", []) + if tool_calls: + for tool_call in tool_calls: + tool_id = tool_call.get("id") + function = tool_call.get("function", {}) + function_name = function.get("name") + arguments = function.get("arguments") + + # Store the tool call details + tool_call_map[tool_id] = {"name": function_name, "arguments": arguments} + + # Optionally, you can indicate that the assistant initiated tool calls + # For this implementation, we'll not modify the assistant's content + # But this can be customized based on specific needs + + # Append the 'assistant' message + assistant_content = msg.get("content", "") + converted_messages.append({"role": "assistant", "content": assistant_content}) + + elif role == "tool": + # 'tool' messages are outputs from tool calls; convert them to 'assistant' messages + # Find the corresponding tool call based on 'tool_call_id' + tool_call_id = msg.get("tool_call_id") + tool_output = msg.get("content", "") + + if tool_call_id and tool_call_id in tool_call_map: + function_details = tool_call_map[tool_call_id] + function_name = function_details["name"] + + # You can format the assistant's response to include tool output contextually + # For simplicity, we'll append the tool output directly + assistant_output = f"[Tool Output - {function_name}]: {tool_output}" + converted_messages.append({"role": "assistant", "content": assistant_output}) + else: + # If 'tool_call_id' is missing or not found, append the tool output as-is + assistant_output = f"[Tool Output]: {tool_output}" + converted_messages.append({"role": "assistant", "content": assistant_output}) + + merged_converted_messages = merge_consecutive_roles(messages=converted_messages) + return merged_converted_messages + class OpenAiProvider(Provider): """Provides chat completions for models hosted directly by OpenAI""" @@ -73,13 +286,24 @@ def complete( tools: Tuple[Tool], **kwargs: Dict[str, Any], ) -> Tuple[Message, Usage]: - system_message = [] if model.startswith("o1") else [{"role": "system", "content": system}] - payload = dict( - messages=system_message + messages_to_openai_spec(messages), - model=model, - tools=tools_to_openai_spec(tools) if tools else [], - **kwargs, - ) + if is_o1(model): + system_with_tools = update_system_message(system, tools) + converted = convert_messages( + [{"role": "system", "content": system}] + messages_to_openai_spec(messages), + ) + messages = [Message(role=m["role"], content=[Text(text=m["content"])]) for m in converted] + payload = dict( + messages=messages_to_openai_spec(messages), + model=model, + **kwargs, + ) + else: + payload = dict( + messages=[{"role": "system", "content": system}] + messages_to_openai_spec(messages), + model=model, + tools=tools_to_openai_spec(tools) if tools else [], + **kwargs, + ) payload = {k: v for k, v in payload.items() if v} response = self._post(payload) @@ -95,3 +319,37 @@ def complete( def _post(self, payload: dict) -> dict: response = self.client.post("v1/chat/completions", json=payload) return raise_for_status(response).json() + + +if __name__ == "__main__": + original_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hi there!"}, + {"role": "user", "content": "can you help me?"}, + {"role": "assistant", "content": "Sure! What do you need assistance with?", "tool_calls": []}, + {"role": "user", "content": "I need to book a flight to Paris on December 25th."}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "tcall_1", + "type": "function", + "function": {"name": "book_flight", "arguments": '{"destination": "Paris", "date": "2023-12-25"}'}, + } + ], + }, + { + "role": "tool", + "content": "Your flight to Paris on December 25th has been booked.", + "tool_call_id": "tcall_1", + }, + {"role": "assistant", "content": "I've booked your flight to Paris on December 25th."}, + ] + + converted = convert_messages(original_messages) + + messages = [Message(role=m["role"], content=[Text(text=m["content"])]) for m in converted] + print("Messages:") + for msg in messages: + print(msg) From 16420d17e4425b0a81116d46164a3a7bf1829c61 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Wed, 25 Sep 2024 14:40:58 -0400 Subject: [PATCH 2/3] add o1 implementation for openai provider --- src/exchange/exchange.py | 1 + src/exchange/providers/openai.py | 630 ++++++++++++++++++++++--------- tests/providers/test_openai.py | 47 ++- tests/test_exchange.py | 55 +++ 4 files changed, 559 insertions(+), 174 deletions(-) diff --git a/src/exchange/exchange.py b/src/exchange/exchange.py index 301208c..d7932fa 100644 --- a/src/exchange/exchange.py +++ b/src/exchange/exchange.py @@ -107,6 +107,7 @@ def reply(self, max_tool_use: int = 128) -> Message: for tool_use in response.tool_use: tool_result = self.call_function(tool_use) content.append(tool_result) + print(f"Tool call result - user content: {content}") self.add(Message(role="user", content=content)) # We've reached the limit of tool calls - break out of the loop diff --git a/src/exchange/providers/openai.py b/src/exchange/providers/openai.py index f6b3142..c30e186 100644 --- a/src/exchange/providers/openai.py +++ b/src/exchange/providers/openai.py @@ -1,5 +1,5 @@ import os -import sys +import re from typing import Any, Dict, List, Tuple, Type import json import httpx @@ -16,7 +16,8 @@ from exchange.tool import Tool from tenacity import retry, wait_fixed, stop_after_attempt from exchange.providers.utils import retry_if_status -from exchange.content import Text +from exchange.content import Text, ToolUse +from exchange.utils import create_object_id OPENAI_HOST = "https://api.openai.com/" @@ -28,12 +29,13 @@ ) USER_PROMPT_TEMPLATE = """ -## Task +## Instructions {system} ## Available Tools & Response Guidelines -You can either respond to the user with a message or compose tool calls. Your task is to translate user queries into appropriate tool calls or response messages in JSON format. +You can either respond to the user with a message or compose tool calls. Your task is to translate user queries into +appropriate tool calls or response messages in JSON format. Follow these guidelines: - Always respond with a valid JSON object containing the function to call and its parameters. @@ -45,198 +47,408 @@ Example 1: --- -User Query: What's the weather like in New York today? +User Query: +What's the weather like in New York today? + Available Functions: -1. get_current_weather(location) -2. get_forecast(location, days) +[ + {{ + "type": "function", + "function": {{ + "name": "get_current_weather", + "description": "Get the current weather for the specified location.", + "parameters": {{ + "type": "object", + "properties": {{ + "location": {{ + "type": "string", + "description": "The location to get the weather for." + }} + }}, + "required": [ + "location" + ] + }} + }} + }}, + {{ + "type": "function", + "function": {{ + "name": "get_forecast", + "description": "Get the weather forecast for the specified location for the next 'days' days.", + "parameters": {{ + "type": "object", + "properties": {{ + "location": {{ + "type": "string", + "description": "The location to get the weather forecast for." + }}, + "days": {{ + "type": "integer", + "description": "The number of days to get the forecast for." + }} + }}, + "required": [ + "location", + "days" + ] + }} + }} + }} +] Response: -{"tool_calls": [{ +```json +{{"tool_calls": [{{ "function": "get_current_weather", - "parameters": { + "parameters": {{ "location": "New York" - } -}]} + }} +}}]}} +``` --- Example 2: --- -User Query: Find me Italian restaurants nearby. +User Query: +Find me Italian restaurants nearby. + Available Functions: -1. search_restaurants(cuisine, location) -2. get_restaurant_details(restaurant_id) +[ + {{ + "type": "function", + "function": {{ + "name": "search_restaurants", + "description": "Search for restaurants of the specified cuisine near the given location.", + "parameters": {{ + "type": "object", + "properties": {{ + "cuisine": {{ + "type": "string", + "description": "The type of cuisine to search for." + }}, + "location": {{ + "type": "string", + "description": "The location to search near." + }} + }}, + "required": [ + "cuisine", + "location" + ] + }} + }} + }}, + {{ + "type": "function", + "function": {{ + "name": "get_restaurant_details", + "description": "Get the details for the specified restaurant.", + "parameters": {{ + "type": "object", + "properties": {{ + "restaurant_id": {{ + "type": "string", + "description": "The unique identifier for the restaurant." + }} + }}, + "required": [ + "restaurant_id" + ] + }} + }} + }} +] Response: -{"tool_calls": [{ +```json +{{"tool_calls": [{{ "function": "search_restaurants", - "parameters": { + "parameters": {{ "cuisine": "Italian", "location": "current location" - } -}]} + }} +}}]}} +``` --- Example 3: --- -User Query: Schedule a meeting with John tomorrow at 10 AM and show me the calendar. +User Query: +Schedule a meeting with John tomorrow at 10 AM and show me the calendar. + Available Functions: -1. create_event(title, datetime, participants) -2. get_calendar() +[ + {{ + "type": "function", + "function": {{ + "name": "create_event", + "description": "Create an event with the specified title, datetime, and participants.", + "parameters": {{ + "type": "object", + "properties": {{ + "title": {{ + "type": "string", + "description": "The title of the event." + }}, + "datetime": {{ + "type": "string", + "description": "The date and time of the event." + }}, + "participants": {{ + "type": "array", + "items": {{ + "type": "string" + }}, + "description": "The list of participants for the event." + }} + }}, + "required": [ + "title", + "datetime", + "participants" + ] + }} + }} + }}, + {{ + "type": "function", + "function": {{ + "name": "get_calendar", + "description": "Get the user's calendar.", + "parameters": {{ + "type": "object", + "properties": {{}}, + "required": [] + }} + }} + }} +] Response: -{"tool_calls": [ - { +```json +{{"tool_calls": [ + {{ "function": "create_event", - "parameters": { + "parameters": {{ "title": "Meeting with John", "datetime": "tomorrow at 10 AM", "participants": ["John"] - } - }, - { + }} + }}, + {{ "function": "get_calendar", "parameters": {{}} - } -]} + }} +]}} +``` --- Example 4: --- -User Query: Hi there! +User Query: +Hi there! + Available Functions: -1. create_event(title, datetime, participants) -2. get_calendar() +[ + {{ + "type": "function", + "function": {{ + "name": "create_event", + "description": "Create an event with the specified title, datetime, and participants.", + "parameters": {{ + "type": "object", + "properties": {{ + "title": {{ + "type": "string", + "description": "The title of the event." + }}, + "datetime": {{ + "type": "string", + "description": "The date and time of the event." + }}, + "participants": {{ + "type": "array", + "items": {{ + "type": "string" + }}, + "description": "The list of participants for the event." + }} + }}, + "required": [ + "title", + "datetime", + "participants" + ] + }} + }} + }}, + {{ + "type": "function", + "function": {{ + "name": "get_calendar", + "description": "Get the user's calendar.", + "parameters": {{ + "type": "object", + "properties": {{}}, + "required": [] + }} + }} + }} +] Response: -{ - "message": Hey! How can I help you today? -} +```json +{{ + "message": "Hey! How can I help you today?" +}} +``` +--- + +Example 5: +--- +User Query: +There is no user query. The last assistant message contained tool calls. Here are the tool results: +[{{"tool_use_id": "tool_use_a0ce4b0f4ff8476f99c77f43", "output":'"Your flight to London on 2024-03-14 has been booked.", "is_error": false}}] + +Available Functions: +[ + {{ + "type": "function", + "function": {{ + "name": "create_event", + "description": "Create an event with the specified title, datetime, and participants.", + "parameters": {{ + "type": "object", + "properties": {{ + "title": {{ + "type": "string", + "description": "The title of the event." + }}, + "datetime": {{ + "type": "string", + "description": "The date and time of the event." + }}, + "participants": {{ + "type": "array", + "items": {{ + "type": "string" + }}, + "description": "The list of participants for the event." + }} + }}, + "required": [ + "title", + "datetime", + "participants" + ] + }} + }} + }}, + {{ + "type": "function", + "function": {{ + "name": "get_calendar", + "description": "Get the user's calendar.", + "parameters": {{ + "type": "object", + "properties": {{}}, + "required": [] + }} + }} + }} +] + +Response: +```json +{{ + "message": "As requested, we have booked your flight to London on 2024-03-14. Please let me know if you need anything else." +}} +``` + +## Task Now, given the following user query and available functions, respond with the appropriate function call in JSON format. -User Query: {user_query} +User Query: +{user_query} + Available Functions: {available_functions} Response: -""" +""".strip() def is_o1(model: str) -> bool: return model.startswith("o1") -def update_system_message(system: str, tools: Tuple[Tool]) -> str: - if not tools: - return system +def update_system_message(system: str, tools: Tuple[Tool], user_query: str) -> str: + tool_descriptions = json.dumps(tools_to_openai_spec(tools), indent=2) - tool_names_str = "" - for i, tool in enumerate(tools, start=1): - tool_names_str += f"{i}. {tool.name}({', '.join(tool.parameters.keys())})\n" + return USER_PROMPT_TEMPLATE.format(system=system, available_functions=tool_descriptions, user_query=user_query) - return USER_PROMPT_TEMPLATE.format(system=system, available_functions=tool_names_str) +def extract_code_blocks(text: str) -> list: + # Regular expression to match code blocks + code_blocks = re.findall(r"```(?:\w+\n)?(.*?)```", text, re.DOTALL) + return code_blocks -def merge_consecutive_roles(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - Merges consecutive messages with the same role into a single message. - Args: - messages (List[Dict[str, Any]]): The list of messages to merge. +def parse_o1_assistant_reply(reply_text: str) -> Message: + error_msg = "ERROR: This is not a valid response. Please make sure the response is in JSON format and contains either a 'message' key or 'tool_calls' key." - Returns: - List[Dict[str, Any]]: The list of messages with consecutive messages of the same role merged. - """ - merged_messages = [] - current_role = None - current_content = "" + try: + code_blocks = extract_code_blocks(reply_text) + if code_blocks: + # If code blocks are present, treat the first block as JSON response + response_data = json.loads(code_blocks[0]) + else: + response_data = json.loads(reply_text) + except json.JSONDecodeError: + # If parsing fails, treat the reply as regular text + return Message( + role="assistant", + content=[Text(text=error_msg)], + ) - for msg in messages: - role = msg.get("role") - content = msg.get("content", "") + content = [] - if role == current_role: - current_content += "\n" + content - else: - if current_role: - merged_messages.append({"role": current_role, "content": current_content.strip()}) - current_role = role - current_content = content + if "tool_calls" in response_data: + tool_calls = response_data["tool_calls"] + for tool_call in tool_calls: + tool_use = ToolUse( + id=create_object_id("tool_use"), + name=tool_call["function"], + parameters=tool_call["parameters"], + ) + content.append(tool_use) + elif "message" in response_data: + message_text = response_data["message"] + content.append(Text(text=message_text)) + else: + # Unrecognized format, treat as regular text + content.append(Text(text=error_msg)) - if current_role: - merged_messages.append({"role": current_role, "content": current_content.strip()}) + return Message(role="assistant", content=content) - return merged_messages +def merge_consecutive_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + if not messages: + return messages + + merged_messages = [messages[0]] -def convert_messages(original_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - Converts original messages with 'system', 'user', 'assistant', and 'tool' roles - into messages containing only 'user' and 'assistant' roles. - - Args: - original_messages (List[Dict[str, Any]]): The original list of messages. - - Returns: - List[Dict[str, Any]]: The converted list of messages with only 'user' and 'assistant' roles. - """ - converted_messages = [] - tool_call_map = {} # Maps tool_call_id to function details - - for msg in original_messages: - role = msg.get("role") - content = msg.get("content", "") - - if role == "system": - # Convert 'system' messages to 'user' messages, optionally prefixing for clarity - content = f"[System]: {content}" - converted_messages.append({"role": "user", "content": content}) - - elif role == "user": - converted_messages.append({"role": "user", "content": content}) - - elif role == "assistant": - # Check for 'tool_calls' in the 'assistant' message - tool_calls = msg.get("tool_calls", []) - if tool_calls: - for tool_call in tool_calls: - tool_id = tool_call.get("id") - function = tool_call.get("function", {}) - function_name = function.get("name") - arguments = function.get("arguments") - - # Store the tool call details - tool_call_map[tool_id] = {"name": function_name, "arguments": arguments} - - # Optionally, you can indicate that the assistant initiated tool calls - # For this implementation, we'll not modify the assistant's content - # But this can be customized based on specific needs - - # Append the 'assistant' message - assistant_content = msg.get("content", "") - converted_messages.append({"role": "assistant", "content": assistant_content}) - - elif role == "tool": - # 'tool' messages are outputs from tool calls; convert them to 'assistant' messages - # Find the corresponding tool call based on 'tool_call_id' - tool_call_id = msg.get("tool_call_id") - tool_output = msg.get("content", "") - - if tool_call_id and tool_call_id in tool_call_map: - function_details = tool_call_map[tool_call_id] - function_name = function_details["name"] - - # You can format the assistant's response to include tool output contextually - # For simplicity, we'll append the tool output directly - assistant_output = f"[Tool Output - {function_name}]: {tool_output}" - converted_messages.append({"role": "assistant", "content": assistant_output}) - else: - # If 'tool_call_id' is missing or not found, append the tool output as-is - assistant_output = f"[Tool Output]: {tool_output}" - converted_messages.append({"role": "assistant", "content": assistant_output}) - - merged_converted_messages = merge_consecutive_roles(messages=converted_messages) - return merged_converted_messages + for current_message in messages[1:]: + last_message = merged_messages[-1] + if current_message["role"] == last_message["role"]: + # Merge contents + last_message["content"] += "\n" + current_message["content"] + else: + merged_messages.append(current_message) + + return merged_messages class OpenAiProvider(Provider): @@ -287,13 +499,76 @@ def complete( **kwargs: Dict[str, Any], ) -> Tuple[Message, Usage]: if is_o1(model): - system_with_tools = update_system_message(system, tools) - converted = convert_messages( - [{"role": "system", "content": system}] + messages_to_openai_spec(messages), + # Prepare the messages for o1 models + # Find the last user message + last_user_message_index = None + for i in range(len(messages) - 1, -1, -1): + if messages[i].role == "user": + last_user_message_index = i + break + + if last_user_message_index is None: + raise ValueError("No user message found in messages") + + last_user_message = messages[last_user_message_index] + + tool_results = last_user_message.tool_result + user_query = last_user_message.text + if tool_results: + tool_result_str = "[\n" + "\n".join([json.dumps(tr.to_dict()) for tr in tool_results]) + "\n]" + user_query = f"There is no user query. The last assistant message contained tool calls. Here are the tool results:\n{tool_result_str}" + elif not user_query: + user_query = "There is no user query. You can respond to the user with 'message'." + + # Update the system message (incorporated into the user message) + combined_user_message_text = update_system_message(system, tools, user_query) + + # Prepare the messages to send + messages_to_send = [] + + # Process previous messages before last user message + for message in messages[:last_user_message_index]: + if message.role == "assistant": + messages_to_send.append( + { + "role": "assistant", + "content": message.text, + } + ) + elif message.role == "user": + # Include user messages (e.g., tool results) + if message.tool_result: + tool_result_texts = [tr.output for tr in message.tool_result] + tool_result_combined = "\n".join(tool_result_texts) + messages_to_send.append( + { + "role": "user", + "content": tool_result_combined, + } + ) + else: + messages_to_send.append( + { + "role": "user", + "content": message.text, + } + ) + else: + pass # Ignore other roles + + # Add the combined user message + messages_to_send.append( + { + "role": "user", + "content": combined_user_message_text, + } ) - messages = [Message(role=m["role"], content=[Text(text=m["content"])]) for m in converted] + + # Merge consecutive messages with the same role + messages_to_send = merge_consecutive_messages(messages_to_send) + payload = dict( - messages=messages_to_openai_spec(messages), + messages=messages_to_send, model=model, **kwargs, ) @@ -311,7 +586,12 @@ def complete( if "error" in response and len(messages) == 1: openai_single_message_context_length_exceeded(response["error"]) - message = openai_response_to_message(response) + if is_o1(model): + assistant_reply = response["choices"][0]["message"]["content"] + message = parse_o1_assistant_reply(assistant_reply) + else: + message = openai_response_to_message(response) + usage = self.get_usage(response) return message, usage @@ -322,34 +602,38 @@ def _post(self, payload: dict) -> dict: if __name__ == "__main__": - original_messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hi there!"}, - {"role": "user", "content": "can you help me?"}, - {"role": "assistant", "content": "Sure! What do you need assistance with?", "tool_calls": []}, - {"role": "user", "content": "I need to book a flight to Paris on December 25th."}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "tcall_1", - "type": "function", - "function": {"name": "book_flight", "arguments": '{"destination": "Paris", "date": "2023-12-25"}'}, - } - ], - }, - { - "role": "tool", - "content": "Your flight to Paris on December 25th has been booked.", - "tool_call_id": "tcall_1", - }, - {"role": "assistant", "content": "I've booked your flight to Paris on December 25th."}, - ] - - converted = convert_messages(original_messages) - - messages = [Message(role=m["role"], content=[Text(text=m["content"])]) for m in converted] - print("Messages:") - for msg in messages: - print(msg) + from exchange import Exchange, Text + from exchange.moderators.passive import PassiveModerator + import pprint + + def book_flight(destination: str, date: str): + """Book a flight to the specified destination on the given date. + + Args: + destination (str): The airport code for destination of the flight. E.g., "LAX" for Los Angeles. + date (str): The date of the flight in "YYYY-MM-DD" format. E.g., "2023-12-25". + """ + return f"Your flight to {destination} on {date} has been booked." + + system = "You are a helpful assistant" + tools = [Tool.from_function(book_flight)] + provider = OpenAiProvider.from_env() + + ex = Exchange( + provider=provider, + model="o1-mini", # Use the 'o1' model + system="You are a helpful assistant.", + tools=tools, + moderator=PassiveModerator(), + ) + + ex.add(Message(role="user", content=[Text(text="Hi there!"), Text(text="can you help me?")])) + pprint.pp(ex) + print("-" * 80) + + pprint.pp(ex.reply()) + print("-" * 80) + + ex.add(Message.user("I need to book a flight to Paris on December 25th.")) + pprint.pp(ex.reply()) + print("-" * 80) diff --git a/tests/providers/test_openai.py b/tests/providers/test_openai.py index 0e0e000..35c5b5a 100644 --- a/tests/providers/test_openai.py +++ b/tests/providers/test_openai.py @@ -1,9 +1,15 @@ import os -from unittest.mock import patch +from unittest.mock import ANY, patch import pytest from exchange import Message, Text from exchange.providers.openai import OpenAiProvider +from exchange.tool import Tool + + +def dummy_tool() -> str: + """An example tool""" + return "dummy response" @pytest.fixture @@ -45,6 +51,45 @@ def test_openai_completion(mock_error, mock_warning, mock_sleep, mock_post, open ) +@patch("httpx.Client.post") +def test_openai_completion_o1_model(mock_post, openai_provider): + # Mock response from 'o1' model + mock_reply_content = '{"tool_calls": [{"function": "dummy_tool", "parameters": {}}]}' + mock_response = { + "choices": [{"message": {"role": "assistant", "content": mock_reply_content}}], + "usage": {"prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35}, + } + + mock_post.return_value.json.return_value = mock_response + + model = "o1-mini" + system = "You are a helpful assistant." + messages = [Message.user("Hello")] + tools = [Tool.from_function(dummy_tool)] + + # Call the complete method + reply_message, reply_usage = openai_provider.complete(model=model, system=system, messages=messages, tools=tools) + + # Check that the assistant's reply was parsed correctly + assert len(reply_message.tool_use) == 1 + assert reply_message.tool_use[0].name == "dummy_tool" + + # Check that the request payload was constructed correctly + # For 'o1' models, the system prompt and user query are combined + expected_user_content = ANY # We can use ANY because the exact content is constructed dynamically + expected_messages = [ + { + "role": "user", + "content": expected_user_content, + } + ] + mock_post.assert_called_once() + args, kwargs = mock_post.call_args + assert kwargs["json"]["model"] == model + assert kwargs["json"]["messages"] == expected_messages + assert "tools" not in kwargs["json"] # 'o1' models should not have 'tools' in the payload + + @pytest.mark.integration def test_openai_integration(): provider = OpenAiProvider.from_env() diff --git a/tests/test_exchange.py b/tests/test_exchange.py index f01ef46..a06baf6 100644 --- a/tests/test_exchange.py +++ b/tests/test_exchange.py @@ -1,6 +1,7 @@ from typing import List, Tuple import pytest +import json from exchange.checkpoint import Checkpoint, CheckpointData from exchange.content import Text, ToolResult, ToolUse @@ -95,6 +96,60 @@ def test_reply_with_unsupported_tool(): assert isinstance(content, ToolResult) and content.is_error and "no tool exists" in content.output.lower() +def test_generate_with_o1_model(): + from exchange.providers.openai import parse_o1_assistant_reply + + # Mock provider to simulate OpenAI 'o1' model behavior + class MockO1Provider(Provider): + def complete(self, model: str, system: str, messages: List[Message], tools: Tuple[Tool]): + # Simulate assistant's reply with a tool call in JSON format + reply_content = json.dumps({"tool_calls": [{"function": "dummy_tool", "parameters": {}}]}) + message = parse_o1_assistant_reply(reply_content) + usage = Usage(input_tokens=10, output_tokens=20, total_tokens=30) + return message, usage + + ex = Exchange( + provider=MockO1Provider(), + model="o1-mini", + system="You are a helpful assistant.", + tools=[Tool.from_function(dummy_tool)], + ) + ex.add(Message.user("Call dummy tool twice, then msg user.")) + message = ex.generate() + + assert message.tool_use[0].name == "dummy_tool" + + +def test_reply_with_o1_model(): + from exchange.providers.openai import parse_o1_assistant_reply + + # Mock provider to simulate OpenAI 'o1' model behavior + class MockO1Provider(Provider): + tool_call_count = 0 + + def complete(self, model: str, system: str, messages: List[Message], tools: Tuple[Tool]): + # Simulate assistant's reply with a tool call in JSON format + if self.tool_call_count < 2: + reply_content = json.dumps({"tool_calls": [{"function": "dummy_tool", "parameters": {}}]}) + self.tool_call_count += 1 + else: + reply_content = json.dumps({"message": "Hi user!"}) + message = parse_o1_assistant_reply(reply_content) + usage = Usage(input_tokens=10, output_tokens=20, total_tokens=30) + return message, usage + + ex = Exchange( + provider=MockO1Provider(), + model="o1-mini", + system="You are a helpful assistant.", + tools=[Tool.from_function(dummy_tool)], + ) + ex.add(Message.user("Call dummy tool twice, then msg user.")) + reply = ex.reply() + + assert reply.text == "Hi user!" + + def test_invalid_tool_parameters(): """Test handling of invalid tool parameters response""" ex = Exchange( From 4318cee6a60b7a5b7284d86166bae47b6f61ac8d Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Wed, 25 Sep 2024 14:41:49 -0400 Subject: [PATCH 3/3] remvoe print statement --- src/exchange/exchange.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/exchange/exchange.py b/src/exchange/exchange.py index d7932fa..301208c 100644 --- a/src/exchange/exchange.py +++ b/src/exchange/exchange.py @@ -107,7 +107,6 @@ def reply(self, max_tool_use: int = 128) -> Message: for tool_use in response.tool_use: tool_result = self.call_function(tool_use) content.append(tool_result) - print(f"Tool call result - user content: {content}") self.add(Message(role="user", content=content)) # We've reached the limit of tool calls - break out of the loop