feat: add api-based extension & external data tool & moderation backend (#1403)
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
98
api/services/api_based_extension_service.py
Normal file
98
api/services/api_based_extension_service.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from extensions.ext_database import db
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from core.helper.encrypter import encrypt_token, decrypt_token
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||
|
||||
|
||||
class APIBasedExtensionService:
|
||||
|
||||
@staticmethod
|
||||
def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
|
||||
extension_list = db.session.query(APIBasedExtension) \
|
||||
.filter_by(tenant_id=tenant_id) \
|
||||
.order_by(APIBasedExtension.created_at.desc()) \
|
||||
.all()
|
||||
|
||||
for extension in extension_list:
|
||||
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
|
||||
|
||||
return extension_list
|
||||
|
||||
@classmethod
|
||||
def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension:
|
||||
cls._validation(extension_data)
|
||||
|
||||
extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key)
|
||||
|
||||
db.session.add(extension_data)
|
||||
db.session.commit()
|
||||
return extension_data
|
||||
|
||||
@staticmethod
|
||||
def delete(extension_data: APIBasedExtension) -> None:
|
||||
db.session.delete(extension_data)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||
extension = db.session.query(APIBasedExtension) \
|
||||
.filter_by(tenant_id=tenant_id) \
|
||||
.filter_by(id=api_based_extension_id) \
|
||||
.first()
|
||||
|
||||
if not extension:
|
||||
raise ValueError("API based extension is not found")
|
||||
|
||||
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
|
||||
|
||||
return extension
|
||||
|
||||
@classmethod
|
||||
def _validation(cls, extension_data: APIBasedExtension) -> None:
|
||||
# name
|
||||
if not extension_data.name:
|
||||
raise ValueError("name must not be empty")
|
||||
|
||||
if not extension_data.id:
|
||||
# case one: check new data, name must be unique
|
||||
is_name_existed = db.session.query(APIBasedExtension) \
|
||||
.filter_by(tenant_id=extension_data.tenant_id) \
|
||||
.filter_by(name=extension_data.name) \
|
||||
.first()
|
||||
|
||||
if is_name_existed:
|
||||
raise ValueError("name must be unique, it is already existed")
|
||||
else:
|
||||
# case two: check existing data, name must be unique
|
||||
is_name_existed = db.session.query(APIBasedExtension) \
|
||||
.filter_by(tenant_id=extension_data.tenant_id) \
|
||||
.filter_by(name=extension_data.name) \
|
||||
.filter(APIBasedExtension.id != extension_data.id) \
|
||||
.first()
|
||||
|
||||
if is_name_existed:
|
||||
raise ValueError("name must be unique, it is already existed")
|
||||
|
||||
# api_endpoint
|
||||
if not extension_data.api_endpoint:
|
||||
raise ValueError("api_endpoint must not be empty")
|
||||
|
||||
# api_key
|
||||
if not extension_data.api_key:
|
||||
raise ValueError("api_key must not be empty")
|
||||
|
||||
if len(extension_data.api_key) < 5:
|
||||
raise ValueError("api_key must be at least 5 characters")
|
||||
|
||||
# check endpoint
|
||||
cls._ping_connection(extension_data)
|
||||
|
||||
@staticmethod
|
||||
def _ping_connection(extension_data: APIBasedExtension) -> None:
|
||||
try:
|
||||
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
|
||||
resp = client.request(point=APIBasedExtensionPoint.PING, params={})
|
||||
if resp.get('result') != 'pong':
|
||||
raise ValueError(resp)
|
||||
except Exception as e:
|
||||
raise ValueError("connection error: {}".format(e))
|
||||
@@ -1,6 +1,8 @@
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from core.prompt.prompt_transform import AppMode
|
||||
from core.agent.agent_executor import PlanningStrategy
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
@@ -13,8 +15,8 @@ SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current
|
||||
|
||||
|
||||
class AppModelConfigService:
|
||||
@staticmethod
|
||||
def is_dataset_exists(account: Account, dataset_id: str) -> bool:
|
||||
@classmethod
|
||||
def is_dataset_exists(cls, account: Account, dataset_id: str) -> bool:
|
||||
# verify if the dataset ID exists
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
@@ -26,8 +28,8 @@ class AppModelConfigService:
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_model_completion_params(cp: dict, model_name: str) -> dict:
|
||||
@classmethod
|
||||
def validate_model_completion_params(cls, cp: dict, model_name: str) -> dict:
|
||||
# 6. model.completion_params
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
@@ -57,7 +59,7 @@ class AppModelConfigService:
|
||||
cp["stop"] = []
|
||||
elif not isinstance(cp["stop"], list):
|
||||
raise ValueError("stop in model.completion_params must be of list type")
|
||||
|
||||
|
||||
if len(cp["stop"]) > 4:
|
||||
raise ValueError("stop sequences must be less than 4")
|
||||
|
||||
@@ -73,8 +75,8 @@ class AppModelConfigService:
|
||||
|
||||
return filtered_cp
|
||||
|
||||
@staticmethod
|
||||
def validate_configuration(tenant_id: str, account: Account, config: dict, mode: str) -> dict:
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, account: Account, config: dict, mode: str) -> dict:
|
||||
# opening_statement
|
||||
if 'opening_statement' not in config or not config["opening_statement"]:
|
||||
config["opening_statement"] = ""
|
||||
@@ -153,33 +155,6 @@ class AppModelConfigService:
|
||||
if not isinstance(config["more_like_this"]["enabled"], bool):
|
||||
raise ValueError("enabled in more_like_this must be of boolean type")
|
||||
|
||||
# sensitive_word_avoidance
|
||||
if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]:
|
||||
config["sensitive_word_avoidance"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"], dict):
|
||||
raise ValueError("sensitive_word_avoidance must be of dict type")
|
||||
|
||||
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
|
||||
config["sensitive_word_avoidance"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"]["enabled"], bool):
|
||||
raise ValueError("enabled in sensitive_word_avoidance must be of boolean type")
|
||||
|
||||
if "words" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["words"]:
|
||||
config["sensitive_word_avoidance"]["words"] = ""
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"]["words"], str):
|
||||
raise ValueError("words in sensitive_word_avoidance must be of string type")
|
||||
|
||||
if "canned_response" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["canned_response"]:
|
||||
config["sensitive_word_avoidance"]["canned_response"] = ""
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"]["canned_response"], str):
|
||||
raise ValueError("canned_response in sensitive_word_avoidance must be of string type")
|
||||
|
||||
# model
|
||||
if 'model' not in config:
|
||||
raise ValueError("model is required")
|
||||
@@ -204,7 +179,7 @@ class AppModelConfigService:
|
||||
model_ids = [m['id'] for m in model_list]
|
||||
if config["model"]["name"] not in model_ids:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
|
||||
# model.mode
|
||||
if 'mode' not in config['model'] or not config['model']["mode"]:
|
||||
config['model']["mode"] = ""
|
||||
@@ -213,7 +188,7 @@ class AppModelConfigService:
|
||||
if 'completion_params' not in config["model"]:
|
||||
raise ValueError("model.completion_params is required")
|
||||
|
||||
config["model"]["completion_params"] = AppModelConfigService.validate_model_completion_params(
|
||||
config["model"]["completion_params"] = cls.validate_model_completion_params(
|
||||
config["model"]["completion_params"],
|
||||
config["model"]["name"]
|
||||
)
|
||||
@@ -330,14 +305,20 @@ class AppModelConfigService:
|
||||
except ValueError:
|
||||
raise ValueError("id in dataset must be of UUID type")
|
||||
|
||||
if not AppModelConfigService.is_dataset_exists(account, tool_item["id"]):
|
||||
if not cls.is_dataset_exists(account, tool_item["id"]):
|
||||
raise ValueError("Dataset ID does not exist, please check your permission.")
|
||||
|
||||
|
||||
# dataset_query_variable
|
||||
AppModelConfigService.is_dataset_query_variable_valid(config, mode)
|
||||
cls.is_dataset_query_variable_valid(config, mode)
|
||||
|
||||
# advanced prompt validation
|
||||
AppModelConfigService.is_advanced_prompt_valid(config, mode)
|
||||
cls.is_advanced_prompt_valid(config, mode)
|
||||
|
||||
# external data tools validation
|
||||
cls.is_external_data_tools_valid(tenant_id, config)
|
||||
|
||||
# moderation validation
|
||||
cls.is_moderation_valid(tenant_id, config)
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {
|
||||
@@ -348,6 +329,7 @@ class AppModelConfigService:
|
||||
"retriever_resource": config["retriever_resource"],
|
||||
"more_like_this": config["more_like_this"],
|
||||
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
|
||||
"external_data_tools": config["external_data_tools"],
|
||||
"model": {
|
||||
"provider": config["model"]["provider"],
|
||||
"name": config["model"]["name"],
|
||||
@@ -365,32 +347,86 @@ class AppModelConfigService:
|
||||
}
|
||||
|
||||
return filtered_config
|
||||
|
||||
@staticmethod
|
||||
def is_dataset_query_variable_valid(config: dict, mode: str) -> None:
|
||||
|
||||
@classmethod
|
||||
def is_moderation_valid(cls, tenant_id: str, config: dict):
|
||||
if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]:
|
||||
config["sensitive_word_avoidance"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"], dict):
|
||||
raise ValueError("sensitive_word_avoidance must be of dict type")
|
||||
|
||||
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
|
||||
config["sensitive_word_avoidance"]["enabled"] = False
|
||||
|
||||
if not config["sensitive_word_avoidance"]["enabled"]:
|
||||
return
|
||||
|
||||
if "type" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["type"]:
|
||||
raise ValueError("sensitive_word_avoidance.type is required")
|
||||
|
||||
type = config["sensitive_word_avoidance"]["type"]
|
||||
config = config["sensitive_word_avoidance"]["config"]
|
||||
|
||||
ModerationFactory.validate_config(
|
||||
name=type,
|
||||
tenant_id=tenant_id,
|
||||
config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_external_data_tools_valid(cls, tenant_id: str, config: dict):
|
||||
if 'external_data_tools' not in config or not config["external_data_tools"]:
|
||||
config["external_data_tools"] = []
|
||||
|
||||
if not isinstance(config["external_data_tools"], list):
|
||||
raise ValueError("external_data_tools must be of list type")
|
||||
|
||||
for tool in config["external_data_tools"]:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
tool["enabled"] = False
|
||||
|
||||
if not tool["enabled"]:
|
||||
continue
|
||||
|
||||
if "type" not in tool or not tool["type"]:
|
||||
raise ValueError("external_data_tools[].type is required")
|
||||
|
||||
type = tool["type"]
|
||||
config = tool["config"]
|
||||
|
||||
ExternalDataToolFactory.validate_config(
|
||||
name=type,
|
||||
tenant_id=tenant_id,
|
||||
config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None:
|
||||
# Only check when mode is completion
|
||||
if mode != 'completion':
|
||||
return
|
||||
|
||||
|
||||
agent_mode = config.get("agent_mode", {})
|
||||
tools = agent_mode.get("tools", [])
|
||||
dataset_exists = "dataset" in str(tools)
|
||||
|
||||
|
||||
dataset_query_variable = config.get("dataset_query_variable")
|
||||
|
||||
if dataset_exists and not dataset_query_variable:
|
||||
raise ValueError("Dataset query variable is required when dataset is exist")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
|
||||
@classmethod
|
||||
def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None:
|
||||
# prompt_type
|
||||
if 'prompt_type' not in config or not config["prompt_type"]:
|
||||
config["prompt_type"] = "simple"
|
||||
|
||||
if config['prompt_type'] not in ['simple', 'advanced']:
|
||||
raise ValueError("prompt_type must be in ['simple', 'advanced']")
|
||||
|
||||
|
||||
# chat_prompt_config
|
||||
if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
|
||||
config["chat_prompt_config"] = {}
|
||||
@@ -404,7 +440,7 @@ class AppModelConfigService:
|
||||
|
||||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
|
||||
# dataset_configs
|
||||
if 'dataset_configs' not in config or not config["dataset_configs"]:
|
||||
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
|
||||
@@ -415,10 +451,10 @@ class AppModelConfigService:
|
||||
if config['prompt_type'] == 'advanced':
|
||||
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
|
||||
raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
|
||||
|
||||
|
||||
if config['model']["mode"] not in ['chat', 'completion']:
|
||||
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
|
||||
|
||||
|
||||
if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
|
||||
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
|
||||
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
|
||||
@@ -429,9 +465,8 @@ class AppModelConfigService:
|
||||
if not assistant_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
|
||||
|
||||
|
||||
if config['model']["mode"] == ModelMode.CHAT.value:
|
||||
prompt_list = config['chat_prompt_config']['prompt']
|
||||
|
||||
if len(prompt_list) > 10:
|
||||
raise ValueError("prompt messages must be less than 10")
|
||||
raise ValueError("prompt messages must be less than 10")
|
||||
|
||||
13
api/services/code_based_extension_service.py
Normal file
13
api/services/code_based_extension_service.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
|
||||
class CodeBasedExtensionService:
|
||||
|
||||
@staticmethod
|
||||
def get_code_based_extension(module: str) -> list[dict]:
|
||||
module_extensions = code_based_extension.module_extensions(module)
|
||||
return [{
|
||||
'name': module_extension.name,
|
||||
'label': module_extension.label,
|
||||
'form_schema': module_extension.form_schema
|
||||
} for module_extension in module_extensions if not module_extension.builtin]
|
||||
@@ -10,7 +10,8 @@ from redis.client import PubSub
|
||||
from sqlalchemy import and_
|
||||
|
||||
from core.completion import Completion
|
||||
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
|
||||
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, \
|
||||
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
@@ -28,9 +29,9 @@ from services.errors.message import MessageNotExistsError
|
||||
class CompletionService:
|
||||
|
||||
@classmethod
|
||||
def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any,
|
||||
def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
|
||||
from_source: str, streaming: bool = True,
|
||||
is_model_config_override: bool = False) -> Union[dict | Generator]:
|
||||
is_model_config_override: bool = False) -> Union[dict, Generator]:
|
||||
# is streaming mode
|
||||
inputs = args['inputs']
|
||||
query = args['query']
|
||||
@@ -199,9 +200,9 @@ class CompletionService:
|
||||
is_override=is_model_config_override,
|
||||
retriever_from=retriever_from
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
|
||||
pass
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
except (ValueError, LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
|
||||
ModelCurrentlyNotSupportError) as e:
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
@@ -234,7 +235,7 @@ class CompletionService:
|
||||
PubHandler.stop(user, generate_task_id)
|
||||
try:
|
||||
pubsub.close()
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
countdown_thread = threading.Thread(target=close_pubsub)
|
||||
@@ -243,9 +244,9 @@ class CompletionService:
|
||||
return countdown_thread
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
|
||||
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
|
||||
message_id: str, streaming: bool = True,
|
||||
retriever_from: str = 'dev') -> Union[dict | Generator]:
|
||||
retriever_from: str = 'dev') -> Union[dict, Generator]:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
|
||||
@@ -341,7 +342,7 @@ class CompletionService:
|
||||
return filtered_inputs
|
||||
|
||||
@classmethod
|
||||
def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict | Generator]:
|
||||
def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict, Generator]:
|
||||
generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
|
||||
if not streaming:
|
||||
try:
|
||||
@@ -386,6 +387,8 @@ class CompletionService:
|
||||
break
|
||||
if event == 'message':
|
||||
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'message_replace':
|
||||
yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'chain':
|
||||
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'agent_thought':
|
||||
@@ -427,6 +430,21 @@ class CompletionService:
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_message_replace_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'message_replace',
|
||||
'task_id': data.get('task_id'),
|
||||
'id': data.get('message_id'),
|
||||
'answer': data.get('text'),
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_blocking_message_response_data(cls, data: dict):
|
||||
message = data.get('message')
|
||||
@@ -508,6 +526,7 @@ class CompletionService:
|
||||
|
||||
# handle errors
|
||||
llm_errors = {
|
||||
'ValueError': LLMBadRequestError,
|
||||
'LLMBadRequestError': LLMBadRequestError,
|
||||
'LLMAPIConnectionError': LLMAPIConnectionError,
|
||||
'LLMAPIUnavailableError': LLMAPIUnavailableError,
|
||||
|
||||
20
api/services/moderation_service.py
Normal file
20
api/services/moderation_service.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from models.model import AppModelConfig, App
|
||||
from core.moderation.factory import ModerationFactory, ModerationOutputsResult
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
class ModerationService:
|
||||
|
||||
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
|
||||
app_model_config: AppModelConfig = None
|
||||
|
||||
app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
|
||||
if not app_model_config:
|
||||
raise ValueError("app model config not found")
|
||||
|
||||
name = app_model_config.sensitive_word_avoidance_dict['type']
|
||||
config = app_model_config.sensitive_word_avoidance_dict['config']
|
||||
|
||||
moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
|
||||
return moderation.moderation_for_outputs(text)
|
||||
Reference in New Issue
Block a user