feat: add api-based extension & external data tool & moderation backend (#1403)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Garfield Dai
2023-11-06 19:36:16 +08:00
committed by GitHub
parent 7699621983
commit db43ed6f41
50 changed files with 1624 additions and 273 deletions

View File

@@ -0,0 +1 @@
import core.moderation.base

View File

@@ -1,13 +1,25 @@
import logging
from typing import Any, Dict, List, Union
import threading
import time
from typing import Any, Dict, List, Union, Optional
from flask import Flask, current_app
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult, BaseMessage
from pydantic import BaseModel
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.moderation.base import ModerationOutputsResult, ModerationAction
from core.moderation.factory import ModerationFactory
class ModerationRule(BaseModel):
type: str
config: Dict[str, Any]
class LLMCallbackHandler(BaseCallbackHandler):
@@ -20,6 +32,24 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.start_at = None
self.conversation_message_task = conversation_message_task
self.output_moderation_handler = None
self.init_output_moderation()
def init_output_moderation(self):
app_model_config = self.conversation_message_task.app_model_config
sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
self.output_moderation_handler = OutputModerationHandler(
tenant_id=self.conversation_message_task.tenant_id,
app_id=self.conversation_message_task.app.id,
rule=ModerationRule(
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config")
),
on_message_replace_func=self.conversation_message_task.on_message_replace
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
@@ -59,10 +89,19 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text)
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
self.llm_message.completion = self.output_moderation_handler.moderation_completion(
completion=response.generations[0][0].text,
public_event=True if self.conversation_message_task.streaming else False
)
else:
self.llm_message.completion = response.generations[0][0].text
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(self.llm_message.completion)
if response.llm_output and 'token_usage' in response.llm_output:
if 'prompt_tokens' in response.llm_output['token_usage']:
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
@@ -79,23 +118,161 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.conversation_message_task.save_message(self.llm_message)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
try:
self.conversation_message_task.append_message_text(token)
except ConversationTaskStoppedException as ex:
if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
ex = ConversationTaskInterruptException()
self.on_llm_error(error=ex)
raise ex
self.llm_message.completion += token
try:
self.conversation_message_task.append_message_text(token)
self.llm_message.completion += token
if self.output_moderation_handler:
self.output_moderation_handler.append_new_token(token)
except ConversationTaskStoppedException as ex:
self.on_llm_error(error=ex)
raise ex
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
if isinstance(error, ConversationTaskInterruptException):
self.llm_message.completion = self.output_moderation_handler.get_final_output()
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message)
else:
logging.debug("on_llm_error: %s", error)
class OutputModerationHandler(BaseModel):
DEFAULT_BUFFER_SIZE: int = 300
tenant_id: str
app_id: str
rule: ModerationRule
on_message_replace_func: Any
thread: Optional[threading.Thread] = None
thread_running: bool = True
buffer: str = ''
is_final_chunk: bool = False
final_output: Optional[str] = None
class Config:
arbitrary_types_allowed = True
def should_direct_output(self):
return self.final_output is not None
def get_final_output(self):
return self.final_output
def append_new_token(self, token: str):
self.buffer += token
if not self.thread:
self.thread = self.start_thread()
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
self.buffer = completion
self.is_final_chunk = True
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=completion
)
if not result or not result.flagged:
return completion
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
else:
final_output = result.text
if public_event:
self.on_message_replace_func(final_output)
return final_output
def start_thread(self) -> threading.Thread:
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
thread = threading.Thread(target=self.worker, kwargs={
'flask_app': current_app._get_current_object(),
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
})
thread.start()
return thread
def stop_thread(self):
if self.thread and self.thread.is_alive():
self.thread_running = False
def worker(self, flask_app: Flask, buffer_size: int):
with flask_app.app_context():
current_length = 0
while self.thread_running:
moderation_buffer = self.buffer
buffer_length = len(moderation_buffer)
if not self.is_final_chunk:
chunk_length = buffer_length - current_length
if 0 <= chunk_length < buffer_size:
time.sleep(1)
continue
current_length = buffer_length
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=moderation_buffer
)
if not result or not result.flagged:
continue
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
self.final_output = final_output
else:
final_output = result.text + self.buffer[len(moderation_buffer):]
# trigger replace event
if self.thread_running:
self.on_message_replace_func(final_output)
if result.action == ModerationAction.DIRECT_OUTPUT:
break
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
try:
moderation_factory = ModerationFactory(
name=self.rule.type,
app_id=app_id,
tenant_id=tenant_id,
config=self.rule.config
)
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
logging.error("Moderation Output error: %s", e)
return None

