Daryl Fung commited on
Commit
8d83939
1 Parent(s): 0fd4a4d

finalize mvp

Browse files
Files changed (2) hide show
  1. app.py +18 -6
  2. db/query_db.py +1 -1
app.py CHANGED
@@ -1,16 +1,17 @@
1
  from fastapi import FastAPI
 
2
  import uvicorn
3
- import faiss
 
4
  from sentence_transformers import SentenceTransformer
5
  from pymilvus import Collection
 
 
6
 
7
  from db.db_connect import connect, disconnect
8
  from db.query_db import query
9
 
10
  model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
11
- index = faiss.IndexFlatL2(model.get_sentence_embedding_dimension()) # build the index
12
-
13
- index.add(model.encode(['hello']))
14
 
15
  app = FastAPI()
16
 
@@ -48,8 +49,19 @@ async def transcribe(text: str):
48
  insert_response_to_generate_for_audio(text, embeddings)
49
  audio = await query(WAIT_RESPONSES_EMBEDDINGS, threshold=0.8)
50
 
51
- return audio
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  if __name__ == '__main__':
55
- uvicorn.run('app:app', host='0.0.0.0', port=7860)
 
1
  from fastapi import FastAPI
2
+ from fastapi.responses import Response
3
  import uvicorn
4
+ import numpy as np
5
+ import io
6
  from sentence_transformers import SentenceTransformer
7
  from pymilvus import Collection
8
+ import soundfile as sf
9
+ from bark import SAMPLE_RATE
10
 
11
  from db.db_connect import connect, disconnect
12
  from db.query_db import query
13
 
14
  model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
 
 
15
 
16
  app = FastAPI()
17
 
 
49
  insert_response_to_generate_for_audio(text, embeddings)
50
  audio = await query(WAIT_RESPONSES_EMBEDDINGS, threshold=0.8)
51
 
52
+ # convert audio bytes to appropriate format to return
53
+ audio_file = io.BytesIO(np.frombuffer(audio, dtype=np.int16))
54
+ audio, sample_rate = sf.read(audio_file)
55
+
56
+ audio_file = io.BytesIO()
57
+ sf.write(audio_file, audio, sample_rate, format='wav')
58
+ audio_file.seek(0)
59
+
60
+ return Response(
61
+ content=audio_file.read(),
62
+ media_type="audio/wav", # Same as the Content-Type header
63
+ )
64
 
65
 
66
  if __name__ == '__main__':
67
+ uvicorn.run('app:app', host='0.0.0.0', port=7861)
db/query_db.py CHANGED
@@ -21,7 +21,7 @@ async def query(embeddings, threshold=0.8):
21
 
22
  if len(similar_indexes) > 0:
23
  selected_index = random.choice(similar_indexes)
24
- selected_id = search_results[0].ids[selected_index]
25
  audio_obj = audio_response.query(f'id == {selected_id}', output_fields=['text', 'filename'])[0]
26
  audio_id = audio_obj['filename']
27
 
 
21
 
22
  if len(similar_indexes) > 0:
23
  selected_index = random.choice(similar_indexes)
24
+ selected_id = search_results.ids[selected_index]
25
  audio_obj = audio_response.query(f'id == {selected_id}', output_fields=['text', 'filename'])[0]
26
  audio_id = audio_obj['filename']
27