opeain
This commit is contained in:
parent
e69201e3e9
commit
92bcf66766
2
main.go
2
main.go
|
|
@ -73,7 +73,7 @@ func main() {
|
|||
llm := NewLLMClientFromEnv(repo)
|
||||
|
||||
// Launch background backfill of sentence embeddings (non-blocking)
|
||||
startSentenceEmbeddingBackfill(repo, llm, &visitDB)
|
||||
//startSentenceEmbeddingBackfill(repo, llm, &visitDB)
|
||||
|
||||
// Wrap templates for controller
|
||||
uiTmpl := &TemplateWrapper{Tmpl: uiTemplate}
|
||||
|
|
|
|||
383
openai_client.go
383
openai_client.go
|
|
@ -2,8 +2,8 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
|
@ -13,39 +13,48 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// --- OpenAIClient implementation ---
|
||||
// Constants for OpenAI client
|
||||
const (
|
||||
// DefaultMaxTokens defines the default maximum number of tokens for completions
|
||||
DefaultMaxTokens = 1500
|
||||
// DefaultTemperature defines the default temperature for model responses
|
||||
DefaultTemperature = 0.7
|
||||
// OpenRouterDomain is used to detect if we're using OpenRouter
|
||||
OpenRouterDomain = "openrouter.ai"
|
||||
)
|
||||
|
||||
// OpenAIClient implements the LLMClientAPI interface using OpenAI's API
|
||||
type OpenAIClient struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
Model string
|
||||
Repo ChatRepositoryAPI
|
||||
client *openai.Client
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewOpenAIClient creates a new OpenAI client with the provided configuration
|
||||
func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient {
|
||||
config := openai.DefaultConfig(apiKey)
|
||||
|
||||
// Set custom base URL if provided
|
||||
if baseURL != "" {
|
||||
config.BaseURL = baseURL
|
||||
config.BaseURL = normalizeBaseURL(baseURL)
|
||||
}
|
||||
|
||||
// Special handling for OpenRouter
|
||||
// Create a new HTTP client with custom headers
|
||||
httpClient := &http.Client{}
|
||||
if strings.Contains(strings.ToLower(baseURL), "openrouter.ai") {
|
||||
// Use custom transport to add OpenRouter-specific headers
|
||||
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
defaultTransport.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
httpClient.Transport = &customTransport{
|
||||
base: defaultTransport,
|
||||
headers: map[string]string{
|
||||
"Referer": "https://github.com/",
|
||||
"X-Title": "vetrag-app",
|
||||
},
|
||||
}
|
||||
config.HTTPClient = httpClient
|
||||
// Configure HTTP client with appropriate headers for OpenRouter if needed
|
||||
if isOpenRouter(baseURL) {
|
||||
config.HTTPClient = createOpenRouterHTTPClient()
|
||||
}
|
||||
|
||||
logger := logrus.WithFields(logrus.Fields{
|
||||
"component": "openai_client",
|
||||
"model": model,
|
||||
"base_url": config.BaseURL,
|
||||
})
|
||||
|
||||
logger.Info("Initializing OpenAI client")
|
||||
|
||||
client := openai.NewClientWithConfig(config)
|
||||
return &OpenAIClient{
|
||||
APIKey: apiKey,
|
||||
|
|
@ -53,88 +62,158 @@ func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *Ope
|
|||
Model: model,
|
||||
Repo: repo,
|
||||
client: client,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (llm *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
||||
_, parsed, err := llm.ExtractKeywordsRaw(ctx, message)
|
||||
return parsed, err
|
||||
}
|
||||
// ExtractKeywords extracts keywords from a message
|
||||
func (c *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
||||
c.logger.WithField("message_length", len(message)).Debug("Extracting keywords")
|
||||
|
||||
func (llm *OpenAIClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) {
|
||||
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
|
||||
return "", nil, err
|
||||
c.logger.WithError(err).Error("Failed to render ExtractKeywords prompt")
|
||||
return nil, fmt.Errorf("failed to render prompt: %w", err)
|
||||
}
|
||||
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
|
||||
|
||||
// Format remains the same
|
||||
format := GetExtractKeywordsFormat()
|
||||
c.logger.WithField("prompt", prompt).Debug("ExtractKeywords prompt prepared")
|
||||
|
||||
resp, err := llm.openAICompletion(ctx, prompt, format)
|
||||
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
|
||||
resp, err := c.createCompletion(ctx, prompt, format)
|
||||
if err != nil {
|
||||
return resp, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(resp), &result); err != nil {
|
||||
return resp, nil, err
|
||||
}
|
||||
return resp, result, nil
|
||||
c.logger.WithError(err).Error("Failed to parse ExtractKeywords response")
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
func (llm *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
|
||||
_, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates)
|
||||
return vr, err
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
|
||||
// DisambiguateBestMatch finds the best match among candidates for a message
|
||||
func (c *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
|
||||
c.logger.WithFields(logrus.Fields{
|
||||
"message_length": len(message),
|
||||
"candidates": len(candidates),
|
||||
}).Debug("Disambiguating best match")
|
||||
|
||||
format := GetDisambiguateFormat()
|
||||
|
||||
entries, _ := json.Marshal(candidates)
|
||||
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
|
||||
entries, err := json.Marshal(candidates)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
|
||||
return "", "", err
|
||||
}
|
||||
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
|
||||
resp, err := llm.openAICompletion(ctx, prompt, format)
|
||||
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
|
||||
if err != nil {
|
||||
return resp, "", err
|
||||
}
|
||||
var parsed map[string]string
|
||||
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
|
||||
return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
|
||||
}
|
||||
visitReason := strings.TrimSpace(parsed["visitReason"])
|
||||
if visitReason == "" {
|
||||
return resp, "", fmt.Errorf("visitReason not found in response")
|
||||
}
|
||||
return resp, visitReason, nil
|
||||
c.logger.WithError(err).Error("Failed to marshal candidates")
|
||||
return "", fmt.Errorf("failed to marshal candidates: %w", err)
|
||||
}
|
||||
|
||||
func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
|
||||
truncate := func(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{
|
||||
"Entries": string(entries),
|
||||
"Message": message,
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.WithError(err).Error("Failed to render Disambiguate prompt")
|
||||
return "", fmt.Errorf("failed to render prompt: %w", err)
|
||||
}
|
||||
return s[:n] + "...<truncated>"
|
||||
|
||||
c.logger.WithField("prompt", prompt).Debug("DisambiguateBestMatch prompt prepared")
|
||||
|
||||
resp, err := c.createCompletion(ctx, prompt, format)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var parsed map[string]string
|
||||
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
|
||||
c.logger.WithError(err).Error("Failed to parse disambiguation response")
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
visitReason := strings.TrimSpace(parsed["visitReason"])
|
||||
if visitReason == "" {
|
||||
return "", errors.New("visitReason not found in response")
|
||||
}
|
||||
|
||||
return visitReason, nil
|
||||
}
|
||||
|
||||
// TranslateToEnglish translates a message to English
|
||||
func (c *OpenAIClient) TranslateToEnglish(ctx context.Context, message string) (string, error) {
|
||||
c.logger.WithField("message_length", len(message)).Debug("Translating to English")
|
||||
|
||||
prompt, err := renderPrompt(appConfig.LLM.TranslatePrompt, map[string]string{"Message": message})
|
||||
if err != nil {
|
||||
c.logger.WithError(err).Error("Failed to render Translate prompt")
|
||||
return "", fmt.Errorf("failed to render prompt: %w", err)
|
||||
}
|
||||
|
||||
c.logger.WithField("prompt", prompt).Debug("TranslateToEnglish prompt prepared")
|
||||
|
||||
resp, err := c.createCompletion(ctx, prompt, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return strings.TrimSpace(resp), nil
|
||||
}
|
||||
|
||||
// GetEmbeddings generates embeddings for the input text
|
||||
func (c *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
|
||||
start := time.Now()
|
||||
c.logger.WithField("input_length", len(input)).Debug("Generating embeddings")
|
||||
|
||||
// Create embedding request
|
||||
req := openai.EmbeddingRequest{
|
||||
Model: openai.EmbeddingModel(c.Model),
|
||||
Input: input,
|
||||
}
|
||||
|
||||
// Make the API call
|
||||
resp, err := c.client.CreateEmbeddings(ctx, req)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
c.logger.WithFields(logrus.Fields{
|
||||
"latency_ms": duration.Milliseconds(),
|
||||
"error": err.Error(),
|
||||
}).Error("Embedding request failed")
|
||||
return nil, fmt.Errorf("embedding error: %w", err)
|
||||
}
|
||||
|
||||
if len(resp.Data) == 0 {
|
||||
c.logger.WithField("latency_ms", duration.Milliseconds()).Error("Empty embedding data in response")
|
||||
return nil, errors.New("embedding error: no embedding data returned")
|
||||
}
|
||||
|
||||
// Convert []float32 to []float64
|
||||
embeddings := make([]float64, len(resp.Data[0].Embedding))
|
||||
for i, v := range resp.Data[0].Embedding {
|
||||
embeddings[i] = float64(v)
|
||||
}
|
||||
|
||||
c.logger.WithFields(logrus.Fields{
|
||||
"latency_ms": duration.Milliseconds(),
|
||||
"vector_size": len(embeddings),
|
||||
}).Debug("Embedding generated successfully")
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// createCompletion creates a chat completion with the given prompt and format
|
||||
func (c *OpenAIClient) createCompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
|
||||
start := time.Now()
|
||||
|
||||
// Build system message with schema if format is provided
|
||||
systemContent := "You are a helpful assistant."
|
||||
if format != nil {
|
||||
schemaJSON, _ := json.MarshalIndent(format, "", " ")
|
||||
systemContent = "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + string(schemaJSON) + " Do not add explanations."
|
||||
systemContent = fmt.Sprintf("You are a strict JSON generator. ONLY output valid JSON matching this schema: %s Do not add explanations.", string(schemaJSON))
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// Create the chat completion request
|
||||
req := openai.ChatCompletionRequest{
|
||||
Model: llm.Model,
|
||||
Model: c.Model,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
|
|
@ -145,143 +224,117 @@ func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, fo
|
|||
Content: prompt,
|
||||
},
|
||||
},
|
||||
Temperature: DefaultTemperature,
|
||||
MaxTokens: DefaultMaxTokens,
|
||||
}
|
||||
|
||||
// If we have a format schema, set the response format to JSON
|
||||
if format != nil {
|
||||
// Set response format to JSON if we have a schema and we're not using a third-party model
|
||||
isThirdPartyModel := strings.Contains(c.Model, "/")
|
||||
if format != nil && !isThirdPartyModel {
|
||||
req.ResponseFormat = &openai.ChatCompletionResponseFormat{
|
||||
Type: openai.ChatCompletionResponseFormatTypeJSONObject,
|
||||
}
|
||||
}
|
||||
|
||||
// Log the request
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "llm_request",
|
||||
"api_url": llm.BaseURL,
|
||||
"model": llm.Model,
|
||||
// Log request details
|
||||
c.logger.WithFields(logrus.Fields{
|
||||
"model": c.Model,
|
||||
"prompt_len": len(prompt),
|
||||
}).Info("[LLM] sending request")
|
||||
}).Debug("Sending completion request")
|
||||
|
||||
// Make the API call
|
||||
resp, err := llm.client.CreateChatCompletion(ctx, req)
|
||||
dur := time.Since(start)
|
||||
resp, err := c.client.CreateChatCompletion(ctx, req)
|
||||
duration := time.Since(start)
|
||||
|
||||
// Handle errors
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "llm_response",
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
c.logger.WithFields(logrus.Fields{
|
||||
"latency_ms": duration.Milliseconds(),
|
||||
"error": err.Error(),
|
||||
}).Error("[LLM] request failed")
|
||||
return "", fmt.Errorf("provider error: %w", err)
|
||||
}).Error("Completion request failed")
|
||||
return "", fmt.Errorf("completion error: %w", err)
|
||||
}
|
||||
|
||||
// Extract content from response
|
||||
// Check if we got a response
|
||||
if len(resp.Choices) == 0 {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "llm_response",
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
}).Error("[LLM] empty choices in response")
|
||||
return "", fmt.Errorf("provider error: no completion choices returned")
|
||||
c.logger.WithField("latency_ms", duration.Milliseconds()).Error("Empty choices in completion response")
|
||||
return "", errors.New("completion error: no completion choices returned")
|
||||
}
|
||||
|
||||
// Extract and clean content
|
||||
content := resp.Choices[0].Message.Content
|
||||
content = cleanJSONResponse(content, format != nil)
|
||||
|
||||
// Log successful response
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "llm_response",
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
c.logger.WithFields(logrus.Fields{
|
||||
"latency_ms": duration.Milliseconds(),
|
||||
"content_len": len(content),
|
||||
"content_snip": truncate(content, 300),
|
||||
"finish_reason": resp.Choices[0].FinishReason,
|
||||
}).Info("[LLM] parsed response")
|
||||
}).Debug("Completion request successful")
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (llm *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
|
||||
start := time.Now()
|
||||
// --- Helper functions ---
|
||||
|
||||
// Create embedding request
|
||||
req := openai.EmbeddingRequest{
|
||||
// Convert the string model to an EmbeddingModel type
|
||||
Model: openai.EmbeddingModel(llm.Model),
|
||||
Input: input,
|
||||
// normalizeBaseURL ensures the base URL is properly formatted
|
||||
func normalizeBaseURL(url string) string {
|
||||
url = strings.TrimSpace(url)
|
||||
url = strings.TrimRight(url, "/")
|
||||
|
||||
// Remove path components that will be added by the client
|
||||
if strings.HasSuffix(url, "/chat/completions") {
|
||||
url = strings.TrimSuffix(url, "/chat/completions")
|
||||
}
|
||||
|
||||
// Log the request
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "embedding_request",
|
||||
"model": llm.Model,
|
||||
"input_len": len(input),
|
||||
}).Info("[LLM] sending embedding request")
|
||||
|
||||
// Make the API call
|
||||
resp, err := llm.client.CreateEmbeddings(ctx, req)
|
||||
dur := time.Since(start)
|
||||
|
||||
// Handle errors
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "embedding_response",
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
"error": err.Error(),
|
||||
}).Error("[LLM] embedding request failed")
|
||||
return nil, fmt.Errorf("embedding error: %w", err)
|
||||
return url
|
||||
}
|
||||
|
||||
// Check if we got embeddings
|
||||
if len(resp.Data) == 0 {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "embedding_response",
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
}).Error("[LLM] empty embedding data in response")
|
||||
return nil, fmt.Errorf("embedding error: no embedding data returned")
|
||||
// isOpenRouter checks if the base URL is for OpenRouter
|
||||
func isOpenRouter(baseURL string) bool {
|
||||
return strings.Contains(strings.ToLower(baseURL), OpenRouterDomain)
|
||||
}
|
||||
|
||||
// Convert []float32 to []float64
|
||||
embeddings := make([]float64, len(resp.Data[0].Embedding))
|
||||
for i, v := range resp.Data[0].Embedding {
|
||||
embeddings[i] = float64(v)
|
||||
// createOpenRouterHTTPClient creates an HTTP client with headers for OpenRouter
|
||||
func createOpenRouterHTTPClient() *http.Client {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
return &http.Client{
|
||||
Transport: &openRouterTransport{
|
||||
base: transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Log successful response
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event": "embedding_response",
|
||||
"latency_ms": dur.Milliseconds(),
|
||||
"vector_size": len(embeddings),
|
||||
}).Info("[LLM] embedding response")
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (llm *OpenAIClient) TranslateToEnglish(ctx context.Context, message string) (string, error) {
|
||||
prompt, err := renderPrompt(appConfig.LLM.TranslatePrompt, map[string]string{"Message": message})
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("[CONFIG] Failed to render Translate prompt")
|
||||
return "", err
|
||||
}
|
||||
logrus.WithField("prompt", prompt).Info("[LLM] TranslateToEnglish prompt")
|
||||
|
||||
resp, err := llm.openAICompletion(ctx, prompt, nil)
|
||||
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] TranslateToEnglish response")
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
return strings.TrimSpace(resp), nil
|
||||
}
|
||||
|
||||
// customTransport is an http.RoundTripper that adds custom headers to requests.
|
||||
type customTransport struct {
|
||||
// openRouterTransport is a custom transport that adds OpenRouter-specific headers
|
||||
type openRouterTransport struct {
|
||||
base http.RoundTripper
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func (t *customTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Add custom headers to the request
|
||||
for key, value := range t.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
// Call the base RoundTripper
|
||||
// RoundTrip adds OpenRouter headers to requests
|
||||
func (t *openRouterTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Add OpenRouter-specific headers
|
||||
req.Header.Set("HTTP-Referer", "https://github.com/")
|
||||
req.Header.Set("X-Title", "vetrag-app")
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
// cleanJSONResponse cleans up a response to ensure valid JSON
|
||||
func cleanJSONResponse(content string, isJSON bool) string {
|
||||
// If not expecting JSON, just return the trimmed content
|
||||
if !isJSON {
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
// Remove any markdown code block markers
|
||||
content = strings.TrimPrefix(content, "```json")
|
||||
content = strings.TrimPrefix(content, "```")
|
||||
content = strings.TrimSuffix(content, "```")
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// If we expect JSON, make sure it ends properly
|
||||
if idx := strings.LastIndex(content, "}"); idx >= 0 && idx < len(content)-1 {
|
||||
// Only take up to the closing brace plus one character
|
||||
content = content[:idx+1]
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue