diff --git a/src/index.ts b/src/index.ts index 58dd2d0..e556a00 100644 --- a/src/index.ts +++ b/src/index.ts @@ -4,6 +4,7 @@ import { FunctionDef, formatFunctionDefinitions } from "./functions"; type Message = OpenAI.Chat.CreateChatCompletionRequestMessage; type Function = OpenAI.Chat.CompletionCreateParams.Function; +type FunctionCall = OpenAI.Chat.CompletionCreateParams.FunctionCallOption; let encoder: Tiktoken | undefined; @@ -17,9 +18,11 @@ let encoder: Tiktoken | undefined; export function promptTokensEstimate({ messages, functions, + function_call, }: { messages: Message[]; functions?: Function[]; + function_call?: 'none' | 'auto' | FunctionCall; }): number { // It appears that if functions are present, the first system message is padded with a trailing newline. This // was inferred by trying lots of combinations of messages and functions and seeing what the token counts were. @@ -49,6 +52,13 @@ export function promptTokensEstimate({ tokens -= 4; } + // If function_call is 'none', add one token. + // If it's a FunctionCall object, add 4 + the number of tokens in the function name. + // If it's undefined or 'auto', don't add anything. + if (function_call && function_call !== 'auto') { + tokens += function_call === 'none' ? 1 : stringTokens(function_call.name) + 4; + } + return tokens; } diff --git a/tests/token-counts.test.ts b/tests/token-counts.test.ts index 2d9ec03..3bc51d0 100644 --- a/tests/token-counts.test.ts +++ b/tests/token-counts.test.ts @@ -3,9 +3,11 @@ import { promptTokensEstimate } from "../src"; type Message = OpenAI.Chat.CreateChatCompletionRequestMessage; type Function = OpenAI.Chat.CompletionCreateParams.Function; +type FunctionCall = OpenAI.Chat.CompletionCreateParams.FunctionCallOption; type Example = { messages: Message[]; functions?: Function[]; + function_call?: "none" | "auto" | FunctionCall; tokens: number; validate?: boolean; }; @@ -109,6 +111,39 @@ const TEST_CASES: Example[] = [ ], tokens: 31, }, + { + messages: [{ role: "user", content: "hello" }], + functions: [ + { + name: "foo", + parameters: { type: "object", properties: {} }, + }, + ], + function_call: "none", + tokens: 32, + }, + { + messages: [{ role: "user", content: "hello" }], + functions: [ + { + name: "foo", + parameters: { type: "object", properties: {} }, + }, + ], + function_call: "auto", + tokens: 31, + }, + { + messages: [{ role: "user", content: "hello" }], + functions: [ + { + name: "foo", + parameters: { type: "object", properties: {} }, + }, + ], + function_call: { name: "foo" }, + tokens: 36, + }, { messages: [{ role: "user", content: "hello" }], functions: [ @@ -263,6 +298,31 @@ const TEST_CASES: Example[] = [ ], tokens: 40, }, + { + messages: [ + { role: "system", content: "Hello:" }, + { role: "system", content: "Hello" }, + { role: "user", content: "Hi there" }, + ], + functions: [ + { name: "do_stuff", parameters: { type: "object", properties: {} } }, + { name: "do_other_stuff", parameters: { type: "object", properties: {} } }, + ], + tokens: 49, + }, + { + messages: [ + { role: "system", content: "Hello:" }, + { role: "system", content: "Hello" }, + { role: "user", content: "Hi there" }, + ], + functions: [ + { name: "do_stuff", parameters: { type: "object", properties: {} } }, + { name: "do_other_stuff", parameters: { type: "object", properties: {} } }, + ], + function_call: { name: "do_stuff" }, + tokens: 55, + }, { messages: [{ role: "user", content: "hello" }], functions: [ @@ -394,6 +454,7 @@ describe.each(TEST_CASES)("token counts (%j)", (example) => { model: "gpt-3.5-turbo", messages: example.messages, functions: example.functions as any, + function_call: example.function_call, max_tokens: 10, }); expect(response.usage?.prompt_tokens).toBe(example.tokens); @@ -406,6 +467,7 @@ describe.each(TEST_CASES)("token counts (%j)", (example) => { promptTokensEstimate({ messages: example.messages, functions: example.functions, + function_call: example.function_call, }), ).toBe(example.tokens); });