test move to client lib for openai
This commit is contained in:
parent
1484b519d7
commit
e69201e3e9
1
go.mod
1
go.mod
|
|
@ -9,6 +9,7 @@ require (
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/jackc/pgx/v5 v5.7.5
|
github.com/jackc/pgx/v5 v5.7.5
|
||||||
github.com/pressly/goose/v3 v3.26.0
|
github.com/pressly/goose/v3 v3.26.0
|
||||||
|
github.com/sashabaranov/go-openai v1.41.2
|
||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
golang.org/x/crypto v0.42.0
|
golang.org/x/crypto v0.42.0
|
||||||
|
|
|
||||||
2
go.sum
2
go.sum
|
|
@ -124,6 +124,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||||
|
github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM=
|
||||||
|
github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||||
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
|
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
|
||||||
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
|
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
|
||||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
|
|
|
||||||
341
openai_client.go
341
openai_client.go
|
|
@ -1,15 +1,15 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -20,10 +20,40 @@ type OpenAIClient struct {
|
||||||
BaseURL string
|
BaseURL string
|
||||||
Model string
|
Model string
|
||||||
Repo ChatRepositoryAPI
|
Repo ChatRepositoryAPI
|
||||||
|
client *openai.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient {
|
func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient {
|
||||||
return &OpenAIClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo}
|
config := openai.DefaultConfig(apiKey)
|
||||||
|
if baseURL != "" {
|
||||||
|
config.BaseURL = baseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special handling for OpenRouter
|
||||||
|
// Create a new HTTP client with custom headers
|
||||||
|
httpClient := &http.Client{}
|
||||||
|
if strings.Contains(strings.ToLower(baseURL), "openrouter.ai") {
|
||||||
|
// Use custom transport to add OpenRouter-specific headers
|
||||||
|
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
defaultTransport.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||||
|
httpClient.Transport = &customTransport{
|
||||||
|
base: defaultTransport,
|
||||||
|
headers: map[string]string{
|
||||||
|
"Referer": "https://github.com/",
|
||||||
|
"X-Title": "vetrag-app",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
config.HTTPClient = httpClient
|
||||||
|
}
|
||||||
|
|
||||||
|
client := openai.NewClientWithConfig(config)
|
||||||
|
return &OpenAIClient{
|
||||||
|
APIKey: apiKey,
|
||||||
|
BaseURL: baseURL,
|
||||||
|
Model: model,
|
||||||
|
Repo: repo,
|
||||||
|
client: client,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
func (llm *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
||||||
|
|
@ -39,7 +69,7 @@ func (llm *OpenAIClient) ExtractKeywordsRaw(ctx context.Context, message string)
|
||||||
}
|
}
|
||||||
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
|
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
|
||||||
|
|
||||||
// Use the utility function instead of inline format definition
|
// Format remains the same
|
||||||
format := GetExtractKeywordsFormat()
|
format := GetExtractKeywordsFormat()
|
||||||
|
|
||||||
resp, err := llm.openAICompletion(ctx, prompt, format)
|
resp, err := llm.openAICompletion(ctx, prompt, format)
|
||||||
|
|
@ -60,7 +90,6 @@ func (llm *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message stri
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
|
func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
|
||||||
// Use the utility function instead of inline format definition
|
|
||||||
format := GetDisambiguateFormat()
|
format := GetDisambiguateFormat()
|
||||||
|
|
||||||
entries, _ := json.Marshal(candidates)
|
entries, _ := json.Marshal(candidates)
|
||||||
|
|
@ -87,19 +116,6 @@ func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
|
func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
|
||||||
apiURL := llm.BaseURL
|
|
||||||
if apiURL == "" {
|
|
||||||
apiURL = "https://api.openai.com/v1/chat/completions"
|
|
||||||
}
|
|
||||||
|
|
||||||
isOpenAIStyle := strings.Contains(apiURL, "openrouter.ai") || strings.Contains(apiURL, "/v1/")
|
|
||||||
|
|
||||||
// Helper to stringify the expected JSON schema for instructions
|
|
||||||
schemaDesc := func() string {
|
|
||||||
b, _ := json.MarshalIndent(format, "", " ")
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
truncate := func(s string, n int) string {
|
truncate := func(s string, n int) string {
|
||||||
if len(s) <= n {
|
if len(s) <= n {
|
||||||
return s
|
return s
|
||||||
|
|
@ -107,204 +123,136 @@ func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, fo
|
||||||
return s[:n] + "...<truncated>"
|
return s[:n] + "...<truncated>"
|
||||||
}
|
}
|
||||||
|
|
||||||
buildBody := func() map[string]interface{} {
|
// Build system message with schema if format is provided
|
||||||
if isOpenAIStyle {
|
systemContent := "You are a helpful assistant."
|
||||||
return map[string]interface{}{
|
if format != nil {
|
||||||
"model": llm.Model,
|
schemaJSON, _ := json.MarshalIndent(format, "", " ")
|
||||||
"messages": []map[string]string{
|
systemContent = "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + string(schemaJSON) + " Do not add explanations."
|
||||||
{"role": "system", "content": "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + schemaDesc() + " Do not add explanations."},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
},
|
|
||||||
"response_format": map[string]interface{}{"type": "json_object"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// This should never be reached in OpenAI client but keeping for safety
|
|
||||||
return map[string]interface{}{
|
|
||||||
"model": llm.Model,
|
|
||||||
"messages": []map[string]string{{"role": "user", "content": prompt}},
|
|
||||||
"stream": false,
|
|
||||||
"format": format,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
body := buildBody()
|
|
||||||
|
|
||||||
// Enhanced logging similar to the unified client
|
|
||||||
jsonBody, _ := json.Marshal(body)
|
|
||||||
bodySize := len(jsonBody)
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"event": "llm_request",
|
|
||||||
"api_url": apiURL,
|
|
||||||
"model": llm.Model,
|
|
||||||
"is_openai_style": isOpenAIStyle,
|
|
||||||
"prompt_len": len(prompt),
|
|
||||||
"body_size": bodySize,
|
|
||||||
}).Info("[LLM] sending request")
|
|
||||||
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
|
|
||||||
if llm.APIKey != "" {
|
|
||||||
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
if strings.Contains(apiURL, "openrouter.ai") {
|
|
||||||
req.Header.Set("Referer", "https://github.com/")
|
|
||||||
req.Header.Set("X-Title", "vetrag-app")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
// Create the chat completion request
|
||||||
|
req := openai.ChatCompletionRequest{
|
||||||
|
Model: llm.Model,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleSystem,
|
||||||
|
Content: systemContent,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: prompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have a format schema, set the response format to JSON
|
||||||
|
if format != nil {
|
||||||
|
req.ResponseFormat = &openai.ChatCompletionResponseFormat{
|
||||||
|
Type: openai.ChatCompletionResponseFormatTypeJSONObject,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log the request
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"event": "llm_request",
|
||||||
|
"api_url": llm.BaseURL,
|
||||||
|
"model": llm.Model,
|
||||||
|
"prompt_len": len(prompt),
|
||||||
|
}).Info("[LLM] sending request")
|
||||||
|
|
||||||
|
// Make the API call
|
||||||
|
resp, err := llm.client.CreateChatCompletion(ctx, req)
|
||||||
dur := time.Since(start)
|
dur := time.Since(start)
|
||||||
|
|
||||||
|
// Handle errors
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"event": "llm_response",
|
"event": "llm_response",
|
||||||
"status": 0,
|
|
||||||
"latency_ms": dur.Milliseconds(),
|
"latency_ms": dur.Milliseconds(),
|
||||||
"error": err,
|
"error": err.Error(),
|
||||||
}).Error("[LLM] request failed")
|
}).Error("[LLM] request failed")
|
||||||
return "", err
|
return "", fmt.Errorf("provider error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
// Extract content from response
|
||||||
raw, err := io.ReadAll(resp.Body)
|
if len(resp.Choices) == 0 {
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"event": "llm_raw_response",
|
|
||||||
"status": resp.StatusCode,
|
|
||||||
"latency_ms": dur.Milliseconds(),
|
|
||||||
"raw_trunc": truncate(string(raw), 600),
|
|
||||||
"raw_len": len(raw),
|
|
||||||
}).Debug("[LLM] raw response body")
|
|
||||||
|
|
||||||
parseVariant := "unknown"
|
|
||||||
|
|
||||||
// Attempt OpenAI/OpenRouter style parse first
|
|
||||||
var openAI struct {
|
|
||||||
Choices []struct {
|
|
||||||
Message struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
} `json:"message"`
|
|
||||||
} `json:"choices"`
|
|
||||||
Error *struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
} `json:"error"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(raw, &openAI); err == nil {
|
|
||||||
if openAI.Error != nil || resp.StatusCode >= 400 {
|
|
||||||
parseVariant = "openai"
|
|
||||||
var msg string
|
|
||||||
if openAI.Error != nil {
|
|
||||||
msg = openAI.Error.Message
|
|
||||||
} else {
|
|
||||||
msg = string(raw)
|
|
||||||
}
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"event": "llm_response",
|
|
||||||
"status": resp.StatusCode,
|
|
||||||
"latency_ms": dur.Milliseconds(),
|
|
||||||
"parse_variant": parseVariant,
|
|
||||||
"error": msg,
|
|
||||||
}).Error("[LLM] provider error")
|
|
||||||
return "", fmt.Errorf("provider error: %s", msg)
|
|
||||||
}
|
|
||||||
if len(openAI.Choices) > 0 && openAI.Choices[0].Message.Content != "" {
|
|
||||||
parseVariant = "openai"
|
|
||||||
content := openAI.Choices[0].Message.Content
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"event": "llm_response",
|
|
||||||
"status": resp.StatusCode,
|
|
||||||
"latency_ms": dur.Milliseconds(),
|
|
||||||
"parse_variant": parseVariant,
|
|
||||||
"content_len": len(content),
|
|
||||||
"content_snip": truncate(content, 300),
|
|
||||||
}).Info("[LLM] parsed response")
|
|
||||||
return content, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// As a fallback, attempt Ollama format parse
|
|
||||||
var ollama struct {
|
|
||||||
Message struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
} `json:"message"`
|
|
||||||
Error string `json:"error"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" {
|
|
||||||
parseVariant = "ollama"
|
|
||||||
content := ollama.Message.Content
|
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"event": "llm_response",
|
"event": "llm_response",
|
||||||
"status": resp.StatusCode,
|
"latency_ms": dur.Milliseconds(),
|
||||||
"latency_ms": dur.Milliseconds(),
|
}).Error("[LLM] empty choices in response")
|
||||||
"parse_variant": parseVariant,
|
return "", fmt.Errorf("provider error: no completion choices returned")
|
||||||
"content_len": len(content),
|
|
||||||
"content_snip": truncate(content, 300),
|
|
||||||
}).Info("[LLM] parsed response")
|
|
||||||
return content, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
content := resp.Choices[0].Message.Content
|
||||||
|
|
||||||
|
// Log successful response
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"event": "llm_response",
|
"event": "llm_response",
|
||||||
"status": resp.StatusCode,
|
|
||||||
"latency_ms": dur.Milliseconds(),
|
"latency_ms": dur.Milliseconds(),
|
||||||
"parse_variant": parseVariant,
|
"content_len": len(content),
|
||||||
"raw_snip": truncate(string(raw), 300),
|
"content_snip": truncate(content, 300),
|
||||||
}).Error("[LLM] unrecognized response format")
|
"finish_reason": resp.Choices[0].FinishReason,
|
||||||
|
}).Info("[LLM] parsed response")
|
||||||
|
|
||||||
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
|
return content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
|
func (llm *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
|
||||||
apiURL := llm.BaseURL
|
start := time.Now()
|
||||||
if apiURL == "" {
|
|
||||||
apiURL = "https://api.openai.com/v1/embeddings"
|
// Create embedding request
|
||||||
|
req := openai.EmbeddingRequest{
|
||||||
|
// Convert the string model to an EmbeddingModel type
|
||||||
|
Model: openai.EmbeddingModel(llm.Model),
|
||||||
|
Input: input,
|
||||||
}
|
}
|
||||||
body := map[string]interface{}{
|
|
||||||
"model": llm.Model,
|
// Log the request
|
||||||
"input": input,
|
logrus.WithFields(logrus.Fields{
|
||||||
}
|
"event": "embedding_request",
|
||||||
jsonBody, _ := json.Marshal(body)
|
"model": llm.Model,
|
||||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
|
"input_len": len(input),
|
||||||
if llm.APIKey != "" {
|
}).Info("[LLM] sending embedding request")
|
||||||
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
|
||||||
}
|
// Make the API call
|
||||||
req.Header.Set("Content-Type", "application/json")
|
resp, err := llm.client.CreateEmbeddings(ctx, req)
|
||||||
req.Header.Set("Accept", "application/json")
|
dur := time.Since(start)
|
||||||
if strings.Contains(apiURL, "openrouter.ai") {
|
|
||||||
req.Header.Set("Referer", "https://github.com/")
|
// Handle errors
|
||||||
req.Header.Set("X-Title", "vetrag-app")
|
|
||||||
}
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"event": "embedding_response",
|
||||||
|
"latency_ms": dur.Milliseconds(),
|
||||||
|
"error": err.Error(),
|
||||||
|
}).Error("[LLM] embedding request failed")
|
||||||
|
return nil, fmt.Errorf("embedding error: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
|
||||||
raw, err := io.ReadAll(resp.Body)
|
// Check if we got embeddings
|
||||||
if err != nil {
|
if len(resp.Data) == 0 {
|
||||||
return nil, err
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"event": "embedding_response",
|
||||||
|
"latency_ms": dur.Milliseconds(),
|
||||||
|
}).Error("[LLM] empty embedding data in response")
|
||||||
|
return nil, fmt.Errorf("embedding error: no embedding data returned")
|
||||||
}
|
}
|
||||||
var openAI struct {
|
|
||||||
Data []struct {
|
// Convert []float32 to []float64
|
||||||
Embedding []float64 `json:"embedding"`
|
embeddings := make([]float64, len(resp.Data[0].Embedding))
|
||||||
} `json:"data"`
|
for i, v := range resp.Data[0].Embedding {
|
||||||
Error *struct {
|
embeddings[i] = float64(v)
|
||||||
Message string `json:"message"`
|
|
||||||
} `json:"error"`
|
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(raw, &openAI); err == nil && len(openAI.Data) > 0 {
|
|
||||||
return openAI.Data[0].Embedding, nil
|
// Log successful response
|
||||||
}
|
logrus.WithFields(logrus.Fields{
|
||||||
if openAI.Error != nil {
|
"event": "embedding_response",
|
||||||
return nil, fmt.Errorf("embedding error: %s", openAI.Error.Message)
|
"latency_ms": dur.Milliseconds(),
|
||||||
}
|
"vector_size": len(embeddings),
|
||||||
return nil, fmt.Errorf("unrecognized embedding response: %.200s", string(raw))
|
}).Info("[LLM] embedding response")
|
||||||
|
|
||||||
|
return embeddings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *OpenAIClient) TranslateToEnglish(ctx context.Context, message string) (string, error) {
|
func (llm *OpenAIClient) TranslateToEnglish(ctx context.Context, message string) (string, error) {
|
||||||
|
|
@ -322,3 +270,18 @@ func (llm *OpenAIClient) TranslateToEnglish(ctx context.Context, message string)
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(resp), nil
|
return strings.TrimSpace(resp), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// customTransport is an http.RoundTripper that adds custom headers to requests.
|
||||||
|
type customTransport struct {
|
||||||
|
base http.RoundTripper
|
||||||
|
headers map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *customTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
// Add custom headers to the request
|
||||||
|
for key, value := range t.headers {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
// Call the base RoundTripper
|
||||||
|
return t.base.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue