Newer
Older
stable_diffusion_runner / ai_server.py
# conda activate PyTorch

# 起動コマンドは下記の通り
# uvicorn ai_server:app --reload

import torch
import os
import threading, queue
import time

from os.path import join, dirname

from diffusers import StableDiffusionPipeline

from dotenv import load_dotenv

from fastapi import FastAPI
app = FastAPI()

load_dotenv(verbose=True)

dotenv_path = join(dirname(__file__), '.env')

load_dotenv(dotenv_path)

AUTH_TOKEN = os.environ.get("AUTH_TOKEN")

pipe = StableDiffusionPipeline.from_pretrained(

    "CompVis/stable-diffusion-v1-4",

    revision="fp16",

    torch_dtype=torch.float16, use_auth_token=AUTH_TOKEN)
pipe.to("cuda")

prompt = "concept art. Illustration,hyper quality,highly detailed, cinematic lighting. A teenage Japanese youth with neko cat ears on head top and neko cat tail on waist, He has a beautiful face with a hint of childhood. equipped with post-apocalyptic armour and tactical hand axe, is looking at us against an abandoned Tokyo building. He is wearing jeans. Long hair that extends to his back."

OUTPUT_DIR = "output/catboy04"

# 画像生成実行処理
def imgGenerate(imgOrder):
    with autocast("cuda"):
        image = pipe(imgOrder["prompt"])["sample"][0]

        if not os.path.exists(imgOrder["outputdir"]):
            # ディレクトリが存在しない場合、ディレクトリを作成する
            os.makedirs(imgOrder["outputdir"])

        image.save(f"{imgOrder['outputdir']}/output001.png")

# 画像生成スレッド用処理
def generateThread(q):
    print("スレッド開始")
    while True:
        imgOrder = q.get()
        imgGenerate(imgOrder)
        time.sleep(3)
        q.task_done()


q = queue.Queue()

# 画像生成スレッド実行
thread = threading.Thread(target=generateThread, args=(q,))
thread.start()

from torch import autocast

# サーバ待ち受け設定
@app.get("/")
async def hello():
    return {"message" : "Hello,World"}

@app.get("/render")
async def render(prompt: str = "", outputdir: str = "output/default"):
    # キューに情報を追加
    imgOrder = {}
    imgOrder["prompt"] = prompt
    imgOrder["outputdir"] = outputdir
    q.put(imgOrder)

    return {"message" : "render"}