From 1f639760e4a126e597c368f28d097c03ff19a357 Mon Sep 17 00:00:00 2001 From: Remixer Dec <6587642+remixer-dec@users.noreply.github.com> Date: Wed, 16 Nov 2022 22:35:45 +0400 Subject: [PATCH 1/3] Added support for proxy engines --- engine.py | 248 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ server.py | 170 ++++--------------------------------- utils.py | 35 ++++++++ 3 files changed, 299 insertions(+), 154 deletions(-) create mode 100644 engine.py create mode 100644 utils.py diff --git a/engine.py b/engine.py new file mode 100644 index 0000000..0aab3bb --- /dev/null +++ b/engine.py @@ -0,0 +1,248 @@ +import re +import sys +import flask +import torch +import diffusers +import requests +import json +from PIL import Image + +from utils import retrieve_param, pil_to_b64, b64_to_pil, get_compute_platform + +class Engine(object): + def __init__(self): + pass + + def process(self, kwargs): + return [] + +class ProxyEngine(Engine): + def __init__(self, base_url): + super().__init__() + self.base_url = base_url + + def process(self, url, args_dict): + response = requests.post(url, json=args_dict) + if response.status_code != 200: + raise RuntimeError(response.content) + return response + + def run(self): + pass + + +class EngineStableDiffusion(Engine): + def __init__(self, pipe, sibling=None, custom_model_path=None, requires_safety_checker=True): + super().__init__() + if sibling == None: + self.engine = pipe.from_pretrained( 'runwayml/stable-diffusion-v1-5', use_auth_token=hf_token.strip() ) + elif custom_model_path: + if requires_safety_checker: + self.engine = diffusers.StableDiffusionPipeline.from_pretrained(custom_model_path, + safety_checker=sibling.engine.safety_checker, + feature_extractor=sibling.engine.feature_extractor) + else: + self.engine = diffusers.StableDiffusionPipeline.from_pretrained(custom_model_path, + feature_extractor=sibling.engine.feature_extractor) + else: + self.engine = pipe( + vae=sibling.engine.vae, + text_encoder=sibling.engine.text_encoder, + tokenizer=sibling.engine.tokenizer, + unet=sibling.engine.unet, + scheduler=sibling.engine.scheduler, + safety_checker=sibling.engine.safety_checker, + feature_extractor=sibling.engine.feature_extractor + ) + self.engine.to( get_compute_platform('engine') ) + + def process(self, kwargs): + output = self.engine( **kwargs ) + return {'image': output.images[0], 'nsfw': output.nsfw_content_detected[0]} + + def prepare_args(self, task): + seed = retrieve_param( 'seed', flask.request.form, int, 0 ) + prompt = flask.request.form[ 'prompt' ] + + args_dict = { + 'prompt' : [ prompt ], + 'seed': seed, + 'num_inference_steps' : retrieve_param( 'num_inference_steps', flask.request.form, int, 100 ), + 'guidance_scale' : retrieve_param( 'guidance_scale', flask.request.form, float, 7.5 ), + 'eta' : retrieve_param( 'eta', flask.request.form, float, 0.0 ) + } + + if (task == 'txt2img'): + args_dict[ 'width' ] = retrieve_param( 'width', flask.request.form, int, 512 ) + args_dict[ 'height' ] = retrieve_param( 'height', flask.request.form, int, 512 ) + if (task == 'img2img' or task == 'masking'): + init_img_b64 = flask.request.form[ 'init_image' ] + init_img_b64 = re.sub( '^data:image/png;base64,', '', init_img_b64 ) + init_img_pil = b64_to_pil( init_img_b64 ) + args_dict[ 'init_image' ] = init_img_pil + args_dict[ 'strength' ] = retrieve_param( 'strength', flask.request.form, float, 0.7 ) + if (task == 'masking'): + mask_img_b64 = flask.request.form[ 'mask_image' ] + mask_img_b64 = re.sub( '^data:image/png;base64,', '', mask_img_b64 ) + mask_img_pil = b64_to_pil( mask_img_b64 ) + args_dict[ 'mask_image' ] = mask_img_pil + return args_dict + + + def run(self, task): + total_results = [] + output_data = {} + count = retrieve_param( 'num_outputs', flask.request.form, int, 1 ) + for i in range( count ): + args_dict = self.prepare_args(task) + if (args_dict['seed'] == 0): + generator = torch.Generator( device=get_compute_platform('generator') ) + else: + generator = torch.Generator( device=get_compute_platform('generator') ).manual_seed( args_dict['seed'] ) + args_dict['generator'] = generator + new_seed = generator.seed() + # Perform inference: + pipeline_output = self.process( args_dict ) + pipeline_output[ 'seed' ] = new_seed + total_results.append( pipeline_output ) + # Prepare response + output_data[ 'status' ] = 'success' + images = [] + for result in total_results: + images.append({ + 'base64' : pil_to_b64( result['image'].convert( 'RGB' ) ), + 'seed' : result['seed'], + 'mime_type': 'image/png', + 'nsfw': result['nsfw'] + }) + output_data[ 'images' ] = images + return output_data + +class A1111EngineStableDiffusion(ProxyEngine): + def prepare_args(self, task): + args_dict = { + 'prompt' : flask.request.form[ 'prompt' ], + 'steps' : retrieve_param( 'num_inference_steps', flask.request.form, int, 100 ), + 'cfg_scale' : retrieve_param( 'guidance_scale', flask.request.form, float, 7.5 ), + 'eta' : retrieve_param( 'eta', flask.request.form, float, 0.0 ), + 'n_iter': retrieve_param( 'num_outputs', flask.request.form, int, 1 ), + 'seed': retrieve_param( 'seed', flask.request.form, int, -1 ) + } + + if (task == 'txt2img'): + args_dict[ 'width' ] = retrieve_param( 'width', flask.request.form, int, 512 ) + args_dict[ 'height' ] = retrieve_param( 'height', flask.request.form, int, 512 ) + self.endpoint_url = '/sdapi/v1/txt2img' + if (task == 'img2img' or task == 'masking'): + init_img_b64 = flask.request.form[ 'init_image' ] + init_img_b64 = 'data:image/png;base64,' + init_img_b64 if init_img_b64[0:4] != 'data' else mask_img_b64 + args_dict[ 'init_images' ] = (init_img_b64,) + args_dict[ 'denoising_strength' ] = 1.0 - retrieve_param( 'strength', flask.request.form, float, 0.7 ) + self.endpoint_url = '/sdapi/v1/img2img' + if (task == 'masking'): + mask_img_b64 = flask.request.form[ 'mask_image' ] + mask_img_b64 = 'data:image/png;base64,' + mask_img_b64 if mask_img_b64[0:4] != 'data' else mask_img_b64 + args_dict[ 'mask' ] = mask_img_b64 + return args_dict + + def run(self, task): + total_results = [] + output_data = {} + args_dict = self.prepare_args(task) + response = self.process(self.base_url + self.endpoint_url, args_dict) + if response.status_code != 200: + raise RuntimeError(response.text) + output_data[ 'status' ] = 'success' + images = [] + data = response.json() + info = json.loads(data[ 'info' ]) + + for idx, result in enumerate(data[ 'images' ]): + images.append({ + 'base64': result, + 'seed': info['all_seeds'][idx], + 'mime_type': 'image/png', + 'nsfw': False + }) + output_data[ 'images' ] = images + return output_data + +class InvokeAIEngineStableDiffusion(ProxyEngine): + def prepare_args(self, task): + args_dict = { + 'prompt' : flask.request.form[ 'prompt' ], + 'steps' : retrieve_param( 'num_inference_steps', flask.request.form, int, 100 ), + 'cfg_scale' : retrieve_param( 'guidance_scale', flask.request.form, float, 7.5 ), + 'eta' : retrieve_param( 'eta', flask.request.form, float, 0.0 ), + 'seed': retrieve_param( 'seed', flask.request.form, int, -1 ), + 'iterations': retrieve_param( 'num_outputs', flask.request.form, int, 1 ), + 'sampler_name': 'k_lms', + 'width': retrieve_param( 'width', flask.request.form, int, 512 ), + 'height': retrieve_param( 'height', flask.request.form, int, 512 ), + 'threshold': 0, + 'perlin': 0, + 'variation_amount': 0, + 'with_variations': '', + 'initimg': None, + 'strength': 0.99, + 'fit': 'on', + 'facetool_strength': 0.0, + 'upscale_level': '', + 'upscale_strength': 0, + 'initimg_name': '' + } + + if (task == 'img2img' or task == 'masking'): + init_img_b64 = flask.request.form[ 'init_image' ] + init_img_b64 = 'data:image/png;base64,' + init_img_b64 if init_img_b64[0:4] != 'data' else mask_img_b64 + args_dict[ 'initimg' ] = init_img_b64 + args_dict[ 'imitimg_name' ] = 'temp.png' + args_dict[ 'strength' ] = 1.0 - retrieve_param( 'strength', flask.request.form, float, 0.7 ) + endpoint_url = '/sdapi/v1/img2img' + if (task == 'masking'): + mask_img_b64 = flask.request.form[ 'mask_image' ] + mask_img_b64 = 'data:image/png;base64,' + mask_img_b64 if mask_img_b64[0:4] != 'data' else mask_img_b64 + args_dict[ 'mask' ] = mask_img_b64 + return args_dict + + def run(self, task): + total_results = [] + output_data = {} + args_dict = self.prepare_args(task) + response = self.process(self.base_url, args_dict) + output_data[ 'status' ] = 'success' + images = [] + json_data = '[{}]'.format(','.join(response.text.split('\n'))[:-1]) + data = json.loads(json_data) + data = [item for item in data if item['event'] == 'result'] + for result in data: + url = self.base_url + '/' + result['url'] + images.append({ + 'base64': pil_to_b64(Image.open(requests.get(url, stream=True).raw)), + 'seed': result[ 'seed' ], + 'mime_type': 'image/png', + 'nsfw': False + }) + output_data[ 'images' ] = images + return output_data + +class EngineManager(object): + def __init__(self): + self.engines = {} + + def has_engine(self, name): + return ( name in self.engines ) + + def add_engine(self, name, engine): + if self.has_engine( name ): + return False + self.engines[ name ] = engine + return True + + def get_engine(self, name): + if self.has_engine( 'universal' ): + return self.engines[ 'universal' ] + if not self.has_engine( name ): + return None + engine = self.engines[ name ] + return engine diff --git a/server.py b/server.py index 71f4dab..46717b4 100644 --- a/server.py +++ b/server.py @@ -1,110 +1,9 @@ -import re -import time -import inspect import json import flask import sys -import base64 -from PIL import Image -from io import BytesIO - -import torch import diffusers - -################################################## -# Utils - -def retrieve_param(key, data, cast, default): - if key in data: - value = flask.request.form[ key ] - value = cast( value ) - return value - return default - -def pil_to_b64(input): - buffer = BytesIO() - input.save( buffer, 'PNG' ) - output = base64.b64encode( buffer.getvalue() ).decode( 'utf-8' ).replace( '\n', '' ) - buffer.close() - return output - -def b64_to_pil(input): - output = Image.open( BytesIO( base64.b64decode( input ) ) ) - return output - -def get_compute_platform(context): - try: - import torch - if torch.cuda.is_available(): - return 'cuda' - elif torch.backends.mps.is_available() and context == 'engine': - return 'mps' - else: - return 'cpu' - except ImportError: - return 'cpu' - -################################################## -# Engines - -class Engine(object): - def __init__(self): - pass - - def process(self, kwargs): - return [] - -class EngineStableDiffusion(Engine): - def __init__(self, pipe, sibling=None, custom_model_path=None, requires_safety_checker=True): - super().__init__() - if sibling == None: - self.engine = pipe.from_pretrained( 'runwayml/stable-diffusion-v1-5', use_auth_token=hf_token.strip() ) - elif custom_model_path: - if requires_safety_checker: - self.engine = diffusers.StableDiffusionPipeline.from_pretrained(custom_model_path, - safety_checker=sibling.engine.safety_checker, - feature_extractor=sibling.engine.feature_extractor) - else: - self.engine = diffusers.StableDiffusionPipeline.from_pretrained(custom_model_path, - feature_extractor=sibling.engine.feature_extractor) - else: - self.engine = pipe( - vae=sibling.engine.vae, - text_encoder=sibling.engine.text_encoder, - tokenizer=sibling.engine.tokenizer, - unet=sibling.engine.unet, - scheduler=sibling.engine.scheduler, - safety_checker=sibling.engine.safety_checker, - feature_extractor=sibling.engine.feature_extractor - ) - self.engine.to( get_compute_platform('engine') ) - - def process(self, kwargs): - output = self.engine( **kwargs ) - return {'image': output.images[0], 'nsfw':output.nsfw_content_detected[0]} - -class EngineManager(object): - def __init__(self): - self.engines = {} - - def has_engine(self, name): - return ( name in self.engines ) - - def add_engine(self, name, engine): - if self.has_engine( name ): - return False - self.engines[ name ] = engine - return True - - def get_engine(self, name): - if not self.has_engine( name ): - return None - engine = self.engines[ name ] - return engine - -################################################## -# App +from engine import EngineManager, EngineStableDiffusion, A1111EngineStableDiffusion, InvokeAIEngineStableDiffusion # Load and parse the config file: try: @@ -116,7 +15,7 @@ def get_engine(self, name): hf_token = config['hf_token'] -if (hf_token == None): +if (hf_token == None and config['mode'] != 'proxy'): sys.exit('No Hugging Face token found in config.json.') custom_models = config['custom_models'] if 'custom_models' in config else [] @@ -128,14 +27,22 @@ def get_engine(self, name): manager = EngineManager() # Add supported engines to manager: -manager.add_engine( 'txt2img', EngineStableDiffusion( diffusers.StableDiffusionPipeline, sibling=None ) ) -manager.add_engine( 'img2img', EngineStableDiffusion( diffusers.StableDiffusionImg2ImgPipeline, sibling=manager.get_engine( 'txt2img' ) ) ) -manager.add_engine( 'masking', EngineStableDiffusion( diffusers.StableDiffusionInpaintPipeline, sibling=manager.get_engine( 'txt2img' ) ) ) -for custom_model in custom_models: - manager.add_engine( custom_model['url_path'], +if (config.get('mode') != 'proxy'): + manager.add_engine( 'txt2img', EngineStableDiffusion( diffusers.StableDiffusionPipeline, sibling=None ) ) + manager.add_engine( 'img2img', EngineStableDiffusion( diffusers.StableDiffusionImg2ImgPipeline, sibling=manager.get_engine( 'txt2img' ) ) ) + manager.add_engine( 'masking', EngineStableDiffusion( diffusers.StableDiffusionInpaintPipeline, sibling=manager.get_engine( 'txt2img' ) ) ) + for custom_model in custom_models: + manager.add_engine( custom_model['url_path'], EngineStableDiffusion( diffusers.StableDiffusionPipeline, sibling=manager.get_engine( 'txt2img' ), custom_model_path=custom_model['model_path'], requires_safety_checker=custom_model['requires_safety_checker'] ) ) +else: + engine = None + if config['base_provider'] == 'AUTOMATIC1111': + engine = A1111EngineStableDiffusion(config['base_url']) + elif config['base_provider'] == 'InvokeAI': + engine = InvokeAIEngineStableDiffusion(config['base_url']) + manager.add_engine('universal', engine) # Define routes: @app.route('/ping', methods=['GET']) @@ -177,52 +84,7 @@ def _generate(task, engine=None): # Handle request: try: - seed = retrieve_param( 'seed', flask.request.form, int, 0 ) - count = retrieve_param( 'num_outputs', flask.request.form, int, 1 ) - total_results = [] - for i in range( count ): - if (seed == 0): - generator = torch.Generator( device=get_compute_platform('generator') ) - else: - generator = torch.Generator( device=get_compute_platform('generator') ).manual_seed( seed ) - new_seed = generator.seed() - prompt = flask.request.form[ 'prompt' ] - args_dict = { - 'prompt' : [ prompt ], - 'num_inference_steps' : retrieve_param( 'num_inference_steps', flask.request.form, int, 100 ), - 'guidance_scale' : retrieve_param( 'guidance_scale', flask.request.form, float, 7.5 ), - 'eta' : retrieve_param( 'eta', flask.request.form, float, 0.0 ), - 'generator' : generator - } - if (task == 'txt2img'): - args_dict[ 'width' ] = retrieve_param( 'width', flask.request.form, int, 512 ) - args_dict[ 'height' ] = retrieve_param( 'height', flask.request.form, int, 512 ) - if (task == 'img2img' or task == 'masking'): - init_img_b64 = flask.request.form[ 'init_image' ] - init_img_b64 = re.sub( '^data:image/png;base64,', '', init_img_b64 ) - init_img_pil = b64_to_pil( init_img_b64 ) - args_dict[ 'init_image' ] = init_img_pil - args_dict[ 'strength' ] = retrieve_param( 'strength', flask.request.form, float, 0.7 ) - if (task == 'masking'): - mask_img_b64 = flask.request.form[ 'mask_image' ] - mask_img_b64 = re.sub( '^data:image/png;base64,', '', mask_img_b64 ) - mask_img_pil = b64_to_pil( mask_img_b64 ) - args_dict[ 'mask_image' ] = mask_img_pil - # Perform inference: - pipeline_output = engine.process( args_dict ) - pipeline_output[ 'seed' ] = new_seed - total_results.append( pipeline_output ) - # Prepare response - output_data[ 'status' ] = 'success' - images = [] - for result in total_results: - images.append({ - 'base64' : pil_to_b64( result['image'].convert( 'RGB' ) ), - 'seed' : result['seed'], - 'mime_type': 'image/png', - 'nsfw': result['nsfw'] - }) - output_data[ 'images' ] = images + output_data = engine.run(task) except RuntimeError as e: output_data[ 'status' ] = 'failure' output_data[ 'message' ] = 'A RuntimeError occurred. You probably ran out of GPU memory. Check the server logs for more details.' diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..cfb3dd0 --- /dev/null +++ b/utils.py @@ -0,0 +1,35 @@ +from PIL import Image +from io import BytesIO +import flask +import base64 +import torch + +def retrieve_param(key, data, cast, default): + if key in data: + value = flask.request.form[ key ] + value = cast( value ) + return value + return default + +def pil_to_b64(input): + buffer = BytesIO() + input.save( buffer, 'PNG' ) + output = base64.b64encode( buffer.getvalue() ).decode( 'utf-8' ).replace( '\n', '' ) + buffer.close() + return output + +def b64_to_pil(input): + output = Image.open( BytesIO( base64.b64decode( input ) ) ) + return output + +def get_compute_platform(context): + try: + import torch + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available() and context == 'engine': + return 'mps' + else: + return 'cpu' + except ImportError: + return 'cpu' From 03de71e5d2ee4fc9a2044da2cbc51e58fc04f563 Mon Sep 17 00:00:00 2001 From: Remixer Dec <6587642+remixer-dec@users.noreply.github.com> Date: Thu, 17 Nov 2022 01:03:20 +0400 Subject: [PATCH 2/3] Added proxy mode to README --- README.md | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f4170ea..e1e2f95 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,9 @@ The API server currently supports: 1. Stable Diffusion weights automatically downloaded from Hugging Face. 1. Custom fine-tuned models in the Hugging Face diffusers file format like those created with [DreamBooth](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion). +1. Running as a [proxy server](#proxy-mode) for [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) and [InvokeAI](https://github.com/invoke-ai/InvokeAI/). -(Note that loading checkpoint files directly is not currently supported, but you can easily convert `.ckpt` files into the diffusers format using the aptly named [`convert_original_stable_diffusion_to_diffusers.py`](https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py) script.) +(Note that loading checkpoint files directly is not currently supported in regular mode, but you can easily convert `.ckpt` files into the diffusers format using the aptly named [`convert_original_stable_diffusion_to_diffusers.py`](https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py) script.) or use [proxy mode](#proxy-mode). The server will run on Windows and Linux machines with NVIDIA GPUs, and on M1 Macs. M1 Mac support using MPS (Metal Performance Shaders) is highly experimental (and not easy to configure) but it does work, and it will get better over time. @@ -156,6 +157,24 @@ Note that the `custom_model` section of the `config.json` file is an array. That To see your custom models in the Generate tab of the Stable Diffusion Photoshop plugin, make sure you've configured your local inference server in the API Key tab. +## Proxy mode + +To use this server in proxy mode, run the base server with the following flags: +for AUTOMATIC1111, you need to add `--api` argument +for InvokeAI, you need to run scripts/legacy_api.py with `--web` argument + +after doing this, you need to add the following parameters to `config.json`: + +``` +{ + ... + "mode": "proxy", + "base_provider": "AUTOMATIC1111" or "InvokeAI", + "base_url": "http://localhost:7860" or "http://localhost:9090" or any other IP/Port of your base server +} +``` + + ## REST API Note that all `POST` requests use the `application/x-www-form-urlencoded` content type, and all images are base64 encoded strings. From 30f4fef5c2fde704a422d36534e7dfd9d71c2bdb Mon Sep 17 00:00:00 2001 From: Remixer Dec <6587642+remixer-dec@users.noreply.github.com> Date: Thu, 17 Nov 2022 01:07:50 +0400 Subject: [PATCH 3/3] Hotfix for non-existing "mode" key --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 46717b4..339312f 100644 --- a/server.py +++ b/server.py @@ -15,7 +15,7 @@ hf_token = config['hf_token'] -if (hf_token == None and config['mode'] != 'proxy'): +if (hf_token == None and config.get('mode') != 'proxy'): sys.exit('No Hugging Face token found in config.json.') custom_models = config['custom_models'] if 'custom_models' in config else []