vetrag/handlechat_integration_test.go

181 lines
5.4 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
)
// mockHandleChatLLM mocks LLM behavior for integration tests
// It implements only the public interface methods.
type mockHandleChatLLM struct {
keywordsResp map[string]interface{}
disambigID string
keywordsErr error
disambigErr error
}
func (m *mockHandleChatLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]interface{}, error) {
return m.keywordsResp, m.keywordsErr
}
func (m *mockHandleChatLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Visit) (string, error) {
return m.disambigID, m.disambigErr
}
// mapChatRepo is an in-memory implementation of ChatRepositoryAPI for tests.
type mapChatRepo struct {
mu sync.Mutex
interactions []ChatInteraction
rawEvents []struct{ CorrelationID, Phase, Raw string }
}
func (r *mapChatRepo) SaveChatInteraction(ctx context.Context, rec ChatInteraction) error {
r.mu.Lock()
defer r.mu.Unlock()
r.interactions = append(r.interactions, rec)
return nil
}
func (r *mapChatRepo) ListChatInteractions(ctx context.Context, limit, offset int) ([]ChatInteraction, error) {
r.mu.Lock()
defer r.mu.Unlock()
if offset >= len(r.interactions) {
return []ChatInteraction{}, nil
}
end := offset + limit
if end > len(r.interactions) {
end = len(r.interactions)
}
// return a copy slice to avoid mutation
out := make([]ChatInteraction, end-offset)
copy(out, r.interactions[offset:end])
return out, nil
}
func (r *mapChatRepo) SaveLLMRawEvent(ctx context.Context, correlationID, phase, raw string) error {
r.mu.Lock()
defer r.mu.Unlock()
r.rawEvents = append(r.rawEvents, struct{ CorrelationID, Phase, Raw string }{correlationID, phase, raw})
return nil
}
func (r *mapChatRepo) ListLLMRawEvents(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error) {
r.mu.Lock()
defer r.mu.Unlock()
var out []RawLLMEvent
for _, e := range r.rawEvents {
if e.CorrelationID == correlationID {
out = append(out, RawLLMEvent{CorrelationID: e.CorrelationID, Phase: e.Phase, RawJSON: e.Raw, CreatedAt: time.Now()})
}
}
return out, nil
}
// testVisitDB2 replicates a minimal VisitDB for integration
// (avoids relying on real Bleve index)
type testVisitDB2 struct {
byID map[string]Visit
candidates []Visit
findErr error
}
func (db *testVisitDB2) FindCandidates(keywords []string) ([]Visit, error) {
return db.candidates, db.findErr
}
func (db *testVisitDB2) FindById(id string) (Visit, error) {
if v, ok := db.byID[id]; ok {
return v, nil
}
return Visit{}, context.DeadlineExceeded
}
func TestHandleChat_PersistsSuccessInteraction(t *testing.T) {
gin.SetMode(gin.TestMode)
visit := Visit{ID: "xray", Notes: "Exam note", Procedures: []Procedure{{Name: "Röntgen vizsgálat", Price: 16000, DurationMin: 25}}}
db := &testVisitDB2{byID: map[string]Visit{"xray": visit}, candidates: []Visit{visit}}
llm := &mockHandleChatLLM{keywordsResp: map[string]interface{}{"translate": "xray leg", "animal": "dog", "keyword": []string{"xray", "bone"}}, disambigID: "xray"}
repo := &mapChatRepo{}
cs := NewChatService(llm, db, repo)
r := gin.New()
r.POST("/chat", cs.HandleChat)
body := map[string]string{"message": "my dog needs an x-ray"}
b, _ := json.Marshal(body)
req, _ := http.NewRequest(http.MethodPost, "/chat", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 got %d", w.Code)
}
corrID := w.Header().Get("X-Correlation-ID")
if corrID == "" {
t.Fatalf("expected correlation id header set")
}
repo.mu.Lock()
if len(repo.interactions) != 1 {
repo.mu.Unlock()
bdy := w.Body.String()
t.Fatalf("expected 1 interaction persisted, got %d; body=%s", len(repo.interactions), bdy)
}
rec := repo.interactions[0]
repo.mu.Unlock()
if rec.CorrelationID != corrID {
t.Errorf("correlation mismatch: header=%s rec=%s", corrID, rec.CorrelationID)
}
if rec.BestVisitID != "xray" {
t.Errorf("expected BestVisitID xray got %s", rec.BestVisitID)
}
if rec.TotalPrice != 16000 || rec.TotalDuration != 25 {
t.Errorf("unexpected totals: %+v", rec)
}
if len(rec.Keywords) != 2 {
t.Errorf("expected 2 keywords got %v", rec.Keywords)
}
}
func TestHandleChat_PersistsOnLLMError(t *testing.T) {
gin.SetMode(gin.TestMode)
llm := &mockHandleChatLLM{keywordsErr: context.DeadlineExceeded}
db := &testVisitDB2{byID: map[string]Visit{}, candidates: []Visit{}}
repo := &mapChatRepo{}
cs := NewChatService(llm, db, repo)
r := gin.New()
r.POST("/chat", cs.HandleChat)
body := map[string]string{"message": "some message"}
b, _ := json.Marshal(body)
req, _ := http.NewRequest(http.MethodPost, "/chat", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 got %d", w.Code)
}
repo.mu.Lock()
cnt := len(repo.interactions)
var rec ChatInteraction
if cnt == 1 {
rec = repo.interactions[0]
}
repo.mu.Unlock()
if cnt != 1 {
t.Fatalf("expected 1 interaction persisted on error got %d", cnt)
}
if rec.BestVisitID != "" {
t.Errorf("expected no best visit on error got %s", rec.BestVisitID)
}
cid := w.Header().Get("X-Correlation-ID")
if cid == "" {
t.Fatalf("expected correlation id header on error path")
}
}