Python接入指南
更新時(shí)間:
操作步驟
步驟一:安裝依賴
pip install requests
pip install dataclass_wizard
步驟二:增加Client類
增加 client.py,按需修改 package 名稱
import time
import uuid
import hmac
import hashlib
import base64
import json
import io
import requests
class Client:
def __init__(self, endpoint: str, app_key: str, app_secret: str):
self.endpoint = endpoint
self.app_key = app_key
self.app_secret = app_secret
def invoke(self, path: str, params: dict = None, method='POST', headers: dict = None, **kwargs):
# url = f'https://{self.endpoint}{path}'
# gen_headers = self._generate_header('POST', path, params, headers)
# return requests.post(url, headers=gen_headers, json=params, **kwargs)
url = f'https://{self.endpoint}{path}'
if method == 'GET':
gen_headers = self._generate_header('GET', path, params, headers)
return requests.get(url, headers=gen_headers, **kwargs)
else:
gen_headers = self._generate_header('POST', path, params, headers)
return requests.post(url, headers=gen_headers, json=params, **kwargs)
def _generate_header(self, http_method: str, path: str, body: dict = None, hdrs: dict = None):
"""
:param http_method:
:param path:
:param params:
:param body:
# http://bestwisewords.com/zh/api-gateway/traditional-api-gateway/use-digest-authentication-to-call-an-api?spm=a2c4g.11186623.0.0.52d126desp6m4B#topic-1867627
"""
timestamp = time.time()
date_str = time.strftime('%a, %d %b %Y %H:%M:%S GMT', time.gmtime(timestamp)).replace('GMT', 'GMT+00:00')
timestamp_str = str(int(timestamp * 1000))
uuid_str = str(uuid.uuid4())
json_header = 'application/json; charset=utf-8'
headers = {
'date': date_str,
'x-ca-key': self.app_key,
'x-ca-timestamp': timestamp_str,
'x-ca-nonce': uuid_str,
'x-ca-signature-method': 'HmacSHA256',
'x-ca-signature-headers': 'x-ca-timestamp,x-ca-key,x-ca-nonce,x-ca-signature-method',
'Content-Type': json_header,
'Accept': json_header
}
o = io.StringIO()
o.write(http_method)
o.write("\n")
o.write(json_header)
o.write("\n")
if body:
# perform md5 and base64
h = hashlib.md5()
h.update(json.dumps(body).encode('utf-8'))
body_md5_str = base64.b64encode(h.digest()).strip().decode('utf-8')
headers["content-md5"] = body_md5_str
o.write(body_md5_str)
o.write("\n")
o.write(json_header)
o.write("\n")
o.write(date_str)
o.write("\n")
o.write("x-ca-key:")
o.write(self.app_key)
o.write("\n")
o.write("x-ca-nonce:")
o.write(uuid_str)
o.write("\n")
o.write("x-ca-signature-method:HmacSHA256")
o.write("\n")
o.write("x-ca-timestamp:")
o.write(timestamp_str)
o.write("\n")
o.write(path)
h = hmac.new(bytes(self.app_secret, 'utf-8'), bytes(o.getvalue(), 'utf-8'), hashlib.sha256)
headers["x-ca-signature"] = base64.b64encode(h.digest()).decode('utf-8')
if hdrs and len(hdrs) > 0:
headers.update(hdrs)
return headers
步驟三:增加Proto類(以ComfyUI生圖服務(wù)舉例)
增加 proto.py
# -*- coding: utf-8 -*-
from dataclasses import dataclass
from typing import Optional, List, Dict
from enum import Enum
from dataclass_wizard import JSONWizard, DumpMeta
from util import batch_download_images
class PredictResultStatusCode(Enum):
TASK_INPROGRESS = "running"
TASK_FAILED = "failed"
TASK_QUEUE = "waiting"
TASK_FINISH = "succeeded"
def finished(self):
return self in (PredictResultStatusCode.TASK_FAILED, PredictResultStatusCode.TASK_FINISH)
class JSONe(JSONWizard):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
DumpMeta(key_transform='SNAKE').bind_to(cls)
@dataclass
class GatewayResponse(JSONe):
status: Optional[int] = 0
err_code: Optional[str] = ""
err_message: Optional[str] = ""
sub_err_code: Optional[str] = ""
sub_err_message: Optional[str] = ""
api_invoke_id: Optional[str] = ""
@dataclass
class ComfyRequest(JSONe):
workflow_id: str
version_id: Optional[str] = None
inputs: Optional[Dict[str, any]] = None
alias_id: Optional[str] = None
@dataclass
class ComfyResponseData(JSONe):
task_id: str
status: Optional[PredictResultStatusCode] = PredictResultStatusCode.TASK_INPROGRESS
@dataclass
class ComfyResponse(GatewayResponse):
data: Optional[ComfyResponseData] = None
@dataclass
class PredictResult(JSONe):
task_id: str
images: Optional[List[str]] = None
info: Optional[Dict[str, str]] = None
parameters: Optional[Dict[str, str]] = None
status: Optional[PredictResultStatusCode] = PredictResultStatusCode.TASK_INPROGRESS
imgs_bytes: Optional[List[str]] = None
result: Optional[Dict] = None
@dataclass
class PredictResultResponse(GatewayResponse):
data: Optional[PredictResult] = None
def download_images(self):
if self.data.images is not None and len(self.data.images) > 0:
self.data.imgs_bytes = batch_download_images(self.data.images)
@dataclass
class ProgressData(JSONe):
task_id: str
progress: float
eta_relative: int
message: Optional[str] = ""
status: Optional[PredictResultStatusCode] = PredictResultStatusCode.TASK_INPROGRESS
@dataclass
class ProgressResponse(GatewayResponse):
data: Optional[ProgressData] = None
步驟四:增加工具類(以ComfyUI生圖服務(wù)舉例)
增加 util.py
from multiprocessing.pool import ThreadPool
import logging
import requests
from dataclass_wizard.utils.string_conv import to_camel_case
logger = logging.getLogger(__name__)
def batch_download_images(image_links):
def _download(image_link):
attempts = 3
while attempts > 0:
try:
response = requests.get(image_link, timeout=100)
return response.content
except Exception:
logger.warning("Failed to download image, retrying...")
attempts -= 1
return None
pool = ThreadPool()
applied = []
for img_url in image_links:
applied.append(pool.apply_async(_download, (img_url, )))
ret = [r.get() for r in applied]
return [_ for _ in ret if _ is not None]
def convert_to_camel_case(data_dict):
if isinstance(data_dict, dict):
return {to_camel_case(key): convert_to_camel_case(value) for key, value in data_dict.items()}
elif isinstance(data_dict, list):
return [convert_to_camel_case(value) for value in data_dict]
else:
return data_dict
步驟五:填寫調(diào)用AK/SK、調(diào)用路徑、調(diào)用參數(shù)
from client import Client
from proto import ComfyRequest, ComfyResponse, PredictResultResponse, ProgressResponse
import time
import json
cli = Client(
endpoint="openai.edu-aliyun.com",
app_key="應(yīng)用AK",
app_secret="應(yīng)用SK"
)
# 原始調(diào)用方法
def call(url, body, method='POST', headers=None):
res = cli.invoke(url, body, method, headers)
if res:
data = res.json()
return data
def comfy_prompt(prompt: ComfyRequest, custom_resource_config_id='default') -> ComfyResponse:
print(prompt.to_dict())
headers = {}
if custom_resource_config_id:
headers['X-SP-RESOURCE-CONFIG-ID'] = custom_resource_config_id
r = ComfyResponse.from_dict(call("/scc/comfy_prompt", prompt.to_dict(), headers=headers))
print(r.to_dict())
if r.err_code:
raise Exception(r.err_message)
for _ in range(1200):
params = {}
params["taskId"] = r.data.task_id
# 查詢進(jìn)度
raw_res = call("/scc/comfy_get_progress", params, headers=None)
print(raw_res)
if raw_res:
r = ProgressResponse.from_dict(raw_res)
if r.status == 20:
pretty_json_str = json.dumps(raw_res, indent=2, ensure_ascii=False)
print(pretty_json_str)
raise Exception("Failed to call , error: %s" % pretty_json_str)
if r.data.status.finished():
# 查詢結(jié)果
raw_res = call("/scc/comfy_get_result", {"taskId": r.data.task_id})
return PredictResultResponse.from_dict(raw_res)
time.sleep(1)
raise Exception("1200s Timeout")
if __name__ == '__main__':
begin = time.time()
alias_id = "配置的工作流別名"
workflow_id = "控制臺(tái)獲取的工作流id"
# 定義的參數(shù)
params = {
"prompt": "A man is walking on the street."
}
result = comfy_prompt(ComfyRequest(alias_id=alias_id,
workflow_id=workflow_id,
inputs=params))
print("生圖結(jié)果:" + str(result))
print("時(shí)間消耗: %.2fs" % (time.time() - begin))
文檔內(nèi)容是否對(duì)您有幫助?