EugeneBos r/learnrust problem
244 lines
use whisper_rs::{WhisperContext, WhisperState, WhisperContextParameters, FullParams, SamplingStrategy};
use whisper_rs::{WhisperContext, WhisperState, WhisperContextParameters, FullParams, SamplingStrategy};
//// new
//// new
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex};
use once_cell::sync::Lazy;
use once_cell::sync::Lazy;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use byteorder::{LittleEndian, ReadBytesExt};
use byteorder::{LittleEndian, ReadBytesExt};
use vad_rs::{Vad, VadStatus};
use vad_rs::{Vad, VadStatus};
use std::process::{Command, Stdio};
use std::process::{Command, Stdio};
use std::thread;
use std::thread;
use std::sync::mpsc;
use std::sync::mpsc;
use tokio::time::Instant;
use tokio::time::Instant;
use lazy_static::lazy_static;
// State
// State
static IS_SPEECH: Lazy<Arc<AtomicBool>> = Lazy::new(|| Arc::new(AtomicBool::new(false)));
static IS_SPEECH: AtomicBool = AtomicBool::new(false);
static SPEECH_DUR: Lazy<Arc<AtomicUsize>> = Lazy::new(|| Arc::new(AtomicUsize::new(0)));
static SPEECH_DUR: AtomicUsize = AtomicUsize::new(0);
static SILENCE_DUR: Lazy<Arc<AtomicUsize>> = Lazy::new(|| Arc::new(AtomicUsize::new(0)));
static SILENCE_DUR: AtomicUsize = AtomicUsize::new(0);
static SPEECH_BUF: Lazy<Mutex<Vec<f32>>> = Lazy::new(|| Mutex::new(Vec::new()));
static SPEECH_BUF: Mutex<Vec<f32>> = Mutex::new(Vec::new());
// Options
// Options
static MIN_SPEECH_DUR: Lazy<usize> = Lazy::new(|| 50); // 0.6s
const MIN_SPEECH_DUR: usize = 50; // 0.6s
static MIN_SILENCE_DUR: Lazy<usize> = Lazy::new(|| 200); // 1s
const MIN_SILENCE_DUR: usize = 200; // 1s
// Vad
// Vad
static VAD_BUF: Lazy<Mutex<Vec<f32>>> = Lazy::new(|| Mutex::new(Vec::new()));
static VAD_BUF: Mutex<Vec<f32>> = Mutex::new(Vec::new());
/////////
/////////
mod audio_files;
mod audio_files;
static WHISPER_STATE: Lazy<Mutex<WhisperState>> = Lazy::new(|| {
// struct WhisperBundle {
let context = Box::leak(Box::new(
// context: WhisperContext,
WhisperContext::new_with_params(
// state: WhisperState<'static>,
"/home/eugenebos/Documents/rust/whisper/ggml-small.bin",
// }
WhisperContextParameters::default(),
)
// lazy_static! {
.expect("Context succeeds with known bin file"),
// static ref WHISPER_BUNDLE: Arc<Mutex<Option<WhisperBundle>>> = Arc::new(Mutex::new(None));
));
// static ref WHISPER_STATE: Arc<Mutex<Option<WhisperState>>> = Arc::new(Mutex::new(None));
Mutex::new(
// }
context
.create_state()
static WHISPER_STATE: Lazy<Arc<Mutex<Option<WhisperState>>>> = Lazy::new(|| Arc::new(Mutex::new(None)));
.expect("State should succeed with known bin file"),
)
});
fn load_whisper() -> () {
let path_to_model = "/home/eugenebos/Documents/rust/whisper/ggml-small.bin";
// Load the context
let params: WhisperContextParameters = WhisperContextParameters::default();
let context: WhisperContext = WhisperContext::new_with_params(&&path_to_model.to_string(), params).unwrap();
// Create the state
let state = context.create_state().expect("failed to create state");
// Create the bundle
// let bundle = WhisperBundle {
// state,
// context: context,
// };
*WHISPER_STATE.lock().unwrap() = Some(state);
// Store the bundle
// let mut bundle_guard = WHISPER_BUNDLE.lock().unwrap();
// *bundle_guard = Some(bundle);
()
}
fn transcribe(audio_data: Vec<f32>) {
fn transcribe(audio_data: Vec<f32>) {
// create a params object
// create a params object
let params = FullParams::new(
let params = FullParams::new(
SamplingStrategy::Greedy { best_of: 1 });
SamplingStrategy::Greedy { best_of: 1 });
{
let mut bundle_guard = WHISPER_BUNDLE.lock().unwrap();
let mut state = WHISPER_STATE.lock().unwrap();
if let Some(bundle) = &mut *bundle_guard {
let state = &mut bundle.state;
state
state
.full(params, &audio_data[..])
.full(params, &audio_data[..])
.expect("failed to run model");
.expect("failed to run model");
// fetch the results
// fetch the results
let num_segments = state
let num_segments = state
.full_n_segments()
.full_n_segments()
.expect("failed to get number of segments");
.expect("failed to get number of segments");
for i in 0..num_segments {
for i in 0..num_segments {
let segment = state
let segment = state
.full_get_segment_text(i)
.full_get_segment_text(i)
.expect("failed to get segment");
.expect("failed to get segment");
let start_timestamp = state
let start_timestamp = state
.full_get_segment_t0(i)
.full_get_segment_t0(i)
.expect("failed to get segment start timestamp");
.expect("failed to get segment start timestamp");
let end_timestamp = state
let end_timestamp = state
.full_get_segment_t1(i)
.full_get_segment_t1(i)
.expect("failed to get segment end timestamp");
.expect("failed to get segment end timestamp");
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
}
}
} else {
println!("Whisper is busy? state is None");
}
}
}
}
fn on_stream_data(input: Vec<f32>, vad_handle: Arc<Mutex<Vad>>, processing: &mut bool) {
fn on_stream_data(input: Vec<f32>, vad_handle: Arc<Mutex<Vad>>, processing: &mut bool) {
println!("Data chunk: {}", input.len());
println!("Data chunk: {}", input.len());
// Convert the input samples to f32
// Convert the input samples to f32
let sample_rate = 16000;
let sample_rate = 16000;
let chunk_size = (30 * sample_rate / 1000) as usize;
let chunk_size = (30 * sample_rate / 1000) as usize;
let mut vad = vad_handle.lock().unwrap();
let mut vad = vad_handle.lock().unwrap();
let mut vad_buf = VAD_BUF.lock().unwrap();
let mut vad_buf = VAD_BUF.lock().unwrap();
vad_buf.extend(input.clone());
vad_buf.extend(input.clone());
if IS_SPEECH.load(Ordering::Relaxed) {
if IS_SPEECH.load(Ordering::Acquire) {
SPEECH_BUF.lock().unwrap().extend(input.clone());
SPEECH_BUF.lock().unwrap().extend(input.clone());
}
}
if vad_buf.len() as f32 > sample_rate as f32 * 0.1 {
if vad_buf.len() as f32 > sample_rate as f32 * 0.1 {
let start_time = Instant::now();
let start_time = Instant::now();
if let Ok(mut result) = vad.compute(&input) {
if let Ok(mut result) = vad.compute(&input) {
// Calculate the elapsed time
// Calculate the elapsed time
let elapsed_time = start_time.elapsed();
let elapsed_time = start_time.elapsed();
let elapsed_ms = elapsed_time.as_secs_f64() * 1000.0;
let elapsed_ms = elapsed_time.as_secs_f64() * 1000.0;
// Log or handle the situation if computation time exceeds a threshold
// Log or handle the situation if computation time exceeds a threshold
if elapsed_ms > 100.0 {
if elapsed_ms > 100.0 {
eprintln!(
eprintln!(
"Warning: VAD computation took too long: {} ms (expected < 30 ms)",
"Warning: VAD computation took too long: {} ms (expected < 30 ms)",
elapsed_ms
elapsed_ms
);
);
}
}
match result.status() {
match result.status() {
VadStatus::Speech => {
VadStatus::Speech => {
SPEECH_DUR.fetch_add(chunk_size, Ordering::Relaxed);
SPEECH_DUR.fetch_add(chunk_size, Ordering::AcqRel);
if SPEECH_DUR.load(Ordering::Relaxed) >= *MIN_SPEECH_DUR
if SPEECH_DUR.load(Ordering::Acquire) >= MIN_SPEECH_DUR
&& !IS_SPEECH.load(Ordering::Relaxed)
&& !IS_SPEECH.load(Ordering::Acquire)
{
{
println!("Speech Start");
println!("Speech Start");
SILENCE_DUR.store(0, Ordering::Relaxed);
SILENCE_DUR.store(0, Ordering::Release);
IS_SPEECH.store(true, Ordering::Relaxed);
IS_SPEECH.store(true, Ordering::Release);
vad_buf.extend(input.clone());
vad_buf.extend(input.clone());
}
}
}
}
VadStatus::Silence => {
VadStatus::Silence => {
SILENCE_DUR.fetch_add(chunk_size, Ordering::Relaxed);
SILENCE_DUR.fetch_add(chunk_size, Ordering::AcqRel);
if SILENCE_DUR.load(Ordering::Relaxed) >= *MIN_SILENCE_DUR
if SILENCE_DUR.load(Ordering::Acquire) >= MIN_SILENCE_DUR
&& IS_SPEECH.load(Ordering::Relaxed)
&& IS_SPEECH.load(Ordering::Acquire)
{
{
println!("Speech End");
println!("Speech End");
match audio_files::save_audio_to_wav(&vad_buf, "stream.wav") {
match audio_files::save_audio_to_wav(&vad_buf, "stream.wav") {
Ok(_) => println!("Audio data saved successfully as WAV"),
Ok(_) => println!("Audio data saved successfully as WAV"),
Err(e) => eprintln!("Error saving audio data: {}", e),
Err(e) => eprintln!("Error saving audio data: {}", e),
}
}
vad_buf.clear();
vad_buf.clear();
transcribe(vad_buf.clone());
transcribe(vad_buf.clone());
*processing = false;
*processing = false;
SPEECH_DUR.store(0, Ordering::Relaxed);
SPEECH_DUR.store(0, Ordering::Release);
IS_SPEECH.store(false, Ordering::Relaxed);
IS_SPEECH.store(false, Ordering::Release);
}
}
}
}
_ => {}
_ => {}
}
}
}
}
} else {
} else {
eprintln!("Some error 1");
eprintln!("Some error 1");
}
}
}
}
async fn listen_stream() -> Result<(), Box<dyn std::error::Error>> {
async fn listen_stream() -> Result<(), Box<dyn std::error::Error>> {
let vad: Vad = Vad::new("silero_vad.onnx", 16000).unwrap();
let vad: Vad = Vad::new("silero_vad.onnx", 16000).unwrap();
let vad_handle = Arc::new(Mutex::new(vad));
let vad_handle = Arc::new(Mutex::new(vad));
let mut processing = true;
let mut processing = true;
let mut cmd: std::process::Child = Command::new("sh")
let mut cmd: std::process::Child = Command::new("sh")
.arg("-c")
.arg("-c")
.arg("streamlink \"--twitch-api-header=Authorization=OAuth INSERT_OAUTH_TOKEN_HERE\" https://www.twitch.tv/sometalkingstreamer audio_only --twitch-disable-ads -O | ffmpeg -i - -f f32le -acodec pcm_f32le -ar 16000 -ac 1 pipe:1")
.arg("streamlink \"--twitch-api-header=Authorization=OAuth INSERT_OAUTH_TOKEN_HERE\" https://www.twitch.tv/sometalkingstreamer audio_only --twitch-disable-ads -O | ffmpeg -i - -f f32le -acodec pcm_f32le -ar 16000 -ac 1 pipe:1")
.stdout(Stdio::piped())
.stdout(Stdio::piped())
.spawn()?;
.spawn()?;
let (tx, rx) = mpsc::channel();
let (tx, rx) = mpsc::channel();
let reader_thread = thread::spawn(move || {
let reader_thread = thread::spawn(move || {
let stdout = cmd.stdout.take().expect("Failed to get stdout");
let stdout = cmd.stdout.take().expect("Failed to get stdout");
let mut reader = std::io::BufReader::new(stdout);
let mut reader = std::io::BufReader::new(stdout);
let mut buffer = Vec::with_capacity(16000); // 1 second of audio at 16kHz
let mut buffer = Vec::with_capacity(16000); // 1 second of audio at 16kHz
loop {
loop {
buffer.clear();
buffer.clear();
for _ in 0..16000 {
for _ in 0..16000 {
match reader.read_f32::<LittleEndian>() {
match reader.read_f32::<LittleEndian>() {
Ok(sample) => buffer.push(sample),
Ok(sample) => buffer.push(sample),
Err(e) => {
Err(e) => {
eprintln!("End of stream or error: {:?}", e);
eprintln!("End of stream or error: {:?}", e);
return;
return;
},
},
}
}
}
}
if tx.send(buffer.clone()).is_err() {
if tx.send(buffer.clone()).is_err() {
eprintln!("Channel closed, stop reading");
eprintln!("Channel closed, stop reading");
break;
break;
}
}
}
}
});
});
// Main thread: process the audio data
// Main thread: process the audio data
for (_chunk_index, chunk) in rx.iter().enumerate() {
for chunk in rx.iter() {
on_stream_data(chunk, vad_handle.clone(), &mut processing);
on_stream_data(chunk, vad_handle.clone(), &mut processing);
}
}
reader_thread.join().expect("Reader thread panicked");
reader_thread.join().expect("Reader thread panicked");
Ok(())
Ok(())
}
}
#[tokio::main]
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
async fn main() -> Result<(), Box<dyn std::error::Error>> {
//let bundle: Result<(), Box<dyn Error>> = load_whisper();
//let bundle: Result<(), Box<dyn Error>> = load_whisper();
listen_stream().await?; // &bundle.state
listen_stream().await?; // &bundle.state
Ok(())
Ok(())
}
}