View File

@@ -1,92 +0,0 @@
import enum
import logging
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import BaseModel
from core.model_providers.error import LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation import openai_moderation
class SensitiveWordAvoidanceRule(BaseModel):
class Type(enum.Enum):
MODERATION = "moderation"
KEYWORDS = "keywords"
type: Type
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
extra_params: dict = {}
class SensitiveWordAvoidanceChain(Chain):
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
model_instance: BaseLLM
sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
@property
def _chain_type(self) -> str:
return "sensitive_word_avoidance_chain"
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def _check_sensitive_word(self, text: str) -> bool:
for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
if word in text:
return False
return True
def _check_moderation(self, text: str) -> bool:
moderation_model_instance = ModelFactory.get_moderation_model(
tenant_id=self.model_instance.model_provider.provider.tenant_id,
model_provider_name='openai',
model_name=openai_moderation.DEFAULT_MODEL
)
try:
return moderation_model_instance.run(text=text)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]
if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
result = self._check_sensitive_word(text)
else:
result = self._check_moderation(text)
if not result:
raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
return {self.output_key: text}
class SensitiveWordAvoidanceError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message

View File

@@ -1,13 +1,18 @@
import concurrent
import json
import logging
from typing import Optional, List, Union
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, List, Union, Tuple
from flask import current_app, Flask
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.external_data_tool.factory import ExternalDataToolFactory
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
@@ -18,6 +23,8 @@ from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.model import App, AppModelConfig, Account, Conversation, EndUser
from core.moderation.base import ModerationException, ModerationAction
from core.moderation.factory import ModerationFactory
class Completion:
@@ -76,26 +83,35 @@ class Completion:
)
try:
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
final_model_instance, [chain_callback])
if sensitive_word_avoidance_chain:
try:
query = sensitive_word_avoidance_chain.run(query)
except SensitiveWordAvoidanceError as ex:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=ex.message
)
return
try:
# process sensitive_word_avoidance
inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
except ModerationException as e:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=str(e)
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_model_config.external_data_tools_list
if external_data_tools:
inputs = cls.fill_in_inputs_from_external_data_tools(
tenant_id=app.tenant_id,
app_id=app.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
@@ -135,19 +151,110 @@ class Completion:
memory=memory,
fake_response=fake_response
)
except ConversationTaskStoppedException:
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
return
except ChunkedEncodingError as e:
# Interrupt by LLM (like OpenAI), handle it.
logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end()
return
@classmethod
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
return inputs, query
type = app_model_config.sensitive_word_avoidance_dict['type']
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
moderation_result = moderation.moderation_for_inputs(inputs, query)
if not moderation_result.flagged:
return inputs, query
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
raise ModerationException(moderation_result.preset_response)
elif moderation_result.action == ModerationAction.OVERRIDED:
inputs = moderation_result.inputs
query = moderation_result.query
return inputs, query
@classmethod
def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
inputs: dict, query: str) -> dict:
"""
Fill in variable inputs from external data tools if exists.
:param tenant_id: workspace id
:param app_id: app id
:param external_data_tools: external data tools configs
:param inputs: the inputs
:param query: the query
:return: the filled inputs
"""
# Group tools by type and config
grouped_tools = {}
for tool in external_data_tools:
if not tool.get("enabled"):
continue
tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
grouped_tools.setdefault(tool_key, []).append(tool)
results = {}
with ThreadPoolExecutor() as executor:
futures = {}
for tools in grouped_tools.values():
# Only query the first tool in each group
first_tool = tools[0]
future = executor.submit(
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, first_tool,
inputs, query
)
for tool in tools:
futures[future] = tool
for future in concurrent.futures.as_completed(futures):
tool_key, result = future.result()
if tool_key in grouped_tools:
for tool in grouped_tools[tool_key]:
results[tool['variable']] = result
inputs.update(results)
return inputs
@classmethod
def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
with flask_app.app_context():
tool_variable = external_data_tool.get("variable")
tool_type = external_data_tool.get("type")
tool_config = external_data_tool.get("config")
external_data_tool_factory = ExternalDataToolFactory(
name=tool_type,
tenant_id=tenant_id,
app_id=app_id,
variable=tool_variable,
config=tool_config
)
# query external data tool
result = external_data_tool_factory.query(
inputs=inputs,
query=query
)
tool_key = (external_data_tool.get("type"), json.dumps(external_data_tool.get("config"), sort_keys=True))
return tool_key, result
@classmethod
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
if app.mode != 'completion':
return query
return inputs.get(app_model_config.dataset_query_variable, "")
@classmethod

