refactor to add tests
This commit is contained in:
parent
e60f0c0ed6
commit
3876ba502e
144
chat_service.go
144
chat_service.go
|
|
@ -9,34 +9,43 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type ChatService struct {
|
||||
LLM *LLMClient
|
||||
reasonsDB ReasonDB
|
||||
// ChatServiceAPI allows mocking ChatService in other places
|
||||
// Only public methods should be included
|
||||
|
||||
type ChatServiceAPI interface {
|
||||
HandleChat(c *gin.Context)
|
||||
}
|
||||
|
||||
func NewChatService(llm *LLMClient, db ReasonDB) *ChatService {
|
||||
// ChatService handles chat interactions and orchestrates LLM and DB calls.
|
||||
type ChatService struct {
|
||||
LLM LLMClientAPI
|
||||
reasonsDB ReasonDBAPI
|
||||
}
|
||||
|
||||
var _ ChatServiceAPI = (*ChatService)(nil)
|
||||
|
||||
func NewChatService(llm LLMClientAPI, db ReasonDBAPI) ChatServiceAPI {
|
||||
return &ChatService{LLM: llm, reasonsDB: db}
|
||||
}
|
||||
|
||||
// HandleChat is the main entrypoint for chat requests. It delegates to modular helpers.
|
||||
func (cs *ChatService) HandleChat(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
req, err := cs.parseRequest(c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
keywordsResponse, err := cs.getKeywordsResponse(ctx, req.Message)
|
||||
keywords, err := cs.extractKeywords(ctx, req.Message)
|
||||
if err != nil {
|
||||
cs.logChat(req, keywordsResponse, nil, "", err)
|
||||
c.JSON(http.StatusOK, ChatResponse{Match: nil})
|
||||
cs.respondWithError(c, req, keywords, err)
|
||||
return
|
||||
}
|
||||
keywordList := cs.keywordsToStrings(keywordsResponse["keyword"])
|
||||
best, err := cs.findVisitReason(ctx, req, keywordList)
|
||||
|
||||
best, err := cs.findBestReason(ctx, req, keywords)
|
||||
resp := cs.buildResponse(best)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// parseRequest parses and validates the incoming chat request.
|
||||
func (cs *ChatService) parseRequest(c *gin.Context) (ChatRequest, error) {
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
|
|
@ -47,56 +56,36 @@ func (cs *ChatService) parseRequest(c *gin.Context) (ChatRequest, error) {
|
|||
return req, nil
|
||||
}
|
||||
|
||||
func (cs *ChatService) getKeywordsResponse(ctx context.Context, message string) (map[string]interface{}, error) {
|
||||
keywords, err := cs.LLM.ExtractKeywords(ctx, message)
|
||||
return keywords, err
|
||||
}
|
||||
|
||||
func (cs *ChatService) keywordsToStrings(kwIface interface{}) []string {
|
||||
var kwArr []string
|
||||
switch v := kwIface.(type) {
|
||||
case []interface{}:
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
kwArr = append(kwArr, s)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
kwArr = v
|
||||
// extractKeywords gets keywords from the LLM and normalizes them.
|
||||
func (cs *ChatService) extractKeywords(ctx context.Context, message string) ([]string, error) {
|
||||
kwResp, err := cs.LLM.ExtractKeywords(ctx, message)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return kwArr
|
||||
return cs.keywordsToStrings(kwResp["keyword"]), nil
|
||||
}
|
||||
|
||||
func (cs *ChatService) findVisitReason(ctx context.Context, req ChatRequest, keywordList []string) (*Reason, error) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"keywords": keywordList,
|
||||
"message": req.Message,
|
||||
}).Info("Finding visit reason candidates")
|
||||
|
||||
candidateReasonse, err := cs.reasonsDB.findCandidates(keywordList)
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"candidates": candidateReasonse,
|
||||
"error": err,
|
||||
}).Info("Candidate reasons found")
|
||||
// findBestReason finds candidate reasons and disambiguates the best match.
|
||||
func (cs *ChatService) findBestReason(ctx context.Context, req ChatRequest, keywords []string) (*Reason, error) {
|
||||
cs.logKeywords(keywords, req.Message)
|
||||
candidates, err := cs.reasonsDB.FindCandidates(keywords)
|
||||
cs.logCandidates(candidates, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bestID := ""
|
||||
if len(candidateReasonse) > 0 {
|
||||
bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidateReasonse)
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"bestID": bestID,
|
||||
"error": err,
|
||||
}).Info("Disambiguated best match")
|
||||
if len(candidates) > 0 {
|
||||
bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidates)
|
||||
cs.logBestID(bestID, err)
|
||||
}
|
||||
reason, err := cs.reasonsDB.findById(bestID)
|
||||
reason, err := cs.reasonsDB.FindById(bestID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("findById: %w", err)
|
||||
return nil, fmt.Errorf("FindById: %w", err)
|
||||
}
|
||||
return &reason, nil
|
||||
|
||||
}
|
||||
|
||||
// buildResponse constructs the ChatResponse from the best Reason.
|
||||
func (cs *ChatService) buildResponse(best *Reason) ChatResponse {
|
||||
if best == nil {
|
||||
resp := ChatResponse{Match: nil}
|
||||
|
|
@ -115,8 +104,65 @@ func (cs *ChatService) buildResponse(best *Reason) ChatResponse {
|
|||
return resp
|
||||
}
|
||||
|
||||
func (cs *ChatService) logChat(req ChatRequest, keywords map[string]interface{}, candidates []Reason, bestID string, err error) {
|
||||
logRequest(req, keywords, candidates, bestID, err)
|
||||
// respondWithError logs and responds with error details.
|
||||
func (cs *ChatService) respondWithError(c *gin.Context, req ChatRequest, keywords []string, err error) {
|
||||
kwMap := map[string]interface{}{"keyword": keywords}
|
||||
cs.logChat(req, kwMap, nil, "", err)
|
||||
c.JSON(http.StatusOK, ChatResponse{Match: nil})
|
||||
}
|
||||
|
||||
// keywordsToStrings normalizes keyword interface to []string.
|
||||
func (cs *ChatService) keywordsToStrings(kwIface interface{}) []string {
|
||||
var kwArr []string
|
||||
switch v := kwIface.(type) {
|
||||
case []interface{}:
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
kwArr = append(kwArr, s)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
kwArr = v
|
||||
}
|
||||
return kwArr
|
||||
}
|
||||
|
||||
// logKeywords logs extracted keywords.
|
||||
func (cs *ChatService) logKeywords(keywords []string, message string) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"keywords": keywords,
|
||||
"message": message,
|
||||
}).Info("Finding visit reason candidates")
|
||||
}
|
||||
|
||||
// logCandidates logs candidate reasons.
|
||||
func (cs *ChatService) logCandidates(candidates []Reason, err error) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"candidates": candidates,
|
||||
"error": err,
|
||||
}).Info("Candidate reasons found")
|
||||
}
|
||||
|
||||
// logBestID logs the best candidate ID.
|
||||
func (cs *ChatService) logBestID(bestID string, err error) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"bestID": bestID,
|
||||
"error": err,
|
||||
}).Info("Disambiguated best match")
|
||||
}
|
||||
|
||||
// logChat logs the chat request and result details.
|
||||
func (cs *ChatService) logChat(req ChatRequest, keywords interface{}, candidates []Reason, bestID string, err error) {
|
||||
var kwMap map[string]interface{}
|
||||
switch v := keywords.(type) {
|
||||
case []string:
|
||||
kwMap = map[string]interface{}{"keyword": v}
|
||||
case map[string]interface{}:
|
||||
kwMap = v
|
||||
default:
|
||||
kwMap = map[string]interface{}{"keyword": v}
|
||||
}
|
||||
logRequest(req, kwMap, candidates, bestID, err)
|
||||
if candidates != nil && bestID != "" {
|
||||
var best *Reason
|
||||
for i := range candidates {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,176 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// --- Mocks ---
|
||||
type mockLLM struct {
|
||||
keywordsResp map[string]interface{}
|
||||
disambigID string
|
||||
keywordsErr error
|
||||
disambigErr error
|
||||
}
|
||||
|
||||
var _ LLMClientAPI = (*mockLLM)(nil)
|
||||
|
||||
func (m *mockLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]interface{}, error) {
|
||||
return m.keywordsResp, m.keywordsErr
|
||||
}
|
||||
func (m *mockLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Reason) (string, error) {
|
||||
return m.disambigID, m.disambigErr
|
||||
}
|
||||
|
||||
// --- Test ReasonDB ---
|
||||
type testReasonDB struct {
|
||||
candidates []Reason
|
||||
findErr error
|
||||
byID map[string]Reason
|
||||
}
|
||||
|
||||
var _ ReasonDBAPI = (*testReasonDB)(nil)
|
||||
|
||||
func (db *testReasonDB) FindCandidates(keywords []string) ([]Reason, error) {
|
||||
return db.candidates, db.findErr
|
||||
}
|
||||
func (db *testReasonDB) FindById(id string) (Reason, error) {
|
||||
r, ok := db.byID[id]
|
||||
if !ok {
|
||||
return Reason{}, context.DeadlineExceeded
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// --- Integration tests ---
|
||||
func TestChatService_MatchFound(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
var llm LLMClientAPI = &mockLLM{
|
||||
keywordsResp: map[string]interface{}{"keyword": []string{"worms", "deworming"}},
|
||||
disambigID: "deworming",
|
||||
}
|
||||
reason := Reason{
|
||||
ID: "deworming",
|
||||
Procedures: []Procedure{{Name: "Deworming tablet", Price: 30, DurationMin: 10}},
|
||||
Notes: "Bloodwork ensures organs are safe for treatment.",
|
||||
}
|
||||
var db ReasonDBAPI = &testReasonDB{
|
||||
candidates: []Reason{reason},
|
||||
byID: map[string]Reason{"deworming": reason},
|
||||
}
|
||||
var cs ChatServiceAPI = NewChatService(llm, db)
|
||||
r := gin.New()
|
||||
r.POST("/chat", cs.HandleChat)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
body := map[string]string{"message": "My dog needs deworming"}
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
var resp ChatResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Invalid response: %v", err)
|
||||
}
|
||||
if resp.Match == nil || *resp.Match != "deworming" {
|
||||
t.Errorf("Expected match 'deworming', got %v", resp.Match)
|
||||
}
|
||||
if len(resp.Procedures) != 1 || resp.Procedures[0].Name != "Deworming tablet" {
|
||||
t.Errorf("Expected procedure 'Deworming tablet', got %+v", resp.Procedures)
|
||||
}
|
||||
if resp.Notes != reason.Notes {
|
||||
t.Errorf("Expected notes '%s', got '%s'", reason.Notes, resp.Notes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatService_NoMatch(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
llm := &mockLLM{
|
||||
keywordsResp: map[string]interface{}{"keyword": []string{"unknown"}},
|
||||
disambigID: "",
|
||||
}
|
||||
db := &testReasonDB{
|
||||
candidates: []Reason{},
|
||||
byID: map[string]Reason{},
|
||||
}
|
||||
cs := NewChatService(llm, db)
|
||||
r := gin.New()
|
||||
r.POST("/chat", cs.HandleChat)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
body := map[string]string{"message": "No known reason"}
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
var resp ChatResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Invalid response: %v", err)
|
||||
}
|
||||
if resp.Match != nil {
|
||||
t.Errorf("Expected no match, got %v", resp.Match)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatService_LLMError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
llm := &mockLLM{
|
||||
keywordsErr: context.DeadlineExceeded,
|
||||
}
|
||||
db := &testReasonDB{}
|
||||
cs := NewChatService(llm, db)
|
||||
r := gin.New()
|
||||
r.POST("/chat", cs.HandleChat)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
body := map[string]string{"message": "error"}
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
var resp ChatResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Invalid response: %v", err)
|
||||
}
|
||||
if resp.Match != nil {
|
||||
t.Errorf("Expected no match on LLM error, got %v", resp.Match)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper for request body ---
|
||||
func httptestNewBody(b []byte) *httptestBody {
|
||||
return &httptestBody{b: b}
|
||||
}
|
||||
|
||||
type httptestBody struct {
|
||||
b []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func (r *httptestBody) Read(p []byte) (int, error) {
|
||||
if r.pos >= len(r.b) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, r.b[r.pos:])
|
||||
r.pos += n
|
||||
return n, nil
|
||||
}
|
||||
func (r *httptestBody) Close() error { return nil }
|
||||
16
db.go
16
db.go
|
|
@ -22,7 +22,7 @@ func NewReasonDB() ReasonDB {
|
|||
db.init()
|
||||
return db
|
||||
}
|
||||
func (rdb *ReasonDB) findById(reasonId string) (Reason, error) {
|
||||
func (rdb *ReasonDB) FindById(reasonId string) (Reason, error) {
|
||||
for _, reason := range rdb.reasonsDB {
|
||||
if reason.ID == reasonId {
|
||||
return reason, nil
|
||||
|
|
@ -73,8 +73,8 @@ func (rdb *ReasonDB) indexReasons(reasons []Reason) error {
|
|||
return rdb.reasonsIdx.Batch(batch)
|
||||
}
|
||||
|
||||
// findCandidates returns reasons with overlapping keywords
|
||||
func (rdb *ReasonDB) findCandidates(keywords []string) ([]Reason, error) {
|
||||
// FindCandidates returns reasons with overlapping keywords
|
||||
func (rdb *ReasonDB) FindCandidates(keywords []string) ([]Reason, error) {
|
||||
query := bleve.NewMatchQuery(strings.Join(keywords, " "))
|
||||
search := bleve.NewSearchRequest(query)
|
||||
searchResults, err := rdb.reasonsIdx.Search(search)
|
||||
|
|
@ -104,3 +104,13 @@ func sumProcedures(procs []Procedure) (int, int) {
|
|||
}
|
||||
return totalPrice, totalDuration
|
||||
}
|
||||
|
||||
// ReasonDBAPI allows mocking ReasonDB in other places
|
||||
// Only public methods should be included
|
||||
|
||||
type ReasonDBAPI interface {
|
||||
FindById(reasonId string) (Reason, error)
|
||||
FindCandidates(keywords []string) ([]Reason, error)
|
||||
}
|
||||
|
||||
var _ ReasonDBAPI = (*ReasonDB)(nil)
|
||||
|
|
|
|||
18
llm.go
18
llm.go
|
|
@ -18,6 +18,14 @@ type LLMClient struct {
|
|||
BaseURL string
|
||||
}
|
||||
|
||||
// NewLLMClient constructs a new LLMClient with the given API key and base URL
|
||||
func NewLLMClient(apiKey, baseURL string) *LLMClient {
|
||||
return &LLMClient{
|
||||
APIKey: apiKey,
|
||||
BaseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
// renderPrompt renders a Go template with the given data
|
||||
func renderPrompt(tmplStr string, data any) (string, error) {
|
||||
tmpl, err := template.New("").Parse(tmplStr)
|
||||
|
|
@ -133,3 +141,13 @@ func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string, forma
|
|||
logrus.WithField("content", result.Message.Content).Info("[LLM] openAICompletion: got content")
|
||||
return result.Message.Content, nil
|
||||
}
|
||||
|
||||
// LLMClientAPI allows mocking LLMClient in other places
|
||||
// Only public methods should be included
|
||||
|
||||
type LLMClientAPI interface {
|
||||
ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error)
|
||||
DisambiguateBestMatch(ctx context.Context, message string, candidates []Reason) (string, error)
|
||||
}
|
||||
|
||||
var _ LLMClientAPI = (*LLMClient)(nil)
|
||||
|
|
|
|||
10
main.go
10
main.go
|
|
@ -20,11 +20,11 @@ func main() {
|
|||
if err := loadUITemplate("ui.html"); err != nil {
|
||||
logrus.Fatalf("Failed to load ui.html: %v", err)
|
||||
}
|
||||
llm := &LLMClient{
|
||||
APIKey: os.Getenv("OPENAI_API_KEY"),
|
||||
BaseURL: os.Getenv("OPENAI_BASE_URL"),
|
||||
}
|
||||
chatService := NewChatService(llm, reasonDB)
|
||||
var llm LLMClientAPI = NewLLMClient(
|
||||
os.Getenv("OPENAI_API_KEY"),
|
||||
os.Getenv("OPENAI_BASE_URL"),
|
||||
)
|
||||
chatService := NewChatService(llm, &reasonDB)
|
||||
r := gin.Default()
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Status(200)
|
||||
|
|
|
|||
145
main_test.go
145
main_test.go
|
|
@ -1,145 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type testDB struct {
|
||||
file string
|
||||
data string
|
||||
}
|
||||
|
||||
func (tdb *testDB) setup() {
|
||||
err := os.WriteFile(tdb.file, []byte(tdb.data), 0644)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (tdb *testDB) teardown() {
|
||||
_ = os.Remove(tdb.file)
|
||||
}
|
||||
|
||||
func TestChatEndpoint_MatchFound(t *testing.T) {
|
||||
tdb := testDB{
|
||||
file: "db.yaml",
|
||||
data: `
|
||||
- id: deworming
|
||||
reason: Deworming for dogs
|
||||
keywords: ["worms", "deworming", "parasite"]
|
||||
procedures:
|
||||
- name: Deworming tablet
|
||||
price: 30
|
||||
duration_minutes: 10
|
||||
- name: Bloodwork
|
||||
price: 35
|
||||
duration_minutes: 35
|
||||
notes: Bloodwork ensures organs are safe for treatment.
|
||||
- id: vaccination
|
||||
reason: Annual vaccination
|
||||
keywords: ["vaccine", "vaccination", "shots"]
|
||||
procedures:
|
||||
- name: Vaccine injection
|
||||
price: 50
|
||||
duration_minutes: 15
|
||||
`,
|
||||
}
|
||||
tdb.setup()
|
||||
defer tdb.teardown()
|
||||
|
||||
if err := loadYAMLDB(tdb.file); err != nil {
|
||||
t.Fatalf("Failed to load test db: %v", err)
|
||||
}
|
||||
|
||||
r := setupRouter()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
body := map[string]string{"message": "My dog needs deworming and bloodwork"}
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("POST", "/chat", bytes.NewBuffer(jsonBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
respBody, _ := io.ReadAll(w.Body)
|
||||
if !bytes.Contains(respBody, []byte("deworming")) {
|
||||
t.Errorf("Expected match for deworming, got %s", string(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatEndpoint_NoMatch(t *testing.T) {
|
||||
tdb := testDB{
|
||||
file: "db.yaml",
|
||||
data: `
|
||||
- id: vaccination
|
||||
reason: Annual vaccination
|
||||
keywords: ["vaccine", "vaccination", "shots"]
|
||||
procedures:
|
||||
- name: Vaccine injection
|
||||
price: 50
|
||||
duration_minutes: 15
|
||||
`,
|
||||
}
|
||||
tdb.setup()
|
||||
defer tdb.teardown()
|
||||
|
||||
if err := loadYAMLDB(tdb.file); err != nil {
|
||||
t.Fatalf("Failed to load test db: %v", err)
|
||||
}
|
||||
|
||||
r := setupRouter()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
body := map[string]string{"message": "My dog has worms"}
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("POST", "/chat", bytes.NewBuffer(jsonBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
respBody, _ := io.ReadAll(w.Body)
|
||||
if !bytes.Contains(respBody, []byte(`"match":null`)) {
|
||||
t.Errorf("Expected no match, got %s", string(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
func setupRouter() *gin.Engine {
|
||||
r := gin.Default()
|
||||
r.POST("/chat", func(c *gin.Context) {
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
keywords := naiveKeywordExtract(req.Message)
|
||||
candidates := findCandidates(keywords)
|
||||
if len(candidates) == 0 {
|
||||
c.JSON(http.StatusOK, ChatResponse{Match: nil})
|
||||
return
|
||||
}
|
||||
best := candidates[0]
|
||||
totalPrice, totalDuration := sumProcedures(best.Procedures)
|
||||
c.JSON(http.StatusOK, ChatResponse{
|
||||
Match: &best.ID,
|
||||
Procedures: best.Procedures,
|
||||
TotalPrice: totalPrice,
|
||||
TotalDuration: totalDuration,
|
||||
Notes: best.Notes,
|
||||
})
|
||||
})
|
||||
return r
|
||||
}
|
||||
Loading…
Reference in New Issue