Skip to content

Commit 8b74c8e

Browse files
committed
Server: add client docstrings
1 parent 06796d5 commit 8b74c8e

File tree

1 file changed

+119
-16
lines changed

1 file changed

+119
-16
lines changed

examples/server/test_client.py

Lines changed: 119 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,36 @@
33
from PIL import Image, PngImagePlugin, ImageShow
44
from threading import Thread
55

6-
def save_img(img: Image,path: str) -> None:
6+
from typing import List
7+
8+
9+
def save_img(img: Image, path: str) -> None:
10+
"""
11+
Save the image to the specified path with metadata.
12+
13+
Args:
14+
img (Image): The image to be saved.
15+
path (str): The path where the image will be saved.
16+
17+
Returns:
18+
None
19+
"""
720
info = PngImagePlugin.PngInfo()
821
for key, value in img.info.items():
922
info.add_text(key, value)
10-
img.save(path,pnginfo=info)
23+
img.save(path, pnginfo=info)
1124

1225
def show_img(img: Image, title = None) -> None:
26+
"""
27+
Display the image (with metadata) in a new window and print the path of the temporary file.
28+
29+
Args:
30+
img (Image): The image to be displayed.
31+
title (str, optional): The title of the image window. Defaults to None.
32+
33+
Returns:
34+
None
35+
"""
1336
info = PngImagePlugin.PngInfo()
1437
for key, value in img.info.items():
1538
info.add_text(key, value)
@@ -23,9 +46,26 @@ def show_img(img: Image, title = None) -> None:
2346
_server = "localhost"
2447
_port = 8080
2548
_endpoint = "txt2img"
49+
url=""
2650

27-
def update_url(protocol= None, server=None,port=None,endpoint=None):
28-
global _protocol, _server, _port, _endpoint
51+
def update_url(protocol=None, server=None, port=None, endpoint=None) -> str:
52+
"""
53+
Update the global URL variable with the provided protocol, server, port, and endpoint.
54+
55+
This function takes optional arguments for protocol, server, port, and endpoint.
56+
If any of these arguments are provided, the corresponding global variable is updated with the new value.
57+
The function then constructs the URL using the updated global variables and returns it.
58+
59+
Args:
60+
protocol (str, optional): The protocol to be used in the URL. Defaults to None.
61+
server (str, optional): The server address to be used in the URL. Defaults to None.
62+
port (int, optional): The port number to be used in the URL. Defaults to None.
63+
endpoint (str, optional): The endpoint to be used in the URL. Defaults to None.
64+
65+
Returns:
66+
str: The updated URL.
67+
"""
68+
global _protocol, _server, _port, _endpoint, url
2969
if protocol:
3070
_protocol = protocol
3171
if server:
@@ -34,27 +74,90 @@ def update_url(protocol= None, server=None,port=None,endpoint=None):
3474
_port = port
3575
if endpoint:
3676
_endpoint = endpoint
37-
return f"{_protocol}://{_server}:{_port}/{_endpoint}"
77+
url = f"{_protocol}://{_server}:{_port}/{_endpoint}"
78+
return url
79+
80+
# set default url value
81+
update_url()
82+
83+
def sendRequest(payload: str) -> str:
84+
"""
85+
Send a POST request to the API endpoint with the provided payload.
3886
39-
url = update_url(port=8084)
87+
This function takes a payload as input and sends a POST request to the API endpoint specified by the global URL variable.
88+
The function then returns the text content of the response.
4089
41-
def sendRequest(payload):
90+
Args:
91+
payload (str): The payload to be sent in the POST request.
92+
93+
Returns:
94+
str: The text content of the response from the POST request.
95+
"""
4296
global url
43-
return requests.post(url,payload ).text
97+
return requests.post(url, payload).text
98+
99+
def getImages(response: str) -> List[Image.Image]:
100+
"""
101+
Convert base64 encoded image data from the API response into a list of Image objects.
102+
103+
This function takes the text response from the API as input and parses it as JSON.
104+
It then iterates over each image data in the JSON response, decodes the base64 encoded image data,
105+
and uses the BytesIO class to convert it into a PIL Image object.
106+
The function returns a list of these Image objects.
107+
108+
Args:
109+
response (str): The text response from the API containing base64 encoded image data.
110+
111+
Returns:
112+
List[Image.Image]: A list of PIL Image objects decoded from the base64 encoded image data in the API response.
113+
"""
114+
return [Image.open(BytesIO(base64.b64decode(img["data"]))) for img in json.loads(response)]
115+
116+
def showImages(imgs: List[Image.Image]) -> None:
117+
"""
118+
Display a list of images in separate threads.
44119
45-
def getImages(response):
46-
return [Image.open(BytesIO(base64.b64decode(img["data"]) )) for img in json.loads(response)]
120+
This function takes a list of PIL Image objects as input and creates a new thread for each image.
121+
Each thread calls the show_img function to display the image in a new window and print the path of the temporary file.
122+
The function does not return any value.
47123
48-
def showImages(imgs):
49-
for img in imgs:
50-
t =Thread(target=show_img,args=(img,))
124+
Args:
125+
imgs (List[Image.Image]): A list of PIL Image objects to be displayed.
126+
127+
Returns:
128+
None
129+
"""
130+
for (i,img) in enumerate(imgs):
131+
t = Thread(target=show_img, args=(img, f"IMG {i}"))
51132
t.setDaemon(True)
52-
t.start()
133+
t.start
134+
135+
def saveImages(imgs: List[Image.Image], path: str) -> None:
136+
"""
137+
Save a list of images to the specified path with metadata.
138+
139+
This function takes a list of PIL Image objects and a path as input.
140+
For each image, it calls the save_img function to save the image to a file
141+
with the name "{path}{i}.png", where i is the index of the image in the list.
142+
The function does not return any value.
143+
144+
Args:
145+
imgs (List[Image.Image]): A list of PIL Image objects to be saved.
146+
path (str): The path where the images will be saved.
147+
148+
Returns:
149+
None
150+
"""
151+
if path.endswith(".png"):
152+
path = path[:-4]
153+
for (i, img) in enumerate(imgs):
154+
save_img(img, f"{path}{i}.png")
53155

54156

55-
def print_usage():
157+
def _print_usage():
56158
print("""Example usage (images will be displayed and saved to a temporary file):
159+
update_url(server=127.0.0.1, port=8080)
57160
showImages(getImages(sendRequest(json.dumps({'seed': -1, 'batch_count':4, 'sample_steps':24, 'width': 512, 'height':768, 'negative_prompt': "Bad quality", 'prompt': "A beautiful image"}))))""")
58161

59162
if __name__ == "__main__":
60-
print_usage()
163+
_print_usage()

0 commit comments

Comments
 (0)