vetrag/ollama_client.go

274 lines
8.6 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/sirupsen/logrus"
)
// --- OllamaClient implementation ---
type OllamaClient struct {
APIKey string
BaseURL string
Model string
EmbeddingModel string
Repo ChatRepositoryAPI
}
func NewOllamaClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OllamaClient {
return &OllamaClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo}
}
func (llm *OllamaClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
_, parsed, err := llm.ExtractKeywordsRaw(ctx, message)
return parsed, err
}
func (llm *OllamaClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) {
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
return "", nil, err
}
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
// Use the utility function instead of inline format definition
format := GetExtractKeywordsFormat()
resp, err := llm.ollamaCompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
if err != nil {
return resp, nil, err
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(resp), &result); err != nil {
return resp, nil, err
}
return resp, result, nil
}
func (llm *OllamaClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
_, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates)
return vr, err
}
func (llm *OllamaClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
// Use the utility function instead of inline format definition
format := GetDisambiguateFormat()
entries, _ := json.Marshal(candidates)
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
return "", "", err
}
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
resp, err := llm.ollamaCompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
if err != nil {
return resp, "", err
}
var parsed map[string]string
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
}
visitReason := strings.TrimSpace(parsed["visitReason"])
if visitReason == "" {
return resp, "", fmt.Errorf("visitReason not found in response")
}
return resp, visitReason, nil
}
func (llm *OllamaClient) ollamaCompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
apiURL := llm.BaseURL
if apiURL == "" {
apiURL = "http://localhost:11434/api/chat"
}
messages := []map[string]string{{"role": "user", "content": prompt}}
//if os.Getenv("DISABLE_THINK") == "1" {
// System message to suppress chain-of-thought style outputs.
messages = append([]map[string]string{{
"role": "system",
"content": "You are a concise assistant. Output ONLY the final answer requested by the user. Do not include reasoning, analysis, or <think> tags.",
}}, messages...)
//}
body := map[string]interface{}{
"model": llm.Model,
"messages": messages,
"stream": false,
"format": format,
}
// Optional: Add a stop sequence to prevent <think> tags if they appear
if os.Getenv("DISABLE_THINK") == "1" {
body["options"] = map[string]interface{}{"stop": []string{"<think>"}}
}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
if llm.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
var ollama struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
Error string `json:"error"`
}
if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" {
return ollama.Message.Content, nil
}
if ollama.Error != "" {
return "", fmt.Errorf("provider error: %s", ollama.Error)
}
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
}
func normalizeOllamaHost(raw string) string {
if raw == "" {
return "http://localhost:11434"
}
// strip trailing /api/* paths if user provided full endpoint
lower := strings.ToLower(raw)
for _, seg := range []string{"/api/chat", "/api/embeddings", "/api/generate"} {
if strings.HasSuffix(lower, seg) {
return raw[:len(raw)-len(seg)]
}
}
return raw
}
func (llm *OllamaClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
host := normalizeOllamaHost(llm.BaseURL)
apiURL := host + "/api/embeddings"
modelName := llm.Model
if llm.EmbeddingModel != "" {
modelName = llm.EmbeddingModel
}
// retry parameters (env override OLLAMA_EMBED_ATTEMPTS)
maxAttempts := 5
if v := os.Getenv("OLLAMA_EMBED_ATTEMPTS"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 && n < 20 {
maxAttempts = n
}
}
baseBackoff := 300 * time.Millisecond
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
body := map[string]interface{}{
"model": modelName,
"prompt": input,
}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
if llm.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := (&http.Client{}).Do(req)
if err != nil {
lastErr = err
logrus.WithError(err).Warnf("[Ollama] embeddings request attempt=%d failed", attempt+1)
} else {
raw, rerr := io.ReadAll(resp.Body)
resp.Body.Close()
if rerr != nil {
lastErr = rerr
} else {
var generic map[string]json.RawMessage
if jerr := json.Unmarshal(raw, &generic); jerr != nil {
lastErr = fmt.Errorf("unrecognized response (parse): %w", jerr)
} else if embRaw, ok := generic["embedding"]; ok && len(embRaw) > 0 {
var emb []float64
if jerr := json.Unmarshal(embRaw, &emb); jerr != nil {
lastErr = fmt.Errorf("failed to decode embedding: %w", jerr)
} else if len(emb) == 0 {
lastErr = fmt.Errorf("empty embedding returned")
} else {
return emb, nil
}
} else if drRaw, ok := generic["done_reason"]; ok {
var reason string
_ = json.Unmarshal(drRaw, &reason)
if reason == "load" { // transient model loading state
lastErr = fmt.Errorf("model loading")
} else {
lastErr = fmt.Errorf("unexpected done_reason=%s", reason)
}
} else if errRaw, ok := generic["error"]; ok {
var errMsg string
_ = json.Unmarshal(errRaw, &errMsg)
if errMsg != "" {
lastErr = fmt.Errorf("embedding error: %s", errMsg)
} else {
lastErr = fmt.Errorf("embedding error (empty message)")
}
} else {
lastErr = fmt.Errorf("unrecognized embedding response: %.200s", string(raw))
}
}
}
if lastErr == nil {
break
}
// backoff if not last attempt
if attempt < maxAttempts-1 {
delay := baseBackoff << attempt
if strings.Contains(strings.ToLower(lastErr.Error()), "model loading") {
delay += 1 * time.Second
}
time.Sleep(delay)
}
}
if lastErr == nil {
lastErr = fmt.Errorf("embedding retrieval failed with no error info")
}
return nil, lastErr
}
func (llm *OllamaClient) TranslateToEnglish(ctx context.Context, message string) (string, error) {
prompt, err := renderPrompt(appConfig.LLM.TranslatePrompt, map[string]string{"Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render Translate prompt")
return "", err
}
logrus.WithField("prompt", prompt).Info("[LLM] TranslateToEnglish prompt")
resp, err := llm.ollamaCompletion(ctx, prompt, nil)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] TranslateToEnglish response")
if err != nil {
return resp, err
}
return strings.TrimSpace(resp), nil
}