326 lines
10 KiB
Go
326 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"text/template"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// LLMClient abstracts LLM API calls
|
|
type LLMClient struct {
|
|
APIKey string
|
|
BaseURL string
|
|
Model string
|
|
Repo ChatRepositoryAPI
|
|
}
|
|
|
|
// NewLLMClient constructs a new LLMClient with the given API key, base URL, model, and optional repository
|
|
func NewLLMClient(apiKey, baseURL string, model string, repo ChatRepositoryAPI) *LLMClient {
|
|
return &LLMClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo}
|
|
}
|
|
|
|
func (llm *LLMClient) SetRepository(r ChatRepositoryAPI) { llm.Repo = r }
|
|
|
|
// helper to get correlation id from context
|
|
const correlationIDCtxKey = "corr_id"
|
|
|
|
func correlationIDFromCtx(ctx context.Context) string {
|
|
v := ctx.Value(correlationIDCtxKey)
|
|
if s, ok := v.(string); ok {
|
|
return s
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (llm *LLMClient) persistRaw(ctx context.Context, phase, raw string) {
|
|
if llm == nil || llm.Repo == nil || raw == "" {
|
|
return
|
|
}
|
|
cid := correlationIDFromCtx(ctx)
|
|
if cid == "" {
|
|
return
|
|
}
|
|
_ = llm.Repo.SaveLLMRawEvent(ctx, cid, phase, raw)
|
|
}
|
|
|
|
// renderPrompt renders a Go template with the given data
|
|
func renderPrompt(tmplStr string, data any) (string, error) {
|
|
tmpl, err := template.New("").Parse(tmplStr)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
var buf bytes.Buffer
|
|
if err := tmpl.Execute(&buf, data); err != nil {
|
|
return "", err
|
|
}
|
|
return buf.String(), nil
|
|
}
|
|
|
|
// ExtractKeywords calls LLM to extract keywords from user message
|
|
func (llm *LLMClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
|
_, parsed, err := llm.ExtractKeywordsRaw(ctx, message)
|
|
return parsed, err
|
|
}
|
|
|
|
// ExtractKeywordsRaw returns the raw JSON string and parsed map
|
|
func (llm *LLMClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) {
|
|
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
|
|
if err != nil {
|
|
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
|
|
return "", nil, err
|
|
}
|
|
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
|
|
format := map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"translate": map[string]interface{}{"type": "string"},
|
|
"keyword": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
|
|
"animal": map[string]interface{}{"type": "string"},
|
|
},
|
|
"required": []string{"translate", "keyword", "animal"},
|
|
}
|
|
resp, err := llm.openAICompletion(ctx, prompt, format)
|
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
|
|
if err != nil {
|
|
return resp, nil, err // return whatever raw we got (may be empty)
|
|
}
|
|
var result map[string]interface{}
|
|
if err := json.Unmarshal([]byte(resp), &result); err != nil {
|
|
return resp, nil, err
|
|
}
|
|
llm.persistRaw(ctx, "extract_keywords", resp)
|
|
return resp, result, nil
|
|
}
|
|
|
|
// DisambiguateBestMatch calls LLM to pick best match from candidates
|
|
func (llm *LLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
|
|
_, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates)
|
|
return vr, err
|
|
}
|
|
|
|
// DisambiguateBestMatchRaw returns raw JSON and visitReason
|
|
func (llm *LLMClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
|
|
format := map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"visitReason": map[string]interface{}{"type": "string"},
|
|
},
|
|
"required": []string{"visitReason"},
|
|
}
|
|
entries, _ := json.Marshal(candidates)
|
|
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
|
|
if err != nil {
|
|
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
|
|
return "", "", err
|
|
}
|
|
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
|
|
resp, err := llm.openAICompletion(ctx, prompt, format)
|
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
|
|
if err != nil {
|
|
return resp, "", err
|
|
}
|
|
var parsed map[string]string
|
|
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
|
|
return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
|
|
}
|
|
visitReason := strings.TrimSpace(parsed["visitReason"])
|
|
if visitReason == "" {
|
|
return resp, "", fmt.Errorf("visitReason not found in response")
|
|
}
|
|
llm.persistRaw(ctx, "disambiguate", resp)
|
|
return resp, visitReason, nil
|
|
}
|
|
|
|
// openAICompletion now supports both Ollama (default local) and OpenRouter/OpenAI-compatible APIs without external branching.
|
|
// It auto-detects by inspecting the BaseURL. If the URL contains "openrouter.ai" or "/v1/", it assumes OpenAI-style.
|
|
func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
|
|
apiURL := llm.BaseURL
|
|
if apiURL == "" {
|
|
// Default to Ollama local chat endpoint
|
|
apiURL = "http://localhost:11434/api/chat"
|
|
}
|
|
|
|
isOpenAIStyle := strings.Contains(apiURL, "openrouter.ai") || strings.Contains(apiURL, "/v1/")
|
|
|
|
// Helper to stringify the expected JSON schema for instructions
|
|
schemaDesc := func() string {
|
|
b, _ := json.MarshalIndent(format, "", " ")
|
|
return string(b)
|
|
}
|
|
|
|
truncate := func(s string, n int) string {
|
|
if len(s) <= n {
|
|
return s
|
|
}
|
|
return s[:n] + "...<truncated>"
|
|
}
|
|
|
|
buildBody := func() map[string]interface{} {
|
|
if isOpenAIStyle {
|
|
return map[string]interface{}{
|
|
"model": llm.Model,
|
|
"messages": []map[string]string{
|
|
{"role": "system", "content": "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + schemaDesc() + " Do not add explanations."},
|
|
{"role": "user", "content": prompt},
|
|
},
|
|
"response_format": map[string]interface{}{"type": "json_object"},
|
|
}
|
|
}
|
|
// Ollama style
|
|
return map[string]interface{}{
|
|
"model": llm.Model,
|
|
"messages": []map[string]string{{"role": "user", "content": prompt}},
|
|
"stream": false,
|
|
"format": format,
|
|
}
|
|
}
|
|
|
|
body := buildBody()
|
|
|
|
doRequest := func(body map[string]interface{}) (raw []byte, status int, err error, dur time.Duration) {
|
|
jsonBody, _ := json.Marshal(body)
|
|
bodySize := len(jsonBody)
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "llm_request",
|
|
"api_url": apiURL,
|
|
"model": llm.Model,
|
|
"is_openai_style": isOpenAIStyle,
|
|
"prompt_len": len(prompt),
|
|
"body_size": bodySize,
|
|
}).Info("[LLM] sending request")
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
|
|
if llm.APIKey != "" {
|
|
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
if strings.Contains(apiURL, "openrouter.ai") {
|
|
req.Header.Set("Referer", "https://github.com/")
|
|
req.Header.Set("X-Title", "vetrag-app")
|
|
}
|
|
start := time.Now()
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, 0, err, time.Since(start)
|
|
}
|
|
defer resp.Body.Close()
|
|
raw, rerr := io.ReadAll(resp.Body)
|
|
return raw, resp.StatusCode, rerr, time.Since(start)
|
|
}
|
|
|
|
raw, status, err, dur := doRequest(body)
|
|
if err != nil {
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "llm_response",
|
|
"status": status,
|
|
"latency_ms": dur.Milliseconds(),
|
|
"error": err,
|
|
}).Error("[LLM] request failed")
|
|
return "", err
|
|
}
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "llm_raw_response",
|
|
"status": status,
|
|
"latency_ms": dur.Milliseconds(),
|
|
"raw_trunc": truncate(string(raw), 600),
|
|
"raw_len": len(raw),
|
|
}).Debug("[LLM] raw response body")
|
|
|
|
parseVariant := "unknown"
|
|
|
|
// Attempt Ollama format parse
|
|
var ollama struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
Error string `json:"error"`
|
|
}
|
|
if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" {
|
|
parseVariant = "ollama"
|
|
content := ollama.Message.Content
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "llm_response",
|
|
"status": status,
|
|
"latency_ms": dur.Milliseconds(),
|
|
"parse_variant": parseVariant,
|
|
"content_len": len(content),
|
|
"content_snip": truncate(content, 300),
|
|
}).Info("[LLM] parsed response")
|
|
return content, nil
|
|
}
|
|
|
|
// Attempt OpenAI/OpenRouter style parse
|
|
var openAI struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
} `json:"choices"`
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
Type string `json:"type"`
|
|
} `json:"error"`
|
|
}
|
|
if err := json.Unmarshal(raw, &openAI); err == nil {
|
|
if openAI.Error != nil || status >= 400 {
|
|
parseVariant = "openai"
|
|
var msg string
|
|
if openAI.Error != nil {
|
|
msg = openAI.Error.Message
|
|
} else {
|
|
msg = string(raw)
|
|
}
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "llm_response",
|
|
"status": status,
|
|
"latency_ms": dur.Milliseconds(),
|
|
"parse_variant": parseVariant,
|
|
"error": msg,
|
|
}).Error("[LLM] provider error")
|
|
return "", fmt.Errorf("provider error: %s", msg)
|
|
}
|
|
if len(openAI.Choices) > 0 && openAI.Choices[0].Message.Content != "" {
|
|
parseVariant = "openai"
|
|
content := openAI.Choices[0].Message.Content
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "llm_response",
|
|
"status": status,
|
|
"latency_ms": dur.Milliseconds(),
|
|
"parse_variant": parseVariant,
|
|
"content_len": len(content),
|
|
"content_snip": truncate(content, 300),
|
|
}).Info("[LLM] parsed response")
|
|
return content, nil
|
|
}
|
|
}
|
|
|
|
logrus.WithFields(logrus.Fields{
|
|
"event": "llm_response",
|
|
"status": status,
|
|
"latency_ms": dur.Milliseconds(),
|
|
"parse_variant": parseVariant,
|
|
"raw_snip": truncate(string(raw), 300),
|
|
}).Error("[LLM] unrecognized response format")
|
|
|
|
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
|
|
}
|
|
|
|
// LLMClientAPI allows mocking LLMClient in other places
|
|
// Only public methods should be included
|
|
|
|
type LLMClientAPI interface {
|
|
ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error)
|
|
DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error)
|
|
}
|
|
|
|
var _ LLMClientAPI = (*LLMClient)(nil)
|