View File

@@ -290,6 +290,10 @@ class ConversationMessageTask:
db.session.commit()
self.retriever_resource = resource
def on_message_replace(self, text: str):
if text is not None:
self._pub_handler.pub_message_replace(text)
def message_end(self):
self._pub_handler.pub_message_end(self.retriever_resource)
@@ -342,6 +346,24 @@ class PubHandler:
self.pub_end()
raise ConversationTaskStoppedException()
def pub_message_replace(self, text: str):
content = {
'event': 'message_replace',
'data': {
'task_id': self._task_id,
'message_id': str(self._message.id),
'text': text,
'mode': self._conversation.mode,
'conversation_id': str(self._conversation.id)
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_chain(self, message_chain: MessageChain):
if self._chain_pub:
content = {
@@ -443,3 +465,7 @@ class PubHandler:
class ConversationTaskStoppedException(Exception):
pass
class ConversationTaskInterruptException(Exception):
pass

View File

View File

@@ -0,0 +1,62 @@
import os
import requests
from models.api_based_extension import APIBasedExtensionPoint
class APIBasedExtensionRequestor:
timeout: (int, int) = (5, 60)
"""timeout for request connect and read"""
def __init__(self, api_endpoint: str, api_key: str) -> None:
self.api_endpoint = api_endpoint
self.api_key = api_key
def request(self, point: APIBasedExtensionPoint, params: dict) -> dict:
"""
Request the api.
:param point: the api point
:param params: the request params
:return: the response json
"""
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(self.api_key)
}
url = self.api_endpoint
try:
# proxy support for security
proxies = None
if os.environ.get("API_BASED_EXTENSION_HTTP_PROXY") and os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"):
proxies = {
'http': os.environ.get("API_BASED_EXTENSION_HTTP_PROXY"),
'https': os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"),
}
response = requests.request(
method='POST',
url=url,
json={
'point': point.value,
'params': params
},
headers=headers,
timeout=self.timeout,
proxies=proxies
)
except requests.exceptions.Timeout:
raise ValueError("request timeout")
except requests.exceptions.ConnectionError:
raise ValueError("request connection error")
if response.status_code != 200:
raise ValueError("request error, status_code: {}, content: {}".format(
response.status_code,
response.text[:100]
))
return response.json()

View File

@@ -0,0 +1,111 @@
import enum
import importlib.util
import json
import logging
import os
from collections import OrderedDict
from typing import Any, Optional
from pydantic import BaseModel
class ExtensionModule(enum.Enum):
MODERATION = 'moderation'
EXTERNAL_DATA_TOOL = 'external_data_tool'
class ModuleExtension(BaseModel):
extension_class: Any
name: str
label: Optional[dict] = None
form_schema: Optional[list] = None
builtin: bool = True
position: Optional[int] = None
class Extensible:
module: ExtensionModule
name: str
tenant_id: str
config: Optional[dict] = None
def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
self.tenant_id = tenant_id
self.config = config
@classmethod
def scan_extensions(cls):
extensions = {}
# get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
current_dir_path = os.path.dirname(current_path)
# traverse subdirectories
for subdir_name in os.listdir(current_dir_path):
if subdir_name.startswith('__'):
continue
subdir_path = os.path.join(current_dir_path, subdir_name)
extension_name = subdir_name
if os.path.isdir(subdir_path):
file_names = os.listdir(subdir_path)
# is builtin extension, builtin extension
# in the front-end page and business logic, there are special treatments.
builtin = False
position = None
if '__builtin__' in file_names:
builtin = True
builtin_file_path = os.path.join(subdir_path, '__builtin__')
if os.path.exists(builtin_file_path):
with open(builtin_file_path, 'r') as f:
position = int(f.read().strip())
if (extension_name + '.py') not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
continue
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + '.py')
spec = importlib.util.spec_from_file_location(extension_name, py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
extension_class = None
for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
extension_class = obj
break
if not extension_class:
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
continue
json_data = {}
if not builtin:
if 'schema.json' not in file_names:
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
continue
json_path = os.path.join(subdir_path, 'schema.json')
json_data = {}
if os.path.exists(json_path):
with open(json_path, 'r') as f:
json_data = json.load(f)
extensions[extension_name] = ModuleExtension(
extension_class=extension_class,
name=extension_name,
label=json_data.get('label'),
form_schema=json_data.get('form_schema'),
builtin=builtin,
position=position
)
sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position))
sorted_extensions = OrderedDict(sorted_items)
return sorted_extensions

