185 lines
5.9 KiB
Go
185 lines
5.9 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// --- OllamaClient implementation ---
|
|
|
|
type OllamaClient struct {
|
|
APIKey string
|
|
BaseURL string
|
|
Model string
|
|
Repo ChatRepositoryAPI
|
|
}
|
|
|
|
func NewOllamaClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OllamaClient {
|
|
return &OllamaClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo}
|
|
}
|
|
|
|
func (llm *OllamaClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
|
_, parsed, err := llm.ExtractKeywordsRaw(ctx, message)
|
|
return parsed, err
|
|
}
|
|
|
|
func (llm *OllamaClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) {
|
|
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
|
|
if err != nil {
|
|
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
|
|
return "", nil, err
|
|
}
|
|
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
|
|
|
|
// Use the utility function instead of inline format definition
|
|
format := GetExtractKeywordsFormat()
|
|
|
|
resp, err := llm.ollamaCompletion(ctx, prompt, format)
|
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
|
|
if err != nil {
|
|
return resp, nil, err
|
|
}
|
|
var result map[string]interface{}
|
|
if err := json.Unmarshal([]byte(resp), &result); err != nil {
|
|
return resp, nil, err
|
|
}
|
|
return resp, result, nil
|
|
}
|
|
|
|
func (llm *OllamaClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
|
|
_, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates)
|
|
return vr, err
|
|
}
|
|
|
|
func (llm *OllamaClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
|
|
// Use the utility function instead of inline format definition
|
|
format := GetDisambiguateFormat()
|
|
|
|
entries, _ := json.Marshal(candidates)
|
|
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
|
|
if err != nil {
|
|
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
|
|
return "", "", err
|
|
}
|
|
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
|
|
resp, err := llm.ollamaCompletion(ctx, prompt, format)
|
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
|
|
if err != nil {
|
|
return resp, "", err
|
|
}
|
|
var parsed map[string]string
|
|
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
|
|
return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
|
|
}
|
|
visitReason := strings.TrimSpace(parsed["visitReason"])
|
|
if visitReason == "" {
|
|
return resp, "", fmt.Errorf("visitReason not found in response")
|
|
}
|
|
return resp, visitReason, nil
|
|
}
|
|
|
|
func (llm *OllamaClient) ollamaCompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
|
|
apiURL := llm.BaseURL
|
|
if apiURL == "" {
|
|
apiURL = "http://localhost:11434/api/chat"
|
|
}
|
|
body := map[string]interface{}{
|
|
"model": llm.Model,
|
|
"messages": []map[string]string{{"role": "user", "content": prompt}},
|
|
"stream": false,
|
|
"format": format,
|
|
}
|
|
jsonBody, _ := json.Marshal(body)
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
|
|
if llm.APIKey != "" {
|
|
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
raw, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
var ollama struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
Error string `json:"error"`
|
|
}
|
|
if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" {
|
|
return ollama.Message.Content, nil
|
|
}
|
|
if ollama.Error != "" {
|
|
return "", fmt.Errorf("provider error: %s", ollama.Error)
|
|
}
|
|
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
|
|
}
|
|
|
|
func (llm *OllamaClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
|
|
apiURL := llm.BaseURL
|
|
if apiURL == "" {
|
|
apiURL = "http://localhost:11434/api/embeddings"
|
|
}
|
|
body := map[string]interface{}{
|
|
"model": llm.Model,
|
|
"prompt": input,
|
|
}
|
|
jsonBody, _ := json.Marshal(body)
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
|
|
if llm.APIKey != "" {
|
|
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
raw, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var ollama struct {
|
|
Embedding []float64 `json:"embedding"`
|
|
Error string `json:"error"`
|
|
}
|
|
if err := json.Unmarshal(raw, &ollama); err == nil && len(ollama.Embedding) > 0 {
|
|
return ollama.Embedding, nil
|
|
}
|
|
if ollama.Error != "" {
|
|
return nil, fmt.Errorf("embedding error: %s", ollama.Error)
|
|
}
|
|
return nil, fmt.Errorf("unrecognized embedding response: %.200s", string(raw))
|
|
}
|
|
|
|
func (llm *OllamaClient) TranslateToEnglish(ctx context.Context, message string) (string, error) {
|
|
prompt, err := renderPrompt(appConfig.LLM.TranslatePrompt, map[string]string{"Message": message})
|
|
if err != nil {
|
|
logrus.WithError(err).Error("[CONFIG] Failed to render Translate prompt")
|
|
return "", err
|
|
}
|
|
logrus.WithField("prompt", prompt).Info("[LLM] TranslateToEnglish prompt")
|
|
|
|
resp, err := llm.ollamaCompletion(ctx, prompt, nil)
|
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] TranslateToEnglish response")
|
|
if err != nil {
|
|
return resp, err
|
|
}
|
|
return strings.TrimSpace(resp), nil
|
|
}
|