This commit is contained in:
lehel 2025-10-09 13:17:42 +02:00
parent e69201e3e9
commit 92bcf66766
No known key found for this signature in database
GPG Key ID: 9C4F9D6111EE5CFA
2 changed files with 214 additions and 161 deletions

View File

@ -73,7 +73,7 @@ func main() {
llm := NewLLMClientFromEnv(repo) llm := NewLLMClientFromEnv(repo)
// Launch background backfill of sentence embeddings (non-blocking) // Launch background backfill of sentence embeddings (non-blocking)
startSentenceEmbeddingBackfill(repo, llm, &visitDB) //startSentenceEmbeddingBackfill(repo, llm, &visitDB)
// Wrap templates for controller // Wrap templates for controller
uiTmpl := &TemplateWrapper{Tmpl: uiTemplate} uiTmpl := &TemplateWrapper{Tmpl: uiTemplate}

View File

@ -2,8 +2,8 @@ package main
import ( import (
"context" "context"
"crypto/tls"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
@ -13,39 +13,48 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// --- OpenAIClient implementation --- // Constants for OpenAI client
const (
// DefaultMaxTokens defines the default maximum number of tokens for completions
DefaultMaxTokens = 1500
// DefaultTemperature defines the default temperature for model responses
DefaultTemperature = 0.7
// OpenRouterDomain is used to detect if we're using OpenRouter
OpenRouterDomain = "openrouter.ai"
)
// OpenAIClient implements the LLMClientAPI interface using OpenAI's API
type OpenAIClient struct { type OpenAIClient struct {
APIKey string APIKey string
BaseURL string BaseURL string
Model string Model string
Repo ChatRepositoryAPI Repo ChatRepositoryAPI
client *openai.Client client *openai.Client
logger *logrus.Entry
} }
// NewOpenAIClient creates a new OpenAI client with the provided configuration
func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient { func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient {
config := openai.DefaultConfig(apiKey) config := openai.DefaultConfig(apiKey)
// Set custom base URL if provided
if baseURL != "" { if baseURL != "" {
config.BaseURL = baseURL config.BaseURL = normalizeBaseURL(baseURL)
} }
// Special handling for OpenRouter // Configure HTTP client with appropriate headers for OpenRouter if needed
// Create a new HTTP client with custom headers if isOpenRouter(baseURL) {
httpClient := &http.Client{} config.HTTPClient = createOpenRouterHTTPClient()
if strings.Contains(strings.ToLower(baseURL), "openrouter.ai") {
// Use custom transport to add OpenRouter-specific headers
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
defaultTransport.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12}
httpClient.Transport = &customTransport{
base: defaultTransport,
headers: map[string]string{
"Referer": "https://github.com/",
"X-Title": "vetrag-app",
},
}
config.HTTPClient = httpClient
} }
logger := logrus.WithFields(logrus.Fields{
"component": "openai_client",
"model": model,
"base_url": config.BaseURL,
})
logger.Info("Initializing OpenAI client")
client := openai.NewClientWithConfig(config) client := openai.NewClientWithConfig(config)
return &OpenAIClient{ return &OpenAIClient{
APIKey: apiKey, APIKey: apiKey,
@ -53,88 +62,158 @@ func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *Ope
Model: model, Model: model,
Repo: repo, Repo: repo,
client: client, client: client,
logger: logger,
} }
} }
func (llm *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) { // ExtractKeywords extracts keywords from a message
_, parsed, err := llm.ExtractKeywordsRaw(ctx, message) func (c *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
return parsed, err c.logger.WithField("message_length", len(message)).Debug("Extracting keywords")
}
func (llm *OpenAIClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) {
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message}) prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
if err != nil { if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt") c.logger.WithError(err).Error("Failed to render ExtractKeywords prompt")
return "", nil, err return nil, fmt.Errorf("failed to render prompt: %w", err)
} }
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
// Format remains the same
format := GetExtractKeywordsFormat() format := GetExtractKeywordsFormat()
c.logger.WithField("prompt", prompt).Debug("ExtractKeywords prompt prepared")
resp, err := llm.openAICompletion(ctx, prompt, format) resp, err := c.createCompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
if err != nil { if err != nil {
return resp, nil, err return nil, err
} }
var result map[string]interface{} var result map[string]interface{}
if err := json.Unmarshal([]byte(resp), &result); err != nil { if err := json.Unmarshal([]byte(resp), &result); err != nil {
return resp, nil, err c.logger.WithError(err).Error("Failed to parse ExtractKeywords response")
} return nil, fmt.Errorf("failed to parse response: %w", err)
return resp, result, nil
} }
func (llm *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) { return result, nil
_, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates)
return vr, err
} }
func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) { // DisambiguateBestMatch finds the best match among candidates for a message
func (c *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
c.logger.WithFields(logrus.Fields{
"message_length": len(message),
"candidates": len(candidates),
}).Debug("Disambiguating best match")
format := GetDisambiguateFormat() format := GetDisambiguateFormat()
entries, _ := json.Marshal(candidates) entries, err := json.Marshal(candidates)
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
if err != nil { if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt") c.logger.WithError(err).Error("Failed to marshal candidates")
return "", "", err return "", fmt.Errorf("failed to marshal candidates: %w", err)
}
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
resp, err := llm.openAICompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
if err != nil {
return resp, "", err
}
var parsed map[string]string
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
}
visitReason := strings.TrimSpace(parsed["visitReason"])
if visitReason == "" {
return resp, "", fmt.Errorf("visitReason not found in response")
}
return resp, visitReason, nil
} }
func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) { prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{
truncate := func(s string, n int) string { "Entries": string(entries),
if len(s) <= n { "Message": message,
return s })
if err != nil {
c.logger.WithError(err).Error("Failed to render Disambiguate prompt")
return "", fmt.Errorf("failed to render prompt: %w", err)
} }
return s[:n] + "...<truncated>"
c.logger.WithField("prompt", prompt).Debug("DisambiguateBestMatch prompt prepared")
resp, err := c.createCompletion(ctx, prompt, format)
if err != nil {
return "", err
} }
var parsed map[string]string
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
c.logger.WithError(err).Error("Failed to parse disambiguation response")
return "", fmt.Errorf("failed to parse response: %w", err)
}
visitReason := strings.TrimSpace(parsed["visitReason"])
if visitReason == "" {
return "", errors.New("visitReason not found in response")
}
return visitReason, nil
}
// TranslateToEnglish translates a message to English
func (c *OpenAIClient) TranslateToEnglish(ctx context.Context, message string) (string, error) {
c.logger.WithField("message_length", len(message)).Debug("Translating to English")
prompt, err := renderPrompt(appConfig.LLM.TranslatePrompt, map[string]string{"Message": message})
if err != nil {
c.logger.WithError(err).Error("Failed to render Translate prompt")
return "", fmt.Errorf("failed to render prompt: %w", err)
}
c.logger.WithField("prompt", prompt).Debug("TranslateToEnglish prompt prepared")
resp, err := c.createCompletion(ctx, prompt, nil)
if err != nil {
return "", err
}
return strings.TrimSpace(resp), nil
}
// GetEmbeddings generates embeddings for the input text
func (c *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
start := time.Now()
c.logger.WithField("input_length", len(input)).Debug("Generating embeddings")
// Create embedding request
req := openai.EmbeddingRequest{
Model: openai.EmbeddingModel(c.Model),
Input: input,
}
// Make the API call
resp, err := c.client.CreateEmbeddings(ctx, req)
duration := time.Since(start)
if err != nil {
c.logger.WithFields(logrus.Fields{
"latency_ms": duration.Milliseconds(),
"error": err.Error(),
}).Error("Embedding request failed")
return nil, fmt.Errorf("embedding error: %w", err)
}
if len(resp.Data) == 0 {
c.logger.WithField("latency_ms", duration.Milliseconds()).Error("Empty embedding data in response")
return nil, errors.New("embedding error: no embedding data returned")
}
// Convert []float32 to []float64
embeddings := make([]float64, len(resp.Data[0].Embedding))
for i, v := range resp.Data[0].Embedding {
embeddings[i] = float64(v)
}
c.logger.WithFields(logrus.Fields{
"latency_ms": duration.Milliseconds(),
"vector_size": len(embeddings),
}).Debug("Embedding generated successfully")
return embeddings, nil
}
// createCompletion creates a chat completion with the given prompt and format
func (c *OpenAIClient) createCompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
start := time.Now()
// Build system message with schema if format is provided // Build system message with schema if format is provided
systemContent := "You are a helpful assistant." systemContent := "You are a helpful assistant."
if format != nil { if format != nil {
schemaJSON, _ := json.MarshalIndent(format, "", " ") schemaJSON, _ := json.MarshalIndent(format, "", " ")
systemContent = "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + string(schemaJSON) + " Do not add explanations." systemContent = fmt.Sprintf("You are a strict JSON generator. ONLY output valid JSON matching this schema: %s Do not add explanations.", string(schemaJSON))
} }
start := time.Now()
// Create the chat completion request // Create the chat completion request
req := openai.ChatCompletionRequest{ req := openai.ChatCompletionRequest{
Model: llm.Model, Model: c.Model,
Messages: []openai.ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: openai.ChatMessageRoleSystem, Role: openai.ChatMessageRoleSystem,
@ -145,143 +224,117 @@ func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, fo
Content: prompt, Content: prompt,
}, },
}, },
Temperature: DefaultTemperature,
MaxTokens: DefaultMaxTokens,
} }
// If we have a format schema, set the response format to JSON // Set response format to JSON if we have a schema and we're not using a third-party model
if format != nil { isThirdPartyModel := strings.Contains(c.Model, "/")
if format != nil && !isThirdPartyModel {
req.ResponseFormat = &openai.ChatCompletionResponseFormat{ req.ResponseFormat = &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONObject, Type: openai.ChatCompletionResponseFormatTypeJSONObject,
} }
} }
// Log the request // Log request details
logrus.WithFields(logrus.Fields{ c.logger.WithFields(logrus.Fields{
"event": "llm_request", "model": c.Model,
"api_url": llm.BaseURL,
"model": llm.Model,
"prompt_len": len(prompt), "prompt_len": len(prompt),
}).Info("[LLM] sending request") }).Debug("Sending completion request")
// Make the API call // Make the API call
resp, err := llm.client.CreateChatCompletion(ctx, req) resp, err := c.client.CreateChatCompletion(ctx, req)
dur := time.Since(start) duration := time.Since(start)
// Handle errors // Handle errors
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ c.logger.WithFields(logrus.Fields{
"event": "llm_response", "latency_ms": duration.Milliseconds(),
"latency_ms": dur.Milliseconds(),
"error": err.Error(), "error": err.Error(),
}).Error("[LLM] request failed") }).Error("Completion request failed")
return "", fmt.Errorf("provider error: %w", err) return "", fmt.Errorf("completion error: %w", err)
} }
// Extract content from response // Check if we got a response
if len(resp.Choices) == 0 { if len(resp.Choices) == 0 {
logrus.WithFields(logrus.Fields{ c.logger.WithField("latency_ms", duration.Milliseconds()).Error("Empty choices in completion response")
"event": "llm_response", return "", errors.New("completion error: no completion choices returned")
"latency_ms": dur.Milliseconds(),
}).Error("[LLM] empty choices in response")
return "", fmt.Errorf("provider error: no completion choices returned")
} }
// Extract and clean content
content := resp.Choices[0].Message.Content content := resp.Choices[0].Message.Content
content = cleanJSONResponse(content, format != nil)
// Log successful response c.logger.WithFields(logrus.Fields{
logrus.WithFields(logrus.Fields{ "latency_ms": duration.Milliseconds(),
"event": "llm_response",
"latency_ms": dur.Milliseconds(),
"content_len": len(content), "content_len": len(content),
"content_snip": truncate(content, 300),
"finish_reason": resp.Choices[0].FinishReason, "finish_reason": resp.Choices[0].FinishReason,
}).Info("[LLM] parsed response") }).Debug("Completion request successful")
return content, nil return content, nil
} }
func (llm *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) { // --- Helper functions ---
start := time.Now()
// Create embedding request // normalizeBaseURL ensures the base URL is properly formatted
req := openai.EmbeddingRequest{ func normalizeBaseURL(url string) string {
// Convert the string model to an EmbeddingModel type url = strings.TrimSpace(url)
Model: openai.EmbeddingModel(llm.Model), url = strings.TrimRight(url, "/")
Input: input,
// Remove path components that will be added by the client
if strings.HasSuffix(url, "/chat/completions") {
url = strings.TrimSuffix(url, "/chat/completions")
} }
// Log the request return url
logrus.WithFields(logrus.Fields{
"event": "embedding_request",
"model": llm.Model,
"input_len": len(input),
}).Info("[LLM] sending embedding request")
// Make the API call
resp, err := llm.client.CreateEmbeddings(ctx, req)
dur := time.Since(start)
// Handle errors
if err != nil {
logrus.WithFields(logrus.Fields{
"event": "embedding_response",
"latency_ms": dur.Milliseconds(),
"error": err.Error(),
}).Error("[LLM] embedding request failed")
return nil, fmt.Errorf("embedding error: %w", err)
} }
// Check if we got embeddings // isOpenRouter checks if the base URL is for OpenRouter
if len(resp.Data) == 0 { func isOpenRouter(baseURL string) bool {
logrus.WithFields(logrus.Fields{ return strings.Contains(strings.ToLower(baseURL), OpenRouterDomain)
"event": "embedding_response",
"latency_ms": dur.Milliseconds(),
}).Error("[LLM] empty embedding data in response")
return nil, fmt.Errorf("embedding error: no embedding data returned")
} }
// Convert []float32 to []float64 // createOpenRouterHTTPClient creates an HTTP client with headers for OpenRouter
embeddings := make([]float64, len(resp.Data[0].Embedding)) func createOpenRouterHTTPClient() *http.Client {
for i, v := range resp.Data[0].Embedding { transport := http.DefaultTransport.(*http.Transport).Clone()
embeddings[i] = float64(v) return &http.Client{
Transport: &openRouterTransport{
base: transport,
},
}
} }
// Log successful response // openRouterTransport is a custom transport that adds OpenRouter-specific headers
logrus.WithFields(logrus.Fields{ type openRouterTransport struct {
"event": "embedding_response",
"latency_ms": dur.Milliseconds(),
"vector_size": len(embeddings),
}).Info("[LLM] embedding response")
return embeddings, nil
}
func (llm *OpenAIClient) TranslateToEnglish(ctx context.Context, message string) (string, error) {
prompt, err := renderPrompt(appConfig.LLM.TranslatePrompt, map[string]string{"Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render Translate prompt")
return "", err
}
logrus.WithField("prompt", prompt).Info("[LLM] TranslateToEnglish prompt")
resp, err := llm.openAICompletion(ctx, prompt, nil)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] TranslateToEnglish response")
if err != nil {
return resp, err
}
return strings.TrimSpace(resp), nil
}
// customTransport is an http.RoundTripper that adds custom headers to requests.
type customTransport struct {
base http.RoundTripper base http.RoundTripper
headers map[string]string
} }
func (t *customTransport) RoundTrip(req *http.Request) (*http.Response, error) { // RoundTrip adds OpenRouter headers to requests
// Add custom headers to the request func (t *openRouterTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for key, value := range t.headers { // Add OpenRouter-specific headers
req.Header.Set(key, value) req.Header.Set("HTTP-Referer", "https://github.com/")
} req.Header.Set("X-Title", "vetrag-app")
// Call the base RoundTripper
return t.base.RoundTrip(req) return t.base.RoundTrip(req)
} }
// cleanJSONResponse cleans up a response to ensure valid JSON
func cleanJSONResponse(content string, isJSON bool) string {
// If not expecting JSON, just return the trimmed content
if !isJSON {
return strings.TrimSpace(content)
}
// Remove any markdown code block markers
content = strings.TrimPrefix(content, "```json")
content = strings.TrimPrefix(content, "```")
content = strings.TrimSuffix(content, "```")
content = strings.TrimSpace(content)
// If we expect JSON, make sure it ends properly
if idx := strings.LastIndex(content, "}"); idx >= 0 && idx < len(content)-1 {
// Only take up to the closing brace plus one character
content = content[:idx+1]
}
return content
}