separate clients

This commit is contained in:
lehel 2025-10-08 13:17:56 +02:00
parent 4ada379be9
commit a0f477c9a8
No known key found for this signature in database
GPG Key ID: 9C4F9D6111EE5CFA
5 changed files with 421 additions and 315 deletions

View File

@ -83,20 +83,9 @@ func (cs *ChatService) findBestVisit(ctx context.Context, req ChatRequest, keywo
bestID := ""
rawDis := ""
if len(candidates) > 0 {
if real, ok := cs.LLM.(*LLMClient); ok {
raw, vr, derr := real.DisambiguateBestMatchRaw(ctx, req.Message, candidates)
rawDis = raw
bestID = vr
if derr != nil {
cs.logBestID(bestID, derr)
} else {
cs.logBestID(bestID, nil)
}
} else {
bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidates)
cs.logBestID(bestID, err)
}
}
visit, err := cs.visitsDB.FindById(bestID)
if err != nil {
return nil, rawDis, fmt.Errorf("FindById: %w", err)
@ -236,3 +225,8 @@ func (cs *ChatService) persistInteraction(ctx context.Context, correlationID str
logrus.WithError(err).Debug("failed to save chat interaction")
}
}
// Add this at the top-level (outside any function)
type correlationIDCtxKeyType struct{}
var correlationIDCtxKey = correlationIDCtxKeyType{}

339
llm.go
View File

