improve: mordernizing validation by migrating pydantic from 1.x to 2.x (#4592)

This commit is contained in:
Bowen Liang
2024-06-14 01:05:37 +08:00
committed by GitHub
parent e8afc416dd
commit f976740b57
87 changed files with 697 additions and 300 deletions

View File

@@ -13,7 +13,7 @@ class UserTool(BaseModel):
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: Optional[list[ToolParameter]]
parameters: Optional[list[ToolParameter]] = None
labels: list[str] = None
UserToolProviderTypeLiteral = Optional[Literal[

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Optional, Union, cast
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from core.tools.entities.common_entities import I18nObject
@@ -116,6 +116,14 @@ class ToolParameterOption(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
@classmethod
@field_validator('value', mode='before')
def transform_id_to_str(cls, value) -> str:
if isinstance(value, bool):
return str(value)
else:
return value
class ToolParameter(BaseModel):
class ToolParameterType(str, Enum):
@@ -278,7 +286,7 @@ class ToolRuntimeVariablePool(BaseModel):
'conversation_id': self.conversation_id,
'user_id': self.user_id,
'tenant_id': self.tenant_id,
'pool': [variable.dict() for variable in self.pool],
'pool': [variable.model_dump() for variable in self.pool],
}
def set_text(self, tool_name: str, name: str, value: str) -> None:

View File

@@ -4,7 +4,7 @@ from hmac import new as hmac_new
from json import loads as json_loads
from threading import Lock
from time import sleep, time
from typing import Any
from typing import Any, Optional
from httpx import get, post
from requests import get as requests_get
@@ -22,9 +22,9 @@ class AIPPTGenerateTool(BuiltinTool):
_api_base_url = URL('https://co.aippt.cn/api')
_api_token_cache = {}
_api_token_cache_lock = Lock()
_api_token_cache_lock:Optional[Lock] = None
_style_cache = {}
_style_cache_lock = Lock()
_style_cache_lock:Optional[Lock] = None
_task = {}
_task_type_map = {
@@ -32,6 +32,11 @@ class AIPPTGenerateTool(BuiltinTool):
'markdown': 7,
}
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._api_token_cache_lock = Lock()
self._style_cache_lock = Lock()
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
Invokes the AIPPT generate tool with the given user ID and tool parameters.

View File

@@ -44,14 +44,10 @@ class ArxivAPIWrapper(BaseModel):
arxiv.run("tree of thought llm)
"""
arxiv_search = arxiv.Search #: :meta private:
arxiv_exceptions = (
arxiv.ArxivError,
arxiv.UnexpectedEmptyPageError,
arxiv.HTTPError,
) # :meta private:
arxiv_search: type[arxiv.Search] = arxiv.Search #: :meta private:
arxiv_http_error: tuple[type[Exception]] = (arxiv.ArxivError, arxiv.UnexpectedEmptyPageError, arxiv.HTTPError)
top_k_results: int = 3
ARXIV_MAX_QUERY_LENGTH = 300
ARXIV_MAX_QUERY_LENGTH: int = 300
load_max_docs: int = 100
load_all_available_meta: bool = False
doc_content_chars_max: Optional[int] = 4000
@@ -73,7 +69,7 @@ class ArxivAPIWrapper(BaseModel):
results = self.arxiv_search( # type: ignore
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
).results()
except self.arxiv_exceptions as ex:
except arxiv_http_error as ex:
return f"Arxiv exception: {ex}"
docs = [
f"Published: {result.updated.date()}\n"

View File

@@ -8,7 +8,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
class BingSearchTool(BuiltinTool):
url = 'https://api.bing.microsoft.com/v7.0/search'
url: str = 'https://api.bing.microsoft.com/v7.0/search'
def _invoke_bing(self,
user_id: str,

View File

@@ -15,7 +15,7 @@ class BraveSearchWrapper(BaseModel):
"""The API key to use for the Brave search engine."""
search_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the search request."""
base_url = "https://api.search.brave.com/res/v1/web/search"
base_url: str = "https://api.search.brave.com/res/v1/web/search"
"""The base URL for the Brave search engine."""
def run(self, query: str) -> str:
@@ -58,8 +58,8 @@ class BraveSearchWrapper(BaseModel):
class BraveSearch(BaseModel):
"""Tool that queries the BraveSearch."""
name = "brave_search"
description = (
name: str = "brave_search"
description: str = (
"a search engine. "
"useful for when you need to answer questions about current events."
" input should be a search query."

View File

@@ -0,0 +1,174 @@
<<<<<<< HEAD
=======
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class DuckDuckGoSearchAPIWrapper(BaseModel):
"""Wrapper for DuckDuckGo Search API.
Free and does not require any setup.
"""
region: Optional[str] = "wt-wt"
safesearch: str = "moderate"
time: Optional[str] = "y"
max_results: int = 5
def get_snippets(self, query: str) -> list[str]:
"""Run query through DuckDuckGo and return concatenated results."""
from duckduckgo_search import DDGS
with DDGS() as ddgs:
results = ddgs.text(
query,
region=self.region,
safesearch=self.safesearch,
timelimit=self.time,
)
if results is None:
return ["No good DuckDuckGo Search Result was found"]
snippets = []
for i, res in enumerate(results, 1):
if res is not None:
snippets.append(res["body"])
if len(snippets) == self.max_results:
break
return snippets
def run(self, query: str) -> str:
snippets = self.get_snippets(query)
return " ".join(snippets)
def results(
self, query: str, num_results: int, backend: str = "api"
) -> list[dict[str, str]]:
"""Run query through DuckDuckGo and return metadata.
Args:
query: The query to search for.
num_results: The number of results to return.
Returns:
A list of dictionaries with the following keys:
snippet - The description of the result.
title - The title of the result.
link - The link to the result.
"""
from duckduckgo_search import DDGS
with DDGS() as ddgs:
results = ddgs.text(
query,
region=self.region,
safesearch=self.safesearch,
timelimit=self.time,
backend=backend,
)
if results is None:
return [{"Result": "No good DuckDuckGo Search Result was found"}]
def to_metadata(result: dict) -> dict[str, str]:
if backend == "news":
return {
"date": result["date"],
"title": result["title"],
"snippet": result["body"],
"source": result["source"],
"link": result["url"],
}
return {
"snippet": result["body"],
"title": result["title"],
"link": result["href"],
}
formatted_results = []
for i, res in enumerate(results, 1):
if res is not None:
formatted_results.append(to_metadata(res))
if len(formatted_results) == num_results:
break
return formatted_results
class DuckDuckGoSearchRun(BaseModel):
"""Tool that queries the DuckDuckGo search API."""
name: str = "duckduckgo_search"
description: str = (
"A wrapper around DuckDuckGo Search. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query."
)
api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
default_factory=DuckDuckGoSearchAPIWrapper
)
def _run(
self,
query: str,
) -> str:
"""Use the tool."""
return self.api_wrapper.run(query)
class DuckDuckGoSearchResults(BaseModel):
"""Tool that queries the DuckDuckGo search API and gets back json."""
name: str = "DuckDuckGo Results JSON"
description: str = (
"A wrapper around Duck Duck Go Search. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query. Output is a JSON array of the query results"
)
num_results: int = 4
api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
default_factory=DuckDuckGoSearchAPIWrapper
)
backend: str = "api"
def _run(
self,
query: str,
) -> str:
"""Use the tool."""
res = self.api_wrapper.results(query, self.num_results, backend=self.backend)
res_strs = [", ".join([f"{k}: {v}" for k, v in d.items()]) for d in res]
return ", ".join([f"[{rs}]" for rs in res_strs])
class DuckDuckGoInput(BaseModel):
query: str = Field(..., description="Search query.")
class DuckDuckGoSearchTool(BuiltinTool):
"""
Tool for performing a search using DuckDuckGo search engine.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
Invoke the DuckDuckGo search tool.
Args:
user_id (str): The ID of the user invoking the tool.
tool_parameters (dict[str, Any]): The parameters for the tool invocation.
Returns:
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation.
"""
query = tool_parameters.get('query', '')
if not query:
return self.create_text_message('Please input query')
tool = DuckDuckGoSearchRun(args_schema=DuckDuckGoInput)
result = tool._run(query)
return self.create_text_message(self.summary(user_id=user_id, content=result))
>>>>>>> 4c2ba442b (missing type in DuckDuckGoSearchAPIWrapper)

View File

@@ -69,10 +69,10 @@ parameters:
options:
- value: true
label:
en_US: Yes
en_US: 'Yes'
zh_Hans:
- value: false
label:
en_US: No
en_US: 'No'
zh_Hans:
default: false

View File

@@ -28,15 +28,15 @@ class PubMedAPIWrapper(BaseModel):
if False: the `metadata` gets only the most informative fields.
"""
base_url_esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
base_url_efetch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
max_retry = 5
sleep_time = 0.2
base_url_esearch: str = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
base_url_efetch: str = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
max_retry: int = 5
sleep_time: float = 0.2
# Default values for the parameters
top_k_results: int = 3
load_max_docs: int = 25
ARXIV_MAX_QUERY_LENGTH = 300
ARXIV_MAX_QUERY_LENGTH: int = 300
doc_content_chars_max: int = 2000
load_all_available_meta: bool = False
email: str = "your_email@example.com"
@@ -160,8 +160,8 @@ class PubMedAPIWrapper(BaseModel):
class PubmedQueryRun(BaseModel):
"""Tool that searches the PubMed API."""
name = "PubMed"
description = (
name: str = "PubMed"
description: str = (
"A wrapper around PubMed.org "
"Useful for when you need to answer questions about Physics, Mathematics, "
"Computer Science, Quantitative Biology, Quantitative Finance, Statistics, "

View File

@@ -12,7 +12,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
class QRCodeGeneratorTool(BuiltinTool):
error_correction_levels = {
error_correction_levels: dict[str, int] = {
'L': ERROR_CORRECT_L, # <=7%
'M': ERROR_CORRECT_M, # <=15%
'Q': ERROR_CORRECT_Q, # <=25%

View File

@@ -24,21 +24,21 @@ class SearXNGSearchTool(BuiltinTool):
Tool for performing a search using SearXNG engine.
"""
SEARCH_TYPE = {
SEARCH_TYPE: dict[str, str] = {
"page": "general",
"news": "news",
"image": "images",
# "video": "videos",
# "file": "files"
}
LINK_FILED = {
LINK_FILED: dict[str, str] = {
"page": "url",
"news": "url",
"image": "img_src",
# "video": "iframe_src",
# "file": "magnetlink"
}
TEXT_FILED = {
TEXT_FILED: dict[str, str] = {
"page": "content",
"news": "content",
"image": "img_src",

View File

@@ -11,7 +11,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
"""
This class is responsible for providing the stable diffusion tool.
"""
model_endpoint_map = {
model_endpoint_map: dict[str, str] = {
'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
'core': 'https://api.stability.ai/v2beta/stable-image/generate/core',

View File

@@ -98,11 +98,11 @@ parameters:
options:
- value: true
label:
en_US: Yes
en_US: 'Yes'
zh_Hans:
- value: false
label:
en_US: No
en_US: 'No'
zh_Hans:
default: true
- name: pagesize

View File

@@ -64,14 +64,14 @@ parameters:
options:
- value: true
label:
en_US: Yes
en_US: 'Yes'
zh_Hans:
pt_BR: Yes
pt_BR: 'Yes'
- value: false
label:
en_US: No
en_US: 'No'
zh_Hans:
pt_BR: No
pt_BR: 'No'
default: false
- name: include_answer
type: boolean
@@ -88,14 +88,14 @@ parameters:
options:
- value: true
label:
en_US: Yes
en_US: 'Yes'
zh_Hans:
pt_BR: Yes
pt_BR: 'Yes'
- value: false
label:
en_US: No
en_US: 'No'
zh_Hans:
pt_BR: No
pt_BR: 'No'
default: false
- name: include_raw_content
type: boolean
@@ -112,14 +112,14 @@ parameters:
options:
- value: true
label:
en_US: Yes
en_US: 'Yes'
zh_Hans:
pt_BR: Yes
pt_BR: 'Yes'
- value: false
label:
en_US: No
en_US: 'No'
zh_Hans:
pt_BR: No
pt_BR: 'No'
default: false
- name: max_results
type: number

View File

@@ -1,6 +1,6 @@
from typing import Any, Optional, Union
from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -15,7 +15,7 @@ class TwilioAPIWrapper(BaseModel):
named parameters to the constructor.
"""
client: Any #: :meta private:
client: Any = None #: :meta private:
account_sid: Optional[str] = None
"""Twilio account string identifier."""
auth_token: Optional[str] = None
@@ -32,7 +32,8 @@ class TwilioAPIWrapper(BaseModel):
must be empty.
"""
@validator("client", pre=True, always=True)
@classmethod
@field_validator('client', mode='before')
def set_validator(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
try:

View File

@@ -51,10 +51,10 @@ parameters:
options:
- value: true
label:
en_US: Yes
en_US: 'Yes'
zh_Hans:
- value: false
label:
en_US: No
en_US: 'No'
zh_Hans:
default: false

View File

@@ -32,10 +32,10 @@ class ApiTool(Tool):
:return: the new tool
"""
return self.__class__(
identity=self.identity.copy() if self.identity else None,
identity=self.identity.model_copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
description=self.description.model_copy() if self.description else None,
api_bundle=self.api_bundle.model_copy() if self.api_bundle else None,
runtime=Tool.Runtime(**runtime)
)

View File

@@ -2,7 +2,7 @@ from abc import abstractmethod
from typing import Any, Optional
from msal_extensions.persistence import ABC
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
@@ -17,9 +17,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
def _run(

View File

@@ -3,7 +3,8 @@ from copy import deepcopy
from enum import Enum
from typing import Any, Optional, Union
from pydantic import BaseModel, validator
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileVar
@@ -28,8 +29,12 @@ class Tool(BaseModel, ABC):
description: ToolDescription = None
is_team_authorization: bool = False
@validator('parameters', pre=True, always=True)
def set_parameters(cls, v, values):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@classmethod
@field_validator('parameters', mode='before')
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
return v or []
class Runtime(BaseModel):
@@ -65,9 +70,9 @@ class Tool(BaseModel, ABC):
:return: the new tool
"""
return self.__class__(
identity=self.identity.copy() if self.identity else None,
identity=self.identity.model_copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
description=self.description.model_copy() if self.description else None,
runtime=Tool.Runtime(**runtime),
)