refactor to add tests

This commit is contained in:
lehel 2025-09-27 11:20:32 +02:00
parent e60f0c0ed6
commit 3876ba502e
No known key found for this signature in database
GPG Key ID: 9C4F9D6111EE5CFA
6 changed files with 307 additions and 202 deletions

View File

@ -9,34 +9,43 @@ import (
"github.com/sirupsen/logrus"
)
type ChatService struct {
LLM *LLMClient
reasonsDB ReasonDB
// ChatServiceAPI allows mocking ChatService in other places
// Only public methods should be included
type ChatServiceAPI interface {
HandleChat(c *gin.Context)
}
func NewChatService(llm *LLMClient, db ReasonDB) *ChatService {
// ChatService handles chat interactions and orchestrates LLM and DB calls.
type ChatService struct {
LLM LLMClientAPI
reasonsDB ReasonDBAPI
}
var _ ChatServiceAPI = (*ChatService)(nil)
func NewChatService(llm LLMClientAPI, db ReasonDBAPI) ChatServiceAPI {
return &ChatService{LLM: llm, reasonsDB: db}
}
// HandleChat is the main entrypoint for chat requests. It delegates to modular helpers.
func (cs *ChatService) HandleChat(c *gin.Context) {
ctx := context.Background()
req, err := cs.parseRequest(c)
if err != nil {
return
}
keywordsResponse, err := cs.getKeywordsResponse(ctx, req.Message)
keywords, err := cs.extractKeywords(ctx, req.Message)
if err != nil {
cs.logChat(req, keywordsResponse, nil, "", err)
c.JSON(http.StatusOK, ChatResponse{Match: nil})
cs.respondWithError(c, req, keywords, err)
return
}
keywordList := cs.keywordsToStrings(keywordsResponse["keyword"])
best, err := cs.findVisitReason(ctx, req, keywordList)
best, err := cs.findBestReason(ctx, req, keywords)
resp := cs.buildResponse(best)
c.JSON(http.StatusOK, resp)
}
// parseRequest parses and validates the incoming chat request.
func (cs *ChatService) parseRequest(c *gin.Context) (ChatRequest, error) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
@ -47,56 +56,36 @@ func (cs *ChatService) parseRequest(c *gin.Context) (ChatRequest, error) {
return req, nil
}
func (cs *ChatService) getKeywordsResponse(ctx context.Context, message string) (map[string]interface{}, error) {
keywords, err := cs.LLM.ExtractKeywords(ctx, message)
return keywords, err
}
func (cs *ChatService) keywordsToStrings(kwIface interface{}) []string {
var kwArr []string
switch v := kwIface.(type) {
case []interface{}:
for _, item := range v {
if s, ok := item.(string); ok {
kwArr = append(kwArr, s)
}
}
case []string:
kwArr = v
// extractKeywords gets keywords from the LLM and normalizes them.
func (cs *ChatService) extractKeywords(ctx context.Context, message string) ([]string, error) {
kwResp, err := cs.LLM.ExtractKeywords(ctx, message)
if err != nil {
return nil, err
}
return kwArr
return cs.keywordsToStrings(kwResp["keyword"]), nil
}
func (cs *ChatService) findVisitReason(ctx context.Context, req ChatRequest, keywordList []string) (*Reason, error) {
logrus.WithFields(logrus.Fields{
"keywords": keywordList,
"message": req.Message,
}).Info("Finding visit reason candidates")
candidateReasonse, err := cs.reasonsDB.findCandidates(keywordList)
logrus.WithFields(logrus.Fields{
"candidates": candidateReasonse,
"error": err,
}).Info("Candidate reasons found")
// findBestReason finds candidate reasons and disambiguates the best match.
func (cs *ChatService) findBestReason(ctx context.Context, req ChatRequest, keywords []string) (*Reason, error) {
cs.logKeywords(keywords, req.Message)
candidates, err := cs.reasonsDB.FindCandidates(keywords)
cs.logCandidates(candidates, err)
if err != nil {
return nil, err
}
bestID := ""
if len(candidateReasonse) > 0 {
bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidateReasonse)
logrus.WithFields(logrus.Fields{
"bestID": bestID,
"error": err,
}).Info("Disambiguated best match")
if len(candidates) > 0 {
bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidates)
cs.logBestID(bestID, err)
}
reason, err := cs.reasonsDB.findById(bestID)
reason, err := cs.reasonsDB.FindById(bestID)
if err != nil {
return nil, fmt.Errorf("findById: %w", err)
return nil, fmt.Errorf("FindById: %w", err)
}
return &reason, nil
}
// buildResponse constructs the ChatResponse from the best Reason.
func (cs *ChatService) buildResponse(best *Reason) ChatResponse {
if best == nil {
resp := ChatResponse{Match: nil}
@ -115,8 +104,65 @@ func (cs *ChatService) buildResponse(best *Reason) ChatResponse {
return resp
}
func (cs *ChatService) logChat(req ChatRequest, keywords map[string]interface{}, candidates []Reason, bestID string, err error) {
logRequest(req, keywords, candidates, bestID, err)
// respondWithError logs and responds with error details.
func (cs *ChatService) respondWithError(c *gin.Context, req ChatRequest, keywords []string, err error) {
kwMap := map[string]interface{}{"keyword": keywords}
cs.logChat(req, kwMap, nil, "", err)
c.JSON(http.StatusOK, ChatResponse{Match: nil})
}
// keywordsToStrings normalizes keyword interface to []string.
func (cs *ChatService) keywordsToStrings(kwIface interface{}) []string {
var kwArr []string
switch v := kwIface.(type) {
case []interface{}:
for _, item := range v {
if s, ok := item.(string); ok {
kwArr = append(kwArr, s)
}
}
case []string:
kwArr = v
}
return kwArr
}
// logKeywords logs extracted keywords.
func (cs *ChatService) logKeywords(keywords []string, message string) {
logrus.WithFields(logrus.Fields{
"keywords": keywords,
"message": message,
}).Info("Finding visit reason candidates")
}
// logCandidates logs candidate reasons.
func (cs *ChatService) logCandidates(candidates []Reason, err error) {
logrus.WithFields(logrus.Fields{
"candidates": candidates,
"error": err,
}).Info("Candidate reasons found")
}
// logBestID logs the best candidate ID.
func (cs *ChatService) logBestID(bestID string, err error) {
logrus.WithFields(logrus.Fields{
"bestID": bestID,
"error": err,
}).Info("Disambiguated best match")
}
// logChat logs the chat request and result details.
func (cs *ChatService) logChat(req ChatRequest, keywords interface{}, candidates []Reason, bestID string, err error) {
var kwMap map[string]interface{}
switch v := keywords.(type) {
case []string:
kwMap = map[string]interface{}{"keyword": v}
case map[string]interface{}:
kwMap = v
default:
kwMap = map[string]interface{}{"keyword": v}
}
logRequest(req, kwMap, candidates, bestID, err)
if candidates != nil && bestID != "" {
var best *Reason
for i := range candidates {

View File

@ -0,0 +1,176 @@
package main
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
// --- Mocks ---
type mockLLM struct {
keywordsResp map[string]interface{}
disambigID string
keywordsErr error
disambigErr error
}
var _ LLMClientAPI = (*mockLLM)(nil)
func (m *mockLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]interface{}, error) {
return m.keywordsResp, m.keywordsErr
}
func (m *mockLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Reason) (string, error) {
return m.disambigID, m.disambigErr
}
// --- Test ReasonDB ---
type testReasonDB struct {
candidates []Reason
findErr error
byID map[string]Reason
}
var _ ReasonDBAPI = (*testReasonDB)(nil)
func (db *testReasonDB) FindCandidates(keywords []string) ([]Reason, error) {
return db.candidates, db.findErr
}
func (db *testReasonDB) FindById(id string) (Reason, error) {
r, ok := db.byID[id]
if !ok {
return Reason{}, context.DeadlineExceeded
}
return r, nil
}
// --- Integration tests ---
func TestChatService_MatchFound(t *testing.T) {
gin.SetMode(gin.TestMode)
var llm LLMClientAPI = &mockLLM{
keywordsResp: map[string]interface{}{"keyword": []string{"worms", "deworming"}},
disambigID: "deworming",
}
reason := Reason{
ID: "deworming",
Procedures: []Procedure{{Name: "Deworming tablet", Price: 30, DurationMin: 10}},
Notes: "Bloodwork ensures organs are safe for treatment.",
}
var db ReasonDBAPI = &testReasonDB{
candidates: []Reason{reason},
byID: map[string]Reason{"deworming": reason},
}
var cs ChatServiceAPI = NewChatService(llm, db)
r := gin.New()
r.POST("/chat", cs.HandleChat)
w := httptest.NewRecorder()
body := map[string]string{"message": "My dog needs deworming"}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
var resp ChatResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("Invalid response: %v", err)
}
if resp.Match == nil || *resp.Match != "deworming" {
t.Errorf("Expected match 'deworming', got %v", resp.Match)
}
if len(resp.Procedures) != 1 || resp.Procedures[0].Name != "Deworming tablet" {
t.Errorf("Expected procedure 'Deworming tablet', got %+v", resp.Procedures)
}
if resp.Notes != reason.Notes {
t.Errorf("Expected notes '%s', got '%s'", reason.Notes, resp.Notes)
}
}
func TestChatService_NoMatch(t *testing.T) {
gin.SetMode(gin.TestMode)
llm := &mockLLM{
keywordsResp: map[string]interface{}{"keyword": []string{"unknown"}},
disambigID: "",
}
db := &testReasonDB{
candidates: []Reason{},
byID: map[string]Reason{},
}
cs := NewChatService(llm, db)
r := gin.New()
r.POST("/chat", cs.HandleChat)
w := httptest.NewRecorder()
body := map[string]string{"message": "No known reason"}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
var resp ChatResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("Invalid response: %v", err)
}
if resp.Match != nil {
t.Errorf("Expected no match, got %v", resp.Match)
}
}
func TestChatService_LLMError(t *testing.T) {
gin.SetMode(gin.TestMode)
llm := &mockLLM{
keywordsErr: context.DeadlineExceeded,
}
db := &testReasonDB{}
cs := NewChatService(llm, db)
r := gin.New()
r.POST("/chat", cs.HandleChat)
w := httptest.NewRecorder()
body := map[string]string{"message": "error"}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
var resp ChatResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("Invalid response: %v", err)
}
if resp.Match != nil {
t.Errorf("Expected no match on LLM error, got %v", resp.Match)
}
}
// --- Helper for request body ---
func httptestNewBody(b []byte) *httptestBody {
return &httptestBody{b: b}
}
type httptestBody struct {
b []byte
pos int
}
func (r *httptestBody) Read(p []byte) (int, error) {
if r.pos >= len(r.b) {
return 0, io.EOF
}
n := copy(p, r.b[r.pos:])
r.pos += n
return n, nil
}
func (r *httptestBody) Close() error { return nil }

16
db.go
View File

@ -22,7 +22,7 @@ func NewReasonDB() ReasonDB {
db.init()
return db
}
func (rdb *ReasonDB) findById(reasonId string) (Reason, error) {
func (rdb *ReasonDB) FindById(reasonId string) (Reason, error) {
for _, reason := range rdb.reasonsDB {
if reason.ID == reasonId {
return reason, nil
@ -73,8 +73,8 @@ func (rdb *ReasonDB) indexReasons(reasons []Reason) error {
return rdb.reasonsIdx.Batch(batch)
}
// findCandidates returns reasons with overlapping keywords
func (rdb *ReasonDB) findCandidates(keywords []string) ([]Reason, error) {
// FindCandidates returns reasons with overlapping keywords
func (rdb *ReasonDB) FindCandidates(keywords []string) ([]Reason, error) {
query := bleve.NewMatchQuery(strings.Join(keywords, " "))
search := bleve.NewSearchRequest(query)
searchResults, err := rdb.reasonsIdx.Search(search)
@ -104,3 +104,13 @@ func sumProcedures(procs []Procedure) (int, int) {
}
return totalPrice, totalDuration
}
// ReasonDBAPI allows mocking ReasonDB in other places
// Only public methods should be included
type ReasonDBAPI interface {
FindById(reasonId string) (Reason, error)
FindCandidates(keywords []string) ([]Reason, error)
}
var _ ReasonDBAPI = (*ReasonDB)(nil)

18
llm.go
View File

@ -18,6 +18,14 @@ type LLMClient struct {
BaseURL string
}
// NewLLMClient constructs a new LLMClient with the given API key and base URL
func NewLLMClient(apiKey, baseURL string) *LLMClient {
return &LLMClient{
APIKey: apiKey,
BaseURL: baseURL,
}
}
// renderPrompt renders a Go template with the given data
func renderPrompt(tmplStr string, data any) (string, error) {
tmpl, err := template.New("").Parse(tmplStr)
@ -133,3 +141,13 @@ func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string, forma
logrus.WithField("content", result.Message.Content).Info("[LLM] openAICompletion: got content")
return result.Message.Content, nil
}
// LLMClientAPI allows mocking LLMClient in other places
// Only public methods should be included
type LLMClientAPI interface {
ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error)
DisambiguateBestMatch(ctx context.Context, message string, candidates []Reason) (string, error)
}
var _ LLMClientAPI = (*LLMClient)(nil)

10
main.go
View File

@ -20,11 +20,11 @@ func main() {
if err := loadUITemplate("ui.html"); err != nil {
logrus.Fatalf("Failed to load ui.html: %v", err)
}
llm := &LLMClient{
APIKey: os.Getenv("OPENAI_API_KEY"),
BaseURL: os.Getenv("OPENAI_BASE_URL"),
}
chatService := NewChatService(llm, reasonDB)
var llm LLMClientAPI = NewLLMClient(
os.Getenv("OPENAI_API_KEY"),
os.Getenv("OPENAI_BASE_URL"),
)
chatService := NewChatService(llm, &reasonDB)
r := gin.Default()
r.GET("/", func(c *gin.Context) {
c.Status(200)

View File

@ -1,145 +0,0 @@
package main
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/gin-gonic/gin"
)
type testDB struct {
file string
data string
}
func (tdb *testDB) setup() {
err := os.WriteFile(tdb.file, []byte(tdb.data), 0644)
if err != nil {
panic(err)
}
}
func (tdb *testDB) teardown() {
_ = os.Remove(tdb.file)
}
func TestChatEndpoint_MatchFound(t *testing.T) {
tdb := testDB{
file: "db.yaml",
data: `
- id: deworming
reason: Deworming for dogs
keywords: ["worms", "deworming", "parasite"]
procedures:
- name: Deworming tablet
price: 30
duration_minutes: 10
- name: Bloodwork
price: 35
duration_minutes: 35
notes: Bloodwork ensures organs are safe for treatment.
- id: vaccination
reason: Annual vaccination
keywords: ["vaccine", "vaccination", "shots"]
procedures:
- name: Vaccine injection
price: 50
duration_minutes: 15
`,
}
tdb.setup()
defer tdb.teardown()
if err := loadYAMLDB(tdb.file); err != nil {
t.Fatalf("Failed to load test db: %v", err)
}
r := setupRouter()
w := httptest.NewRecorder()
body := map[string]string{"message": "My dog needs deworming and bloodwork"}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequest("POST", "/chat", bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
respBody, _ := io.ReadAll(w.Body)
if !bytes.Contains(respBody, []byte("deworming")) {
t.Errorf("Expected match for deworming, got %s", string(respBody))
}
}
func TestChatEndpoint_NoMatch(t *testing.T) {
tdb := testDB{
file: "db.yaml",
data: `
- id: vaccination
reason: Annual vaccination
keywords: ["vaccine", "vaccination", "shots"]
procedures:
- name: Vaccine injection
price: 50
duration_minutes: 15
`,
}
tdb.setup()
defer tdb.teardown()
if err := loadYAMLDB(tdb.file); err != nil {
t.Fatalf("Failed to load test db: %v", err)
}
r := setupRouter()
w := httptest.NewRecorder()
body := map[string]string{"message": "My dog has worms"}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequest("POST", "/chat", bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
respBody, _ := io.ReadAll(w.Body)
if !bytes.Contains(respBody, []byte(`"match":null`)) {
t.Errorf("Expected no match, got %s", string(respBody))
}
}
func setupRouter() *gin.Engine {
r := gin.Default()
r.POST("/chat", func(c *gin.Context) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
keywords := naiveKeywordExtract(req.Message)
candidates := findCandidates(keywords)
if len(candidates) == 0 {
c.JSON(http.StatusOK, ChatResponse{Match: nil})
return
}
best := candidates[0]
totalPrice, totalDuration := sumProcedures(best.Procedures)
c.JSON(http.StatusOK, ChatResponse{
Match: &best.ID,
Procedures: best.Procedures,
TotalPrice: totalPrice,
TotalDuration: totalDuration,
Notes: best.Notes,
})
})
return r
}