refactor break up a bit
This commit is contained in:
parent
3233692b66
commit
86fe25dbee
|
|
@ -0,0 +1,122 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type ChatService struct {
|
||||
LLM *LLMClient
|
||||
}
|
||||
|
||||
func NewChatService(llm *LLMClient) *ChatService {
|
||||
return &ChatService{LLM: llm}
|
||||
}
|
||||
|
||||
func (cs *ChatService) HandleChat(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
req, err := cs.parseRequest(c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
keywords, err := cs.extractKeywords(ctx, req.Message)
|
||||
if err != nil {
|
||||
cs.logChat(req, keywords, nil, "", err)
|
||||
c.JSON(http.StatusOK, ChatResponse{Match: nil})
|
||||
return
|
||||
}
|
||||
kwArr := cs.keywordsToStrings(keywords["keyword"])
|
||||
best, bestID, candidates, err := cs.findBestCandidate(ctx, req, kwArr)
|
||||
cs.logChat(req, keywords, candidates, bestID, err)
|
||||
resp := cs.buildResponse(best)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (cs *ChatService) parseRequest(c *gin.Context) (ChatRequest, error) {
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logrus.WithError(err).Error("Invalid request")
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||
return req, err
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (cs *ChatService) extractKeywords(ctx context.Context, message string) (map[string]interface{}, error) {
|
||||
keywords, err := cs.LLM.ExtractKeywords(ctx, message)
|
||||
return keywords, err
|
||||
}
|
||||
|
||||
func (cs *ChatService) keywordsToStrings(kwIface interface{}) []string {
|
||||
var kwArr []string
|
||||
switch v := kwIface.(type) {
|
||||
case []interface{}:
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
kwArr = append(kwArr, s)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
kwArr = v
|
||||
}
|
||||
return kwArr
|
||||
}
|
||||
|
||||
func (cs *ChatService) findBestCandidate(ctx context.Context, req ChatRequest, kwArr []string) (*Reason, string, []Reason, error) {
|
||||
candidates := findCandidates(kwArr)
|
||||
bestID := ""
|
||||
var err error
|
||||
if len(candidates) > 0 {
|
||||
bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidates)
|
||||
}
|
||||
var best *Reason
|
||||
for i := range candidates {
|
||||
if candidates[i].ID == bestID {
|
||||
best = &candidates[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil || len(kwArr) == 0 || len(candidates) == 0 || bestID == "" || best == nil {
|
||||
return nil, bestID, candidates, err
|
||||
}
|
||||
return best, bestID, candidates, nil
|
||||
}
|
||||
|
||||
func (cs *ChatService) buildResponse(best *Reason) ChatResponse {
|
||||
if best == nil {
|
||||
return ChatResponse{Match: nil}
|
||||
}
|
||||
totalPrice, totalDuration := sumProcedures(best.Procedures)
|
||||
return ChatResponse{
|
||||
Match: &best.ID,
|
||||
Procedures: best.Procedures,
|
||||
TotalPrice: totalPrice,
|
||||
TotalDuration: totalDuration,
|
||||
Notes: best.Notes,
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *ChatService) logChat(req ChatRequest, keywords map[string]interface{}, candidates []Reason, bestID string, err error) {
|
||||
logRequest(req, keywords, candidates, bestID, err)
|
||||
if candidates != nil && bestID != "" {
|
||||
var best *Reason
|
||||
for i := range candidates {
|
||||
if candidates[i].ID == bestID {
|
||||
best = &candidates[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if best != nil {
|
||||
totalPrice, totalDuration := sumProcedures(best.Procedures)
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"match": best.ID,
|
||||
"total_price": totalPrice,
|
||||
"total_duration": totalDuration,
|
||||
"notes": best.Notes,
|
||||
}).Info("Responding with match")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var reasonsDB []Reason
|
||||
|
||||
func loadYAMLDB(path string) error {
|
||||
data, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return yaml.Unmarshal(data, &reasonsDB)
|
||||
}
|
||||
|
||||
// findCandidates returns reasons with overlapping keywords
|
||||
func findCandidates(keywords []string) []Reason {
|
||||
kwSet := make(map[string]struct{})
|
||||
for _, k := range keywords {
|
||||
kwSet[k] = struct{}{}
|
||||
}
|
||||
var candidates []Reason
|
||||
for _, r := range reasonsDB {
|
||||
for _, k := range r.Keywords {
|
||||
if _, ok := kwSet[strings.ToLower(k)]; ok {
|
||||
candidates = append(candidates, r)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
// sumProcedures calculates total price and duration
|
||||
func sumProcedures(procs []Procedure) (int, int) {
|
||||
totalPrice := 0
|
||||
totalDuration := 0
|
||||
for _, p := range procs {
|
||||
totalPrice += p.Price
|
||||
totalDuration += p.DurationMin
|
||||
}
|
||||
return totalPrice, totalDuration
|
||||
}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// logRequest logs incoming chat requests and extracted info
|
||||
func logRequest(req ChatRequest, keywords map[string]interface{}, candidates []Reason, bestID string, err error) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"message": req.Message,
|
||||
"keywords": keywords,
|
||||
"candidates": getCandidateIDs(candidates),
|
||||
"bestID": bestID,
|
||||
"err": err,
|
||||
}).Info("Chat request trace")
|
||||
}
|
||||
|
||||
func getCandidateIDs(candidates []Reason) []string {
|
||||
ids := make([]string, len(candidates))
|
||||
for i, c := range candidates {
|
||||
ids[i] = c.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// Procedure represents a single procedure for a visit reason
|
||||
type Procedure struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Price int `yaml:"price" json:"price"`
|
||||
DurationMin int `yaml:"duration_minutes" json:"duration_minutes"`
|
||||
}
|
||||
|
||||
// Reason represents a visit reason entry
|
||||
type Reason struct {
|
||||
ID string `yaml:"id" json:"id"`
|
||||
Reason string `yaml:"reason" json:"reason"`
|
||||
Keywords []string `yaml:"keywords" json:"keywords"`
|
||||
Procedures []Procedure `yaml:"procedures" json:"procedures"`
|
||||
Notes string `yaml:"notes" json:"notes,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest represents the incoming chat message
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ChatResponse represents the response to the frontend
|
||||
type ChatResponse struct {
|
||||
Match *string `json:"match"`
|
||||
Procedures []Procedure `json:"procedures,omitempty"`
|
||||
TotalPrice int `json:"total_price,omitempty"`
|
||||
TotalDuration int `json:"total_duration,omitempty"`
|
||||
Notes string `json:"notes,omitempty"`
|
||||
}
|
||||
180
main.go
180
main.go
|
|
@ -1,115 +1,13 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Procedure represents a single procedure for a visit reason
|
||||
type Procedure struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Price int `yaml:"price" json:"price"`
|
||||
DurationMin int `yaml:"duration_minutes" json:"duration_minutes"`
|
||||
}
|
||||
|
||||
// Reason represents a visit reason entry
|
||||
type Reason struct {
|
||||
ID string `yaml:"id" json:"id"`
|
||||
Reason string `yaml:"reason" json:"reason"`
|
||||
Keywords []string `yaml:"keywords" json:"keywords"`
|
||||
Procedures []Procedure `yaml:"procedures" json:"procedures"`
|
||||
Notes string `yaml:"notes" json:"notes,omitempty"`
|
||||
}
|
||||
|
||||
var reasonsDB []Reason
|
||||
|
||||
func loadYAMLDB(path string) error {
|
||||
data, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return yaml.Unmarshal(data, &reasonsDB)
|
||||
}
|
||||
|
||||
// ChatRequest represents the incoming chat message
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ChatResponse represents the response to the frontend
|
||||
type ChatResponse struct {
|
||||
Match *string `json:"match"`
|
||||
Procedures []Procedure `json:"procedures,omitempty"`
|
||||
TotalPrice int `json:"total_price,omitempty"`
|
||||
TotalDuration int `json:"total_duration,omitempty"`
|
||||
Notes string `json:"notes,omitempty"`
|
||||
}
|
||||
|
||||
// naiveKeywordExtract splits message into lowercase words (placeholder for LLM)
|
||||
func naiveKeywordExtract(msg string) []string {
|
||||
// TODO: Replace with LLM call
|
||||
words := make(map[string]struct{})
|
||||
for _, w := range strings.FieldsFunc(strings.ToLower(msg), func(r rune) bool {
|
||||
return r < 'a' || r > 'z' && r < 'á' || r > 'ű'
|
||||
}) {
|
||||
words[w] = struct{}{}
|
||||
}
|
||||
res := make([]string, 0, len(words))
|
||||
for w := range words {
|
||||
res = append(res, w)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// findCandidates returns reasons with overlapping keywords
|
||||
func findCandidates(keywords []string) []Reason {
|
||||
kwSet := make(map[string]struct{})
|
||||
for _, k := range keywords {
|
||||
kwSet[k] = struct{}{}
|
||||
}
|
||||
var candidates []Reason
|
||||
for _, r := range reasonsDB {
|
||||
for _, k := range r.Keywords {
|
||||
if _, ok := kwSet[strings.ToLower(k)]; ok {
|
||||
candidates = append(candidates, r)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
// sumProcedures calculates total price and duration
|
||||
func sumProcedures(procs []Procedure) (int, int) {
|
||||
totalPrice := 0
|
||||
totalDuration := 0
|
||||
for _, p := range procs {
|
||||
totalPrice += p.Price
|
||||
totalDuration += p.DurationMin
|
||||
}
|
||||
return totalPrice, totalDuration
|
||||
}
|
||||
|
||||
var uiTemplate *template.Template
|
||||
|
||||
func loadUITemplate(path string) error {
|
||||
tmpl, err := template.ParseFiles(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uiTemplate = tmpl
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
logrus.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})
|
||||
logrus.SetLevel(logrus.InfoLevel)
|
||||
|
|
@ -129,87 +27,13 @@ func main() {
|
|||
APIKey: os.Getenv("OPENAI_API_KEY"),
|
||||
BaseURL: os.Getenv("OPENAI_BASE_URL"),
|
||||
}
|
||||
chatService := NewChatService(llm)
|
||||
r := gin.Default()
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Status(200)
|
||||
uiTemplate.Execute(c.Writer, nil)
|
||||
})
|
||||
r.POST("/chat", func(c *gin.Context) {
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logrus.WithError(err).Error("Invalid request")
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
keywords, err := llm.ExtractKeywords(ctx, req.Message)
|
||||
kwIface := keywords["keyword"]
|
||||
var kwArr []string
|
||||
switch v := kwIface.(type) {
|
||||
case []interface{}:
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
kwArr = append(kwArr, s)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
kwArr = v
|
||||
}
|
||||
candidates := findCandidates(kwArr)
|
||||
bestID := ""
|
||||
if len(candidates) > 0 && err == nil {
|
||||
bestID, err = llm.DisambiguateBestMatch(ctx, req.Message, candidates)
|
||||
}
|
||||
logRequest(req, keywords, candidates, bestID, err)
|
||||
if err != nil || len(keywords) == 0 || len(candidates) == 0 || bestID == "" {
|
||||
c.JSON(http.StatusOK, ChatResponse{Match: nil})
|
||||
return
|
||||
}
|
||||
var best *Reason
|
||||
for i := range candidates {
|
||||
if candidates[i].ID == bestID {
|
||||
best = &candidates[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if best == nil {
|
||||
c.JSON(http.StatusOK, ChatResponse{Match: nil})
|
||||
return
|
||||
}
|
||||
totalPrice, totalDuration := sumProcedures(best.Procedures)
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"match": best.ID,
|
||||
"total_price": totalPrice,
|
||||
"total_duration": totalDuration,
|
||||
"notes": best.Notes,
|
||||
}).Info("Responding with match")
|
||||
c.JSON(http.StatusOK, ChatResponse{
|
||||
Match: &best.ID,
|
||||
Procedures: best.Procedures,
|
||||
TotalPrice: totalPrice,
|
||||
TotalDuration: totalDuration,
|
||||
Notes: best.Notes,
|
||||
})
|
||||
})
|
||||
r.POST("/chat", chatService.HandleChat)
|
||||
|
||||
r.Run(":8080")
|
||||
}
|
||||
|
||||
// logRequest logs incoming chat requests and extracted info
|
||||
func logRequest(req ChatRequest, keywords map[string]interface{}, candidates []Reason, bestID string, err error) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"message": req.Message,
|
||||
"keywords": keywords,
|
||||
"candidates": getCandidateIDs(candidates),
|
||||
"bestID": bestID,
|
||||
"err": err,
|
||||
}).Info("Chat request trace")
|
||||
}
|
||||
|
||||
func getCandidateIDs(candidates []Reason) []string {
|
||||
ids := make([]string, len(candidates))
|
||||
for i, c := range candidates {
|
||||
ids[i] = c.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue