288 lines
8.4 KiB
Go
288 lines
8.4 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/sashabaranov/go-openai"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// --- OpenAIClient implementation ---
|
|
|
|
type OpenAIClient struct {
|
|
APIKey string
|
|
BaseURL string
|
|
Model string
|
|
Repo ChatRepositoryAPI
|
|
client *openai.Client
|
|
}
|
|
|
|
func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient {
|
|
config := openai.DefaultConfig(apiKey)
|
|
if baseURL != "" {
|
|
config.BaseURL = baseURL
|
|
}
|
|
|
|
// Special handling for OpenRouter
|
|
// Create a new HTTP client with custom headers
|
|
httpClient := &http.Client{}
|
|
if strings.Contains(strings.ToLower(baseURL), "openrouter.ai") {
|
|
// Use custom transport to add OpenRouter-specific headers
|
|
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
defaultTransport.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
|
httpClient.Transport = &customTransport{
|
|
base: defaultTransport,
|
|
headers: map[string]string{
|
|
"Referer": "https://github.com/",
|
|
"X-Title": "vetrag-app",
|
|
},
|
|
}
|
|
config.HTTPClient = httpClient
|
|
}
|
|
|
|
client := openai.NewClientWithConfig(config)
|
|
return &OpenAIClient{
|
|
APIKey: apiKey,
|
|
BaseURL: baseURL,
|
|
Model: model,
|
|
Repo: repo,
|
|
client: client,
|
|
}
|
|
}
|
|
|
|
func (llm *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
|
_, parsed, err := llm.ExtractKeywordsRaw(ctx, message)
|
|
return parsed, err
|
|
}
|
|
|
|
func (llm *OpenAIClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) {
|
|
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
|
|
if err != nil {
|
|
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
|
|
return "", nil, err
|
|
}
|
|
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
|
|
|
|
// Format remains the same
|
|
format := GetExtractKeywordsFormat()
|
|
|
|
resp, err := llm.openAICompletion(ctx, prompt, format)
|
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
|
|
if err != nil {
|
|
return resp, nil, err
|
|
}
|
|
var result map[string]interface{}
|
|
if err := json.Unmarshal([]byte(resp), &result); err != nil {
|
|
return resp, nil, err
|
|
}
|
|
return resp, result, nil
|
|
}
|
|
|
|
func (llm *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
|
|
_, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates)
|
|
return vr, err
|
|
}
|
|
|
|
func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
|
|
format := GetDisambiguateFormat()
|
|
|
|
entries, _ := json.Marshal(candidates)
|
|
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
|
|
if err != nil {
|
|
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
|
|
return "", "", err
|
|
}
|
|
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
|
|
resp, err := llm.openAICompletion(ctx, prompt, format)
|
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
|
|
if err != nil {
|
|
return resp, "", err
|
|
}
|
|
var parsed map[string]string
|
|
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
|
|
return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
|
|
}
|
|
visitReason := strings.TrimSpace(parsed["visitReason"])
|
|
if visitReason == "" {
|
|
return resp, "", fmt.Errorf("visitReason not found in response")
|
|
}
|
|
return resp, visitReason, nil
|
|
}
|
|
|
|
func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
|
|
truncate := func(s string, n int) string {
|
|
if len(s) <= n {
|
|
return s
|
|
}
|
|
return s[:n] + "...<truncated>"
|
|
}
|
|
|
|
// Build system message with schema if format is provided
|
|
systemContent := "You are a helpful assistant."
|
|
if format != nil {
|
|
schemaJSON, _ := json.MarshalIndent(format, "", " ")
|
|
systemContent = "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + string(schemaJSON) + " Do not add explanations."
|
|
}
|
|
|
|
start := time.Now()
|
|
|
|
// Create the chat completion request
|
|
req := openai.ChatCompletionRequest{
|
|
Model: llm.Model,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleSystem,
|
|
Content: systemContent,
|
|
},
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: prompt,
|
|
},
|
|
},
|
|
}
|
|
|
|
// If we have a format schema, set the response format to JSON
|
|
if format != nil {
|
|
req.ResponseFormat = &openai.ChatCompletionResponseFormat{
|
|
Type: openai.ChatCompletionResponseFormatTypeJSONObject,
|
|
}
|
|
}
|
|
|
|
// Log the request
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "llm_request",
|
|
"api_url": llm.BaseURL,
|
|
"model": llm.Model,
|
|
"prompt_len": len(prompt),
|
|
}).Info("[LLM] sending request")
|
|
|
|
// Make the API call
|
|
resp, err := llm.client.CreateChatCompletion(ctx, req)
|
|
dur := time.Since(start)
|
|
|
|
// Handle errors
|
|
if err != nil {
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "llm_response",
|
|
"latency_ms": dur.Milliseconds(),
|
|
"error": err.Error(),
|
|
}).Error("[LLM] request failed")
|
|
return "", fmt.Errorf("provider error: %w", err)
|
|
}
|
|
|
|
// Extract content from response
|
|
if len(resp.Choices) == 0 {
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "llm_response",
|
|
"latency_ms": dur.Milliseconds(),
|
|
}).Error("[LLM] empty choices in response")
|
|
return "", fmt.Errorf("provider error: no completion choices returned")
|
|
}
|
|
|
|
content := resp.Choices[0].Message.Content
|
|
|
|
// Log successful response
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "llm_response",
|
|
"latency_ms": dur.Milliseconds(),
|
|
"content_len": len(content),
|
|
"content_snip": truncate(content, 300),
|
|
"finish_reason": resp.Choices[0].FinishReason,
|
|
}).Info("[LLM] parsed response")
|
|
|
|
return content, nil
|
|
}
|
|
|
|
func (llm *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
|
|
start := time.Now()
|
|
|
|
// Create embedding request
|
|
req := openai.EmbeddingRequest{
|
|
// Convert the string model to an EmbeddingModel type
|
|
Model: openai.EmbeddingModel(llm.Model),
|
|
Input: input,
|
|
}
|
|
|
|
// Log the request
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "embedding_request",
|
|
"model": llm.Model,
|
|
"input_len": len(input),
|
|
}).Info("[LLM] sending embedding request")
|
|
|
|
// Make the API call
|
|
resp, err := llm.client.CreateEmbeddings(ctx, req)
|
|
dur := time.Since(start)
|
|
|
|
// Handle errors
|
|
if err != nil {
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "embedding_response",
|
|
"latency_ms": dur.Milliseconds(),
|
|
"error": err.Error(),
|
|
}).Error("[LLM] embedding request failed")
|
|
return nil, fmt.Errorf("embedding error: %w", err)
|
|
}
|
|
|
|
// Check if we got embeddings
|
|
if len(resp.Data) == 0 {
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "embedding_response",
|
|
"latency_ms": dur.Milliseconds(),
|
|
}).Error("[LLM] empty embedding data in response")
|
|
return nil, fmt.Errorf("embedding error: no embedding data returned")
|
|
}
|
|
|
|
// Convert []float32 to []float64
|
|
embeddings := make([]float64, len(resp.Data[0].Embedding))
|
|
for i, v := range resp.Data[0].Embedding {
|
|
embeddings[i] = float64(v)
|
|
}
|
|
|
|
// Log successful response
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "embedding_response",
|
|
"latency_ms": dur.Milliseconds(),
|
|
"vector_size": len(embeddings),
|
|
}).Info("[LLM] embedding response")
|
|
|
|
return embeddings, nil
|
|
}
|
|
|
|
func (llm *OpenAIClient) TranslateToEnglish(ctx context.Context, message string) (string, error) {
|
|
prompt, err := renderPrompt(appConfig.LLM.TranslatePrompt, map[string]string{"Message": message})
|
|
if err != nil {
|
|
logrus.WithError(err).Error("[CONFIG] Failed to render Translate prompt")
|
|
return "", err
|
|
}
|
|
logrus.WithField("prompt", prompt).Info("[LLM] TranslateToEnglish prompt")
|
|
|
|
resp, err := llm.openAICompletion(ctx, prompt, nil)
|
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] TranslateToEnglish response")
|
|
if err != nil {
|
|
return resp, err
|
|
}
|
|
return strings.TrimSpace(resp), nil
|
|
}
|
|
|
|
// customTransport is an http.RoundTripper that adds custom headers to requests.
|
|
type customTransport struct {
|
|
base http.RoundTripper
|
|
headers map[string]string
|
|
}
|
|
|
|
func (t *customTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
// Add custom headers to the request
|
|
for key, value := range t.headers {
|
|
req.Header.Set(key, value)
|
|
}
|
|
// Call the base RoundTripper
|
|
return t.base.RoundTrip(req)
|
|
}
|