View File

@@ -0,0 +1,47 @@
from core.extension.extensible import ModuleExtension, ExtensionModule
from core.external_data_tool.base import ExternalDataTool
from core.moderation.base import Moderation
class Extension:
__module_extensions: dict[str, dict[str, ModuleExtension]] = {}
module_classes = {
ExtensionModule.MODERATION: Moderation,
ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool
}
def init(self):
for module, module_class in self.module_classes.items():
self.__module_extensions[module.value] = module_class.scan_extensions()
def module_extensions(self, module: str) -> list[ModuleExtension]:
module_extensions = self.__module_extensions.get(module)
if not module_extensions:
raise ValueError(f"Extension Module {module} not found")
return list(module_extensions.values())
def module_extension(self, module: ExtensionModule, extension_name: str) -> ModuleExtension:
module_extensions = self.__module_extensions.get(module.value)
if not module_extensions:
raise ValueError(f"Extension Module {module} not found")
module_extension = module_extensions.get(extension_name)
if not module_extension:
raise ValueError(f"Extension {extension_name} not found")
return module_extension
def extension_class(self, module: ExtensionModule, extension_name: str) -> type:
module_extension = self.module_extension(module, extension_name)
return module_extension.extension_class
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
module_extension = self.module_extension(module, extension_name)
form_schema = module_extension.form_schema
# TODO validate form_schema

View File

View File

@@ -0,0 +1 @@
1

View File

@@ -0,0 +1,92 @@
from typing import Optional
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.external_data_tool.base import ExternalDataTool
from core.helper import encrypter
from extensions.ext_database import db
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
class ApiExternalDataTool(ExternalDataTool):
"""
The api external data tool.
"""
name: str = "api"
"""the unique name of external data tool"""
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
# own validation logic
api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id:
raise ValueError("api_based_extension_id is required")
# get api_based_extension
api_based_extension = db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == tenant_id,
APIBasedExtension.id == api_based_extension_id
).first()
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
def query(self, inputs: dict, query: Optional[str] = None) -> str:
"""
Query the external data tool.
:param inputs: user inputs
:param query: the query of chat app
:return: the tool query result
"""
# get params from config
api_based_extension_id = self.config.get("api_based_extension_id")
# get api_based_extension
api_based_extension = db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == self.tenant_id,
APIBasedExtension.id == api_based_extension_id
).first()
if not api_based_extension:
raise ValueError("[External data tool] API query failed, variable: {}, "
"error: api_based_extension_id is invalid"
.format(self.config.get('variable')))
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=self.tenant_id,
token=api_based_extension.api_key
)
try:
# request api
requestor = APIBasedExtensionRequestor(
api_endpoint=api_based_extension.api_endpoint,
api_key=api_key
)
except Exception as e:
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(
self.config.get('variable'),
e
))
response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={
'app_id': self.app_id,
'tool_variable': self.variable,
'inputs': inputs,
'query': query
})
if 'result' not in response_json:
raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response"
.format(self.config.get('variable')))
return response_json['result']

View File

@@ -0,0 +1,45 @@
from abc import abstractmethod, ABC
from typing import Optional
from core.extension.extensible import Extensible, ExtensionModule
class ExternalDataTool(Extensible, ABC):
"""
The base class of external data tool.
"""
module: ExtensionModule = ExtensionModule.EXTERNAL_DATA_TOOL
app_id: str
"""the id of app"""
variable: str
"""the tool variable name of app tool"""
def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None:
super().__init__(tenant_id, config)
self.app_id = app_id
self.variable = variable
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
raise NotImplementedError
@abstractmethod
def query(self, inputs: dict, query: Optional[str] = None) -> str:
"""
Query the external data tool.
:param inputs: user inputs
:param query: the query of chat app
:return: the tool query result
"""
raise NotImplementedError

