refactor to add tests
This commit is contained in:
parent
e60f0c0ed6
commit
3876ba502e
146
chat_service.go
146
chat_service.go
|
|
@ -9,34 +9,43 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatService struct {
|
// ChatServiceAPI allows mocking ChatService in other places
|
||||||
LLM *LLMClient
|
// Only public methods should be included
|
||||||
reasonsDB ReasonDB
|
|
||||||
|
type ChatServiceAPI interface {
|
||||||
|
HandleChat(c *gin.Context)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatService(llm *LLMClient, db ReasonDB) *ChatService {
|
// ChatService handles chat interactions and orchestrates LLM and DB calls.
|
||||||
|
type ChatService struct {
|
||||||
|
LLM LLMClientAPI
|
||||||
|
reasonsDB ReasonDBAPI
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ChatServiceAPI = (*ChatService)(nil)
|
||||||
|
|
||||||
|
func NewChatService(llm LLMClientAPI, db ReasonDBAPI) ChatServiceAPI {
|
||||||
return &ChatService{LLM: llm, reasonsDB: db}
|
return &ChatService{LLM: llm, reasonsDB: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HandleChat is the main entrypoint for chat requests. It delegates to modular helpers.
|
||||||
func (cs *ChatService) HandleChat(c *gin.Context) {
|
func (cs *ChatService) HandleChat(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
req, err := cs.parseRequest(c)
|
req, err := cs.parseRequest(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
keywordsResponse, err := cs.getKeywordsResponse(ctx, req.Message)
|
keywords, err := cs.extractKeywords(ctx, req.Message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cs.logChat(req, keywordsResponse, nil, "", err)
|
cs.respondWithError(c, req, keywords, err)
|
||||||
c.JSON(http.StatusOK, ChatResponse{Match: nil})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
keywordList := cs.keywordsToStrings(keywordsResponse["keyword"])
|
best, err := cs.findBestReason(ctx, req, keywords)
|
||||||
best, err := cs.findVisitReason(ctx, req, keywordList)
|
|
||||||
|
|
||||||
resp := cs.buildResponse(best)
|
resp := cs.buildResponse(best)
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseRequest parses and validates the incoming chat request.
|
||||||
func (cs *ChatService) parseRequest(c *gin.Context) (ChatRequest, error) {
|
func (cs *ChatService) parseRequest(c *gin.Context) (ChatRequest, error) {
|
||||||
var req ChatRequest
|
var req ChatRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
|
@ -47,56 +56,36 @@ func (cs *ChatService) parseRequest(c *gin.Context) (ChatRequest, error) {
|
||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *ChatService) getKeywordsResponse(ctx context.Context, message string) (map[string]interface{}, error) {
|
// extractKeywords gets keywords from the LLM and normalizes them.
|
||||||
keywords, err := cs.LLM.ExtractKeywords(ctx, message)
|
func (cs *ChatService) extractKeywords(ctx context.Context, message string) ([]string, error) {
|
||||||
return keywords, err
|
kwResp, err := cs.LLM.ExtractKeywords(ctx, message)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return cs.keywordsToStrings(kwResp["keyword"]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *ChatService) keywordsToStrings(kwIface interface{}) []string {
|
// findBestReason finds candidate reasons and disambiguates the best match.
|
||||||
var kwArr []string
|
func (cs *ChatService) findBestReason(ctx context.Context, req ChatRequest, keywords []string) (*Reason, error) {
|
||||||
switch v := kwIface.(type) {
|
cs.logKeywords(keywords, req.Message)
|
||||||
case []interface{}:
|
candidates, err := cs.reasonsDB.FindCandidates(keywords)
|
||||||
for _, item := range v {
|
cs.logCandidates(candidates, err)
|
||||||
if s, ok := item.(string); ok {
|
|
||||||
kwArr = append(kwArr, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case []string:
|
|
||||||
kwArr = v
|
|
||||||
}
|
|
||||||
return kwArr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *ChatService) findVisitReason(ctx context.Context, req ChatRequest, keywordList []string) (*Reason, error) {
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"keywords": keywordList,
|
|
||||||
"message": req.Message,
|
|
||||||
}).Info("Finding visit reason candidates")
|
|
||||||
|
|
||||||
candidateReasonse, err := cs.reasonsDB.findCandidates(keywordList)
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"candidates": candidateReasonse,
|
|
||||||
"error": err,
|
|
||||||
}).Info("Candidate reasons found")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
bestID := ""
|
bestID := ""
|
||||||
if len(candidateReasonse) > 0 {
|
if len(candidates) > 0 {
|
||||||
bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidateReasonse)
|
bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidates)
|
||||||
logrus.WithFields(logrus.Fields{
|
cs.logBestID(bestID, err)
|
||||||
"bestID": bestID,
|
|
||||||
"error": err,
|
|
||||||
}).Info("Disambiguated best match")
|
|
||||||
}
|
}
|
||||||
reason, err := cs.reasonsDB.findById(bestID)
|
reason, err := cs.reasonsDB.FindById(bestID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("findById: %w", err)
|
return nil, fmt.Errorf("FindById: %w", err)
|
||||||
}
|
}
|
||||||
return &reason, nil
|
return &reason, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildResponse constructs the ChatResponse from the best Reason.
|
||||||
func (cs *ChatService) buildResponse(best *Reason) ChatResponse {
|
func (cs *ChatService) buildResponse(best *Reason) ChatResponse {
|
||||||
if best == nil {
|
if best == nil {
|
||||||
resp := ChatResponse{Match: nil}
|
resp := ChatResponse{Match: nil}
|
||||||
|
|
@ -115,8 +104,65 @@ func (cs *ChatService) buildResponse(best *Reason) ChatResponse {
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *ChatService) logChat(req ChatRequest, keywords map[string]interface{}, candidates []Reason, bestID string, err error) {
|
// respondWithError logs and responds with error details.
|
||||||
logRequest(req, keywords, candidates, bestID, err)
|
func (cs *ChatService) respondWithError(c *gin.Context, req ChatRequest, keywords []string, err error) {
|
||||||
|
kwMap := map[string]interface{}{"keyword": keywords}
|
||||||
|
cs.logChat(req, kwMap, nil, "", err)
|
||||||
|
c.JSON(http.StatusOK, ChatResponse{Match: nil})
|
||||||
|
}
|
||||||
|
|
||||||
|
// keywordsToStrings normalizes keyword interface to []string.
|
||||||
|
func (cs *ChatService) keywordsToStrings(kwIface interface{}) []string {
|
||||||
|
var kwArr []string
|
||||||
|
switch v := kwIface.(type) {
|
||||||
|
case []interface{}:
|
||||||
|
for _, item := range v {
|
||||||
|
if s, ok := item.(string); ok {
|
||||||
|
kwArr = append(kwArr, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case []string:
|
||||||
|
kwArr = v
|
||||||
|
}
|
||||||
|
return kwArr
|
||||||
|
}
|
||||||
|
|
||||||
|
// logKeywords logs extracted keywords.
|
||||||
|
func (cs *ChatService) logKeywords(keywords []string, message string) {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"keywords": keywords,
|
||||||
|
"message": message,
|
||||||
|
}).Info("Finding visit reason candidates")
|
||||||
|
}
|
||||||
|
|
||||||
|
// logCandidates logs candidate reasons.
|
||||||
|
func (cs *ChatService) logCandidates(candidates []Reason, err error) {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"candidates": candidates,
|
||||||
|
"error": err,
|
||||||
|
}).Info("Candidate reasons found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// logBestID logs the best candidate ID.
|
||||||
|
func (cs *ChatService) logBestID(bestID string, err error) {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"bestID": bestID,
|
||||||
|
"error": err,
|
||||||
|
}).Info("Disambiguated best match")
|
||||||
|
}
|
||||||
|
|
||||||
|
// logChat logs the chat request and result details.
|
||||||
|
func (cs *ChatService) logChat(req ChatRequest, keywords interface{}, candidates []Reason, bestID string, err error) {
|
||||||
|
var kwMap map[string]interface{}
|
||||||
|
switch v := keywords.(type) {
|
||||||
|
case []string:
|
||||||
|
kwMap = map[string]interface{}{"keyword": v}
|
||||||
|
case map[string]interface{}:
|
||||||
|
kwMap = v
|
||||||
|
default:
|
||||||
|
kwMap = map[string]interface{}{"keyword": v}
|
||||||
|
}
|
||||||
|
logRequest(req, kwMap, candidates, bestID, err)
|
||||||
if candidates != nil && bestID != "" {
|
if candidates != nil && bestID != "" {
|
||||||
var best *Reason
|
var best *Reason
|
||||||
for i := range candidates {
|
for i := range candidates {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,176 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Mocks ---
|
||||||
|
type mockLLM struct {
|
||||||
|
keywordsResp map[string]interface{}
|
||||||
|
disambigID string
|
||||||
|
keywordsErr error
|
||||||
|
disambigErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ LLMClientAPI = (*mockLLM)(nil)
|
||||||
|
|
||||||
|
func (m *mockLLM) ExtractKeywords(ctx context.Context, msg string) (map[string]interface{}, error) {
|
||||||
|
return m.keywordsResp, m.keywordsErr
|
||||||
|
}
|
||||||
|
func (m *mockLLM) DisambiguateBestMatch(ctx context.Context, msg string, candidates []Reason) (string, error) {
|
||||||
|
return m.disambigID, m.disambigErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Test ReasonDB ---
|
||||||
|
type testReasonDB struct {
|
||||||
|
candidates []Reason
|
||||||
|
findErr error
|
||||||
|
byID map[string]Reason
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ReasonDBAPI = (*testReasonDB)(nil)
|
||||||
|
|
||||||
|
func (db *testReasonDB) FindCandidates(keywords []string) ([]Reason, error) {
|
||||||
|
return db.candidates, db.findErr
|
||||||
|
}
|
||||||
|
func (db *testReasonDB) FindById(id string) (Reason, error) {
|
||||||
|
r, ok := db.byID[id]
|
||||||
|
if !ok {
|
||||||
|
return Reason{}, context.DeadlineExceeded
|
||||||
|
}
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Integration tests ---
|
||||||
|
func TestChatService_MatchFound(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
var llm LLMClientAPI = &mockLLM{
|
||||||
|
keywordsResp: map[string]interface{}{"keyword": []string{"worms", "deworming"}},
|
||||||
|
disambigID: "deworming",
|
||||||
|
}
|
||||||
|
reason := Reason{
|
||||||
|
ID: "deworming",
|
||||||
|
Procedures: []Procedure{{Name: "Deworming tablet", Price: 30, DurationMin: 10}},
|
||||||
|
Notes: "Bloodwork ensures organs are safe for treatment.",
|
||||||
|
}
|
||||||
|
var db ReasonDBAPI = &testReasonDB{
|
||||||
|
candidates: []Reason{reason},
|
||||||
|
byID: map[string]Reason{"deworming": reason},
|
||||||
|
}
|
||||||
|
var cs ChatServiceAPI = NewChatService(llm, db)
|
||||||
|
r := gin.New()
|
||||||
|
r.POST("/chat", cs.HandleChat)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
body := map[string]string{"message": "My dog needs deworming"}
|
||||||
|
jsonBody, _ := json.Marshal(body)
|
||||||
|
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("Expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
var resp ChatResponse
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("Invalid response: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Match == nil || *resp.Match != "deworming" {
|
||||||
|
t.Errorf("Expected match 'deworming', got %v", resp.Match)
|
||||||
|
}
|
||||||
|
if len(resp.Procedures) != 1 || resp.Procedures[0].Name != "Deworming tablet" {
|
||||||
|
t.Errorf("Expected procedure 'Deworming tablet', got %+v", resp.Procedures)
|
||||||
|
}
|
||||||
|
if resp.Notes != reason.Notes {
|
||||||
|
t.Errorf("Expected notes '%s', got '%s'", reason.Notes, resp.Notes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatService_NoMatch(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
llm := &mockLLM{
|
||||||
|
keywordsResp: map[string]interface{}{"keyword": []string{"unknown"}},
|
||||||
|
disambigID: "",
|
||||||
|
}
|
||||||
|
db := &testReasonDB{
|
||||||
|
candidates: []Reason{},
|
||||||
|
byID: map[string]Reason{},
|
||||||
|
}
|
||||||
|
cs := NewChatService(llm, db)
|
||||||
|
r := gin.New()
|
||||||
|
r.POST("/chat", cs.HandleChat)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
body := map[string]string{"message": "No known reason"}
|
||||||
|
jsonBody, _ := json.Marshal(body)
|
||||||
|
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("Expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
var resp ChatResponse
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("Invalid response: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Match != nil {
|
||||||
|
t.Errorf("Expected no match, got %v", resp.Match)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatService_LLMError(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
llm := &mockLLM{
|
||||||
|
keywordsErr: context.DeadlineExceeded,
|
||||||
|
}
|
||||||
|
db := &testReasonDB{}
|
||||||
|
cs := NewChatService(llm, db)
|
||||||
|
r := gin.New()
|
||||||
|
r.POST("/chat", cs.HandleChat)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
body := map[string]string{"message": "error"}
|
||||||
|
jsonBody, _ := json.Marshal(body)
|
||||||
|
req, _ := http.NewRequest("POST", "/chat", httptestNewBody(jsonBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("Expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
var resp ChatResponse
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("Invalid response: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Match != nil {
|
||||||
|
t.Errorf("Expected no match on LLM error, got %v", resp.Match)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helper for request body ---
|
||||||
|
func httptestNewBody(b []byte) *httptestBody {
|
||||||
|
return &httptestBody{b: b}
|
||||||
|
}
|
||||||
|
|
||||||
|
type httptestBody struct {
|
||||||
|
b []byte
|
||||||
|
pos int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *httptestBody) Read(p []byte) (int, error) {
|
||||||
|
if r.pos >= len(r.b) {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n := copy(p, r.b[r.pos:])
|
||||||
|
r.pos += n
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
func (r *httptestBody) Close() error { return nil }
|
||||||
16
db.go
16
db.go
|
|
@ -22,7 +22,7 @@ func NewReasonDB() ReasonDB {
|
||||||
db.init()
|
db.init()
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
func (rdb *ReasonDB) findById(reasonId string) (Reason, error) {
|
func (rdb *ReasonDB) FindById(reasonId string) (Reason, error) {
|
||||||
for _, reason := range rdb.reasonsDB {
|
for _, reason := range rdb.reasonsDB {
|
||||||
if reason.ID == reasonId {
|
if reason.ID == reasonId {
|
||||||
return reason, nil
|
return reason, nil
|
||||||
|
|
@ -73,8 +73,8 @@ func (rdb *ReasonDB) indexReasons(reasons []Reason) error {
|
||||||
return rdb.reasonsIdx.Batch(batch)
|
return rdb.reasonsIdx.Batch(batch)
|
||||||
}
|
}
|
||||||
|
|
||||||
// findCandidates returns reasons with overlapping keywords
|
// FindCandidates returns reasons with overlapping keywords
|
||||||
func (rdb *ReasonDB) findCandidates(keywords []string) ([]Reason, error) {
|
func (rdb *ReasonDB) FindCandidates(keywords []string) ([]Reason, error) {
|
||||||
query := bleve.NewMatchQuery(strings.Join(keywords, " "))
|
query := bleve.NewMatchQuery(strings.Join(keywords, " "))
|
||||||
search := bleve.NewSearchRequest(query)
|
search := bleve.NewSearchRequest(query)
|
||||||
searchResults, err := rdb.reasonsIdx.Search(search)
|
searchResults, err := rdb.reasonsIdx.Search(search)
|
||||||
|
|
@ -104,3 +104,13 @@ func sumProcedures(procs []Procedure) (int, int) {
|
||||||
}
|
}
|
||||||
return totalPrice, totalDuration
|
return totalPrice, totalDuration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReasonDBAPI allows mocking ReasonDB in other places
|
||||||
|
// Only public methods should be included
|
||||||
|
|
||||||
|
type ReasonDBAPI interface {
|
||||||
|
FindById(reasonId string) (Reason, error)
|
||||||
|
FindCandidates(keywords []string) ([]Reason, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ReasonDBAPI = (*ReasonDB)(nil)
|
||||||
|
|
|
||||||
18
llm.go
18
llm.go
|
|
@ -18,6 +18,14 @@ type LLMClient struct {
|
||||||
BaseURL string
|
BaseURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewLLMClient constructs a new LLMClient with the given API key and base URL
|
||||||
|
func NewLLMClient(apiKey, baseURL string) *LLMClient {
|
||||||
|
return &LLMClient{
|
||||||
|
APIKey: apiKey,
|
||||||
|
BaseURL: baseURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// renderPrompt renders a Go template with the given data
|
// renderPrompt renders a Go template with the given data
|
||||||
func renderPrompt(tmplStr string, data any) (string, error) {
|
func renderPrompt(tmplStr string, data any) (string, error) {
|
||||||
tmpl, err := template.New("").Parse(tmplStr)
|
tmpl, err := template.New("").Parse(tmplStr)
|
||||||
|
|
@ -133,3 +141,13 @@ func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string, forma
|
||||||
logrus.WithField("content", result.Message.Content).Info("[LLM] openAICompletion: got content")
|
logrus.WithField("content", result.Message.Content).Info("[LLM] openAICompletion: got content")
|
||||||
return result.Message.Content, nil
|
return result.Message.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LLMClientAPI allows mocking LLMClient in other places
|
||||||
|
// Only public methods should be included
|
||||||
|
|
||||||
|
type LLMClientAPI interface {
|
||||||
|
ExtractKeywords(ctx context.Context, message string) (map[string]interface{}, error)
|
||||||
|
DisambiguateBestMatch(ctx context.Context, message string, candidates []Reason) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ LLMClientAPI = (*LLMClient)(nil)
|
||||||
|
|
|
||||||
10
main.go
10
main.go
|
|
@ -20,11 +20,11 @@ func main() {
|
||||||
if err := loadUITemplate("ui.html"); err != nil {
|
if err := loadUITemplate("ui.html"); err != nil {
|
||||||
logrus.Fatalf("Failed to load ui.html: %v", err)
|
logrus.Fatalf("Failed to load ui.html: %v", err)
|
||||||
}
|
}
|
||||||
llm := &LLMClient{
|
var llm LLMClientAPI = NewLLMClient(
|
||||||
APIKey: os.Getenv("OPENAI_API_KEY"),
|
os.Getenv("OPENAI_API_KEY"),
|
||||||
BaseURL: os.Getenv("OPENAI_BASE_URL"),
|
os.Getenv("OPENAI_BASE_URL"),
|
||||||
}
|
)
|
||||||
chatService := NewChatService(llm, reasonDB)
|
chatService := NewChatService(llm, &reasonDB)
|
||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
r.GET("/", func(c *gin.Context) {
|
r.GET("/", func(c *gin.Context) {
|
||||||
c.Status(200)
|
c.Status(200)
|
||||||
|
|
|
||||||
145
main_test.go
145
main_test.go
|
|
@ -1,145 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
type testDB struct {
|
|
||||||
file string
|
|
||||||
data string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tdb *testDB) setup() {
|
|
||||||
err := os.WriteFile(tdb.file, []byte(tdb.data), 0644)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tdb *testDB) teardown() {
|
|
||||||
_ = os.Remove(tdb.file)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChatEndpoint_MatchFound(t *testing.T) {
|
|
||||||
tdb := testDB{
|
|
||||||
file: "db.yaml",
|
|
||||||
data: `
|
|
||||||
- id: deworming
|
|
||||||
reason: Deworming for dogs
|
|
||||||
keywords: ["worms", "deworming", "parasite"]
|
|
||||||
procedures:
|
|
||||||
- name: Deworming tablet
|
|
||||||
price: 30
|
|
||||||
duration_minutes: 10
|
|
||||||
- name: Bloodwork
|
|
||||||
price: 35
|
|
||||||
duration_minutes: 35
|
|
||||||
notes: Bloodwork ensures organs are safe for treatment.
|
|
||||||
- id: vaccination
|
|
||||||
reason: Annual vaccination
|
|
||||||
keywords: ["vaccine", "vaccination", "shots"]
|
|
||||||
procedures:
|
|
||||||
- name: Vaccine injection
|
|
||||||
price: 50
|
|
||||||
duration_minutes: 15
|
|
||||||
`,
|
|
||||||
}
|
|
||||||
tdb.setup()
|
|
||||||
defer tdb.teardown()
|
|
||||||
|
|
||||||
if err := loadYAMLDB(tdb.file); err != nil {
|
|
||||||
t.Fatalf("Failed to load test db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r := setupRouter()
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
body := map[string]string{"message": "My dog needs deworming and bloodwork"}
|
|
||||||
jsonBody, _ := json.Marshal(body)
|
|
||||||
req, _ := http.NewRequest("POST", "/chat", bytes.NewBuffer(jsonBody))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("Expected 200, got %d", w.Code)
|
|
||||||
}
|
|
||||||
respBody, _ := io.ReadAll(w.Body)
|
|
||||||
if !bytes.Contains(respBody, []byte("deworming")) {
|
|
||||||
t.Errorf("Expected match for deworming, got %s", string(respBody))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChatEndpoint_NoMatch(t *testing.T) {
|
|
||||||
tdb := testDB{
|
|
||||||
file: "db.yaml",
|
|
||||||
data: `
|
|
||||||
- id: vaccination
|
|
||||||
reason: Annual vaccination
|
|
||||||
keywords: ["vaccine", "vaccination", "shots"]
|
|
||||||
procedures:
|
|
||||||
- name: Vaccine injection
|
|
||||||
price: 50
|
|
||||||
duration_minutes: 15
|
|
||||||
`,
|
|
||||||
}
|
|
||||||
tdb.setup()
|
|
||||||
defer tdb.teardown()
|
|
||||||
|
|
||||||
if err := loadYAMLDB(tdb.file); err != nil {
|
|
||||||
t.Fatalf("Failed to load test db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r := setupRouter()
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
body := map[string]string{"message": "My dog has worms"}
|
|
||||||
jsonBody, _ := json.Marshal(body)
|
|
||||||
req, _ := http.NewRequest("POST", "/chat", bytes.NewBuffer(jsonBody))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("Expected 200, got %d", w.Code)
|
|
||||||
}
|
|
||||||
respBody, _ := io.ReadAll(w.Body)
|
|
||||||
if !bytes.Contains(respBody, []byte(`"match":null`)) {
|
|
||||||
t.Errorf("Expected no match, got %s", string(respBody))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupRouter() *gin.Engine {
|
|
||||||
r := gin.Default()
|
|
||||||
r.POST("/chat", func(c *gin.Context) {
|
|
||||||
var req ChatRequest
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
keywords := naiveKeywordExtract(req.Message)
|
|
||||||
candidates := findCandidates(keywords)
|
|
||||||
if len(candidates) == 0 {
|
|
||||||
c.JSON(http.StatusOK, ChatResponse{Match: nil})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
best := candidates[0]
|
|
||||||
totalPrice, totalDuration := sumProcedures(best.Procedures)
|
|
||||||
c.JSON(http.StatusOK, ChatResponse{
|
|
||||||
Match: &best.ID,
|
|
||||||
Procedures: best.Procedures,
|
|
||||||
TotalPrice: totalPrice,
|
|
||||||
TotalDuration: totalDuration,
|
|
||||||
Notes: best.Notes,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
Loading…
Reference in New Issue