fffiloni commited on
Commit
05e653a
1 Parent(s): fff1f22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -54
app.py CHANGED
@@ -1,11 +1,7 @@
1
  import gradio as gr
2
  import torch
3
 
4
- from scipy.io import wavfile
5
- import numpy as np
6
- from PIL import Image
7
-
8
- from spectro import wav_bytes_from_spectrogram_image, spectrogram_from_waveform, image_from_spectrogram
9
 
10
  from diffusers import StableDiffusionPipeline
11
  from diffusers import StableDiffusionImg2ImgPipeline
@@ -15,63 +11,20 @@ from share_btn import community_icon_html, loading_icon_html, share_js
15
  MODEL_ID = "riffusion/riffusion-model-v1"
16
  pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
17
  pipe = pipe.to("cuda")
18
- pipe2 = StableDiffusionImg2ImgPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
19
- pipe2 = pipe2.to("cuda")
20
-
21
- def predict(prompt, audio_input, duration):
22
- if audio_input == None:
23
- return classic(prompt, duration)
24
- else:
25
- return audio_transfer(prompt, audio_input)
26
-
27
 
28
- def classic(prompt, duration):
 
29
  if duration == 5:
30
  width_duration=512
31
  else :
32
  width_duration = 512 + ((int(duration)-5) * 128)
33
- spec = pipe(prompt, height=512, width=width_duration).images[0]
34
  print(spec)
35
  wav = wav_bytes_from_spectrogram_image(spec)
36
  with open("output.wav", "wb") as f:
37
  f.write(wav[0].getbuffer())
38
  return spec, 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
39
 
40
- def audio_transfer(prompt, audio):
41
- # read uploaded file to wav
42
- rate, data = wavfile.read(audio)
43
-
44
- # convert to mono
45
- data = np.mean(data, axis=0)
46
-
47
- # convert to float32
48
- data = data.astype(np.float32)
49
-
50
- # take a random 7 second slice of the audio
51
- data = data[rate*7:rate*14]
52
-
53
- spectrogram = spectrogram_from_waveform(
54
- waveform=data,
55
- sample_rate=rate,
56
- # width=768,
57
- n_fft=8192,
58
- hop_length=512,
59
- win_length=8192,
60
- )
61
-
62
- spec = image_from_spectrogram(spectrogram)
63
-
64
- images = pipe2(
65
- prompt=prompt,
66
- image=spec,
67
- strength=0.5,
68
- guidance_scale=7
69
- ).images
70
-
71
- wav = wav_bytes_from_spectrogram_image(images[0])
72
- with open("output.wav", "wb") as f:
73
- f.write(wav[0].getbuffer())
74
- return images[0], 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
75
 
76
  title = """
77
  <div style="text-align: center; max-width: 500px; margin: 0 auto;">
@@ -189,8 +142,10 @@ with gr.Blocks(css=css) as demo:
189
  gr.HTML(title)
190
 
191
  prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in")
192
- audio_input = gr.Audio(label="audio input", type="filepath", source="upload")
193
- duration_input = gr.Slider(label="Duration in seconds", minimum=5, maximum=10, step=1, value=8, elem_id="duration-slider")
 
 
194
  send_btn = gr.Button(value="Get a new spectrogram ! ", elem_id="submit-btn")
195
 
196
  with gr.Column(elem_id="col-container-2"):
@@ -205,7 +160,7 @@ with gr.Blocks(css=css) as demo:
205
 
206
  gr.HTML(article)
207
 
208
- send_btn.click(predict, inputs=[prompt_input, audio_input, duration_input], outputs=[spectrogram_output, sound_output, share_button, community_icon, loading_icon])
209
  share_button.click(None, [], [], _js=share_js)
210
 
211
  demo.queue(max_size=250).launch(debug=True)
 
1
  import gradio as gr
2
  import torch
3
 
4
+ from spectro import wav_bytes_from_spectrogram_image
 
 
 
 
5
 
6
  from diffusers import StableDiffusionPipeline
7
  from diffusers import StableDiffusionImg2ImgPipeline
 
11
  MODEL_ID = "riffusion/riffusion-model-v1"
12
  pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
13
  pipe = pipe.to("cuda")
 
 
 
 
 
 
 
 
 
14
 
15
+
16
+ def predict(prompt, negative_prompt, duration):
17
  if duration == 5:
18
  width_duration=512
19
  else :
20
  width_duration = 512 + ((int(duration)-5) * 128)
21
+ spec = pipe(prompt, negative_prompt=negative_prompt, height=512, width=width_duration).images[0]
22
  print(spec)
23
  wav = wav_bytes_from_spectrogram_image(spec)
24
  with open("output.wav", "wb") as f:
25
  f.write(wav[0].getbuffer())
26
  return spec, 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  title = """
30
  <div style="text-align: center; max-width: 500px; margin: 0 auto;">
 
142
  gr.HTML(title)
143
 
144
  prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in")
145
+ with gr.Row():
146
+ negative_prompt = gr.Textbox(label="Negative prompt")
147
+ duration_input = gr.Slider(label="Duration in seconds", minimum=5, maximum=10, step=1, value=8, elem_id="duration-slider")
148
+
149
  send_btn = gr.Button(value="Get a new spectrogram ! ", elem_id="submit-btn")
150
 
151
  with gr.Column(elem_id="col-container-2"):
 
160
 
161
  gr.HTML(article)
162
 
163
+ send_btn.click(predict, inputs=[prompt_input, negative_prompt, duration_input], outputs=[spectrogram_output, sound_output, share_button, community_icon, loading_icon])
164
  share_button.click(None, [], [], _js=share_js)
165
 
166
  demo.queue(max_size=250).launch(debug=True)