Update AI providers for tool call improvements

Provider updates across all supported backends:
- Anthropic: Better tool call handling
- OpenAI: Improved response parsing
- Gemini: Enhanced compatibility
- Ollama: Local model support improvements

Includes test updates for OpenAI provider.
This commit is contained in:
rcourtman
2026-01-28 16:51:18 +00:00
parent 7be3ab2c1a
commit 641d29a16b
6 changed files with 202 additions and 167 deletions

View File

@@ -62,12 +62,20 @@ func (c *AnthropicClient) Name() string {
// anthropicRequest is the request body for the Anthropic API
type anthropicRequest struct {
Model string `json:"model"`
Messages []anthropicMessage `json:"messages"`
MaxTokens int `json:"max_tokens"`
System string `json:"system,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Tools []anthropicTool `json:"tools,omitempty"`
Model string `json:"model"`
Messages []anthropicMessage `json:"messages"`
MaxTokens int `json:"max_tokens"`
System string `json:"system,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Tools []anthropicTool `json:"tools,omitempty"`
ToolChoice *anthropicToolChoice `json:"tool_choice,omitempty"`
}
// anthropicToolChoice controls how Claude selects tools
// See: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/implement-tool-use#forcing-tool-use
type anthropicToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", or "none"
Name string `json:"name,omitempty"` // Only used when Type is "tool"
}
type anthropicMessage struct {
@@ -230,6 +238,16 @@ func (c *AnthropicClient) Chat(ctx context.Context, req ChatRequest) (*ChatRespo
}
}
// Add tool_choice if specified
// This controls whether Claude MUST use tools vs just being able to
// See: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/implement-tool-use#forcing-tool-use
if req.ToolChoice != nil {
anthropicReq.ToolChoice = &anthropicToolChoice{
Type: string(req.ToolChoice.Type),
Name: req.ToolChoice.Name,
}
}
body, err := json.Marshal(anthropicReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
@@ -441,13 +459,14 @@ func (c *AnthropicClient) SupportsThinking(model string) bool {
// anthropicStreamRequest is the request body for streaming API calls
type anthropicStreamRequest struct {
Model string `json:"model"`
Messages []anthropicMessage `json:"messages"`
MaxTokens int `json:"max_tokens"`
System string `json:"system,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Tools []anthropicTool `json:"tools,omitempty"`
Stream bool `json:"stream"`
Model string `json:"model"`
Messages []anthropicMessage `json:"messages"`
MaxTokens int `json:"max_tokens"`
System string `json:"system,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Tools []anthropicTool `json:"tools,omitempty"`
ToolChoice *anthropicToolChoice `json:"tool_choice,omitempty"`
Stream bool `json:"stream"`
}
// anthropicStreamEvent represents a streaming event from the Anthropic API
@@ -565,6 +584,14 @@ func (c *AnthropicClient) ChatStream(ctx context.Context, req ChatRequest, callb
}
}
// Add tool_choice if specified (same as non-streaming)
if req.ToolChoice != nil {
anthropicReq.ToolChoice = &anthropicToolChoice{
Type: string(req.ToolChoice.Type),
Name: req.ToolChoice.Name,
}
}
body, err := json.Marshal(anthropicReq)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)

View File

