feat: add api-based extension & external data tool & moderation backend (#1403)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Garfield Dai
2023-11-06 19:36:16 +08:00
committed by GitHub
parent 7699621983
commit db43ed6f41
50 changed files with 1624 additions and 273 deletions

View File

@@ -0,0 +1,98 @@
from extensions.ext_database import db
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from core.helper.encrypter import encrypt_token, decrypt_token
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
class APIBasedExtensionService:
@staticmethod
def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
extension_list = db.session.query(APIBasedExtension) \
.filter_by(tenant_id=tenant_id) \
.order_by(APIBasedExtension.created_at.desc()) \
.all()
for extension in extension_list:
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
return extension_list
@classmethod
def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension:
cls._validation(extension_data)
extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key)
db.session.add(extension_data)
db.session.commit()
return extension_data
@staticmethod
def delete(extension_data: APIBasedExtension) -> None:
db.session.delete(extension_data)
db.session.commit()
@staticmethod
def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
extension = db.session.query(APIBasedExtension) \
.filter_by(tenant_id=tenant_id) \
.filter_by(id=api_based_extension_id) \
.first()
if not extension:
raise ValueError("API based extension is not found")
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
return extension
@classmethod
def _validation(cls, extension_data: APIBasedExtension) -> None:
# name
if not extension_data.name:
raise ValueError("name must not be empty")
if not extension_data.id:
# case one: check new data, name must be unique
is_name_existed = db.session.query(APIBasedExtension) \
.filter_by(tenant_id=extension_data.tenant_id) \
.filter_by(name=extension_data.name) \
.first()
if is_name_existed:
raise ValueError("name must be unique, it is already existed")
else:
# case two: check existing data, name must be unique
is_name_existed = db.session.query(APIBasedExtension) \
.filter_by(tenant_id=extension_data.tenant_id) \
.filter_by(name=extension_data.name) \
.filter(APIBasedExtension.id != extension_data.id) \
.first()
if is_name_existed:
raise ValueError("name must be unique, it is already existed")
# api_endpoint
if not extension_data.api_endpoint:
raise ValueError("api_endpoint must not be empty")
# api_key
if not extension_data.api_key:
raise ValueError("api_key must not be empty")
if len(extension_data.api_key) < 5:
raise ValueError("api_key must be at least 5 characters")
# check endpoint
cls._ping_connection(extension_data)
@staticmethod
def _ping_connection(extension_data: APIBasedExtension) -> None:
try:
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
resp = client.request(point=APIBasedExtensionPoint.PING, params={})
if resp.get('result') != 'pong':
raise ValueError(resp)
except Exception as e:
raise ValueError("connection error: {}".format(e))

View File

