chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -13,13 +13,8 @@ class DALLEProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"prompt": "cute girl, blue eyes, white hair, anime style",
"size": "small",
"n": 1
},
user_id="",
tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "small", "n": 1},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@@ -9,59 +9,58 @@ from core.tools.tool.builtin_tool import BuiltinTool
class DallE2Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
openai_organization = self.runtime.credentials.get('openai_organization_id', None)
openai_organization = self.runtime.credentials.get("openai_organization_id", None)
if not openai_organization:
openai_organization = None
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
openai_base_url = self.runtime.credentials.get("openai_base_url", None)
if not openai_base_url:
openai_base_url = None
else:
openai_base_url = str(URL(openai_base_url) / 'v1')
openai_base_url = str(URL(openai_base_url) / "v1")
client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],
api_key=self.runtime.credentials["openai_api_key"],
base_url=openai_base_url,
organization=openai_organization
organization=openai_organization,
)
SIZE_MAPPING = {
'small': '256x256',
'medium': '512x512',
'large': '1024x1024',
"small": "256x256",
"medium": "512x512",
"large": "1024x1024",
}
# prompt
prompt = tool_parameters.get('prompt', '')
prompt = tool_parameters.get("prompt", "")
if not prompt:
return self.create_text_message('Please input prompt')
return self.create_text_message("Please input prompt")
# get size
size = SIZE_MAPPING[tool_parameters.get('size', 'large')]
size = SIZE_MAPPING[tool_parameters.get("size", "large")]
# get n
n = tool_parameters.get('n', 1)
n = tool_parameters.get("n", 1)
# call openapi dalle2
response = client.images.generate(
prompt=prompt,
model='dall-e-2',
size=size,
n=n,
response_format='b64_json'
)
response = client.images.generate(prompt=prompt, model="dall-e-2", size=size, n=n, response_format="b64_json")
result = []
for image in response.data:
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={ 'mime_type': 'image/png' },
save_as=self.VARIABLE_KEY.IMAGE.value))
result.append(
self.create_blob_message(
blob=b64decode(image.b64_json),
meta={"mime_type": "image/png"},
save_as=self.VARIABLE_KEY.IMAGE.value,
)
)
return result

View File

@@ -10,69 +10,64 @@ from core.tools.tool.builtin_tool import BuiltinTool
class DallE3Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
openai_organization = self.runtime.credentials.get('openai_organization_id', None)
openai_organization = self.runtime.credentials.get("openai_organization_id", None)
if not openai_organization:
openai_organization = None
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
openai_base_url = self.runtime.credentials.get("openai_base_url", None)
if not openai_base_url:
openai_base_url = None
else:
openai_base_url = str(URL(openai_base_url) / 'v1')
openai_base_url = str(URL(openai_base_url) / "v1")
client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],
api_key=self.runtime.credentials["openai_api_key"],
base_url=openai_base_url,
organization=openai_organization
organization=openai_organization,
)
SIZE_MAPPING = {
'square': '1024x1024',
'vertical': '1024x1792',
'horizontal': '1792x1024',
"square": "1024x1024",
"vertical": "1024x1792",
"horizontal": "1792x1024",
}
# prompt
prompt = tool_parameters.get('prompt', '')
prompt = tool_parameters.get("prompt", "")
if not prompt:
return self.create_text_message('Please input prompt')
return self.create_text_message("Please input prompt")
# get size
size = SIZE_MAPPING[tool_parameters.get('size', 'square')]
size = SIZE_MAPPING[tool_parameters.get("size", "square")]
# get n
n = tool_parameters.get('n', 1)
n = tool_parameters.get("n", 1)
# get quality
quality = tool_parameters.get('quality', 'standard')
if quality not in ['standard', 'hd']:
return self.create_text_message('Invalid quality')
quality = tool_parameters.get("quality", "standard")
if quality not in ["standard", "hd"]:
return self.create_text_message("Invalid quality")
# get style
style = tool_parameters.get('style', 'vivid')
if style not in ['natural', 'vivid']:
return self.create_text_message('Invalid style')
style = tool_parameters.get("style", "vivid")
if style not in ["natural", "vivid"]:
return self.create_text_message("Invalid style")
# call openapi dalle3
response = client.images.generate(
prompt=prompt,
model='dall-e-3',
size=size,
n=n,
style=style,
quality=quality,
response_format='b64_json'
prompt=prompt, model="dall-e-3", size=size, n=n, style=style, quality=quality, response_format="b64_json"
)
result = []
for image in response.data:
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
blob_message = self.create_blob_message(blob=blob_image,
meta={'mime_type': mime_type},
save_as=self.VARIABLE_KEY.IMAGE.value)
blob_message = self.create_blob_message(
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VARIABLE_KEY.IMAGE.value
)
result.append(blob_message)
return result
@@ -86,7 +81,7 @@ class DallE3Tool(BuiltinTool):
:return: A tuple containing the MIME type and the decoded image bytes
"""
if DallE3Tool._is_plain_base64(base64_image):
return 'image/png', base64.b64decode(base64_image)
return "image/png", base64.b64decode(base64_image)
else:
return DallE3Tool._extract_mime_and_data(base64_image)
@@ -98,7 +93,7 @@ class DallE3Tool(BuiltinTool):
:param encoded_str: Base64 encoded image string
:return: True if the string is plain base64, False otherwise
"""
return not encoded_str.startswith('data:image')
return not encoded_str.startswith("data:image")
@staticmethod
def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]:
@@ -108,13 +103,13 @@ class DallE3Tool(BuiltinTool):
:param encoded_str: Base64 encoded image string with MIME type prefix
:return: A tuple containing the MIME type and the decoded image bytes
"""
mime_type = encoded_str.split(';')[0].split(':')[1]
image_data_base64 = encoded_str.split(',')[1]
mime_type = encoded_str.split(";")[0].split(":")[1]
image_data_base64 = encoded_str.split(",")[1]
decoded_data = base64.b64decode(image_data_base64)
return mime_type, decoded_data
@staticmethod
def _generate_random_id(length=8):
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
random_id = ''.join(random.choices(characters, k=length))
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
random_id = "".join(random.choices(characters, k=length))
return random_id