@ -3,55 +3,66 @@ package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"text/template"
"time"
"github.com/sirupsen/logrus"
)
// LLMClient abstracts LLM API calls
type LLMClient struct {
APIKey string
BaseURL string
Model string
Repo ChatRepositoryAPI
// LLMClientAPI allows mocking LLMClient in other places
type LLMClientAPI interface {
ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error)
DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error)
GetEmbeddings(ctx context.Context, input string) ([]float64, error)
}
// NewLLMClient constructs a new LLMClient with the given API key, base URL, model, and optional repository
func NewLLMClient(apiKey, baseURL string, model string, repo ChatRepositoryAPI) *LLMClient {
return &LLMClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo}
}
// --- Format Utilities ---
func (llm *LLMClient) SetRepository(r ChatRepositoryAPI) { llm.Repo = r }
// helper to get correlation id from context
const correlationIDCtxKey = "corr_id"
func correlationIDFromCtx(ctx context.Context) string {
v := ctx.Value(correlationIDCtxKey)
if s, ok := v.(string); ok {
return s
// GetExtractKeywordsFormat returns the format specification for keyword extraction
func GetExtractKeywordsFormat() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"translate": map[string]interface{}{"type": "string"},
"keyword": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
"animal": map[string]interface{}{"type": "string"},
},
"required": []string{"translate", "keyword", "animal"},
}
return ""
}
func (llm *LLMClient) persistRaw(ctx context.Context, phase, raw string) {
if llm == nil || llm.Repo == nil || raw == "" {
return
// GetDisambiguateFormat returns the format specification for disambiguation
func GetDisambiguateFormat() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"visitReason": map[string]interface{}{"type": "string"},
},
"required": []string{"visitReason"},
}
cid := correlationIDFromCtx(ctx)
if cid == "" {
return
}
_ = llm.Repo.SaveLLMRawEvent(ctx, cid, phase, raw)
}
// renderPrompt renders a Go template with the given data
// --- Factory ---
func NewLLMClientFromEnv(repo ChatRepositoryAPI) LLMClientAPI {
provider := os.Getenv("LLM_PROVIDER")
apiKey := os.Getenv("OPENAI_API_KEY")
baseURL := os.Getenv("OPENAI_BASE_URL")
model := os.Getenv("OPENAI_MODEL")
switch strings.ToLower(provider) {
case "openai", "openrouter":
return NewOpenAIClient(apiKey, baseURL, model, repo)
case "ollama", "":
return NewOllamaClient(apiKey, baseURL, model, repo)
default:
logrus.Warnf("Unknown LLM_PROVIDER %q, defaulting to Ollama", provider)
return NewOllamaClient(apiKey, baseURL, model, repo)
}
}
// --- Utility ---
func renderPrompt(tmplStr string, data any) (string, error) {
tmpl, err := template.New("").Parse(tmplStr)
if err != nil {
@ -63,263 +74,3 @@ func renderPrompt(tmplStr string, data any) (string, error) {
}
return buf.String(), nil
}
// ExtractKeywords calls LLM to extract keywords from user message
func (llm *LLMClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
_, parsed, err := llm.ExtractKeywordsRaw(ctx, message)
return parsed, err
}
// ExtractKeywordsRaw returns the raw JSON string and parsed map
func (llm *LLMClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) {
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
return "", nil, err
}
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
format := map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"translate": map[string]interface{}{"type": "string"},
"keyword": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
"animal": map[string]interface{}{"type": "string"},
},
"required": []string{"translate", "keyword", "animal"},
}
resp, err := llm.openAICompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
if err != nil {
return resp, nil, err // return whatever raw we got (may be empty)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(resp), &result); err != nil {
return resp, nil, err
}
llm.persistRaw(ctx, "extract_keywords", resp)
return resp, result, nil
}
// DisambiguateBestMatch calls LLM to pick best match from candidates
func (llm *LLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
_, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates)
return vr, err
}
// DisambiguateBestMatchRaw returns raw JSON and visitReason
func (llm *LLMClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
format := map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"visitReason": map[string]interface{}{"type": "string"},
},
"required": []string{"visitReason"},
}
entries, _ := json.Marshal(candidates)
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
return "", "", err
}
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
resp, err := llm.openAICompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
if err != nil {
return resp, "", err
}
var parsed map[string]string
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
}
visitReason := strings.TrimSpace(parsed["visitReason"])
if visitReason == "" {
return resp, "", fmt.Errorf("visitReason not found in response")
}
llm.persistRaw(ctx, "disambiguate", resp)
return resp, visitReason, nil
}
// openAICompletion now supports both Ollama (default local) and OpenRouter/OpenAI-compatible APIs without external branching.
// It auto-detects by inspecting the BaseURL. If the URL contains "openrouter.ai" or "/v1/", it assumes OpenAI-style.
func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
apiURL := llm.BaseURL
if apiURL == "" {
// Default to Ollama local chat endpoint
apiURL = "http://localhost:11434/api/chat"
}
isOpenAIStyle := strings.Contains(apiURL, "openrouter.ai") || strings.Contains(apiURL, "/v1/")
// Helper to stringify the expected JSON schema for instructions
schemaDesc := func() string {
b, _ := json.MarshalIndent(format, "", " ")
return string(b)
}
truncate := func(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "...<truncated>"
}
buildBody := func() map[string]interface{} {
if isOpenAIStyle {
return map[string]interface{}{
"model": llm.Model,
"messages": []map[string]string{
{"role": "system", "content": "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + schemaDesc() + " Do not add explanations."},
{"role": "user", "content": prompt},
},
"response_format": map[string]interface{}{"type": "json_object"},
}
}
// Ollama style
return map[string]interface{}{
"model": llm.Model,
"messages": []map[string]string{{"role": "user", "content": prompt}},
"stream": false,
"format": format,
}
}
body := buildBody()
doRequest := func(body map[string]interface{}) (raw []byte, status int, err error, dur time.Duration) {
jsonBody, _ := json.Marshal(body)
bodySize := len(jsonBody)
logrus.WithFields(logrus.Fields{
"event": "llm_request",
"api_url": apiURL,
"model": llm.Model,
"is_openai_style": isOpenAIStyle,
"prompt_len": len(prompt),
"body_size": bodySize,
}).Info("[LLM] sending request")
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
if llm.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if strings.Contains(apiURL, "openrouter.ai") {
req.Header.Set("Referer", "https://github.com/")
req.Header.Set("X-Title", "vetrag-app")
}
start := time.Now()
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, 0, err, time.Since(start)
}
defer resp.Body.Close()
raw, rerr := io.ReadAll(resp.Body)
return raw, resp.StatusCode, rerr, time.Since(start)
}
raw, status, err, dur := doRequest(body)
if err != nil {
logrus.WithFields(logrus.Fields{
"event": "llm_response",
"status": status,
"latency_ms": dur.Milliseconds(),
"error": err,
}).Error("[LLM] request failed")
return "", err
}
logrus.WithFields(logrus.Fields{
"event": "llm_raw_response",
"status": status,
"latency_ms": dur.Milliseconds(),
"raw_trunc": truncate(string(raw), 600),
"raw_len": len(raw),
}).Debug("[LLM] raw response body")
parseVariant := "unknown"
// Attempt Ollama format parse
var ollama struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
Error string `json:"error"`
}
if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" {
parseVariant = "ollama"
content := ollama.Message.Content
logrus.WithFields(logrus.Fields{
"event": "llm_response",
"status": status,
"latency_ms": dur.Milliseconds(),
"parse_variant": parseVariant,
"content_len": len(content),
"content_snip": truncate(content, 300),
}).Info("[LLM] parsed response")
return content, nil
}
// Attempt OpenAI/OpenRouter style parse
var openAI struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Error *struct {
Message string `json:"message"`
Type string `json:"type"`
} `json:"error"`
}
if err := json.Unmarshal(raw, &openAI); err == nil {
if openAI.Error != nil || status >= 400 {
parseVariant = "openai"
var msg string
if openAI.Error != nil {
msg = openAI.Error.Message
} else {
msg = string(raw)
}
logrus.WithFields(logrus.Fields{
"event": "llm_response",
"status": status,
"latency_ms": dur.Milliseconds(),
"parse_variant": parseVariant,
"error": msg,
}).Error("[LLM] provider error")
return "", fmt.Errorf("provider error: %s", msg)
}
if len(openAI.Choices) > 0 && openAI.Choices[0].Message.Content != "" {
parseVariant = "openai"
content := openAI.Choices[0].Message.Content
logrus.WithFields(logrus.Fields{
"event": "llm_response",
"status": status,
"latency_ms": dur.Milliseconds(),
"parse_variant": parseVariant,
"content_len": len(content),
"content_snip": truncate(content, 300),
}).Info("[LLM] parsed response")
return content, nil
}
}
logrus.WithFields(logrus.Fields{
"event": "llm_response",
"status": status,
"latency_ms": dur.Milliseconds(),
"parse_variant": parseVariant,
"raw_snip": truncate(string(raw), 300),
}).Error("[LLM] unrecognized response format")
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
}
// LLMClientAPI allows mocking LLMClient in other places
// Only public methods should be included
type LLMClientAPI interface {
ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error)
DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error)
}
var _ LLMClientAPI = (*LLMClient)(nil)

