chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -13,12 +13,8 @@ class AzureDALLEProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"prompt": "cute girl, blue eyes, white hair, anime style",
"size": "square",
"n": 1
},
user_id="",
tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "square", "n": 1},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@@ -9,47 +9,48 @@ from core.tools.tool.builtin_tool import BuiltinTool
class DallE3Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
client = AzureOpenAI(
api_version=self.runtime.credentials['azure_openai_api_version'],
azure_endpoint=self.runtime.credentials['azure_openai_base_url'],
api_key=self.runtime.credentials['azure_openai_api_key'],
api_version=self.runtime.credentials["azure_openai_api_version"],
azure_endpoint=self.runtime.credentials["azure_openai_base_url"],
api_key=self.runtime.credentials["azure_openai_api_key"],
)
SIZE_MAPPING = {
'square': '1024x1024',
'vertical': '1024x1792',
'horizontal': '1792x1024',
"square": "1024x1024",
"vertical": "1024x1792",
"horizontal": "1792x1024",
}
# prompt
prompt = tool_parameters.get('prompt', '')
prompt = tool_parameters.get("prompt", "")
if not prompt:
return self.create_text_message('Please input prompt')
return self.create_text_message("Please input prompt")
# get size
size = SIZE_MAPPING[tool_parameters.get('size', 'square')]
size = SIZE_MAPPING[tool_parameters.get("size", "square")]
# get n
n = tool_parameters.get('n', 1)
n = tool_parameters.get("n", 1)
# get quality
quality = tool_parameters.get('quality', 'standard')
if quality not in ['standard', 'hd']:
return self.create_text_message('Invalid quality')
quality = tool_parameters.get("quality", "standard")
if quality not in ["standard", "hd"]:
return self.create_text_message("Invalid quality")
# get style
style = tool_parameters.get('style', 'vivid')
if style not in ['natural', 'vivid']:
return self.create_text_message('Invalid style')
style = tool_parameters.get("style", "vivid")
if style not in ["natural", "vivid"]:
return self.create_text_message("Invalid style")
# set extra body
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
extra_body = {'seed': seed_id}
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
extra_body = {"seed": seed_id}
# call openapi dalle3
model = self.runtime.credentials['azure_openai_api_model_name']
model = self.runtime.credentials["azure_openai_api_model_name"]
response = client.images.generate(
prompt=prompt,
model=model,
@@ -58,21 +59,25 @@ class DallE3Tool(BuiltinTool):
extra_body=extra_body,
style=style,
quality=quality,
response_format='b64_json'
response_format="b64_json",
)
result = []
for image in response.data:
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={'mime_type': 'image/png'},
save_as=self.VARIABLE_KEY.IMAGE.value))
result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}'))
result.append(
self.create_blob_message(
blob=b64decode(image.b64_json),
meta={"mime_type": "image/png"},
save_as=self.VARIABLE_KEY.IMAGE.value,
)
)
result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}"))
return result
@staticmethod
def _generate_random_id(length=8):
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
random_id = ''.join(random.choices(characters, k=length))
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
random_id = "".join(random.choices(characters, k=length))
return random_id