vetrag/llm.go

77 lines
2.2 KiB
Go

package main
import (
"bytes"
"context"
"os"
"strings"
"text/template"
"github.com/sirupsen/logrus"
)
// LLMClientAPI allows mocking LLMClient in other places
type LLMClientAPI interface {
ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error)
DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error)
GetEmbeddings(ctx context.Context, input string) ([]float64, error)
}
// --- Format Utilities ---
// GetExtractKeywordsFormat returns the format specification for keyword extraction
func GetExtractKeywordsFormat() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"translate": map[string]interface{}{"type": "string"},
"keyword": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
"animal": map[string]interface{}{"type": "string"},
},
"required": []string{"translate", "keyword", "animal"},
}
}
// GetDisambiguateFormat returns the format specification for disambiguation
func GetDisambiguateFormat() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"visitReason": map[string]interface{}{"type": "string"},
},
"required": []string{"visitReason"},
}
}
// --- Factory ---
func NewLLMClientFromEnv(repo ChatRepositoryAPI) LLMClientAPI {
provider := os.Getenv("LLM_PROVIDER")
apiKey := os.Getenv("OPENAI_API_KEY")
baseURL := os.Getenv("OPENAI_BASE_URL")
model := os.Getenv("OPENAI_MODEL")
switch strings.ToLower(provider) {
case "openai", "openrouter":
return NewOpenAIClient(apiKey, baseURL, model, repo)
case "ollama", "":
return NewOllamaClient(apiKey, baseURL, model, repo)
default:
logrus.Warnf("Unknown LLM_PROVIDER %q, defaulting to Ollama", provider)
return NewOllamaClient(apiKey, baseURL, model, repo)
}
}
// --- Utility ---
func renderPrompt(tmplStr string, data any) (string, error) {
tmpl, err := template.New("").Parse(tmplStr)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}