@@ -1,6 +1,8 @@
import re
import uuid
from core.external_data_tool.factory import ExternalDataToolFactory
from core.moderation.factory import ModerationFactory
from core.prompt.prompt_transform import AppMode
from core.agent.agent_executor import PlanningStrategy
from core.model_providers.model_provider_factory import ModelProviderFactory
@@ -13,8 +15,8 @@ SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current
class AppModelConfigService:
@staticmethod
def is_dataset_exists(account: Account, dataset_id: str) -> bool:
@classmethod
def is_dataset_exists(cls, account: Account, dataset_id: str) -> bool:
# verify if the dataset ID exists
dataset = DatasetService.get_dataset(dataset_id)
@@ -26,8 +28,8 @@ class AppModelConfigService:
return True
@staticmethod
def validate_model_completion_params(cp: dict, model_name: str) -> dict:
@classmethod
def validate_model_completion_params(cls, cp: dict, model_name: str) -> dict:
# 6. model.completion_params
if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type")
@@ -57,7 +59,7 @@ class AppModelConfigService:
cp["stop"] = []
elif not isinstance(cp["stop"], list):
raise ValueError("stop in model.completion_params must be of list type")
if len(cp["stop"]) > 4:
raise ValueError("stop sequences must be less than 4")
@@ -73,8 +75,8 @@ class AppModelConfigService:
return filtered_cp
@staticmethod
def validate_configuration(tenant_id: str, account: Account, config: dict, mode: str) -> dict:
@classmethod
def validate_configuration(cls, tenant_id: str, account: Account, config: dict, mode: str) -> dict:
# opening_statement
if 'opening_statement' not in config or not config["opening_statement"]:
config["opening_statement"] = ""
@@ -153,33 +155,6 @@ class AppModelConfigService:
if not isinstance(config["more_like_this"]["enabled"], bool):
raise ValueError("enabled in more_like_this must be of boolean type")
# sensitive_word_avoidance
if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]:
config["sensitive_word_avoidance"] = {
"enabled": False
}
if not isinstance(config["sensitive_word_avoidance"], dict):
raise ValueError("sensitive_word_avoidance must be of dict type")
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
config["sensitive_word_avoidance"]["enabled"] = False
if not isinstance(config["sensitive_word_avoidance"]["enabled"], bool):
raise ValueError("enabled in sensitive_word_avoidance must be of boolean type")
if "words" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["words"]:
config["sensitive_word_avoidance"]["words"] = ""
if not isinstance(config["sensitive_word_avoidance"]["words"], str):
raise ValueError("words in sensitive_word_avoidance must be of string type")
if "canned_response" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["canned_response"]:
config["sensitive_word_avoidance"]["canned_response"] = ""
if not isinstance(config["sensitive_word_avoidance"]["canned_response"], str):
raise ValueError("canned_response in sensitive_word_avoidance must be of string type")
# model
if 'model' not in config:
raise ValueError("model is required")
@@ -204,7 +179,7 @@ class AppModelConfigService:
model_ids = [m['id'] for m in model_list]
if config["model"]["name"] not in model_ids:
raise ValueError("model.name must be in the specified model list")
# model.mode
if 'mode' not in config['model'] or not config['model']["mode"]:
config['model']["mode"] = ""
@@ -213,7 +188,7 @@ class AppModelConfigService:
if 'completion_params' not in config["model"]:
raise ValueError("model.completion_params is required")
config["model"]["completion_params"] = AppModelConfigService.validate_model_completion_params(
config["model"]["completion_params"] = cls.validate_model_completion_params(
config["model"]["completion_params"],
config["model"]["name"]
)
@@ -330,14 +305,20 @@ class AppModelConfigService:
except ValueError:
raise ValueError("id in dataset must be of UUID type")
if not AppModelConfigService.is_dataset_exists(account, tool_item["id"]):
if not cls.is_dataset_exists(account, tool_item["id"]):
raise ValueError("Dataset ID does not exist, please check your permission.")
# dataset_query_variable
AppModelConfigService.is_dataset_query_variable_valid(config, mode)
cls.is_dataset_query_variable_valid(config, mode)
# advanced prompt validation
AppModelConfigService.is_advanced_prompt_valid(config, mode)
cls.is_advanced_prompt_valid(config, mode)
# external data tools validation
cls.is_external_data_tools_valid(tenant_id, config)
# moderation validation
cls.is_moderation_valid(tenant_id, config)
# Filter out extra parameters
filtered_config = {
@@ -348,6 +329,7 @@ class AppModelConfigService:
"retriever_resource": config["retriever_resource"],
"more_like_this": config["more_like_this"],
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
"external_data_tools": config["external_data_tools"],
"model": {
"provider": config["model"]["provider"],
"name": config["model"]["name"],
@@ -365,32 +347,86 @@ class AppModelConfigService:
}
return filtered_config
@staticmethod
def is_dataset_query_variable_valid(config: dict, mode: str) -> None:
@classmethod
def is_moderation_valid(cls, tenant_id: str, config: dict):
if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]:
config["sensitive_word_avoidance"] = {
"enabled": False
}
if not isinstance(config["sensitive_word_avoidance"], dict):
raise ValueError("sensitive_word_avoidance must be of dict type")
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
config["sensitive_word_avoidance"]["enabled"] = False
if not config["sensitive_word_avoidance"]["enabled"]:
return
if "type" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["type"]:
raise ValueError("sensitive_word_avoidance.type is required")
type = config["sensitive_word_avoidance"]["type"]
config = config["sensitive_word_avoidance"]["config"]
ModerationFactory.validate_config(
name=type,
tenant_id=tenant_id,
config=config
)
@classmethod
def is_external_data_tools_valid(cls, tenant_id: str, config: dict):
if 'external_data_tools' not in config or not config["external_data_tools"]:
config["external_data_tools"] = []
if not isinstance(config["external_data_tools"], list):
raise ValueError("external_data_tools must be of list type")
for tool in config["external_data_tools"]:
if "enabled" not in tool or not tool["enabled"]:
tool["enabled"] = False
if not tool["enabled"]:
continue
if "type" not in tool or not tool["type"]:
raise ValueError("external_data_tools[].type is required")
type = tool["type"]
config = tool["config"]
ExternalDataToolFactory.validate_config(
name=type,
tenant_id=tenant_id,
config=config
)
@classmethod
def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None:
# Only check when mode is completion
if mode != 'completion':
return
agent_mode = config.get("agent_mode", {})
tools = agent_mode.get("tools", [])
dataset_exists = "dataset" in str(tools)
dataset_query_variable = config.get("dataset_query_variable")
if dataset_exists and not dataset_query_variable:
raise ValueError("Dataset query variable is required when dataset is exist")
@staticmethod
def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
@classmethod
def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None:
# prompt_type
if 'prompt_type' not in config or not config["prompt_type"]:
config["prompt_type"] = "simple"
if config['prompt_type'] not in ['simple', 'advanced']:
raise ValueError("prompt_type must be in ['simple', 'advanced']")
# chat_prompt_config
if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
config["chat_prompt_config"] = {}
@@ -404,7 +440,7 @@ class AppModelConfigService:
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
# dataset_configs
if 'dataset_configs' not in config or not config["dataset_configs"]:
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
@@ -415,10 +451,10 @@ class AppModelConfigService:
if config['prompt_type'] == 'advanced':
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
if config['model']["mode"] not in ['chat', 'completion']:
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
@@ -429,9 +465,8 @@ class AppModelConfigService:
if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
if config['model']["mode"] == ModelMode.CHAT.value:
prompt_list = config['chat_prompt_config']['prompt']
if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10")
raise ValueError("prompt messages must be less than 10")

View File

@@ -0,0 +1,13 @@
from extensions.ext_code_based_extension import code_based_extension
class CodeBasedExtensionService:
@staticmethod
def get_code_based_extension(module: str) -> list[dict]:
module_extensions = code_based_extension.module_extensions(module)
return [{
'name': module_extension.name,
'label': module_extension.label,
'form_schema': module_extension.form_schema
} for module_extension in module_extensions if not module_extension.builtin]

View File

@@ -10,7 +10,8 @@ from redis.client import PubSub
from sqlalchemy import and_
from core.completion import Completion
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, \
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
@@ -28,9 +29,9 @@ from services.errors.message import MessageNotExistsError
class CompletionService:
@classmethod
def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any,
def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
from_source: str, streaming: bool = True,
is_model_config_override: bool = False) -> Union[dict | Generator]:
is_model_config_override: bool = False) -> Union[dict, Generator]:
# is streaming mode
inputs = args['inputs']
query = args['query']
@@ -199,9 +200,9 @@ class CompletionService:
is_override=is_model_config_override,
retriever_from=retriever_from
)
except ConversationTaskStoppedException:
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
pass
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
except (ValueError, LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
ModelCurrentlyNotSupportError) as e:
PubHandler.pub_error(user, generate_task_id, e)
@@ -234,7 +235,7 @@ class CompletionService:
PubHandler.stop(user, generate_task_id)
try:
pubsub.close()
except:
except Exception:
pass
countdown_thread = threading.Thread(target=close_pubsub)
@@ -243,9 +244,9 @@ class CompletionService:
return countdown_thread
@classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
message_id: str, streaming: bool = True,
retriever_from: str = 'dev') -> Union[dict | Generator]:
retriever_from: str = 'dev') -> Union[dict, Generator]:
if not user:
raise ValueError('user cannot be None')
@@ -341,7 +342,7 @@ class CompletionService:
return filtered_inputs
@classmethod
def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict | Generator]:
def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict, Generator]:
generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
if not streaming:
try:
@@ -386,6 +387,8 @@ class CompletionService:
break
if event == 'message':
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
elif event == 'message_replace':
yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
elif event == 'chain':
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
elif event == 'agent_thought':
@@ -427,6 +430,21 @@ class CompletionService:
return response_data
@classmethod
def get_message_replace_response_data(cls, data: dict):
response_data = {
'event': 'message_replace',
'task_id': data.get('task_id'),
'id': data.get('message_id'),
'answer': data.get('text'),
'created_at': int(time.time())
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def get_blocking_message_response_data(cls, data: dict):
message = data.get('message')
@@ -508,6 +526,7 @@ class CompletionService:
# handle errors
llm_errors = {
'ValueError': LLMBadRequestError,
'LLMBadRequestError': LLMBadRequestError,
'LLMAPIConnectionError': LLMAPIConnectionError,
'LLMAPIUnavailableError': LLMAPIUnavailableError,

View File

@@ -0,0 +1,20 @@
from models.model import AppModelConfig, App
from core.moderation.factory import ModerationFactory, ModerationOutputsResult
from extensions.ext_database import db
class ModerationService:
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
app_model_config: AppModelConfig = None
app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
if not app_model_config:
raise ValueError("app model config not found")
name = app_model_config.sensitive_word_avoidance_dict['type']
config = app_model_config.sensitive_word_avoidance_dict['config']
moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
return moderation.moderation_for_outputs(text)