test move to client lib for openai
This commit is contained in:
parent
1484b519d7
commit
e69201e3e9
1
go.mod
1
go.mod
|
|
@ -9,6 +9,7 @@ require (
|
|||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.5
|
||||
github.com/pressly/goose/v3 v3.26.0
|
||||
github.com/sashabaranov/go-openai v1.41.2
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/stretchr/testify v1.11.1
|
||||
golang.org/x/crypto v0.42.0
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -124,6 +124,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94
|
|||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM=
|
||||
github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
|
||||
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
|
|
|
|||
331
openai_client.go
331
openai_client.go
|
|
@ -1,15 +1,15 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -20,10 +20,40 @@ type OpenAIClient struct {
|
|||
BaseURL string
|
||||
Model string
|
||||
Repo ChatRepositoryAPI
|
||||
client *openai.Client
|
||||
}
|
||||
|
||||
func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient {
|
||||
return &OpenAIClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo}
|
||||
config := openai.DefaultConfig(apiKey)
|
||||
if baseURL != "" {
|
||||
config.BaseURL = baseURL
|
||||
}
|
||||
|
||||
// Special handling for OpenRouter
|
||||
// Create a new HTTP client with custom headers
|
||||
httpClient := &http.Client{}
|
||||
if strings.Contains(strings.ToLower(baseURL), "openrouter.ai") {
|
||||
// Use custom transport to add OpenRouter-specific headers
|
||||
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
defaultTransport.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
httpClient.Transport = &customTransport{
|
||||
base: defaultTransport,
|
||||
headers: map[string]string{
|
||||
"Referer": "https://github.com/",
|
||||
"X-Title": "vetrag-app",
|
||||
},
|
||||
}
|
||||
config.HTTPClient = httpClient
|
||||
}
|
||||
|
||||
client := openai.NewClientWithConfig(config)
|
||||
return &OpenAIClient{
|
||||
APIKey: apiKey,
|
||||
BaseURL: baseURL,
|
||||
Model: model,
|
||||
Repo: repo,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (llm *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
||||
|
|
@ -39,7 +69,7 @@ func (llm *OpenAIClient) ExtractKeywordsRaw(ctx context.Context, message string)
|
|||
}
|
||||
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
|
||||
|
||||
// Use the utility function instead of inline format definition
|
||||
// Format remains the same
|
||||
format := GetExtractKeywordsFormat()
|
||||
|
||||
resp, err := llm.openAICompletion(ctx, prompt, format)
|
||||
|
|
@ -60,7 +90,6 @@ func (llm *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message stri
|
|||
}
|
||||
|
||||
func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
|
||||
// Use the utility function instead of inline format definition
|
||||
format := GetDisambiguateFormat()
|
||||
|
||||
entries, _ := json.Marshal(candidates)
|
||||
|
|
@ -87,19 +116,6 @@ func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message s
|
|||
}
|
||||
|
||||
func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
|
||||
apiURL := llm.BaseURL
|
||||
if apiURL == "" {
|
||||
apiURL = "https://api.openai.com/v1/chat/completions"
|
||||
}
|
||||
|
||||
isOpenAIStyle := strings.Contains(apiURL, "openrouter.ai") || strings.Contains(apiURL, "/v1/")
|
||||
|
||||
// Helper to stringify the expected JSON schema for instructions
|
||||
schemaDesc := func() string {
|
||||
b, _ := json.MarshalIndent(format, "", " ")
|
||||
return string(b)
|
||||
}
|
||||
|
||||
truncate := func(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
|
|
@ -107,204 +123,136 @@ func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, fo
|
|||
return s[:n] + "...<truncated>"
|
||||
}
|
||||
|
||||
buildBody := func() map[string]interface{} {
|
||||
if isOpenAIStyle {
|
||||
return map[string]interface{}{
|
||||
"model": llm.Model,
|
||||
"messages": []map[string]string{
|
||||
{"role": "system", "content": "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + schemaDesc() + " Do not add explanations."},
|
||||
{"role": "user", "content": prompt},
|
||||
},
|
||||
"response_format": map[string]interface{}{"type": "json_object"},
|
||||
}
|
||||
}
|
||||
// This should never be reached in OpenAI client but keeping for safety
|
||||
return map[string]interface{}{
|
||||
"model": llm.Model,
|
||||
"messages": []map[string]string{{"role": "user", "content": prompt}},
|
||||
"stream": false,
|
||||
"format": format,
|
||||
}
|
||||
}
|
||||
|
||||
body := buildBody()
|
||||
|
||||
// Enhanced logging similar to the unified client
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
bodySize := len(jsonBody)
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "llm_request",
|
||||
"api_url": apiURL,
|
||||
"model": llm.Model,
|
||||
"is_openai_style": isOpenAIStyle,
|
||||
"prompt_len": len(prompt),
|
||||
"body_size": bodySize,
|
||||
}).Info("[LLM] sending request")
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
|
||||
if llm.APIKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if strings.Contains(apiURL, "openrouter.ai") {
|
||||
req.Header.Set("Referer", "https://github.com/")
|
||||
req.Header.Set("X-Title", "vetrag-app")
|
||||
// Build system message with schema if format is provided
|
||||
systemContent := "You are a helpful assistant."
|
||||
if format != nil {
|
||||
schemaJSON, _ := json.MarshalIndent(format, "", " ")
|
||||
systemContent = "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + string(schemaJSON) + " Do not add explanations."
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
|
||||
// Create the chat completion request
|
||||
req := openai.ChatCompletionRequest{
|
||||
Model: llm.Model,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: systemContent,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: prompt,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// If we have a format schema, set the response format to JSON
|
||||
if format != nil {
|
||||
req.ResponseFormat = &openai.ChatCompletionResponseFormat{
|
||||
Type: openai.ChatCompletionResponseFormatTypeJSONObject,
|
||||
}
|
||||
}
|
||||
|
||||
// Log the request
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "llm_request",
|
||||
"api_url": llm.BaseURL,
|
||||
"model": llm.Model,
|
||||
"prompt_len": len(prompt),
|
||||
}).Info("[LLM] sending request")
|
||||
|
||||
// Make the API call
|
||||
resp, err := llm.client.CreateChatCompletion(ctx, req)
|
||||
dur := time.Since(start)
|
||||
|
||||
// Handle errors
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "llm_response",
|
||||
"status": 0,
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
"error": err,
|
||||
"error": err.Error(),
|
||||
}).Error("[LLM] request failed")
|
||||
return "", err
|
||||
return "", fmt.Errorf("provider error: %w", err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
raw, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "llm_raw_response",
|
||||
"status": resp.StatusCode,
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
"raw_trunc": truncate(string(raw), 600),
|
||||
"raw_len": len(raw),
|
||||
}).Debug("[LLM] raw response body")
|
||||
|
||||
parseVariant := "unknown"
|
||||
|
||||
// Attempt OpenAI/OpenRouter style parse first
|
||||
var openAI struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &openAI); err == nil {
|
||||
if openAI.Error != nil || resp.StatusCode >= 400 {
|
||||
parseVariant = "openai"
|
||||
var msg string
|
||||
if openAI.Error != nil {
|
||||
msg = openAI.Error.Message
|
||||
} else {
|
||||
msg = string(raw)
|
||||
}
|
||||
// Extract content from response
|
||||
if len(resp.Choices) == 0 {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "llm_response",
|
||||
"status": resp.StatusCode,
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
"parse_variant": parseVariant,
|
||||
"error": msg,
|
||||
}).Error("[LLM] provider error")
|
||||
return "", fmt.Errorf("provider error: %s", msg)
|
||||
}).Error("[LLM] empty choices in response")
|
||||
return "", fmt.Errorf("provider error: no completion choices returned")
|
||||
}
|
||||
if len(openAI.Choices) > 0 && openAI.Choices[0].Message.Content != "" {
|
||||
parseVariant = "openai"
|
||||
content := openAI.Choices[0].Message.Content
|
||||
|
||||
content := resp.Choices[0].Message.Content
|
||||
|
||||
// Log successful response
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "llm_response",
|
||||
"status": resp.StatusCode,
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
"parse_variant": parseVariant,
|
||||
"content_len": len(content),
|
||||
"content_snip": truncate(content, 300),
|
||||
"finish_reason": resp.Choices[0].FinishReason,
|
||||
}).Info("[LLM] parsed response")
|
||||
|
||||
return content, nil
|
||||
}
|
||||
}
|
||||
|
||||
// As a fallback, attempt Ollama format parse
|
||||
var ollama struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" {
|
||||
parseVariant = "ollama"
|
||||
content := ollama.Message.Content
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "llm_response",
|
||||
"status": resp.StatusCode,
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
"parse_variant": parseVariant,
|
||||
"content_len": len(content),
|
||||
"content_snip": truncate(content, 300),
|
||||
}).Info("[LLM] parsed response")
|
||||
return content, nil
|
||||
}
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "llm_response",
|
||||
"status": resp.StatusCode,
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
"parse_variant": parseVariant,
|
||||
"raw_snip": truncate(string(raw), 300),
|
||||
}).Error("[LLM] unrecognized response format")
|
||||
|
||||
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
|
||||
}
|
||||
|
||||
func (llm *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
|
||||
apiURL := llm.BaseURL
|
||||
if apiURL == "" {
|
||||
apiURL = "https://api.openai.com/v1/embeddings"
|
||||
start := time.Now()
|
||||
|
||||
// Create embedding request
|
||||
req := openai.EmbeddingRequest{
|
||||
// Convert the string model to an EmbeddingModel type
|
||||
Model: openai.EmbeddingModel(llm.Model),
|
||||
Input: input,
|
||||
}
|
||||
body := map[string]interface{}{
|
||||
|
||||
// Log the request
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "embedding_request",
|
||||
"model": llm.Model,
|
||||
"input": input,
|
||||
}
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
|
||||
if llm.APIKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if strings.Contains(apiURL, "openrouter.ai") {
|
||||
req.Header.Set("Referer", "https://github.com/")
|
||||
req.Header.Set("X-Title", "vetrag-app")
|
||||
}
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
"input_len": len(input),
|
||||
}).Info("[LLM] sending embedding request")
|
||||
|
||||
// Make the API call
|
||||
resp, err := llm.client.CreateEmbeddings(ctx, req)
|
||||
dur := time.Since(start)
|
||||
|
||||
// Handle errors
|
||||
if err != nil {
|
||||
return nil, err
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "embedding_response",
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
"error": err.Error(),
|
||||
}).Error("[LLM] embedding request failed")
|
||||
return nil, fmt.Errorf("embedding error: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
raw, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
// Check if we got embeddings
|
||||
if len(resp.Data) == 0 {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "embedding_response",
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
}).Error("[LLM] empty embedding data in response")
|
||||
return nil, fmt.Errorf("embedding error: no embedding data returned")
|
||||
}
|
||||
var openAI struct {
|
||||
Data []struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
|
||||
// Convert []float32 to []float64
|
||||
embeddings := make([]float64, len(resp.Data[0].Embedding))
|
||||
for i, v := range resp.Data[0].Embedding {
|
||||
embeddings[i] = float64(v)
|
||||
}
|
||||
if err := json.Unmarshal(raw, &openAI); err == nil && len(openAI.Data) > 0 {
|
||||
return openAI.Data[0].Embedding, nil
|
||||
}
|
||||
if openAI.Error != nil {
|
||||
return nil, fmt.Errorf("embedding error: %s", openAI.Error.Message)
|
||||
}
|
||||
return nil, fmt.Errorf("unrecognized embedding response: %.200s", string(raw))
|
||||
|
||||
// Log successful response
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "embedding_response",
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
"vector_size": len(embeddings),
|
||||
}).Info("[LLM] embedding response")
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (llm *OpenAIClient) TranslateToEnglish(ctx context.Context, message string) (string, error) {
|
||||
|
|
@ -322,3 +270,18 @@ func (llm *OpenAIClient) TranslateToEnglish(ctx context.Context, message string)
|
|||
}
|
||||
return strings.TrimSpace(resp), nil
|
||||
}
|
||||
|
||||
// customTransport is an http.RoundTripper that adds custom headers to requests.
|
||||
type customTransport struct {
|
||||
base http.RoundTripper
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func (t *customTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Add custom headers to the request
|
||||
for key, value := range t.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
// Call the base RoundTripper
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue