package main import ( "bytes" "context" "encoding/json" "net/http" "net/http/httptest" "sync" "testing" "time" "github.com/gin-gonic/gin" ) // mockHandleChatLLM mocks LLM behavior for integration tests // It implements only the public interface methods. type mockHandleChatLLM struct { keywordsResp map[string]interface{} disambigID string keywordsErr error disambigErr error embeddings []float64 embeddingErr error } func (m *mockHandleChatLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]interface{}, error) { return m.keywordsResp, m.keywordsErr } func (m *mockHandleChatLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Visit) (string, error) { return m.disambigID, m.disambigErr } func (m *mockHandleChatLLM) GetEmbeddings(ctx context.Context, input string) ([]float64, error) { return m.embeddings, m.embeddingErr } func (m *mockHandleChatLLM) TranslateToEnglish(ctx context.Context, msg string) (string, error) { return msg, nil } // mapChatRepo is an in-memory implementation of ChatRepositoryAPI for tests. type mapChatRepo struct { mu sync.Mutex interactions []ChatInteraction rawEvents []struct{ CorrelationID, Phase, Raw string } } func (r *mapChatRepo) SaveChatInteraction(ctx context.Context, rec ChatInteraction) error { r.mu.Lock() defer r.mu.Unlock() r.interactions = append(r.interactions, rec) return nil } func (r *mapChatRepo) ListChatInteractions(ctx context.Context, limit, offset int) ([]ChatInteraction, error) { r.mu.Lock() defer r.mu.Unlock() if offset >= len(r.interactions) { return []ChatInteraction{}, nil } end := offset + limit if end > len(r.interactions) { end = len(r.interactions) } // return a copy slice to avoid mutation out := make([]ChatInteraction, end-offset) copy(out, r.interactions[offset:end]) return out, nil } func (r *mapChatRepo) SaveLLMRawEvent(ctx context.Context, correlationID, phase, raw string) error { r.mu.Lock() defer r.mu.Unlock() r.rawEvents = append(r.rawEvents, struct{ CorrelationID, Phase, Raw string }{correlationID, phase, raw}) return nil } func (r *mapChatRepo) ListLLMRawEvents(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error) { r.mu.Lock() defer r.mu.Unlock() var out []RawLLMEvent for _, e := range r.rawEvents { if e.CorrelationID == correlationID { out = append(out, RawLLMEvent{CorrelationID: e.CorrelationID, Phase: e.Phase, RawJSON: e.Raw, CreatedAt: time.Now()}) } } return out, nil } func (r *mapChatRepo) CountUsers(ctx context.Context) (int, error) { return 0, nil } func (r *mapChatRepo) CreateUser(ctx context.Context, username, passwordHash string) error { return nil } func (r *mapChatRepo) GetUserByUsername(ctx context.Context, username string) (*User, error) { return nil, nil } func (r *mapChatRepo) SaveKnowledgeModel(ctx context.Context, text string) error { return nil } func (r *mapChatRepo) ListKnowledgeModels(ctx context.Context, limit, offset int) ([]knowledgeModelMeta, error) { return nil, nil } func (r *mapChatRepo) GetKnowledgeModelText(ctx context.Context, id int64) (string, error) { return "", nil } // testVisitDB2 replicates a minimal VisitDB for integration // (avoids relying on real Bleve index) type testVisitDB2 struct { byID map[string]Visit candidates []Visit findErr error } func (db *testVisitDB2) FindCandidates(keywords []string) ([]Visit, error) { return db.candidates, db.findErr } func (db *testVisitDB2) FindById(id string) (Visit, error) { if v, ok := db.byID[id]; ok { return v, nil } return Visit{}, context.DeadlineExceeded } func TestHandleChat_PersistsSuccessInteraction(t *testing.T) { gin.SetMode(gin.TestMode) visit := Visit{ID: "xray", Notes: "Exam note", Procedures: []Procedure{{Name: "Röntgen vizsgálat", Price: 16000, DurationMin: 25}}} db := &testVisitDB2{byID: map[string]Visit{"xray": visit}, candidates: []Visit{visit}} llm := &mockHandleChatLLM{keywordsResp: map[string]interface{}{"translate": "xray leg", "animal": "dog", "keyword": []string{"xray", "bone"}}, disambigID: "xray"} repo := &mapChatRepo{} cs := NewChatService(llm, db, repo) r := gin.New() r.POST("/chat", cs.HandleChat) body := map[string]string{"message": "my dog needs an x-ray"} b, _ := json.Marshal(body) req, _ := http.NewRequest(http.MethodPost, "/chat", bytes.NewReader(b)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200 got %d", w.Code) } corrID := w.Header().Get("X-Correlation-ID") if corrID == "" { t.Fatalf("expected correlation id header set") } repo.mu.Lock() if len(repo.interactions) != 1 { repo.mu.Unlock() bdy := w.Body.String() t.Fatalf("expected 1 interaction persisted, got %d; body=%s", len(repo.interactions), bdy) } rec := repo.interactions[0] repo.mu.Unlock() if rec.CorrelationID != corrID { t.Errorf("correlation mismatch: header=%s rec=%s", corrID, rec.CorrelationID) } if rec.BestVisitID != "xray" { t.Errorf("expected BestVisitID xray got %s", rec.BestVisitID) } if rec.TotalPrice != 16000 || rec.TotalDuration != 25 { t.Errorf("unexpected totals: %+v", rec) } if len(rec.Keywords) != 2 { t.Errorf("expected 2 keywords got %v", rec.Keywords) } } func TestHandleChat_PersistsOnLLMError(t *testing.T) { gin.SetMode(gin.TestMode) llm := &mockHandleChatLLM{keywordsErr: context.DeadlineExceeded} db := &testVisitDB2{byID: map[string]Visit{}, candidates: []Visit{}} repo := &mapChatRepo{} cs := NewChatService(llm, db, repo) r := gin.New() r.POST("/chat", cs.HandleChat) body := map[string]string{"message": "some message"} b, _ := json.Marshal(body) req, _ := http.NewRequest(http.MethodPost, "/chat", bytes.NewReader(b)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200 got %d", w.Code) } repo.mu.Lock() cnt := len(repo.interactions) var rec ChatInteraction if cnt == 1 { rec = repo.interactions[0] } repo.mu.Unlock() if cnt != 1 { t.Fatalf("expected 1 interaction persisted on error got %d", cnt) } if rec.BestVisitID != "" { t.Errorf("expected no best visit on error got %s", rec.BestVisitID) } cid := w.Header().Get("X-Correlation-ID") if cid == "" { t.Fatalf("expected correlation id header on error path") } }