perf(api): optimize tool provider list API with Redis caching (#29101)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
56
api/core/helper/tool_provider_cache.py
Normal file
56
api/core/helper/tool_provider_cache.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolProviderListCache:
|
||||
"""Cache for tool provider lists"""
|
||||
|
||||
CACHE_TTL = 300 # 5 minutes
|
||||
|
||||
@staticmethod
|
||||
def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
|
||||
"""Generate cache key for tool providers list"""
|
||||
type_filter = typ or "all"
|
||||
return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
|
||||
"""Get cached tool providers"""
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
cached_data = redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
try:
|
||||
return json.loads(cached_data.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
logger.warning("Failed to decode cached tool providers data")
|
||||
return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback()
|
||||
def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
|
||||
"""Cache tool providers"""
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback()
|
||||
def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
|
||||
"""Invalidate cache for tool providers"""
|
||||
if typ:
|
||||
# Invalidate specific type cache
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
redis_client.delete(cache_key)
|
||||
else:
|
||||
# Invalidate all caches for this tenant
|
||||
pattern = f"tool_providers:tenant_id:{tenant_id}:*"
|
||||
keys = list(redis_client.scan_iter(pattern))
|
||||
if keys:
|
||||
redis_client.delete(*keys)
|
||||
@@ -5,7 +5,7 @@ import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
@@ -67,6 +67,11 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiProviderControllerItem(TypedDict):
|
||||
provider: ApiToolProvider
|
||||
controller: ApiToolProviderController
|
||||
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
@@ -655,9 +660,10 @@ class ToolManager:
|
||||
else:
|
||||
filters.append(typ)
|
||||
|
||||
with db.session.no_autoflush:
|
||||
# Use a single session for all database operations to reduce connection overhead
|
||||
with Session(db.engine) as session:
|
||||
if "builtin" in filters:
|
||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||
builtin_providers = list(cls.list_builtin_providers(tenant_id))
|
||||
|
||||
# key: provider name, value: provider
|
||||
db_builtin_providers = {
|
||||
@@ -688,57 +694,74 @@ class ToolManager:
|
||||
|
||||
# get db api providers
|
||||
if "api" in filters:
|
||||
db_api_providers = db.session.scalars(
|
||||
db_api_providers = session.scalars(
|
||||
select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
||||
for provider in db_api_providers
|
||||
]
|
||||
# Batch create controllers
|
||||
api_provider_controllers: list[ApiProviderControllerItem] = []
|
||||
for api_provider in db_api_providers:
|
||||
try:
|
||||
controller = ToolTransformService.api_provider_to_controller(api_provider)
|
||||
api_provider_controllers.append({"provider": api_provider, "controller": controller})
|
||||
except Exception:
|
||||
# Skip invalid providers but continue processing others
|
||||
logger.warning("Failed to create controller for API provider %s", api_provider.id)
|
||||
|
||||
# get labels
|
||||
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
|
||||
|
||||
for api_provider_controller in api_provider_controllers:
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=api_provider_controller["controller"],
|
||||
db_provider=api_provider_controller["provider"],
|
||||
decrypt_credentials=False,
|
||||
labels=labels.get(api_provider_controller["controller"].provider_id, []),
|
||||
# Batch get labels for all API providers
|
||||
if api_provider_controllers:
|
||||
controllers = cast(
|
||||
list[ToolProviderController], [item["controller"] for item in api_provider_controllers]
|
||||
)
|
||||
result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
||||
labels = ToolLabelManager.get_tools_labels(controllers)
|
||||
|
||||
for item in api_provider_controllers:
|
||||
provider_controller = item["controller"]
|
||||
db_provider = item["provider"]
|
||||
provider_labels = labels.get(provider_controller.provider_id, [])
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=db_provider,
|
||||
decrypt_credentials=False,
|
||||
labels=provider_labels,
|
||||
)
|
||||
result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
if "workflow" in filters:
|
||||
# get workflow providers
|
||||
workflow_providers = db.session.scalars(
|
||||
workflow_providers = session.scalars(
|
||||
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
for workflow_provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
workflow_controller: WorkflowToolProviderController = (
|
||||
ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
)
|
||||
workflow_provider_controllers.append(workflow_controller)
|
||||
except Exception:
|
||||
# app has been deleted
|
||||
logger.exception("Failed to transform workflow provider %s to controller", workflow_provider.id)
|
||||
continue
|
||||
# Batch get labels for workflow providers
|
||||
if workflow_provider_controllers:
|
||||
workflow_controllers: list[ToolProviderController] = [
|
||||
cast(ToolProviderController, controller) for controller in workflow_provider_controllers
|
||||
]
|
||||
labels = ToolLabelManager.get_tools_labels(workflow_controllers)
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(
|
||||
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
||||
)
|
||||
for workflow_provider_controller in workflow_provider_controllers:
|
||||
provider_labels = labels.get(workflow_provider_controller.provider_id, [])
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=workflow_provider_controller,
|
||||
labels=provider_labels,
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
labels=labels.get(provider_controller.provider_id, []),
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
if "mcp" in filters:
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
|
||||
for mcp_provider in mcp_providers:
|
||||
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
|
||||
|
||||
|
||||
Reference in New Issue
Block a user