vetrag/chat_service_integration_te...

177 lines
4.5 KiB
Go

package main
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
// --- Mocks ---
type mockLLM struct {
keywordsResp map[string]interface{}
disambigID string
keywordsErr error
disambigErr error
}
var _ LLMClientAPI = (*mockLLM)(nil)
func (m *mockLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]interface{}, error) {
return m.keywordsResp, m.keywordsErr
}
func (m *mockLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Visit) (string, error) {
return m.disambigID, m.disambigErr
}
// --- Test VisitDB ---
type testVisitDB struct {
candidates []Visit
findErr error
byID map[string]Visit
}
var _ VisitDBAPI = (*testVisitDB)(nil)
func (db *testVisitDB) FindCandidates(keywords []string) ([]Visit, error) {
return db.candidates, db.findErr
}
func (db *testVisitDB) FindById(id string) (Visit, error) {
r, ok := db.byID[id]
if !ok {
return Visit{}, context.DeadlineExceeded
}
return r, nil
}
// --- Integration tests ---
func TestChatService_MatchFound(t *testing.T) {
gin.SetMode(gin.TestMode)
var llm LLMClientAPI = &mockLLM{
keywordsResp: map[string]interface{}{"keyword": []string{"worms", "deworming"}},
disambigID: "deworming",
}
visit := Visit{
ID: "deworming",
Procedures: []Procedure{{Name: "Deworming tablet", Price: 30, DurationMin: 10, Note: ""}},
Notes: "Bloodwork ensures organs are safe for treatment.",
}
var db VisitDBAPI = &testVisitDB{
candidates: []Visit{visit},
byID: map[string]Visit{"deworming": visit},
}
var cs ChatServiceAPI = NewChatService(llm, db, nil)
r := gin.New()
r.POST("/chat", cs.HandleChat)
w := httptest.NewRecorder()
body := map[string]string{"message": "My dog needs deworming"}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
var resp ChatResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("Invalid response: %v", err)
}
if resp.Match == nil || *resp.Match != "deworming" {
t.Errorf("Expected match 'deworming', got %v", resp.Match)
}
if len(resp.Procedures) != 1 || resp.Procedures[0].Name != "Deworming tablet" {
t.Errorf("Expected procedure 'Deworming tablet', got %+v", resp.Procedures)
}
if resp.Notes != visit.Notes {
t.Errorf("Expected notes '%s', got '%s'", visit.Notes, resp.Notes)
}
}
func TestChatService_NoMatch(t *testing.T) {
gin.SetMode(gin.TestMode)
llm := &mockLLM{
keywordsResp: map[string]interface{}{"keyword": []string{"unknown"}},
disambigID: "",
}
db := &testVisitDB{
candidates: []Visit{},
byID: map[string]Visit{},
}
cs := NewChatService(llm, db, nil)
r := gin.New()
r.POST("/chat", cs.HandleChat)
w := httptest.NewRecorder()
body := map[string]string{"message": "No known reason"}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
var resp ChatResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("Invalid response: %v", err)
}
if resp.Match != nil {
t.Errorf("Expected no match, got %v", resp.Match)
}
}
func TestChatService_LLMError(t *testing.T) {
gin.SetMode(gin.TestMode)
llm := &mockLLM{
keywordsErr: context.DeadlineExceeded,
}
db := &testVisitDB{}
cs := NewChatService(llm, db, nil)
r := gin.New()
r.POST("/chat", cs.HandleChat)
w := httptest.NewRecorder()
body := map[string]string{"message": "error"}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
var resp ChatResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("Invalid response: %v", err)
}
if resp.Match != nil {
t.Errorf("Expected no match on LLM error, got %v", resp.Match)
}
}
// --- Helper for request body ---
func httptestNewBody(b []byte) *httptestBody {
return &httptestBody{b: b}
}
type httptestBody struct {
b []byte
pos int
}
func (r *httptestBody) Read(p []byte) (int, error) {
if r.pos >= len(r.b) {
return 0, io.EOF
}
n := copy(p, r.b[r.pos:])
r.pos += n
return n, nil
}
func (r *httptestBody) Close() error { return nil }