package main import ( "context" "crypto/tls" "encoding/json" "fmt" "net/http" "strings" "time" "github.com/sashabaranov/go-openai" "github.com/sirupsen/logrus" ) // --- OpenAIClient implementation --- type OpenAIClient struct { APIKey string BaseURL string Model string Repo ChatRepositoryAPI client *openai.Client } func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient { config := openai.DefaultConfig(apiKey) if baseURL != "" { config.BaseURL = baseURL } // Special handling for OpenRouter // Create a new HTTP client with custom headers httpClient := &http.Client{} if strings.Contains(strings.ToLower(baseURL), "openrouter.ai") { // Use custom transport to add OpenRouter-specific headers defaultTransport := http.DefaultTransport.(*http.Transport).Clone() defaultTransport.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12} httpClient.Transport = &customTransport{ base: defaultTransport, headers: map[string]string{ "Referer": "https://github.com/", "X-Title": "vetrag-app", }, } config.HTTPClient = httpClient } client := openai.NewClientWithConfig(config) return &OpenAIClient{ APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo, client: client, } } func (llm *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) { _, parsed, err := llm.ExtractKeywordsRaw(ctx, message) return parsed, err } func (llm *OpenAIClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) { prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message}) if err != nil { logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt") return "", nil, err } logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt") // Format remains the same format := GetExtractKeywordsFormat() resp, err := llm.openAICompletion(ctx, prompt, format) logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response") if err != nil { return resp, nil, err } var result map[string]interface{} if err := json.Unmarshal([]byte(resp), &result); err != nil { return resp, nil, err } return resp, result, nil } func (llm *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) { _, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates) return vr, err } func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) { format := GetDisambiguateFormat() entries, _ := json.Marshal(candidates) prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message}) if err != nil { logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt") return "", "", err } logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt") resp, err := llm.openAICompletion(ctx, prompt, format) logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response") if err != nil { return resp, "", err } var parsed map[string]string if err := json.Unmarshal([]byte(resp), &parsed); err != nil { return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err) } visitReason := strings.TrimSpace(parsed["visitReason"]) if visitReason == "" { return resp, "", fmt.Errorf("visitReason not found in response") } return resp, visitReason, nil } func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) { truncate := func(s string, n int) string { if len(s) <= n { return s } return s[:n] + "..." } // Build system message with schema if format is provided systemContent := "You are a helpful assistant." if format != nil { schemaJSON, _ := json.MarshalIndent(format, "", " ") systemContent = "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + string(schemaJSON) + " Do not add explanations." } start := time.Now() // Create the chat completion request req := openai.ChatCompletionRequest{ Model: llm.Model, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleSystem, Content: systemContent, }, { Role: openai.ChatMessageRoleUser, Content: prompt, }, }, } // If we have a format schema, set the response format to JSON if format != nil { req.ResponseFormat = &openai.ChatCompletionResponseFormat{ Type: openai.ChatCompletionResponseFormatTypeJSONObject, } } // Log the request logrus.WithFields(logrus.Fields{ "event": "llm_request", "api_url": llm.BaseURL, "model": llm.Model, "prompt_len": len(prompt), }).Info("[LLM] sending request") // Make the API call resp, err := llm.client.CreateChatCompletion(ctx, req) dur := time.Since(start) // Handle errors if err != nil { logrus.WithFields(logrus.Fields{ "event": "llm_response", "latency_ms": dur.Milliseconds(), "error": err.Error(), }).Error("[LLM] request failed") return "", fmt.Errorf("provider error: %w", err) } // Extract content from response if len(resp.Choices) == 0 { logrus.WithFields(logrus.Fields{ "event": "llm_response", "latency_ms": dur.Milliseconds(), }).Error("[LLM] empty choices in response") return "", fmt.Errorf("provider error: no completion choices returned") } content := resp.Choices[0].Message.Content // Log successful response logrus.WithFields(logrus.Fields{ "event": "llm_response", "latency_ms": dur.Milliseconds(), "content_len": len(content), "content_snip": truncate(content, 300), "finish_reason": resp.Choices[0].FinishReason, }).Info("[LLM] parsed response") return content, nil } func (llm *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) { start := time.Now() // Create embedding request req := openai.EmbeddingRequest{ // Convert the string model to an EmbeddingModel type Model: openai.EmbeddingModel(llm.Model), Input: input, } // Log the request logrus.WithFields(logrus.Fields{ "event": "embedding_request", "model": llm.Model, "input_len": len(input), }).Info("[LLM] sending embedding request") // Make the API call resp, err := llm.client.CreateEmbeddings(ctx, req) dur := time.Since(start) // Handle errors if err != nil { logrus.WithFields(logrus.Fields{ "event": "embedding_response", "latency_ms": dur.Milliseconds(), "error": err.Error(), }).Error("[LLM] embedding request failed") return nil, fmt.Errorf("embedding error: %w", err) } // Check if we got embeddings if len(resp.Data) == 0 { logrus.WithFields(logrus.Fields{ "event": "embedding_response", "latency_ms": dur.Milliseconds(), }).Error("[LLM] empty embedding data in response") return nil, fmt.Errorf("embedding error: no embedding data returned") } // Convert []float32 to []float64 embeddings := make([]float64, len(resp.Data[0].Embedding)) for i, v := range resp.Data[0].Embedding { embeddings[i] = float64(v) } // Log successful response logrus.WithFields(logrus.Fields{ "event": "embedding_response", "latency_ms": dur.Milliseconds(), "vector_size": len(embeddings), }).Info("[LLM] embedding response") return embeddings, nil } func (llm *OpenAIClient) TranslateToEnglish(ctx context.Context, message string) (string, error) { prompt, err := renderPrompt(appConfig.LLM.TranslatePrompt, map[string]string{"Message": message}) if err != nil { logrus.WithError(err).Error("[CONFIG] Failed to render Translate prompt") return "", err } logrus.WithField("prompt", prompt).Info("[LLM] TranslateToEnglish prompt") resp, err := llm.openAICompletion(ctx, prompt, nil) logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] TranslateToEnglish response") if err != nil { return resp, err } return strings.TrimSpace(resp), nil } // customTransport is an http.RoundTripper that adds custom headers to requests. type customTransport struct { base http.RoundTripper headers map[string]string } func (t *customTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Add custom headers to the request for key, value := range t.headers { req.Header.Set(key, value) } // Call the base RoundTripper return t.base.RoundTrip(req) }