WhisperLivesubs / app.py
codewithdark's picture
Rename demo.py to app.py
81d799d verified
raw
history blame contribute delete
No virus
2.53 kB
import streamlit as st
import sounddevice as sd
import numpy as np
import torch
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import soundfile as sf # Using soundfile for audio file handling
import librosa
# Load model
@st.cache_resource
def load_model():
processor = AutoProcessor.from_pretrained("codewithdark/WhisperLiveSubs")
model = AutoModelForSpeechSeq2Seq.from_pretrained("codewithdark/WhisperLiveSubs")
return processor, model
try:
processor, model = load_model()
except ConnectionError as e:
st.error(f"Error loading model: Check your Internet Connection")
except Exception as e:
st.error(f"Error loading model: Please try again")
# Function to transcribe audio
def transcribe_audio(audio, sample_rate):
# Ensure audio is in the expected format
audio = np.array(audio) # Convert to numpy array if needed
input_features = processor(audio, sampling_rate=sample_rate, return_tensors="pt").input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription[0]
# Streamlit app
st.title("Speech-to-Text Transcription")
# File upload
uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3"])
if uploaded_file is not None:
try:
# Read the audio file
audio_data, sample_rate = sf.read(uploaded_file)
# Resample if necessary
target_sample_rate = 16000
if sample_rate != target_sample_rate:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=target_sample_rate)
# Ensure audio_data is 1D
if audio_data.ndim > 1:
audio_data = audio_data.mean(axis=1)
st.audio(uploaded_file, format="audio/wav")
transcription = transcribe_audio(audio_data, target_sample_rate)
st.write("Transcription:", transcription)
except Exception as e:
st.error(f"Error processing the file: {e}")
# Real-time voice input
if st.button("Start Recording"):
duration = 15 # Record for 15 seconds
sample_rate = 16000
st.write("Recording...")
recording = sd.rec(int(duration * sample_rate), samplerate=sample_rate, channels=1)
sd.wait()
st.write("Recording finished!")
audio_data = recording.flatten()
transcription = transcribe_audio(audio_data, sample_rate)
st.write("Transcription:", transcription)