separate clients
This commit is contained in:
parent
4ada379be9
commit
a0f477c9a8
|
|
@ -83,20 +83,9 @@ func (cs *ChatService) findBestVisit(ctx context.Context, req ChatRequest, keywo
|
||||||
bestID := ""
|
bestID := ""
|
||||||
rawDis := ""
|
rawDis := ""
|
||||||
if len(candidates) > 0 {
|
if len(candidates) > 0 {
|
||||||
if real, ok := cs.LLM.(*LLMClient); ok {
|
|
||||||
raw, vr, derr := real.DisambiguateBestMatchRaw(ctx, req.Message, candidates)
|
|
||||||
rawDis = raw
|
|
||||||
bestID = vr
|
|
||||||
if derr != nil {
|
|
||||||
cs.logBestID(bestID, derr)
|
|
||||||
} else {
|
|
||||||
cs.logBestID(bestID, nil)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidates)
|
bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidates)
|
||||||
cs.logBestID(bestID, err)
|
cs.logBestID(bestID, err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
visit, err := cs.visitsDB.FindById(bestID)
|
visit, err := cs.visitsDB.FindById(bestID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, rawDis, fmt.Errorf("FindById: %w", err)
|
return nil, rawDis, fmt.Errorf("FindById: %w", err)
|
||||||
|
|
@ -236,3 +225,8 @@ func (cs *ChatService) persistInteraction(ctx context.Context, correlationID str
|
||||||
logrus.WithError(err).Debug("failed to save chat interaction")
|
logrus.WithError(err).Debug("failed to save chat interaction")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add this at the top-level (outside any function)
|
||||||
|
type correlationIDCtxKeyType struct{}
|
||||||
|
|
||||||
|
var correlationIDCtxKey = correlationIDCtxKeyType{}
|
||||||
|
|
|
||||||
337
llm.go
337
llm.go
|
|
@ -3,55 +3,66 @@ package main
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"os"
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LLMClient abstracts LLM API calls
|
// LLMClientAPI allows mocking LLMClient in other places
|
||||||
type LLMClient struct {
|
type LLMClientAPI interface {
|
||||||
APIKey string
|
ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error)
|
||||||
BaseURL string
|
DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error)
|
||||||
Model string
|
GetEmbeddings(ctx context.Context, input string) ([]float64, error)
|
||||||
Repo ChatRepositoryAPI
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLLMClient constructs a new LLMClient with the given API key, base URL, model, and optional repository
|
// --- Format Utilities ---
|
||||||
func NewLLMClient(apiKey, baseURL string, model string, repo ChatRepositoryAPI) *LLMClient {
|
|
||||||
return &LLMClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo}
|
// GetExtractKeywordsFormat returns the format specification for keyword extraction
|
||||||
|
func GetExtractKeywordsFormat() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"translate": map[string]interface{}{"type": "string"},
|
||||||
|
"keyword": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
|
||||||
|
"animal": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
"required": []string{"translate", "keyword", "animal"},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLMClient) SetRepository(r ChatRepositoryAPI) { llm.Repo = r }
|
// GetDisambiguateFormat returns the format specification for disambiguation
|
||||||
|
func GetDisambiguateFormat() map[string]interface{} {
|
||||||
// helper to get correlation id from context
|
return map[string]interface{}{
|
||||||
const correlationIDCtxKey = "corr_id"
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
func correlationIDFromCtx(ctx context.Context) string {
|
"visitReason": map[string]interface{}{"type": "string"},
|
||||||
v := ctx.Value(correlationIDCtxKey)
|
},
|
||||||
if s, ok := v.(string); ok {
|
"required": []string{"visitReason"},
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLMClient) persistRaw(ctx context.Context, phase, raw string) {
|
// --- Factory ---
|
||||||
if llm == nil || llm.Repo == nil || raw == "" {
|
|
||||||
return
|
func NewLLMClientFromEnv(repo ChatRepositoryAPI) LLMClientAPI {
|
||||||
|
provider := os.Getenv("LLM_PROVIDER")
|
||||||
|
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||||
|
baseURL := os.Getenv("OPENAI_BASE_URL")
|
||||||
|
model := os.Getenv("OPENAI_MODEL")
|
||||||
|
switch strings.ToLower(provider) {
|
||||||
|
case "openai", "openrouter":
|
||||||
|
return NewOpenAIClient(apiKey, baseURL, model, repo)
|
||||||
|
case "ollama", "":
|
||||||
|
return NewOllamaClient(apiKey, baseURL, model, repo)
|
||||||
|
default:
|
||||||
|
logrus.Warnf("Unknown LLM_PROVIDER %q, defaulting to Ollama", provider)
|
||||||
|
return NewOllamaClient(apiKey, baseURL, model, repo)
|
||||||
}
|
}
|
||||||
cid := correlationIDFromCtx(ctx)
|
|
||||||
if cid == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = llm.Repo.SaveLLMRawEvent(ctx, cid, phase, raw)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// renderPrompt renders a Go template with the given data
|
// --- Utility ---
|
||||||
|
|
||||||
func renderPrompt(tmplStr string, data any) (string, error) {
|
func renderPrompt(tmplStr string, data any) (string, error) {
|
||||||
tmpl, err := template.New("").Parse(tmplStr)
|
tmpl, err := template.New("").Parse(tmplStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -63,263 +74,3 @@ func renderPrompt(tmplStr string, data any) (string, error) {
|
||||||
}
|
}
|
||||||
return buf.String(), nil
|
return buf.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExtractKeywords calls LLM to extract keywords from user message
|
|
||||||
func (llm *LLMClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
|
||||||
_, parsed, err := llm.ExtractKeywordsRaw(ctx, message)
|
|
||||||
return parsed, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExtractKeywordsRaw returns the raw JSON string and parsed map
|
|
||||||
func (llm *LLMClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) {
|
|
||||||
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
|
|
||||||
format := map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"translate": map[string]interface{}{"type": "string"},
|
|
||||||
"keyword": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
|
|
||||||
"animal": map[string]interface{}{"type": "string"},
|
|
||||||
},
|
|
||||||
"required": []string{"translate", "keyword", "animal"},
|
|
||||||
}
|
|
||||||
resp, err := llm.openAICompletion(ctx, prompt, format)
|
|
||||||
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
|
|
||||||
if err != nil {
|
|
||||||
return resp, nil, err // return whatever raw we got (may be empty)
|
|
||||||
}
|
|
||||||
var result map[string]interface{}
|
|
||||||
if err := json.Unmarshal([]byte(resp), &result); err != nil {
|
|
||||||
return resp, nil, err
|
|
||||||
}
|
|
||||||
llm.persistRaw(ctx, "extract_keywords", resp)
|
|
||||||
return resp, result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DisambiguateBestMatch calls LLM to pick best match from candidates
|
|
||||||
func (llm *LLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
|
|
||||||
_, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates)
|
|
||||||
return vr, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// DisambiguateBestMatchRaw returns raw JSON and visitReason
|
|
||||||
func (llm *LLMClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
|
|
||||||
format := map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"visitReason": map[string]interface{}{"type": "string"},
|
|
||||||
},
|
|
||||||
"required": []string{"visitReason"},
|
|
||||||
}
|
|
||||||
entries, _ := json.Marshal(candidates)
|
|
||||||
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
|
|
||||||
resp, err := llm.openAICompletion(ctx, prompt, format)
|
|
||||||
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
|
|
||||||
if err != nil {
|
|
||||||
return resp, "", err
|
|
||||||
}
|
|
||||||
var parsed map[string]string
|
|
||||||
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
|
|
||||||
return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
|
|
||||||
}
|
|
||||||
visitReason := strings.TrimSpace(parsed["visitReason"])
|
|
||||||
if visitReason == "" {
|
|
||||||
return resp, "", fmt.Errorf("visitReason not found in response")
|
|
||||||
}
|
|
||||||
llm.persistRaw(ctx, "disambiguate", resp)
|
|
||||||
return resp, visitReason, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// openAICompletion now supports both Ollama (default local) and OpenRouter/OpenAI-compatible APIs without external branching.
|
|
||||||
// It auto-detects by inspecting the BaseURL. If the URL contains "openrouter.ai" or "/v1/", it assumes OpenAI-style.
|
|
||||||
func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
|
|
||||||
apiURL := llm.BaseURL
|
|
||||||
if apiURL == "" {
|
|
||||||
// Default to Ollama local chat endpoint
|
|
||||||
apiURL = "http://localhost:11434/api/chat"
|
|
||||||
}
|
|
||||||
|
|
||||||
isOpenAIStyle := strings.Contains(apiURL, "openrouter.ai") || strings.Contains(apiURL, "/v1/")
|
|
||||||
|
|
||||||
// Helper to stringify the expected JSON schema for instructions
|
|
||||||
schemaDesc := func() string {
|
|
||||||
b, _ := json.MarshalIndent(format, "", " ")
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
truncate := func(s string, n int) string {
|
|
||||||
if len(s) <= n {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
return s[:n] + "...<truncated>"
|
|
||||||
}
|
|
||||||
|
|
||||||
buildBody := func() map[string]interface{} {
|
|
||||||
if isOpenAIStyle {
|
|
||||||
return map[string]interface{}{
|
|
||||||
"model": llm.Model,
|
|
||||||
"messages": []map[string]string{
|
|
||||||
{"role": "system", "content": "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + schemaDesc() + " Do not add explanations."},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
},
|
|
||||||
"response_format": map[string]interface{}{"type": "json_object"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Ollama style
|
|
||||||
return map[string]interface{}{
|
|
||||||
"model": llm.Model,
|
|
||||||
"messages": []map[string]string{{"role": "user", "content": prompt}},
|
|
||||||
"stream": false,
|
|
||||||
"format": format,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
body := buildBody()
|
|
||||||
|
|
||||||
doRequest := func(body map[string]interface{}) (raw []byte, status int, err error, dur time.Duration) {
|
|
||||||
jsonBody, _ := json.Marshal(body)
|
|
||||||
bodySize := len(jsonBody)
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"event": "llm_request",
|
|
||||||
"api_url": apiURL,
|
|
||||||
"model": llm.Model,
|
|
||||||
"is_openai_style": isOpenAIStyle,
|
|
||||||
"prompt_len": len(prompt),
|
|
||||||
"body_size": bodySize,
|
|
||||||
}).Info("[LLM] sending request")
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
|
|
||||||
if llm.APIKey != "" {
|
|
||||||
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
if strings.Contains(apiURL, "openrouter.ai") {
|
|
||||||
req.Header.Set("Referer", "https://github.com/")
|
|
||||||
req.Header.Set("X-Title", "vetrag-app")
|
|
||||||
}
|
|
||||||
start := time.Now()
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err, time.Since(start)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
raw, rerr := io.ReadAll(resp.Body)
|
|
||||||
return raw, resp.StatusCode, rerr, time.Since(start)
|
|
||||||
}
|
|
||||||
|
|
||||||
raw, status, err, dur := doRequest(body)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"event": "llm_response",
|
|
||||||
"status": status,
|
|
||||||
"latency_ms": dur.Milliseconds(),
|
|
||||||
"error": err,
|
|
||||||
}).Error("[LLM] request failed")
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"event": "llm_raw_response",
|
|
||||||
"status": status,
|
|
||||||
"latency_ms": dur.Milliseconds(),
|
|
||||||
"raw_trunc": truncate(string(raw), 600),
|
|
||||||
"raw_len": len(raw),
|
|
||||||
}).Debug("[LLM] raw response body")
|
|
||||||
|
|
||||||
parseVariant := "unknown"
|
|
||||||
|
|
||||||
// Attempt Ollama format parse
|
|
||||||
var ollama struct {
|
|
||||||
Message struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
} `json:"message"`
|
|
||||||
Error string `json:"error"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" {
|
|
||||||
parseVariant = "ollama"
|
|
||||||
content := ollama.Message.Content
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"event": "llm_response",
|
|
||||||
"status": status,
|
|
||||||
"latency_ms": dur.Milliseconds(),
|
|
||||||
"parse_variant": parseVariant,
|
|
||||||
"content_len": len(content),
|
|
||||||
"content_snip": truncate(content, 300),
|
|
||||||
}).Info("[LLM] parsed response")
|
|
||||||
return content, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt OpenAI/OpenRouter style parse
|
|
||||||
var openAI struct {
|
|
||||||
Choices []struct {
|
|
||||||
Message struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
} `json:"message"`
|
|
||||||
} `json:"choices"`
|
|
||||||
Error *struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
} `json:"error"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(raw, &openAI); err == nil {
|
|
||||||
if openAI.Error != nil || status >= 400 {
|
|
||||||
parseVariant = "openai"
|
|
||||||
var msg string
|
|
||||||
if openAI.Error != nil {
|
|
||||||
msg = openAI.Error.Message
|
|
||||||
} else {
|
|
||||||
msg = string(raw)
|
|
||||||
}
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"event": "llm_response",
|
|
||||||
"status": status,
|
|
||||||
"latency_ms": dur.Milliseconds(),
|
|
||||||
"parse_variant": parseVariant,
|
|
||||||
"error": msg,
|
|
||||||
}).Error("[LLM] provider error")
|
|
||||||
return "", fmt.Errorf("provider error: %s", msg)
|
|
||||||
}
|
|
||||||
if len(openAI.Choices) > 0 && openAI.Choices[0].Message.Content != "" {
|
|
||||||
parseVariant = "openai"
|
|
||||||
content := openAI.Choices[0].Message.Content
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"event": "llm_response",
|
|
||||||
"status": status,
|
|
||||||
"latency_ms": dur.Milliseconds(),
|
|
||||||
"parse_variant": parseVariant,
|
|
||||||
"content_len": len(content),
|
|
||||||
"content_snip": truncate(content, 300),
|
|
||||||
}).Info("[LLM] parsed response")
|
|
||||||
return content, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"event": "llm_response",
|
|
||||||
"status": status,
|
|
||||||
"latency_ms": dur.Milliseconds(),
|
|
||||||
"parse_variant": parseVariant,
|
|
||||||
"raw_snip": truncate(string(raw), 300),
|
|
||||||
}).Error("[LLM] unrecognized response format")
|
|
||||||
|
|
||||||
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
|
|
||||||
}
|
|
||||||
|
|
||||||
// LLMClientAPI allows mocking LLMClient in other places
|
|
||||||
// Only public methods should be included
|
|
||||||
|
|
||||||
type LLMClientAPI interface {
|
|
||||||
ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error)
|
|
||||||
DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ LLMClientAPI = (*LLMClient)(nil)
|
|
||||||
|
|
|
||||||
9
main.go
9
main.go
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
_ "github.com/jackc/pgx/v5/stdlib"
|
_ "github.com/jackc/pgx/v5/stdlib"
|
||||||
|
|
@ -71,13 +70,7 @@ func main() {
|
||||||
// defer repo.Close() // optionally enable
|
// defer repo.Close() // optionally enable
|
||||||
|
|
||||||
// Initialize LLM client
|
// Initialize LLM client
|
||||||
llmClient := NewLLMClient(
|
llm := NewLLMClientFromEnv(repo)
|
||||||
os.Getenv("OPENAI_API_KEY"),
|
|
||||||
os.Getenv("OPENAI_BASE_URL"),
|
|
||||||
os.Getenv("OPENAI_MODEL"),
|
|
||||||
repo,
|
|
||||||
)
|
|
||||||
var llm LLMClientAPI = llmClient
|
|
||||||
|
|
||||||
// Wrap templates for controller
|
// Wrap templates for controller
|
||||||
uiTmpl := &TemplateWrapper{Tmpl: uiTemplate}
|
uiTmpl := &TemplateWrapper{Tmpl: uiTemplate}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,168 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- OllamaClient implementation ---
|
||||||
|
|
||||||
|
type OllamaClient struct {
|
||||||
|
APIKey string
|
||||||
|
BaseURL string
|
||||||
|
Model string
|
||||||
|
Repo ChatRepositoryAPI
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOllamaClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OllamaClient {
|
||||||
|
return &OllamaClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *OllamaClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
||||||
|
_, parsed, err := llm.ExtractKeywordsRaw(ctx, message)
|
||||||
|
return parsed, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *OllamaClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) {
|
||||||
|
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
|
||||||
|
|
||||||
|
// Use the utility function instead of inline format definition
|
||||||
|
format := GetExtractKeywordsFormat()
|
||||||
|
|
||||||
|
resp, err := llm.ollamaCompletion(ctx, prompt, format)
|
||||||
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
|
||||||
|
if err != nil {
|
||||||
|
return resp, nil, err
|
||||||
|
}
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(resp), &result); err != nil {
|
||||||
|
return resp, nil, err
|
||||||
|
}
|
||||||
|
return resp, result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *OllamaClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
|
||||||
|
_, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates)
|
||||||
|
return vr, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *OllamaClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
|
||||||
|
// Use the utility function instead of inline format definition
|
||||||
|
format := GetDisambiguateFormat()
|
||||||
|
|
||||||
|
entries, _ := json.Marshal(candidates)
|
||||||
|
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
|
||||||
|
resp, err := llm.ollamaCompletion(ctx, prompt, format)
|
||||||
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
|
||||||
|
if err != nil {
|
||||||
|
return resp, "", err
|
||||||
|
}
|
||||||
|
var parsed map[string]string
|
||||||
|
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
|
||||||
|
return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
|
||||||
|
}
|
||||||
|
visitReason := strings.TrimSpace(parsed["visitReason"])
|
||||||
|
if visitReason == "" {
|
||||||
|
return resp, "", fmt.Errorf("visitReason not found in response")
|
||||||
|
}
|
||||||
|
return resp, visitReason, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *OllamaClient) ollamaCompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
|
||||||
|
apiURL := llm.BaseURL
|
||||||
|
if apiURL == "" {
|
||||||
|
apiURL = "http://localhost:11434/api/chat"
|
||||||
|
}
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"model": llm.Model,
|
||||||
|
"messages": []map[string]string{{"role": "user", "content": prompt}},
|
||||||
|
"stream": false,
|
||||||
|
"format": format,
|
||||||
|
}
|
||||||
|
jsonBody, _ := json.Marshal(body)
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
|
||||||
|
if llm.APIKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var ollama struct {
|
||||||
|
Message struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" {
|
||||||
|
return ollama.Message.Content, nil
|
||||||
|
}
|
||||||
|
if ollama.Error != "" {
|
||||||
|
return "", fmt.Errorf("provider error: %s", ollama.Error)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *OllamaClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
|
||||||
|
apiURL := llm.BaseURL
|
||||||
|
if apiURL == "" {
|
||||||
|
apiURL = "http://localhost:11434/api/embeddings"
|
||||||
|
}
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"model": llm.Model,
|
||||||
|
"prompt": input,
|
||||||
|
}
|
||||||
|
jsonBody, _ := json.Marshal(body)
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
|
||||||
|
if llm.APIKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var ollama struct {
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(raw, &ollama); err == nil && len(ollama.Embedding) > 0 {
|
||||||
|
return ollama.Embedding, nil
|
||||||
|
}
|
||||||
|
if ollama.Error != "" {
|
||||||
|
return nil, fmt.Errorf("embedding error: %s", ollama.Error)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unrecognized embedding response: %.200s", string(raw))
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,200 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- OpenAIClient implementation ---
|
||||||
|
|
||||||
|
type OpenAIClient struct {
|
||||||
|
APIKey string
|
||||||
|
BaseURL string
|
||||||
|
Model string
|
||||||
|
Repo ChatRepositoryAPI
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient {
|
||||||
|
return &OpenAIClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
||||||
|
_, parsed, err := llm.ExtractKeywordsRaw(ctx, message)
|
||||||
|
return parsed, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *OpenAIClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) {
|
||||||
|
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
|
||||||
|
|
||||||
|
// Use the utility function instead of inline format definition
|
||||||
|
format := GetExtractKeywordsFormat()
|
||||||
|
|
||||||
|
resp, err := llm.openAICompletion(ctx, prompt, format)
|
||||||
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
|
||||||
|
if err != nil {
|
||||||
|
return resp, nil, err
|
||||||
|
}
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(resp), &result); err != nil {
|
||||||
|
return resp, nil, err
|
||||||
|
}
|
||||||
|
return resp, result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
|
||||||
|
_, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates)
|
||||||
|
return vr, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
|
||||||
|
// Use the utility function instead of inline format definition
|
||||||
|
format := GetDisambiguateFormat()
|
||||||
|
|
||||||
|
entries, _ := json.Marshal(candidates)
|
||||||
|
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
|
||||||
|
resp, err := llm.openAICompletion(ctx, prompt, format)
|
||||||
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
|
||||||
|
if err != nil {
|
||||||
|
return resp, "", err
|
||||||
|
}
|
||||||
|
var parsed map[string]string
|
||||||
|
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
|
||||||
|
return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
|
||||||
|
}
|
||||||
|
visitReason := strings.TrimSpace(parsed["visitReason"])
|
||||||
|
if visitReason == "" {
|
||||||
|
return resp, "", fmt.Errorf("visitReason not found in response")
|
||||||
|
}
|
||||||
|
return resp, visitReason, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
|
||||||
|
apiURL := llm.BaseURL
|
||||||
|
if apiURL == "" {
|
||||||
|
apiURL = "https://api.openai.com/v1/chat/completions"
|
||||||
|
}
|
||||||
|
// Helper to stringify the expected JSON schema for instructions
|
||||||
|
schemaDesc := func() string {
|
||||||
|
b, _ := json.MarshalIndent(format, "", " ")
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"model": llm.Model,
|
||||||
|
"messages": []map[string]string{
|
||||||
|
{"role": "system", "content": "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + schemaDesc() + " Do not add explanations."},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
},
|
||||||
|
"response_format": map[string]interface{}{"type": "json_object"},
|
||||||
|
}
|
||||||
|
jsonBody, _ := json.Marshal(body)
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
|
||||||
|
if llm.APIKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
if strings.Contains(apiURL, "openrouter.ai") {
|
||||||
|
req.Header.Set("Referer", "https://github.com/")
|
||||||
|
req.Header.Set("X-Title", "vetrag-app")
|
||||||
|
}
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var openAI struct {
|
||||||
|
Choices []struct {
|
||||||
|
Message struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"choices"`
|
||||||
|
Error *struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
} `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(raw, &openAI); err == nil {
|
||||||
|
if openAI.Error != nil || resp.StatusCode >= 400 {
|
||||||
|
var msg string
|
||||||
|
if openAI.Error != nil {
|
||||||
|
msg = openAI.Error.Message
|
||||||
|
} else {
|
||||||
|
msg = string(raw)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("provider error: %s", msg)
|
||||||
|
}
|
||||||
|
if len(openAI.Choices) > 0 && openAI.Choices[0].Message.Content != "" {
|
||||||
|
return openAI.Choices[0].Message.Content, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
|
||||||
|
apiURL := llm.BaseURL
|
||||||
|
if apiURL == "" {
|
||||||
|
apiURL = "https://api.openai.com/v1/embeddings"
|
||||||
|
}
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"model": llm.Model,
|
||||||
|
"input": input,
|
||||||
|
}
|
||||||
|
jsonBody, _ := json.Marshal(body)
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
|
||||||
|
if llm.APIKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
if strings.Contains(apiURL, "openrouter.ai") {
|
||||||
|
req.Header.Set("Referer", "https://github.com/")
|
||||||
|
req.Header.Set("X-Title", "vetrag-app")
|
||||||
|
}
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var openAI struct {
|
||||||
|
Data []struct {
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
} `json:"data"`
|
||||||
|
Error *struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
} `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(raw, &openAI); err == nil && len(openAI.Data) > 0 {
|
||||||
|
return openAI.Data[0].Embedding, nil
|
||||||
|
}
|
||||||
|
if openAI.Error != nil {
|
||||||
|
return nil, fmt.Errorf("embedding error: %s", openAI.Error.Message)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unrecognized embedding response: %.200s", string(raw))
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue