add test
This commit is contained in:
parent
a0f477c9a8
commit
77c0396623
22
Makefile
22
Makefile
|
|
@ -32,7 +32,7 @@ DB_SSLMODE ?= disable
|
|||
db_env = PGHOST=$(DB_HOST) PGPORT=$(DB_PORT) PGUSER=$(DB_USER) PGPASSWORD=$(DB_PASSWORD) PGDATABASE=$(DB_NAME) PGSSLMODE=$(DB_SSLMODE)
|
||||
|
||||
# Run the Go server (assumes Ollama is running) with DB env vars
|
||||
run:
|
||||
run:
|
||||
$(db_env) OPENAI_API_KEY=ollama OPENAI_BASE_URL=http://localhost:11434/api/chat OPENAI_MODEL=qwen3:latest go run .
|
||||
|
||||
# Run without pulling model (faster if already present)
|
||||
|
|
@ -48,7 +48,25 @@ print-dsn:
|
|||
@echo postgres://$(DB_USER):******@$(DB_HOST):$(DB_PORT)/$(DB_NAME)?sslmode=$(DB_SSLMODE)
|
||||
|
||||
# Run tests
|
||||
.PHONY: test
|
||||
.PHONY: test test-verbose test-race test-coverage test-coverage-html
|
||||
|
||||
# Run standard tests
|
||||
test:
|
||||
go test ./...
|
||||
|
||||
# Run tests with verbose output
|
||||
test-verbose:
|
||||
go test -v ./...
|
||||
|
||||
# Run tests with race detection
|
||||
test-race:
|
||||
go test -race ./...
|
||||
|
||||
# Run tests with coverage reporting
|
||||
test-coverage:
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -func=coverage.out
|
||||
|
||||
# Run tests with HTML coverage report
|
||||
test-coverage-html: test-coverage
|
||||
go tool cover -html=coverage.out
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ type mockLLM struct {
|
|||
disambigID string
|
||||
keywordsErr error
|
||||
disambigErr error
|
||||
embeddings []float64
|
||||
embeddingErr error
|
||||
}
|
||||
|
||||
var _ LLMClientAPI = (*mockLLM)(nil)
|
||||
|
|
@ -27,6 +29,9 @@ func (m *mockLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]i
|
|||
func (m *mockLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Visit) (string, error) {
|
||||
return m.disambigID, m.disambigErr
|
||||
}
|
||||
func (m *mockLLM) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
|
||||
return m.embeddings, m.embeddingErr
|
||||
}
|
||||
|
||||
// --- Test VisitDB ---
|
||||
type testVisitDB struct {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,180 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestChatService_HandleChat(t *testing.T) {
|
||||
// Setup mock dependencies
|
||||
mockLLM := &MockLLMClient{
|
||||
ExtractKeywordsFunc: func(ctx context.Context, message string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"translate": "test translation",
|
||||
"keyword": []string{"test", "keyword"},
|
||||
"animal": "dog",
|
||||
}, nil
|
||||
},
|
||||
DisambiguateBestMatchFunc: func(ctx context.Context, message string, candidates []Visit) (string, error) {
|
||||
return "visit1", nil
|
||||
},
|
||||
}
|
||||
|
||||
mockDB := &MockVisitDB{
|
||||
FindCandidatesFunc: func(keywords []string) ([]Visit, error) {
|
||||
return []Visit{
|
||||
{ID: "visit1", Procedures: []Procedure{{Name: "Test", Price: 100, DurationMin: 30}}},
|
||||
{ID: "visit2", Procedures: []Procedure{{Name: "Test2", Price: 200, DurationMin: 60}}},
|
||||
}, nil
|
||||
},
|
||||
FindByIdFunc: func(id string) (Visit, error) {
|
||||
return Visit{
|
||||
ID: id,
|
||||
Procedures: []Procedure{
|
||||
{Name: "Test", Price: 100, DurationMin: 30},
|
||||
},
|
||||
Notes: "Test notes",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mockRepo := &MockChatRepository{}
|
||||
|
||||
// Create service with mocks
|
||||
svc := NewChatService(mockLLM, mockDB, mockRepo)
|
||||
|
||||
// Create test context
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Mock request body
|
||||
reqBody := `{"message": "I need a test visit"}`
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(reqBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Call the handler
|
||||
svc.HandleChat(c)
|
||||
|
||||
// Validate response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp ChatResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, resp.Match)
|
||||
assert.Equal(t, "visit1", *resp.Match)
|
||||
assert.Equal(t, 100, resp.TotalPrice)
|
||||
assert.Equal(t, 30, resp.TotalDuration)
|
||||
assert.Equal(t, "Test notes", resp.Notes)
|
||||
}
|
||||
|
||||
type MockVisitDB struct {
|
||||
FindCandidatesFunc func(keywords []string) ([]Visit, error)
|
||||
FindByIdFunc func(id string) (Visit, error)
|
||||
}
|
||||
|
||||
func (m *MockVisitDB) FindCandidates(keywords []string) ([]Visit, error) {
|
||||
if m.FindCandidatesFunc != nil {
|
||||
return m.FindCandidatesFunc(keywords)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockVisitDB) FindById(id string) (Visit, error) {
|
||||
if m.FindByIdFunc != nil {
|
||||
return m.FindByIdFunc(id)
|
||||
}
|
||||
return Visit{}, nil
|
||||
}
|
||||
|
||||
type MockChatRepository struct {
|
||||
SaveChatInteractionFunc func(ctx context.Context, interaction ChatInteraction) error
|
||||
ListChatInteractionsFunc func(ctx context.Context, limit, offset int) ([]ChatInteraction, error)
|
||||
SaveLLMRawEventFunc func(ctx context.Context, correlationID, phase, raw string) error
|
||||
ListLLMRawEventsFunc func(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error)
|
||||
SaveKnowledgeModelFunc func(ctx context.Context, text string) error
|
||||
ListKnowledgeModelsFunc func(ctx context.Context, limit, offset int) ([]knowledgeModelMeta, error)
|
||||
GetKnowledgeModelTextFunc func(ctx context.Context, id int64) (string, error)
|
||||
GetUserByUsernameFunc func(ctx context.Context, username string) (*User, error)
|
||||
CountUsersFunc func(ctx context.Context) (int, error)
|
||||
CreateUserFunc func(ctx context.Context, username, passwordHash string) error
|
||||
}
|
||||
|
||||
func (m *MockChatRepository) SaveChatInteraction(ctx context.Context, interaction ChatInteraction) error {
|
||||
if m.SaveChatInteractionFunc != nil {
|
||||
return m.SaveChatInteractionFunc(ctx, interaction)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockChatRepository) ListChatInteractions(ctx context.Context, limit, offset int) ([]ChatInteraction, error) {
|
||||
if m.ListChatInteractionsFunc != nil {
|
||||
return m.ListChatInteractionsFunc(ctx, limit, offset)
|
||||
}
|
||||
return []ChatInteraction{}, nil
|
||||
}
|
||||
|
||||
func (m *MockChatRepository) SaveLLMRawEvent(ctx context.Context, correlationID, phase, raw string) error {
|
||||
if m.SaveLLMRawEventFunc != nil {
|
||||
return m.SaveLLMRawEventFunc(ctx, correlationID, phase, raw)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockChatRepository) ListLLMRawEvents(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error) {
|
||||
if m.ListLLMRawEventsFunc != nil {
|
||||
return m.ListLLMRawEventsFunc(ctx, correlationID, limit, offset)
|
||||
}
|
||||
return []RawLLMEvent{}, nil
|
||||
}
|
||||
|
||||
func (m *MockChatRepository) SaveKnowledgeModel(ctx context.Context, text string) error {
|
||||
if m.SaveKnowledgeModelFunc != nil {
|
||||
return m.SaveKnowledgeModelFunc(ctx, text)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockChatRepository) ListKnowledgeModels(ctx context.Context, limit, offset int) ([]knowledgeModelMeta, error) {
|
||||
if m.ListKnowledgeModelsFunc != nil {
|
||||
return m.ListKnowledgeModelsFunc(ctx, limit, offset)
|
||||
}
|
||||
return []knowledgeModelMeta{}, nil
|
||||
}
|
||||
|
||||
func (m *MockChatRepository) GetKnowledgeModelText(ctx context.Context, id int64) (string, error) {
|
||||
if m.GetKnowledgeModelTextFunc != nil {
|
||||
return m.GetKnowledgeModelTextFunc(ctx, id)
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *MockChatRepository) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
||||
if m.GetUserByUsernameFunc != nil {
|
||||
return m.GetUserByUsernameFunc(ctx, username)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockChatRepository) CountUsers(ctx context.Context) (int, error) {
|
||||
if m.CountUsersFunc != nil {
|
||||
return m.CountUsersFunc(ctx)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockChatRepository) CreateUser(ctx context.Context, username, passwordHash string) error {
|
||||
if m.CreateUserFunc != nil {
|
||||
return m.CreateUserFunc(ctx, username, passwordHash)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -20,6 +20,8 @@ type mockHandleChatLLM struct {
|
|||
disambigID string
|
||||
keywordsErr error
|
||||
disambigErr error
|
||||
embeddings []float64
|
||||
embeddingErr error
|
||||
}
|
||||
|
||||
func (m *mockHandleChatLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]interface{}, error) {
|
||||
|
|
@ -28,6 +30,9 @@ func (m *mockHandleChatLLM) ExtractKeywords(ctx context.Context, msg string) (ma
|
|||
func (m *mockHandleChatLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Visit) (string, error) {
|
||||
return m.disambigID, m.disambigErr
|
||||
}
|
||||
func (m *mockHandleChatLLM) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
|
||||
return m.embeddings, m.embeddingErr
|
||||
}
|
||||
|
||||
// mapChatRepo is an in-memory implementation of ChatRepositoryAPI for tests.
|
||||
type mapChatRepo struct {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,167 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// MockLLMClient implements LLMClientAPI for testing
|
||||
type MockLLMClient struct {
|
||||
ExtractKeywordsFunc func(ctx context.Context, message string) (map[string]interface{}, error)
|
||||
DisambiguateBestMatchFunc func(ctx context.Context, message string, candidates []Visit) (string, error)
|
||||
GetEmbeddingsFunc func(ctx context.Context, input string) ([]float64, error)
|
||||
}
|
||||
|
||||
func (m *MockLLMClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
||||
if m.ExtractKeywordsFunc != nil {
|
||||
return m.ExtractKeywordsFunc(ctx, message)
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"translate": "test translation",
|
||||
"keyword": []string{"test", "keywords"},
|
||||
"animal": "test animal",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockLLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
|
||||
if m.DisambiguateBestMatchFunc != nil {
|
||||
return m.DisambiguateBestMatchFunc(ctx, message, candidates)
|
||||
}
|
||||
if len(candidates) > 0 {
|
||||
return candidates[0].ID, nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *MockLLMClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
|
||||
if m.GetEmbeddingsFunc != nil {
|
||||
return m.GetEmbeddingsFunc(ctx, input)
|
||||
}
|
||||
return []float64{0.1, 0.2, 0.3}, nil
|
||||
}
|
||||
|
||||
func TestNewLLMClientFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envVars map[string]string
|
||||
expectedType string
|
||||
}{
|
||||
{
|
||||
name: "default to ollama when no provider specified",
|
||||
envVars: map[string]string{
|
||||
"LLM_PROVIDER": "",
|
||||
},
|
||||
expectedType: "*main.OllamaClient",
|
||||
},
|
||||
{
|
||||
name: "use openai client when provider is openai",
|
||||
envVars: map[string]string{
|
||||
"LLM_PROVIDER": "openai",
|
||||
},
|
||||
expectedType: "*main.OpenAIClient",
|
||||
},
|
||||
{
|
||||
name: "use ollama client when provider is ollama",
|
||||
envVars: map[string]string{
|
||||
"LLM_PROVIDER": "ollama",
|
||||
},
|
||||
expectedType: "*main.OllamaClient",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Clear existing env vars
|
||||
os.Unsetenv("LLM_PROVIDER")
|
||||
os.Unsetenv("OPENAI_API_KEY")
|
||||
os.Unsetenv("OPENAI_BASE_URL")
|
||||
os.Unsetenv("OPENAI_MODEL")
|
||||
|
||||
// Set env vars for test
|
||||
for k, v := range tt.envVars {
|
||||
os.Setenv(k, v)
|
||||
}
|
||||
|
||||
client := NewLLMClientFromEnv(nil)
|
||||
assert.NotNil(t, client)
|
||||
assert.Equal(t, tt.expectedType, typeName(client))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func typeName(v interface{}) string {
|
||||
if v == nil {
|
||||
return "nil"
|
||||
}
|
||||
return fmt.Sprintf("%T", v)
|
||||
}
|
||||
|
||||
func TestOpenAIClient_GetEmbeddings(t *testing.T) {
|
||||
// Mock server to simulate OpenAI API response
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// No path check needed since we're passing the full URL as baseURL
|
||||
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqBody)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test-model", reqBody["model"])
|
||||
assert.Equal(t, "test input", reqBody["input"])
|
||||
|
||||
// Respond with mock embedding data
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{
|
||||
"data": [
|
||||
{
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
}
|
||||
]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Pass the full URL as the baseURL parameter
|
||||
client := NewOpenAIClient("test-key", server.URL, "test-model", nil)
|
||||
embeddings, err := client.GetEmbeddings(context.Background(), "test input")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []float64{0.1, 0.2, 0.3, 0.4, 0.5}, embeddings)
|
||||
}
|
||||
|
||||
func TestOllamaClient_GetEmbeddings(t *testing.T) {
|
||||
// Mock server to simulate Ollama API response
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// The API URL for embeddings in ollama_client.go is constructed as:
|
||||
// apiURL (baseURL) = "http://localhost:11434/api/embeddings"
|
||||
// So we shouldn't expect a path suffix here
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqBody)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test-model", reqBody["model"])
|
||||
assert.Equal(t, "test input", reqBody["prompt"])
|
||||
|
||||
// Respond with mock embedding data
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Pass the full URL as the baseURL parameter
|
||||
client := NewOllamaClient("", server.URL, "test-model", nil)
|
||||
embeddings, err := client.GetEmbeddings(context.Background(), "test input")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []float64{0.1, 0.2, 0.3, 0.4, 0.5}, embeddings)
|
||||
}
|
||||
|
|
@ -18,16 +18,9 @@ func TestLLMClient_OpenRouterStyle_ExtractKeywords(t *testing.T) {
|
|||
appConfig.LLM.ExtractKeywordsPrompt = "Dummy {{.Message}}" // simple template
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/chat/completions" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
// Optionally verify header presence
|
||||
if got := r.Header.Get("Authorization"); got == "" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
// Format the response exactly as the OpenAI API would
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
resp := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
|
|
@ -35,14 +28,25 @@ func TestLLMClient_OpenRouterStyle_ExtractKeywords(t *testing.T) {
|
|||
"role": "assistant",
|
||||
"content": `{"translate":"dog has diarrhea","keyword":["diarrhea","digestive"],"animal":"dog"}`,
|
||||
},
|
||||
"index": 0,
|
||||
},
|
||||
},
|
||||
"id": "test-id",
|
||||
"object": "chat.completion",
|
||||
"created": 1717585613,
|
||||
"model": "meta-llama/test",
|
||||
"usage": map[string]interface{}{
|
||||
"prompt_tokens": 50,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 70,
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
llm := NewLLMClient("test-key", ts.URL+"/v1/chat/completions", "meta-llama/test", nil)
|
||||
// Pass the server URL directly (not adding /v1 as that causes issues)
|
||||
llm := NewOpenAIClient("test-key", ts.URL, "meta-llama/test", nil)
|
||||
res, err := llm.ExtractKeywords(context.Background(), "kutya hasmenés")
|
||||
if err != nil {
|
||||
te(t, "unexpected error: %v", err)
|
||||
|
|
@ -66,18 +70,22 @@ func TestLLMClient_OpenRouterStyle_Error(t *testing.T) {
|
|||
appConfig.LLM.ExtractKeywordsPrompt = "Dummy {{.Message}}"
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate a rate limit error response from OpenAI API
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"error": map[string]interface{}{
|
||||
"message": "Rate limit",
|
||||
"type": "rate_limit",
|
||||
"message": "Rate limit exceeded, please try again in 20ms",
|
||||
"type": "rate_limit_exceeded",
|
||||
"param": nil,
|
||||
"code": "rate_limit_exceeded",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
llm := NewLLMClient("test-key", ts.URL+"/v1/chat/completions", "meta-llama/test", nil)
|
||||
// Use the same URL structure as the success test
|
||||
llm := NewOpenAIClient("test-key", ts.URL, "meta-llama/test", nil)
|
||||
_, err := llm.ExtractKeywords(context.Background(), "test")
|
||||
if err == nil || !contains(err.Error(), "Rate limit") {
|
||||
te(t, "expected rate limit error, got: %v", err)
|
||||
|
|
|
|||
Loading…
Reference in New Issue