vetrag/llm.go

218 lines
7.1 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"text/template"
"github.com/sirupsen/logrus"
)
// LLMClient abstracts LLM API calls
type LLMClient struct {
APIKey string
BaseURL string
Model string
}
// NewLLMClient constructs a new LLMClient with the given API key and base URL
func NewLLMClient(apiKey, baseURL string, model string) *LLMClient {
return &LLMClient{
APIKey: apiKey,
BaseURL: baseURL,
Model: model,
}
}
// renderPrompt renders a Go template with the given data
func renderPrompt(tmplStr string, data any) (string, error) {
tmpl, err := template.New("").Parse(tmplStr)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
// ExtractKeywords calls LLM to extract keywords from user message
func (llm *LLMClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
return nil, err
}
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
format := map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"translate": map[string]interface{}{"type": "string"},
"keyword": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
"animal": map[string]interface{}{"type": "string"},
},
"required": []string{"translate", "keyword", "animal"},
}
resp, err := llm.openAICompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
if err != nil {
return nil, err
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(resp), &result); err != nil {
return nil, err
}
return result, nil
}
// DisambiguateBestMatch calls LLM to pick best match from candidates
func (llm *LLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
format := map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"visitReason": map[string]interface{}{"type": "string"},
},
"required": []string{"visitReason"},
}
entries, _ := json.Marshal(candidates)
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
return "", err
}
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
resp, err := llm.openAICompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
if err != nil {
return "", err
}
var parsed map[string]string
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
return "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
}
visitReason := strings.TrimSpace(parsed["visitReason"])
if visitReason == "" {
return "", fmt.Errorf("visitReason not found in response")
}
return visitReason, nil
}
// openAICompletion now supports both Ollama (default local) and OpenRouter/OpenAI-compatible APIs without external branching.
// It auto-detects by inspecting the BaseURL. If the URL contains "openrouter.ai" or "/v1/", it assumes OpenAI-style.
func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
apiURL := llm.BaseURL
if apiURL == "" {
// Default to Ollama local chat endpoint
apiURL = "http://localhost:11434/api/chat"
}
isOpenAIStyle := strings.Contains(apiURL, "openrouter.ai") || strings.Contains(apiURL, "/v1/")
// Build request body depending on style
var body map[string]interface{}
if isOpenAIStyle {
// OpenAI / OpenRouter style (chat.completions)
// Use response_format with JSON schema when provided.
responseFormat := map[string]interface{}{
"type": "json_schema",
"json_schema": map[string]interface{}{
"name": "structured_output",
"schema": format,
},
}
body = map[string]interface{}{
"model": llm.Model,
"messages": []map[string]string{{"role": "user", "content": prompt}},
"response_format": responseFormat,
}
} else {
// Ollama structured output extension
body = map[string]interface{}{
"model": llm.Model,
"messages": []map[string]string{{"role": "user", "content": prompt}},
"stream": false,
"format": format,
}
}
jsonBody, _ := json.Marshal(body)
logrus.WithFields(logrus.Fields{"api_url": apiURL, "prompt": prompt, "is_openai_style": isOpenAIStyle}).Info("[LLM] completion POST")
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
if llm.APIKey != "" {
// OpenRouter expects: Authorization: Bearer sk-... or OR-... depending on key type
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
logrus.WithError(err).Error("[LLM] completion HTTP error")
return "", err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed reading response body: %w", err)
}
logrus.WithFields(logrus.Fields{"status": resp.StatusCode, "raw": string(raw)}).Debug("[LLM] completion raw response")
// Attempt Ollama format first (backwards compatible)
var ollama struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
}
if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" {
logrus.WithField("content", ollama.Message.Content).Info("[LLM] completion (ollama) parsed")
return ollama.Message.Content, nil
}
// Attempt OpenAI / OpenRouter style
var openAI struct {
Choices []struct {
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Error *struct {
Message string `json:"message"`
Type string `json:"type"`
} `json:"error"`
}
if err := json.Unmarshal(raw, &openAI); err == nil {
if openAI.Error != nil {
return "", fmt.Errorf("provider error: %s (%s)", openAI.Error.Message, openAI.Error.Type)
}
if len(openAI.Choices) > 0 && openAI.Choices[0].Message.Content != "" {
content := openAI.Choices[0].Message.Content
logrus.WithField("content", content).Info("[LLM] completion (openai) parsed")
return content, nil
}
}
// If still nothing, return error with snippet
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
}
// LLMClientAPI allows mocking LLMClient in other places
// Only public methods should be included
type LLMClientAPI interface {
ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error)
DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error)
}
var _ LLMClientAPI = (*LLMClient)(nil)