diff --git a/internal/ai/providers/anthropic.go b/internal/ai/providers/anthropic.go index 13a24f03d..9082f0c85 100644 --- a/internal/ai/providers/anthropic.go +++ b/internal/ai/providers/anthropic.go @@ -62,12 +62,20 @@ func (c *AnthropicClient) Name() string { // anthropicRequest is the request body for the Anthropic API type anthropicRequest struct { - Model string `json:"model"` - Messages []anthropicMessage `json:"messages"` - MaxTokens int `json:"max_tokens"` - System string `json:"system,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - Tools []anthropicTool `json:"tools,omitempty"` + Model string `json:"model"` + Messages []anthropicMessage `json:"messages"` + MaxTokens int `json:"max_tokens"` + System string `json:"system,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Tools []anthropicTool `json:"tools,omitempty"` + ToolChoice *anthropicToolChoice `json:"tool_choice,omitempty"` +} + +// anthropicToolChoice controls how Claude selects tools +// See: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/implement-tool-use#forcing-tool-use +type anthropicToolChoice struct { + Type string `json:"type"` // "auto", "any", "tool", or "none" + Name string `json:"name,omitempty"` // Only used when Type is "tool" } type anthropicMessage struct { @@ -230,6 +238,16 @@ func (c *AnthropicClient) Chat(ctx context.Context, req ChatRequest) (*ChatRespo } } + // Add tool_choice if specified + // This controls whether Claude MUST use tools vs just being able to + // See: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/implement-tool-use#forcing-tool-use + if req.ToolChoice != nil { + anthropicReq.ToolChoice = &anthropicToolChoice{ + Type: string(req.ToolChoice.Type), + Name: req.ToolChoice.Name, + } + } + body, err := json.Marshal(anthropicReq) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) @@ -441,13 +459,14 @@ func (c *AnthropicClient) SupportsThinking(model string) bool { // anthropicStreamRequest is the request body for streaming API calls type anthropicStreamRequest struct { - Model string `json:"model"` - Messages []anthropicMessage `json:"messages"` - MaxTokens int `json:"max_tokens"` - System string `json:"system,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - Tools []anthropicTool `json:"tools,omitempty"` - Stream bool `json:"stream"` + Model string `json:"model"` + Messages []anthropicMessage `json:"messages"` + MaxTokens int `json:"max_tokens"` + System string `json:"system,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Tools []anthropicTool `json:"tools,omitempty"` + ToolChoice *anthropicToolChoice `json:"tool_choice,omitempty"` + Stream bool `json:"stream"` } // anthropicStreamEvent represents a streaming event from the Anthropic API @@ -565,6 +584,14 @@ func (c *AnthropicClient) ChatStream(ctx context.Context, req ChatRequest, callb } } + // Add tool_choice if specified (same as non-streaming) + if req.ToolChoice != nil { + anthropicReq.ToolChoice = &anthropicToolChoice{ + Type: string(req.ToolChoice.Type), + Name: req.ToolChoice.Name, + } + } + body, err := json.Marshal(anthropicReq) if err != nil { return fmt.Errorf("failed to marshal request: %w", err) diff --git a/internal/ai/providers/gemini.go b/internal/ai/providers/gemini.go index 8c6f294b9..c9bdfadd8 100644 --- a/internal/ai/providers/gemini.go +++ b/internal/ai/providers/gemini.go @@ -60,6 +60,17 @@ type geminiRequest struct { SystemInstruction *geminiContent `json:"systemInstruction,omitempty"` GenerationConfig *geminiGenerationConfig `json:"generationConfig,omitempty"` Tools []geminiToolDef `json:"tools,omitempty"` + ToolConfig *geminiToolConfig `json:"toolConfig,omitempty"` +} + +// geminiToolConfig controls how the model uses tools +// See: https://ai.google.dev/api/caching#ToolConfig +type geminiToolConfig struct { + FunctionCallingConfig *geminiFunctionCallingConfig `json:"functionCallingConfig,omitempty"` +} + +type geminiFunctionCallingConfig struct { + Mode string `json:"mode"` // AUTO, ANY, or NONE } type geminiContent struct { @@ -140,6 +151,28 @@ type geminiError struct { } `json:"error"` } +// convertToolChoiceToGemini converts our ToolChoice to Gemini's mode string +// Gemini uses: AUTO (default), ANY (force tool use), NONE (no tools) +// See: https://ai.google.dev/api/caching#FunctionCallingConfig +func convertToolChoiceToGemini(tc *ToolChoice) string { + if tc == nil { + return "AUTO" + } + switch tc.Type { + case ToolChoiceAuto: + return "AUTO" + case ToolChoiceNone: + return "NONE" + case ToolChoiceAny: + return "ANY" + case ToolChoiceTool: + // Gemini doesn't support forcing a specific tool, fall back to ANY + return "ANY" + default: + return "AUTO" + } +} + // Chat sends a chat request to the Gemini API func (c *GeminiClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, error) { // Convert messages to Gemini format @@ -244,8 +277,13 @@ func (c *GeminiClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse geminiReq.GenerationConfig.Temperature = req.Temperature } - // Add tools if provided - if len(req.Tools) > 0 { + // Add tools if provided (unless ToolChoice is None) + shouldAddTools := len(req.Tools) > 0 + if req.ToolChoice != nil && req.ToolChoice.Type == ToolChoiceNone { + shouldAddTools = false + } + + if shouldAddTools { funcDecls := make([]geminiFunctionDeclaration, 0, len(req.Tools)) for _, t := range req.Tools { // Skip non-function tools @@ -260,6 +298,15 @@ func (c *GeminiClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse } if len(funcDecls) > 0 { geminiReq.Tools = []geminiToolDef{{FunctionDeclarations: funcDecls}} + + // Add tool_config based on ToolChoice + // Gemini uses: AUTO (default), ANY (force tool use), NONE (no tools) + geminiReq.ToolConfig = &geminiToolConfig{ + FunctionCallingConfig: &geminiFunctionCallingConfig{ + Mode: convertToolChoiceToGemini(req.ToolChoice), + }, + } + log.Debug().Int("tool_count", len(funcDecls)).Strs("tool_names", func() []string { names := make([]string, len(funcDecls)) for i, f := range funcDecls { @@ -615,7 +662,13 @@ func (c *GeminiClient) ChatStream(ctx context.Context, req ChatRequest, callback geminiReq.GenerationConfig.Temperature = req.Temperature } - if len(req.Tools) > 0 { + // Add tools if provided (unless ToolChoice is None) - same as non-streaming + shouldAddTools := len(req.Tools) > 0 + if req.ToolChoice != nil && req.ToolChoice.Type == ToolChoiceNone { + shouldAddTools = false + } + + if shouldAddTools { funcDecls := make([]geminiFunctionDeclaration, 0, len(req.Tools)) for _, t := range req.Tools { if t.Type != "" && t.Type != "function" { @@ -629,6 +682,23 @@ func (c *GeminiClient) ChatStream(ctx context.Context, req ChatRequest, callback } if len(funcDecls) > 0 { geminiReq.Tools = []geminiToolDef{{FunctionDeclarations: funcDecls}} + + // Add tool_config based on ToolChoice (same as non-streaming) + geminiReq.ToolConfig = &geminiToolConfig{ + FunctionCallingConfig: &geminiFunctionCallingConfig{ + Mode: convertToolChoiceToGemini(req.ToolChoice), + }, + } + + // Log tool names for debugging tool selection issues + toolNames := make([]string, len(funcDecls)) + for i, f := range funcDecls { + toolNames[i] = f.Name + } + log.Debug(). + Int("tool_count", len(funcDecls)). + Strs("tool_names", toolNames). + Msg("Gemini stream request includes tools") } } @@ -637,6 +707,12 @@ func (c *GeminiClient) ChatStream(ctx context.Context, req ChatRequest, callback return fmt.Errorf("failed to marshal request: %w", err) } + // Log the full request body for debugging (at trace level to avoid noise) + log.Trace(). + Str("model", model). + RawJSON("request_body", body). + Msg("Gemini stream request body") + // Use streamGenerateContent endpoint for streaming url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s&alt=sse", c.baseURL, model, c.apiKey) @@ -725,6 +801,10 @@ func (c *GeminiClient) ChatStream(ctx context.Context, req ChatRequest, callback if len(signature) == 0 { signature = part.ThoughtSignatureSnake } + log.Debug(). + Str("tool_name", part.FunctionCall.Name). + Interface("tool_args", part.FunctionCall.Args). + Msg("Gemini called tool") callback(StreamEvent{ Type: "tool_start", Data: ToolStartEvent{ @@ -824,35 +904,6 @@ func (c *GeminiClient) ListModels(ctx context.Context) ([]ModelInfo, error) { // Extract model ID from the full name (e.g., "models/gemini-1.5-pro" -> "gemini-1.5-pro") modelID := strings.TrimPrefix(m.Name, "models/") - // Only include the useful Gemini models for chat/agentic tasks - // Filter out Gemma (open-source, no function calling), embedding, AQA, vision-only models - // Keep: gemini-3-*, gemini-2.5-*, gemini-2.0-*, gemini-1.5-* (pro and flash variants) - isUsefulModel := false - usefulPrefixes := []string{ - "gemini-3-pro", "gemini-3-flash", - "gemini-2.5-pro", "gemini-2.5-flash", - "gemini-2.0-pro", "gemini-2.0-flash", - "gemini-1.5-pro", "gemini-1.5-flash", - "gemini-flash", "gemini-pro", // Latest aliases - } - for _, prefix := range usefulPrefixes { - if strings.HasPrefix(modelID, prefix) { - isUsefulModel = true - break - } - } - if !isUsefulModel { - continue - } - - // Skip experimental/deprecated variants - if strings.Contains(modelID, "exp-") || - strings.Contains(modelID, "-exp") || - strings.Contains(modelID, "tuning") || - strings.Contains(modelID, "8b") { // Skip smaller variants - continue - } - models = append(models, ModelInfo{ ID: modelID, Name: m.DisplayName, diff --git a/internal/ai/providers/ollama.go b/internal/ai/providers/ollama.go index 762f9d22a..6fb500dac 100644 --- a/internal/ai/providers/ollama.go +++ b/internal/ai/providers/ollama.go @@ -155,9 +155,8 @@ func (c *OllamaClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse if model == "" { model = c.model } - // Ultimate fallback - if no model configured anywhere, use llama3 if model == "" { - model = "llama3" + return nil, fmt.Errorf("no model specified") } ollamaReq := ollamaRequest{ @@ -167,7 +166,14 @@ func (c *OllamaClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse } // Convert tools to Ollama format - if len(req.Tools) > 0 { + // Note: Ollama doesn't support tool_choice like Anthropic/OpenAI + // We handle ToolChoiceNone by not adding tools, but can't force tool use + shouldAddTools := len(req.Tools) > 0 + if req.ToolChoice != nil && req.ToolChoice.Type == ToolChoiceNone { + shouldAddTools = false + } + + if shouldAddTools { ollamaReq.Tools = make([]ollamaTool, 0, len(req.Tools)) for _, t := range req.Tools { // Skip non-function tools (like web_search which Ollama doesn't support) @@ -318,7 +324,7 @@ func (c *OllamaClient) ChatStream(ctx context.Context, req ChatRequest, callback model = c.model } if model == "" { - model = "llama3" + return fmt.Errorf("no model specified") } ollamaReq := ollamaRequest{ @@ -327,7 +333,13 @@ func (c *OllamaClient) ChatStream(ctx context.Context, req ChatRequest, callback Stream: true, // Enable streaming } - if len(req.Tools) > 0 { + // Handle tools with tool_choice support (same as non-streaming) + shouldAddTools := len(req.Tools) > 0 + if req.ToolChoice != nil && req.ToolChoice.Type == ToolChoiceNone { + shouldAddTools = false + } + + if shouldAddTools { ollamaReq.Tools = make([]ollamaTool, 0, len(req.Tools)) for _, t := range req.Tools { if t.Type != "" && t.Type != "function" { diff --git a/internal/ai/providers/openai.go b/internal/ai/providers/openai.go index 10a32c9e5..dde0822bd 100644 --- a/internal/ai/providers/openai.go +++ b/internal/ai/providers/openai.go @@ -87,23 +87,6 @@ type openaiRequest struct { ToolChoice interface{} `json:"tool_choice,omitempty"` // "auto", "none", or specific tool } -// deepseekRequest extends openaiRequest with DeepSeek-specific fields -type deepseekRequest struct { - Model string `json:"model"` - Messages []openaiMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Tools []openaiTool `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` -} - -// openaiCompletionsRequest is for non-chat models like gpt-5.2-pro that use /v1/completions -type openaiCompletionsRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` -} - // openaiTool represents a function tool in OpenAI format type openaiTool struct { Type string `json:"type"` // always "function" @@ -147,8 +130,7 @@ type openaiResponse struct { type openaiChoice struct { Index int `json:"index"` - Message openaiRespMsg `json:"message"` // For chat completions - Text string `json:"text"` // For completions API (non-chat models) + Message openaiRespMsg `json:"message"` FinishReason string `json:"finish_reason"` // "stop", "tool_calls", etc. } @@ -186,19 +168,37 @@ func (c *OpenAIClient) isDeepSeekReasoner() bool { } // requiresMaxCompletionTokens returns true for models that need max_completion_tokens instead of max_tokens +// Per OpenAI docs, o1/o3/o4 reasoning models require max_completion_tokens; max_tokens will error. func (c *OpenAIClient) requiresMaxCompletionTokens(model string) bool { - // o1, o1-mini, o1-preview, o3, o3-mini, o4-mini, gpt-5.2, etc. - return strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") || strings.HasPrefix(model, "o4") || strings.HasPrefix(model, "gpt-5") + return strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") || strings.HasPrefix(model, "o4") } -// isGPT52NonChat returns true if using GPT-5.2 models that require /v1/completions endpoint -// Only gpt-5.2-chat-latest uses chat completions; gpt-5.2, gpt-5.2-pro use completions -func (c *OpenAIClient) isGPT52NonChat(model string) bool { - if !strings.HasPrefix(model, "gpt-5.2") { - return false +// convertToolChoiceToOpenAI converts our ToolChoice to OpenAI's format +// OpenAI uses "required" instead of Anthropic's "any" to force tool use +// See: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice +func convertToolChoiceToOpenAI(tc *ToolChoice) interface{} { + if tc == nil { + return "auto" + } + switch tc.Type { + case ToolChoiceAuto: + return "auto" + case ToolChoiceNone: + return "none" + case ToolChoiceAny: + // OpenAI uses "required" to force the model to use one of the provided tools + return "required" + case ToolChoiceTool: + // Force a specific tool + return map[string]interface{}{ + "type": "function", + "function": map[string]string{ + "name": tc.Name, + }, + } + default: + return "auto" } - // gpt-5.2-chat-latest is the only chat model - return !strings.Contains(model, "chat") } // Chat sends a chat request to the OpenAI API @@ -309,42 +309,16 @@ func (c *OpenAIClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse }) } if len(openaiReq.Tools) > 0 { - openaiReq.ToolChoice = "auto" + // Map ToolChoice to OpenAI format + // OpenAI uses "required" instead of Anthropic's "any" + openaiReq.ToolChoice = convertToolChoiceToOpenAI(req.ToolChoice) } } // Log actual model being sent (INFO level for visibility) log.Info().Str("model_in_request", openaiReq.Model).Str("base_url", c.baseURL).Msg("Sending OpenAI/DeepSeek request") - var body []byte - var err error - - // GPT-5.2 non-chat models need completions format (prompt instead of messages) - if c.isGPT52NonChat(model) { - // Convert messages to a single prompt string - var promptBuilder strings.Builder - if req.System != "" { - promptBuilder.WriteString("System: ") - promptBuilder.WriteString(req.System) - promptBuilder.WriteString("\n\n") - } - for _, m := range req.Messages { - promptBuilder.WriteString(m.Role) - promptBuilder.WriteString(": ") - promptBuilder.WriteString(m.Content) - promptBuilder.WriteString("\n\n") - } - promptBuilder.WriteString("Assistant: ") - - completionsReq := openaiCompletionsRequest{ - Model: model, - Prompt: promptBuilder.String(), - MaxCompletionTokens: req.MaxTokens, - } - body, err = json.Marshal(completionsReq) - } else { - body, err = json.Marshal(openaiReq) - } + body, err := json.Marshal(openaiReq) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } @@ -370,14 +344,7 @@ func (c *OpenAIClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse } } - // Use the appropriate endpoint - endpoint := c.baseURL - if c.isGPT52NonChat(model) && strings.Contains(c.baseURL, "api.openai.com") { - // GPT-5.2 non-chat models need completions endpoint - endpoint = strings.Replace(c.baseURL, "/chat/completions", "/completions", 1) - } - - httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(body)) + httpReq, err := http.NewRequestWithContext(ctx, "POST", c.baseURL, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -449,10 +416,6 @@ func (c *OpenAIClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse // For DeepSeek reasoner, the actual content may be in reasoning_content // when content is empty (it shows the "thinking" but that's the full response) contentToUse := choice.Message.Content - // Completions API uses Text instead of Message.Content - if contentToUse == "" && choice.Text != "" { - contentToUse = choice.Text - } if contentToUse == "" && choice.Message.ReasoningContent != "" { // DeepSeek reasoner puts output in reasoning_content contentToUse = choice.Message.ReasoningContent @@ -679,7 +642,8 @@ func (c *OpenAIClient) ChatStream(ctx context.Context, req ChatRequest, callback }) } if len(openaiReq.Tools) > 0 { - openaiReq.ToolChoice = "auto" + // Map ToolChoice to OpenAI format (same as non-streaming) + openaiReq.ToolChoice = convertToolChoiceToOpenAI(req.ToolChoice) } } diff --git a/internal/ai/providers/openai_test.go b/internal/ai/providers/openai_test.go index 918f21bab..0f7581cbd 100644 --- a/internal/ai/providers/openai_test.go +++ b/internal/ai/providers/openai_test.go @@ -269,50 +269,10 @@ func TestOpenAIClient_Chat_Success(t *testing.T) { assert.Equal(t, 3, resp.OutputTokens) } -func TestOpenAIClient_Chat_GPT52NonChat(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/v1/chat/completions", r.URL.Path) - - var req openaiCompletionsRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) - assert.Equal(t, "gpt-5.2-pro", req.Model) - assert.Contains(t, req.Prompt, "System: sys") - assert.Contains(t, req.Prompt, "user: hi") - assert.Equal(t, 55, req.MaxCompletionTokens) - - _ = json.NewEncoder(w).Encode(openaiResponse{ - ID: "cmpl-1", - Model: "gpt-5.2-pro", - Choices: []openaiChoice{ - { - Text: "Answer", - FinishReason: "stop", - }, - }, - Usage: openaiUsage{PromptTokens: 3, CompletionTokens: 4}, - }) - })) - defer server.Close() - - client := NewOpenAIClient("sk-test", "gpt-5.2-pro", server.URL, 0) - resp, err := client.Chat(context.Background(), ChatRequest{ - System: "sys", - MaxTokens: 55, - Messages: []Message{ - {Role: "user", Content: "hi"}, - }, - }) - require.NoError(t, err) - assert.Equal(t, "Answer", resp.Content) -} - func TestOpenAIClient_HelperFlags(t *testing.T) { client := NewOpenAIClient("sk", "gpt-4", "https://api.openai.com", 0) assert.True(t, client.requiresMaxCompletionTokens("o1-mini")) assert.False(t, client.requiresMaxCompletionTokens("gpt-4")) - - assert.True(t, client.isGPT52NonChat("gpt-5.2-pro")) - assert.False(t, client.isGPT52NonChat("gpt-5.2-chat-latest")) } func TestOpenAIClient_SupportsThinking(t *testing.T) { diff --git a/internal/ai/providers/provider.go b/internal/ai/providers/provider.go index b508a91f8..78ed766c1 100644 --- a/internal/ai/providers/provider.go +++ b/internal/ai/providers/provider.go @@ -39,14 +39,35 @@ type Tool struct { MaxUses int `json:"max_uses,omitempty"` // For web search: limit searches per request } +// ToolChoiceType represents how the model should choose tools +type ToolChoiceType string + +const ( + // ToolChoiceAuto lets the model decide whether to use tools (default) + ToolChoiceAuto ToolChoiceType = "auto" + // ToolChoiceAny forces the model to use one of the provided tools + ToolChoiceAny ToolChoiceType = "any" + // ToolChoiceNone prevents the model from using any tools + ToolChoiceNone ToolChoiceType = "none" + // ToolChoiceTool forces the model to use a specific tool (set ToolName) + ToolChoiceTool ToolChoiceType = "tool" +) + +// ToolChoice controls how the model selects tools +type ToolChoice struct { + Type ToolChoiceType `json:"type"` + Name string `json:"name,omitempty"` // Only used when Type is ToolChoiceTool +} + // ChatRequest represents a request to the AI provider type ChatRequest struct { - Messages []Message `json:"messages"` - Model string `json:"model"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - System string `json:"system,omitempty"` // System prompt (Anthropic style) - Tools []Tool `json:"tools,omitempty"` // Available tools + Messages []Message `json:"messages"` + Model string `json:"model"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + System string `json:"system,omitempty"` // System prompt (Anthropic style) + Tools []Tool `json:"tools,omitempty"` // Available tools + ToolChoice *ToolChoice `json:"tool_choice,omitempty"` // How to select tools (nil = auto) } // ChatResponse represents a response from the AI provider