@@ -60,6 +60,17 @@ type geminiRequest struct {
SystemInstruction *geminiContent `json:"systemInstruction,omitempty"`
GenerationConfig *geminiGenerationConfig `json:"generationConfig,omitempty"`
Tools []geminiToolDef `json:"tools,omitempty"`
ToolConfig *geminiToolConfig `json:"toolConfig,omitempty"`
}
// geminiToolConfig controls how the model uses tools
// See: https://ai.google.dev/api/caching#ToolConfig
type geminiToolConfig struct {
FunctionCallingConfig *geminiFunctionCallingConfig `json:"functionCallingConfig,omitempty"`
}
type geminiFunctionCallingConfig struct {
Mode string `json:"mode"` // AUTO, ANY, or NONE
}
type geminiContent struct {
@@ -140,6 +151,28 @@ type geminiError struct {
} `json:"error"`
}
// convertToolChoiceToGemini converts our ToolChoice to Gemini's mode string
// Gemini uses: AUTO (default), ANY (force tool use), NONE (no tools)
// See: https://ai.google.dev/api/caching#FunctionCallingConfig
func convertToolChoiceToGemini(tc *ToolChoice) string {
if tc == nil {
return "AUTO"
}
switch tc.Type {
case ToolChoiceAuto:
return "AUTO"
case ToolChoiceNone:
return "NONE"
case ToolChoiceAny:
return "ANY"
case ToolChoiceTool:
// Gemini doesn't support forcing a specific tool, fall back to ANY
return "ANY"
default:
return "AUTO"
}
}
// Chat sends a chat request to the Gemini API
func (c *GeminiClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, error) {
// Convert messages to Gemini format
@@ -244,8 +277,13 @@ func (c *GeminiClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
geminiReq.GenerationConfig.Temperature = req.Temperature
}
// Add tools if provided
if len(req.Tools) > 0 {
// Add tools if provided (unless ToolChoice is None)
shouldAddTools := len(req.Tools) > 0
if req.ToolChoice != nil && req.ToolChoice.Type == ToolChoiceNone {
shouldAddTools = false
}
if shouldAddTools {
funcDecls := make([]geminiFunctionDeclaration, 0, len(req.Tools))
for _, t := range req.Tools {
// Skip non-function tools
@@ -260,6 +298,15 @@ func (c *GeminiClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
}
if len(funcDecls) > 0 {
geminiReq.Tools = []geminiToolDef{{FunctionDeclarations: funcDecls}}
// Add tool_config based on ToolChoice
// Gemini uses: AUTO (default), ANY (force tool use), NONE (no tools)
geminiReq.ToolConfig = &geminiToolConfig{
FunctionCallingConfig: &geminiFunctionCallingConfig{
Mode: convertToolChoiceToGemini(req.ToolChoice),
},
}
log.Debug().Int("tool_count", len(funcDecls)).Strs("tool_names", func() []string {
names := make([]string, len(funcDecls))
for i, f := range funcDecls {
@@ -615,7 +662,13 @@ func (c *GeminiClient) ChatStream(ctx context.Context, req ChatRequest, callback
geminiReq.GenerationConfig.Temperature = req.Temperature
}
if len(req.Tools) > 0 {
// Add tools if provided (unless ToolChoice is None) - same as non-streaming
shouldAddTools := len(req.Tools) > 0
if req.ToolChoice != nil && req.ToolChoice.Type == ToolChoiceNone {
shouldAddTools = false
}
if shouldAddTools {
funcDecls := make([]geminiFunctionDeclaration, 0, len(req.Tools))
for _, t := range req.Tools {
if t.Type != "" && t.Type != "function" {
@@ -629,6 +682,23 @@ func (c *GeminiClient) ChatStream(ctx context.Context, req ChatRequest, callback
}
if len(funcDecls) > 0 {
geminiReq.Tools = []geminiToolDef{{FunctionDeclarations: funcDecls}}
// Add tool_config based on ToolChoice (same as non-streaming)
geminiReq.ToolConfig = &geminiToolConfig{
FunctionCallingConfig: &geminiFunctionCallingConfig{
Mode: convertToolChoiceToGemini(req.ToolChoice),
},
}
// Log tool names for debugging tool selection issues
toolNames := make([]string, len(funcDecls))
for i, f := range funcDecls {
toolNames[i] = f.Name
}
log.Debug().
Int("tool_count", len(funcDecls)).
Strs("tool_names", toolNames).
Msg("Gemini stream request includes tools")
}
}
@@ -637,6 +707,12 @@ func (c *GeminiClient) ChatStream(ctx context.Context, req ChatRequest, callback
return fmt.Errorf("failed to marshal request: %w", err)
}
// Log the full request body for debugging (at trace level to avoid noise)
log.Trace().
Str("model", model).
RawJSON("request_body", body).
Msg("Gemini stream request body")
// Use streamGenerateContent endpoint for streaming
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s&alt=sse", c.baseURL, model, c.apiKey)
@@ -725,6 +801,10 @@ func (c *GeminiClient) ChatStream(ctx context.Context, req ChatRequest, callback
if len(signature) == 0 {
signature = part.ThoughtSignatureSnake
}
log.Debug().
Str("tool_name", part.FunctionCall.Name).
Interface("tool_args", part.FunctionCall.Args).
Msg("Gemini called tool")
callback(StreamEvent{
Type: "tool_start",
Data: ToolStartEvent{
@@ -824,35 +904,6 @@ func (c *GeminiClient) ListModels(ctx context.Context) ([]ModelInfo, error) {
// Extract model ID from the full name (e.g., "models/gemini-1.5-pro" -> "gemini-1.5-pro")
modelID := strings.TrimPrefix(m.Name, "models/")
// Only include the useful Gemini models for chat/agentic tasks
// Filter out Gemma (open-source, no function calling), embedding, AQA, vision-only models
// Keep: gemini-3-*, gemini-2.5-*, gemini-2.0-*, gemini-1.5-* (pro and flash variants)
isUsefulModel := false
usefulPrefixes := []string{
"gemini-3-pro", "gemini-3-flash",
"gemini-2.5-pro", "gemini-2.5-flash",
"gemini-2.0-pro", "gemini-2.0-flash",
"gemini-1.5-pro", "gemini-1.5-flash",
"gemini-flash", "gemini-pro", // Latest aliases
}
for _, prefix := range usefulPrefixes {
if strings.HasPrefix(modelID, prefix) {
isUsefulModel = true
break
}
}
if !isUsefulModel {
continue
}
// Skip experimental/deprecated variants
if strings.Contains(modelID, "exp-") ||
strings.Contains(modelID, "-exp") ||
strings.Contains(modelID, "tuning") ||
strings.Contains(modelID, "8b") { // Skip smaller variants
continue
}
models = append(models, ModelInfo{
ID: modelID,
Name: m.DisplayName,

View File

@@ -155,9 +155,8 @@ func (c *OllamaClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
if model == "" {
model = c.model
}
// Ultimate fallback - if no model configured anywhere, use llama3
if model == "" {
model = "llama3"
return nil, fmt.Errorf("no model specified")
}
ollamaReq := ollamaRequest{
@@ -167,7 +166,14 @@ func (c *OllamaClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
}
// Convert tools to Ollama format
if len(req.Tools) > 0 {
// Note: Ollama doesn't support tool_choice like Anthropic/OpenAI
// We handle ToolChoiceNone by not adding tools, but can't force tool use
shouldAddTools := len(req.Tools) > 0
if req.ToolChoice != nil && req.ToolChoice.Type == ToolChoiceNone {
shouldAddTools = false
}
if shouldAddTools {
ollamaReq.Tools = make([]ollamaTool, 0, len(req.Tools))
for _, t := range req.Tools {
// Skip non-function tools (like web_search which Ollama doesn't support)
@@ -318,7 +324,7 @@ func (c *OllamaClient) ChatStream(ctx context.Context, req ChatRequest, callback
model = c.model
}
if model == "" {
model = "llama3"
return fmt.Errorf("no model specified")
}
ollamaReq := ollamaRequest{
@@ -327,7 +333,13 @@ func (c *OllamaClient) ChatStream(ctx context.Context, req ChatRequest, callback
Stream: true, // Enable streaming
}
if len(req.Tools) > 0 {
// Handle tools with tool_choice support (same as non-streaming)
shouldAddTools := len(req.Tools) > 0
if req.ToolChoice != nil && req.ToolChoice.Type == ToolChoiceNone {
shouldAddTools = false
}
if shouldAddTools {
ollamaReq.Tools = make([]ollamaTool, 0, len(req.Tools))
for _, t := range req.Tools {
if t.Type != "" && t.Type != "function" {

View File

@@ -87,23 +87,6 @@ type openaiRequest struct {
ToolChoice interface{} `json:"tool_choice,omitempty"` // "auto", "none", or specific tool
}
// deepseekRequest extends openaiRequest with DeepSeek-specific fields
type deepseekRequest struct {
Model string `json:"model"`
Messages []openaiMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
Tools []openaiTool `json:"tools,omitempty"`
ToolChoice interface{} `json:"tool_choice,omitempty"`
}
// openaiCompletionsRequest is for non-chat models like gpt-5.2-pro that use /v1/completions
type openaiCompletionsRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}
// openaiTool represents a function tool in OpenAI format
type openaiTool struct {
Type string `json:"type"` // always "function"
@@ -147,8 +130,7 @@ type openaiResponse struct {
type openaiChoice struct {
Index int `json:"index"`
Message openaiRespMsg `json:"message"` // For chat completions
Text string `json:"text"` // For completions API (non-chat models)
Message openaiRespMsg `json:"message"`
FinishReason string `json:"finish_reason"` // "stop", "tool_calls", etc.
}
@@ -186,19 +168,37 @@ func (c *OpenAIClient) isDeepSeekReasoner() bool {
}
// requiresMaxCompletionTokens returns true for models that need max_completion_tokens instead of max_tokens
// Per OpenAI docs, o1/o3/o4 reasoning models require max_completion_tokens; max_tokens will error.
func (c *OpenAIClient) requiresMaxCompletionTokens(model string) bool {
// o1, o1-mini, o1-preview, o3, o3-mini, o4-mini, gpt-5.2, etc.
return strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") || strings.HasPrefix(model, "o4") || strings.HasPrefix(model, "gpt-5")
return strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") || strings.HasPrefix(model, "o4")
}
// isGPT52NonChat returns true if using GPT-5.2 models that require /v1/completions endpoint
// Only gpt-5.2-chat-latest uses chat completions; gpt-5.2, gpt-5.2-pro use completions
func (c *OpenAIClient) isGPT52NonChat(model string) bool {
if !strings.HasPrefix(model, "gpt-5.2") {
return false
// convertToolChoiceToOpenAI converts our ToolChoice to OpenAI's format
// OpenAI uses "required" instead of Anthropic's "any" to force tool use
// See: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
func convertToolChoiceToOpenAI(tc *ToolChoice) interface{} {
if tc == nil {
return "auto"
}
switch tc.Type {
case ToolChoiceAuto:
return "auto"
case ToolChoiceNone:
return "none"
case ToolChoiceAny:
// OpenAI uses "required" to force the model to use one of the provided tools
return "required"
case ToolChoiceTool:
// Force a specific tool
return map[string]interface{}{
"type": "function",
"function": map[string]string{
"name": tc.Name,
},
}
default:
return "auto"
}
// gpt-5.2-chat-latest is the only chat model
return !strings.Contains(model, "chat")
}
// Chat sends a chat request to the OpenAI API
@@ -309,42 +309,16 @@ func (c *OpenAIClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
})
}
if len(openaiReq.Tools) > 0 {
openaiReq.ToolChoice = "auto"
// Map ToolChoice to OpenAI format
// OpenAI uses "required" instead of Anthropic's "any"
openaiReq.ToolChoice = convertToolChoiceToOpenAI(req.ToolChoice)
}
}
// Log actual model being sent (INFO level for visibility)
log.Info().Str("model_in_request", openaiReq.Model).Str("base_url", c.baseURL).Msg("Sending OpenAI/DeepSeek request")
var body []byte
var err error
// GPT-5.2 non-chat models need completions format (prompt instead of messages)
if c.isGPT52NonChat(model) {
// Convert messages to a single prompt string
var promptBuilder strings.Builder
if req.System != "" {
promptBuilder.WriteString("System: ")
promptBuilder.WriteString(req.System)
promptBuilder.WriteString("\n\n")
}
for _, m := range req.Messages {
promptBuilder.WriteString(m.Role)
promptBuilder.WriteString(": ")
promptBuilder.WriteString(m.Content)
promptBuilder.WriteString("\n\n")
}
promptBuilder.WriteString("Assistant: ")
completionsReq := openaiCompletionsRequest{
Model: model,
Prompt: promptBuilder.String(),
MaxCompletionTokens: req.MaxTokens,
}
body, err = json.Marshal(completionsReq)
} else {
body, err = json.Marshal(openaiReq)
}
body, err := json.Marshal(openaiReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
@@ -370,14 +344,7 @@ func (c *OpenAIClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
}
}
// Use the appropriate endpoint
endpoint := c.baseURL
if c.isGPT52NonChat(model) && strings.Contains(c.baseURL, "api.openai.com") {
// GPT-5.2 non-chat models need completions endpoint
endpoint = strings.Replace(c.baseURL, "/chat/completions", "/completions", 1)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(body))
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.baseURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
@@ -449,10 +416,6 @@ func (c *OpenAIClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
// For DeepSeek reasoner, the actual content may be in reasoning_content
// when content is empty (it shows the "thinking" but that's the full response)
contentToUse := choice.Message.Content
// Completions API uses Text instead of Message.Content
if contentToUse == "" && choice.Text != "" {
contentToUse = choice.Text
}
if contentToUse == "" && choice.Message.ReasoningContent != "" {
// DeepSeek reasoner puts output in reasoning_content
contentToUse = choice.Message.ReasoningContent
@@ -679,7 +642,8 @@ func (c *OpenAIClient) ChatStream(ctx context.Context, req ChatRequest, callback
})
}
if len(openaiReq.Tools) > 0 {
openaiReq.ToolChoice = "auto"
// Map ToolChoice to OpenAI format (same as non-streaming)
openaiReq.ToolChoice = convertToolChoiceToOpenAI(req.ToolChoice)
}
}

View File

@@ -269,50 +269,10 @@ func TestOpenAIClient_Chat_Success(t *testing.T) {
assert.Equal(t, 3, resp.OutputTokens)
}
func TestOpenAIClient_Chat_GPT52NonChat(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/chat/completions", r.URL.Path)
var req openaiCompletionsRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
assert.Equal(t, "gpt-5.2-pro", req.Model)
assert.Contains(t, req.Prompt, "System: sys")
assert.Contains(t, req.Prompt, "user: hi")
assert.Equal(t, 55, req.MaxCompletionTokens)
_ = json.NewEncoder(w).Encode(openaiResponse{
ID: "cmpl-1",
Model: "gpt-5.2-pro",
Choices: []openaiChoice{
{
Text: "Answer",
FinishReason: "stop",
},
},
Usage: openaiUsage{PromptTokens: 3, CompletionTokens: 4},
})
}))
defer server.Close()
client := NewOpenAIClient("sk-test", "gpt-5.2-pro", server.URL, 0)
resp, err := client.Chat(context.Background(), ChatRequest{
System: "sys",
MaxTokens: 55,
Messages: []Message{
{Role: "user", Content: "hi"},
},
})
require.NoError(t, err)
assert.Equal(t, "Answer", resp.Content)
}
func TestOpenAIClient_HelperFlags(t *testing.T) {
client := NewOpenAIClient("sk", "gpt-4", "https://api.openai.com", 0)
assert.True(t, client.requiresMaxCompletionTokens("o1-mini"))
assert.False(t, client.requiresMaxCompletionTokens("gpt-4"))
assert.True(t, client.isGPT52NonChat("gpt-5.2-pro"))
assert.False(t, client.isGPT52NonChat("gpt-5.2-chat-latest"))
}
func TestOpenAIClient_SupportsThinking(t *testing.T) {

View File

@@ -39,14 +39,35 @@ type Tool struct {
MaxUses int `json:"max_uses,omitempty"` // For web search: limit searches per request
}
// ToolChoiceType represents how the model should choose tools
type ToolChoiceType string
const (
// ToolChoiceAuto lets the model decide whether to use tools (default)
ToolChoiceAuto ToolChoiceType = "auto"
// ToolChoiceAny forces the model to use one of the provided tools
ToolChoiceAny ToolChoiceType = "any"
// ToolChoiceNone prevents the model from using any tools
ToolChoiceNone ToolChoiceType = "none"
// ToolChoiceTool forces the model to use a specific tool (set ToolName)
ToolChoiceTool ToolChoiceType = "tool"
)
// ToolChoice controls how the model selects tools
type ToolChoice struct {
Type ToolChoiceType `json:"type"`
Name string `json:"name,omitempty"` // Only used when Type is ToolChoiceTool
}
// ChatRequest represents a request to the AI provider
type ChatRequest struct {
Messages []Message `json:"messages"`
Model string `json:"model"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
System string `json:"system,omitempty"` // System prompt (Anthropic style)
Tools []Tool `json:"tools,omitempty"` // Available tools
Messages []Message `json:"messages"`
Model string `json:"model"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
System string `json:"system,omitempty"` // System prompt (Anthropic style)
Tools []Tool `json:"tools,omitempty"` // Available tools
ToolChoice *ToolChoice `json:"tool_choice,omitempty"` // How to select tools (nil = auto)
}
// ChatResponse represents a response from the AI provider