Files
go-receipt-tracker/workers/openai.go
2025-02-03 17:50:47 -05:00

161 lines
3.9 KiB
Go

package workers
import (
"context"
"log/slog"
"sync/atomic"
"git.savin.nyc/alex/go-receipt-tracker/bus"
"git.savin.nyc/alex/go-receipt-tracker/config"
"git.savin.nyc/alex/go-receipt-tracker/utils"
"github.com/sashabaranov/go-openai"
)
type OpenAi struct {
id string
model string
opts *config.Parser
client *openai.Client
subscriptions map[string]chan bus.Event //
bus *bus.Bus
log *slog.Logger
cancel context.CancelFunc //
end uint32 // ensure the close methods are only called once
}
func NewOpenAi(cfg *config.Parser, b *bus.Bus) *OpenAi {
ai := &OpenAi{
id: cfg.Type, // "openai",
model: cfg.Model, // "gpt-4o-mini"
opts: cfg,
subscriptions: make(map[string]chan bus.Event),
bus: b,
}
return ai
}
func (ai *OpenAi) ID() string {
return ai.id
}
func (ai *OpenAi) Type() string {
return "parser"
}
func (ai *OpenAi) Init(log *slog.Logger) error {
ai.client = openai.NewClient(ai.opts.ApiKey)
ai.log = log
return nil
}
func (ai *OpenAi) Serve() {
if atomic.LoadUint32(&ai.end) == 1 {
return
}
ai.subscribe("parser:" + ai.ID())
ctx, cancel := context.WithCancel(context.Background())
ai.cancel = cancel
go ai.eventLoop(ctx)
}
func (ai *OpenAi) OneTime() error {
return nil
}
func (ai *OpenAi) Stop() {
}
func (ai *OpenAi) Close() {
ai.cancel()
}
func (ai *OpenAi) subscribe(chn string) error {
s, err := ai.bus.Subscribe(chn, ai.Type()+":"+ai.ID())
if err != nil {
ai.log.Error("couldn't subscribe to a channel", "channel", chn, "error", err.Error())
return err
}
ai.subscriptions[chn] = s
return nil
}
// eventLoop loops forever
func (ai *OpenAi) eventLoop(ctx context.Context) {
ai.log.Debug(ai.ID() + " communication event loop is started")
defer ai.log.Debug(ai.ID() + " communication event loop halted")
for {
ai.log.Debug("looping over the subscriptions", "channel", "parser:"+ai.ID(), "worker", ai.ID())
select {
case event := <-ai.subscriptions["parser:"+ai.ID()]:
ai.log.Debug("worker got a new message from a channel", "channel", "parser:"+ai.ID(), "type", ai.Type(), "id", ai.ID())
switch event.Payload.(type) {
case *bus.Message:
msg := event.Payload.(*bus.Message)
res, err := ai.recognize("data:" + event.Payload.(*bus.Message).Image.Type + ";base64," + event.Payload.(*bus.Message).Image.Base64)
if err != nil {
ai.log.Error("got an error from parser ("+ai.ID()+")", "error", err)
}
msg.Image.Parsed[ai.ID()] = res
err = ai.bus.Publish("processor:receipt_add", msg)
if err != nil {
ai.log.Error("couldn't publish to a channel", "channel", "processor:receipt_add", "error", err.Error())
}
}
case <-ctx.Done():
ai.log.Info("stopping " + ai.ID() + " communication event loop")
return
}
}
}
func (ai *OpenAi) recognize(img string) (res string, err error) {
resp, err := ai.client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
MaxTokens: 1000,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
MultiContent: []openai.ChatMessagePart{
{
Type: openai.ChatMessagePartTypeText,
Text: config.Request,
},
{
Type: openai.ChatMessagePartTypeImageURL,
ImageURL: &openai.ChatMessageImageURL{
URL: img,
Detail: openai.ImageURLDetailHigh,
},
},
},
},
},
},
)
if err != nil {
ai.log.Error("error during recognition process", "%+v", resp.Choices[0].Message.Content)
return "", err
}
if !utils.IsJSON(resp.Choices[0].Message.Content) {
ai.log.Error("returned not valid JSON", "%+v", resp.Choices[0].Message.Content)
return "", err
}
ai.log.Debug("recognition result", "worker", ai.ID(), "output", resp.Choices[0].Message.Content)
return resp.Choices[0].Message.Content, nil
}