View File

@@ -0,0 +1,40 @@
from typing import Optional
from core.extension.extensible import ExtensionModule
from extensions.ext_code_based_extension import code_based_extension
class ExternalDataToolFactory:
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
self.__extension_instance = extension_class(
tenant_id=tenant_id,
app_id=app_id,
variable=variable,
config=config
)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param name: the name of external data tool
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
extension_class.validate_config(tenant_id, config)
def query(self, inputs: dict, query: Optional[str] = None) -> str:
"""
Query the external data tool.
:param inputs: user inputs
:param query: the query of chat app
:return: the tool query result
"""
return self.__extension_instance.query(inputs, query)

View File

View File

@@ -0,0 +1 @@
3

View File

View File

@@ -0,0 +1,88 @@
from pydantic import BaseModel
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor, APIBasedExtensionPoint
from core.helper.encrypter import decrypt_token
from extensions.ext_database import db
from models.api_based_extension import APIBasedExtension
class ModerationInputParams(BaseModel):
app_id: str = ""
inputs: dict = {}
query: str = ""
class ModerationOutputParams(BaseModel):
app_id: str = ""
text: str
class ApiModeration(Moderation):
name: str = "api"
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
cls._validate_inputs_and_outputs_config(config, False)
api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id:
raise ValueError("api_based_extension_id is required")
extension = cls._get_api_based_extension(tenant_id, api_based_extension_id)
if not extension:
raise ValueError("API-based Extension not found. Please check it again.")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config['inputs_config']['enabled']:
params = ModerationInputParams(
app_id=self.app_id,
inputs=inputs,
query=query
)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.dict())
return ModerationInputsResult(**result)
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
if self.config['outputs_config']['enabled']:
params = ModerationOutputParams(
app_id=self.app_id,
text=text
)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.dict())
return ModerationOutputsResult(**result)
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id"))
requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
result = requestor.request(extension_point, params)
return result
@staticmethod
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
extension = db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == tenant_id,
APIBasedExtension.id == api_based_extension_id
).first()
return extension

113
api/core/moderation/base.py Normal file
View File

@@ -0,0 +1,113 @@
from abc import ABC, abstractmethod
from typing import Optional
from pydantic import BaseModel
from enum import Enum
from core.extension.extensible import Extensible, ExtensionModule
class ModerationAction(Enum):
DIRECT_OUTPUT = 'direct_output'
OVERRIDED = 'overrided'
class ModerationInputsResult(BaseModel):
flagged: bool = False
action: ModerationAction
preset_response: str = ""
inputs: dict = {}
query: str = ""
class ModerationOutputsResult(BaseModel):
flagged: bool = False
action: ModerationAction
preset_response: str = ""
text: str = ""
class Moderation(Extensible, ABC):
"""
The base class of moderation.
"""
module: ExtensionModule = ExtensionModule.MODERATION
def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None:
super().__init__(tenant_id, config)
self.app_id = app_id
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
raise NotImplementedError
@abstractmethod
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review
on the user inputs and return the processed results.
:param inputs: user inputs
:param query: query string (required in chat app)
:return:
"""
raise NotImplementedError
@abstractmethod
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
"""
Moderation for outputs.
When LLM outputs content, the front end will pass the output content (may be segmented)
to this method for sensitive content review, and the output content will be shielded if the review fails.
:param text: LLM output content
:return:
"""
raise NotImplementedError
@classmethod
def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None:
# inputs_config
inputs_config = config.get("inputs_config")
if not isinstance(inputs_config, dict):
raise ValueError("inputs_config must be a dict")
# outputs_config
outputs_config = config.get("outputs_config")
if not isinstance(outputs_config, dict):
raise ValueError("outputs_config must be a dict")
inputs_config_enabled = inputs_config.get("enabled")
outputs_config_enabled = outputs_config.get("enabled")
if not inputs_config_enabled and not outputs_config_enabled:
raise ValueError("At least one of inputs_config or outputs_config must be enabled")
# preset_response
if not is_preset_response_required:
return
if inputs_config_enabled:
if not inputs_config.get("preset_response"):
raise ValueError("inputs_config.preset_response is required")
if len(inputs_config.get("preset_response")) > 100:
raise ValueError("inputs_config.preset_response must be less than 100 characters")
if outputs_config_enabled:
if not outputs_config.get("preset_response"):
raise ValueError("outputs_config.preset_response is required")
if len(outputs_config.get("preset_response")) > 100:
raise ValueError("outputs_config.preset_response must be less than 100 characters")
class ModerationException(Exception):
pass

View File

@@ -0,0 +1,48 @@
from core.extension.extensible import ExtensionModule
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
from extensions.ext_code_based_extension import code_based_extension
class ModerationFactory:
__extension_instance: Moderation
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None:
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
self.__extension_instance = extension_class(app_id, tenant_id, config)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param name: the name of extension
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
extension_class.validate_config(tenant_id, config)
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review
on the user inputs and return the processed results.
:param inputs: user inputs
:param query: query string (required in chat app)
:return:
"""
return self.__extension_instance.moderation_for_inputs(inputs, query)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
"""
Moderation for outputs.
When LLM outputs content, the front end will pass the output content (may be segmented)
to this method for sensitive content review, and the output content will be shielded if the review fails.
:param text: LLM output content
:return:
"""
return self.__extension_instance.moderation_for_outputs(text)

