package main import ( "context" "encoding/json" "io" "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" ) // --- Mocks --- type mockLLM struct { keywordsResp map[string]interface{} disambigID string keywordsErr error disambigErr error } var _ LLMClientAPI = (*mockLLM)(nil) func (m *mockLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]interface{}, error) { return m.keywordsResp, m.keywordsErr } func (m *mockLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Visit) (string, error) { return m.disambigID, m.disambigErr } // --- Test VisitDB --- type testVisitDB struct { candidates []Visit findErr error byID map[string]Visit } var _ VisitDBAPI = (*testVisitDB)(nil) func (db *testVisitDB) FindCandidates(keywords []string) ([]Visit, error) { return db.candidates, db.findErr } func (db *testVisitDB) FindById(id string) (Visit, error) { r, ok := db.byID[id] if !ok { return Visit{}, context.DeadlineExceeded } return r, nil } // --- Integration tests --- func TestChatService_MatchFound(t *testing.T) { gin.SetMode(gin.TestMode) var llm LLMClientAPI = &mockLLM{ keywordsResp: map[string]interface{}{"keyword": []string{"worms", "deworming"}}, disambigID: "deworming", } visit := Visit{ ID: "deworming", Procedures: []Procedure{{Name: "Deworming tablet", Price: 30, DurationMin: 10}}, Notes: "Bloodwork ensures organs are safe for treatment.", } var db VisitDBAPI = &testVisitDB{ candidates: []Visit{visit}, byID: map[string]Visit{"deworming": visit}, } var cs ChatServiceAPI = NewChatService(llm, db) r := gin.New() r.POST("/chat", cs.HandleChat) w := httptest.NewRecorder() body := map[string]string{"message": "My dog needs deworming"} jsonBody, _ := json.Marshal(body) req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("Expected 200, got %d", w.Code) } var resp ChatResponse if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("Invalid response: %v", err) } if resp.Match == nil || *resp.Match != "deworming" { t.Errorf("Expected match 'deworming', got %v", resp.Match) } if len(resp.Procedures) != 1 || resp.Procedures[0].Name != "Deworming tablet" { t.Errorf("Expected procedure 'Deworming tablet', got %+v", resp.Procedures) } if resp.Notes != visit.Notes { t.Errorf("Expected notes '%s', got '%s'", visit.Notes, resp.Notes) } } func TestChatService_NoMatch(t *testing.T) { gin.SetMode(gin.TestMode) llm := &mockLLM{ keywordsResp: map[string]interface{}{"keyword": []string{"unknown"}}, disambigID: "", } db := &testVisitDB{ candidates: []Visit{}, byID: map[string]Visit{}, } cs := NewChatService(llm, db) r := gin.New() r.POST("/chat", cs.HandleChat) w := httptest.NewRecorder() body := map[string]string{"message": "No known reason"} jsonBody, _ := json.Marshal(body) req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("Expected 200, got %d", w.Code) } var resp ChatResponse if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("Invalid response: %v", err) } if resp.Match != nil { t.Errorf("Expected no match, got %v", resp.Match) } } func TestChatService_LLMError(t *testing.T) { gin.SetMode(gin.TestMode) llm := &mockLLM{ keywordsErr: context.DeadlineExceeded, } db := &testVisitDB{} cs := NewChatService(llm, db) r := gin.New() r.POST("/chat", cs.HandleChat) w := httptest.NewRecorder() body := map[string]string{"message": "error"} jsonBody, _ := json.Marshal(body) req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("Expected 200, got %d", w.Code) } var resp ChatResponse if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("Invalid response: %v", err) } if resp.Match != nil { t.Errorf("Expected no match on LLM error, got %v", resp.Match) } } // --- Helper for request body --- func httptestNewBody(b []byte) *httptestBody { return &httptestBody{b: b} } type httptestBody struct { b []byte pos int } func (r *httptestBody) Read(p []byte) (int, error) { if r.pos >= len(r.b) { return 0, io.EOF } n := copy(p, r.b[r.pos:]) r.pos += n return n, nil } func (r *httptestBody) Close() error { return nil }