View File

@ -5,7 +5,6 @@ import (
"database/sql"
"html/template"
"net/http"
"os"
"github.com/gin-gonic/gin"
_ "github.com/jackc/pgx/v5/stdlib"
@ -71,13 +70,7 @@ func main() {
// defer repo.Close() // optionally enable
// Initialize LLM client
llmClient := NewLLMClient(
os.Getenv("OPENAI_API_KEY"),
os.Getenv("OPENAI_BASE_URL"),
os.Getenv("OPENAI_MODEL"),
repo,
)
var llm LLMClientAPI = llmClient
llm := NewLLMClientFromEnv(repo)
// Wrap templates for controller
uiTmpl := &TemplateWrapper{Tmpl: uiTemplate}

168
ollama_client.go Normal file
View File

@ -0,0 +1,168 @@
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/sirupsen/logrus"
)
// --- OllamaClient implementation ---
type OllamaClient struct {
APIKey string
BaseURL string
Model string
Repo ChatRepositoryAPI
}
func NewOllamaClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OllamaClient {
return &OllamaClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo}
}
func (llm *OllamaClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
_, parsed, err := llm.ExtractKeywordsRaw(ctx, message)
return parsed, err
}
func (llm *OllamaClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) {
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
return "", nil, err
}
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
// Use the utility function instead of inline format definition
format := GetExtractKeywordsFormat()
resp, err := llm.ollamaCompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
if err != nil {
return resp, nil, err
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(resp), &result); err != nil {
return resp, nil, err
}
return resp, result, nil
}
func (llm *OllamaClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
_, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates)
return vr, err
}
func (llm *OllamaClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
// Use the utility function instead of inline format definition
format := GetDisambiguateFormat()
entries, _ := json.Marshal(candidates)
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
return "", "", err
}
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
resp, err := llm.ollamaCompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
if err != nil {
return resp, "", err
}
var parsed map[string]string
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
}
visitReason := strings.TrimSpace(parsed["visitReason"])
if visitReason == "" {
return resp, "", fmt.Errorf("visitReason not found in response")
}
return resp, visitReason, nil
}
func (llm *OllamaClient) ollamaCompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
apiURL := llm.BaseURL
if apiURL == "" {
apiURL = "http://localhost:11434/api/chat"
}
body := map[string]interface{}{
"model": llm.Model,
"messages": []map[string]string{{"role": "user", "content": prompt}},
"stream": false,
"format": format,
}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
if llm.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
var ollama struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
Error string `json:"error"`
}
if err := json.Unmarshal(raw, &ollama); err == nil && ollama.Message.Content != "" {
return ollama.Message.Content, nil
}
if ollama.Error != "" {
return "", fmt.Errorf("provider error: %s", ollama.Error)
}
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
}
func (llm *OllamaClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
apiURL := llm.BaseURL
if apiURL == "" {
apiURL = "http://localhost:11434/api/embeddings"
}
body := map[string]interface{}{
"model": llm.Model,
"prompt": input,
}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
if llm.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var ollama struct {
Embedding []float64 `json:"embedding"`
Error string `json:"error"`
}
if err := json.Unmarshal(raw, &ollama); err == nil && len(ollama.Embedding) > 0 {
return ollama.Embedding, nil
}
if ollama.Error != "" {
return nil, fmt.Errorf("embedding error: %s", ollama.Error)
}
return nil, fmt.Errorf("unrecognized embedding response: %.200s", string(raw))
}

