package main import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "os" "testing" "github.com/stretchr/testify/assert" ) // MockLLMClient implements LLMClientAPI for testing type MockLLMClient struct { ExtractKeywordsFunc func(ctx context.Context, message string) (map[string]interface{}, error) DisambiguateBestMatchFunc func(ctx context.Context, message string, candidates []Visit) (string, error) GetEmbeddingsFunc func(ctx context.Context, input string) ([]float64, error) } func (m *MockLLMClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) { if m.ExtractKeywordsFunc != nil { return m.ExtractKeywordsFunc(ctx, message) } return map[string]interface{}{ "translate": "test translation", "keyword": []string{"test", "keywords"}, "animal": "test animal", }, nil } func (m *MockLLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) { if m.DisambiguateBestMatchFunc != nil { return m.DisambiguateBestMatchFunc(ctx, message, candidates) } if len(candidates) > 0 { return candidates[0].ID, nil } return "", nil } func (m *MockLLMClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) { if m.GetEmbeddingsFunc != nil { return m.GetEmbeddingsFunc(ctx, input) } return []float64{0.1, 0.2, 0.3}, nil } func TestNewLLMClientFromEnv(t *testing.T) { tests := []struct { name string envVars map[string]string expectedType string }{ { name: "default to ollama when no provider specified", envVars: map[string]string{ "LLM_PROVIDER": "", }, expectedType: "*main.OllamaClient", }, { name: "use openai client when provider is openai", envVars: map[string]string{ "LLM_PROVIDER": "openai", }, expectedType: "*main.OpenAIClient", }, { name: "use ollama client when provider is ollama", envVars: map[string]string{ "LLM_PROVIDER": "ollama", }, expectedType: "*main.OllamaClient", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Clear existing env vars os.Unsetenv("LLM_PROVIDER") os.Unsetenv("OPENAI_API_KEY") os.Unsetenv("OPENAI_BASE_URL") os.Unsetenv("OPENAI_MODEL") // Set env vars for test for k, v := range tt.envVars { os.Setenv(k, v) } client := NewLLMClientFromEnv(nil) assert.NotNil(t, client) assert.Equal(t, tt.expectedType, typeName(client)) }) } } func typeName(v interface{}) string { if v == nil { return "nil" } return fmt.Sprintf("%T", v) } func TestOpenAIClient_GetEmbeddings(t *testing.T) { // Mock server to simulate OpenAI API response server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // No path check needed since we're passing the full URL as baseURL assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) var reqBody map[string]interface{} err := json.NewDecoder(r.Body).Decode(&reqBody) assert.NoError(t, err) assert.Equal(t, "test-model", reqBody["model"]) assert.Equal(t, "test input", reqBody["input"]) // Respond with mock embedding data w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(`{ "data": [ { "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] } ] }`)) })) defer server.Close() // Pass the full URL as the baseURL parameter client := NewOpenAIClient("test-key", server.URL, "test-model", nil) embeddings, err := client.GetEmbeddings(context.Background(), "test input") assert.NoError(t, err) assert.Equal(t, []float64{0.1, 0.2, 0.3, 0.4, 0.5}, embeddings) } func TestOllamaClient_GetEmbeddings(t *testing.T) { // Mock server to simulate Ollama API response server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // The API URL for embeddings in ollama_client.go is constructed as: // apiURL (baseURL) = "http://localhost:11434/api/embeddings" // So we shouldn't expect a path suffix here var reqBody map[string]interface{} err := json.NewDecoder(r.Body).Decode(&reqBody) assert.NoError(t, err) assert.Equal(t, "test-model", reqBody["model"]) assert.Equal(t, "test input", reqBody["prompt"]) // Respond with mock embedding data w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(`{ "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] }`)) })) defer server.Close() // Pass the full URL as the baseURL parameter client := NewOllamaClient("", server.URL, "test-model", nil) embeddings, err := client.GetEmbeddings(context.Background(), "test input") assert.NoError(t, err) assert.Equal(t, []float64{0.1, 0.2, 0.3, 0.4, 0.5}, embeddings) }