This commit is contained in:
lehel 2025-10-08 14:23:14 +02:00
parent a0f477c9a8
commit 77c0396623
No known key found for this signature in database
GPG Key ID: 9C4F9D6111EE5CFA
6 changed files with 398 additions and 15 deletions

View File

@ -32,7 +32,7 @@ DB_SSLMODE ?= disable
db_env = PGHOST=$(DB_HOST) PGPORT=$(DB_PORT) PGUSER=$(DB_USER) PGPASSWORD=$(DB_PASSWORD) PGDATABASE=$(DB_NAME) PGSSLMODE=$(DB_SSLMODE) db_env = PGHOST=$(DB_HOST) PGPORT=$(DB_PORT) PGUSER=$(DB_USER) PGPASSWORD=$(DB_PASSWORD) PGDATABASE=$(DB_NAME) PGSSLMODE=$(DB_SSLMODE)
# Run the Go server (assumes Ollama is running) with DB env vars # Run the Go server (assumes Ollama is running) with DB env vars
run: run:
$(db_env) OPENAI_API_KEY=ollama OPENAI_BASE_URL=http://localhost:11434/api/chat OPENAI_MODEL=qwen3:latest go run . $(db_env) OPENAI_API_KEY=ollama OPENAI_BASE_URL=http://localhost:11434/api/chat OPENAI_MODEL=qwen3:latest go run .
# Run without pulling model (faster if already present) # Run without pulling model (faster if already present)
@ -48,7 +48,25 @@ print-dsn:
@echo postgres://$(DB_USER):******@$(DB_HOST):$(DB_PORT)/$(DB_NAME)?sslmode=$(DB_SSLMODE) @echo postgres://$(DB_USER):******@$(DB_HOST):$(DB_PORT)/$(DB_NAME)?sslmode=$(DB_SSLMODE)
# Run tests # Run tests
.PHONY: test .PHONY: test test-verbose test-race test-coverage test-coverage-html
# Run standard tests
test: test:
go test ./... go test ./...
# Run tests with verbose output
test-verbose:
go test -v ./...
# Run tests with race detection
test-race:
go test -race ./...
# Run tests with coverage reporting
test-coverage:
go test -coverprofile=coverage.out ./...
go tool cover -func=coverage.out
# Run tests with HTML coverage report
test-coverage-html: test-coverage
go tool cover -html=coverage.out

View File

