From e69201e3e919a37efd08397572ca0a2a0ab25f17 Mon Sep 17 00:00:00 2001 From: lehel Date: Wed, 8 Oct 2025 22:32:13 +0200 Subject: [PATCH] test move to client lib for openai --- go.mod | 1 + go.sum | 2 + openai_client.go | 341 +++++++++++++++++++++-------------------------- 3 files changed, 155 insertions(+), 189 deletions(-) diff --git a/go.mod b/go.mod index a992042..424c599 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.7.5 github.com/pressly/goose/v3 v3.26.0 + github.com/sashabaranov/go-openai v1.41.2 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.42.0 diff --git a/go.sum b/go.sum index 580c43e..57247d0 100644 --- a/go.sum +++ b/go.sum @@ -124,6 +124,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94 github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM= +github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/openai_client.go b/openai_client.go index 7b178aa..44af87c 100644 --- a/openai_client.go +++ b/openai_client.go @@ -1,15 +1,15 @@ package main import ( - "bytes" "context" + "crypto/tls" "encoding/json" "fmt" - "io" "net/http" "strings" "time" + "github.com/sashabaranov/go-openai" "github.com/sirupsen/logrus" ) @@ -20,10 +20,40 @@ type OpenAIClient struct { BaseURL string Model string Repo ChatRepositoryAPI + client *openai.Client } func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient { - return &OpenAIClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo} + config := openai.DefaultConfig(apiKey) + if baseURL != "" { + config.BaseURL = baseURL + } + + // Special handling for OpenRouter + // Create a new HTTP client with custom headers + httpClient := &http.Client{} + if strings.Contains(strings.ToLower(baseURL), "openrouter.ai") { + // Use custom transport to add OpenRouter-specific headers + defaultTransport := http.DefaultTransport.(*http.Transport).Clone() + defaultTransport.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12} + httpClient.Transport = &customTransport{ + base: defaultTransport, + headers: map[string]string{ + "Referer": "https://github.com/", + "X-Title": "vetrag-app", + }, + } + config.HTTPClient = httpClient + } + + client := openai.NewClientWithConfig(config) + return &OpenAIClient{ + APIKey: apiKey, + BaseURL: baseURL, + Model: model, + Repo: repo, + client: client, + } } func (llm *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) { @@ -39,7 +69,7 @@ func (llm *OpenAIClient) ExtractKeywordsRaw(ctx context.Context, message string) } logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt") - // Use the utility function instead of inline format definition + // Format remains the same format := GetExtractKeywordsFormat() resp, err := llm.openAICompletion(ctx, prompt, format) @@ -60,7 +90,6 @@ func (llm *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message stri } func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) { - // Use the utility function instead of inline format definition format := GetDisambiguateFormat() entries, _ := json.Marshal(candidates) @@ -87,19 +116,6 @@ func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message s } func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) { - apiURL := llm.BaseURL - if apiURL == "" { - apiURL = "https://api.openai.com/v1/chat/completions" - } - - isOpenAIStyle := strings.Contains(apiURL, "openrouter.ai") || strings.Contains(apiURL, "/v1/") - - // Helper to stringify the expected JSON schema for instructions - schemaDesc := func() string { - b, _ := json.MarshalIndent(format, "", " ") - return string(b) - } - truncate := func(s string, n int) string { if len(s) <= n { return s @@ -107,204 +123,136 @@ func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, fo return s[:n] + "..." } - buildBody := func() map[string]interface{} { - if isOpenAIStyle { - return map[string]interface{}{ - "model": llm.Model, - "messages": []map[string]string{ - {"role": "system", "content": "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + schemaDesc() + " Do not add explanations."}, - {"role": "user", "content": prompt}, - }, - "response_format": map[string]interface{}{"type": "json_object"}, - } - } - // This should never be reached in OpenAI client but keeping for safety - return map[string]interface{}{ - "model": llm.Model, - "messages": []map[string]string{{"role": "user", "content": prompt}}, - "stream": false, - "format": format, - } - } - - body := buildBody() - - // Enhanced logging similar to the unified client - jsonBody, _ := json.Marshal(body) - bodySize := len(jsonBody) - logrus.WithFields(logrus.Fields{ - "event": "llm_request", - "api_url": apiURL, - "model": llm.Model, - "is_openai_style": isOpenAIStyle, - "prompt_len": len(prompt), - "body_size": bodySize, - }).Info("[LLM] sending request") - - req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody)) - if llm.APIKey != "" { - req.Header.Set("Authorization", "Bearer "+llm.APIKey) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - if strings.Contains(apiURL, "openrouter.ai") { - req.Header.Set("Referer", "https://github.com/") - req.Header.Set("X-Title", "vetrag-app") + // Build system message with schema if format is provided + systemContent := "You are a helpful assistant." + if format != nil { + schemaJSON, _ := json.MarshalIndent(format, "", " ") + systemContent = "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + string(schemaJSON) + " Do not add explanations." } start := time.Now() - client := &http.Client{} - resp, err := client.Do(req) + + // Create the chat completion request + req := openai.ChatCompletionRequest{ + Model: llm.Model, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: systemContent, + }, + { + Role: openai.ChatMessageRoleUser, + Content: prompt, + }, + }, + } + + // If we have a format schema, set the response format to JSON + if format != nil { + req.ResponseFormat = &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONObject, + } + } + + // Log the request + logrus.WithFields(logrus.Fields{ + "event": "llm_request", + "api_url": llm.BaseURL, + "model": llm.Model, + "prompt_len": len(prompt), + }).Info("[LLM] sending request") + + // Make the API call + resp, err := llm.client.CreateChatCompletion(ctx, req) dur := time.Since(start) + // Handle errors if err != nil { logrus.WithFields(logrus.Fields{ "event": "llm_response", - "status": 0, "latency_ms": dur.Milliseconds(), - "error": err, + "error": err.Error(), }).Error("[LLM] request failed") - return "", err + return "", fmt.Errorf("provider error: %w", err) } - defer resp.Body.Close() - raw, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - logrus.WithFields(logrus.Fields{ - "event": "llm_raw_response", - "status": resp.StatusCode, - "latency_ms": dur.Milliseconds(), - "raw_trunc": truncate(string(raw), 600), - "raw_len": len(raw), - }).Debug("[LLM] raw response body") - - parseVariant := "unknown" - - // Attempt OpenAI/OpenRouter style parse first - var openAI struct { - Choices []struct { - Message struct { - Content string `json:"content"` - } `json:"message"` - } `json:"choices"` - Error *struct { - Message string `json:"message"` - Type string `json:"type"` - } `json:"error"` - } - if err := json.Unmarshal(raw, &openAI); err == nil { - if openAI.Error != nil || resp.StatusCode >= 400 { - parseVariant = "openai" - var msg string - if openAI.Error != nil { - msg = openAI.Error.Message - } else { - msg = string(raw) - } - logrus.WithFields(logrus.Fields{ - "event": "llm_response", - "status": resp.StatusCode, - "latency_ms": dur.Milliseconds(), - "parse_variant": parseVariant, - "error": msg, - }).Error("[LLM] provider error") - return "", fmt.Errorf("provider error: %s", msg) - } - if len(openAI.Choices) > 0 && openAI.Choices[0].Message.Content != "" { - parseVariant = "openai" - content := openAI.Choices[0].Message.Content - logrus.WithFields(logrus.Fields{ - "event": "llm_response", - "status": resp.StatusCode, - "latency_ms": dur.Milliseconds(), - "parse_variant": parseVariant, - "content_len": len(content), - "content_snip": truncate(content, 300), - }).Info("[LLM] parsed response") - return content, nil - } - } - - // As a fallback, attempt Ollama format parse - var ollama struct { - Message struct { - Content string `json:"content"` - } `json:"message"` - Error string `json:"error"` - } - if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" { - parseVariant = "ollama" - content := ollama.Message.Content + // Extract content from response + if len(resp.Choices) == 0 { logrus.WithFields(logrus.Fields{ - "event": "llm_response", - "status": resp.StatusCode, - "latency_ms": dur.Milliseconds(), - "parse_variant": parseVariant, - "content_len": len(content), - "content_snip": truncate(content, 300), - }).Info("[LLM] parsed response") - return content, nil + "event": "llm_response", + "latency_ms": dur.Milliseconds(), + }).Error("[LLM] empty choices in response") + return "", fmt.Errorf("provider error: no completion choices returned") } + content := resp.Choices[0].Message.Content + + // Log successful response logrus.WithFields(logrus.Fields{ "event": "llm_response", - "status": resp.StatusCode, "latency_ms": dur.Milliseconds(), - "parse_variant": parseVariant, - "raw_snip": truncate(string(raw), 300), - }).Error("[LLM] unrecognized response format") + "content_len": len(content), + "content_snip": truncate(content, 300), + "finish_reason": resp.Choices[0].FinishReason, + }).Info("[LLM] parsed response") - return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw)) + return content, nil } func (llm *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) { - apiURL := llm.BaseURL - if apiURL == "" { - apiURL = "https://api.openai.com/v1/embeddings" + start := time.Now() + + // Create embedding request + req := openai.EmbeddingRequest{ + // Convert the string model to an EmbeddingModel type + Model: openai.EmbeddingModel(llm.Model), + Input: input, } - body := map[string]interface{}{ - "model": llm.Model, - "input": input, - } - jsonBody, _ := json.Marshal(body) - req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody)) - if llm.APIKey != "" { - req.Header.Set("Authorization", "Bearer "+llm.APIKey) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - if strings.Contains(apiURL, "openrouter.ai") { - req.Header.Set("Referer", "https://github.com/") - req.Header.Set("X-Title", "vetrag-app") - } - client := &http.Client{} - resp, err := client.Do(req) + + // Log the request + logrus.WithFields(logrus.Fields{ + "event": "embedding_request", + "model": llm.Model, + "input_len": len(input), + }).Info("[LLM] sending embedding request") + + // Make the API call + resp, err := llm.client.CreateEmbeddings(ctx, req) + dur := time.Since(start) + + // Handle errors if err != nil { - return nil, err + logrus.WithFields(logrus.Fields{ + "event": "embedding_response", + "latency_ms": dur.Milliseconds(), + "error": err.Error(), + }).Error("[LLM] embedding request failed") + return nil, fmt.Errorf("embedding error: %w", err) } - defer resp.Body.Close() - raw, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err + + // Check if we got embeddings + if len(resp.Data) == 0 { + logrus.WithFields(logrus.Fields{ + "event": "embedding_response", + "latency_ms": dur.Milliseconds(), + }).Error("[LLM] empty embedding data in response") + return nil, fmt.Errorf("embedding error: no embedding data returned") } - var openAI struct { - Data []struct { - Embedding []float64 `json:"embedding"` - } `json:"data"` - Error *struct { - Message string `json:"message"` - } `json:"error"` + + // Convert []float32 to []float64 + embeddings := make([]float64, len(resp.Data[0].Embedding)) + for i, v := range resp.Data[0].Embedding { + embeddings[i] = float64(v) } - if err := json.Unmarshal(raw, &openAI); err == nil && len(openAI.Data) > 0 { - return openAI.Data[0].Embedding, nil - } - if openAI.Error != nil { - return nil, fmt.Errorf("embedding error: %s", openAI.Error.Message) - } - return nil, fmt.Errorf("unrecognized embedding response: %.200s", string(raw)) + + // Log successful response + logrus.WithFields(logrus.Fields{ + "event": "embedding_response", + "latency_ms": dur.Milliseconds(), + "vector_size": len(embeddings), + }).Info("[LLM] embedding response") + + return embeddings, nil } func (llm *OpenAIClient) TranslateToEnglish(ctx context.Context, message string) (string, error) { @@ -322,3 +270,18 @@ func (llm *OpenAIClient) TranslateToEnglish(ctx context.Context, message string) } return strings.TrimSpace(resp), nil } + +// customTransport is an http.RoundTripper that adds custom headers to requests. +type customTransport struct { + base http.RoundTripper + headers map[string]string +} + +func (t *customTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Add custom headers to the request + for key, value := range t.headers { + req.Header.Set(key, value) + } + // Call the base RoundTripper + return t.base.RoundTrip(req) +}