mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-18 00:17:39 +01:00
Update AI providers for tool call improvements
Provider updates across all supported backends: - Anthropic: Better tool call handling - OpenAI: Improved response parsing - Gemini: Enhanced compatibility - Ollama: Local model support improvements Includes test updates for OpenAI provider.
This commit is contained in:
@@ -62,12 +62,20 @@ func (c *AnthropicClient) Name() string {
|
||||
|
||||
// anthropicRequest is the request body for the Anthropic API
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Tools []anthropicTool `json:"tools,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Tools []anthropicTool `json:"tools,omitempty"`
|
||||
ToolChoice *anthropicToolChoice `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
// anthropicToolChoice controls how Claude selects tools
|
||||
// See: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/implement-tool-use#forcing-tool-use
|
||||
type anthropicToolChoice struct {
|
||||
Type string `json:"type"` // "auto", "any", "tool", or "none"
|
||||
Name string `json:"name,omitempty"` // Only used when Type is "tool"
|
||||
}
|
||||
|
||||
type anthropicMessage struct {
|
||||
@@ -230,6 +238,16 @@ func (c *AnthropicClient) Chat(ctx context.Context, req ChatRequest) (*ChatRespo
|
||||
}
|
||||
}
|
||||
|
||||
// Add tool_choice if specified
|
||||
// This controls whether Claude MUST use tools vs just being able to
|
||||
// See: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/implement-tool-use#forcing-tool-use
|
||||
if req.ToolChoice != nil {
|
||||
anthropicReq.ToolChoice = &anthropicToolChoice{
|
||||
Type: string(req.ToolChoice.Type),
|
||||
Name: req.ToolChoice.Name,
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(anthropicReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
@@ -441,13 +459,14 @@ func (c *AnthropicClient) SupportsThinking(model string) bool {
|
||||
|
||||
// anthropicStreamRequest is the request body for streaming API calls
|
||||
type anthropicStreamRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Tools []anthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
Model string `json:"model"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Tools []anthropicTool `json:"tools,omitempty"`
|
||||
ToolChoice *anthropicToolChoice `json:"tool_choice,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// anthropicStreamEvent represents a streaming event from the Anthropic API
|
||||
@@ -565,6 +584,14 @@ func (c *AnthropicClient) ChatStream(ctx context.Context, req ChatRequest, callb
|
||||
}
|
||||
}
|
||||
|
||||
// Add tool_choice if specified (same as non-streaming)
|
||||
if req.ToolChoice != nil {
|
||||
anthropicReq.ToolChoice = &anthropicToolChoice{
|
||||
Type: string(req.ToolChoice.Type),
|
||||
Name: req.ToolChoice.Name,
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(anthropicReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
|
||||
@@ -60,6 +60,17 @@ type geminiRequest struct {
|
||||
SystemInstruction *geminiContent `json:"systemInstruction,omitempty"`
|
||||
GenerationConfig *geminiGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools []geminiToolDef `json:"tools,omitempty"`
|
||||
ToolConfig *geminiToolConfig `json:"toolConfig,omitempty"`
|
||||
}
|
||||
|
||||
// geminiToolConfig controls how the model uses tools
|
||||
// See: https://ai.google.dev/api/caching#ToolConfig
|
||||
type geminiToolConfig struct {
|
||||
FunctionCallingConfig *geminiFunctionCallingConfig `json:"functionCallingConfig,omitempty"`
|
||||
}
|
||||
|
||||
type geminiFunctionCallingConfig struct {
|
||||
Mode string `json:"mode"` // AUTO, ANY, or NONE
|
||||
}
|
||||
|
||||
type geminiContent struct {
|
||||
@@ -140,6 +151,28 @@ type geminiError struct {
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// convertToolChoiceToGemini converts our ToolChoice to Gemini's mode string
|
||||
// Gemini uses: AUTO (default), ANY (force tool use), NONE (no tools)
|
||||
// See: https://ai.google.dev/api/caching#FunctionCallingConfig
|
||||
func convertToolChoiceToGemini(tc *ToolChoice) string {
|
||||
if tc == nil {
|
||||
return "AUTO"
|
||||
}
|
||||
switch tc.Type {
|
||||
case ToolChoiceAuto:
|
||||
return "AUTO"
|
||||
case ToolChoiceNone:
|
||||
return "NONE"
|
||||
case ToolChoiceAny:
|
||||
return "ANY"
|
||||
case ToolChoiceTool:
|
||||
// Gemini doesn't support forcing a specific tool, fall back to ANY
|
||||
return "ANY"
|
||||
default:
|
||||
return "AUTO"
|
||||
}
|
||||
}
|
||||
|
||||
// Chat sends a chat request to the Gemini API
|
||||
func (c *GeminiClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, error) {
|
||||
// Convert messages to Gemini format
|
||||
@@ -244,8 +277,13 @@ func (c *GeminiClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
|
||||
geminiReq.GenerationConfig.Temperature = req.Temperature
|
||||
}
|
||||
|
||||
// Add tools if provided
|
||||
if len(req.Tools) > 0 {
|
||||
// Add tools if provided (unless ToolChoice is None)
|
||||
shouldAddTools := len(req.Tools) > 0
|
||||
if req.ToolChoice != nil && req.ToolChoice.Type == ToolChoiceNone {
|
||||
shouldAddTools = false
|
||||
}
|
||||
|
||||
if shouldAddTools {
|
||||
funcDecls := make([]geminiFunctionDeclaration, 0, len(req.Tools))
|
||||
for _, t := range req.Tools {
|
||||
// Skip non-function tools
|
||||
@@ -260,6 +298,15 @@ func (c *GeminiClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
|
||||
}
|
||||
if len(funcDecls) > 0 {
|
||||
geminiReq.Tools = []geminiToolDef{{FunctionDeclarations: funcDecls}}
|
||||
|
||||
// Add tool_config based on ToolChoice
|
||||
// Gemini uses: AUTO (default), ANY (force tool use), NONE (no tools)
|
||||
geminiReq.ToolConfig = &geminiToolConfig{
|
||||
FunctionCallingConfig: &geminiFunctionCallingConfig{
|
||||
Mode: convertToolChoiceToGemini(req.ToolChoice),
|
||||
},
|
||||
}
|
||||
|
||||
log.Debug().Int("tool_count", len(funcDecls)).Strs("tool_names", func() []string {
|
||||
names := make([]string, len(funcDecls))
|
||||
for i, f := range funcDecls {
|
||||
@@ -615,7 +662,13 @@ func (c *GeminiClient) ChatStream(ctx context.Context, req ChatRequest, callback
|
||||
geminiReq.GenerationConfig.Temperature = req.Temperature
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
// Add tools if provided (unless ToolChoice is None) - same as non-streaming
|
||||
shouldAddTools := len(req.Tools) > 0
|
||||
if req.ToolChoice != nil && req.ToolChoice.Type == ToolChoiceNone {
|
||||
shouldAddTools = false
|
||||
}
|
||||
|
||||
if shouldAddTools {
|
||||
funcDecls := make([]geminiFunctionDeclaration, 0, len(req.Tools))
|
||||
for _, t := range req.Tools {
|
||||
if t.Type != "" && t.Type != "function" {
|
||||
@@ -629,6 +682,23 @@ func (c *GeminiClient) ChatStream(ctx context.Context, req ChatRequest, callback
|
||||
}
|
||||
if len(funcDecls) > 0 {
|
||||
geminiReq.Tools = []geminiToolDef{{FunctionDeclarations: funcDecls}}
|
||||
|
||||
// Add tool_config based on ToolChoice (same as non-streaming)
|
||||
geminiReq.ToolConfig = &geminiToolConfig{
|
||||
FunctionCallingConfig: &geminiFunctionCallingConfig{
|
||||
Mode: convertToolChoiceToGemini(req.ToolChoice),
|
||||
},
|
||||
}
|
||||
|
||||
// Log tool names for debugging tool selection issues
|
||||
toolNames := make([]string, len(funcDecls))
|
||||
for i, f := range funcDecls {
|
||||
toolNames[i] = f.Name
|
||||
}
|
||||
log.Debug().
|
||||
Int("tool_count", len(funcDecls)).
|
||||
Strs("tool_names", toolNames).
|
||||
Msg("Gemini stream request includes tools")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -637,6 +707,12 @@ func (c *GeminiClient) ChatStream(ctx context.Context, req ChatRequest, callback
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
// Log the full request body for debugging (at trace level to avoid noise)
|
||||
log.Trace().
|
||||
Str("model", model).
|
||||
RawJSON("request_body", body).
|
||||
Msg("Gemini stream request body")
|
||||
|
||||
// Use streamGenerateContent endpoint for streaming
|
||||
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s&alt=sse", c.baseURL, model, c.apiKey)
|
||||
|
||||
@@ -725,6 +801,10 @@ func (c *GeminiClient) ChatStream(ctx context.Context, req ChatRequest, callback
|
||||
if len(signature) == 0 {
|
||||
signature = part.ThoughtSignatureSnake
|
||||
}
|
||||
log.Debug().
|
||||
Str("tool_name", part.FunctionCall.Name).
|
||||
Interface("tool_args", part.FunctionCall.Args).
|
||||
Msg("Gemini called tool")
|
||||
callback(StreamEvent{
|
||||
Type: "tool_start",
|
||||
Data: ToolStartEvent{
|
||||
@@ -824,35 +904,6 @@ func (c *GeminiClient) ListModels(ctx context.Context) ([]ModelInfo, error) {
|
||||
// Extract model ID from the full name (e.g., "models/gemini-1.5-pro" -> "gemini-1.5-pro")
|
||||
modelID := strings.TrimPrefix(m.Name, "models/")
|
||||
|
||||
// Only include the useful Gemini models for chat/agentic tasks
|
||||
// Filter out Gemma (open-source, no function calling), embedding, AQA, vision-only models
|
||||
// Keep: gemini-3-*, gemini-2.5-*, gemini-2.0-*, gemini-1.5-* (pro and flash variants)
|
||||
isUsefulModel := false
|
||||
usefulPrefixes := []string{
|
||||
"gemini-3-pro", "gemini-3-flash",
|
||||
"gemini-2.5-pro", "gemini-2.5-flash",
|
||||
"gemini-2.0-pro", "gemini-2.0-flash",
|
||||
"gemini-1.5-pro", "gemini-1.5-flash",
|
||||
"gemini-flash", "gemini-pro", // Latest aliases
|
||||
}
|
||||
for _, prefix := range usefulPrefixes {
|
||||
if strings.HasPrefix(modelID, prefix) {
|
||||
isUsefulModel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isUsefulModel {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip experimental/deprecated variants
|
||||
if strings.Contains(modelID, "exp-") ||
|
||||
strings.Contains(modelID, "-exp") ||
|
||||
strings.Contains(modelID, "tuning") ||
|
||||
strings.Contains(modelID, "8b") { // Skip smaller variants
|
||||
continue
|
||||
}
|
||||
|
||||
models = append(models, ModelInfo{
|
||||
ID: modelID,
|
||||
Name: m.DisplayName,
|
||||
|
||||
@@ -155,9 +155,8 @@ func (c *OllamaClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
|
||||
if model == "" {
|
||||
model = c.model
|
||||
}
|
||||
// Ultimate fallback - if no model configured anywhere, use llama3
|
||||
if model == "" {
|
||||
model = "llama3"
|
||||
return nil, fmt.Errorf("no model specified")
|
||||
}
|
||||
|
||||
ollamaReq := ollamaRequest{
|
||||
@@ -167,7 +166,14 @@ func (c *OllamaClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
|
||||
}
|
||||
|
||||
// Convert tools to Ollama format
|
||||
if len(req.Tools) > 0 {
|
||||
// Note: Ollama doesn't support tool_choice like Anthropic/OpenAI
|
||||
// We handle ToolChoiceNone by not adding tools, but can't force tool use
|
||||
shouldAddTools := len(req.Tools) > 0
|
||||
if req.ToolChoice != nil && req.ToolChoice.Type == ToolChoiceNone {
|
||||
shouldAddTools = false
|
||||
}
|
||||
|
||||
if shouldAddTools {
|
||||
ollamaReq.Tools = make([]ollamaTool, 0, len(req.Tools))
|
||||
for _, t := range req.Tools {
|
||||
// Skip non-function tools (like web_search which Ollama doesn't support)
|
||||
@@ -318,7 +324,7 @@ func (c *OllamaClient) ChatStream(ctx context.Context, req ChatRequest, callback
|
||||
model = c.model
|
||||
}
|
||||
if model == "" {
|
||||
model = "llama3"
|
||||
return fmt.Errorf("no model specified")
|
||||
}
|
||||
|
||||
ollamaReq := ollamaRequest{
|
||||
@@ -327,7 +333,13 @@ func (c *OllamaClient) ChatStream(ctx context.Context, req ChatRequest, callback
|
||||
Stream: true, // Enable streaming
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
// Handle tools with tool_choice support (same as non-streaming)
|
||||
shouldAddTools := len(req.Tools) > 0
|
||||
if req.ToolChoice != nil && req.ToolChoice.Type == ToolChoiceNone {
|
||||
shouldAddTools = false
|
||||
}
|
||||
|
||||
if shouldAddTools {
|
||||
ollamaReq.Tools = make([]ollamaTool, 0, len(req.Tools))
|
||||
for _, t := range req.Tools {
|
||||
if t.Type != "" && t.Type != "function" {
|
||||
|
||||
@@ -87,23 +87,6 @@ type openaiRequest struct {
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"` // "auto", "none", or specific tool
|
||||
}
|
||||
|
||||
// deepseekRequest extends openaiRequest with DeepSeek-specific fields
|
||||
type deepseekRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openaiMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Tools []openaiTool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
// openaiCompletionsRequest is for non-chat models like gpt-5.2-pro that use /v1/completions
|
||||
type openaiCompletionsRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
// openaiTool represents a function tool in OpenAI format
|
||||
type openaiTool struct {
|
||||
Type string `json:"type"` // always "function"
|
||||
@@ -147,8 +130,7 @@ type openaiResponse struct {
|
||||
|
||||
type openaiChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message openaiRespMsg `json:"message"` // For chat completions
|
||||
Text string `json:"text"` // For completions API (non-chat models)
|
||||
Message openaiRespMsg `json:"message"`
|
||||
FinishReason string `json:"finish_reason"` // "stop", "tool_calls", etc.
|
||||
}
|
||||
|
||||
@@ -186,19 +168,37 @@ func (c *OpenAIClient) isDeepSeekReasoner() bool {
|
||||
}
|
||||
|
||||
// requiresMaxCompletionTokens returns true for models that need max_completion_tokens instead of max_tokens
|
||||
// Per OpenAI docs, o1/o3/o4 reasoning models require max_completion_tokens; max_tokens will error.
|
||||
func (c *OpenAIClient) requiresMaxCompletionTokens(model string) bool {
|
||||
// o1, o1-mini, o1-preview, o3, o3-mini, o4-mini, gpt-5.2, etc.
|
||||
return strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") || strings.HasPrefix(model, "o4") || strings.HasPrefix(model, "gpt-5")
|
||||
return strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") || strings.HasPrefix(model, "o4")
|
||||
}
|
||||
|
||||
// isGPT52NonChat returns true if using GPT-5.2 models that require /v1/completions endpoint
|
||||
// Only gpt-5.2-chat-latest uses chat completions; gpt-5.2, gpt-5.2-pro use completions
|
||||
func (c *OpenAIClient) isGPT52NonChat(model string) bool {
|
||||
if !strings.HasPrefix(model, "gpt-5.2") {
|
||||
return false
|
||||
// convertToolChoiceToOpenAI converts our ToolChoice to OpenAI's format
|
||||
// OpenAI uses "required" instead of Anthropic's "any" to force tool use
|
||||
// See: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
func convertToolChoiceToOpenAI(tc *ToolChoice) interface{} {
|
||||
if tc == nil {
|
||||
return "auto"
|
||||
}
|
||||
switch tc.Type {
|
||||
case ToolChoiceAuto:
|
||||
return "auto"
|
||||
case ToolChoiceNone:
|
||||
return "none"
|
||||
case ToolChoiceAny:
|
||||
// OpenAI uses "required" to force the model to use one of the provided tools
|
||||
return "required"
|
||||
case ToolChoiceTool:
|
||||
// Force a specific tool
|
||||
return map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": map[string]string{
|
||||
"name": tc.Name,
|
||||
},
|
||||
}
|
||||
default:
|
||||
return "auto"
|
||||
}
|
||||
// gpt-5.2-chat-latest is the only chat model
|
||||
return !strings.Contains(model, "chat")
|
||||
}
|
||||
|
||||
// Chat sends a chat request to the OpenAI API
|
||||
@@ -309,42 +309,16 @@ func (c *OpenAIClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
|
||||
})
|
||||
}
|
||||
if len(openaiReq.Tools) > 0 {
|
||||
openaiReq.ToolChoice = "auto"
|
||||
// Map ToolChoice to OpenAI format
|
||||
// OpenAI uses "required" instead of Anthropic's "any"
|
||||
openaiReq.ToolChoice = convertToolChoiceToOpenAI(req.ToolChoice)
|
||||
}
|
||||
}
|
||||
|
||||
// Log actual model being sent (INFO level for visibility)
|
||||
log.Info().Str("model_in_request", openaiReq.Model).Str("base_url", c.baseURL).Msg("Sending OpenAI/DeepSeek request")
|
||||
|
||||
var body []byte
|
||||
var err error
|
||||
|
||||
// GPT-5.2 non-chat models need completions format (prompt instead of messages)
|
||||
if c.isGPT52NonChat(model) {
|
||||
// Convert messages to a single prompt string
|
||||
var promptBuilder strings.Builder
|
||||
if req.System != "" {
|
||||
promptBuilder.WriteString("System: ")
|
||||
promptBuilder.WriteString(req.System)
|
||||
promptBuilder.WriteString("\n\n")
|
||||
}
|
||||
for _, m := range req.Messages {
|
||||
promptBuilder.WriteString(m.Role)
|
||||
promptBuilder.WriteString(": ")
|
||||
promptBuilder.WriteString(m.Content)
|
||||
promptBuilder.WriteString("\n\n")
|
||||
}
|
||||
promptBuilder.WriteString("Assistant: ")
|
||||
|
||||
completionsReq := openaiCompletionsRequest{
|
||||
Model: model,
|
||||
Prompt: promptBuilder.String(),
|
||||
MaxCompletionTokens: req.MaxTokens,
|
||||
}
|
||||
body, err = json.Marshal(completionsReq)
|
||||
} else {
|
||||
body, err = json.Marshal(openaiReq)
|
||||
}
|
||||
body, err := json.Marshal(openaiReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
@@ -370,14 +344,7 @@ func (c *OpenAIClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
|
||||
}
|
||||
}
|
||||
|
||||
// Use the appropriate endpoint
|
||||
endpoint := c.baseURL
|
||||
if c.isGPT52NonChat(model) && strings.Contains(c.baseURL, "api.openai.com") {
|
||||
// GPT-5.2 non-chat models need completions endpoint
|
||||
endpoint = strings.Replace(c.baseURL, "/chat/completions", "/completions", 1)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(body))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.baseURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
@@ -449,10 +416,6 @@ func (c *OpenAIClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
|
||||
// For DeepSeek reasoner, the actual content may be in reasoning_content
|
||||
// when content is empty (it shows the "thinking" but that's the full response)
|
||||
contentToUse := choice.Message.Content
|
||||
// Completions API uses Text instead of Message.Content
|
||||
if contentToUse == "" && choice.Text != "" {
|
||||
contentToUse = choice.Text
|
||||
}
|
||||
if contentToUse == "" && choice.Message.ReasoningContent != "" {
|
||||
// DeepSeek reasoner puts output in reasoning_content
|
||||
contentToUse = choice.Message.ReasoningContent
|
||||
@@ -679,7 +642,8 @@ func (c *OpenAIClient) ChatStream(ctx context.Context, req ChatRequest, callback
|
||||
})
|
||||
}
|
||||
if len(openaiReq.Tools) > 0 {
|
||||
openaiReq.ToolChoice = "auto"
|
||||
// Map ToolChoice to OpenAI format (same as non-streaming)
|
||||
openaiReq.ToolChoice = convertToolChoiceToOpenAI(req.ToolChoice)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -269,50 +269,10 @@ func TestOpenAIClient_Chat_Success(t *testing.T) {
|
||||
assert.Equal(t, 3, resp.OutputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIClient_Chat_GPT52NonChat(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/v1/chat/completions", r.URL.Path)
|
||||
|
||||
var req openaiCompletionsRequest
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
assert.Equal(t, "gpt-5.2-pro", req.Model)
|
||||
assert.Contains(t, req.Prompt, "System: sys")
|
||||
assert.Contains(t, req.Prompt, "user: hi")
|
||||
assert.Equal(t, 55, req.MaxCompletionTokens)
|
||||
|
||||
_ = json.NewEncoder(w).Encode(openaiResponse{
|
||||
ID: "cmpl-1",
|
||||
Model: "gpt-5.2-pro",
|
||||
Choices: []openaiChoice{
|
||||
{
|
||||
Text: "Answer",
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Usage: openaiUsage{PromptTokens: 3, CompletionTokens: 4},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewOpenAIClient("sk-test", "gpt-5.2-pro", server.URL, 0)
|
||||
resp, err := client.Chat(context.Background(), ChatRequest{
|
||||
System: "sys",
|
||||
MaxTokens: 55,
|
||||
Messages: []Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Answer", resp.Content)
|
||||
}
|
||||
|
||||
func TestOpenAIClient_HelperFlags(t *testing.T) {
|
||||
client := NewOpenAIClient("sk", "gpt-4", "https://api.openai.com", 0)
|
||||
assert.True(t, client.requiresMaxCompletionTokens("o1-mini"))
|
||||
assert.False(t, client.requiresMaxCompletionTokens("gpt-4"))
|
||||
|
||||
assert.True(t, client.isGPT52NonChat("gpt-5.2-pro"))
|
||||
assert.False(t, client.isGPT52NonChat("gpt-5.2-chat-latest"))
|
||||
}
|
||||
|
||||
func TestOpenAIClient_SupportsThinking(t *testing.T) {
|
||||
|
||||
@@ -39,14 +39,35 @@ type Tool struct {
|
||||
MaxUses int `json:"max_uses,omitempty"` // For web search: limit searches per request
|
||||
}
|
||||
|
||||
// ToolChoiceType represents how the model should choose tools
|
||||
type ToolChoiceType string
|
||||
|
||||
const (
|
||||
// ToolChoiceAuto lets the model decide whether to use tools (default)
|
||||
ToolChoiceAuto ToolChoiceType = "auto"
|
||||
// ToolChoiceAny forces the model to use one of the provided tools
|
||||
ToolChoiceAny ToolChoiceType = "any"
|
||||
// ToolChoiceNone prevents the model from using any tools
|
||||
ToolChoiceNone ToolChoiceType = "none"
|
||||
// ToolChoiceTool forces the model to use a specific tool (set ToolName)
|
||||
ToolChoiceTool ToolChoiceType = "tool"
|
||||
)
|
||||
|
||||
// ToolChoice controls how the model selects tools
|
||||
type ToolChoice struct {
|
||||
Type ToolChoiceType `json:"type"`
|
||||
Name string `json:"name,omitempty"` // Only used when Type is ToolChoiceTool
|
||||
}
|
||||
|
||||
// ChatRequest represents a request to the AI provider
|
||||
type ChatRequest struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
System string `json:"system,omitempty"` // System prompt (Anthropic style)
|
||||
Tools []Tool `json:"tools,omitempty"` // Available tools
|
||||
Messages []Message `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
System string `json:"system,omitempty"` // System prompt (Anthropic style)
|
||||
Tools []Tool `json:"tools,omitempty"` // Available tools
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"` // How to select tools (nil = auto)
|
||||
}
|
||||
|
||||
// ChatResponse represents a response from the AI provider
|
||||
|
||||
Reference in New Issue
Block a user