View File

@@ -0,0 +1 @@
2

View File

View File

@@ -0,0 +1,60 @@
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
class KeywordsModeration(Moderation):
name: str = "keywords"
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
cls._validate_inputs_and_outputs_config(config, True)
if not config.get("keywords"):
raise ValueError("keywords is required")
if len(config.get("keywords")) > 1000:
raise ValueError("keywords length must be less than 1000")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config['inputs_config']['enabled']:
preset_response = self.config['inputs_config']['preset_response']
if query:
inputs['query__'] = query
keywords_list = self.config['keywords'].split('\n')
flagged = self._is_violated(inputs, keywords_list)
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
if self.config['outputs_config']['enabled']:
keywords_list = self.config['keywords'].split('\n')
flagged = self._is_violated({'text': text}, keywords_list)
preset_response = self.config['outputs_config']['preset_response']
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
for value in inputs.values():
if self._check_keywords_in_value(keywords_list, value):
return True
return False
def _check_keywords_in_value(self, keywords_list, value):
for keyword in keywords_list:
if keyword.lower() in value.lower():
return True
return False

View File

@@ -0,0 +1 @@
1

View File

@@ -0,0 +1,46 @@
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
from core.model_providers.model_factory import ModelFactory
class OpenAIModeration(Moderation):
name: str = "openai_moderation"
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
cls._validate_inputs_and_outputs_config(config, True)
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config['inputs_config']['enabled']:
preset_response = self.config['inputs_config']['preset_response']
if query:
inputs['query__'] = query
flagged = self._is_violated(inputs)
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
if self.config['outputs_config']['enabled']:
flagged = self._is_violated({'text': text})
preset_response = self.config['outputs_config']['preset_response']
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def _is_violated(self, inputs: dict):
text = '\n'.join(inputs.values())
openai_moderation = ModelFactory.get_moderation_model(self.tenant_id, "openai", "moderation")
is_not_invalid = openai_moderation.run(text)
return not is_not_invalid

View File

@@ -11,7 +11,6 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
@@ -125,52 +124,6 @@ class OrchestratorRuleParser:
return chain
def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
-> Optional[SensitiveWordAvoidanceChain]:
"""
Convert app sensitive word avoidance config to chain
:param model_instance: model instance
:param callbacks: callbacks for the chain
:param kwargs:
:return:
"""
sensitive_word_avoidance_rule = None
if self.app_model_config.sensitive_word_avoidance_dict:
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_config.get("enabled", False):
if sensitive_word_avoidance_config.get('type') == 'moderation':
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.MODERATION,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
)
else:
sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_words:
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
extra_params={
'sensitive_words': sensitive_words.split(','),
}
)
if sensitive_word_avoidance_rule:
return SensitiveWordAvoidanceChain(
model_instance=model_instance,
sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
output_key="sensitive_word_avoidance_output",
callbacks=callbacks,
**kwargs
)
return None
def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
"""
Convert app agent tool configs to tools