package main import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "os" "strconv" "strings" "time" "github.com/sirupsen/logrus" ) // --- OllamaClient implementation --- type OllamaClient struct { APIKey string BaseURL string Model string EmbeddingModel string Repo ChatRepositoryAPI } func NewOllamaClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OllamaClient { return &OllamaClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo} } func (llm *OllamaClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) { _, parsed, err := llm.ExtractKeywordsRaw(ctx, message) return parsed, err } func (llm *OllamaClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) { prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message}) if err != nil { logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt") return "", nil, err } logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt") // Use the utility function instead of inline format definition format := GetExtractKeywordsFormat() resp, err := llm.ollamaCompletion(ctx, prompt, format) logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response") if err != nil { return resp, nil, err } var result map[string]interface{} if err := json.Unmarshal([]byte(resp), &result); err != nil { return resp, nil, err } return resp, result, nil } func (llm *OllamaClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) { _, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates) return vr, err } func (llm *OllamaClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) { // Use the utility function instead of inline format definition format := GetDisambiguateFormat() entries, _ := json.Marshal(candidates) prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message}) if err != nil { logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt") return "", "", err } logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt") resp, err := llm.ollamaCompletion(ctx, prompt, format) logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response") if err != nil { return resp, "", err } var parsed map[string]string if err := json.Unmarshal([]byte(resp), &parsed); err != nil { return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err) } visitReason := strings.TrimSpace(parsed["visitReason"]) if visitReason == "" { return resp, "", fmt.Errorf("visitReason not found in response") } return resp, visitReason, nil } func (llm *OllamaClient) ollamaCompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) { apiURL := llm.BaseURL if apiURL == "" { apiURL = "http://localhost:11434/api/chat" } messages := []map[string]string{{"role": "user", "content": prompt}} //if os.Getenv("DISABLE_THINK") == "1" { // System message to suppress chain-of-thought style outputs. messages = append([]map[string]string{{ "role": "system", "content": "You are a concise assistant. Output ONLY the final answer requested by the user. Do not include reasoning, analysis, or tags.", }}, messages...) //} body := map[string]interface{}{ "model": llm.Model, "messages": messages, "stream": false, "format": format, } // Optional: Add a stop sequence to prevent tags if they appear if os.Getenv("DISABLE_THINK") == "1" { body["options"] = map[string]interface{}{"stop": []string{""}} } jsonBody, _ := json.Marshal(body) req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody)) if llm.APIKey != "" { req.Header.Set("Authorization", "Bearer "+llm.APIKey) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") client := &http.Client{} resp, err := client.Do(req) if err != nil { return "", err } defer resp.Body.Close() raw, err := io.ReadAll(resp.Body) if err != nil { return "", err } var ollama struct { Message struct { Content string `json:"content"` } `json:"message"` Error string `json:"error"` } if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" { return ollama.Message.Content, nil } if ollama.Error != "" { return "", fmt.Errorf("provider error: %s", ollama.Error) } return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw)) } func normalizeOllamaHost(raw string) string { if raw == "" { return "http://localhost:11434" } // strip trailing /api/* paths if user provided full endpoint lower := strings.ToLower(raw) for _, seg := range []string{"/api/chat", "/api/embeddings", "/api/generate"} { if strings.HasSuffix(lower, seg) { return raw[:len(raw)-len(seg)] } } return raw } func (llm *OllamaClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) { host := normalizeOllamaHost(llm.BaseURL) apiURL := host + "/api/embeddings" modelName := llm.Model if llm.EmbeddingModel != "" { modelName = llm.EmbeddingModel } // retry parameters (env override OLLAMA_EMBED_ATTEMPTS) maxAttempts := 5 if v := os.Getenv("OLLAMA_EMBED_ATTEMPTS"); v != "" { if n, err := strconv.Atoi(v); err == nil && n > 0 && n < 20 { maxAttempts = n } } baseBackoff := 300 * time.Millisecond var lastErr error for attempt := 0; attempt < maxAttempts; attempt++ { select { case <-ctx.Done(): return nil, ctx.Err() default: } body := map[string]interface{}{ "model": modelName, "prompt": input, } jsonBody, _ := json.Marshal(body) req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody)) if llm.APIKey != "" { req.Header.Set("Authorization", "Bearer "+llm.APIKey) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") resp, err := (&http.Client{}).Do(req) if err != nil { lastErr = err logrus.WithError(err).Warnf("[Ollama] embeddings request attempt=%d failed", attempt+1) } else { raw, rerr := io.ReadAll(resp.Body) resp.Body.Close() if rerr != nil { lastErr = rerr } else { var generic map[string]json.RawMessage if jerr := json.Unmarshal(raw, &generic); jerr != nil { lastErr = fmt.Errorf("unrecognized response (parse): %w", jerr) } else if embRaw, ok := generic["embedding"]; ok && len(embRaw) > 0 { var emb []float64 if jerr := json.Unmarshal(embRaw, &emb); jerr != nil { lastErr = fmt.Errorf("failed to decode embedding: %w", jerr) } else if len(emb) == 0 { lastErr = fmt.Errorf("empty embedding returned") } else { return emb, nil } } else if drRaw, ok := generic["done_reason"]; ok { var reason string _ = json.Unmarshal(drRaw, &reason) if reason == "load" { // transient model loading state lastErr = fmt.Errorf("model loading") } else { lastErr = fmt.Errorf("unexpected done_reason=%s", reason) } } else if errRaw, ok := generic["error"]; ok { var errMsg string _ = json.Unmarshal(errRaw, &errMsg) if errMsg != "" { lastErr = fmt.Errorf("embedding error: %s", errMsg) } else { lastErr = fmt.Errorf("embedding error (empty message)") } } else { lastErr = fmt.Errorf("unrecognized embedding response: %.200s", string(raw)) } } } if lastErr == nil { break } // backoff if not last attempt if attempt < maxAttempts-1 { delay := baseBackoff << attempt if strings.Contains(strings.ToLower(lastErr.Error()), "model loading") { delay += 1 * time.Second } time.Sleep(delay) } } if lastErr == nil { lastErr = fmt.Errorf("embedding retrieval failed with no error info") } return nil, lastErr } func (llm *OllamaClient) TranslateToEnglish(ctx context.Context, message string) (string, error) { prompt, err := renderPrompt(appConfig.LLM.TranslatePrompt, map[string]string{"Message": message}) if err != nil { logrus.WithError(err).Error("[CONFIG] Failed to render Translate prompt") return "", err } logrus.WithField("prompt", prompt).Info("[LLM] TranslateToEnglish prompt") resp, err := llm.ollamaCompletion(ctx, prompt, nil) logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] TranslateToEnglish response") if err != nil { return resp, err } return strings.TrimSpace(resp), nil }