200
openai_client.go Normal file
View File

@ -0,0 +1,200 @@
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/sirupsen/logrus"
)
// --- OpenAIClient implementation ---
type OpenAIClient struct {
APIKey string
BaseURL string
Model string
Repo ChatRepositoryAPI
}
func NewOpenAIClient(apiKey, baseURL, model string, repo ChatRepositoryAPI) *OpenAIClient {
return &OpenAIClient{APIKey: apiKey, BaseURL: baseURL, Model: model, Repo: repo}
}
func (llm *OpenAIClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
_, parsed, err := llm.ExtractKeywordsRaw(ctx, message)
return parsed, err
}
func (llm *OpenAIClient) ExtractKeywordsRaw(ctx context.Context, message string) (string, map[string]interface{}, error) {
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
return "", nil, err
}
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
// Use the utility function instead of inline format definition
format := GetExtractKeywordsFormat()
resp, err := llm.openAICompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
if err != nil {
return resp, nil, err
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(resp), &result); err != nil {
return resp, nil, err
}
return resp, result, nil
}
func (llm *OpenAIClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
_, vr, err := llm.DisambiguateBestMatchRaw(ctx, message, candidates)
return vr, err
}
func (llm *OpenAIClient) DisambiguateBestMatchRaw(ctx context.Context, message string, candidates []Visit) (string, string, error) {
// Use the utility function instead of inline format definition
format := GetDisambiguateFormat()
entries, _ := json.Marshal(candidates)
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
if err != nil {
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
return "", "", err
}
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
resp, err := llm.openAICompletion(ctx, prompt, format)
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
if err != nil {
return resp, "", err
}
var parsed map[string]string
if err := json.Unmarshal([]byte(resp), &parsed); err != nil {
return resp, "", fmt.Errorf("failed to unmarshal disambiguation response: %w", err)
}
visitReason := strings.TrimSpace(parsed["visitReason"])
if visitReason == "" {
return resp, "", fmt.Errorf("visitReason not found in response")
}
return resp, visitReason, nil
}
func (llm *OpenAIClient) openAICompletion(ctx context.Context, prompt string, format map[string]interface{}) (string, error) {
apiURL := llm.BaseURL
if apiURL == "" {
apiURL = "https://api.openai.com/v1/chat/completions"
}
// Helper to stringify the expected JSON schema for instructions
schemaDesc := func() string {
b, _ := json.MarshalIndent(format, "", " ")
return string(b)
}
body := map[string]interface{}{
"model": llm.Model,
"messages": []map[string]string{
{"role": "system", "content": "You are a strict JSON generator. ONLY output valid JSON matching this schema: " + schemaDesc() + " Do not add explanations."},
{"role": "user", "content": prompt},
},
"response_format": map[string]interface{}{"type": "json_object"},
}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
if llm.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if strings.Contains(apiURL, "openrouter.ai") {
req.Header.Set("Referer", "https://github.com/")
req.Header.Set("X-Title", "vetrag-app")
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
var openAI struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Error *struct {
Message string `json:"message"`
Type string `json:"type"`
} `json:"error"`
}
if err := json.Unmarshal(raw, &openAI); err == nil {
if openAI.Error != nil || resp.StatusCode >= 400 {
var msg string
if openAI.Error != nil {
msg = openAI.Error.Message
} else {
msg = string(raw)
}
return "", fmt.Errorf("provider error: %s", msg)
}
if len(openAI.Choices) > 0 && openAI.Choices[0].Message.Content != "" {
return openAI.Choices[0].Message.Content, nil
}
}
return "", fmt.Errorf("unrecognized LLM response format: %.200s", string(raw))
}
func (llm *OpenAIClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
apiURL := llm.BaseURL
if apiURL == "" {
apiURL = "https://api.openai.com/v1/embeddings"
}
body := map[string]interface{}{
"model": llm.Model,
"input": input,
}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonBody))
if llm.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if strings.Contains(apiURL, "openrouter.ai") {
req.Header.Set("Referer", "https://github.com/")
req.Header.Set("X-Title", "vetrag-app")
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var openAI struct {
Data []struct {
Embedding []float64 `json:"embedding"`
} `json:"data"`
Error *struct {
Message string `json:"message"`
} `json:"error"`
}
if err := json.Unmarshal(raw, &openAI); err == nil && len(openAI.Data) > 0 {
return openAI.Data[0].Embedding, nil
}
if openAI.Error != nil {
return nil, fmt.Errorf("embedding error: %s", openAI.Error.Message)
}
return nil, fmt.Errorf("unrecognized embedding response: %.200s", string(raw))
}