147 lines
3.4 KiB
Go
147 lines
3.4 KiB
Go
package parser
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"sync/atomic"
|
|
|
|
"git.savin.nyc/alex/go-receipt-tracker/bus"
|
|
"git.savin.nyc/alex/go-receipt-tracker/config"
|
|
"git.savin.nyc/alex/go-receipt-tracker/utils"
|
|
"github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
type OpenAi struct {
|
|
id string
|
|
model string
|
|
opts *config.Parser
|
|
client *openai.Client
|
|
subscriptions map[string]chan bus.Event //
|
|
bus *bus.Bus
|
|
log *slog.Logger
|
|
cancel context.CancelFunc //
|
|
end uint32 // ensure the close methods are only called once
|
|
}
|
|
|
|
func NewOpenAi(cfg *config.Parser, bus *bus.Bus) *OpenAi {
|
|
|
|
ai := &OpenAi{
|
|
id: cfg.Type, // "openai",
|
|
model: cfg.Model, // "gpt-4o-mini"
|
|
opts: cfg,
|
|
bus: bus,
|
|
}
|
|
return ai
|
|
}
|
|
|
|
func (ai *OpenAi) ID() string {
|
|
return ai.id
|
|
}
|
|
|
|
func (ai *OpenAi) Type() string {
|
|
return "parser"
|
|
}
|
|
|
|
func (ai *OpenAi) Init(log *slog.Logger) error {
|
|
ai.client = openai.NewClient(ai.opts.ApiKey)
|
|
ai.log = log
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ai *OpenAi) Serve() {
|
|
if atomic.LoadUint32(&ai.end) == 1 {
|
|
return
|
|
}
|
|
|
|
ai.subscribe("parser:" + ai.ID())
|
|
ai.subscribe("parser:*")
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
ai.cancel = cancel
|
|
|
|
go ai.eventLoop(ctx)
|
|
}
|
|
|
|
func (ai *OpenAi) subscribe(chn string) error {
|
|
s, err := ai.bus.Subscribe(chn, ai.Type()+":"+ai.ID())
|
|
if err != nil {
|
|
ai.log.Error("couldn't subscribe to a channel", "channel", chn, "error", err.Error())
|
|
return err
|
|
}
|
|
ai.subscriptions[chn] = s
|
|
|
|
return nil
|
|
}
|
|
|
|
// eventLoop loops forever
|
|
func (ai *OpenAi) eventLoop(ctx context.Context) {
|
|
ai.log.Debug(ai.ID() + " communication event loop started")
|
|
defer ai.log.Debug(ai.ID() + " communication event loop halted")
|
|
|
|
for {
|
|
for chn, ch := range ai.subscriptions {
|
|
select {
|
|
case event := <-ch:
|
|
switch event.Payload.(type) {
|
|
case bus.Image:
|
|
ai.log.Debug("got a new message to a channel", "channel", chn)
|
|
|
|
res, err := ai.recognize(event.Payload.(bus.Image).Base64)
|
|
if err != nil {
|
|
ai.log.Error("got an error from parser ("+ai.ID()+")", "error", err)
|
|
}
|
|
err = ai.bus.Publish("telegram:publish", res)
|
|
if err != nil {
|
|
ai.log.Error("couldn't publish to a channel", "channel", "telegram:publish", "error", err.Error())
|
|
}
|
|
}
|
|
case <-ctx.Done():
|
|
ai.log.Info("stopping " + ai.ID() + " communication event loop")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ai *OpenAi) recognize(img string) (res string, err error) {
|
|
resp, err := ai.client.CreateChatCompletion(
|
|
context.Background(),
|
|
openai.ChatCompletionRequest{
|
|
Model: openai.GPT4oMini,
|
|
MaxTokens: 1000,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
MultiContent: []openai.ChatMessagePart{
|
|
{
|
|
Type: openai.ChatMessagePartTypeText,
|
|
Text: config.Request,
|
|
},
|
|
{
|
|
Type: openai.ChatMessagePartTypeImageURL,
|
|
ImageURL: &openai.ChatMessageImageURL{
|
|
URL: img,
|
|
Detail: openai.ImageURLDetailHigh,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
)
|
|
|
|
if err != nil {
|
|
ai.log.Error("error during recognition process", "%+v", resp.Choices[0].Message.Content)
|
|
return "", err
|
|
}
|
|
if !utils.IsJSON(resp.Choices[0].Message.Content) {
|
|
ai.log.Error("OpenAI returned not valid JSON", "%+v", resp.Choices[0].Message.Content)
|
|
return "", err
|
|
}
|
|
|
|
ai.log.Debug("recognition output", "%+v", resp.Choices[0].Message.Content)
|
|
|
|
return resp.Choices[0].Message.Content, nil
|
|
}
|