feat: add api-based extension & external data tool & moderation backend (#1403)
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
@@ -0,0 +1 @@
|
||||
import core.moderation.base
|
||||
@@ -1,13 +1,25 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Union
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult, BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.moderation.base import ModerationOutputsResult, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class LLMCallbackHandler(BaseCallbackHandler):
|
||||
@@ -20,6 +32,24 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
self.start_at = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
self.output_moderation_handler = None
|
||||
self.init_output_moderation()
|
||||
|
||||
def init_output_moderation(self):
|
||||
app_model_config = self.conversation_message_task.app_model_config
|
||||
sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
|
||||
|
||||
if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
|
||||
self.output_moderation_handler = OutputModerationHandler(
|
||||
tenant_id=self.conversation_message_task.tenant_id,
|
||||
app_id=self.conversation_message_task.app.id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config")
|
||||
),
|
||||
on_message_replace_func=self.conversation_message_task.on_message_replace
|
||||
)
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
@@ -59,10 +89,19 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(response.generations[0][0].text)
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.stop_thread()
|
||||
|
||||
self.llm_message.completion = self.output_moderation_handler.moderation_completion(
|
||||
completion=response.generations[0][0].text,
|
||||
public_event=True if self.conversation_message_task.streaming else False
|
||||
)
|
||||
else:
|
||||
self.llm_message.completion = response.generations[0][0].text
|
||||
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(self.llm_message.completion)
|
||||
|
||||
if response.llm_output and 'token_usage' in response.llm_output:
|
||||
if 'prompt_tokens' in response.llm_output['token_usage']:
|
||||
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
@@ -79,23 +118,161 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
self.conversation_message_task.save_message(self.llm_message)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
try:
|
||||
self.conversation_message_task.append_message_text(token)
|
||||
except ConversationTaskStoppedException as ex:
|
||||
if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
ex = ConversationTaskInterruptException()
|
||||
self.on_llm_error(error=ex)
|
||||
raise ex
|
||||
|
||||
self.llm_message.completion += token
|
||||
try:
|
||||
self.conversation_message_task.append_message_text(token)
|
||||
self.llm_message.completion += token
|
||||
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.append_new_token(token)
|
||||
except ConversationTaskStoppedException as ex:
|
||||
self.on_llm_error(error=ex)
|
||||
raise ex
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.stop_thread()
|
||||
|
||||
if isinstance(error, ConversationTaskStoppedException):
|
||||
if self.conversation_message_task.streaming:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)]
|
||||
)
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
|
||||
if isinstance(error, ConversationTaskInterruptException):
|
||||
self.llm_message.completion = self.output_moderation_handler.get_final_output()
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)]
|
||||
)
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message)
|
||||
else:
|
||||
logging.debug("on_llm_error: %s", error)
|
||||
|
||||
|
||||
class OutputModerationHandler(BaseModel):
|
||||
DEFAULT_BUFFER_SIZE: int = 300
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
|
||||
rule: ModerationRule
|
||||
on_message_replace_func: Any
|
||||
|
||||
thread: Optional[threading.Thread] = None
|
||||
thread_running: bool = True
|
||||
buffer: str = ''
|
||||
is_final_chunk: bool = False
|
||||
final_output: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_direct_output(self):
|
||||
return self.final_output is not None
|
||||
|
||||
def get_final_output(self):
|
||||
return self.final_output
|
||||
|
||||
def append_new_token(self, token: str):
|
||||
self.buffer += token
|
||||
|
||||
if not self.thread:
|
||||
self.thread = self.start_thread()
|
||||
|
||||
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
|
||||
self.buffer = completion
|
||||
self.is_final_chunk = True
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=completion
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
return completion
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
else:
|
||||
final_output = result.text
|
||||
|
||||
if public_event:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
return final_output
|
||||
|
||||
def start_thread(self) -> threading.Thread:
|
||||
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
|
||||
thread = threading.Thread(target=self.worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
|
||||
})
|
||||
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
|
||||
def stop_thread(self):
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread_running = False
|
||||
|
||||
def worker(self, flask_app: Flask, buffer_size: int):
|
||||
with flask_app.app_context():
|
||||
current_length = 0
|
||||
while self.thread_running:
|
||||
moderation_buffer = self.buffer
|
||||
buffer_length = len(moderation_buffer)
|
||||
if not self.is_final_chunk:
|
||||
chunk_length = buffer_length - current_length
|
||||
if 0 <= chunk_length < buffer_size:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
current_length = buffer_length
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=moderation_buffer
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
continue
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
self.final_output = final_output
|
||||
else:
|
||||
final_output = result.text + self.buffer[len(moderation_buffer):]
|
||||
|
||||
# trigger replace event
|
||||
if self.thread_running:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
break
|
||||
|
||||
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
|
||||
try:
|
||||
moderation_factory = ModerationFactory(
|
||||
name=self.rule.type,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
config=self.rule.config
|
||||
)
|
||||
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error("Moderation Output error: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
import enum
|
||||
import logging
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.moderation import openai_moderation
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceRule(BaseModel):
|
||||
class Type(enum.Enum):
|
||||
MODERATION = "moderation"
|
||||
KEYWORDS = "keywords"
|
||||
|
||||
type: Type
|
||||
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
|
||||
extra_params: dict = {}
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceChain(Chain):
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
model_instance: BaseLLM
|
||||
sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "sensitive_word_avoidance_chain"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _check_sensitive_word(self, text: str) -> bool:
|
||||
for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
|
||||
if word in text:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_moderation(self, text: str) -> bool:
|
||||
moderation_model_instance = ModelFactory.get_moderation_model(
|
||||
tenant_id=self.model_instance.model_provider.provider.tenant_id,
|
||||
model_provider_name='openai',
|
||||
model_name=openai_moderation.DEFAULT_MODEL
|
||||
)
|
||||
|
||||
try:
|
||||
return moderation_model_instance.run(text=text)
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
text = inputs[self.input_key]
|
||||
|
||||
if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
|
||||
result = self._check_sensitive_word(text)
|
||||
else:
|
||||
result = self._check_moderation(text)
|
||||
|
||||
if not result:
|
||||
raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
|
||||
|
||||
return {self.output_key: text}
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceError(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
@@ -1,13 +1,18 @@
|
||||
import concurrent
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List, Union
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, List, Union, Tuple
|
||||
|
||||
from flask import current_app, Flask
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
@@ -18,6 +23,8 @@ from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from models.model import App, AppModelConfig, Account, Conversation, EndUser
|
||||
from core.moderation.base import ModerationException, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class Completion:
|
||||
@@ -76,26 +83,35 @@ class Completion:
|
||||
)
|
||||
|
||||
try:
|
||||
# parse sensitive_word_avoidance_chain
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
|
||||
final_model_instance, [chain_callback])
|
||||
if sensitive_word_avoidance_chain:
|
||||
try:
|
||||
query = sensitive_word_avoidance_chain.run(query)
|
||||
except SensitiveWordAvoidanceError as ex:
|
||||
cls.run_final_llm(
|
||||
model_instance=final_model_instance,
|
||||
mode=app.mode,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
agent_execute_result=None,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
fake_response=ex.message
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
|
||||
except ModerationException as e:
|
||||
cls.run_final_llm(
|
||||
model_instance=final_model_instance,
|
||||
mode=app.mode,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
agent_execute_result=None,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
fake_response=str(e)
|
||||
)
|
||||
return
|
||||
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_model_config.external_data_tools_list
|
||||
if external_data_tools:
|
||||
inputs = cls.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
# get agent executor
|
||||
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
||||
@@ -135,19 +151,110 @@ class Completion:
|
||||
memory=memory,
|
||||
fake_response=fake_response
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
|
||||
return
|
||||
except ChunkedEncodingError as e:
|
||||
# Interrupt by LLM (like OpenAI), handle it.
|
||||
logging.warning(f'ChunkedEncodingError: {e}')
|
||||
conversation_message_task.end()
|
||||
return
|
||||
|
||||
|
||||
@classmethod
|
||||
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
|
||||
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
|
||||
return inputs, query
|
||||
|
||||
type = app_model_config.sensitive_word_avoidance_dict['type']
|
||||
|
||||
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
|
||||
moderation_result = moderation.moderation_for_inputs(inputs, query)
|
||||
|
||||
if not moderation_result.flagged:
|
||||
return inputs, query
|
||||
|
||||
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
raise ModerationException(moderation_result.preset_response)
|
||||
elif moderation_result.action == ModerationAction.OVERRIDED:
|
||||
inputs = moderation_result.inputs
|
||||
query = moderation_result.query
|
||||
|
||||
return inputs, query
|
||||
|
||||
@classmethod
|
||||
def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
|
||||
inputs: dict, query: str) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param app_id: app id
|
||||
:param external_data_tools: external data tools configs
|
||||
:param inputs: the inputs
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
# Group tools by type and config
|
||||
grouped_tools = {}
|
||||
for tool in external_data_tools:
|
||||
if not tool.get("enabled"):
|
||||
continue
|
||||
|
||||
tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
|
||||
grouped_tools.setdefault(tool_key, []).append(tool)
|
||||
|
||||
results = {}
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = {}
|
||||
for tools in grouped_tools.values():
|
||||
# Only query the first tool in each group
|
||||
first_tool = tools[0]
|
||||
future = executor.submit(
|
||||
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, first_tool,
|
||||
inputs, query
|
||||
)
|
||||
for tool in tools:
|
||||
futures[future] = tool
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
tool_key, result = future.result()
|
||||
if tool_key in grouped_tools:
|
||||
for tool in grouped_tools[tool_key]:
|
||||
results[tool['variable']] = result
|
||||
|
||||
inputs.update(results)
|
||||
return inputs
|
||||
|
||||
@classmethod
|
||||
def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
|
||||
inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
with flask_app.app_context():
|
||||
tool_variable = external_data_tool.get("variable")
|
||||
tool_type = external_data_tool.get("type")
|
||||
tool_config = external_data_tool.get("config")
|
||||
|
||||
external_data_tool_factory = ExternalDataToolFactory(
|
||||
name=tool_type,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
variable=tool_variable,
|
||||
config=tool_config
|
||||
)
|
||||
|
||||
# query external data tool
|
||||
result = external_data_tool_factory.query(
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
tool_key = (external_data_tool.get("type"), json.dumps(external_data_tool.get("config"), sort_keys=True))
|
||||
|
||||
return tool_key, result
|
||||
|
||||
@classmethod
|
||||
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
|
||||
if app.mode != 'completion':
|
||||
return query
|
||||
|
||||
|
||||
return inputs.get(app_model_config.dataset_query_variable, "")
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -290,6 +290,10 @@ class ConversationMessageTask:
|
||||
db.session.commit()
|
||||
self.retriever_resource = resource
|
||||
|
||||
def on_message_replace(self, text: str):
|
||||
if text is not None:
|
||||
self._pub_handler.pub_message_replace(text)
|
||||
|
||||
def message_end(self):
|
||||
self._pub_handler.pub_message_end(self.retriever_resource)
|
||||
|
||||
@@ -342,6 +346,24 @@ class PubHandler:
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_message_replace(self, text: str):
|
||||
content = {
|
||||
'event': 'message_replace',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': str(self._message.id),
|
||||
'text': text,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': str(self._conversation.id)
|
||||
}
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_chain(self, message_chain: MessageChain):
|
||||
if self._chain_pub:
|
||||
content = {
|
||||
@@ -443,3 +465,7 @@ class PubHandler:
|
||||
|
||||
class ConversationTaskStoppedException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConversationTaskInterruptException(Exception):
|
||||
pass
|
||||
|
||||
0
api/core/extension/__init__.py
Normal file
0
api/core/extension/__init__.py
Normal file
62
api/core/extension/api_based_extension_requestor.py
Normal file
62
api/core/extension/api_based_extension_requestor.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from models.api_based_extension import APIBasedExtensionPoint
|
||||
|
||||
|
||||
class APIBasedExtensionRequestor:
|
||||
timeout: (int, int) = (5, 60)
|
||||
"""timeout for request connect and read"""
|
||||
|
||||
def __init__(self, api_endpoint: str, api_key: str) -> None:
|
||||
self.api_endpoint = api_endpoint
|
||||
self.api_key = api_key
|
||||
|
||||
def request(self, point: APIBasedExtensionPoint, params: dict) -> dict:
|
||||
"""
|
||||
Request the api.
|
||||
|
||||
:param point: the api point
|
||||
:param params: the request params
|
||||
:return: the response json
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {}".format(self.api_key)
|
||||
}
|
||||
|
||||
url = self.api_endpoint
|
||||
|
||||
try:
|
||||
# proxy support for security
|
||||
proxies = None
|
||||
if os.environ.get("API_BASED_EXTENSION_HTTP_PROXY") and os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"):
|
||||
proxies = {
|
||||
'http': os.environ.get("API_BASED_EXTENSION_HTTP_PROXY"),
|
||||
'https': os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"),
|
||||
}
|
||||
|
||||
response = requests.request(
|
||||
method='POST',
|
||||
url=url,
|
||||
json={
|
||||
'point': point.value,
|
||||
'params': params
|
||||
},
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
proxies=proxies
|
||||
)
|
||||
except requests.exceptions.Timeout:
|
||||
raise ValueError("request timeout")
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise ValueError("request connection error")
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError("request error, status_code: {}, content: {}".format(
|
||||
response.status_code,
|
||||
response.text[:100]
|
||||
))
|
||||
|
||||
return response.json()
|
||||
111
api/core/extension/extensible.py
Normal file
111
api/core/extension/extensible.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import enum
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExtensionModule(enum.Enum):
|
||||
MODERATION = 'moderation'
|
||||
EXTERNAL_DATA_TOOL = 'external_data_tool'
|
||||
|
||||
|
||||
class ModuleExtension(BaseModel):
|
||||
extension_class: Any
|
||||
name: str
|
||||
label: Optional[dict] = None
|
||||
form_schema: Optional[list] = None
|
||||
builtin: bool = True
|
||||
position: Optional[int] = None
|
||||
|
||||
|
||||
class Extensible:
|
||||
module: ExtensionModule
|
||||
|
||||
name: str
|
||||
tenant_id: str
|
||||
config: Optional[dict] = None
|
||||
|
||||
def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def scan_extensions(cls):
|
||||
extensions = {}
|
||||
|
||||
# get the path of the current class
|
||||
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
|
||||
current_dir_path = os.path.dirname(current_path)
|
||||
|
||||
# traverse subdirectories
|
||||
for subdir_name in os.listdir(current_dir_path):
|
||||
if subdir_name.startswith('__'):
|
||||
continue
|
||||
|
||||
subdir_path = os.path.join(current_dir_path, subdir_name)
|
||||
extension_name = subdir_name
|
||||
if os.path.isdir(subdir_path):
|
||||
file_names = os.listdir(subdir_path)
|
||||
|
||||
# is builtin extension, builtin extension
|
||||
# in the front-end page and business logic, there are special treatments.
|
||||
builtin = False
|
||||
position = None
|
||||
if '__builtin__' in file_names:
|
||||
builtin = True
|
||||
|
||||
builtin_file_path = os.path.join(subdir_path, '__builtin__')
|
||||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, 'r') as f:
|
||||
position = int(f.read().strip())
|
||||
|
||||
if (extension_name + '.py') not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
|
||||
py_path = os.path.join(subdir_path, extension_name + '.py')
|
||||
spec = importlib.util.spec_from_file_location(extension_name, py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
extension_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
|
||||
extension_class = obj
|
||||
break
|
||||
|
||||
if not extension_class:
|
||||
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
|
||||
continue
|
||||
|
||||
json_data = {}
|
||||
if not builtin:
|
||||
if 'schema.json' not in file_names:
|
||||
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
json_path = os.path.join(subdir_path, 'schema.json')
|
||||
json_data = {}
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, 'r') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
extensions[extension_name] = ModuleExtension(
|
||||
extension_class=extension_class,
|
||||
name=extension_name,
|
||||
label=json_data.get('label'),
|
||||
form_schema=json_data.get('form_schema'),
|
||||
builtin=builtin,
|
||||
position=position
|
||||
)
|
||||
|
||||
sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position))
|
||||
sorted_extensions = OrderedDict(sorted_items)
|
||||
|
||||
return sorted_extensions
|
||||
47
api/core/extension/extension.py
Normal file
47
api/core/extension/extension.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from core.extension.extensible import ModuleExtension, ExtensionModule
|
||||
from core.external_data_tool.base import ExternalDataTool
|
||||
from core.moderation.base import Moderation
|
||||
|
||||
|
||||
class Extension:
|
||||
__module_extensions: dict[str, dict[str, ModuleExtension]] = {}
|
||||
|
||||
module_classes = {
|
||||
ExtensionModule.MODERATION: Moderation,
|
||||
ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool
|
||||
}
|
||||
|
||||
def init(self):
|
||||
for module, module_class in self.module_classes.items():
|
||||
self.__module_extensions[module.value] = module_class.scan_extensions()
|
||||
|
||||
def module_extensions(self, module: str) -> list[ModuleExtension]:
|
||||
module_extensions = self.__module_extensions.get(module)
|
||||
|
||||
if not module_extensions:
|
||||
raise ValueError(f"Extension Module {module} not found")
|
||||
|
||||
return list(module_extensions.values())
|
||||
|
||||
def module_extension(self, module: ExtensionModule, extension_name: str) -> ModuleExtension:
|
||||
module_extensions = self.__module_extensions.get(module.value)
|
||||
|
||||
if not module_extensions:
|
||||
raise ValueError(f"Extension Module {module} not found")
|
||||
|
||||
module_extension = module_extensions.get(extension_name)
|
||||
|
||||
if not module_extension:
|
||||
raise ValueError(f"Extension {extension_name} not found")
|
||||
|
||||
return module_extension
|
||||
|
||||
def extension_class(self, module: ExtensionModule, extension_name: str) -> type:
|
||||
module_extension = self.module_extension(module, extension_name)
|
||||
return module_extension.extension_class
|
||||
|
||||
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
|
||||
module_extension = self.module_extension(module, extension_name)
|
||||
form_schema = module_extension.form_schema
|
||||
|
||||
# TODO validate form_schema
|
||||
0
api/core/external_data_tool/__init__.py
Normal file
0
api/core/external_data_tool/__init__.py
Normal file
1
api/core/external_data_tool/api/__builtin__
Normal file
1
api/core/external_data_tool/api/__builtin__
Normal file
@@ -0,0 +1 @@
|
||||
1
|
||||
0
api/core/external_data_tool/api/__init__.py
Normal file
0
api/core/external_data_tool/api/__init__.py
Normal file
92
api/core/external_data_tool/api/api.py
Normal file
92
api/core/external_data_tool/api/api.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||
from core.external_data_tool.base import ExternalDataTool
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
|
||||
|
||||
class ApiExternalDataTool(ExternalDataTool):
|
||||
"""
|
||||
The api external data tool.
|
||||
"""
|
||||
|
||||
name: str = "api"
|
||||
"""the unique name of external data tool"""
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
# own validation logic
|
||||
api_based_extension_id = config.get("api_based_extension_id")
|
||||
if not api_based_extension_id:
|
||||
raise ValueError("api_based_extension_id is required")
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = db.session.query(APIBasedExtension).filter(
|
||||
APIBasedExtension.tenant_id == tenant_id,
|
||||
APIBasedExtension.id == api_based_extension_id
|
||||
).first()
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
|
||||
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
:param inputs: user inputs
|
||||
:param query: the query of chat app
|
||||
:return: the tool query result
|
||||
"""
|
||||
# get params from config
|
||||
api_based_extension_id = self.config.get("api_based_extension_id")
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = db.session.query(APIBasedExtension).filter(
|
||||
APIBasedExtension.tenant_id == self.tenant_id,
|
||||
APIBasedExtension.id == api_based_extension_id
|
||||
).first()
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, "
|
||||
"error: api_based_extension_id is invalid"
|
||||
.format(self.config.get('variable')))
|
||||
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id,
|
||||
token=api_based_extension.api_key
|
||||
)
|
||||
|
||||
try:
|
||||
# request api
|
||||
requestor = APIBasedExtensionRequestor(
|
||||
api_endpoint=api_based_extension.api_endpoint,
|
||||
api_key=api_key
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(
|
||||
self.config.get('variable'),
|
||||
e
|
||||
))
|
||||
|
||||
response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={
|
||||
'app_id': self.app_id,
|
||||
'tool_variable': self.variable,
|
||||
'inputs': inputs,
|
||||
'query': query
|
||||
})
|
||||
|
||||
if 'result' not in response_json:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response"
|
||||
.format(self.config.get('variable')))
|
||||
|
||||
return response_json['result']
|
||||
45
api/core/external_data_tool/base.py
Normal file
45
api/core/external_data_tool/base.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Optional
|
||||
|
||||
from core.extension.extensible import Extensible, ExtensionModule
|
||||
|
||||
|
||||
class ExternalDataTool(Extensible, ABC):
|
||||
"""
|
||||
The base class of external data tool.
|
||||
"""
|
||||
|
||||
module: ExtensionModule = ExtensionModule.EXTERNAL_DATA_TOOL
|
||||
|
||||
app_id: str
|
||||
"""the id of app"""
|
||||
variable: str
|
||||
"""the tool variable name of app tool"""
|
||||
|
||||
def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None:
|
||||
super().__init__(tenant_id, config)
|
||||
self.app_id = app_id
|
||||
self.variable = variable
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
:param inputs: user inputs
|
||||
:param query: the query of chat app
|
||||
:return: the tool query result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
40
api/core/external_data_tool/factory.py
Normal file
40
api/core/external_data_tool/factory.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.extension.extensible import ExtensionModule
|
||||
from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
|
||||
class ExternalDataToolFactory:
|
||||
|
||||
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
self.__extension_instance = extension_class(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
variable=variable,
|
||||
config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param name: the name of external data tool
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
extension_class.validate_config(tenant_id, config)
|
||||
|
||||
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
:param inputs: user inputs
|
||||
:param query: the query of chat app
|
||||
:return: the tool query result
|
||||
"""
|
||||
return self.__extension_instance.query(inputs, query)
|
||||
0
api/core/moderation/__init__.py
Normal file
0
api/core/moderation/__init__.py
Normal file
1
api/core/moderation/api/__builtin__
Normal file
1
api/core/moderation/api/__builtin__
Normal file
@@ -0,0 +1 @@
|
||||
3
|
||||
0
api/core/moderation/api/__init__.py
Normal file
0
api/core/moderation/api/__init__.py
Normal file
88
api/core/moderation/api/api.py
Normal file
88
api/core/moderation/api/api.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor, APIBasedExtensionPoint
|
||||
from core.helper.encrypter import decrypt_token
|
||||
from extensions.ext_database import db
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
|
||||
|
||||
class ModerationInputParams(BaseModel):
|
||||
app_id: str = ""
|
||||
inputs: dict = {}
|
||||
query: str = ""
|
||||
|
||||
|
||||
class ModerationOutputParams(BaseModel):
|
||||
app_id: str = ""
|
||||
text: str
|
||||
|
||||
|
||||
class ApiModeration(Moderation):
|
||||
name: str = "api"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, False)
|
||||
|
||||
api_based_extension_id = config.get("api_based_extension_id")
|
||||
if not api_based_extension_id:
|
||||
raise ValueError("api_based_extension_id is required")
|
||||
|
||||
extension = cls._get_api_based_extension(tenant_id, api_based_extension_id)
|
||||
if not extension:
|
||||
raise ValueError("API-based Extension not found. Please check it again.")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['inputs_config']['enabled']:
|
||||
params = ModerationInputParams(
|
||||
app_id=self.app_id,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.dict())
|
||||
return ModerationInputsResult(**result)
|
||||
|
||||
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['outputs_config']['enabled']:
|
||||
params = ModerationOutputParams(
|
||||
app_id=self.app_id,
|
||||
text=text
|
||||
)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.dict())
|
||||
return ModerationOutputsResult(**result)
|
||||
|
||||
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
|
||||
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id"))
|
||||
requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
|
||||
|
||||
result = requestor.request(extension_point, params)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||
extension = db.session.query(APIBasedExtension).filter(
|
||||
APIBasedExtension.tenant_id == tenant_id,
|
||||
APIBasedExtension.id == api_based_extension_id
|
||||
).first()
|
||||
|
||||
return extension
|
||||
113
api/core/moderation/base.py
Normal file
113
api/core/moderation/base.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
|
||||
from core.extension.extensible import Extensible, ExtensionModule
|
||||
|
||||
|
||||
class ModerationAction(Enum):
|
||||
DIRECT_OUTPUT = 'direct_output'
|
||||
OVERRIDED = 'overrided'
|
||||
|
||||
|
||||
class ModerationInputsResult(BaseModel):
|
||||
flagged: bool = False
|
||||
action: ModerationAction
|
||||
preset_response: str = ""
|
||||
inputs: dict = {}
|
||||
query: str = ""
|
||||
|
||||
|
||||
class ModerationOutputsResult(BaseModel):
|
||||
flagged: bool = False
|
||||
action: ModerationAction
|
||||
preset_response: str = ""
|
||||
text: str = ""
|
||||
|
||||
|
||||
class Moderation(Extensible, ABC):
|
||||
"""
|
||||
The base class of moderation.
|
||||
"""
|
||||
module: ExtensionModule = ExtensionModule.MODERATION
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None:
|
||||
super().__init__(tenant_id, config)
|
||||
self.app_id = app_id
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
Moderation for inputs.
|
||||
After the user inputs, this method will be called to perform sensitive content review
|
||||
on the user inputs and return the processed results.
|
||||
|
||||
:param inputs: user inputs
|
||||
:param query: query string (required in chat app)
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
"""
|
||||
Moderation for outputs.
|
||||
When LLM outputs content, the front end will pass the output content (may be segmented)
|
||||
to this method for sensitive content review, and the output content will be shielded if the review fails.
|
||||
|
||||
:param text: LLM output content
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None:
|
||||
# inputs_config
|
||||
inputs_config = config.get("inputs_config")
|
||||
if not isinstance(inputs_config, dict):
|
||||
raise ValueError("inputs_config must be a dict")
|
||||
|
||||
# outputs_config
|
||||
outputs_config = config.get("outputs_config")
|
||||
if not isinstance(outputs_config, dict):
|
||||
raise ValueError("outputs_config must be a dict")
|
||||
|
||||
inputs_config_enabled = inputs_config.get("enabled")
|
||||
outputs_config_enabled = outputs_config.get("enabled")
|
||||
if not inputs_config_enabled and not outputs_config_enabled:
|
||||
raise ValueError("At least one of inputs_config or outputs_config must be enabled")
|
||||
|
||||
# preset_response
|
||||
if not is_preset_response_required:
|
||||
return
|
||||
|
||||
if inputs_config_enabled:
|
||||
if not inputs_config.get("preset_response"):
|
||||
raise ValueError("inputs_config.preset_response is required")
|
||||
|
||||
if len(inputs_config.get("preset_response")) > 100:
|
||||
raise ValueError("inputs_config.preset_response must be less than 100 characters")
|
||||
|
||||
if outputs_config_enabled:
|
||||
if not outputs_config.get("preset_response"):
|
||||
raise ValueError("outputs_config.preset_response is required")
|
||||
|
||||
if len(outputs_config.get("preset_response")) > 100:
|
||||
raise ValueError("outputs_config.preset_response must be less than 100 characters")
|
||||
|
||||
|
||||
class ModerationException(Exception):
|
||||
pass
|
||||
48
api/core/moderation/factory.py
Normal file
48
api/core/moderation/factory.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from core.extension.extensible import ExtensionModule
|
||||
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
|
||||
from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
|
||||
class ModerationFactory:
|
||||
__extension_instance: Moderation
|
||||
|
||||
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None:
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||
self.__extension_instance = extension_class(app_id, tenant_id, config)
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param name: the name of extension
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||
extension_class.validate_config(tenant_id, config)
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
Moderation for inputs.
|
||||
After the user inputs, this method will be called to perform sensitive content review
|
||||
on the user inputs and return the processed results.
|
||||
|
||||
:param inputs: user inputs
|
||||
:param query: query string (required in chat app)
|
||||
:return:
|
||||
"""
|
||||
return self.__extension_instance.moderation_for_inputs(inputs, query)
|
||||
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
"""
|
||||
Moderation for outputs.
|
||||
When LLM outputs content, the front end will pass the output content (may be segmented)
|
||||
to this method for sensitive content review, and the output content will be shielded if the review fails.
|
||||
|
||||
:param text: LLM output content
|
||||
:return:
|
||||
"""
|
||||
return self.__extension_instance.moderation_for_outputs(text)
|
||||
1
api/core/moderation/keywords/__builtin__
Normal file
1
api/core/moderation/keywords/__builtin__
Normal file
@@ -0,0 +1 @@
|
||||
2
|
||||
0
api/core/moderation/keywords/__init__.py
Normal file
0
api/core/moderation/keywords/__init__.py
Normal file
60
api/core/moderation/keywords/keywords.py
Normal file
60
api/core/moderation/keywords/keywords.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
|
||||
|
||||
|
||||
class KeywordsModeration(Moderation):
|
||||
name: str = "keywords"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, True)
|
||||
|
||||
if not config.get("keywords"):
|
||||
raise ValueError("keywords is required")
|
||||
|
||||
if len(config.get("keywords")) > 1000:
|
||||
raise ValueError("keywords length must be less than 1000")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['inputs_config']['enabled']:
|
||||
preset_response = self.config['inputs_config']['preset_response']
|
||||
|
||||
if query:
|
||||
inputs['query__'] = query
|
||||
keywords_list = self.config['keywords'].split('\n')
|
||||
flagged = self._is_violated(inputs, keywords_list)
|
||||
|
||||
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['outputs_config']['enabled']:
|
||||
keywords_list = self.config['keywords'].split('\n')
|
||||
flagged = self._is_violated({'text': text}, keywords_list)
|
||||
preset_response = self.config['outputs_config']['preset_response']
|
||||
|
||||
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
|
||||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||
for value in inputs.values():
|
||||
if self._check_keywords_in_value(keywords_list, value):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _check_keywords_in_value(self, keywords_list, value):
|
||||
for keyword in keywords_list:
|
||||
if keyword.lower() in value.lower():
|
||||
return True
|
||||
return False
|
||||
1
api/core/moderation/openai_moderation/__builtin__
Normal file
1
api/core/moderation/openai_moderation/__builtin__
Normal file
@@ -0,0 +1 @@
|
||||
1
|
||||
0
api/core/moderation/openai_moderation/__init__.py
Normal file
0
api/core/moderation/openai_moderation/__init__.py
Normal file
46
api/core/moderation/openai_moderation/openai_moderation.py
Normal file
46
api/core/moderation/openai_moderation/openai_moderation.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
|
||||
|
||||
class OpenAIModeration(Moderation):
|
||||
name: str = "openai_moderation"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, True)
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['inputs_config']['enabled']:
|
||||
preset_response = self.config['inputs_config']['preset_response']
|
||||
|
||||
if query:
|
||||
inputs['query__'] = query
|
||||
flagged = self._is_violated(inputs)
|
||||
|
||||
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['outputs_config']['enabled']:
|
||||
flagged = self._is_violated({'text': text})
|
||||
preset_response = self.config['outputs_config']['preset_response']
|
||||
|
||||
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
|
||||
def _is_violated(self, inputs: dict):
|
||||
text = '\n'.join(inputs.values())
|
||||
openai_moderation = ModelFactory.get_moderation_model(self.tenant_id, "openai", "moderation")
|
||||
is_not_invalid = openai_moderation.run(text)
|
||||
return not is_not_invalid
|
||||
@@ -11,7 +11,6 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
@@ -125,52 +124,6 @@ class OrchestratorRuleParser:
|
||||
|
||||
return chain
|
||||
|
||||
def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
|
||||
-> Optional[SensitiveWordAvoidanceChain]:
|
||||
"""
|
||||
Convert app sensitive word avoidance config to chain
|
||||
|
||||
:param model_instance: model instance
|
||||
:param callbacks: callbacks for the chain
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
sensitive_word_avoidance_rule = None
|
||||
|
||||
if self.app_model_config.sensitive_word_avoidance_dict:
|
||||
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
|
||||
if sensitive_word_avoidance_config.get("enabled", False):
|
||||
if sensitive_word_avoidance_config.get('type') == 'moderation':
|
||||
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
|
||||
type=SensitiveWordAvoidanceRule.Type.MODERATION,
|
||||
canned_response=sensitive_word_avoidance_config.get("canned_response")
|
||||
if sensitive_word_avoidance_config.get("canned_response")
|
||||
else 'Your content violates our usage policy. Please revise and try again.',
|
||||
)
|
||||
else:
|
||||
sensitive_words = sensitive_word_avoidance_config.get("words", "")
|
||||
if sensitive_words:
|
||||
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
|
||||
type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
|
||||
canned_response=sensitive_word_avoidance_config.get("canned_response")
|
||||
if sensitive_word_avoidance_config.get("canned_response")
|
||||
else 'Your content violates our usage policy. Please revise and try again.',
|
||||
extra_params={
|
||||
'sensitive_words': sensitive_words.split(','),
|
||||
}
|
||||
)
|
||||
|
||||
if sensitive_word_avoidance_rule:
|
||||
return SensitiveWordAvoidanceChain(
|
||||
model_instance=model_instance,
|
||||
sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
|
||||
output_key="sensitive_word_avoidance_output",
|
||||
callbacks=callbacks,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
|
||||
"""
|
||||
Convert app agent tool configs to tools
|
||||
|
||||
Reference in New Issue
Block a user