From 77c03966232cd59f7cc8d9aaa84ea3a45049f252 Mon Sep 17 00:00:00 2001 From: lehel Date: Wed, 8 Oct 2025 14:23:14 +0200 Subject: [PATCH] add test --- Makefile | 22 +++- chat_service_integration_test.go | 5 + controller_test.go | 180 +++++++++++++++++++++++++++++++ handlechat_integration_test.go | 5 + llm_test.go | 167 ++++++++++++++++++++++++++++ openrouter_integration_test.go | 34 +++--- 6 files changed, 398 insertions(+), 15 deletions(-) create mode 100644 controller_test.go create mode 100644 llm_test.go diff --git a/Makefile b/Makefile index 71afcc1..5d754d0 100644 --- a/Makefile +++ b/Makefile @@ -32,7 +32,7 @@ DB_SSLMODE ?= disable db_env = PGHOST=$(DB_HOST) PGPORT=$(DB_PORT) PGUSER=$(DB_USER) PGPASSWORD=$(DB_PASSWORD) PGDATABASE=$(DB_NAME) PGSSLMODE=$(DB_SSLMODE) # Run the Go server (assumes Ollama is running) with DB env vars -run: +run: $(db_env) OPENAI_API_KEY=ollama OPENAI_BASE_URL=http://localhost:11434/api/chat OPENAI_MODEL=qwen3:latest go run . # Run without pulling model (faster if already present) @@ -48,7 +48,25 @@ print-dsn: @echo postgres://$(DB_USER):******@$(DB_HOST):$(DB_PORT)/$(DB_NAME)?sslmode=$(DB_SSLMODE) # Run tests -.PHONY: test +.PHONY: test test-verbose test-race test-coverage test-coverage-html +# Run standard tests test: go test ./... + +# Run tests with verbose output +test-verbose: + go test -v ./... + +# Run tests with race detection +test-race: + go test -race ./... + +# Run tests with coverage reporting +test-coverage: + go test -coverprofile=coverage.out ./... + go tool cover -func=coverage.out + +# Run tests with HTML coverage report +test-coverage-html: test-coverage + go tool cover -html=coverage.out diff --git a/chat_service_integration_test.go b/chat_service_integration_test.go index 1d3cdfa..a7c1e50 100644 --- a/chat_service_integration_test.go +++ b/chat_service_integration_test.go @@ -17,6 +17,8 @@ type mockLLM struct { disambigID string keywordsErr error disambigErr error + embeddings []float64 + embeddingErr error } var _ LLMClientAPI = (*mockLLM)(nil) @@ -27,6 +29,9 @@ func (m *mockLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]i func (m *mockLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Visit) (string, error) { return m.disambigID, m.disambigErr } +func (m *mockLLM) GetEmbeddings(ctx context.Context, input string) ([]float64, error) { + return m.embeddings, m.embeddingErr +} // --- Test VisitDB --- type testVisitDB struct { diff --git a/controller_test.go b/controller_test.go new file mode 100644 index 0000000..7f63217 --- /dev/null +++ b/controller_test.go @@ -0,0 +1,180 @@ +package main + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestChatService_HandleChat(t *testing.T) { + // Setup mock dependencies + mockLLM := &MockLLMClient{ + ExtractKeywordsFunc: func(ctx context.Context, message string) (map[string]interface{}, error) { + return map[string]interface{}{ + "translate": "test translation", + "keyword": []string{"test", "keyword"}, + "animal": "dog", + }, nil + }, + DisambiguateBestMatchFunc: func(ctx context.Context, message string, candidates []Visit) (string, error) { + return "visit1", nil + }, + } + + mockDB := &MockVisitDB{ + FindCandidatesFunc: func(keywords []string) ([]Visit, error) { + return []Visit{ + {ID: "visit1", Procedures: []Procedure{{Name: "Test", Price: 100, DurationMin: 30}}}, + {ID: "visit2", Procedures: []Procedure{{Name: "Test2", Price: 200, DurationMin: 60}}}, + }, nil + }, + FindByIdFunc: func(id string) (Visit, error) { + return Visit{ + ID: id, + Procedures: []Procedure{ + {Name: "Test", Price: 100, DurationMin: 30}, + }, + Notes: "Test notes", + }, nil + }, + } + + mockRepo := &MockChatRepository{} + + // Create service with mocks + svc := NewChatService(mockLLM, mockDB, mockRepo) + + // Create test context + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + // Mock request body + reqBody := `{"message": "I need a test visit"}` + c.Request = httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(reqBody)) + c.Request.Header.Set("Content-Type", "application/json") + + // Call the handler + svc.HandleChat(c) + + // Validate response + assert.Equal(t, http.StatusOK, w.Code) + + var resp ChatResponse + err := json.Unmarshal(w.Body.Bytes(), &resp) + assert.NoError(t, err) + + assert.NotNil(t, resp.Match) + assert.Equal(t, "visit1", *resp.Match) + assert.Equal(t, 100, resp.TotalPrice) + assert.Equal(t, 30, resp.TotalDuration) + assert.Equal(t, "Test notes", resp.Notes) +} + +type MockVisitDB struct { + FindCandidatesFunc func(keywords []string) ([]Visit, error) + FindByIdFunc func(id string) (Visit, error) +} + +func (m *MockVisitDB) FindCandidates(keywords []string) ([]Visit, error) { + if m.FindCandidatesFunc != nil { + return m.FindCandidatesFunc(keywords) + } + return nil, nil +} + +func (m *MockVisitDB) FindById(id string) (Visit, error) { + if m.FindByIdFunc != nil { + return m.FindByIdFunc(id) + } + return Visit{}, nil +} + +type MockChatRepository struct { + SaveChatInteractionFunc func(ctx context.Context, interaction ChatInteraction) error + ListChatInteractionsFunc func(ctx context.Context, limit, offset int) ([]ChatInteraction, error) + SaveLLMRawEventFunc func(ctx context.Context, correlationID, phase, raw string) error + ListLLMRawEventsFunc func(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error) + SaveKnowledgeModelFunc func(ctx context.Context, text string) error + ListKnowledgeModelsFunc func(ctx context.Context, limit, offset int) ([]knowledgeModelMeta, error) + GetKnowledgeModelTextFunc func(ctx context.Context, id int64) (string, error) + GetUserByUsernameFunc func(ctx context.Context, username string) (*User, error) + CountUsersFunc func(ctx context.Context) (int, error) + CreateUserFunc func(ctx context.Context, username, passwordHash string) error +} + +func (m *MockChatRepository) SaveChatInteraction(ctx context.Context, interaction ChatInteraction) error { + if m.SaveChatInteractionFunc != nil { + return m.SaveChatInteractionFunc(ctx, interaction) + } + return nil +} + +func (m *MockChatRepository) ListChatInteractions(ctx context.Context, limit, offset int) ([]ChatInteraction, error) { + if m.ListChatInteractionsFunc != nil { + return m.ListChatInteractionsFunc(ctx, limit, offset) + } + return []ChatInteraction{}, nil +} + +func (m *MockChatRepository) SaveLLMRawEvent(ctx context.Context, correlationID, phase, raw string) error { + if m.SaveLLMRawEventFunc != nil { + return m.SaveLLMRawEventFunc(ctx, correlationID, phase, raw) + } + return nil +} + +func (m *MockChatRepository) ListLLMRawEvents(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error) { + if m.ListLLMRawEventsFunc != nil { + return m.ListLLMRawEventsFunc(ctx, correlationID, limit, offset) + } + return []RawLLMEvent{}, nil +} + +func (m *MockChatRepository) SaveKnowledgeModel(ctx context.Context, text string) error { + if m.SaveKnowledgeModelFunc != nil { + return m.SaveKnowledgeModelFunc(ctx, text) + } + return nil +} + +func (m *MockChatRepository) ListKnowledgeModels(ctx context.Context, limit, offset int) ([]knowledgeModelMeta, error) { + if m.ListKnowledgeModelsFunc != nil { + return m.ListKnowledgeModelsFunc(ctx, limit, offset) + } + return []knowledgeModelMeta{}, nil +} + +func (m *MockChatRepository) GetKnowledgeModelText(ctx context.Context, id int64) (string, error) { + if m.GetKnowledgeModelTextFunc != nil { + return m.GetKnowledgeModelTextFunc(ctx, id) + } + return "", nil +} + +func (m *MockChatRepository) GetUserByUsername(ctx context.Context, username string) (*User, error) { + if m.GetUserByUsernameFunc != nil { + return m.GetUserByUsernameFunc(ctx, username) + } + return nil, nil +} + +func (m *MockChatRepository) CountUsers(ctx context.Context) (int, error) { + if m.CountUsersFunc != nil { + return m.CountUsersFunc(ctx) + } + return 0, nil +} + +func (m *MockChatRepository) CreateUser(ctx context.Context, username, passwordHash string) error { + if m.CreateUserFunc != nil { + return m.CreateUserFunc(ctx, username, passwordHash) + } + return nil +} diff --git a/handlechat_integration_test.go b/handlechat_integration_test.go index 870a7e6..82e9e95 100644 --- a/handlechat_integration_test.go +++ b/handlechat_integration_test.go @@ -20,6 +20,8 @@ type mockHandleChatLLM struct { disambigID string keywordsErr error disambigErr error + embeddings []float64 + embeddingErr error } func (m *mockHandleChatLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]interface{}, error) { @@ -28,6 +30,9 @@ func (m *mockHandleChatLLM) ExtractKeywords(ctx context.Context, msg string) (ma func (m *mockHandleChatLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Visit) (string, error) { return m.disambigID, m.disambigErr } +func (m *mockHandleChatLLM) GetEmbeddings(ctx context.Context, input string) ([]float64, error) { + return m.embeddings, m.embeddingErr +} // mapChatRepo is an in-memory implementation of ChatRepositoryAPI for tests. type mapChatRepo struct { diff --git a/llm_test.go b/llm_test.go new file mode 100644 index 0000000..83396d5 --- /dev/null +++ b/llm_test.go @@ -0,0 +1,167 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +// MockLLMClient implements LLMClientAPI for testing +type MockLLMClient struct { + ExtractKeywordsFunc func(ctx context.Context, message string) (map[string]interface{}, error) + DisambiguateBestMatchFunc func(ctx context.Context, message string, candidates []Visit) (string, error) + GetEmbeddingsFunc func(ctx context.Context, input string) ([]float64, error) +} + +func (m *MockLLMClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) { + if m.ExtractKeywordsFunc != nil { + return m.ExtractKeywordsFunc(ctx, message) + } + return map[string]interface{}{ + "translate": "test translation", + "keyword": []string{"test", "keywords"}, + "animal": "test animal", + }, nil +} + +func (m *MockLLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) { + if m.DisambiguateBestMatchFunc != nil { + return m.DisambiguateBestMatchFunc(ctx, message, candidates) + } + if len(candidates) > 0 { + return candidates[0].ID, nil + } + return "", nil +} + +func (m *MockLLMClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) { + if m.GetEmbeddingsFunc != nil { + return m.GetEmbeddingsFunc(ctx, input) + } + return []float64{0.1, 0.2, 0.3}, nil +} + +func TestNewLLMClientFromEnv(t *testing.T) { + tests := []struct { + name string + envVars map[string]string + expectedType string + }{ + { + name: "default to ollama when no provider specified", + envVars: map[string]string{ + "LLM_PROVIDER": "", + }, + expectedType: "*main.OllamaClient", + }, + { + name: "use openai client when provider is openai", + envVars: map[string]string{ + "LLM_PROVIDER": "openai", + }, + expectedType: "*main.OpenAIClient", + }, + { + name: "use ollama client when provider is ollama", + envVars: map[string]string{ + "LLM_PROVIDER": "ollama", + }, + expectedType: "*main.OllamaClient", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Clear existing env vars + os.Unsetenv("LLM_PROVIDER") + os.Unsetenv("OPENAI_API_KEY") + os.Unsetenv("OPENAI_BASE_URL") + os.Unsetenv("OPENAI_MODEL") + + // Set env vars for test + for k, v := range tt.envVars { + os.Setenv(k, v) + } + + client := NewLLMClientFromEnv(nil) + assert.NotNil(t, client) + assert.Equal(t, tt.expectedType, typeName(client)) + }) + } +} + +func typeName(v interface{}) string { + if v == nil { + return "nil" + } + return fmt.Sprintf("%T", v) +} + +func TestOpenAIClient_GetEmbeddings(t *testing.T) { + // Mock server to simulate OpenAI API response + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // No path check needed since we're passing the full URL as baseURL + assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) + + var reqBody map[string]interface{} + err := json.NewDecoder(r.Body).Decode(&reqBody) + assert.NoError(t, err) + assert.Equal(t, "test-model", reqBody["model"]) + assert.Equal(t, "test input", reqBody["input"]) + + // Respond with mock embedding data + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "data": [ + { + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] + } + ] + }`)) + })) + defer server.Close() + + // Pass the full URL as the baseURL parameter + client := NewOpenAIClient("test-key", server.URL, "test-model", nil) + embeddings, err := client.GetEmbeddings(context.Background(), "test input") + + assert.NoError(t, err) + assert.Equal(t, []float64{0.1, 0.2, 0.3, 0.4, 0.5}, embeddings) +} + +func TestOllamaClient_GetEmbeddings(t *testing.T) { + // Mock server to simulate Ollama API response + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The API URL for embeddings in ollama_client.go is constructed as: + // apiURL (baseURL) = "http://localhost:11434/api/embeddings" + // So we shouldn't expect a path suffix here + + var reqBody map[string]interface{} + err := json.NewDecoder(r.Body).Decode(&reqBody) + assert.NoError(t, err) + assert.Equal(t, "test-model", reqBody["model"]) + assert.Equal(t, "test input", reqBody["prompt"]) + + // Respond with mock embedding data + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] + }`)) + })) + defer server.Close() + + // Pass the full URL as the baseURL parameter + client := NewOllamaClient("", server.URL, "test-model", nil) + embeddings, err := client.GetEmbeddings(context.Background(), "test input") + + assert.NoError(t, err) + assert.Equal(t, []float64{0.1, 0.2, 0.3, 0.4, 0.5}, embeddings) +} diff --git a/openrouter_integration_test.go b/openrouter_integration_test.go index 1698893..085b780 100644 --- a/openrouter_integration_test.go +++ b/openrouter_integration_test.go @@ -18,16 +18,9 @@ func TestLLMClient_OpenRouterStyle_ExtractKeywords(t *testing.T) { appConfig.LLM.ExtractKeywordsPrompt = "Dummy {{.Message}}" // simple template ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/v1/chat/completions" { - w.WriteHeader(http.StatusNotFound) - return - } - // Optionally verify header presence - if got := r.Header.Get("Authorization"); got == "" { - w.WriteHeader(http.StatusUnauthorized) - return - } + // Format the response exactly as the OpenAI API would w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) resp := map[string]interface{}{ "choices": []map[string]interface{}{ { @@ -35,14 +28,25 @@ func TestLLMClient_OpenRouterStyle_ExtractKeywords(t *testing.T) { "role": "assistant", "content": `{"translate":"dog has diarrhea","keyword":["diarrhea","digestive"],"animal":"dog"}`, }, + "index": 0, }, }, + "id": "test-id", + "object": "chat.completion", + "created": 1717585613, + "model": "meta-llama/test", + "usage": map[string]interface{}{ + "prompt_tokens": 50, + "completion_tokens": 20, + "total_tokens": 70, + }, } json.NewEncoder(w).Encode(resp) })) defer ts.Close() - llm := NewLLMClient("test-key", ts.URL+"/v1/chat/completions", "meta-llama/test", nil) + // Pass the server URL directly (not adding /v1 as that causes issues) + llm := NewOpenAIClient("test-key", ts.URL, "meta-llama/test", nil) res, err := llm.ExtractKeywords(context.Background(), "kutya hasmenés") if err != nil { te(t, "unexpected error: %v", err) @@ -66,18 +70,22 @@ func TestLLMClient_OpenRouterStyle_Error(t *testing.T) { appConfig.LLM.ExtractKeywordsPrompt = "Dummy {{.Message}}" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate a rate limit error response from OpenAI API w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) json.NewEncoder(w).Encode(map[string]interface{}{ "error": map[string]interface{}{ - "message": "Rate limit", - "type": "rate_limit", + "message": "Rate limit exceeded, please try again in 20ms", + "type": "rate_limit_exceeded", + "param": nil, + "code": "rate_limit_exceeded", }, }) })) defer ts.Close() - llm := NewLLMClient("test-key", ts.URL+"/v1/chat/completions", "meta-llama/test", nil) + // Use the same URL structure as the success test + llm := NewOpenAIClient("test-key", ts.URL, "meta-llama/test", nil) _, err := llm.ExtractKeywords(context.Background(), "test") if err == nil || !contains(err.Error(), "Rate limit") { te(t, "expected rate limit error, got: %v", err)