improve: mordernizing validation by migrating pydantic from 1.x to 2.x (#4592)
This commit is contained in:
@@ -13,7 +13,7 @@ class UserTool(BaseModel):
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]]
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
labels: list[str] = None
|
||||
|
||||
UserToolProviderTypeLiteral = Optional[Literal[
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
@@ -116,6 +116,14 @@ class ToolParameterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
@classmethod
|
||||
@field_validator('value', mode='before')
|
||||
def transform_id_to_str(cls, value) -> str:
|
||||
if isinstance(value, bool):
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class ToolParameter(BaseModel):
|
||||
class ToolParameterType(str, Enum):
|
||||
@@ -278,7 +286,7 @@ class ToolRuntimeVariablePool(BaseModel):
|
||||
'conversation_id': self.conversation_id,
|
||||
'user_id': self.user_id,
|
||||
'tenant_id': self.tenant_id,
|
||||
'pool': [variable.dict() for variable in self.pool],
|
||||
'pool': [variable.model_dump() for variable in self.pool],
|
||||
}
|
||||
|
||||
def set_text(self, tool_name: str, name: str, value: str) -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@ from hmac import new as hmac_new
|
||||
from json import loads as json_loads
|
||||
from threading import Lock
|
||||
from time import sleep, time
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from httpx import get, post
|
||||
from requests import get as requests_get
|
||||
@@ -22,9 +22,9 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
|
||||
_api_base_url = URL('https://co.aippt.cn/api')
|
||||
_api_token_cache = {}
|
||||
_api_token_cache_lock = Lock()
|
||||
_api_token_cache_lock:Optional[Lock] = None
|
||||
_style_cache = {}
|
||||
_style_cache_lock = Lock()
|
||||
_style_cache_lock:Optional[Lock] = None
|
||||
|
||||
_task = {}
|
||||
_task_type_map = {
|
||||
@@ -32,6 +32,11 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
'markdown': 7,
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._api_token_cache_lock = Lock()
|
||||
self._style_cache_lock = Lock()
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes the AIPPT generate tool with the given user ID and tool parameters.
|
||||
|
||||
@@ -44,14 +44,10 @@ class ArxivAPIWrapper(BaseModel):
|
||||
arxiv.run("tree of thought llm)
|
||||
"""
|
||||
|
||||
arxiv_search = arxiv.Search #: :meta private:
|
||||
arxiv_exceptions = (
|
||||
arxiv.ArxivError,
|
||||
arxiv.UnexpectedEmptyPageError,
|
||||
arxiv.HTTPError,
|
||||
) # :meta private:
|
||||
arxiv_search: type[arxiv.Search] = arxiv.Search #: :meta private:
|
||||
arxiv_http_error: tuple[type[Exception]] = (arxiv.ArxivError, arxiv.UnexpectedEmptyPageError, arxiv.HTTPError)
|
||||
top_k_results: int = 3
|
||||
ARXIV_MAX_QUERY_LENGTH = 300
|
||||
ARXIV_MAX_QUERY_LENGTH: int = 300
|
||||
load_max_docs: int = 100
|
||||
load_all_available_meta: bool = False
|
||||
doc_content_chars_max: Optional[int] = 4000
|
||||
@@ -73,7 +69,7 @@ class ArxivAPIWrapper(BaseModel):
|
||||
results = self.arxiv_search( # type: ignore
|
||||
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
|
||||
).results()
|
||||
except self.arxiv_exceptions as ex:
|
||||
except arxiv_http_error as ex:
|
||||
return f"Arxiv exception: {ex}"
|
||||
docs = [
|
||||
f"Published: {result.updated.date()}\n"
|
||||
|
||||
@@ -8,7 +8,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BingSearchTool(BuiltinTool):
|
||||
url = 'https://api.bing.microsoft.com/v7.0/search'
|
||||
url: str = 'https://api.bing.microsoft.com/v7.0/search'
|
||||
|
||||
def _invoke_bing(self,
|
||||
user_id: str,
|
||||
|
||||
@@ -15,7 +15,7 @@ class BraveSearchWrapper(BaseModel):
|
||||
"""The API key to use for the Brave search engine."""
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the search request."""
|
||||
base_url = "https://api.search.brave.com/res/v1/web/search"
|
||||
base_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
"""The base URL for the Brave search engine."""
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
@@ -58,8 +58,8 @@ class BraveSearchWrapper(BaseModel):
|
||||
class BraveSearch(BaseModel):
|
||||
"""Tool that queries the BraveSearch."""
|
||||
|
||||
name = "brave_search"
|
||||
description = (
|
||||
name: str = "brave_search"
|
||||
description: str = (
|
||||
"a search engine. "
|
||||
"useful for when you need to answer questions about current events."
|
||||
" input should be a search query."
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DuckDuckGoSearchAPIWrapper(BaseModel):
|
||||
"""Wrapper for DuckDuckGo Search API.
|
||||
|
||||
Free and does not require any setup.
|
||||
"""
|
||||
|
||||
region: Optional[str] = "wt-wt"
|
||||
safesearch: str = "moderate"
|
||||
time: Optional[str] = "y"
|
||||
max_results: int = 5
|
||||
|
||||
def get_snippets(self, query: str) -> list[str]:
|
||||
"""Run query through DuckDuckGo and return concatenated results."""
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
results = ddgs.text(
|
||||
query,
|
||||
region=self.region,
|
||||
safesearch=self.safesearch,
|
||||
timelimit=self.time,
|
||||
)
|
||||
if results is None:
|
||||
return ["No good DuckDuckGo Search Result was found"]
|
||||
snippets = []
|
||||
for i, res in enumerate(results, 1):
|
||||
if res is not None:
|
||||
snippets.append(res["body"])
|
||||
if len(snippets) == self.max_results:
|
||||
break
|
||||
return snippets
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
snippets = self.get_snippets(query)
|
||||
return " ".join(snippets)
|
||||
|
||||
def results(
|
||||
self, query: str, num_results: int, backend: str = "api"
|
||||
) -> list[dict[str, str]]:
|
||||
"""Run query through DuckDuckGo and return metadata.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
num_results: The number of results to return.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries with the following keys:
|
||||
snippet - The description of the result.
|
||||
title - The title of the result.
|
||||
link - The link to the result.
|
||||
"""
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
results = ddgs.text(
|
||||
query,
|
||||
region=self.region,
|
||||
safesearch=self.safesearch,
|
||||
timelimit=self.time,
|
||||
backend=backend,
|
||||
)
|
||||
if results is None:
|
||||
return [{"Result": "No good DuckDuckGo Search Result was found"}]
|
||||
|
||||
def to_metadata(result: dict) -> dict[str, str]:
|
||||
if backend == "news":
|
||||
return {
|
||||
"date": result["date"],
|
||||
"title": result["title"],
|
||||
"snippet": result["body"],
|
||||
"source": result["source"],
|
||||
"link": result["url"],
|
||||
}
|
||||
return {
|
||||
"snippet": result["body"],
|
||||
"title": result["title"],
|
||||
"link": result["href"],
|
||||
}
|
||||
|
||||
formatted_results = []
|
||||
for i, res in enumerate(results, 1):
|
||||
if res is not None:
|
||||
formatted_results.append(to_metadata(res))
|
||||
if len(formatted_results) == num_results:
|
||||
break
|
||||
return formatted_results
|
||||
|
||||
|
||||
class DuckDuckGoSearchRun(BaseModel):
|
||||
"""Tool that queries the DuckDuckGo search API."""
|
||||
|
||||
name: str = "duckduckgo_search"
|
||||
description: str = (
|
||||
"A wrapper around DuckDuckGo Search. "
|
||||
"Useful for when you need to answer questions about current events. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
|
||||
default_factory=DuckDuckGoSearchAPIWrapper
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
return self.api_wrapper.run(query)
|
||||
|
||||
|
||||
class DuckDuckGoSearchResults(BaseModel):
|
||||
"""Tool that queries the DuckDuckGo search API and gets back json."""
|
||||
|
||||
name: str = "DuckDuckGo Results JSON"
|
||||
description: str = (
|
||||
"A wrapper around Duck Duck Go Search. "
|
||||
"Useful for when you need to answer questions about current events. "
|
||||
"Input should be a search query. Output is a JSON array of the query results"
|
||||
)
|
||||
num_results: int = 4
|
||||
api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
|
||||
default_factory=DuckDuckGoSearchAPIWrapper
|
||||
)
|
||||
backend: str = "api"
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
res = self.api_wrapper.results(query, self.num_results, backend=self.backend)
|
||||
res_strs = [", ".join([f"{k}: {v}" for k, v in d.items()]) for d in res]
|
||||
return ", ".join([f"[{rs}]" for rs in res_strs])
|
||||
|
||||
class DuckDuckGoInput(BaseModel):
|
||||
query: str = Field(..., description="Search query.")
|
||||
|
||||
class DuckDuckGoSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a search using DuckDuckGo search engine.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invoke the DuckDuckGo search tool.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool invocation.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation.
|
||||
"""
|
||||
query = tool_parameters.get('query', '')
|
||||
|
||||
if not query:
|
||||
return self.create_text_message('Please input query')
|
||||
|
||||
tool = DuckDuckGoSearchRun(args_schema=DuckDuckGoInput)
|
||||
|
||||
result = tool._run(query)
|
||||
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=result))
|
||||
|
||||
>>>>>>> 4c2ba442b (missing type in DuckDuckGoSearchAPIWrapper)
|
||||
@@ -69,10 +69,10 @@ parameters:
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
default: false
|
||||
|
||||
@@ -28,15 +28,15 @@ class PubMedAPIWrapper(BaseModel):
|
||||
if False: the `metadata` gets only the most informative fields.
|
||||
"""
|
||||
|
||||
base_url_esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
|
||||
base_url_efetch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
|
||||
max_retry = 5
|
||||
sleep_time = 0.2
|
||||
base_url_esearch: str = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
|
||||
base_url_efetch: str = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
|
||||
max_retry: int = 5
|
||||
sleep_time: float = 0.2
|
||||
|
||||
# Default values for the parameters
|
||||
top_k_results: int = 3
|
||||
load_max_docs: int = 25
|
||||
ARXIV_MAX_QUERY_LENGTH = 300
|
||||
ARXIV_MAX_QUERY_LENGTH: int = 300
|
||||
doc_content_chars_max: int = 2000
|
||||
load_all_available_meta: bool = False
|
||||
email: str = "your_email@example.com"
|
||||
@@ -160,8 +160,8 @@ class PubMedAPIWrapper(BaseModel):
|
||||
class PubmedQueryRun(BaseModel):
|
||||
"""Tool that searches the PubMed API."""
|
||||
|
||||
name = "PubMed"
|
||||
description = (
|
||||
name: str = "PubMed"
|
||||
description: str = (
|
||||
"A wrapper around PubMed.org "
|
||||
"Useful for when you need to answer questions about Physics, Mathematics, "
|
||||
"Computer Science, Quantitative Biology, Quantitative Finance, Statistics, "
|
||||
|
||||
@@ -12,7 +12,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class QRCodeGeneratorTool(BuiltinTool):
|
||||
error_correction_levels = {
|
||||
error_correction_levels: dict[str, int] = {
|
||||
'L': ERROR_CORRECT_L, # <=7%
|
||||
'M': ERROR_CORRECT_M, # <=15%
|
||||
'Q': ERROR_CORRECT_Q, # <=25%
|
||||
|
||||
@@ -24,21 +24,21 @@ class SearXNGSearchTool(BuiltinTool):
|
||||
Tool for performing a search using SearXNG engine.
|
||||
"""
|
||||
|
||||
SEARCH_TYPE = {
|
||||
SEARCH_TYPE: dict[str, str] = {
|
||||
"page": "general",
|
||||
"news": "news",
|
||||
"image": "images",
|
||||
# "video": "videos",
|
||||
# "file": "files"
|
||||
}
|
||||
LINK_FILED = {
|
||||
LINK_FILED: dict[str, str] = {
|
||||
"page": "url",
|
||||
"news": "url",
|
||||
"image": "img_src",
|
||||
# "video": "iframe_src",
|
||||
# "file": "magnetlink"
|
||||
}
|
||||
TEXT_FILED = {
|
||||
TEXT_FILED: dict[str, str] = {
|
||||
"page": "content",
|
||||
"news": "content",
|
||||
"image": "img_src",
|
||||
|
||||
@@ -11,7 +11,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
|
||||
"""
|
||||
This class is responsible for providing the stable diffusion tool.
|
||||
"""
|
||||
model_endpoint_map = {
|
||||
model_endpoint_map: dict[str, str] = {
|
||||
'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
|
||||
'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
|
||||
'core': 'https://api.stability.ai/v2beta/stable-image/generate/core',
|
||||
|
||||
@@ -98,11 +98,11 @@ parameters:
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
default: true
|
||||
- name: pagesize
|
||||
|
||||
@@ -64,14 +64,14 @@ parameters:
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
pt_BR: Yes
|
||||
pt_BR: 'Yes'
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
pt_BR: No
|
||||
pt_BR: 'No'
|
||||
default: false
|
||||
- name: include_answer
|
||||
type: boolean
|
||||
@@ -88,14 +88,14 @@ parameters:
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
pt_BR: Yes
|
||||
pt_BR: 'Yes'
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
pt_BR: No
|
||||
pt_BR: 'No'
|
||||
default: false
|
||||
- name: include_raw_content
|
||||
type: boolean
|
||||
@@ -112,14 +112,14 @@ parameters:
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
pt_BR: Yes
|
||||
pt_BR: 'Yes'
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
pt_BR: No
|
||||
pt_BR: 'No'
|
||||
default: false
|
||||
- name: max_results
|
||||
type: number
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
@@ -15,7 +15,7 @@ class TwilioAPIWrapper(BaseModel):
|
||||
named parameters to the constructor.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
client: Any = None #: :meta private:
|
||||
account_sid: Optional[str] = None
|
||||
"""Twilio account string identifier."""
|
||||
auth_token: Optional[str] = None
|
||||
@@ -32,7 +32,8 @@ class TwilioAPIWrapper(BaseModel):
|
||||
must be empty.
|
||||
"""
|
||||
|
||||
@validator("client", pre=True, always=True)
|
||||
@classmethod
|
||||
@field_validator('client', mode='before')
|
||||
def set_validator(cls, values: dict) -> dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
|
||||
@@ -51,10 +51,10 @@ parameters:
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
default: false
|
||||
|
||||
@@ -32,10 +32,10 @@ class ApiTool(Tool):
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
identity=self.identity.copy() if self.identity else None,
|
||||
identity=self.identity.model_copy() if self.identity else None,
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
|
||||
description=self.description.model_copy() if self.description else None,
|
||||
api_bundle=self.api_bundle.model_copy() if self.api_bundle else None,
|
||||
runtime=Tool.Runtime(**runtime)
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from abc import abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from msal_extensions.persistence import ABC
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
|
||||
@@ -17,9 +17,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
|
||||
@@ -3,7 +3,8 @@ from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileVar
|
||||
@@ -28,8 +29,12 @@ class Tool(BaseModel, ABC):
|
||||
description: ToolDescription = None
|
||||
is_team_authorization: bool = False
|
||||
|
||||
@validator('parameters', pre=True, always=True)
|
||||
def set_parameters(cls, v, values):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@classmethod
|
||||
@field_validator('parameters', mode='before')
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
|
||||
return v or []
|
||||
|
||||
class Runtime(BaseModel):
|
||||
@@ -65,9 +70,9 @@ class Tool(BaseModel, ABC):
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
identity=self.identity.copy() if self.identity else None,
|
||||
identity=self.identity.model_copy() if self.identity else None,
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
description=self.description.model_copy() if self.description else None,
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user