vetrag/main_test.go

146 lines
3.2 KiB
Go

package main
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/gin-gonic/gin"
)
type testDB struct {
file string
data string
}
func (tdb *testDB) setup() {
err := os.WriteFile(tdb.file, []byte(tdb.data), 0644)
if err != nil {
panic(err)
}
}
func (tdb *testDB) teardown() {
_ = os.Remove(tdb.file)
}
func TestChatEndpoint_MatchFound(t *testing.T) {
tdb := testDB{
file: "db.yaml",
data: `
- id: deworming
reason: Deworming for dogs
keywords: ["worms", "deworming", "parasite"]
procedures:
- name: Deworming tablet
price: 30
duration_minutes: 10
- name: Bloodwork
price: 35
duration_minutes: 35
notes: Bloodwork ensures organs are safe for treatment.
- id: vaccination
reason: Annual vaccination
keywords: ["vaccine", "vaccination", "shots"]
procedures:
- name: Vaccine injection
price: 50
duration_minutes: 15
`,
}
tdb.setup()
defer tdb.teardown()
if err := loadYAMLDB(tdb.file); err != nil {
t.Fatalf("Failed to load test db: %v", err)
}
r := setupRouter()
w := httptest.NewRecorder()
body := map[string]string{"message": "My dog needs deworming and bloodwork"}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequest("POST", "/chat", bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
respBody, _ := io.ReadAll(w.Body)
if !bytes.Contains(respBody, []byte("deworming")) {
t.Errorf("Expected match for deworming, got %s", string(respBody))
}
}
func TestChatEndpoint_NoMatch(t *testing.T) {
tdb := testDB{
file: "db.yaml",
data: `
- id: vaccination
reason: Annual vaccination
keywords: ["vaccine", "vaccination", "shots"]
procedures:
- name: Vaccine injection
price: 50
duration_minutes: 15
`,
}
tdb.setup()
defer tdb.teardown()
if err := loadYAMLDB(tdb.file); err != nil {
t.Fatalf("Failed to load test db: %v", err)
}
r := setupRouter()
w := httptest.NewRecorder()
body := map[string]string{"message": "My dog has worms"}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequest("POST", "/chat", bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
respBody, _ := io.ReadAll(w.Body)
if !bytes.Contains(respBody, []byte(`"match":null`)) {
t.Errorf("Expected no match, got %s", string(respBody))
}
}
func setupRouter() *gin.Engine {
r := gin.Default()
r.POST("/chat", func(c *gin.Context) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
keywords := naiveKeywordExtract(req.Message)
candidates := findCandidates(keywords)
if len(candidates) == 0 {
c.JSON(http.StatusOK, ChatResponse{Match: nil})
return
}
best := candidates[0]
totalPrice, totalDuration := sumProcedures(best.Procedures)
c.JSON(http.StatusOK, ChatResponse{
Match: &best.ID,
Procedures: best.Procedures,
TotalPrice: totalPrice,
TotalDuration: totalDuration,
Notes: best.Notes,
})
})
return r
}