From a0f477c9a818b3a8473567665303b628d6dead57 Mon Sep 17 00:00:00 2001 From: lehel Date: Wed, 8 Oct 2025 13:17:56 +0200 Subject: [PATCH] separate clients --- chat_service.go | 20 +-- llm.go | 339 +++++++---------------------------------------- main.go | 9 +- ollama_client.go | 168 +++++++++++++++++++++++ openai_client.go | 200 ++++++++++++++++++++++++++++ 5 files changed, 421 insertions(+), 315 deletions(-) create mode 100644 ollama_client.go create mode 100644 openai_client.go diff --git a/chat_service.go b/chat_service.go index bd89a38..f6117dc 100644 --- a/chat_service.go +++ b/chat_service.go @@ -83,19 +83,8 @@ func (cs *ChatService) findBestVisit(ctx context.Context, req ChatRequest, keywo bestID := "" rawDis := "" if len(candidates) > 0 { - if real, ok := cs.LLM.(*LLMClient); ok { - raw, vr, derr := real.DisambiguateBestMatchRaw(ctx, req.Message, candidates) - rawDis = raw - bestID = vr - if derr != nil { - cs.logBestID(bestID, derr) - } else { - cs.logBestID(bestID, nil) - } - } else { - bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidates) - cs.logBestID(bestID, err) - } + bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidates) + cs.logBestID(bestID, err) } visit, err := cs.visitsDB.FindById(bestID) if err != nil { @@ -236,3 +225,8 @@ func (cs *ChatService) persistInteraction(ctx context.Context, correlationID str logrus.WithError(err).Debug("failed to save chat interaction") } } + +// Add this at the top-level (outside any function) +type correlationIDCtxKeyType struct{} + +var correlationIDCtxKey = correlationIDCtxKeyType{} diff --git a/llm.go b/llm.go index 3ec8dd0..837de18 100644 --- a/llm.go +++ b/llm.go @@ -3,55 +3,66 @@ package main import ( "bytes" "context" - "encoding/json" - "fmt" - "io" - "net/http" + "os" "strings" "text/template" - "time" "github.com/sirupsen/logrus" ) -// LLMClient abstracts LLM API calls -type LLMClient struct { - APIKey string - BaseURL string - Model string - Repo ChatRepositoryAPI +// LLMClientAPI allows mocking LLMClient in other places +type LLMClientAPI interface { + ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) + DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) + GetEmbeddings(ctx context.Context, input string) ([]float64, error) } -// NewLLMClient constructs a new LLMClient with the given API key, base URL, model, and optional repository -func NewLLMClient(apiKey, baseURL string, model string, repo ChatRepositoryAPI) *LLMClient { - return &LLMClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo} -} +// --- Format Utilities --- -func (llm *LLMClient) SetRepository(r ChatRepositoryAPI) { llm.Repo = r } - -// helper to get correlation id from context -const correlationIDCtxKey = "corr_id" - -func correlationIDFromCtx(ctx context.Context) string { - v := ctx.Value(correlationIDCtxKey) - if s, ok := v.(string); ok { - return s +// GetExtractKeywordsFormat returns the format specification for keyword extraction +func GetExtractKeywordsFormat() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "translate": map[string]interface{}{"type": "string"}, + "keyword": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}}, + "animal": map[string]interface{}{"type": "string"}, + }, + "required": []string{"translate", "keyword", "animal"}, } - return "" } -func (llm *LLMClient) persistRaw(ctx context.Context, phase, raw string) { - if llm == nil || llm.Repo == nil || raw == "" { - return +// GetDisambiguateFormat returns the format specification for disambiguation +func GetDisambiguateFormat() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "visitReason": map[string]interface{}{"type": "string"}, + }, + "required": []string{"visitReason"}, } - cid := correlationIDFromCtx(ctx) - if cid == "" { - return - } - _ = llm.Repo.SaveLLMRawEvent(ctx, cid, phase, raw) } -// renderPrompt renders a Go template with the given data +// --- Factory --- + +func NewLLMClientFromEnv(repo ChatRepositoryAPI) LLMClientAPI { + provider := os.Getenv("LLM_PROVIDER") + apiKey := os.Getenv("OPENAI_API_KEY") + baseURL := os.Getenv("OPENAI_BASE_URL") + model := os.Getenv("OPENAI_MODEL") + switch strings.ToLower(provider) { + case "openai", "openrouter": + return NewOpenAIClient(apiKey, baseURL, model, repo) + case "ollama", "": + return NewOllamaClient(apiKey, baseURL, model, repo) + default: + logrus.Warnf("Unknown LLM_PROVIDER %q, defaulting to Ollama", provider) + return NewOllamaClient(apiKey, baseURL, model, repo) + } +} + +// --- Utility --- + func renderPrompt(tmplStr string, data any) (string, error) { tmpl, err := template.New("").Parse(tmplStr) if err != nil { @@ -63,263 +74,3 @@ func renderPrompt(tmplStr string, data any) (string, error) { } return buf.String(), nil } - -// ExtractKeywords calls LLM to extract keywords from user message -func (llm *LLMClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) { - _, parsed, err := llm.ExtractKeywordsRaw(ctx, message) - return parsed, err -} - -// ExtractKeywordsRaw returns the raw JSON string and parsed map -func (llm *LLMClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) { - prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message}) - if err != nil { - logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt") - return "", nil, err - } - logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt") - format := map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "translate": map[string]interface{}{"type": "string"}, - "keyword": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}}, - "animal": map[string]interface{}{"type": "string"}, - }, - "required": []string{"translate", "keyword", "animal"}, - } - resp, err := llm.openAICompletion(ctx, prompt, format) - logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response") - if err != nil { - return resp, nil, err // return whatever raw we got (may be empty) - } - var result map[string]interface{} - if err := json.Unmarshal([]byte(resp), &result); err != nil { - return resp, nil, err - } - llm.persistRaw(ctx, "extract_keywords", resp) - return resp, result, nil -} - -// DisambiguateBestMatch calls LLM to pick best match from candidates -func (llm *LLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) { - _, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates) - return vr, err -} - -// DisambiguateBestMatchRaw returns raw JSON and visitReason -func (llm *LLMClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) { - format := map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "visitReason": map[string]interface{}{"type": "string"}, - }, - "required": []string{"visitReason"}, - } - entries, _ := json.Marshal(candidates) - prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message}) - if err != nil { - logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt") - return "", "", err - } - logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt") - resp, err := llm.openAICompletion(ctx, prompt, format) - logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response") - if err != nil { - return resp, "", err - } - var parsed map[string]string - if err := json.Unmarshal([]byte(resp), &parsed); err != nil { - return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err) - } - visitReason := strings.TrimSpace(parsed["visitReason"]) - if visitReason == "" { - return resp, "", fmt.Errorf("visitReason not found in response") - } - llm.persistRaw(ctx, "disambiguate", resp) - return resp, visitReason, nil -} - -// openAICompletion now supports both Ollama (default local) and OpenRouter/OpenAI-compatible APIs without external branching. -// It auto-detects by inspecting the BaseURL. If the URL contains "openrouter.ai" or "/v1/", it assumes OpenAI-style. -func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) { - apiURL := llm.BaseURL - if apiURL == "" { - // Default to Ollama local chat endpoint - apiURL = "http://localhost:11434/api/chat" - } - - isOpenAIStyle := strings.Contains(apiURL, "openrouter.ai") || strings.Contains(apiURL, "/v1/") - - // Helper to stringify the expected JSON schema for instructions - schemaDesc := func() string { - b, _ := json.MarshalIndent(format, "", " ") - return string(b) - } - - truncate := func(s string, n int) string { - if len(s) <= n { - return s - } - return s[:n] + "..." - } - - buildBody := func() map[string]interface{} { - if isOpenAIStyle { - return map[string]interface{}{ - "model": llm.Model, - "messages": []map[string]string{ - {"role": "system", "content": "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + schemaDesc() + " Do not add explanations."}, - {"role": "user", "content": prompt}, - }, - "response_format": map[string]interface{}{"type": "json_object"}, - } - } - // Ollama style - return map[string]interface{}{ - "model": llm.Model, - "messages": []map[string]string{{"role": "user", "content": prompt}}, - "stream": false, - "format": format, - } - } - - body := buildBody() - - doRequest := func(body map[string]interface{}) (raw []byte, status int, err error, dur time.Duration) { - jsonBody, _ := json.Marshal(body) - bodySize := len(jsonBody) - logrus.WithFields(logrus.Fields{ - "event": "llm_request", - "api_url": apiURL, - "model": llm.Model, - "is_openai_style": isOpenAIStyle, - "prompt_len": len(prompt), - "body_size": bodySize, - }).Info("[LLM] sending request") - req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody)) - if llm.APIKey != "" { - req.Header.Set("Authorization", "Bearer "+llm.APIKey) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - if strings.Contains(apiURL, "openrouter.ai") { - req.Header.Set("Referer", "https://github.com/") - req.Header.Set("X-Title", "vetrag-app") - } - start := time.Now() - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return nil, 0, err, time.Since(start) - } - defer resp.Body.Close() - raw, rerr := io.ReadAll(resp.Body) - return raw, resp.StatusCode, rerr, time.Since(start) - } - - raw, status, err, dur := doRequest(body) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event": "llm_response", - "status": status, - "latency_ms": dur.Milliseconds(), - "error": err, - }).Error("[LLM] request failed") - return "", err - } - logrus.WithFields(logrus.Fields{ - "event": "llm_raw_response", - "status": status, - "latency_ms": dur.Milliseconds(), - "raw_trunc": truncate(string(raw), 600), - "raw_len": len(raw), - }).Debug("[LLM] raw response body") - - parseVariant := "unknown" - - // Attempt Ollama format parse - var ollama struct { - Message struct { - Content string `json:"content"` - } `json:"message"` - Error string `json:"error"` - } - if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" { - parseVariant = "ollama" - content := ollama.Message.Content - logrus.WithFields(logrus.Fields{ - "event": "llm_response", - "status": status, - "latency_ms": dur.Milliseconds(), - "parse_variant": parseVariant, - "content_len": len(content), - "content_snip": truncate(content, 300), - }).Info("[LLM] parsed response") - return content, nil - } - - // Attempt OpenAI/OpenRouter style parse - var openAI struct { - Choices []struct { - Message struct { - Content string `json:"content"` - } `json:"message"` - } `json:"choices"` - Error *struct { - Message string `json:"message"` - Type string `json:"type"` - } `json:"error"` - } - if err := json.Unmarshal(raw, &openAI); err == nil { - if openAI.Error != nil || status >= 400 { - parseVariant = "openai" - var msg string - if openAI.Error != nil { - msg = openAI.Error.Message - } else { - msg = string(raw) - } - logrus.WithFields(logrus.Fields{ - "event": "llm_response", - "status": status, - "latency_ms": dur.Milliseconds(), - "parse_variant": parseVariant, - "error": msg, - }).Error("[LLM] provider error") - return "", fmt.Errorf("provider error: %s", msg) - } - if len(openAI.Choices) > 0 && openAI.Choices[0].Message.Content != "" { - parseVariant = "openai" - content := openAI.Choices[0].Message.Content - logrus.WithFields(logrus.Fields{ - "event": "llm_response", - "status": status, - "latency_ms": dur.Milliseconds(), - "parse_variant": parseVariant, - "content_len": len(content), - "content_snip": truncate(content, 300), - }).Info("[LLM] parsed response") - return content, nil - } - } - - logrus.WithFields(logrus.Fields{ - "event": "llm_response", - "status": status, - "latency_ms": dur.Milliseconds(), - "parse_variant": parseVariant, - "raw_snip": truncate(string(raw), 300), - }).Error("[LLM] unrecognized response format") - - return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw)) -} - -// LLMClientAPI allows mocking LLMClient in other places -// Only public methods should be included - -type LLMClientAPI interface { - ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) - DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) -} - -var _ LLMClientAPI = (*LLMClient)(nil) diff --git a/main.go b/main.go index 9ff1ba1..cfa03e4 100644 --- a/main.go +++ b/main.go @@ -5,7 +5,6 @@ import ( "database/sql" "html/template" "net/http" - "os" "github.com/gin-gonic/gin" _ "github.com/jackc/pgx/v5/stdlib" @@ -71,13 +70,7 @@ func main() { // defer repo.Close() // optionally enable // Initialize LLM client - llmClient := NewLLMClient( - os.Getenv("OPENAI_API_KEY"), - os.Getenv("OPENAI_BASE_URL"), - os.Getenv("OPENAI_MODEL"), - repo, - ) - var llm LLMClientAPI = llmClient + llm := NewLLMClientFromEnv(repo) // Wrap templates for controller uiTmpl := &TemplateWrapper{Tmpl: uiTemplate} diff --git a/ollama_client.go b/ollama_client.go new file mode 100644 index 0000000..a455c92 --- /dev/null +++ b/ollama_client.go @@ -0,0 +1,168 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/sirupsen/logrus" +) + +// --- OllamaClient implementation --- + +type OllamaClient struct { + APIKey string + BaseURL string + Model string + Repo ChatRepositoryAPI +} + +func NewOllamaClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OllamaClient { + return &OllamaClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo} +} + +func (llm *OllamaClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) { + _, parsed, err := llm.ExtractKeywordsRaw(ctx, message) + return parsed, err +} + +func (llm *OllamaClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) { + prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message}) + if err != nil { + logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt") + return "", nil, err + } + logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt") + + // Use the utility function instead of inline format definition + format := GetExtractKeywordsFormat() + + resp, err := llm.ollamaCompletion(ctx, prompt, format) + logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response") + if err != nil { + return resp, nil, err + } + var result map[string]interface{} + if err := json.Unmarshal([]byte(resp), &result); err != nil { + return resp, nil, err + } + return resp, result, nil +} + +func (llm *OllamaClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) { + _, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates) + return vr, err +} + +func (llm *OllamaClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) { + // Use the utility function instead of inline format definition + format := GetDisambiguateFormat() + + entries, _ := json.Marshal(candidates) + prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message}) + if err != nil { + logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt") + return "", "", err + } + logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt") + resp, err := llm.ollamaCompletion(ctx, prompt, format) + logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response") + if err != nil { + return resp, "", err + } + var parsed map[string]string + if err := json.Unmarshal([]byte(resp), &parsed); err != nil { + return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err) + } + visitReason := strings.TrimSpace(parsed["visitReason"]) + if visitReason == "" { + return resp, "", fmt.Errorf("visitReason not found in response") + } + return resp, visitReason, nil +} + +func (llm *OllamaClient) ollamaCompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) { + apiURL := llm.BaseURL + if apiURL == "" { + apiURL = "http://localhost:11434/api/chat" + } + body := map[string]interface{}{ + "model": llm.Model, + "messages": []map[string]string{{"role": "user", "content": prompt}}, + "stream": false, + "format": format, + } + jsonBody, _ := json.Marshal(body) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody)) + if llm.APIKey != "" { + req.Header.Set("Authorization", "Bearer "+llm.APIKey) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + raw, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + var ollama struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + Error string `json:"error"` + } + if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" { + return ollama.Message.Content, nil + } + if ollama.Error != "" { + return "", fmt.Errorf("provider error: %s", ollama.Error) + } + return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw)) +} + +func (llm *OllamaClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) { + apiURL := llm.BaseURL + if apiURL == "" { + apiURL = "http://localhost:11434/api/embeddings" + } + body := map[string]interface{}{ + "model": llm.Model, + "prompt": input, + } + jsonBody, _ := json.Marshal(body) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody)) + if llm.APIKey != "" { + req.Header.Set("Authorization", "Bearer "+llm.APIKey) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + var ollama struct { + Embedding []float64 `json:"embedding"` + Error string `json:"error"` + } + if err := json.Unmarshal(raw, &ollama); err == nil && len(ollama.Embedding) > 0 { + return ollama.Embedding, nil + } + if ollama.Error != "" { + return nil, fmt.Errorf("embedding error: %s", ollama.Error) + } + return nil, fmt.Errorf("unrecognized embedding response: %.200s", string(raw)) +} diff --git a/openai_client.go b/openai_client.go new file mode 100644 index 0000000..43a9f29 --- /dev/null +++ b/openai_client.go @@ -0,0 +1,200 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/sirupsen/logrus" +) + +// --- OpenAIClient implementation --- + +type OpenAIClient struct { + APIKey string + BaseURL string + Model string + Repo ChatRepositoryAPI +} + +func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient { + return &OpenAIClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo} +} + +func (llm *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) { + _, parsed, err := llm.ExtractKeywordsRaw(ctx, message) + return parsed, err +} + +func (llm *OpenAIClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) { + prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message}) + if err != nil { + logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt") + return "", nil, err + } + logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt") + + // Use the utility function instead of inline format definition + format := GetExtractKeywordsFormat() + + resp, err := llm.openAICompletion(ctx, prompt, format) + logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response") + if err != nil { + return resp, nil, err + } + var result map[string]interface{} + if err := json.Unmarshal([]byte(resp), &result); err != nil { + return resp, nil, err + } + return resp, result, nil +} + +func (llm *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) { + _, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates) + return vr, err +} + +func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) { + // Use the utility function instead of inline format definition + format := GetDisambiguateFormat() + + entries, _ := json.Marshal(candidates) + prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message}) + if err != nil { + logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt") + return "", "", err + } + logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt") + resp, err := llm.openAICompletion(ctx, prompt, format) + logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response") + if err != nil { + return resp, "", err + } + var parsed map[string]string + if err := json.Unmarshal([]byte(resp), &parsed); err != nil { + return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err) + } + visitReason := strings.TrimSpace(parsed["visitReason"]) + if visitReason == "" { + return resp, "", fmt.Errorf("visitReason not found in response") + } + return resp, visitReason, nil +} + +func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) { + apiURL := llm.BaseURL + if apiURL == "" { + apiURL = "https://api.openai.com/v1/chat/completions" + } + // Helper to stringify the expected JSON schema for instructions + schemaDesc := func() string { + b, _ := json.MarshalIndent(format, "", " ") + return string(b) + } + body := map[string]interface{}{ + "model": llm.Model, + "messages": []map[string]string{ + {"role": "system", "content": "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + schemaDesc() + " Do not add explanations."}, + {"role": "user", "content": prompt}, + }, + "response_format": map[string]interface{}{"type": "json_object"}, + } + jsonBody, _ := json.Marshal(body) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody)) + if llm.APIKey != "" { + req.Header.Set("Authorization", "Bearer "+llm.APIKey) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + if strings.Contains(apiURL, "openrouter.ai") { + req.Header.Set("Referer", "https://github.com/") + req.Header.Set("X-Title", "vetrag-app") + } + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + raw, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + var openAI struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + Error *struct { + Message string `json:"message"` + Type string `json:"type"` + } `json:"error"` + } + if err := json.Unmarshal(raw, &openAI); err == nil { + if openAI.Error != nil || resp.StatusCode >= 400 { + var msg string + if openAI.Error != nil { + msg = openAI.Error.Message + } else { + msg = string(raw) + } + return "", fmt.Errorf("provider error: %s", msg) + } + if len(openAI.Choices) > 0 && openAI.Choices[0].Message.Content != "" { + return openAI.Choices[0].Message.Content, nil + } + } + return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw)) +} + +func (llm *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) { + apiURL := llm.BaseURL + if apiURL == "" { + apiURL = "https://api.openai.com/v1/embeddings" + } + body := map[string]interface{}{ + "model": llm.Model, + "input": input, + } + jsonBody, _ := json.Marshal(body) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody)) + if llm.APIKey != "" { + req.Header.Set("Authorization", "Bearer "+llm.APIKey) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + if strings.Contains(apiURL, "openrouter.ai") { + req.Header.Set("Referer", "https://github.com/") + req.Header.Set("X-Title", "vetrag-app") + } + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + var openAI struct { + Data []struct { + Embedding []float64 `json:"embedding"` + } `json:"data"` + Error *struct { + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal(raw, &openAI); err == nil && len(openAI.Data) > 0 { + return openAI.Data[0].Embedding, nil + } + if openAI.Error != nil { + return nil, fmt.Errorf("embedding error: %s", openAI.Error.Message) + } + return nil, fmt.Errorf("unrecognized embedding response: %.200s", string(raw)) +}