77 lines
2.2 KiB
Go
77 lines
2.2 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"os"
|
|
"strings"
|
|
"text/template"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// LLMClientAPI allows mocking LLMClient in other places
|
|
type LLMClientAPI interface {
|
|
ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error)
|
|
DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error)
|
|
GetEmbeddings(ctx context.Context, input string) ([]float64, error)
|
|
}
|
|
|
|
// --- Format Utilities ---
|
|
|
|
// GetExtractKeywordsFormat returns the format specification for keyword extraction
|
|
func GetExtractKeywordsFormat() map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"translate": map[string]interface{}{"type": "string"},
|
|
"keyword": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
|
|
"animal": map[string]interface{}{"type": "string"},
|
|
},
|
|
"required": []string{"translate", "keyword", "animal"},
|
|
}
|
|
}
|
|
|
|
// GetDisambiguateFormat returns the format specification for disambiguation
|
|
func GetDisambiguateFormat() map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"visitReason": map[string]interface{}{"type": "string"},
|
|
},
|
|
"required": []string{"visitReason"},
|
|
}
|
|
}
|
|
|
|
// --- Factory ---
|
|
|
|
func NewLLMClientFromEnv(repo ChatRepositoryAPI) LLMClientAPI {
|
|
provider := os.Getenv("LLM_PROVIDER")
|
|
apiKey := os.Getenv("OPENAI_API_KEY")
|
|
baseURL := os.Getenv("OPENAI_BASE_URL")
|
|
model := os.Getenv("OPENAI_MODEL")
|
|
switch strings.ToLower(provider) {
|
|
case "openai", "openrouter":
|
|
return NewOpenAIClient(apiKey, baseURL, model, repo)
|
|
case "ollama", "":
|
|
return NewOllamaClient(apiKey, baseURL, model, repo)
|
|
default:
|
|
logrus.Warnf("Unknown LLM_PROVIDER %q, defaulting to Ollama", provider)
|
|
return NewOllamaClient(apiKey, baseURL, model, repo)
|
|
}
|
|
}
|
|
|
|
// --- Utility ---
|
|
|
|
func renderPrompt(tmplStr string, data any) (string, error) {
|
|
tmpl, err := template.New("").Parse(tmplStr)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
var buf bytes.Buffer
|
|
if err := tmpl.Execute(&buf, data); err != nil {
|
|
return "", err
|
|
}
|
|
return buf.String(), nil
|
|
}
|