Feat: AIPPT & DynamicToolParamter (#2725)
This commit is contained in:
@@ -2,11 +2,11 @@ import io
|
||||
import json
|
||||
from base64 import b64decode, b64encode
|
||||
from copy import deepcopy
|
||||
from os.path import join
|
||||
from typing import Any, Union
|
||||
|
||||
from httpx import get, post
|
||||
from PIL import Image
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
@@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool):
|
||||
|
||||
# set model
|
||||
try:
|
||||
url = join(base_url, 'sdapi/v1/options')
|
||||
url = str(URL(base_url) / 'sdapi' / 'v1' / 'options')
|
||||
response = post(url, data=json.dumps({
|
||||
'sd_model_checkpoint': model
|
||||
}))
|
||||
@@ -153,8 +153,21 @@ class StableDiffusionTool(BuiltinTool):
|
||||
if not model:
|
||||
raise ToolProviderCredentialValidationError('Please input model')
|
||||
|
||||
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
|
||||
if response.status_code != 200:
|
||||
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
|
||||
response = get(url=api_url, timeout=10)
|
||||
if response.status_code == 404:
|
||||
# try draw a picture
|
||||
self._invoke(
|
||||
user_id='test',
|
||||
tool_parameters={
|
||||
'prompt': 'a cat',
|
||||
'width': 1024,
|
||||
'height': 1024,
|
||||
'steps': 1,
|
||||
'lora': '',
|
||||
}
|
||||
)
|
||||
elif response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError('Failed to get models')
|
||||
else:
|
||||
models = [d['model_name'] for d in response.json()]
|
||||
@@ -165,6 +178,23 @@ class StableDiffusionTool(BuiltinTool):
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
|
||||
|
||||
def get_sd_models(self) -> list[str]:
|
||||
"""
|
||||
get sd models
|
||||
"""
|
||||
try:
|
||||
base_url = self.runtime.credentials.get('base_url', None)
|
||||
if not base_url:
|
||||
return []
|
||||
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
|
||||
response = get(url=api_url, timeout=10)
|
||||
if response.status_code != 200:
|
||||
return []
|
||||
else:
|
||||
return [d['model_name'] for d in response.json()]
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
def img2img(self, base_url: str, lora: str, image_binary: bytes,
|
||||
prompt: str, negative_prompt: str,
|
||||
width: int, height: int, steps: int) \
|
||||
@@ -192,7 +222,7 @@ class StableDiffusionTool(BuiltinTool):
|
||||
draw_options['prompt'] = prompt
|
||||
|
||||
try:
|
||||
url = join(base_url, 'sdapi/v1/img2img')
|
||||
url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
|
||||
response = post(url, data=json.dumps(draw_options), timeout=120)
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message('Failed to generate image')
|
||||
@@ -225,7 +255,7 @@ class StableDiffusionTool(BuiltinTool):
|
||||
draw_options['negative_prompt'] = negative_prompt
|
||||
|
||||
try:
|
||||
url = join(base_url, 'sdapi/v1/txt2img')
|
||||
url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')
|
||||
response = post(url, data=json.dumps(draw_options), timeout=120)
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message('Failed to generate image')
|
||||
@@ -269,5 +299,29 @@ class StableDiffusionTool(BuiltinTool):
|
||||
label=I18nObject(en_US=i.name, zh_Hans=i.name)
|
||||
) for i in self.list_default_image_variables()])
|
||||
)
|
||||
|
||||
if self.runtime.credentials:
|
||||
try:
|
||||
models = self.get_sd_models()
|
||||
if len(models) != 0:
|
||||
parameters.append(
|
||||
ToolParameter(name='model',
|
||||
label=I18nObject(en_US='Model', zh_Hans='Model'),
|
||||
human_description=I18nObject(
|
||||
en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
|
||||
zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档',
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
|
||||
required=True,
|
||||
default=models[0],
|
||||
options=[ToolParameterOption(
|
||||
value=i,
|
||||
label=I18nObject(en_US=i, zh_Hans=i)
|
||||
) for i in models])
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
return parameters
|
||||
|
||||
Reference in New Issue
Block a user