diff --git a/chat_service.go b/chat_service.go index e8e6bb7..a444618 100644 --- a/chat_service.go +++ b/chat_service.go @@ -9,34 +9,43 @@ import ( "github.com/sirupsen/logrus" ) -type ChatService struct { - LLM *LLMClient - reasonsDB ReasonDB +// ChatServiceAPI allows mocking ChatService in other places +// Only public methods should be included + +type ChatServiceAPI interface { + HandleChat(c *gin.Context) } -func NewChatService(llm *LLMClient, db ReasonDB) *ChatService { +// ChatService handles chat interactions and orchestrates LLM and DB calls. +type ChatService struct { + LLM LLMClientAPI + reasonsDB ReasonDBAPI +} + +var _ ChatServiceAPI = (*ChatService)(nil) + +func NewChatService(llm LLMClientAPI, db ReasonDBAPI) ChatServiceAPI { return &ChatService{LLM: llm, reasonsDB: db} } +// HandleChat is the main entrypoint for chat requests. It delegates to modular helpers. func (cs *ChatService) HandleChat(c *gin.Context) { ctx := context.Background() req, err := cs.parseRequest(c) if err != nil { return } - keywordsResponse, err := cs.getKeywordsResponse(ctx, req.Message) + keywords, err := cs.extractKeywords(ctx, req.Message) if err != nil { - cs.logChat(req, keywordsResponse, nil, "", err) - c.JSON(http.StatusOK, ChatResponse{Match: nil}) + cs.respondWithError(c, req, keywords, err) return } - keywordList := cs.keywordsToStrings(keywordsResponse["keyword"]) - best, err := cs.findVisitReason(ctx, req, keywordList) - + best, err := cs.findBestReason(ctx, req, keywords) resp := cs.buildResponse(best) c.JSON(http.StatusOK, resp) } +// parseRequest parses and validates the incoming chat request. func (cs *ChatService) parseRequest(c *gin.Context) (ChatRequest, error) { var req ChatRequest if err := c.ShouldBindJSON(&req); err != nil { @@ -47,56 +56,36 @@ func (cs *ChatService) parseRequest(c *gin.Context) (ChatRequest, error) { return req, nil } -func (cs *ChatService) getKeywordsResponse(ctx context.Context, message string) (map[string]interface{}, error) { - keywords, err := cs.LLM.ExtractKeywords(ctx, message) - return keywords, err -} - -func (cs *ChatService) keywordsToStrings(kwIface interface{}) []string { - var kwArr []string - switch v := kwIface.(type) { - case []interface{}: - for _, item := range v { - if s, ok := item.(string); ok { - kwArr = append(kwArr, s) - } - } - case []string: - kwArr = v +// extractKeywords gets keywords from the LLM and normalizes them. +func (cs *ChatService) extractKeywords(ctx context.Context, message string) ([]string, error) { + kwResp, err := cs.LLM.ExtractKeywords(ctx, message) + if err != nil { + return nil, err } - return kwArr + return cs.keywordsToStrings(kwResp["keyword"]), nil } -func (cs *ChatService) findVisitReason(ctx context.Context, req ChatRequest, keywordList []string) (*Reason, error) { - logrus.WithFields(logrus.Fields{ - "keywords": keywordList, - "message": req.Message, - }).Info("Finding visit reason candidates") - - candidateReasonse, err := cs.reasonsDB.findCandidates(keywordList) - logrus.WithFields(logrus.Fields{ - "candidates": candidateReasonse, - "error": err, - }).Info("Candidate reasons found") +// findBestReason finds candidate reasons and disambiguates the best match. +func (cs *ChatService) findBestReason(ctx context.Context, req ChatRequest, keywords []string) (*Reason, error) { + cs.logKeywords(keywords, req.Message) + candidates, err := cs.reasonsDB.FindCandidates(keywords) + cs.logCandidates(candidates, err) if err != nil { return nil, err } bestID := "" - if len(candidateReasonse) > 0 { - bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidateReasonse) - logrus.WithFields(logrus.Fields{ - "bestID": bestID, - "error": err, - }).Info("Disambiguated best match") + if len(candidates) > 0 { + bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidates) + cs.logBestID(bestID, err) } - reason, err := cs.reasonsDB.findById(bestID) + reason, err := cs.reasonsDB.FindById(bestID) if err != nil { - return nil, fmt.Errorf("findById: %w", err) + return nil, fmt.Errorf("FindById: %w", err) } return &reason, nil - } +// buildResponse constructs the ChatResponse from the best Reason. func (cs *ChatService) buildResponse(best *Reason) ChatResponse { if best == nil { resp := ChatResponse{Match: nil} @@ -115,8 +104,65 @@ func (cs *ChatService) buildResponse(best *Reason) ChatResponse { return resp } -func (cs *ChatService) logChat(req ChatRequest, keywords map[string]interface{}, candidates []Reason, bestID string, err error) { - logRequest(req, keywords, candidates, bestID, err) +// respondWithError logs and responds with error details. +func (cs *ChatService) respondWithError(c *gin.Context, req ChatRequest, keywords []string, err error) { + kwMap := map[string]interface{}{"keyword": keywords} + cs.logChat(req, kwMap, nil, "", err) + c.JSON(http.StatusOK, ChatResponse{Match: nil}) +} + +// keywordsToStrings normalizes keyword interface to []string. +func (cs *ChatService) keywordsToStrings(kwIface interface{}) []string { + var kwArr []string + switch v := kwIface.(type) { + case []interface{}: + for _, item := range v { + if s, ok := item.(string); ok { + kwArr = append(kwArr, s) + } + } + case []string: + kwArr = v + } + return kwArr +} + +// logKeywords logs extracted keywords. +func (cs *ChatService) logKeywords(keywords []string, message string) { + logrus.WithFields(logrus.Fields{ + "keywords": keywords, + "message": message, + }).Info("Finding visit reason candidates") +} + +// logCandidates logs candidate reasons. +func (cs *ChatService) logCandidates(candidates []Reason, err error) { + logrus.WithFields(logrus.Fields{ + "candidates": candidates, + "error": err, + }).Info("Candidate reasons found") +} + +// logBestID logs the best candidate ID. +func (cs *ChatService) logBestID(bestID string, err error) { + logrus.WithFields(logrus.Fields{ + "bestID": bestID, + "error": err, + }).Info("Disambiguated best match") +} + +// logChat logs the chat request and result details. +func (cs *ChatService) logChat(req ChatRequest, keywords interface{}, candidates []Reason, bestID string, err error) { + var kwMap map[string]interface{} + switch v := keywords.(type) { + case []string: + kwMap = map[string]interface{}{"keyword": v} + case map[string]interface{}: + kwMap = v + default: + kwMap = map[string]interface{}{"keyword": v} + } + logRequest(req, kwMap, candidates, bestID, err) if candidates != nil && bestID != "" { var best *Reason for i := range candidates { diff --git a/chat_service_integration_test.go b/chat_service_integration_test.go new file mode 100644 index 0000000..5f94655 --- /dev/null +++ b/chat_service_integration_test.go @@ -0,0 +1,176 @@ +package main + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +// --- Mocks --- +type mockLLM struct { + keywordsResp map[string]interface{} + disambigID string + keywordsErr error + disambigErr error +} + +var _ LLMClientAPI = (*mockLLM)(nil) + +func (m *mockLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]interface{}, error) { + return m.keywordsResp, m.keywordsErr +} +func (m *mockLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Reason) (string, error) { + return m.disambigID, m.disambigErr +} + +// --- Test ReasonDB --- +type testReasonDB struct { + candidates []Reason + findErr error + byID map[string]Reason +} + +var _ ReasonDBAPI = (*testReasonDB)(nil) + +func (db *testReasonDB) FindCandidates(keywords []string) ([]Reason, error) { + return db.candidates, db.findErr +} +func (db *testReasonDB) FindById(id string) (Reason, error) { + r, ok := db.byID[id] + if !ok { + return Reason{}, context.DeadlineExceeded + } + return r, nil +} + +// --- Integration tests --- +func TestChatService_MatchFound(t *testing.T) { + gin.SetMode(gin.TestMode) + var llm LLMClientAPI = &mockLLM{ + keywordsResp: map[string]interface{}{"keyword": []string{"worms", "deworming"}}, + disambigID: "deworming", + } + reason := Reason{ + ID: "deworming", + Procedures: []Procedure{{Name: "Deworming tablet", Price: 30, DurationMin: 10}}, + Notes: "Bloodwork ensures organs are safe for treatment.", + } + var db ReasonDBAPI = &testReasonDB{ + candidates: []Reason{reason}, + byID: map[string]Reason{"deworming": reason}, + } + var cs ChatServiceAPI = NewChatService(llm, db) + r := gin.New() + r.POST("/chat", cs.HandleChat) + + w := httptest.NewRecorder() + body := map[string]string{"message": "My dog needs deworming"} + jsonBody, _ := json.Marshal(body) + req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody)) + req.Header.Set("Content-Type", "application/json") + + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", w.Code) + } + var resp ChatResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Invalid response: %v", err) + } + if resp.Match == nil || *resp.Match != "deworming" { + t.Errorf("Expected match 'deworming', got %v", resp.Match) + } + if len(resp.Procedures) != 1 || resp.Procedures[0].Name != "Deworming tablet" { + t.Errorf("Expected procedure 'Deworming tablet', got %+v", resp.Procedures) + } + if resp.Notes != reason.Notes { + t.Errorf("Expected notes '%s', got '%s'", reason.Notes, resp.Notes) + } +} + +func TestChatService_NoMatch(t *testing.T) { + gin.SetMode(gin.TestMode) + llm := &mockLLM{ + keywordsResp: map[string]interface{}{"keyword": []string{"unknown"}}, + disambigID: "", + } + db := &testReasonDB{ + candidates: []Reason{}, + byID: map[string]Reason{}, + } + cs := NewChatService(llm, db) + r := gin.New() + r.POST("/chat", cs.HandleChat) + + w := httptest.NewRecorder() + body := map[string]string{"message": "No known reason"} + jsonBody, _ := json.Marshal(body) + req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody)) + req.Header.Set("Content-Type", "application/json") + + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", w.Code) + } + var resp ChatResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Invalid response: %v", err) + } + if resp.Match != nil { + t.Errorf("Expected no match, got %v", resp.Match) + } +} + +func TestChatService_LLMError(t *testing.T) { + gin.SetMode(gin.TestMode) + llm := &mockLLM{ + keywordsErr: context.DeadlineExceeded, + } + db := &testReasonDB{} + cs := NewChatService(llm, db) + r := gin.New() + r.POST("/chat", cs.HandleChat) + + w := httptest.NewRecorder() + body := map[string]string{"message": "error"} + jsonBody, _ := json.Marshal(body) + req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody)) + req.Header.Set("Content-Type", "application/json") + + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", w.Code) + } + var resp ChatResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Invalid response: %v", err) + } + if resp.Match != nil { + t.Errorf("Expected no match on LLM error, got %v", resp.Match) + } +} + +// --- Helper for request body --- +func httptestNewBody(b []byte) *httptestBody { + return &httptestBody{b: b} +} + +type httptestBody struct { + b []byte + pos int +} + +func (r *httptestBody) Read(p []byte) (int, error) { + if r.pos >= len(r.b) { + return 0, io.EOF + } + n := copy(p, r.b[r.pos:]) + r.pos += n + return n, nil +} +func (r *httptestBody) Close() error { return nil } diff --git a/db.go b/db.go index 4df460e..deab67e 100644 --- a/db.go +++ b/db.go @@ -22,7 +22,7 @@ func NewReasonDB() ReasonDB { db.init() return db } -func (rdb *ReasonDB) findById(reasonId string) (Reason, error) { +func (rdb *ReasonDB) FindById(reasonId string) (Reason, error) { for _, reason := range rdb.reasonsDB { if reason.ID == reasonId { return reason, nil @@ -73,8 +73,8 @@ func (rdb *ReasonDB) indexReasons(reasons []Reason) error { return rdb.reasonsIdx.Batch(batch) } -// findCandidates returns reasons with overlapping keywords -func (rdb *ReasonDB) findCandidates(keywords []string) ([]Reason, error) { +// FindCandidates returns reasons with overlapping keywords +func (rdb *ReasonDB) FindCandidates(keywords []string) ([]Reason, error) { query := bleve.NewMatchQuery(strings.Join(keywords, " ")) search := bleve.NewSearchRequest(query) searchResults, err := rdb.reasonsIdx.Search(search) @@ -104,3 +104,13 @@ func sumProcedures(procs []Procedure) (int, int) { } return totalPrice, totalDuration } + +// ReasonDBAPI allows mocking ReasonDB in other places +// Only public methods should be included + +type ReasonDBAPI interface { + FindById(reasonId string) (Reason, error) + FindCandidates(keywords []string) ([]Reason, error) +} + +var _ ReasonDBAPI = (*ReasonDB)(nil) diff --git a/llm.go b/llm.go index a653201..4836557 100644 --- a/llm.go +++ b/llm.go @@ -18,6 +18,14 @@ type LLMClient struct { BaseURL string } +// NewLLMClient constructs a new LLMClient with the given API key and base URL +func NewLLMClient(apiKey, baseURL string) *LLMClient { + return &LLMClient{ + APIKey: apiKey, + BaseURL: baseURL, + } +} + // renderPrompt renders a Go template with the given data func renderPrompt(tmplStr string, data any) (string, error) { tmpl, err := template.New("").Parse(tmplStr) @@ -133,3 +141,13 @@ func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string, forma logrus.WithField("content", result.Message.Content).Info("[LLM] openAICompletion: got content") return result.Message.Content, nil } + +// LLMClientAPI allows mocking LLMClient in other places +// Only public methods should be included + +type LLMClientAPI interface { + ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error) + DisambiguateBestMatch(ctx context.Context, message string, candidates []Reason) (string, error) +} + +var _ LLMClientAPI = (*LLMClient)(nil) diff --git a/main.go b/main.go index 369f34a..0cc512c 100644 --- a/main.go +++ b/main.go @@ -20,11 +20,11 @@ func main() { if err := loadUITemplate("ui.html"); err != nil { logrus.Fatalf("Failed to load ui.html: %v", err) } - llm := &LLMClient{ - APIKey: os.Getenv("OPENAI_API_KEY"), - BaseURL: os.Getenv("OPENAI_BASE_URL"), - } - chatService := NewChatService(llm, reasonDB) + var llm LLMClientAPI = NewLLMClient( + os.Getenv("OPENAI_API_KEY"), + os.Getenv("OPENAI_BASE_URL"), + ) + chatService := NewChatService(llm, &reasonDB) r := gin.Default() r.GET("/", func(c *gin.Context) { c.Status(200) diff --git a/main_test.go b/main_test.go deleted file mode 100644 index cbe6193..0000000 --- a/main_test.go +++ /dev/null @@ -1,145 +0,0 @@ -package main - -import ( - "bytes" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "os" - "testing" - - "github.com/gin-gonic/gin" -) - -type testDB struct { - file string - data string -} - -func (tdb *testDB) setup() { - err := os.WriteFile(tdb.file, []byte(tdb.data), 0644) - if err != nil { - panic(err) - } -} - -func (tdb *testDB) teardown() { - _ = os.Remove(tdb.file) -} - -func TestChatEndpoint_MatchFound(t *testing.T) { - tdb := testDB{ - file: "db.yaml", - data: ` -- id: deworming - reason: Deworming for dogs - keywords: ["worms", "deworming", "parasite"] - procedures: - - name: Deworming tablet - price: 30 - duration_minutes: 10 - - name: Bloodwork - price: 35 - duration_minutes: 35 - notes: Bloodwork ensures organs are safe for treatment. -- id: vaccination - reason: Annual vaccination - keywords: ["vaccine", "vaccination", "shots"] - procedures: - - name: Vaccine injection - price: 50 - duration_minutes: 15 -`, - } - tdb.setup() - defer tdb.teardown() - - if err := loadYAMLDB(tdb.file); err != nil { - t.Fatalf("Failed to load test db: %v", err) - } - - r := setupRouter() - - w := httptest.NewRecorder() - body := map[string]string{"message": "My dog needs deworming and bloodwork"} - jsonBody, _ := json.Marshal(body) - req, _ := http.NewRequest("POST", "/chat", bytes.NewBuffer(jsonBody)) - req.Header.Set("Content-Type", "application/json") - - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("Expected 200, got %d", w.Code) - } - respBody, _ := io.ReadAll(w.Body) - if !bytes.Contains(respBody, []byte("deworming")) { - t.Errorf("Expected match for deworming, got %s", string(respBody)) - } -} - -func TestChatEndpoint_NoMatch(t *testing.T) { - tdb := testDB{ - file: "db.yaml", - data: ` -- id: vaccination - reason: Annual vaccination - keywords: ["vaccine", "vaccination", "shots"] - procedures: - - name: Vaccine injection - price: 50 - duration_minutes: 15 -`, - } - tdb.setup() - defer tdb.teardown() - - if err := loadYAMLDB(tdb.file); err != nil { - t.Fatalf("Failed to load test db: %v", err) - } - - r := setupRouter() - - w := httptest.NewRecorder() - body := map[string]string{"message": "My dog has worms"} - jsonBody, _ := json.Marshal(body) - req, _ := http.NewRequest("POST", "/chat", bytes.NewBuffer(jsonBody)) - req.Header.Set("Content-Type", "application/json") - - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("Expected 200, got %d", w.Code) - } - respBody, _ := io.ReadAll(w.Body) - if !bytes.Contains(respBody, []byte(`"match":null`)) { - t.Errorf("Expected no match, got %s", string(respBody)) - } -} - -func setupRouter() *gin.Engine { - r := gin.Default() - r.POST("/chat", func(c *gin.Context) { - var req ChatRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"}) - return - } - keywords := naiveKeywordExtract(req.Message) - candidates := findCandidates(keywords) - if len(candidates) == 0 { - c.JSON(http.StatusOK, ChatResponse{Match: nil}) - return - } - best := candidates[0] - totalPrice, totalDuration := sumProcedures(best.Procedures) - c.JSON(http.StatusOK, ChatResponse{ - Match: &best.ID, - Procedures: best.Procedures, - TotalPrice: totalPrice, - TotalDuration: totalDuration, - Notes: best.Notes, - }) - }) - return r -}