Feat: AIPPT & DynamicToolParamter (#2725)

This commit is contained in:
Yeuoly
2024-03-07 15:04:42 +08:00
committed by GitHub
parent 7052565380
commit 27e678480e
8 changed files with 720 additions and 11 deletions

View File

@@ -2,11 +2,11 @@ import io
import json
from base64 import b64decode, b64encode
from copy import deepcopy
from os.path import join
from typing import Any, Union
from httpx import get, post
from PIL import Image
from yarl import URL
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
@@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool):
# set model
try:
url = join(base_url, 'sdapi/v1/options')
url = str(URL(base_url) / 'sdapi' / 'v1' / 'options')
response = post(url, data=json.dumps({
'sd_model_checkpoint': model
}))
@@ -153,8 +153,21 @@ class StableDiffusionTool(BuiltinTool):
if not model:
raise ToolProviderCredentialValidationError('Please input model')
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
if response.status_code != 200:
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
response = get(url=api_url, timeout=10)
if response.status_code == 404:
# try draw a picture
self._invoke(
user_id='test',
tool_parameters={
'prompt': 'a cat',
'width': 1024,
'height': 1024,
'steps': 1,
'lora': '',
}
)
elif response.status_code != 200:
raise ToolProviderCredentialValidationError('Failed to get models')
else:
models = [d['model_name'] for d in response.json()]
@@ -165,6 +178,23 @@ class StableDiffusionTool(BuiltinTool):
except Exception as e:
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
def get_sd_models(self) -> list[str]:
"""
get sd models
"""
try:
base_url = self.runtime.credentials.get('base_url', None)
if not base_url:
return []
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
response = get(url=api_url, timeout=10)
if response.status_code != 200:
return []
else:
return [d['model_name'] for d in response.json()]
except Exception as e:
return []
def img2img(self, base_url: str, lora: str, image_binary: bytes,
prompt: str, negative_prompt: str,
width: int, height: int, steps: int) \
@@ -192,7 +222,7 @@ class StableDiffusionTool(BuiltinTool):
draw_options['prompt'] = prompt
try:
url = join(base_url, 'sdapi/v1/img2img')
url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200:
return self.create_text_message('Failed to generate image')
@@ -225,7 +255,7 @@ class StableDiffusionTool(BuiltinTool):
draw_options['negative_prompt'] = negative_prompt
try:
url = join(base_url, 'sdapi/v1/txt2img')
url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')
response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200:
return self.create_text_message('Failed to generate image')
@@ -269,5 +299,29 @@ class StableDiffusionTool(BuiltinTool):
label=I18nObject(en_US=i.name, zh_Hans=i.name)
) for i in self.list_default_image_variables()])
)
if self.runtime.credentials:
try:
models = self.get_sd_models()
if len(models) != 0:
parameters.append(
ToolParameter(name='model',
label=I18nObject(en_US='Model', zh_Hans='Model'),
human_description=I18nObject(
en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档',
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
required=True,
default=models[0],
options=[ToolParameterOption(
value=i,
label=I18nObject(en_US=i, zh_Hans=i)
) for i in models])
)
except:
pass
return parameters