@ -17,6 +17,8 @@ type mockLLM struct {
disambigID string disambigID string
keywordsErr error keywordsErr error
disambigErr error disambigErr error
embeddings []float64
embeddingErr error
} }
var _ LLMClientAPI = (*mockLLM)(nil) var _ LLMClientAPI = (*mockLLM)(nil)
@ -27,6 +29,9 @@ func (m *mockLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]i
func (m *mockLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Visit) (string, error) { func (m *mockLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Visit) (string, error) {
return m.disambigID, m.disambigErr return m.disambigID, m.disambigErr
} }
func (m *mockLLM) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
return m.embeddings, m.embeddingErr
}
// --- Test VisitDB --- // --- Test VisitDB ---
type testVisitDB struct { type testVisitDB struct {

180
controller_test.go Normal file
View File

@ -0,0 +1,180 @@
package main
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestChatService_HandleChat(t *testing.T) {
// Setup mock dependencies
mockLLM := &MockLLMClient{
ExtractKeywordsFunc: func(ctx context.Context, message string) (map[string]interface{}, error) {
return map[string]interface{}{
"translate": "test translation",
"keyword": []string{"test", "keyword"},
"animal": "dog",
}, nil
},
DisambiguateBestMatchFunc: func(ctx context.Context, message string, candidates []Visit) (string, error) {
return "visit1", nil
},
}
mockDB := &MockVisitDB{
FindCandidatesFunc: func(keywords []string) ([]Visit, error) {
return []Visit{
{ID: "visit1", Procedures: []Procedure{{Name: "Test", Price: 100, DurationMin: 30}}},
{ID: "visit2", Procedures: []Procedure{{Name: "Test2", Price: 200, DurationMin: 60}}},
}, nil
},
FindByIdFunc: func(id string) (Visit, error) {
return Visit{
ID: id,
Procedures: []Procedure{
{Name: "Test", Price: 100, DurationMin: 30},
},
Notes: "Test notes",
}, nil
},
}
mockRepo := &MockChatRepository{}
// Create service with mocks
svc := NewChatService(mockLLM, mockDB, mockRepo)
// Create test context
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Mock request body
reqBody := `{"message": "I need a test visit"}`
c.Request = httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(reqBody))
c.Request.Header.Set("Content-Type", "application/json")
// Call the handler
svc.HandleChat(c)
// Validate response
assert.Equal(t, http.StatusOK, w.Code)
var resp ChatResponse
err := json.Unmarshal(w.Body.Bytes(), &resp)
assert.NoError(t, err)
assert.NotNil(t, resp.Match)
assert.Equal(t, "visit1", *resp.Match)
assert.Equal(t, 100, resp.TotalPrice)
assert.Equal(t, 30, resp.TotalDuration)
assert.Equal(t, "Test notes", resp.Notes)
}
type MockVisitDB struct {
FindCandidatesFunc func(keywords []string) ([]Visit, error)
FindByIdFunc func(id string) (Visit, error)
}
func (m *MockVisitDB) FindCandidates(keywords []string) ([]Visit, error) {
if m.FindCandidatesFunc != nil {
return m.FindCandidatesFunc(keywords)
}
return nil, nil
}
func (m *MockVisitDB) FindById(id string) (Visit, error) {
if m.FindByIdFunc != nil {
return m.FindByIdFunc(id)
}
return Visit{}, nil
}
type MockChatRepository struct {
SaveChatInteractionFunc func(ctx context.Context, interaction ChatInteraction) error
ListChatInteractionsFunc func(ctx context.Context, limit, offset int) ([]ChatInteraction, error)
SaveLLMRawEventFunc func(ctx context.Context, correlationID, phase, raw string) error
ListLLMRawEventsFunc func(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error)
SaveKnowledgeModelFunc func(ctx context.Context, text string) error
ListKnowledgeModelsFunc func(ctx context.Context, limit, offset int) ([]knowledgeModelMeta, error)
GetKnowledgeModelTextFunc func(ctx context.Context, id int64) (string, error)
GetUserByUsernameFunc func(ctx context.Context, username string) (*User, error)
CountUsersFunc func(ctx context.Context) (int, error)
CreateUserFunc func(ctx context.Context, username, passwordHash string) error
}
func (m *MockChatRepository) SaveChatInteraction(ctx context.Context, interaction ChatInteraction) error {
if m.SaveChatInteractionFunc != nil {
return m.SaveChatInteractionFunc(ctx, interaction)
}
return nil
}
func (m *MockChatRepository) ListChatInteractions(ctx context.Context, limit, offset int) ([]ChatInteraction, error) {
if m.ListChatInteractionsFunc != nil {
return m.ListChatInteractionsFunc(ctx, limit, offset)
}
return []ChatInteraction{}, nil
}
func (m *MockChatRepository) SaveLLMRawEvent(ctx context.Context, correlationID, phase, raw string) error {
if m.SaveLLMRawEventFunc != nil {
return m.SaveLLMRawEventFunc(ctx, correlationID, phase, raw)
}
return nil
}
func (m *MockChatRepository) ListLLMRawEvents(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error) {
if m.ListLLMRawEventsFunc != nil {
return m.ListLLMRawEventsFunc(ctx, correlationID, limit, offset)
}
return []RawLLMEvent{}, nil
}
func (m *MockChatRepository) SaveKnowledgeModel(ctx context.Context, text string) error {
if m.SaveKnowledgeModelFunc != nil {
return m.SaveKnowledgeModelFunc(ctx, text)
}
return nil
}
func (m *MockChatRepository) ListKnowledgeModels(ctx context.Context, limit, offset int) ([]knowledgeModelMeta, error) {
if m.ListKnowledgeModelsFunc != nil {
return m.ListKnowledgeModelsFunc(ctx, limit, offset)
}
return []knowledgeModelMeta{}, nil
}
func (m *MockChatRepository) GetKnowledgeModelText(ctx context.Context, id int64) (string, error) {
if m.GetKnowledgeModelTextFunc != nil {
return m.GetKnowledgeModelTextFunc(ctx, id)
}
return "", nil
}
func (m *MockChatRepository) GetUserByUsername(ctx context.Context, username string) (*User, error) {
if m.GetUserByUsernameFunc != nil {
return m.GetUserByUsernameFunc(ctx, username)
}
return nil, nil
}
func (m *MockChatRepository) CountUsers(ctx context.Context) (int, error) {
if m.CountUsersFunc != nil {
return m.CountUsersFunc(ctx)
}
return 0, nil
}
func (m *MockChatRepository) CreateUser(ctx context.Context, username, passwordHash string) error {
if m.CreateUserFunc != nil {
return m.CreateUserFunc(ctx, username, passwordHash)
}
return nil
}

View File

@ -20,6 +20,8 @@ type mockHandleChatLLM struct {
disambigID string disambigID string
keywordsErr error keywordsErr error
disambigErr error disambigErr error
embeddings []float64
embeddingErr error
} }
func (m *mockHandleChatLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]interface{}, error) { func (m *mockHandleChatLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]interface{}, error) {
@ -28,6 +30,9 @@ func (m *mockHandleChatLLM) ExtractKeywords(ctx context.Context, msg string) (ma
func (m *mockHandleChatLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Visit) (string, error) { func (m *mockHandleChatLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Visit) (string, error) {
return m.disambigID, m.disambigErr return m.disambigID, m.disambigErr
} }
func (m *mockHandleChatLLM) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
return m.embeddings, m.embeddingErr
}
// mapChatRepo is an in-memory implementation of ChatRepositoryAPI for tests. // mapChatRepo is an in-memory implementation of ChatRepositoryAPI for tests.
type mapChatRepo struct { type mapChatRepo struct {

167
llm_test.go Normal file
View File

@ -0,0 +1,167 @@
package main
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
// MockLLMClient implements LLMClientAPI for testing
type MockLLMClient struct {
ExtractKeywordsFunc func(ctx context.Context, message string) (map[string]interface{}, error)
DisambiguateBestMatchFunc func(ctx context.Context, message string, candidates []Visit) (string, error)
GetEmbeddingsFunc func(ctx context.Context, input string) ([]float64, error)
}
func (m *MockLLMClient) ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
if m.ExtractKeywordsFunc != nil {
return m.ExtractKeywordsFunc(ctx, message)
}
return map[string]interface{}{
"translate": "test translation",
"keyword": []string{"test", "keywords"},
"animal": "test animal",
}, nil
}
func (m *MockLLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Visit) (string, error) {
if m.DisambiguateBestMatchFunc != nil {
return m.DisambiguateBestMatchFunc(ctx, message, candidates)
}
if len(candidates) > 0 {
return candidates[0].ID, nil
}
return "", nil
}
func (m *MockLLMClient) GetEmbeddings(ctx context.Context, input string) ([]float64, error) {
if m.GetEmbeddingsFunc != nil {
return m.GetEmbeddingsFunc(ctx, input)
}
return []float64{0.1, 0.2, 0.3}, nil
}
func TestNewLLMClientFromEnv(t *testing.T) {
tests := []struct {
name string
envVars map[string]string
expectedType string
}{
{
name: "default to ollama when no provider specified",
envVars: map[string]string{
"LLM_PROVIDER": "",
},
expectedType: "*main.OllamaClient",
},
{
name: "use openai client when provider is openai",
envVars: map[string]string{
"LLM_PROVIDER": "openai",
},
expectedType: "*main.OpenAIClient",
},
{
name: "use ollama client when provider is ollama",
envVars: map[string]string{
"LLM_PROVIDER": "ollama",
},
expectedType: "*main.OllamaClient",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Clear existing env vars
os.Unsetenv("LLM_PROVIDER")
os.Unsetenv("OPENAI_API_KEY")
os.Unsetenv("OPENAI_BASE_URL")
os.Unsetenv("OPENAI_MODEL")
// Set env vars for test
for k, v := range tt.envVars {
os.Setenv(k, v)
}
client := NewLLMClientFromEnv(nil)
assert.NotNil(t, client)
assert.Equal(t, tt.expectedType, typeName(client))
})
}
}
func typeName(v interface{}) string {
if v == nil {
return "nil"
}
return fmt.Sprintf("%T", v)
}
func TestOpenAIClient_GetEmbeddings(t *testing.T) {
// Mock server to simulate OpenAI API response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// No path check needed since we're passing the full URL as baseURL
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
var reqBody map[string]interface{}
err := json.NewDecoder(r.Body).Decode(&reqBody)
assert.NoError(t, err)
assert.Equal(t, "test-model", reqBody["model"])
assert.Equal(t, "test input", reqBody["input"])
// Respond with mock embedding data
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{
"data": [
{
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5]
}
]
}`))
}))
defer server.Close()
// Pass the full URL as the baseURL parameter
client := NewOpenAIClient("test-key", server.URL, "test-model", nil)
embeddings, err := client.GetEmbeddings(context.Background(), "test input")
assert.NoError(t, err)
assert.Equal(t, []float64{0.1, 0.2, 0.3, 0.4, 0.5}, embeddings)
}
func TestOllamaClient_GetEmbeddings(t *testing.T) {
// Mock server to simulate Ollama API response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// The API URL for embeddings in ollama_client.go is constructed as:
// apiURL (baseURL) = "http://localhost:11434/api/embeddings"
// So we shouldn't expect a path suffix here
var reqBody map[string]interface{}
err := json.NewDecoder(r.Body).Decode(&reqBody)
assert.NoError(t, err)
assert.Equal(t, "test-model", reqBody["model"])
assert.Equal(t, "test input", reqBody["prompt"])
// Respond with mock embedding data
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5]
}`))
}))
defer server.Close()
// Pass the full URL as the baseURL parameter
client := NewOllamaClient("", server.URL, "test-model", nil)
embeddings, err := client.GetEmbeddings(context.Background(), "test input")
assert.NoError(t, err)
assert.Equal(t, []float64{0.1, 0.2, 0.3, 0.4, 0.5}, embeddings)
}

