177 lines
4.5 KiB
Go
177 lines
4.5 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// --- Mocks ---
|
|
type mockLLM struct {
|
|
keywordsResp map[string]interface{}
|
|
disambigID string
|
|
keywordsErr error
|
|
disambigErr error
|
|
}
|
|
|
|
var _ LLMClientAPI = (*mockLLM)(nil)
|
|
|
|
func (m *mockLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]interface{}, error) {
|
|
return m.keywordsResp, m.keywordsErr
|
|
}
|
|
func (m *mockLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Visit) (string, error) {
|
|
return m.disambigID, m.disambigErr
|
|
}
|
|
|
|
// --- Test VisitDB ---
|
|
type testVisitDB struct {
|
|
candidates []Visit
|
|
findErr error
|
|
byID map[string]Visit
|
|
}
|
|
|
|
var _ VisitDBAPI = (*testVisitDB)(nil)
|
|
|
|
func (db *testVisitDB) FindCandidates(keywords []string) ([]Visit, error) {
|
|
return db.candidates, db.findErr
|
|
}
|
|
func (db *testVisitDB) FindById(id string) (Visit, error) {
|
|
r, ok := db.byID[id]
|
|
if !ok {
|
|
return Visit{}, context.DeadlineExceeded
|
|
}
|
|
return r, nil
|
|
}
|
|
|
|
// --- Integration tests ---
|
|
func TestChatService_MatchFound(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
var llm LLMClientAPI = &mockLLM{
|
|
keywordsResp: map[string]interface{}{"keyword": []string{"worms", "deworming"}},
|
|
disambigID: "deworming",
|
|
}
|
|
visit := Visit{
|
|
ID: "deworming",
|
|
Procedures: []Procedure{{Name: "Deworming tablet", Price: 30, DurationMin: 10}},
|
|
Notes: "Bloodwork ensures organs are safe for treatment.",
|
|
}
|
|
var db VisitDBAPI = &testVisitDB{
|
|
candidates: []Visit{visit},
|
|
byID: map[string]Visit{"deworming": visit},
|
|
}
|
|
var cs ChatServiceAPI = NewChatService(llm, db)
|
|
r := gin.New()
|
|
r.POST("/chat", cs.HandleChat)
|
|
|
|
w := httptest.NewRecorder()
|
|
body := map[string]string{"message": "My dog needs deworming"}
|
|
jsonBody, _ := json.Marshal(body)
|
|
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
r.ServeHTTP(w, req)
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("Expected 200, got %d", w.Code)
|
|
}
|
|
var resp ChatResponse
|
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("Invalid response: %v", err)
|
|
}
|
|
if resp.Match == nil || *resp.Match != "deworming" {
|
|
t.Errorf("Expected match 'deworming', got %v", resp.Match)
|
|
}
|
|
if len(resp.Procedures) != 1 || resp.Procedures[0].Name != "Deworming tablet" {
|
|
t.Errorf("Expected procedure 'Deworming tablet', got %+v", resp.Procedures)
|
|
}
|
|
if resp.Notes != visit.Notes {
|
|
t.Errorf("Expected notes '%s', got '%s'", visit.Notes, resp.Notes)
|
|
}
|
|
}
|
|
|
|
func TestChatService_NoMatch(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
llm := &mockLLM{
|
|
keywordsResp: map[string]interface{}{"keyword": []string{"unknown"}},
|
|
disambigID: "",
|
|
}
|
|
db := &testVisitDB{
|
|
candidates: []Visit{},
|
|
byID: map[string]Visit{},
|
|
}
|
|
cs := NewChatService(llm, db)
|
|
r := gin.New()
|
|
r.POST("/chat", cs.HandleChat)
|
|
|
|
w := httptest.NewRecorder()
|
|
body := map[string]string{"message": "No known reason"}
|
|
jsonBody, _ := json.Marshal(body)
|
|
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
r.ServeHTTP(w, req)
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("Expected 200, got %d", w.Code)
|
|
}
|
|
var resp ChatResponse
|
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("Invalid response: %v", err)
|
|
}
|
|
if resp.Match != nil {
|
|
t.Errorf("Expected no match, got %v", resp.Match)
|
|
}
|
|
}
|
|
|
|
func TestChatService_LLMError(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
llm := &mockLLM{
|
|
keywordsErr: context.DeadlineExceeded,
|
|
}
|
|
db := &testVisitDB{}
|
|
cs := NewChatService(llm, db)
|
|
r := gin.New()
|
|
r.POST("/chat", cs.HandleChat)
|
|
|
|
w := httptest.NewRecorder()
|
|
body := map[string]string{"message": "error"}
|
|
jsonBody, _ := json.Marshal(body)
|
|
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
r.ServeHTTP(w, req)
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("Expected 200, got %d", w.Code)
|
|
}
|
|
var resp ChatResponse
|
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("Invalid response: %v", err)
|
|
}
|
|
if resp.Match != nil {
|
|
t.Errorf("Expected no match on LLM error, got %v", resp.Match)
|
|
}
|
|
}
|
|
|
|
// --- Helper for request body ---
|
|
func httptestNewBody(b []byte) *httptestBody {
|
|
return &httptestBody{b: b}
|
|
}
|
|
|
|
type httptestBody struct {
|
|
b []byte
|
|
pos int
|
|
}
|
|
|
|
func (r *httptestBody) Read(p []byte) (int, error) {
|
|
if r.pos >= len(r.b) {
|
|
return 0, io.EOF
|
|
}
|
|
n := copy(p, r.b[r.pos:])
|
|
r.pos += n
|
|
return n, nil
|
|
}
|
|
func (r *httptestBody) Close() error { return nil }
|