package main import ( "context" "encoding/json" "errors" "fmt" "net/http" "strings" "time" "github.com/sashabaranov/go-openai" "github.com/sirupsen/logrus" ) // Constants for OpenAI client const ( // DefaultMaxTokens defines the default maximum number of tokens for completions DefaultMaxTokens = 1500 // DefaultTemperature defines the default temperature for model responses DefaultTemperature = 0.7 // OpenRouterDomain is used to detect if we're using OpenRouter OpenRouterDomain = "openrouter.ai" ) // OpenAIClient implements the LLMClientAPI interface using OpenAI's API type OpenAIClient struct { APIKey string BaseURL string Model string Repo ChatRepositoryAPI client *openai.Client logger *logrus.Entry } // NewOpenAIClient creates a new OpenAI client with the provided configuration func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient { config := openai.DefaultConfig(apiKey) // Set custom base URL if provided if baseURL != "" { config.BaseURL = normalizeBaseURL(baseURL) } // Configure HTTP client with appropriate headers for OpenRouter if needed if isOpenRouter(baseURL) { config.HTTPClient = createOpenRouterHTTPClient() } logger := logrus.WithFields(logrus.Fields{ "component": "openai_client", "model": model, "base_url": config.BaseURL, }) logger.Info("Initializing OpenAI client") client := openai.NewClientWithConfig(config) return &OpenAIClient{ APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo, client: client, logger: logger, } } // ExtractKeywords extracts keywords from a message func (c *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) { c.logger.WithField("message_length", len(message)).Debug("Extracting keywords") prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message}) if err != nil { c.logger.WithError(err).Error("Failed to render ExtractKeywords prompt") return nil, fmt.Errorf("failed to render prompt: %w", err) } format := GetExtractKeywordsFormat() c.logger.WithField("prompt", prompt).Debug("ExtractKeywords prompt prepared") resp, err := c.createCompletion(ctx, prompt, format) if err != nil { return nil, err } var result map[string]interface{} if err := json.Unmarshal([]byte(resp), &result); err != nil { c.logger.WithError(err).Error("Failed to parse ExtractKeywords response") return nil, fmt.Errorf("failed to parse response: %w", err) } return result, nil } // DisambiguateBestMatch finds the best match among candidates for a message func (c *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) { c.logger.WithFields(logrus.Fields{ "message_length": len(message), "candidates": len(candidates), }).Debug("Disambiguating best match") format := GetDisambiguateFormat() entries, err := json.Marshal(candidates) if err != nil { c.logger.WithError(err).Error("Failed to marshal candidates") return "", fmt.Errorf("failed to marshal candidates: %w", err) } prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{ "Entries": string(entries), "Message": message, }) if err != nil { c.logger.WithError(err).Error("Failed to render Disambiguate prompt") return "", fmt.Errorf("failed to render prompt: %w", err) } c.logger.WithField("prompt", prompt).Debug("DisambiguateBestMatch prompt prepared") resp, err := c.createCompletion(ctx, prompt, format) if err != nil { return "", err } var parsed map[string]string if err := json.Unmarshal([]byte(resp), &parsed); err != nil { c.logger.WithError(err).Error("Failed to parse disambiguation response") return "", fmt.Errorf("failed to parse response: %w", err) } visitReason := strings.TrimSpace(parsed["visitReason"]) if visitReason == "" { return "", errors.New("visitReason not found in response") } return visitReason, nil } // TranslateToEnglish translates a message to English func (c *OpenAIClient) TranslateToEnglish(ctx context.Context, message string) (string, error) { c.logger.WithField("message_length", len(message)).Debug("Translating to English") prompt, err := renderPrompt(appConfig.LLM.TranslatePrompt, map[string]string{"Message": message}) if err != nil { c.logger.WithError(err).Error("Failed to render Translate prompt") return "", fmt.Errorf("failed to render prompt: %w", err) } c.logger.WithField("prompt", prompt).Debug("TranslateToEnglish prompt prepared") resp, err := c.createCompletion(ctx, prompt, nil) if err != nil { return "", err } return strings.TrimSpace(resp), nil } // GetEmbeddings generates embeddings for the input text func (c *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) { start := time.Now() c.logger.WithField("input_length", len(input)).Debug("Generating embeddings") // Create embedding request req := openai.EmbeddingRequest{ Model: openai.EmbeddingModel(c.Model), Input: input, } // Make the API call resp, err := c.client.CreateEmbeddings(ctx, req) duration := time.Since(start) if err != nil { c.logger.WithFields(logrus.Fields{ "latency_ms": duration.Milliseconds(), "error": err.Error(), }).Error("Embedding request failed") return nil, fmt.Errorf("embedding error: %w", err) } if len(resp.Data) == 0 { c.logger.WithField("latency_ms", duration.Milliseconds()).Error("Empty embedding data in response") return nil, errors.New("embedding error: no embedding data returned") } // Convert []float32 to []float64 embeddings := make([]float64, len(resp.Data[0].Embedding)) for i, v := range resp.Data[0].Embedding { embeddings[i] = float64(v) } c.logger.WithFields(logrus.Fields{ "latency_ms": duration.Milliseconds(), "vector_size": len(embeddings), }).Debug("Embedding generated successfully") return embeddings, nil } // createCompletion creates a chat completion with the given prompt and format func (c *OpenAIClient) createCompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) { start := time.Now() // Build system message with schema if format is provided systemContent := "You are a helpful assistant." if format != nil { schemaJSON, _ := json.MarshalIndent(format, "", " ") systemContent = fmt.Sprintf("You are a strict JSON generator. ONLY output valid JSON matching this schema: %s Do not add explanations.", string(schemaJSON)) } // Create the chat completion request req := openai.ChatCompletionRequest{ Model: c.Model, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleSystem, Content: systemContent, }, { Role: openai.ChatMessageRoleUser, Content: prompt, }, }, Temperature: DefaultTemperature, MaxTokens: DefaultMaxTokens, } // Set response format to JSON if we have a schema and we're not using a third-party model isThirdPartyModel := strings.Contains(c.Model, "/") if format != nil && !isThirdPartyModel { req.ResponseFormat = &openai.ChatCompletionResponseFormat{ Type: openai.ChatCompletionResponseFormatTypeJSONObject, } } // Log request details c.logger.WithFields(logrus.Fields{ "model": c.Model, "prompt_len": len(prompt), }).Debug("Sending completion request") // Make the API call resp, err := c.client.CreateChatCompletion(ctx, req) duration := time.Since(start) // Handle errors if err != nil { c.logger.WithFields(logrus.Fields{ "latency_ms": duration.Milliseconds(), "error": err.Error(), }).Error("Completion request failed") return "", fmt.Errorf("completion error: %w", err) } // Check if we got a response if len(resp.Choices) == 0 { c.logger.WithField("latency_ms", duration.Milliseconds()).Error("Empty choices in completion response") return "", errors.New("completion error: no completion choices returned") } // Extract and clean content content := resp.Choices[0].Message.Content content = cleanJSONResponse(content, format != nil) c.logger.WithFields(logrus.Fields{ "latency_ms": duration.Milliseconds(), "content_len": len(content), "finish_reason": resp.Choices[0].FinishReason, }).Debug("Completion request successful") return content, nil } // --- Helper functions --- // normalizeBaseURL ensures the base URL is properly formatted func normalizeBaseURL(url string) string { url = strings.TrimSpace(url) url = strings.TrimRight(url, "/") // Remove path components that will be added by the client if strings.HasSuffix(url, "/chat/completions") { url = strings.TrimSuffix(url, "/chat/completions") } return url } // isOpenRouter checks if the base URL is for OpenRouter func isOpenRouter(baseURL string) bool { return strings.Contains(strings.ToLower(baseURL), OpenRouterDomain) } // createOpenRouterHTTPClient creates an HTTP client with headers for OpenRouter func createOpenRouterHTTPClient() *http.Client { transport := http.DefaultTransport.(*http.Transport).Clone() return &http.Client{ Transport: &openRouterTransport{ base: transport, }, } } // openRouterTransport is a custom transport that adds OpenRouter-specific headers type openRouterTransport struct { base http.RoundTripper } // RoundTrip adds OpenRouter headers to requests func (t *openRouterTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Add OpenRouter-specific headers req.Header.Set("HTTP-Referer", "https://github.com/") req.Header.Set("X-Title", "vetrag-app") return t.base.RoundTrip(req) } // cleanJSONResponse cleans up a response to ensure valid JSON func cleanJSONResponse(content string, isJSON bool) string { // If not expecting JSON, just return the trimmed content if !isJSON { return strings.TrimSpace(content) } // Remove any markdown code block markers content = strings.TrimPrefix(content, "```json") content = strings.TrimPrefix(content, "```") content = strings.TrimSuffix(content, "```") content = strings.TrimSpace(content) // If we expect JSON, make sure it ends properly if idx := strings.LastIndex(content, "}"); idx >= 0 && idx < len(content)-1 { // Only take up to the closing brace plus one character content = content[:idx+1] } return content }