์นดํ
๊ณ ๋ฆฌ ์์
๋ฏธ๋ ํ๋ก์ ํธ 2์ผ์ฐจ
์ํ์ด0812
2024. 11. 4. 00:58
728x90
๋ฐ์ํ
# ๋ฅ๋ฌ๋ ํ๊ฒฝ ๊ตฌ์ถ : CUDA, CuDNN
# ํ ์คํธ -> ์ฌ์ง (stabilityai/stable-diffusion-3.5-large)
๊ฐ์ด๋..
์คํจ ์ฝ๋...
# '''
# 5. ์์
, ๋ช
์ธ, ๊ทธ๋ฆผ ์ถ์ฒ/์์ฑ: (๋ชจ๋ธ ๋ฏธ์ ) :
# (ํ๋ณด 1) facebook/musicgen-small
# ํ๋กฌํํฐ๋ฅผ ์์ฑํด์ฃผ๋ฉด ์์
์ ๋ง๋ค์ด์ฃผ๋ ๋ชจ๋ธ์
๋๋ค.
# ๋์ ์ด๊ฑด ์ ํฌ๊ฐ ๋ฐฑ์๋ ์ชฝ์์ ํ๋กฌํํฐ๋ฅผ ๋ช๊ฐ์ง ๋ง๋ค์ด์ ๋ฃ์ด์ค์ผ ํ ๊ฒ ๊ฐ์์.
# '''
from diffusers import StableDiffusionPipeline
from fastapi import FastAPI, Form
from PIL import Image
import torch
import json
import requests
# Hugging Face API ํ ํฐ ์ค์
API_TOKEN = "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"
headers = {"Authorization": f"Bearer {API_TOKEN}"}
# Stable Diffusion ๋ชจ๋ธ ๋ก๋
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-large",
torch_dtype=torch.float16,
use_auth_token=API_TOKEN # API ํ ํฐ์ ์ฌ์ฉํด ์ธ์ฆ
)
pipe = pipe.to("cuda")
# FastAPI ์ธ์คํด์ค ์์ฑ
app = FastAPI()
# ์ด๋ฏธ์ง ์์ฑ ์๋ํฌ์ธํธ
@app.post("/createimage/")
async def create_image(text: str = Form(...)):
"""
์ฃผ์ด์ง ํ
์คํธ ํ๋กฌํํธ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ด๋ฏธ์ง๋ฅผ ์์ฑํ์ฌ ์ ์ฅํฉ๋๋ค.
"""
# ํ
์คํธ ๊ธฐ๋ฐ์ผ๋ก ์ด๋ฏธ์ง ์์ฑ
image = pipe(
prompt=text,
num_inference_steps=28,
guidance_scale=3.5
).images[0]
# ์ด๋ฏธ์ง ์ ์ฅ
image_path = "generated_image.png"
image.save(image_path)
return {"image_path": image_path}
# from diffusers import BitsAndBytesConfig, StableDiffusionPipeline
# from fastapi import FastAPI, Form
# from PIL import Image
# import torch
# # from huggingface_hub import login
# # ์์ธ์ค ํ ํฐ
# # Hugging Face ๋ก๊ทธ์ธ (ํ ํฐ ํ์)
# # login("your_hugging_face_access_token")
# # ๋ชจ๋ธ ์ค์
# model_id = "stabilityai/stable-diffusion-3.5-large, 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=access_token"
# nf4_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.bfloat16
# )
# # ๋ชจ๋ธ ๋ก๋
# pipe = StableDiffusionPipeline.from_pretrained(
# model_id,
# torch_dtype=torch.bfloat16,
# quantization_config=nf4_config
# )
# pipe = pipe.to("cuda")
# pipe.enable_model_cpu_offload()
# # FastAPI ์ธ์คํด์ค ์์ฑ
# app = FastAPI()
# # ์ด๋ฏธ์ง ์์ฑ ์๋ํฌ์ธํธ
# @app.post("/createimage/")
# async def create_image(text: str = Form(...)):
# """
# ์ฃผ์ด์ง ํ
์คํธ ํ๋กฌํํธ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ด๋ฏธ์ง๋ฅผ ์์ฑํ์ฌ ์ ์ฅํฉ๋๋ค.
# """
# # ํ
์คํธ ๊ธฐ๋ฐ์ผ๋ก ์ด๋ฏธ์ง ์์ฑ
# image = pipe(
# prompt=text,
# num_inference_steps=28,
# guidance_scale=3.5
# ).images[0]
# # ์ด๋ฏธ์ง ์ ์ฅ ๊ฒฝ๋ก
# image_path = "generated_image.png"
# image.save(image_path)
# return {"image_path": image_path}
# ํ ์คํธ -> ์ค๋์ค (facebook/musicgen-small)
from fastapi import FastAPI, Form
from transformers import pipeline, AutoProcessor, MusicgenForConditionalGeneration
import scipy.io.wavfile
import torch
import tempfile
import os
import torch
# print(torch.cuda.is_available()) # True๊ฐ ์ถ๋ ฅ๋๋ฉด GPU๋ฅผ ์ฌ์ฉํ ์ ์๋ ์ํ์
๋๋ค.
# ๋ชจ๋ธ์ด GPU๋ก ์ด๋๋์๋์ง ํ์ธ
app = FastAPI()
# # MusicGen ๋ชจ๋ธ ์ค์
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", torch_dtype=torch.float32).to("cuda")
print(next(model.parameters()).device) # "cuda:0"์ด ์ถ๋ ฅ๋๋ฉด GPU์ ๋ก๋๋ ๊ฒ
# ์์
์์ฑ ํ์ดํ๋ผ์ธ ํจ์
def generate_music(text: str, length: int = 512):
# ํ
์คํธ๋ฅผ ๋ชจ๋ธ์ ์
๋ ฅํ ์ ์๋ ํํ๋ก ๋ณํ
inputs = processor(text=[text], padding=True, return_tensors="pt").to("cuda")
# ๋ชจ๋ธ์ ํตํด ์ค๋์ค ๊ฐ ์์ฑ
audio_values = model.generate(**inputs, max_new_tokens=length)
# ์์ฑ๋ ์ค๋์ค๋ฅผ WAV ํ์ผ๋ก ์์ ์ ์ฅ
sampling_rate = model.config.audio_encoder.sampling_rate
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
scipy.io.wavfile.write(tmp_file.name, rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
temp_path = tmp_file.name
return temp_path
# FastAPI ์๋ํฌ์ธํธ: ์์ฝ๋ ํ
์คํธ ๊ธฐ๋ฐ์ผ๋ก ์์
์์ฑ
@app.post("/recommend_music/")
async def recommend_music(summary_text: str = Form(...), length: int = Form(512)):
# ์์
์์ฑ
music_path = generate_music(summary_text)
# ์์ฑ๋ ์์
ํ์ผ ๊ฒฝ๋ก ๋ฐํ
return {"music_path": music_path}
# FastAPI์์ ์ ์ ํ์ผ๋ก WAV ํ์ผ์ ์ ๊ณตํ ์ ์๋๋ก ์ค์ (๊ฐ๋จํ ํ
์คํธ์ฉ)
@app.get("/download_music/")
async def download_music(music_path: str):
if os.path.exists(music_path):
with open(music_path, "rb") as f:
data = f.read()
return {"audio_data": data}
else:
return {"error": "File not found"}
728x90
๋ฐ์ํ