Files
go-rtsp-audio-detection-vad/main.go
2025-08-18 11:18:27 -04:00

385 lines
10 KiB
Go

package main
import (
"flag"
"fmt"
"io"
"log/slog"
"os"
"time"
"encoding/binary"
"github.com/bluenviron/gortsplib/v4"
"github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/mediacommon/v2/pkg/codecs/g711"
"github.com/maxhawkins/go-webrtcvad"
"github.com/pion/rtp"
)
// CircularBuffer is a simple fixed-size circular buffer for audio data
type CircularBuffer struct {
data []byte
size int
head int
tail int
isFull bool
}
func NewCircularBuffer(size int) *CircularBuffer {
return &CircularBuffer{
data: make([]byte, size),
size: size,
head: 0,
tail: 0,
isFull: false,
}
}
func (b *CircularBuffer) Write(p []byte) (n int, err error) {
if len(p) > b.size {
return 0, fmt.Errorf("write data exceeds buffer size")
}
n = len(p)
remaining := b.size - b.tail
// Write data to the buffer
if remaining >= len(p) {
// Enough space from tail to end
copy(b.data[b.tail:], p)
b.tail += len(p)
} else {
// Split write: part to end, part from beginning
copy(b.data[b.tail:], p[:remaining])
copy(b.data[0:], p[remaining:])
b.tail = len(p) - remaining
}
// Update head if buffer is full
if b.tail > b.head || (b.tail == b.head && b.isFull) {
b.head = b.tail
b.isFull = true
}
return n, nil
}
func (b *CircularBuffer) Read(p []byte) (n int, err error) {
if b.head == b.tail && !b.isFull {
return 0, io.EOF
}
available := b.Len()
n = len(p)
if n > available {
n = available
}
remaining := b.size - b.head
if remaining >= n {
// Read from head to n
copy(p, b.data[b.head:b.head+n])
b.head += n
} else {
// Read part from head to end, part from beginning
copy(p, b.data[b.head:])
copy(p[remaining:], b.data[:n-remaining])
b.head = n - remaining
}
if b.head == b.size {
b.head = 0
}
if b.head == b.tail {
b.isFull = false
}
return n, nil
}
func (b *CircularBuffer) Len() int {
if b.isFull {
return b.size
}
if b.tail >= b.head {
return b.tail - b.head
}
return b.size - b.head + b.tail
}
func (b *CircularBuffer) Reset() {
b.head = 0
b.tail = 0
b.isFull = false
}
// Swap bytes in PCM data to convert from big-endian to little-endian
func toLittleEndian(pcm []byte) {
for i := 0; i < len(pcm); i += 2 {
pcm[i], pcm[i+1] = pcm[i+1], pcm[i]
}
}
// This example shows how to
// 1. connect to a RTSP server.
// 2. check if there's a G711 stream.
// 3. decode the G711 stream into audio samples.
// 4. detect audio and silence with a 2-5 second silence threshold.
// 5. buffer audio only during detected audio messages and save to a WAV file when silence is detected.
// 6. count the duration of continuous audio before silence is detected.
func main() {
// Command-line arguments
rtspURL := flag.String("rtsp", "", "RTSP URL (e.g., rtsp://localhost:8554/stream)")
vadMode := flag.Int("vad-mode", 3, "VAD sensitivity mode (0-3, 3 is most aggressive)")
frameMs := flag.Int("frame-ms", 20, "VAD frame duration in milliseconds (10, 20, or 30)")
logLevel := flag.String("log-level", "info", "Log level (debug, info, warn)")
saveBuffer := flag.Bool("save-buffer", false, "Save audio message to a WAV file when silence is detected")
flag.Parse()
if *rtspURL == "" {
slog.Error("RTSP URL is required")
flag.Usage()
os.Exit(1)
}
// Structured logging setup
var lvl slog.Level
switch *logLevel {
case "debug":
lvl = slog.LevelDebug
case "info":
lvl = slog.LevelInfo
case "warn":
lvl = slog.LevelWarn
default:
slog.Error("Invalid log level", "level", *logLevel)
os.Exit(1)
}
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: lvl}))
slog.SetDefault(logger)
// Parse URL
u, err := base.ParseURL(*rtspURL)
if err != nil {
slog.Error("Failed to parse RTSP URL", "error", err, "url", *rtspURL)
os.Exit(1)
}
c := gortsplib.Client{
Scheme: u.Scheme,
Host: u.Host,
}
// Connect to the server
err = c.Start2()
if err != nil {
slog.Error("Failed to connect to RTSP server", "error", err)
os.Exit(1)
}
defer c.Close()
// Find available medias
desc, _, err := c.Describe(u)
if err != nil {
slog.Error("Failed to describe RTSP stream", "error", err)
os.Exit(1)
}
// Find the G711 media and format
var forma *format.G711
medi := desc.FindFormat(&forma)
if medi == nil {
slog.Error("G711 media not found in RTSP stream")
os.Exit(1)
}
// Setup a single media
_, err = c.Setup(desc.BaseURL, medi, 0, 0)
if err != nil {
slog.Error("Failed to setup RTSP media", "error", err)
os.Exit(1)
}
// Initialize VAD
vad, err := webrtcvad.New()
if err != nil {
slog.Error("Failed to initialize VAD", "error", err)
os.Exit(1)
}
if err := vad.SetMode(*vadMode); err != nil {
slog.Error("Failed to set VAD mode", "mode", *vadMode, "error", err)
os.Exit(1)
}
slog.Info("Initialized VAD", "mode", *vadMode)
// Validate frame duration
const sampleRate = 8000
if *frameMs != 10 && *frameMs != 20 && *frameMs != 30 {
slog.Error("Invalid frame duration", "frame_ms", *frameMs, "allowed", "10, 20, or 30")
os.Exit(1)
}
frameSamples := sampleRate * *frameMs / 1000
frameBytes := frameSamples * 2 // 16-bit PCM
if ok := vad.ValidRateAndFrameLength(sampleRate, frameBytes); !ok {
slog.Error("Invalid rate or frame length for VAD", "sample_rate", sampleRate, "frame_bytes", frameBytes)
os.Exit(1)
}
slog.Debug("VAD parameters", "sample_rate", sampleRate, "frame_ms", *frameMs, "frame_bytes", frameBytes)
// Initialize audio processing
var pcmBuffer []byte
var isSilent = true
var silenceStart time.Time
var audioStart time.Time
const minSilenceDuration = 4 * time.Second
const maxSilenceDuration = 6 * time.Second
// Initialize ring buffer for audio messages (sized for 30 seconds to handle long messages)
const bufferDuration = 30 * time.Second
const bytesPerSecond = sampleRate * 2 // 16-bit PCM at 8000 Hz
bufferSize := bytesPerSecond * int(bufferDuration.Seconds())
audioBuffer := NewCircularBuffer(bufferSize)
slog.Info("Initialized audio buffer", "size_bytes", bufferSize, "duration_s", bufferDuration.Seconds())
// Use Mu-law decoding
slog.Info("Using Mu-law decoding")
decodeFunc := func(data []byte) g711.Mulaw {
var raw g711.Mulaw
raw.Unmarshal(data)
toLittleEndian(raw)
return raw
}
// Function to save buffer to WAV file
saveBufferToWAV := func(filename string, buffer *CircularBuffer) error {
if buffer.Len() == 0 {
return fmt.Errorf("no audio data to save")
}
file, err := os.Create(filename)
if err != nil {
return fmt.Errorf("failed to create WAV file: %v", err)
}
defer file.Close()
// Write WAV header
dataSize := buffer.Len()
header := make([]byte, 44)
copy(header[0:4], []byte("RIFF"))
binary.LittleEndian.PutUint32(header[4:8], uint32(36+dataSize)) // File size
copy(header[8:12], []byte("WAVE"))
copy(header[12:16], []byte("fmt "))
binary.LittleEndian.PutUint32(header[16:20], 16) // Subchunk1 size
binary.LittleEndian.PutUint16(header[20:22], 1) // Audio format (PCM)
binary.LittleEndian.PutUint16(header[22:24], 1) // Num channels
binary.LittleEndian.PutUint32(header[24:28], sampleRate) // Sample rate
binary.LittleEndian.PutUint32(header[28:32], sampleRate*2) // Byte rate
binary.LittleEndian.PutUint16(header[32:34], 2) // Block align
binary.LittleEndian.PutUint16(header[34:36], 16) // Bits per sample
copy(header[36:40], []byte("data"))
binary.LittleEndian.PutUint32(header[40:44], uint32(dataSize)) // Data size
if _, err := file.Write(header); err != nil {
return fmt.Errorf("failed to write WAV header: %v", err)
}
// Write buffer data
data := make([]byte, dataSize)
_, err = buffer.Read(data)
if err != nil && err != io.EOF {
return fmt.Errorf("failed to read from buffer: %v", err)
}
if _, err := file.Write(data); err != nil {
return fmt.Errorf("failed to write WAV data: %v", err)
}
slog.Info("Saved audio message to WAV", "filename", filename, "size_bytes", dataSize)
return nil
}
// Process RTP packets
c.OnPacketRTP(medi, forma, func(pkt *rtp.Packet) {
pcm := decodeFunc(pkt.Payload)
if len(pcm) == 0 {
slog.Warn("Empty PCM data after decoding, skipping")
return
}
pcmBuffer = append(pcmBuffer, pcm...)
for len(pcmBuffer) >= frameBytes {
frame := pcmBuffer[:frameBytes]
pcmBuffer = pcmBuffer[frameBytes:]
active, err := vad.Process(sampleRate, frame)
if err != nil {
slog.Warn("VAD processing error", "error", err)
return
}
now := time.Now()
if active {
// Audio detected, start buffering
if isSilent {
slog.Info("Audio begins (silence ends)", "timestamp", now.Format("2006-01-02 15:04:05"))
isSilent = false
audioStart = now // Start tracking audio
audioBuffer.Reset() // Clear buffer for new audio message
}
// Add PCM data to buffer only during audio
_, err := audioBuffer.Write(frame)
if err != nil {
slog.Warn("Failed to write to audio buffer", "error", err)
}
silenceStart = time.Time{} // Clear silence start
} else {
// Silence detected
if !isSilent {
if silenceStart.IsZero() {
silenceStart = now
} else if now.Sub(silenceStart) >= minSilenceDuration && now.Sub(silenceStart) <= maxSilenceDuration {
// Log audio duration if audio was active
var audioDurationMs int64
if !audioStart.IsZero() {
audioDurationMs = now.Sub(audioStart).Milliseconds()
}
slog.Info("Silence detected",
"timestamp", now.Format("2006-01-02 15:04:05"),
"silence_duration_ms", now.Sub(silenceStart).Milliseconds(),
"audio_duration_ms", audioDurationMs)
isSilent = true
// Optionally save buffer on silence detection
if *saveBuffer {
filename := fmt.Sprintf("audio_buffer_%s.wav", now.Format("20060102_150405"))
if err := saveBufferToWAV(filename, audioBuffer); err != nil {
slog.Error("Failed to save audio buffer", "error", err)
}
}
audioStart = time.Time{} // Reset audio start time
silenceStart = time.Time{} // Reset silence start time to allow new silence detection
}
}
}
// slog.Debug("Processed audio frame", "active", active)
}
})
// Start playing
_, err = c.Play(nil)
if err != nil {
slog.Error("Failed to start RTSP playback", "error", err)
os.Exit(1)
}
slog.Info("Started RTSP playback")
// Wait for errors or interruption
err = c.Wait()
if err != nil {
slog.Error("RTSP client error", "error", err)
os.Exit(1)
}
}