Skip to content

Commit

Permalink
Config for max rounds and concurrency of tool executions
Browse files Browse the repository at this point in the history
  • Loading branch information
raulraja committed Jun 12, 2024
1 parent 6543b89 commit b70dbc5
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ private suspend fun <A> Chat.promptWithFunctions(
invokeSerializer: suspend (FunctionCall) -> A = { serializer.invoke(it) }
): A {
usageTracker.llmCalls++
validateMaxToolCallsPerRound(prompt, usageTracker)
val result =
promptWithResponse(prompt, scope, tools.map { it.function }) { response ->
val responseUsage = response.usage
Expand All @@ -107,7 +108,7 @@ private suspend fun <A> Chat.promptWithFunctions(
result
} else {
val callRequestedMessages = listOf(assistantRequestedCallMessage(calls))
val resultMessages = callResultMessages(calls, tools, collector)
val resultMessages = callResultMessages(prompt, calls, tools, collector)
repeat(resultMessages.size) { usageTracker.toolInvocations++ }
val promptWithToolOutputs =
prompt.copy(messages = prompt.messages + callRequestedMessages + resultMessages)
Expand All @@ -126,12 +127,21 @@ private suspend fun <A> Chat.promptWithFunctions(
return result
}

private fun validateMaxToolCallsPerRound(prompt: Prompt, usageTracker: UsageTracker) {
if (usageTracker.toolInvocations >= prompt.configuration.maxToolCallsPerRound) {
error(
"Too many tool calls in this round: ${usageTracker.toolInvocations}, max allowed: ${prompt.configuration.maxToolCallsPerRound}"
)
}
}

private suspend fun callResultMessages(
prompt: Prompt,
calls: List<FunctionCall>,
functions: List<Tool<*>>,
collector: ProducerScope<AIEvent<*>>?
): List<ChatCompletionRequestMessage> =
calls.parMapNotNull { call ->
calls.parMapNotNull(concurrency = prompt.configuration.concurrentToolCallsPerRound) { call ->
val tool = functions.firstOrNull { it.function.name == call.functionName }
tool?.let { collector?.send(AIEvent.ToolExecutionRequest(it, call.arguments)) }
val invokeTool = tool?.invoke
Expand Down Expand Up @@ -251,7 +261,11 @@ private fun chatCompletionToolChoiceOption(adaptedPrompt: Prompt): ChatCompletio
function = ChatCompletionNamedToolChoiceFunction(adaptedPrompt.functions.first().name)
)
)
else ChatCompletionToolChoiceOption.CaseString("required")
else {
if (adaptedPrompt.model is CreateChatCompletionRequestModel.Custom)
ChatCompletionToolChoiceOption.CaseString("auto")
else ChatCompletionToolChoiceOption.CaseString("required")
}

private fun chatCompletionTools(adaptedPrompt: Prompt): List<ChatCompletionTool> =
adaptedPrompt.functions.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ data class PromptConfiguration
@JvmOverloads
constructor(
var maxDeserializationAttempts: Int = 3,
var maxToolCallsPerRound: Int = 10000,
var concurrentToolCallsPerRound: Int = 5,
var user: String = ChatCompletionRole.user.value,
var temperature: Double = 0.4,
var numberOfPredictions: Int = 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ import kotlinx.serialization.Serializable
val ballCupLocation = 47

suspend fun ballLocationInfoFromLastCupTried(input: Int): String {
val tip = if (input < ballCupLocation) "higher" else "lower"
val recommendedCup =
if (input < ballCupLocation) (input + 1)..ballCupLocation else ballCupLocation until input
return "The ball is not under cup number $input. Try a cup with a $tip number. We recommend trying cup ${recommendedCup.random()}, ${recommendedCup.random()}, ${recommendedCup.random()} next"
return when {
input < ballCupLocation -> "${(input..ballCupLocation).random()} may have the ball."
input > ballCupLocation -> "${(ballCupLocation..input).random()} may have the ball."
else -> "The ball is under cup number $input."
}
}

fun lookUnderCupNumber(cupNumber: Int): String =
Expand All @@ -35,7 +36,9 @@ suspend fun main() {
listOf(
Tool(
::ballLocationInfoFromLastCupTried,
Description("Get a tip on where the ball is based on the last cup number tried.")
Description(
"Get a tip on where the ball is based on the last cup number tried. Request this tool in parallel with different cup numbers to get multiple tips."
)
),
Tool(::lookUnderCupNumber, Description("Look under a cup to find the ball."))
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.xebia.functional.xef.dsl.chat

import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.xef.AI
import com.xebia.functional.xef.AIConfig
import com.xebia.functional.xef.AIEvent
Expand All @@ -26,6 +27,7 @@ suspend fun main() {
prompt = "Where is the ball? use the available tools to find out.",
config =
AIConfig(
model = CreateChatCompletionRequestModel.gpt_3_5_turbo,
tools =
listOf(
Tool.suspend(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ sealed class Response {

fun informationOnPlanet(planet: Planet): String =
when (planet.name) {
"Mars" -> "It's a secret but Mars has 100 moons."
"Mars" -> "It's a secret but we just discovered that Mars has 100 moons."
else -> "I don't have information on the number of moons for ${planet.name}."
}

Expand All @@ -34,7 +34,7 @@ suspend fun main() {

val other =
AI<Flow<AIEvent<Response>>>(
prompt = "How many moons does Mars have?",
prompt = "How many moons does Mars have based on the recent discovery?",
config = AIConfig(tools = listOf(Tool(::informationOnPlanet)))
)
other.collect { it.debugPrint() }
Expand Down

0 comments on commit b70dbc5

Please sign in to comment.