View File

@ -18,16 +18,9 @@ func TestLLMClient_OpenRouterStyle_ExtractKeywords(t *testing.T) {
appConfig.LLM.ExtractKeywordsPrompt = "Dummy {{.Message}}" // simple template appConfig.LLM.ExtractKeywordsPrompt = "Dummy {{.Message}}" // simple template
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/chat/completions" { // Format the response exactly as the OpenAI API would
w.WriteHeader(http.StatusNotFound)
return
}
// Optionally verify header presence
if got := r.Header.Get("Authorization"); got == "" {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
resp := map[string]interface{}{ resp := map[string]interface{}{
"choices": []map[string]interface{}{ "choices": []map[string]interface{}{
{ {
@ -35,14 +28,25 @@ func TestLLMClient_OpenRouterStyle_ExtractKeywords(t *testing.T) {
"role": "assistant", "role": "assistant",
"content": `{"translate":"dog has diarrhea","keyword":["diarrhea","digestive"],"animal":"dog"}`, "content": `{"translate":"dog has diarrhea","keyword":["diarrhea","digestive"],"animal":"dog"}`,
}, },
"index": 0,
}, },
}, },
"id": "test-id",
"object": "chat.completion",
"created": 1717585613,
"model": "meta-llama/test",
"usage": map[string]interface{}{
"prompt_tokens": 50,
"completion_tokens": 20,
"total_tokens": 70,
},
} }
json.NewEncoder(w).Encode(resp) json.NewEncoder(w).Encode(resp)
})) }))
defer ts.Close() defer ts.Close()
llm := NewLLMClient("test-key", ts.URL+"/v1/chat/completions", "meta-llama/test", nil) // Pass the server URL directly (not adding /v1 as that causes issues)
llm := NewOpenAIClient("test-key", ts.URL, "meta-llama/test", nil)
res, err := llm.ExtractKeywords(context.Background(), "kutya hasmenés") res, err := llm.ExtractKeywords(context.Background(), "kutya hasmenés")
if err != nil { if err != nil {
te(t, "unexpected error: %v", err) te(t, "unexpected error: %v", err)
@ -66,18 +70,22 @@ func TestLLMClient_OpenRouterStyle_Error(t *testing.T) {
appConfig.LLM.ExtractKeywordsPrompt = "Dummy {{.Message}}" appConfig.LLM.ExtractKeywordsPrompt = "Dummy {{.Message}}"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate a rate limit error response from OpenAI API
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests) w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]interface{}{ json.NewEncoder(w).Encode(map[string]interface{}{
"error": map[string]interface{}{ "error": map[string]interface{}{
"message": "Rate limit", "message": "Rate limit exceeded, please try again in 20ms",
"type": "rate_limit", "type": "rate_limit_exceeded",
"param": nil,
"code": "rate_limit_exceeded",
}, },
}) })
})) }))
defer ts.Close() defer ts.Close()
llm := NewLLMClient("test-key", ts.URL+"/v1/chat/completions", "meta-llama/test", nil) // Use the same URL structure as the success test
llm := NewOpenAIClient("test-key", ts.URL, "meta-llama/test", nil)
_, err := llm.ExtractKeywords(context.Background(), "test") _, err := llm.ExtractKeywords(context.Background(), "test")
if err == nil || !contains(err.Error(), "Rate limit") { if err == nil || !contains(err.Error(), "Rate limit") {
te(t, "expected rate limit error, got: %v", err) te(t, "expected rate limit error, got: %v", err)