Introduce Plugins (#13836)
Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: xhe <xw897002528@gmail.com> Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: takatost <takatost@gmail.com> Co-authored-by: kurokobo <kuro664@gmail.com> Co-authored-by: Novice Lee <novicelee@NoviPro.local> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: AkaraChen <akarachen@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com> Co-authored-by: AkaraChen <85140972+AkaraChen@users.noreply.github.com> Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: Novice <857526207@qq.com> Co-authored-by: Hiroki Nagai <82458324+nagaihiroki-git@users.noreply.github.com> Co-authored-by: Gen Sato <52241300+halogen22@users.noreply.github.com> Co-authored-by: eux <euxuuu@gmail.com> Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com> Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com> Co-authored-by: lotsik <lotsik@mail.ru> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: gakkiyomi <gakkiyomi@aliyun.com> Co-authored-by: CN-P5 <heibai2006@gmail.com> Co-authored-by: CN-P5 <heibai2006@qq.com> Co-authored-by: Chuehnone <1897025+chuehnone@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Boris Feld <lothiraldan@gmail.com> Co-authored-by: mbo <himabo@gmail.com> Co-authored-by: mabo <mabo@aeyes.ai> Co-authored-by: Warren Chen <warren.chen830@gmail.com> Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com> Co-authored-by: jiandanfeng <chenjh3@wangsu.com> Co-authored-by: zhu-an <70234959+xhdd123321@users.noreply.github.com> Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com> Co-authored-by: 海狸大師 <86974027+yenslife@users.noreply.github.com> Co-authored-by: Xu Song <xusong.vip@gmail.com> Co-authored-by: rayshaw001 <396301947@163.com> Co-authored-by: Ding Jiatong <dingjiatong@gmail.com> Co-authored-by: Bowen Liang <liangbowen@gf.com.cn> Co-authored-by: JasonVV <jasonwangiii@outlook.com> Co-authored-by: le0zh <newlight@qq.com> Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com> Co-authored-by: k-zaku <zaku99@outlook.jp> Co-authored-by: luckylhb90 <luckylhb90@gmail.com> Co-authored-by: hobo.l <hobo.l@binance.com> Co-authored-by: jiangbo721 <365065261@qq.com> Co-authored-by: 刘江波 <jiangbo721@163.com> Co-authored-by: Shun Miyazawa <34241526+miya@users.noreply.github.com> Co-authored-by: EricPan <30651140+Egfly@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: sino <sino2322@gmail.com> Co-authored-by: Jhvcc <37662342+Jhvcc@users.noreply.github.com> Co-authored-by: lowell <lowell.hu@zkteco.in> Co-authored-by: Boris Polonsky <BorisPolonsky@users.noreply.github.com> Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com> Co-authored-by: Ademílson Tonato <ademilson.tonato@refurbed.com> Co-authored-by: IWAI, Masaharu <iwaim.sub@gmail.com> Co-authored-by: Yueh-Po Peng (Yabi) <94939112+y10ab1@users.noreply.github.com> Co-authored-by: Jason <ggbbddjm@gmail.com> Co-authored-by: Xin Zhang <sjhpzx@gmail.com> Co-authored-by: yjc980121 <3898524+yjc980121@users.noreply.github.com> Co-authored-by: heyszt <36215648+hieheihei@users.noreply.github.com> Co-authored-by: Abdullah AlOsaimi <osaimiacc@gmail.com> Co-authored-by: Abdullah AlOsaimi <189027247+osaimi@users.noreply.github.com> Co-authored-by: Yingchun Lai <laiyingchun@apache.org> Co-authored-by: Hash Brown <hi@xzd.me> Co-authored-by: zuodongxu <192560071+zuodongxu@users.noreply.github.com> Co-authored-by: Masashi Tomooka <tmokmss@users.noreply.github.com> Co-authored-by: aplio <ryo.091219@gmail.com> Co-authored-by: Obada Khalili <54270856+obadakhalili@users.noreply.github.com> Co-authored-by: Nam Vu <zuzoovn@gmail.com> Co-authored-by: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com> Co-authored-by: TechnoHouse <13776377+deephbz@users.noreply.github.com> Co-authored-by: Riddhimaan-Senapati <114703025+Riddhimaan-Senapati@users.noreply.github.com> Co-authored-by: MaFee921 <31881301+2284730142@users.noreply.github.com> Co-authored-by: te-chan <t-nakanome@sakura-is.co.jp> Co-authored-by: HQidea <HQidea@users.noreply.github.com> Co-authored-by: Joshbly <36315710+Joshbly@users.noreply.github.com> Co-authored-by: xhe <xw897002528@gmail.com> Co-authored-by: weiwenyan-dev <154779315+weiwenyan-dev@users.noreply.github.com> Co-authored-by: ex_wenyan.wei <ex_wenyan.wei@tcl.com> Co-authored-by: engchina <12236799+engchina@users.noreply.github.com> Co-authored-by: engchina <atjapan2015@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: 呆萌闷油瓶 <253605712@qq.com> Co-authored-by: Kemal <kemalmeler@outlook.com> Co-authored-by: Lazy_Frog <4590648+lazyFrogLOL@users.noreply.github.com> Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Co-authored-by: Steven sun <98230804+Tuyohai@users.noreply.github.com> Co-authored-by: steven <sunzwj@digitalchina.com> Co-authored-by: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com> Co-authored-by: Katy Tao <34019945+KatyTao@users.noreply.github.com> Co-authored-by: depy <42985524+h4ckdepy@users.noreply.github.com> Co-authored-by: 胡春东 <gycm520@gmail.com> Co-authored-by: Junjie.M <118170653@qq.com> Co-authored-by: MuYu <mr.muzea@gmail.com> Co-authored-by: Naoki Takashima <39912547+takatea@users.noreply.github.com> Co-authored-by: Summer-Gu <37869445+gubinjie@users.noreply.github.com> Co-authored-by: Fei He <droxer.he@gmail.com> Co-authored-by: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Co-authored-by: Yuanbo Li <ybalbert@amazon.com> Co-authored-by: douxc <7553076+douxc@users.noreply.github.com> Co-authored-by: liuzhenghua <1090179900@qq.com> Co-authored-by: Wu Jiayang <62842862+Wu-Jiayang@users.noreply.github.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: kimjion <45935338+kimjion@users.noreply.github.com> Co-authored-by: AugNSo <song.tiankai@icloud.com> Co-authored-by: llinvokerl <38915183+llinvokerl@users.noreply.github.com> Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com> Co-authored-by: Vasu Negi <vasu-negi@users.noreply.github.com> Co-authored-by: Hundredwz <1808096180@qq.com> Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
This commit is contained in:
@@ -3,115 +3,120 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
|
||||
class ToolConfigurationManager(BaseModel):
|
||||
class ProviderConfigEncrypter(BaseModel):
|
||||
tenant_id: str
|
||||
provider_controller: ToolProviderController
|
||||
config: list[BasicProviderConfig]
|
||||
provider_type: str
|
||||
provider_identity: str
|
||||
|
||||
def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy credentials
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(credentials)
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
credentials = self._deep_copy(credentials)
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = self.provider_controller.get_credentials_schema()
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
|
||||
if field_name in credentials:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
|
||||
credentials[field_name] = encrypted
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return credentials
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
credentials = self._deep_copy(credentials)
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = self.provider_controller.get_credentials_schema()
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
|
||||
if field_name in credentials:
|
||||
if len(credentials[field_name]) > 6:
|
||||
credentials[field_name] = (
|
||||
credentials[field_name][:2]
|
||||
+ "*" * (len(credentials[field_name]) - 4)
|
||||
+ credentials[field_name][-2:]
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
credentials[field_name] = "*" * len(credentials[field_name])
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return credentials
|
||||
return data
|
||||
|
||||
def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
identity_id = ""
|
||||
if self.provider_controller.identity:
|
||||
identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}"
|
||||
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=identity_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
credentials = self._deep_copy(credentials)
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = self.provider_controller.get_credentials_schema()
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
|
||||
if field_name in credentials:
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
try:
|
||||
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
|
||||
except:
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cache.set(credentials)
|
||||
return credentials
|
||||
cache.set(data)
|
||||
return data
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
identity_id = ""
|
||||
if self.provider_controller.identity:
|
||||
identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}"
|
||||
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=identity_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
|
||||
class ToolParameterConfigurationManager(BaseModel):
|
||||
class ToolParameterConfigurationManager:
|
||||
"""
|
||||
Tool parameter configuration manager
|
||||
"""
|
||||
@@ -119,9 +124,18 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
tenant_id: str
|
||||
tool_runtime: Tool
|
||||
provider_name: str
|
||||
provider_type: str
|
||||
provider_type: ToolProviderType
|
||||
identity_id: str
|
||||
|
||||
def __init__(
|
||||
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.tool_runtime = tool_runtime
|
||||
self.provider_name = provider_name
|
||||
self.provider_type = provider_type
|
||||
self.identity_id = identity_id
|
||||
|
||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
deep copy parameters
|
||||
@@ -133,7 +147,7 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
merge parameters
|
||||
"""
|
||||
# get tool parameters
|
||||
tool_parameters = self.tool_runtime.parameters or []
|
||||
tool_parameters = self.tool_runtime.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = self.tool_runtime.get_runtime_parameters()
|
||||
# override parameters
|
||||
@@ -207,13 +221,11 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
if self.tool_runtime is None or self.tool_runtime.identity is None:
|
||||
raise ValueError("tool_runtime is required")
|
||||
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
@@ -234,7 +246,7 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
try:
|
||||
has_secret_input = True
|
||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_secret_input:
|
||||
@@ -243,13 +255,10 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
return parameters
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
if self.tool_runtime is None or self.tool_runtime.identity is None:
|
||||
raise ValueError("tool_runtime is required")
|
||||
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class DatasetMultiRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="dataset multi retriever and rerank")
|
||||
|
||||
|
||||
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"""Tool for querying multi dataset."""
|
||||
|
||||
name: str = "dataset_"
|
||||
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
|
||||
description: str = "dataset multi retriever and rerank. "
|
||||
dataset_ids: list[str]
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
|
||||
return cls(
|
||||
name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs
|
||||
)
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
threads = []
|
||||
all_documents: list[RagDocument] = []
|
||||
for dataset_id in self.dataset_ids:
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"all_documents": all_documents,
|
||||
"hit_callbacks": self.hit_callbacks,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=self.reranking_provider_name,
|
||||
model_type=ModelType.RERANK,
|
||||
model=self.reranking_model_name,
|
||||
)
|
||||
|
||||
rerank_runner = RerankModelRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(all_documents)
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if item.metadata and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(
|
||||
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
|
||||
)
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
|
||||
else:
|
||||
document_context_list.append(segment.get_sign_content())
|
||||
if self.return_resource:
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = Document.query.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
}
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
return ""
|
||||
|
||||
raise RuntimeError("not segments found")
|
||||
|
||||
def _retriever(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
all_documents: list,
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler],
|
||||
):
|
||||
with flask_app.app_context():
|
||||
dataset = (
|
||||
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
for hit_callback in hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search",
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if self.top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
@@ -0,0 +1,33 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from msal_extensions.persistence import ABC # type: ignore
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
|
||||
|
||||
class DatasetRetrieverBaseTool(BaseModel, ABC):
|
||||
"""Tool for querying a Dataset."""
|
||||
|
||||
name: str = "dataset"
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
tenant_id: str
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool.
|
||||
|
||||
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
to child implementations to enable tracing,
|
||||
"""
|
||||
201
api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
Normal file
201
api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"reranking_mode": "reranking_model",
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class DatasetRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
|
||||
|
||||
|
||||
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"""Tool for querying a Dataset."""
|
||||
|
||||
name: str = "dataset"
|
||||
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
dataset_id: str
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = "useful for when you want to answer queries about the " + dataset.name
|
||||
|
||||
description = description.replace("\n", "").replace("\r", "")
|
||||
return cls(
|
||||
name=f"dataset_{dataset.id.replace('-', '_')}",
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
description=description,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
dataset = (
|
||||
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
return ""
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
if dataset.provider == "external":
|
||||
results = []
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
document = RetrievalDocument(
|
||||
page_content=external_document.get("content"),
|
||||
metadata=external_document.get("metadata"),
|
||||
provider="external",
|
||||
)
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset.id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
# deal with external documents
|
||||
context_list = []
|
||||
for position, item in enumerate(results, start=1):
|
||||
if item.metadata is not None:
|
||||
source = {
|
||||
"position": position,
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": item.metadata.get("score"),
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
context_list.append(source)
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join([item.page_content for item in results]))
|
||||
else:
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
|
||||
)
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
if self.top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model")
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights"),
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
for item in documents:
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
document_context_list = []
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
if segment.answer:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
else:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=segment.get_sign_content(),
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
retrieval_resource_list = []
|
||||
if self.return_resource:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id, # type: ignore
|
||||
"document_name": document.name, # type: ignore
|
||||
"data_source_type": document.data_source_type, # type: ignore
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": record.score or 0.0,
|
||||
}
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if self.return_resource and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=lambda x: x.get("score") or 0.0,
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
|
||||
item["position"] = position # type: ignore
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
return str("\n".join([document_context.content for document_context in document_context_list]))
|
||||
return ""
|
||||
134
api/core/tools/utils/dataset_retriever_tool.py
Normal file
134
api/core/tools/utils/dataset_retriever_tool.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
|
||||
|
||||
class DatasetRetrieverTool(Tool):
|
||||
retrieval_tool: DatasetRetrieverBaseTool
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.retrieval_tool = retrieval_tool
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_tools(
|
||||
tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity | None,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
) -> list["DatasetRetrieverTool"]:
|
||||
"""
|
||||
get dataset tool
|
||||
"""
|
||||
# check if retrieve_config is valid
|
||||
if dataset_ids is None or len(dataset_ids) == 0:
|
||||
return []
|
||||
if retrieve_config is None:
|
||||
return []
|
||||
|
||||
feature = DatasetRetrieval()
|
||||
|
||||
# save original retrieve strategy, and set retrieve strategy to SINGLE
|
||||
# Agent only support SINGLE mode
|
||||
original_retriever_mode = retrieve_config.retrieve_strategy
|
||||
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
||||
retrieval_tools = feature.to_dataset_retriever_tool(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=retrieve_config,
|
||||
return_resource=return_resource,
|
||||
invoke_from=invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
)
|
||||
if retrieval_tools is None or len(retrieval_tools) == 0:
|
||||
return []
|
||||
|
||||
# restore retrieve strategy
|
||||
retrieve_config.retrieve_strategy = original_retriever_mode
|
||||
|
||||
# convert retrieval tools to Tools
|
||||
tools = []
|
||||
for retrieval_tool in retrieval_tools:
|
||||
tool = DatasetRetrieverTool(
|
||||
retrieval_tool=retrieval_tool,
|
||||
entity=ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
|
||||
),
|
||||
parameters=[],
|
||||
description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
|
||||
),
|
||||
runtime=ToolRuntime(tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="query",
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description="Query for the dataset to be used to retrieve the dataset.",
|
||||
required=True,
|
||||
default="",
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
),
|
||||
]
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.DATASET_RETRIEVAL
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke dataset retriever tool
|
||||
"""
|
||||
query = tool_parameters.get("query")
|
||||
if not query:
|
||||
yield self.create_text_message(text="please input query")
|
||||
else:
|
||||
# invoke dataset retriever tool
|
||||
result = self.retrieval_tool._run(query=query)
|
||||
yield self.create_text_message(text=result)
|
||||
|
||||
def validate_credentials(
|
||||
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
|
||||
) -> str | None:
|
||||
"""
|
||||
validate the credentials for dataset retriever tool
|
||||
"""
|
||||
pass
|
||||
@@ -1,919 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
def auth(credentials):
|
||||
app_id = credentials.get("app_id")
|
||||
app_secret = credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ToolProviderCredentialValidationError("app_id and app_secret is required")
|
||||
try:
|
||||
assert FeishuRequest(app_id, app_secret).tenant_access_token is not None
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
|
||||
def convert_add_records(json_str):
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("Parsed data must be a list")
|
||||
converted_data = [{"fields": json.dumps(item, ensure_ascii=False)} for item in data]
|
||||
return converted_data
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
except Exception as e:
|
||||
raise ValueError(f"An error occurred while processing the data: {e}")
|
||||
|
||||
|
||||
def convert_update_records(json_str):
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("Parsed data must be a list")
|
||||
|
||||
converted_data = [
|
||||
{"fields": json.dumps(record["fields"], ensure_ascii=False), "record_id": record["record_id"]}
|
||||
for record in data
|
||||
if "fields" in record and "record_id" in record
|
||||
]
|
||||
|
||||
if len(converted_data) != len(data):
|
||||
raise ValueError("Each record must contain 'fields' and 'record_id'")
|
||||
|
||||
return converted_data
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
except Exception as e:
|
||||
raise ValueError(f"An error occurred while processing the data: {e}")
|
||||
|
||||
|
||||
class FeishuRequest:
|
||||
API_BASE_URL = "https://lark-plugin-api.solutionsuite.cn/lark-plugin"
|
||||
|
||||
def __init__(self, app_id: str, app_secret: str):
|
||||
self.app_id = app_id
|
||||
self.app_secret = app_secret
|
||||
|
||||
@property
|
||||
def tenant_access_token(self):
|
||||
feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token"
|
||||
if redis_client.exists(feishu_tenant_access_token):
|
||||
return redis_client.get(feishu_tenant_access_token).decode()
|
||||
res = self.get_tenant_access_token(self.app_id, self.app_secret)
|
||||
redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token"))
|
||||
return res.get("tenant_access_token")
|
||||
|
||||
def _send_request(
|
||||
self,
|
||||
url: str,
|
||||
method: str = "post",
|
||||
require_token: bool = True,
|
||||
payload: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"user-agent": "Dify",
|
||||
}
|
||||
if require_token:
|
||||
headers["tenant-access-token"] = f"{self.tenant_access_token}"
|
||||
res = httpx.request(method=method, url=url, headers=headers, json=payload, params=params, timeout=30).json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res)
|
||||
return res
|
||||
|
||||
def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict:
|
||||
"""
|
||||
API url: https://open.feishu.cn/document/server-docs/authentication-management/access-token/tenant_access_token_internal
|
||||
Example Response:
|
||||
{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"tenant_access_token": "t-caecc734c2e3328a62489fe0648c4b98779515d3",
|
||||
"expire": 7200
|
||||
}
|
||||
"""
|
||||
url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token"
|
||||
payload = {"app_id": app_id, "app_secret": app_secret}
|
||||
res: dict = self._send_request(url, require_token=False, payload=payload)
|
||||
return res
|
||||
|
||||
def create_document(self, title: str, content: str, folder_token: str) -> dict:
|
||||
"""
|
||||
API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/create
|
||||
Example Response:
|
||||
{
|
||||
"data": {
|
||||
"title": "title",
|
||||
"url": "https://svi136aogf123.feishu.cn/docx/VWbvd4fEdoW0WSxaY1McQTz8n7d",
|
||||
"type": "docx",
|
||||
"token": "VWbvd4fEdoW0WSxaY1McQTz8n7d"
|
||||
},
|
||||
"log_id": "021721281231575fdbddc0200ff00060a9258ec0000103df61b5d",
|
||||
"code": 0,
|
||||
"msg": "创建飞书文档成功,请查看"
|
||||
}
|
||||
"""
|
||||
url = f"{self.API_BASE_URL}/document/create_document"
|
||||
payload = {
|
||||
"title": title,
|
||||
"content": content,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
|
||||
url = f"{self.API_BASE_URL}/document/write_document"
|
||||
payload = {"document_id": document_id, "content": content, "position": position}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
return res
|
||||
|
||||
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str:
|
||||
"""
|
||||
API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/raw_content
|
||||
Example Response:
|
||||
{
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": {
|
||||
"content": "云文档\n多人实时协同,插入一切元素。不仅是在线文档,更是强大的创作和互动工具\n云文档:专为协作而生\n"
|
||||
}
|
||||
}
|
||||
""" # noqa: E501
|
||||
params = {
|
||||
"document_id": document_id,
|
||||
"mode": mode,
|
||||
"lang": lang,
|
||||
}
|
||||
url = f"{self.API_BASE_URL}/document/get_document_content"
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return cast(str, res.get("data", {}).get("content"))
|
||||
return ""
|
||||
|
||||
def list_document_blocks(
|
||||
self, document_id: str, page_token: str, user_id_type: str = "open_id", page_size: int = 500
|
||||
) -> dict:
|
||||
"""
|
||||
API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/list
|
||||
"""
|
||||
params = {
|
||||
"user_id_type": user_id_type,
|
||||
"document_id": document_id,
|
||||
"page_size": page_size,
|
||||
"page_token": page_token,
|
||||
}
|
||||
url = f"{self.API_BASE_URL}/document/list_document_blocks"
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
|
||||
"""
|
||||
API url: https://open.larkoffice.com/document/server-docs/im-v1/message/create
|
||||
"""
|
||||
url = f"{self.API_BASE_URL}/message/send_bot_message"
|
||||
params = {
|
||||
"receive_id_type": receive_id_type,
|
||||
}
|
||||
payload = {
|
||||
"receive_id": receive_id,
|
||||
"msg_type": msg_type,
|
||||
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
|
||||
}
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/message/send_webhook_message"
|
||||
payload = {
|
||||
"webhook": webhook,
|
||||
"msg_type": msg_type,
|
||||
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
|
||||
}
|
||||
res: dict = self._send_request(url, require_token=False, payload=payload)
|
||||
return res
|
||||
|
||||
def get_chat_messages(
|
||||
self,
|
||||
container_id: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
page_token: str,
|
||||
sort_type: str = "ByCreateTimeAsc",
|
||||
page_size: int = 20,
|
||||
) -> dict:
|
||||
"""
|
||||
API url: https://open.larkoffice.com/document/server-docs/im-v1/message/list
|
||||
"""
|
||||
url = f"{self.API_BASE_URL}/message/get_chat_messages"
|
||||
params = {
|
||||
"container_id": container_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"sort_type": sort_type,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_thread_messages(
|
||||
self, container_id: str, page_token: str, sort_type: str = "ByCreateTimeAsc", page_size: int = 20
|
||||
) -> dict:
|
||||
"""
|
||||
API url: https://open.larkoffice.com/document/server-docs/im-v1/message/list
|
||||
"""
|
||||
url = f"{self.API_BASE_URL}/message/get_thread_messages"
|
||||
params = {
|
||||
"container_id": container_id,
|
||||
"sort_type": sort_type,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
|
||||
# 创建任务
|
||||
url = f"{self.API_BASE_URL}/task/create_task"
|
||||
payload = {
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"completed_at": completed_time,
|
||||
"description": description,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def update_task(
|
||||
self, task_guid: str, summary: str, start_time: str, end_time: str, completed_time: str, description: str
|
||||
) -> dict:
|
||||
# 更新任务
|
||||
url = f"{self.API_BASE_URL}/task/update_task"
|
||||
payload = {
|
||||
"task_guid": task_guid,
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"completed_time": completed_time,
|
||||
"description": description,
|
||||
}
|
||||
res: dict = self._send_request(url, method="PATCH", payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def delete_task(self, task_guid: str) -> dict:
|
||||
# 删除任务
|
||||
url = f"{self.API_BASE_URL}/task/delete_task"
|
||||
payload = {
|
||||
"task_guid": task_guid,
|
||||
}
|
||||
res: dict = self._send_request(url, method="DELETE", payload=payload)
|
||||
return res
|
||||
|
||||
def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict:
|
||||
# 删除任务
|
||||
url = f"{self.API_BASE_URL}/task/add_members"
|
||||
payload = {
|
||||
"task_guid": task_guid,
|
||||
"member_phone_or_email": member_phone_or_email,
|
||||
"member_role": member_role,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
return res
|
||||
|
||||
def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict:
|
||||
# 获取知识库全部子节点列表
|
||||
url = f"{self.API_BASE_URL}/wiki/get_wiki_nodes"
|
||||
payload = {
|
||||
"space_id": space_id,
|
||||
"parent_node_token": parent_node_token,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/get_primary_calendar"
|
||||
params = {
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_event(
|
||||
self,
|
||||
summary: str,
|
||||
description: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
attendee_ability: str,
|
||||
need_notification: bool = True,
|
||||
auto_record: bool = False,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/create_event"
|
||||
payload = {
|
||||
"summary": summary,
|
||||
"description": description,
|
||||
"need_notification": need_notification,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"auto_record": auto_record,
|
||||
"attendee_ability": attendee_ability,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def update_event(
|
||||
self,
|
||||
event_id: str,
|
||||
summary: str,
|
||||
description: str,
|
||||
need_notification: bool,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
auto_record: bool,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}"
|
||||
payload: dict[str, Any] = {}
|
||||
if summary:
|
||||
payload["summary"] = summary
|
||||
if description:
|
||||
payload["description"] = description
|
||||
if start_time:
|
||||
payload["start_time"] = start_time
|
||||
if end_time:
|
||||
payload["end_time"] = end_time
|
||||
if need_notification:
|
||||
payload["need_notification"] = need_notification
|
||||
if auto_record:
|
||||
payload["auto_record"] = auto_record
|
||||
res: dict = self._send_request(url, method="PATCH", payload=payload)
|
||||
return res
|
||||
|
||||
def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/delete_event/{event_id}"
|
||||
params = {
|
||||
"need_notification": need_notification,
|
||||
}
|
||||
res: dict = self._send_request(url, method="DELETE", params=params)
|
||||
return res
|
||||
|
||||
def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/list_events"
|
||||
params = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def search_events(
|
||||
self,
|
||||
query: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
page_token: str,
|
||||
user_id_type: str = "open_id",
|
||||
page_size: int = 20,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/search_events"
|
||||
payload = {
|
||||
"query": query,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"page_token": page_token,
|
||||
"user_id_type": user_id_type,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
|
||||
# 参加日程参会人
|
||||
url = f"{self.API_BASE_URL}/calendar/add_event_attendees"
|
||||
payload = {
|
||||
"event_id": event_id,
|
||||
"attendee_phone_or_email": attendee_phone_or_email,
|
||||
"need_notification": need_notification,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_spreadsheet(
|
||||
self,
|
||||
title: str,
|
||||
folder_token: str,
|
||||
) -> dict:
|
||||
# 创建电子表格
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/create_spreadsheet"
|
||||
payload = {
|
||||
"title": title,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_spreadsheet(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
# 获取电子表格信息
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/get_spreadsheet"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def list_spreadsheet_sheets(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
) -> dict:
|
||||
# 列出电子表格的所有工作表
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/list_spreadsheet_sheets"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_rows(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
length: int,
|
||||
values: str,
|
||||
) -> dict:
|
||||
# 增加行,在工作表最后添加
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/add_rows"
|
||||
payload = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"length": length,
|
||||
"values": values,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_cols(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
length: int,
|
||||
values: str,
|
||||
) -> dict:
|
||||
# 增加列,在工作表最后添加
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/add_cols"
|
||||
payload = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"length": length,
|
||||
"values": values,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_rows(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
start_row: int,
|
||||
num_rows: int,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
# 读取工作表行数据
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/read_rows"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"start_row": start_row,
|
||||
"num_rows": num_rows,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_cols(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
start_col: int,
|
||||
num_cols: int,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
# 读取工作表列数据
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/read_cols"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"start_col": start_col,
|
||||
"num_cols": num_cols,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_table(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
num_range: str,
|
||||
query: str,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
# 自定义读取行列数据
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/read_table"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"range": num_range,
|
||||
"query": query,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_base(
|
||||
self,
|
||||
name: str,
|
||||
folder_token: str,
|
||||
) -> dict:
|
||||
# 创建多维表格
|
||||
url = f"{self.API_BASE_URL}/base/create_base"
|
||||
payload = {
|
||||
"name": name,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_records(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
records: str,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
# 新增多条记录
|
||||
url = f"{self.API_BASE_URL}/base/add_records"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
payload = {
|
||||
"records": convert_add_records(records),
|
||||
}
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def update_records(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
records: str,
|
||||
user_id_type: str,
|
||||
) -> dict:
|
||||
# 更新多条记录
|
||||
url = f"{self.API_BASE_URL}/base/update_records"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
payload = {
|
||||
"records": convert_update_records(records),
|
||||
}
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def delete_records(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
record_ids: str,
|
||||
) -> dict:
|
||||
# 删除多条记录
|
||||
url = f"{self.API_BASE_URL}/base/delete_records"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
}
|
||||
if not record_ids:
|
||||
record_id_list = []
|
||||
else:
|
||||
try:
|
||||
record_id_list = json.loads(record_ids)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
payload = {
|
||||
"records": record_id_list,
|
||||
}
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def search_record(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
view_id: str,
|
||||
field_names: str,
|
||||
sort: str,
|
||||
filters: str,
|
||||
page_token: str,
|
||||
automatic_fields: bool = False,
|
||||
user_id_type: str = "open_id",
|
||||
page_size: int = 20,
|
||||
) -> dict:
|
||||
# 查询记录,单次最多查询 500 行记录。
|
||||
url = f"{self.API_BASE_URL}/base/search_record"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
"user_id_type": user_id_type,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
if not field_names:
|
||||
field_name_list = []
|
||||
else:
|
||||
try:
|
||||
field_name_list = json.loads(field_names)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
if not sort:
|
||||
sort_list = []
|
||||
else:
|
||||
try:
|
||||
sort_list = json.loads(sort)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
if not filters:
|
||||
filter_dict = {}
|
||||
else:
|
||||
try:
|
||||
filter_dict = json.loads(filters)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
payload: dict[str, Any] = {}
|
||||
|
||||
if view_id:
|
||||
payload["view_id"] = view_id
|
||||
if field_names:
|
||||
payload["field_names"] = field_name_list
|
||||
if sort:
|
||||
payload["sort"] = sort_list
|
||||
if filters:
|
||||
payload["filter"] = filter_dict
|
||||
if automatic_fields:
|
||||
payload["automatic_fields"] = automatic_fields
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_base_info(
|
||||
self,
|
||||
app_token: str,
|
||||
) -> dict:
|
||||
# 获取多维表格元数据
|
||||
url = f"{self.API_BASE_URL}/base/get_base_info"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_table(
|
||||
self,
|
||||
app_token: str,
|
||||
table_name: str,
|
||||
default_view_name: str,
|
||||
fields: str,
|
||||
) -> dict:
|
||||
# 新增一个数据表
|
||||
url = f"{self.API_BASE_URL}/base/create_table"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
}
|
||||
if not fields:
|
||||
fields_list = []
|
||||
else:
|
||||
try:
|
||||
fields_list = json.loads(fields)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
payload = {
|
||||
"name": table_name,
|
||||
"fields": fields_list,
|
||||
}
|
||||
if default_view_name:
|
||||
payload["default_view_name"] = default_view_name
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def delete_tables(
|
||||
self,
|
||||
app_token: str,
|
||||
table_ids: str,
|
||||
table_names: str,
|
||||
) -> dict:
|
||||
# 删除多个数据表
|
||||
url = f"{self.API_BASE_URL}/base/delete_tables"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
}
|
||||
if not table_ids:
|
||||
table_id_list = []
|
||||
else:
|
||||
try:
|
||||
table_id_list = json.loads(table_ids)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
if not table_names:
|
||||
table_name_list = []
|
||||
else:
|
||||
try:
|
||||
table_name_list = json.loads(table_names)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
payload = {
|
||||
"table_ids": table_id_list,
|
||||
"table_names": table_name_list,
|
||||
}
|
||||
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def list_tables(
|
||||
self,
|
||||
app_token: str,
|
||||
page_token: str,
|
||||
page_size: int = 20,
|
||||
) -> dict:
|
||||
# 列出多维表格下的全部数据表
|
||||
url = f"{self.API_BASE_URL}/base/list_tables"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_records(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
record_ids: str,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/read_records"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
}
|
||||
if not record_ids:
|
||||
record_id_list = []
|
||||
else:
|
||||
try:
|
||||
record_id_list = json.loads(record_ids)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
payload = {
|
||||
"record_ids": record_id_list,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
@@ -1,851 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
def lark_auth(credentials):
|
||||
app_id = credentials.get("app_id")
|
||||
app_secret = credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ToolProviderCredentialValidationError("app_id and app_secret is required")
|
||||
try:
|
||||
assert LarkRequest(app_id, app_secret).tenant_access_token is not None
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
|
||||
class LarkRequest:
|
||||
API_BASE_URL = "https://lark-plugin-api.solutionsuite.ai/lark-plugin"
|
||||
|
||||
def __init__(self, app_id: str, app_secret: str):
|
||||
self.app_id = app_id
|
||||
self.app_secret = app_secret
|
||||
|
||||
def convert_add_records(self, json_str):
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("Parsed data must be a list")
|
||||
converted_data = [{"fields": json.dumps(item, ensure_ascii=False)} for item in data]
|
||||
return converted_data
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
except Exception as e:
|
||||
raise ValueError(f"An error occurred while processing the data: {e}")
|
||||
|
||||
def convert_update_records(self, json_str):
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("Parsed data must be a list")
|
||||
|
||||
converted_data = [
|
||||
{"fields": json.dumps(record["fields"], ensure_ascii=False), "record_id": record["record_id"]}
|
||||
for record in data
|
||||
if "fields" in record and "record_id" in record
|
||||
]
|
||||
|
||||
if len(converted_data) != len(data):
|
||||
raise ValueError("Each record must contain 'fields' and 'record_id'")
|
||||
|
||||
return converted_data
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
except Exception as e:
|
||||
raise ValueError(f"An error occurred while processing the data: {e}")
|
||||
|
||||
@property
|
||||
def tenant_access_token(self) -> str:
|
||||
feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token"
|
||||
if redis_client.exists(feishu_tenant_access_token):
|
||||
return str(redis_client.get(feishu_tenant_access_token).decode())
|
||||
res: dict[str, str] = self.get_tenant_access_token(self.app_id, self.app_secret)
|
||||
redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token"))
|
||||
return res.get("tenant_access_token", "")
|
||||
|
||||
def _send_request(
|
||||
self,
|
||||
url: str,
|
||||
method: str = "post",
|
||||
require_token: bool = True,
|
||||
payload: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"user-agent": "Dify",
|
||||
}
|
||||
if require_token:
|
||||
headers["tenant-access-token"] = f"{self.tenant_access_token}"
|
||||
res = httpx.request(method=method, url=url, headers=headers, json=payload, params=params, timeout=30).json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res)
|
||||
return res
|
||||
|
||||
def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token"
|
||||
payload = {"app_id": app_id, "app_secret": app_secret}
|
||||
res: dict = self._send_request(url, require_token=False, payload=payload)
|
||||
return res
|
||||
|
||||
def create_document(self, title: str, content: str, folder_token: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/document/create_document"
|
||||
payload = {
|
||||
"title": title,
|
||||
"content": content,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
|
||||
url = f"{self.API_BASE_URL}/document/write_document"
|
||||
payload = {"document_id": document_id, "content": content, "position": position}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
return res
|
||||
|
||||
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str | dict:
|
||||
params = {
|
||||
"document_id": document_id,
|
||||
"mode": mode,
|
||||
"lang": lang,
|
||||
}
|
||||
url = f"{self.API_BASE_URL}/document/get_document_content"
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return cast(dict, res.get("data", {}).get("content"))
|
||||
return ""
|
||||
|
||||
def list_document_blocks(
|
||||
self, document_id: str, page_token: str, user_id_type: str = "open_id", page_size: int = 500
|
||||
) -> dict:
|
||||
params = {
|
||||
"user_id_type": user_id_type,
|
||||
"document_id": document_id,
|
||||
"page_size": page_size,
|
||||
"page_token": page_token,
|
||||
}
|
||||
url = f"{self.API_BASE_URL}/document/list_document_blocks"
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/message/send_bot_message"
|
||||
params = {
|
||||
"receive_id_type": receive_id_type,
|
||||
}
|
||||
payload = {
|
||||
"receive_id": receive_id,
|
||||
"msg_type": msg_type,
|
||||
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
|
||||
}
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/message/send_webhook_message"
|
||||
payload = {
|
||||
"webhook": webhook,
|
||||
"msg_type": msg_type,
|
||||
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
|
||||
}
|
||||
res: dict = self._send_request(url, require_token=False, payload=payload)
|
||||
return res
|
||||
|
||||
def get_chat_messages(
|
||||
self,
|
||||
container_id: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
page_token: str,
|
||||
sort_type: str = "ByCreateTimeAsc",
|
||||
page_size: int = 20,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/message/get_chat_messages"
|
||||
params = {
|
||||
"container_id": container_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"sort_type": sort_type,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_thread_messages(
|
||||
self, container_id: str, page_token: str, sort_type: str = "ByCreateTimeAsc", page_size: int = 20
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/message/get_thread_messages"
|
||||
params = {
|
||||
"container_id": container_id,
|
||||
"sort_type": sort_type,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/task/create_task"
|
||||
payload = {
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"completed_at": completed_time,
|
||||
"description": description,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def update_task(
|
||||
self, task_guid: str, summary: str, start_time: str, end_time: str, completed_time: str, description: str
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/task/update_task"
|
||||
payload = {
|
||||
"task_guid": task_guid,
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"completed_time": completed_time,
|
||||
"description": description,
|
||||
}
|
||||
res: dict = self._send_request(url, method="PATCH", payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def delete_task(self, task_guid: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/task/delete_task"
|
||||
payload = {
|
||||
"task_guid": task_guid,
|
||||
}
|
||||
res: dict = self._send_request(url, method="DELETE", payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/task/add_members"
|
||||
payload = {
|
||||
"task_guid": task_guid,
|
||||
"member_phone_or_email": member_phone_or_email,
|
||||
"member_role": member_role,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict:
|
||||
url = f"{self.API_BASE_URL}/wiki/get_wiki_nodes"
|
||||
payload = {
|
||||
"space_id": space_id,
|
||||
"parent_node_token": parent_node_token,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/get_primary_calendar"
|
||||
params = {
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_event(
|
||||
self,
|
||||
summary: str,
|
||||
description: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
attendee_ability: str,
|
||||
need_notification: bool = True,
|
||||
auto_record: bool = False,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/create_event"
|
||||
payload = {
|
||||
"summary": summary,
|
||||
"description": description,
|
||||
"need_notification": need_notification,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"auto_record": auto_record,
|
||||
"attendee_ability": attendee_ability,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def update_event(
|
||||
self,
|
||||
event_id: str,
|
||||
summary: str,
|
||||
description: str,
|
||||
need_notification: bool,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
auto_record: bool,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}"
|
||||
payload: dict[str, Any] = {}
|
||||
if summary:
|
||||
payload["summary"] = summary
|
||||
if description:
|
||||
payload["description"] = description
|
||||
if start_time:
|
||||
payload["start_time"] = start_time
|
||||
if end_time:
|
||||
payload["end_time"] = end_time
|
||||
if need_notification:
|
||||
payload["need_notification"] = need_notification
|
||||
if auto_record:
|
||||
payload["auto_record"] = auto_record
|
||||
res: dict = self._send_request(url, method="PATCH", payload=payload)
|
||||
return res
|
||||
|
||||
def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/delete_event/{event_id}"
|
||||
params = {
|
||||
"need_notification": need_notification,
|
||||
}
|
||||
res: dict = self._send_request(url, method="DELETE", params=params)
|
||||
return res
|
||||
|
||||
def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/list_events"
|
||||
params = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def search_events(
|
||||
self,
|
||||
query: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
page_token: str,
|
||||
user_id_type: str = "open_id",
|
||||
page_size: int = 20,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/search_events"
|
||||
payload = {
|
||||
"query": query,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"page_token": page_token,
|
||||
"user_id_type": user_id_type,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/add_event_attendees"
|
||||
payload = {
|
||||
"event_id": event_id,
|
||||
"attendee_phone_or_email": attendee_phone_or_email,
|
||||
"need_notification": need_notification,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_spreadsheet(
|
||||
self,
|
||||
title: str,
|
||||
folder_token: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/create_spreadsheet"
|
||||
payload = {
|
||||
"title": title,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_spreadsheet(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/get_spreadsheet"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def list_spreadsheet_sheets(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/list_spreadsheet_sheets"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_rows(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
length: int,
|
||||
values: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/add_rows"
|
||||
payload = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"length": length,
|
||||
"values": values,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_cols(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
length: int,
|
||||
values: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/add_cols"
|
||||
payload = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"length": length,
|
||||
"values": values,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_rows(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
start_row: int,
|
||||
num_rows: int,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/read_rows"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"start_row": start_row,
|
||||
"num_rows": num_rows,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_cols(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
start_col: int,
|
||||
num_cols: int,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/read_cols"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"start_col": start_col,
|
||||
"num_cols": num_cols,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_table(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
num_range: str,
|
||||
query: str,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/read_table"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"range": num_range,
|
||||
"query": query,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_base(
|
||||
self,
|
||||
name: str,
|
||||
folder_token: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/create_base"
|
||||
payload = {
|
||||
"name": name,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_records(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
records: str,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/add_records"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
payload = {
|
||||
"records": self.convert_add_records(records),
|
||||
}
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def update_records(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
records: str,
|
||||
user_id_type: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/update_records"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
payload = {
|
||||
"records": self.convert_update_records(records),
|
||||
}
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def delete_records(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
record_ids: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/delete_records"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
}
|
||||
if not record_ids:
|
||||
record_id_list = []
|
||||
else:
|
||||
try:
|
||||
record_id_list = json.loads(record_ids)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
payload = {
|
||||
"records": record_id_list,
|
||||
}
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def search_record(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
view_id: str,
|
||||
field_names: str,
|
||||
sort: str,
|
||||
filters: str,
|
||||
page_token: str,
|
||||
automatic_fields: bool = False,
|
||||
user_id_type: str = "open_id",
|
||||
page_size: int = 20,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/search_record"
|
||||
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
"user_id_type": user_id_type,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
if not field_names:
|
||||
field_name_list = []
|
||||
else:
|
||||
try:
|
||||
field_name_list = json.loads(field_names)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
if not sort:
|
||||
sort_list = []
|
||||
else:
|
||||
try:
|
||||
sort_list = json.loads(sort)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
if not filters:
|
||||
filter_dict = {}
|
||||
else:
|
||||
try:
|
||||
filter_dict = json.loads(filters)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
payload: dict[str, Any] = {}
|
||||
|
||||
if view_id:
|
||||
payload["view_id"] = view_id
|
||||
if field_names:
|
||||
payload["field_names"] = field_name_list
|
||||
if sort:
|
||||
payload["sort"] = sort_list
|
||||
if filters:
|
||||
payload["filter"] = filter_dict
|
||||
if automatic_fields:
|
||||
payload["automatic_fields"] = automatic_fields
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_base_info(
|
||||
self,
|
||||
app_token: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/get_base_info"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_table(
|
||||
self,
|
||||
app_token: str,
|
||||
table_name: str,
|
||||
default_view_name: str,
|
||||
fields: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/create_table"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
}
|
||||
if not fields:
|
||||
fields_list = []
|
||||
else:
|
||||
try:
|
||||
fields_list = json.loads(fields)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
payload = {
|
||||
"name": table_name,
|
||||
"fields": fields_list,
|
||||
}
|
||||
if default_view_name:
|
||||
payload["default_view_name"] = default_view_name
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def delete_tables(
|
||||
self,
|
||||
app_token: str,
|
||||
table_ids: str,
|
||||
table_names: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/delete_tables"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
}
|
||||
if not table_ids:
|
||||
table_id_list = []
|
||||
else:
|
||||
try:
|
||||
table_id_list = json.loads(table_ids)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
if not table_names:
|
||||
table_name_list = []
|
||||
else:
|
||||
try:
|
||||
table_name_list = json.loads(table_names)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
payload = {
|
||||
"table_ids": table_id_list,
|
||||
"table_names": table_name_list,
|
||||
}
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def list_tables(
|
||||
self,
|
||||
app_token: str,
|
||||
page_token: str,
|
||||
page_size: int = 20,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/list_tables"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_records(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
record_ids: str,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/read_records"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
}
|
||||
if not record_ids:
|
||||
record_id_list = []
|
||||
else:
|
||||
try:
|
||||
record_id_list = json.loads(record_ids)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
payload = {
|
||||
"record_ids": record_id_list,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res: dict = self._send_request(url, method="POST", params=params, payload=payload)
|
||||
if "data" in res:
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional
|
||||
|
||||
@@ -12,58 +13,64 @@ logger = logging.getLogger(__name__)
|
||||
class ToolFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(
|
||||
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str | None
|
||||
) -> list[ToolInvokeMessage]:
|
||||
cls,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Transform tool message and handle file download
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in messages:
|
||||
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
|
||||
result.append(message)
|
||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(message.message, str):
|
||||
yield message
|
||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(
|
||||
message.message, ToolInvokeMessage.TextMessage
|
||||
):
|
||||
# try to download image
|
||||
try:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
file = ToolFileManager.create_file_by_url(
|
||||
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}"
|
||||
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to download image from {url}")
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=f"Failed to download image: {message.message}, please try to download it manually.",
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
save_as=message.save_as,
|
||||
)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(
|
||||
text=f"Failed to download image: {message.message.text}: {e}"
|
||||
),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
assert message.meta is not None
|
||||
mimetype = message.meta.get("mime_type", "octet/stream")
|
||||
meta = message.meta or {}
|
||||
|
||||
mimetype = meta.get("mime_type", "octet/stream")
|
||||
# if message is str, encode it to bytes
|
||||
if isinstance(message.message, str):
|
||||
message.message = message.message.encode("utf-8")
|
||||
|
||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message, bytes)
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
file = ToolFileManager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_binary=message.message,
|
||||
file_binary=message.message.blob,
|
||||
mimetype=mimetype,
|
||||
)
|
||||
|
||||
@@ -71,54 +78,40 @@ class ToolFileMessageTransformer:
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
file_mata = message.meta.get("file")
|
||||
if isinstance(file_mata, File):
|
||||
if file_mata.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file_mata.related_id is not None
|
||||
url = cls.get_tool_file_url(tool_file_id=file_mata.related_id, extension=file_mata.extension)
|
||||
if file_mata.type == FileType.IMAGE:
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
meta = message.meta or {}
|
||||
file = meta.get("file", None)
|
||||
if isinstance(file, File):
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file.related_id is not None
|
||||
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
|
||||
if file.type == FileType.IMAGE:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
result.append(message)
|
||||
yield message
|
||||
else:
|
||||
result.append(message)
|
||||
|
||||
return result
|
||||
yield message
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Optional, cast
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
@@ -18,7 +18,7 @@ from core.model_runtime.errors.invoke import (
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ToolModelInvoke
|
||||
|
||||
@@ -18,7 +18,7 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: Optional[dict], warning: Optional[dict]
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
@@ -191,7 +191,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: Optional[dict], warning: Optional[dict]
|
||||
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
@@ -208,7 +208,8 @@ class ApiBasedToolSchemaParser:
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: Optional[dict], warning: Optional[dict]) -> dict:
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
@@ -271,7 +272,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: Optional[dict], warning: Optional[dict]
|
||||
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
@@ -287,7 +288,7 @@ class ApiBasedToolSchemaParser:
|
||||
api = openai_plugin["api"]
|
||||
api_url = api["url"]
|
||||
api_type = api["type"]
|
||||
except:
|
||||
except JSONDecodeError:
|
||||
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||
|
||||
if api_type != "openapi":
|
||||
@@ -305,7 +306,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: Optional[dict] = None, warning: Optional[dict] = None
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
17
api/core/tools/utils/rag_web_reader.py
Normal file
17
api/core/tools/utils/rag_web_reader.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import re
|
||||
|
||||
|
||||
def get_image_upload_file_ids(content):
|
||||
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
|
||||
matches = re.findall(pattern, content)
|
||||
image_upload_file_ids = []
|
||||
for match in matches:
|
||||
if match[1] == "file-preview":
|
||||
content_pattern = r"files/([^/]+)/file-preview"
|
||||
else:
|
||||
content_pattern = r"files/([^/]+)/image-preview"
|
||||
content_match = re.search(content_pattern, match[0])
|
||||
if content_match:
|
||||
image_upload_file_id = content_match.group(1)
|
||||
image_upload_file_ids.append(image_upload_file_id)
|
||||
return image_upload_file_ids
|
||||
@@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
) -> bool:
|
||||
):
|
||||
"""
|
||||
check is synced
|
||||
|
||||
@@ -41,5 +41,3 @@ class WorkflowToolConfigurationUtils:
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user