asr_pipeline / audio_processing.py
Manjot Singh
changed gpu compute type
706f640
raw
history blame
No virus
6.05 kB
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
hf_token = os.getenv("HF_TOKEN")
CHUNK_LENGTH=10
OVERLAP=0
import whisperx
import torch
import numpy as np
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import spaces
def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000): # 2 seconds overlap
chunks = []
for i in range(0, len(audio), chunk_size - overlap):
chunk = audio[i:i+chunk_size]
if len(chunk) < chunk_size:
chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
chunks.append(chunk)
return chunks
@spaces.GPU(duration=120)
def process_audio(audio_file, translate=False, model_size="small"):
start_time = time.time()
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
compute_type = "int8" if torch.cuda.is_available() else "float32"
audio = whisperx.load_audio(audio_file)
model = whisperx.load_model(model_size, device, compute_type=compute_type)
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
diarization_pipeline = diarization_pipeline.to(torch.device(device))
diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
chunks = preprocess_audio(audio)
language_segments = []
final_segments = []
overlap_duration = OVERLAP # 2 seconds overlap
for i, chunk in enumerate(chunks):
chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
chunk_end_time = chunk_start_time + CHUNK_LENGTH
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
lang = model.detect_language(chunk)
result_transcribe = model.transcribe(chunk, language=lang)
if translate:
result_translate = model.transcribe(chunk, task="translate")
chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
for j, t_seg in enumerate(result_transcribe["segments"]):
segment_start = chunk_start_time + t_seg["start"]
segment_end = chunk_start_time + t_seg["end"]
# Skip segments in the overlapping region of the previous chunk
if i > 0 and segment_end <= chunk_start_time + overlap_duration:
print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
continue
# Skip segments in the overlapping region of the next chunk
if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
continue
speakers = []
for turn, track, speaker in diarization_result.itertracks(yield_label=True):
if turn.start <= segment_end and turn.end >= segment_start:
speakers.append(speaker)
segment = {
"start": segment_start,
"end": segment_end,
"language": lang,
"speaker": max(set(speakers), key=speakers.count) if speakers else "Unknown",
"text": t_seg["text"],
}
if translate:
segment["translated"] = result_translate["segments"][j]["text"]
final_segments.append(segment)
language_segments.append({
"language": lang,
"start": chunk_start_time,
"end": chunk_start_time + CHUNK_LENGTH
})
chunk_end_time = time.time()
logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")
final_segments.sort(key=lambda x: x["start"])
merged_segments = merge_nearby_segments(final_segments)
end_time = time.time()
logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
return language_segments, merged_segments
except Exception as e:
logger.error(f"An error occurred during audio processing: {str(e)}")
raise
def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
merged = []
for segment in segments:
if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
merged.append(segment)
else:
# Find the overlap
matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
if match.size / len(segment['text']) > similarity_threshold:
# Merge the segments
merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
merged_translated = merged[-1]['translated'] + segment['translated'][match.b + match.size:]
merged[-1]['end'] = segment['end']
merged[-1]['text'] = merged_text
merged[-1]['translated'] = merged_translated
else:
# If no significant overlap, append as a new segment
merged.append(segment)
return merged
def print_results(segments):
for segment in segments:
print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:")
print(f"Original: {segment['text']}")
if 'translated' in segment:
print(f"Translated: {segment['translated']}")
print()