test move to client lib for openai

This commit is contained in:
lehel 2025-10-08 22:32:13 +02:00
parent 1484b519d7
commit e69201e3e9
No known key found for this signature in database
GPG Key ID: 9C4F9D6111EE5CFA
3 changed files with 155 additions and 189 deletions

1
go.mod
View File

@ -9,6 +9,7 @@ require (
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.5
github.com/pressly/goose/v3 v3.26.0
github.com/sashabaranov/go-openai v1.41.2
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.11.1
golang.org/x/crypto v0.42.0

2
go.sum
View File

@ -124,6 +124,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM=
github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=

View File

@ -1,15 +1,15 @@
package main
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/sashabaranov/go-openai"
"github.com/sirupsen/logrus"
)
@ -20,10 +20,40 @@ type OpenAIClient struct {
BaseURL string
Model string
Repo ChatRepositoryAPI
client *openai.Client
}
func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient {
return &OpenAIClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo}
config := openai.DefaultConfig(apiKey)
if baseURL != "" {
config.BaseURL = baseURL
}
// Special handling for OpenRouter
// Create a new HTTP client with custom headers
httpClient := &http.Client{}
if strings.Contains(strings.ToLower(baseURL), "openrouter.ai") {
// Use custom transport to add OpenRouter-specific headers
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
defaultTransport.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12}
httpClient.Transport = &customTransport{
base: defaultTransport,
headers: map[string]string{
"Referer": "https://github.com/",
"X-Title": "vetrag-app",
},
}
config.HTTPClient = httpClient
}
client := openai.NewClientWithConfig(config)
return &OpenAIClient{
APIKey: apiKey,
BaseURL: baseURL,
Model: model,
Repo: repo,
client: client,
}
}
func (llm *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
@ -39,7 +69,7 @@ func (llm *OpenAIClient) ExtractKeywordsRaw(ctx context.Context, message string)
}
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
// Use the utility function instead of inline format definition
// Format remains the same
format := GetExtractKeywordsFormat()
resp, err := llm.openAICompletion(ctx, prompt, format)
@ -60,7 +90,6 @@ func (llm *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message stri
}
func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
// Use the utility function instead of inline format definition
format := GetDisambiguateFormat()
entries, _ := json.Marshal(candidates)
@ -87,19 +116,6 @@ func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message s
}
func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
apiURL := llm.BaseURL
if apiURL == "" {
apiURL = "https://api.openai.com/v1/chat/completions"
}
isOpenAIStyle := strings.Contains(apiURL, "openrouter.ai") || strings.Contains(apiURL, "/v1/")
// Helper to stringify the expected JSON schema for instructions
schemaDesc := func() string {
b, _ := json.MarshalIndent(format, "", " ")
return string(b)
}
truncate := func(s string, n int) string {
if len(s) <= n {
return s
@ -107,204 +123,136 @@ func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, fo
return s[:n] + "...<truncated>"
}
buildBody := func() map[string]interface{} {
if isOpenAIStyle {
return map[string]interface{}{
"model": llm.Model,
"messages": []map[string]string{
{"role": "system", "content": "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + schemaDesc() + " Do not add explanations."},
{"role": "user", "content": prompt},
},
"response_format": map[string]interface{}{"type": "json_object"},
}
}
// This should never be reached in OpenAI client but keeping for safety
return map[string]interface{}{
"model": llm.Model,
"messages": []map[string]string{{"role": "user", "content": prompt}},
"stream": false,
"format": format,
}
}
body := buildBody()
// Enhanced logging similar to the unified client
jsonBody, _ := json.Marshal(body)
bodySize := len(jsonBody)
logrus.WithFields(logrus.Fields{
"event": "llm_request",
"api_url": apiURL,
"model": llm.Model,
"is_openai_style": isOpenAIStyle,
"prompt_len": len(prompt),
"body_size": bodySize,
}).Info("[LLM] sending request")
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
if llm.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if strings.Contains(apiURL, "openrouter.ai") {
req.Header.Set("Referer", "https://github.com/")
req.Header.Set("X-Title", "vetrag-app")
// Build system message with schema if format is provided
systemContent := "You are a helpful assistant."
if format != nil {
schemaJSON, _ := json.MarshalIndent(format, "", " ")
systemContent = "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + string(schemaJSON) + " Do not add explanations."
}
start := time.Now()
client := &http.Client{}
resp, err := client.Do(req)
// Create the chat completion request
req := openai.ChatCompletionRequest{
Model: llm.Model,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: systemContent,
},
{
Role: openai.ChatMessageRoleUser,
Content: prompt,
},
},
}
// If we have a format schema, set the response format to JSON
if format != nil {
req.ResponseFormat = &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONObject,
}
}
// Log the request
logrus.WithFields(logrus.Fields{
"event": "llm_request",
"api_url": llm.BaseURL,
"model": llm.Model,
"prompt_len": len(prompt),
}).Info("[LLM] sending request")
// Make the API call
resp, err := llm.client.CreateChatCompletion(ctx, req)
dur := time.Since(start)
// Handle errors
if err != nil {
logrus.WithFields(logrus.Fields{
"event": "llm_response",
"status": 0,
"latency_ms": dur.Milliseconds(),
"error": err,
"error": err.Error(),
}).Error("[LLM] request failed")
return "", err
return "", fmt.Errorf("provider error: %w", err)
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
logrus.WithFields(logrus.Fields{
"event": "llm_raw_response",
"status": resp.StatusCode,
"latency_ms": dur.Milliseconds(),
"raw_trunc": truncate(string(raw), 600),
"raw_len": len(raw),
}).Debug("[LLM] raw response body")
parseVariant := "unknown"
// Attempt OpenAI/OpenRouter style parse first
var openAI struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Error *struct {
Message string `json:"message"`
Type string `json:"type"`
} `json:"error"`
}
if err := json.Unmarshal(raw, &openAI); err == nil {
if openAI.Error != nil || resp.StatusCode >= 400 {
parseVariant = "openai"
var msg string
if openAI.Error != nil {
msg = openAI.Error.Message
} else {
msg = string(raw)
}
logrus.WithFields(logrus.Fields{
"event": "llm_response",
"status": resp.StatusCode,
"latency_ms": dur.Milliseconds(),
"parse_variant": parseVariant,
"error": msg,
}).Error("[LLM] provider error")
return "", fmt.Errorf("provider error: %s", msg)
}
if len(openAI.Choices) > 0 && openAI.Choices[0].Message.Content != "" {
parseVariant = "openai"
content := openAI.Choices[0].Message.Content
logrus.WithFields(logrus.Fields{
"event": "llm_response",
"status": resp.StatusCode,
"latency_ms": dur.Milliseconds(),
"parse_variant": parseVariant,
"content_len": len(content),
"content_snip": truncate(content, 300),
}).Info("[LLM] parsed response")
return content, nil
}
}
// As a fallback, attempt Ollama format parse
var ollama struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
Error string `json:"error"`
}
if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" {
parseVariant = "ollama"
content := ollama.Message.Content
// Extract content from response
if len(resp.Choices) == 0 {
logrus.WithFields(logrus.Fields{
"event": "llm_response",
"status": resp.StatusCode,
"latency_ms": dur.Milliseconds(),
"parse_variant": parseVariant,
"content_len": len(content),
"content_snip": truncate(content, 300),
}).Info("[LLM] parsed response")
return content, nil
"event": "llm_response",
"latency_ms": dur.Milliseconds(),
}).Error("[LLM] empty choices in response")
return "", fmt.Errorf("provider error: no completion choices returned")
}
content := resp.Choices[0].Message.Content
// Log successful response
logrus.WithFields(logrus.Fields{
"event": "llm_response",
"status": resp.StatusCode,
"latency_ms": dur.Milliseconds(),
"parse_variant": parseVariant,
"raw_snip": truncate(string(raw), 300),
}).Error("[LLM] unrecognized response format")
"content_len": len(content),
"content_snip": truncate(content, 300),
"finish_reason": resp.Choices[0].FinishReason,
}).Info("[LLM] parsed response")
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
return content, nil
}
func (llm *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
apiURL := llm.BaseURL
if apiURL == "" {
apiURL = "https://api.openai.com/v1/embeddings"
start := time.Now()
// Create embedding request
req := openai.EmbeddingRequest{
// Convert the string model to an EmbeddingModel type
Model: openai.EmbeddingModel(llm.Model),
Input: input,
}
body := map[string]interface{}{
"model": llm.Model,
"input": input,
}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
if llm.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if strings.Contains(apiURL, "openrouter.ai") {
req.Header.Set("Referer", "https://github.com/")
req.Header.Set("X-Title", "vetrag-app")
}
client := &http.Client{}
resp, err := client.Do(req)
// Log the request
logrus.WithFields(logrus.Fields{
"event": "embedding_request",
"model": llm.Model,
"input_len": len(input),
}).Info("[LLM] sending embedding request")
// Make the API call
resp, err := llm.client.CreateEmbeddings(ctx, req)
dur := time.Since(start)
// Handle errors
if err != nil {
return nil, err
logrus.WithFields(logrus.Fields{
"event": "embedding_response",
"latency_ms": dur.Milliseconds(),
"error": err.Error(),
}).Error("[LLM] embedding request failed")
return nil, fmt.Errorf("embedding error: %w", err)
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
// Check if we got embeddings
if len(resp.Data) == 0 {
logrus.WithFields(logrus.Fields{
"event": "embedding_response",
"latency_ms": dur.Milliseconds(),
}).Error("[LLM] empty embedding data in response")
return nil, fmt.Errorf("embedding error: no embedding data returned")
}
var openAI struct {
Data []struct {
Embedding []float64 `json:"embedding"`
} `json:"data"`
Error *struct {
Message string `json:"message"`
} `json:"error"`
// Convert []float32 to []float64
embeddings := make([]float64, len(resp.Data[0].Embedding))
for i, v := range resp.Data[0].Embedding {
embeddings[i] = float64(v)
}
if err := json.Unmarshal(raw, &openAI); err == nil && len(openAI.Data) > 0 {
return openAI.Data[0].Embedding, nil
}
if openAI.Error != nil {
return nil, fmt.Errorf("embedding error: %s", openAI.Error.Message)
}
return nil, fmt.Errorf("unrecognized embedding response: %.200s", string(raw))
// Log successful response
logrus.WithFields(logrus.Fields{
"event": "embedding_response",
"latency_ms": dur.Milliseconds(),
"vector_size": len(embeddings),
}).Info("[LLM] embedding response")
return embeddings, nil
}
func (llm *OpenAIClient) TranslateToEnglish(ctx context.Context, message string) (string, error) {
@ -322,3 +270,18 @@ func (llm *OpenAIClient) TranslateToEnglish(ctx context.Context, message string)
}
return strings.TrimSpace(resp), nil
}
// customTransport is an http.RoundTripper that adds custom headers to requests.
type customTransport struct {
base http.RoundTripper
headers map[string]string
}
func (t *customTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Add custom headers to the request
for key, value := range t.headers {
req.Header.Set(key, value)
}
// Call the base RoundTripper
return t.base.RoundTrip(req)
}