vetrag/openai_client.go

201 lines
6.3 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/sirupsen/logrus"
)
// --- OpenAIClient implementation ---
type OpenAIClient struct {
APIKey string
BaseURL string
Model string
Repo ChatRepositoryAPI
}
func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient {
return &OpenAIClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo}
}
func (llm *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
_, parsed, err := llm.ExtractKeywordsRaw(ctx, message)
return parsed, err
}
func (llm *OpenAIClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) {
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
return "", nil, err
}
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
// Use the utility function instead of inline format definition
format := GetExtractKeywordsFormat()
resp, err := llm.openAICompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
if err != nil {
return resp, nil, err
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(resp), &result); err != nil {
return resp, nil, err
}
return resp, result, nil
}
func (llm *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
_, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates)
return vr, err
}
func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
// Use the utility function instead of inline format definition
format := GetDisambiguateFormat()
entries, _ := json.Marshal(candidates)
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
return "", "", err
}
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
resp, err := llm.openAICompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
if err != nil {
return resp, "", err
}
var parsed map[string]string
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
}
visitReason := strings.TrimSpace(parsed["visitReason"])
if visitReason == "" {
return resp, "", fmt.Errorf("visitReason not found in response")
}
return resp, visitReason, nil
}
func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
apiURL := llm.BaseURL
if apiURL == "" {
apiURL = "https://api.openai.com/v1/chat/completions"
}
// Helper to stringify the expected JSON schema for instructions
schemaDesc := func() string {
b, _ := json.MarshalIndent(format, "", " ")
return string(b)
}
body := map[string]interface{}{
"model": llm.Model,
"messages": []map[string]string{
{"role": "system", "content": "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + schemaDesc() + " Do not add explanations."},
{"role": "user", "content": prompt},
},
"response_format": map[string]interface{}{"type": "json_object"},
}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
if llm.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if strings.Contains(apiURL, "openrouter.ai") {
req.Header.Set("Referer", "https://github.com/")
req.Header.Set("X-Title", "vetrag-app")
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
var openAI struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Error *struct {
Message string `json:"message"`
Type string `json:"type"`
} `json:"error"`
}
if err := json.Unmarshal(raw, &openAI); err == nil {
if openAI.Error != nil || resp.StatusCode >= 400 {
var msg string
if openAI.Error != nil {
msg = openAI.Error.Message
} else {
msg = string(raw)
}
return "", fmt.Errorf("provider error: %s", msg)
}
if len(openAI.Choices) > 0 && openAI.Choices[0].Message.Content != "" {
return openAI.Choices[0].Message.Content, nil
}
}
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
}
func (llm *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
apiURL := llm.BaseURL
if apiURL == "" {
apiURL = "https://api.openai.com/v1/embeddings"
}
body := map[string]interface{}{
"model": llm.Model,
"input": input,
}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
if llm.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if strings.Contains(apiURL, "openrouter.ai") {
req.Header.Set("Referer", "https://github.com/")
req.Header.Set("X-Title", "vetrag-app")
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var openAI struct {
Data []struct {
Embedding []float64 `json:"embedding"`
} `json:"data"`
Error *struct {
Message string `json:"message"`
} `json:"error"`
}
if err := json.Unmarshal(raw, &openAI); err == nil && len(openAI.Data) > 0 {
return openAI.Data[0].Embedding, nil
}
if openAI.Error != nil {
return nil, fmt.Errorf("embedding error: %s", openAI.Error.Message)
}
return nil, fmt.Errorf("unrecognized embedding response: %.200s", string(raw))
}