157 lines
5.1 KiB
Go
157 lines
5.1 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"text/template"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// LLMClient abstracts LLM API calls
|
|
type LLMClient struct {
|
|
APIKey string
|
|
BaseURL string
|
|
Model string
|
|
}
|
|
|
|
// NewLLMClient constructs a new LLMClient with the given API key and base URL
|
|
func NewLLMClient(apiKey, baseURL string, model string) *LLMClient {
|
|
return &LLMClient{
|
|
APIKey: apiKey,
|
|
BaseURL: baseURL,
|
|
Model: model,
|
|
}
|
|
}
|
|
|
|
// renderPrompt renders a Go template with the given data
|
|
func renderPrompt(tmplStr string, data any) (string, error) {
|
|
tmpl, err := template.New("").Parse(tmplStr)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
var buf bytes.Buffer
|
|
if err := tmpl.Execute(&buf, data); err != nil {
|
|
return "", err
|
|
}
|
|
return buf.String(), nil
|
|
}
|
|
|
|
// ExtractKeywords calls LLM to extract keywords from user message
|
|
func (llm *LLMClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
|
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
|
|
if err != nil {
|
|
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
|
|
return nil, err
|
|
}
|
|
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
|
|
format := map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"translate": map[string]interface{}{"type": "string"},
|
|
"keyword": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
|
|
"animal": map[string]interface{}{"type": "string"},
|
|
},
|
|
"required": []string{"translate", "keyword", "animal"},
|
|
}
|
|
resp, err := llm.openAICompletion(ctx, prompt, format)
|
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var result map[string]interface{}
|
|
if err := json.Unmarshal([]byte(resp), &result); err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// DisambiguateBestMatch calls LLM to pick best match from candidates
|
|
func (llm *LLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
|
|
format := map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"visitReason": map[string]interface{}{"type": "string"},
|
|
},
|
|
"required": []string{"visitReason"},
|
|
}
|
|
entries, _ := json.Marshal(candidates)
|
|
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
|
|
if err != nil {
|
|
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
|
|
return "", err
|
|
}
|
|
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
|
|
resp, err := llm.openAICompletion(ctx, prompt, format)
|
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
var parsed map[string]string
|
|
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
|
|
return "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
|
|
}
|
|
|
|
visitReason := strings.TrimSpace(parsed["visitReason"])
|
|
if visitReason == "" {
|
|
return "", fmt.Errorf("visitReason not found in response")
|
|
}
|
|
|
|
return visitReason, nil
|
|
}
|
|
|
|
// openAICompletion calls Ollama API with prompt and structure, returns structured result
|
|
func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
|
|
apiURL := llm.BaseURL
|
|
if apiURL == "" {
|
|
apiURL = "http://localhost:11434/api/chat"
|
|
}
|
|
logrus.WithFields(logrus.Fields{"api_url": apiURL, "prompt": prompt, "format": format}).Info("[LLM] openAICompletion POST")
|
|
body := map[string]interface{}{
|
|
"model": llm.Model, // "qwen3:latest",
|
|
"messages": []map[string]string{{"role": "user", "content": prompt}},
|
|
"stream": false,
|
|
"format": format,
|
|
}
|
|
jsonBody, _ := json.Marshal(body)
|
|
req, _ := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
|
|
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
logrus.WithError(err).Error("[LLM] openAICompletion error")
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
var result struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
logrus.WithError(err).Error("[LLM] openAICompletion decode error")
|
|
return "", err
|
|
}
|
|
if result.Message.Content == "" {
|
|
logrus.Warn("[LLM] openAICompletion: no content returned %v body:[%v]", resp.Status, resp.Body)
|
|
return "", nil
|
|
}
|
|
logrus.WithField("content", result.Message.Content).Info("[LLM] openAICompletion: got content")
|
|
return result.Message.Content, nil
|
|
}
|
|
|
|
// LLMClientAPI allows mocking LLMClient in other places
|
|
// Only public methods should be included
|
|
|
|
type LLMClientAPI interface {
|
|
ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error)
|
|
DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error)
|
|
}
|
|
|
|
var _ LLMClientAPI = (*LLMClient)(nil)
|