Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
|
||||
import copy
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelMode
|
||||
from core.prompt.prompt_transform import AppMode
|
||||
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
|
||||
@@ -25,14 +24,14 @@ class AdvancedPromptTemplateService:
|
||||
context_prompt = copy.deepcopy(CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
if model_mode == ModelMode.COMPLETION.value:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
|
||||
elif model_mode == ModelMode.CHAT.value:
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
|
||||
elif app_mode == AppMode.COMPLETION.value:
|
||||
if model_mode == ModelMode.COMPLETION.value:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
|
||||
elif model_mode == ModelMode.CHAT.value:
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
|
||||
|
||||
@classmethod
|
||||
@@ -54,12 +53,12 @@ class AdvancedPromptTemplateService:
|
||||
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
if model_mode == ModelMode.COMPLETION.value:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
elif model_mode == ModelMode.CHAT.value:
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
elif app_mode == AppMode.COMPLETION.value:
|
||||
if model_mode == ModelMode.COMPLETION.value:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
elif model_mode == ModelMode.CHAT.value:
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
@@ -2,11 +2,12 @@ import re
|
||||
import uuid
|
||||
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from core.prompt.prompt_transform import AppMode
|
||||
from core.agent.agent_executor import PlanningStrategy
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelMode
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
@@ -34,26 +35,6 @@ class AppModelConfigService:
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
# max_tokens
|
||||
if 'max_tokens' not in cp:
|
||||
cp["max_tokens"] = 512
|
||||
|
||||
# temperature
|
||||
if 'temperature' not in cp:
|
||||
cp["temperature"] = 1
|
||||
|
||||
# top_p
|
||||
if 'top_p' not in cp:
|
||||
cp["top_p"] = 1
|
||||
|
||||
# presence_penalty
|
||||
if 'presence_penalty' not in cp:
|
||||
cp["presence_penalty"] = 0
|
||||
|
||||
# presence_penalty
|
||||
if 'frequency_penalty' not in cp:
|
||||
cp["frequency_penalty"] = 0
|
||||
|
||||
# stop
|
||||
if 'stop' not in cp:
|
||||
cp["stop"] = []
|
||||
@@ -63,20 +44,10 @@ class AppModelConfigService:
|
||||
if len(cp["stop"]) > 4:
|
||||
raise ValueError("stop sequences must be less than 4")
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_cp = {
|
||||
"max_tokens": cp["max_tokens"],
|
||||
"temperature": cp["temperature"],
|
||||
"top_p": cp["top_p"],
|
||||
"presence_penalty": cp["presence_penalty"],
|
||||
"frequency_penalty": cp["frequency_penalty"],
|
||||
"stop": cp["stop"]
|
||||
}
|
||||
|
||||
return filtered_cp
|
||||
return cp
|
||||
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, account: Account, config: dict, mode: str) -> dict:
|
||||
def validate_configuration(cls, tenant_id: str, account: Account, config: dict, app_mode: str) -> dict:
|
||||
# opening_statement
|
||||
if 'opening_statement' not in config or not config["opening_statement"]:
|
||||
config["opening_statement"] = ""
|
||||
@@ -140,21 +111,6 @@ class AppModelConfigService:
|
||||
if not isinstance(config["retriever_resource"]["enabled"], bool):
|
||||
raise ValueError("enabled in retriever_resource must be of boolean type")
|
||||
|
||||
# annotation reply
|
||||
if 'annotation_reply' not in config or not config["annotation_reply"]:
|
||||
config["annotation_reply"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["annotation_reply"], dict):
|
||||
raise ValueError("annotation_reply must be of dict type")
|
||||
|
||||
if "enabled" not in config["annotation_reply"] or not config["annotation_reply"]["enabled"]:
|
||||
config["annotation_reply"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["annotation_reply"]["enabled"], bool):
|
||||
raise ValueError("enabled in annotation_reply must be of boolean type")
|
||||
|
||||
# more_like_this
|
||||
if 'more_like_this' not in config or not config["more_like_this"]:
|
||||
config["more_like_this"] = {
|
||||
@@ -178,7 +134,8 @@ class AppModelConfigService:
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# model.provider
|
||||
model_provider_names = ModelProviderFactory.get_provider_names()
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
@@ -186,18 +143,29 @@ class AppModelConfigService:
|
||||
if 'name' not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, config["model"]["provider"])
|
||||
if not model_provider:
|
||||
provider_manager = ProviderManager()
|
||||
models = provider_manager.get_configurations(tenant_id).get_models(
|
||||
provider=config["model"]["provider"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
if not models:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
model_list = model_provider.get_supported_model_list(ModelType.TEXT_GENERATION)
|
||||
model_ids = [m['id'] for m in model_list]
|
||||
model_ids = [m.model for m in models]
|
||||
if config["model"]["name"] not in model_ids:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
model_mode = None
|
||||
for model in models:
|
||||
if model.model == config["model"]["name"]:
|
||||
model_mode = model.model_properties.get(ModelPropertyKey.MODE)
|
||||
break
|
||||
|
||||
# model.mode
|
||||
if 'mode' not in config['model'] or not config['model']["mode"]:
|
||||
config['model']["mode"] = ""
|
||||
if model_mode:
|
||||
config['model']["mode"] = model_mode
|
||||
else:
|
||||
config['model']["mode"] = "completion"
|
||||
|
||||
# model.completion_params
|
||||
if 'completion_params' not in config["model"]:
|
||||
@@ -319,10 +287,10 @@ class AppModelConfigService:
|
||||
raise ValueError("Dataset ID does not exist, please check your permission.")
|
||||
|
||||
# dataset_query_variable
|
||||
cls.is_dataset_query_variable_valid(config, mode)
|
||||
cls.is_dataset_query_variable_valid(config, app_mode)
|
||||
|
||||
# advanced prompt validation
|
||||
cls.is_advanced_prompt_valid(config, mode)
|
||||
cls.is_advanced_prompt_valid(config, app_mode)
|
||||
|
||||
# external data tools validation
|
||||
cls.is_external_data_tools_valid(tenant_id, config)
|
||||
@@ -340,7 +308,6 @@ class AppModelConfigService:
|
||||
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
|
||||
"speech_to_text": config["speech_to_text"],
|
||||
"retriever_resource": config["retriever_resource"],
|
||||
"annotation_reply": config["annotation_reply"],
|
||||
"more_like_this": config["more_like_this"],
|
||||
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
|
||||
"external_data_tools": config["external_data_tools"],
|
||||
@@ -507,7 +474,7 @@ class AppModelConfigService:
|
||||
if config['model']["mode"] not in ['chat', 'completion']:
|
||||
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
|
||||
|
||||
if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
|
||||
if app_mode == AppMode.CHAT.value and config['model']["mode"] == "completion":
|
||||
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
|
||||
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
|
||||
|
||||
@@ -517,7 +484,7 @@ class AppModelConfigService:
|
||||
if not assistant_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
|
||||
|
||||
if config['model']["mode"] == ModelMode.CHAT.value:
|
||||
if config['model']["mode"] == "chat":
|
||||
prompt_list = config['chat_prompt_config']['prompt']
|
||||
|
||||
if len(prompt_list) > 10:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import io
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
|
||||
FILE_SIZE = 15
|
||||
@@ -25,11 +27,13 @@ class AudioService:
|
||||
message = f"Audio size larger than {FILE_SIZE} mb"
|
||||
raise AudioTooLargeServiceError(message)
|
||||
|
||||
model = ModelFactory.get_speech2text_model(
|
||||
tenant_id=tenant_id
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.SPEECH2TEXT
|
||||
)
|
||||
|
||||
buffer = io.BytesIO(file_content)
|
||||
buffer.name = 'temp.mp3'
|
||||
|
||||
return model.run(buffer)
|
||||
return {"text": model_instance.invoke_speech2text(buffer)}
|
||||
|
||||
@@ -1,29 +1,16 @@
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Generator, Union, Any, Optional, List
|
||||
from typing import Generator, Union, Any
|
||||
|
||||
from flask import current_app, Flask
|
||||
from redis.client import PubSub
|
||||
from sqlalchemy import and_
|
||||
|
||||
from core.completion import Completion
|
||||
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.application_manager import ApplicationManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, \
|
||||
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_providers.models.entity.message import PromptMessageFile
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.completion import CompletionStoppedError
|
||||
from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
@@ -32,7 +19,7 @@ class CompletionService:
|
||||
|
||||
@classmethod
|
||||
def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
|
||||
from_source: str, streaming: bool = True,
|
||||
invoke_from: InvokeFrom, streaming: bool = True,
|
||||
is_model_config_override: bool = False) -> Union[dict, Generator]:
|
||||
# is streaming mode
|
||||
inputs = args['inputs']
|
||||
@@ -56,7 +43,7 @@ class CompletionService:
|
||||
Conversation.status == 'normal'
|
||||
]
|
||||
|
||||
if from_source == 'console':
|
||||
if isinstance(user, Account):
|
||||
conversation_filter.append(Conversation.from_account_id == user.id)
|
||||
else:
|
||||
conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
|
||||
@@ -124,7 +111,7 @@ class CompletionService:
|
||||
tenant_id=app_model.tenant_id,
|
||||
account=user,
|
||||
config=args['model_config'],
|
||||
mode=app_model.mode
|
||||
app_mode=app_model.mode
|
||||
)
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
@@ -145,134 +132,29 @@ class CompletionService:
|
||||
user
|
||||
)
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
|
||||
|
||||
user = cls.get_real_user_instead_of_proxy_obj(user)
|
||||
|
||||
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'generate_task_id': generate_task_id,
|
||||
'detached_app_model': app_model,
|
||||
'app_model_config': app_model_config.copy(),
|
||||
'query': query,
|
||||
'inputs': inputs,
|
||||
'files': file_objs,
|
||||
'detached_user': user,
|
||||
'detached_conversation': conversation,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': is_model_config_override,
|
||||
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
|
||||
'auto_generate_name': auto_generate_name,
|
||||
'from_source': from_source
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
|
||||
# wait for 10 minutes to close the thread
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
|
||||
generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
|
||||
@classmethod
|
||||
def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
|
||||
if isinstance(user, Account):
|
||||
user = db.session.query(Account).filter(Account.id == user.id).first()
|
||||
elif isinstance(user, EndUser):
|
||||
user = db.session.query(EndUser).filter(EndUser.id == user.id).first()
|
||||
else:
|
||||
raise Exception("Unknown user type")
|
||||
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict, files: List[PromptMessageFile],
|
||||
detached_user: Union[Account, EndUser],
|
||||
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
|
||||
retriever_from: str = 'dev', auto_generate_name: bool = True, from_source: str = 'console'):
|
||||
with flask_app.app_context():
|
||||
# fixed the state of the model object when it detached from the original session
|
||||
user = db.session.merge(detached_user)
|
||||
app_model = db.session.merge(detached_app_model)
|
||||
|
||||
if detached_conversation:
|
||||
conversation = db.session.merge(detached_conversation)
|
||||
else:
|
||||
conversation = None
|
||||
|
||||
try:
|
||||
# run
|
||||
Completion.generate(
|
||||
task_id=generate_task_id,
|
||||
app=app_model,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
user=user,
|
||||
files=files,
|
||||
conversation=conversation,
|
||||
streaming=streaming,
|
||||
is_override=is_model_config_override,
|
||||
retriever_from=retriever_from,
|
||||
auto_generate_name=auto_generate_name,
|
||||
from_source=from_source
|
||||
)
|
||||
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
|
||||
pass
|
||||
except (ValueError, LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
|
||||
ModelCurrentlyNotSupportError) as e:
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
except LLMAuthorizationError:
|
||||
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
|
||||
except Exception as e:
|
||||
logging.exception("Unknown Error in completion")
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
@classmethod
|
||||
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user,
|
||||
generate_task_id) -> threading.Thread:
|
||||
# wait for 10 minutes to close the thread
|
||||
timeout = 600
|
||||
|
||||
def close_pubsub():
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
user = db.session.merge(detached_user)
|
||||
|
||||
sleep_iterations = 0
|
||||
while sleep_iterations < timeout and worker_thread.is_alive():
|
||||
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
|
||||
PubHandler.ping(user, generate_task_id)
|
||||
|
||||
time.sleep(1)
|
||||
sleep_iterations += 1
|
||||
|
||||
if worker_thread.is_alive():
|
||||
PubHandler.stop(user, generate_task_id)
|
||||
try:
|
||||
pubsub.close()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
countdown_thread = threading.Thread(target=close_pubsub)
|
||||
countdown_thread.start()
|
||||
|
||||
return countdown_thread
|
||||
application_manager = ApplicationManager()
|
||||
return application_manager.generate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=app_model_config.to_dict(),
|
||||
app_model_config_override=is_model_config_override,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=file_objs,
|
||||
conversation=conversation,
|
||||
stream=streaming,
|
||||
extras={
|
||||
"auto_generate_conversation_name": auto_generate_name
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
|
||||
message_id: str, streaming: bool = True,
|
||||
retriever_from: str = 'dev') -> Union[dict, Generator]:
|
||||
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
|
||||
-> Union[dict, Generator]:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
|
||||
@@ -306,36 +188,24 @@ class CompletionService:
|
||||
message.files, app_model_config
|
||||
)
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
|
||||
|
||||
user = cls.get_real_user_instead_of_proxy_obj(user)
|
||||
|
||||
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'generate_task_id': generate_task_id,
|
||||
'detached_app_model': app_model,
|
||||
'app_model_config': app_model_config.copy(),
|
||||
'query': message.query,
|
||||
'inputs': message.inputs,
|
||||
'files': file_objs,
|
||||
'detached_user': user,
|
||||
'detached_conversation': None,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': True,
|
||||
'retriever_from': retriever_from,
|
||||
'auto_generate_name': False
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
|
||||
# wait for 10 minutes to close the thread
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
|
||||
generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
application_manager = ApplicationManager()
|
||||
return application_manager.generate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=app_model_config.to_dict(),
|
||||
app_model_config_override=True,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
files=file_objs,
|
||||
conversation=None,
|
||||
stream=streaming,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|
||||
@@ -375,247 +245,3 @@ class CompletionService:
|
||||
|
||||
return filtered_inputs
|
||||
|
||||
@classmethod
|
||||
def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict, Generator]:
|
||||
generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
|
||||
if not streaming:
|
||||
try:
|
||||
message_result = {}
|
||||
for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
result = message["data"].decode('utf-8')
|
||||
result = json.loads(result)
|
||||
if result.get('error'):
|
||||
cls.handle_error(result)
|
||||
if result['event'] == 'annotation' and 'data' in result:
|
||||
message_result['annotation'] = result.get('data')
|
||||
return cls.get_blocking_annotation_message_response_data(message_result)
|
||||
if result['event'] == 'message' and 'data' in result:
|
||||
message_result['message'] = result.get('data')
|
||||
if result['event'] == 'message_end' and 'data' in result:
|
||||
message_result['message_end'] = result.get('data')
|
||||
return cls.get_blocking_message_response_data(message_result)
|
||||
except ValueError as e:
|
||||
if e.args[0] != "I/O operation on closed file.": # ignore this error
|
||||
raise CompletionStoppedError()
|
||||
else:
|
||||
logging.exception(e)
|
||||
raise
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
try:
|
||||
pubsub.unsubscribe(generate_channel)
|
||||
except ConnectionError:
|
||||
pass
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
result = message["data"].decode('utf-8')
|
||||
result = json.loads(result)
|
||||
if result.get('error'):
|
||||
cls.handle_error(result)
|
||||
|
||||
event = result.get('event')
|
||||
if event == "end":
|
||||
logging.debug("{} finished".format(generate_channel))
|
||||
break
|
||||
if event == 'message':
|
||||
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'message_replace':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'chain':
|
||||
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'agent_thought':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'annotation':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_annotation_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'message_end':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_message_end_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'ping':
|
||||
yield "event: ping\n\n"
|
||||
else:
|
||||
yield "data: " + json.dumps(result) + "\n\n"
|
||||
except ValueError as e:
|
||||
if e.args[0] != "I/O operation on closed file.": # ignore this error
|
||||
logging.exception(e)
|
||||
raise
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
try:
|
||||
pubsub.unsubscribe(generate_channel)
|
||||
except ConnectionError:
|
||||
pass
|
||||
|
||||
return generate()
|
||||
|
||||
@classmethod
|
||||
def get_message_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'message',
|
||||
'task_id': data.get('task_id'),
|
||||
'id': data.get('message_id'),
|
||||
'answer': data.get('text'),
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_message_replace_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'message_replace',
|
||||
'task_id': data.get('task_id'),
|
||||
'id': data.get('message_id'),
|
||||
'answer': data.get('text'),
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_blocking_message_response_data(cls, data: dict):
|
||||
message = data.get('message')
|
||||
response_data = {
|
||||
'event': 'message',
|
||||
'task_id': message.get('task_id'),
|
||||
'id': message.get('message_id'),
|
||||
'answer': message.get('text'),
|
||||
'metadata': {},
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if message.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = message.get('conversation_id')
|
||||
if 'message_end' in data:
|
||||
message_end = data.get('message_end')
|
||||
if 'retriever_resources' in message_end:
|
||||
response_data['metadata']['retriever_resources'] = message_end.get('retriever_resources')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_blocking_annotation_message_response_data(cls, data: dict):
|
||||
message = data.get('annotation')
|
||||
response_data = {
|
||||
'event': 'annotation',
|
||||
'task_id': message.get('task_id'),
|
||||
'id': message.get('message_id'),
|
||||
'answer': message.get('text'),
|
||||
'metadata': {},
|
||||
'created_at': int(time.time()),
|
||||
'annotation_id': message.get('annotation_id'),
|
||||
'annotation_author_name': message.get('annotation_author_name')
|
||||
}
|
||||
|
||||
if message.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = message.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_message_end_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'message_end',
|
||||
'task_id': data.get('task_id'),
|
||||
'id': data.get('message_id')
|
||||
}
|
||||
if 'retriever_resources' in data:
|
||||
response_data['retriever_resources'] = data.get('retriever_resources')
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_chain_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'chain',
|
||||
'id': data.get('chain_id'),
|
||||
'task_id': data.get('task_id'),
|
||||
'message_id': data.get('message_id'),
|
||||
'type': data.get('type'),
|
||||
'input': data.get('input'),
|
||||
'output': data.get('output'),
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_agent_thought_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'agent_thought',
|
||||
'id': data.get('id'),
|
||||
'chain_id': data.get('chain_id'),
|
||||
'task_id': data.get('task_id'),
|
||||
'message_id': data.get('message_id'),
|
||||
'position': data.get('position'),
|
||||
'thought': data.get('thought'),
|
||||
'tool': data.get('tool'),
|
||||
'tool_input': data.get('tool_input'),
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_annotation_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'annotation',
|
||||
'task_id': data.get('task_id'),
|
||||
'id': data.get('message_id'),
|
||||
'answer': data.get('text'),
|
||||
'created_at': int(time.time()),
|
||||
'annotation_id': data.get('annotation_id'),
|
||||
'annotation_author_name': data.get('annotation_author_name'),
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def handle_error(cls, result: dict):
|
||||
logging.debug("error: %s", result)
|
||||
error = result.get('error')
|
||||
description = result.get('description')
|
||||
|
||||
# handle errors
|
||||
llm_errors = {
|
||||
'ValueError': LLMBadRequestError,
|
||||
'LLMBadRequestError': LLMBadRequestError,
|
||||
'LLMAPIConnectionError': LLMAPIConnectionError,
|
||||
'LLMAPIUnavailableError': LLMAPIUnavailableError,
|
||||
'LLMRateLimitError': LLMRateLimitError,
|
||||
'ProviderTokenNotInitError': ProviderTokenNotInitError,
|
||||
'QuotaExceededError': QuotaExceededError,
|
||||
'ModelCurrentlyNotSupportError': ModelCurrentlyNotSupportError
|
||||
}
|
||||
|
||||
if error in llm_errors:
|
||||
raise llm_errors[error](description)
|
||||
elif error == 'LLMAuthorizationError':
|
||||
raise LLMAuthorizationError('Incorrect API key provided')
|
||||
else:
|
||||
raise Exception(description)
|
||||
|
||||
@@ -4,14 +4,16 @@ import datetime
|
||||
import time
|
||||
import random
|
||||
import uuid
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, cast
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from extensions.ext_redis import redis_client
|
||||
from flask_login import current_user
|
||||
|
||||
@@ -92,16 +94,18 @@ class DatasetService:
|
||||
f'Dataset with name {name} already exists.')
|
||||
embedding_model = None
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
||||
# dataset = Dataset(name=name, provider=provider, config=config)
|
||||
dataset.created_by = account.id
|
||||
dataset.updated_by = account.id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None
|
||||
dataset.embedding_model = embedding_model.name if embedding_model else None
|
||||
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
|
||||
dataset.embedding_model = embedding_model.model if embedding_model else None
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
return dataset
|
||||
@@ -120,10 +124,12 @@ class DatasetService:
|
||||
def check_dataset_model_setting(dataset):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
@@ -150,14 +156,16 @@ class DatasetService:
|
||||
action = 'add'
|
||||
# get embedding model setting
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
filtered_data['embedding_model'] = embedding_model.name
|
||||
filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name
|
||||
filtered_data['embedding_model'] = embedding_model.model
|
||||
filtered_data['embedding_model_provider'] = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.model_provider.provider_name,
|
||||
embedding_model.name
|
||||
embedding_model.provider,
|
||||
embedding_model.model
|
||||
)
|
||||
filtered_data['collection_binding_id'] = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
@@ -458,14 +466,16 @@ class DocumentService:
|
||||
|
||||
dataset.indexing_technique = document_data["indexing_technique"]
|
||||
if document_data["indexing_technique"] == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
dataset.embedding_model = embedding_model.name
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.model_provider.provider_name,
|
||||
embedding_model.name
|
||||
embedding_model.provider,
|
||||
embedding_model.model
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
@@ -737,12 +747,14 @@ class DocumentService:
|
||||
dataset_collection_binding_id = None
|
||||
retrieval_model = None
|
||||
if document_data['indexing_technique'] == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.model_provider.provider_name,
|
||||
embedding_model.name
|
||||
embedding_model.provider,
|
||||
embedding_model.model
|
||||
)
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
if 'retrieval_model' in document_data and document_data['retrieval_model']:
|
||||
@@ -766,8 +778,8 @@ class DocumentService:
|
||||
data_source_type=document_data["data_source"]["type"],
|
||||
indexing_technique=document_data["indexing_technique"],
|
||||
created_by=account.id,
|
||||
embedding_model=embedding_model.name if embedding_model else None,
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
|
||||
embedding_model=embedding_model.model if embedding_model else None,
|
||||
embedding_model_provider=embedding_model.provider if embedding_model else None,
|
||||
collection_binding_id=dataset_collection_binding_id,
|
||||
retrieval_model=retrieval_model
|
||||
)
|
||||
@@ -989,13 +1001,20 @@ class SegmentService:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
|
||||
tokens = model_type_instance.get_num_tokens(
|
||||
model=embedding_model.model,
|
||||
credentials=embedding_model.credentials,
|
||||
texts=[content]
|
||||
)
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == document.id
|
||||
).scalar()
|
||||
@@ -1037,10 +1056,12 @@ class SegmentService:
|
||||
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == document.id
|
||||
@@ -1054,7 +1075,12 @@ class SegmentService:
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality' and embedding_model:
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
|
||||
tokens = model_type_instance.get_num_tokens(
|
||||
model=embedding_model.model,
|
||||
credentials=embedding_model.credentials,
|
||||
texts=[content]
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
@@ -1121,14 +1147,21 @@ class SegmentService:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
|
||||
tokens = model_type_instance.get_num_tokens(
|
||||
model=embedding_model.model,
|
||||
credentials=embedding_model.credentials,
|
||||
texts=[content]
|
||||
)
|
||||
segment.content = content
|
||||
segment.index_node_hash = segment_hash
|
||||
segment.word_count = len(content)
|
||||
|
||||
0
api/services/entities/__init__.py
Normal file
0
api/services/entities/__init__.py
Normal file
152
api/services/entities/model_provider_entities.py
Normal file
152
api/services/entities/model_provider_entities.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity
|
||||
from core.entities.provider_entities import QuotaConfiguration
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType, ProviderModel
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderCredentialSchema, \
|
||||
ModelCredentialSchema, ProviderHelpEntity, SimpleProviderEntity
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
|
||||
class CustomConfigurationStatus(Enum):
|
||||
"""
|
||||
Enum class for custom configuration status.
|
||||
"""
|
||||
ACTIVE = 'active'
|
||||
NO_CONFIGURE = 'no-configure'
|
||||
|
||||
|
||||
class CustomConfigurationResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider custom configuration response.
|
||||
"""
|
||||
status: CustomConfigurationStatus
|
||||
|
||||
|
||||
class SystemConfigurationResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider system configuration response.
|
||||
"""
|
||||
enabled: bool
|
||||
current_quota_type: Optional[ProviderQuotaType] = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
|
||||
|
||||
class ProviderResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider response.
|
||||
"""
|
||||
provider: str
|
||||
label: I18nObject
|
||||
description: Optional[I18nObject] = None
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
background: Optional[str] = None
|
||||
help: Optional[ProviderHelpEntity] = None
|
||||
supported_model_types: list[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
provider_credential_schema: Optional[ProviderCredentialSchema] = None
|
||||
model_credential_schema: Optional[ModelCredentialSchema] = None
|
||||
preferred_provider_type: ProviderType
|
||||
custom_configuration: CustomConfigurationResponse
|
||||
system_configuration: SystemConfigurationResponse
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
)
|
||||
|
||||
if self.icon_large is not None:
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
|
||||
|
||||
class ModelResponse(ProviderModel):
|
||||
"""
|
||||
Model class for model response.
|
||||
"""
|
||||
status: ModelStatus
|
||||
|
||||
|
||||
class ProviderWithModelsResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider with models response.
|
||||
"""
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
status: CustomConfigurationStatus
|
||||
models: list[ModelResponse]
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
)
|
||||
|
||||
if self.icon_large is not None:
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
|
||||
|
||||
class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
"""
|
||||
Simple provider entity response.
|
||||
"""
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
)
|
||||
|
||||
if self.icon_large is not None:
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
|
||||
|
||||
class DefaultModelResponse(BaseModel):
|
||||
"""
|
||||
Default model entity.
|
||||
"""
|
||||
model: str
|
||||
model_type: ModelType
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
|
||||
class ModelWithProviderEntityResponse(ModelWithProviderEntity):
|
||||
"""
|
||||
Model with provider entity.
|
||||
"""
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
def __init__(self, model: ModelWithProviderEntity) -> None:
|
||||
super().__init__(**model.dict())
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
@@ -11,7 +10,9 @@ from langchain.schema import Document
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DocumentSegment, DatasetQuery
|
||||
@@ -47,11 +48,14 @@ class HitTestingService:
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get embedding model
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
all_documents = []
|
||||
@@ -93,14 +97,22 @@ class HitTestingService:
|
||||
thread.join()
|
||||
|
||||
if retrieval_model['search_method'] == 'hybrid_search':
|
||||
hybrid_rerank = ModelFactory.get_reranking_model(
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
|
||||
model_name=retrieval_model['reranking_model']['reranking_model_name']
|
||||
provider=retrieval_model['reranking_model']['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=retrieval_model['reranking_model']['reranking_model_name']
|
||||
)
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
top_n=retrieval_model['top_k'],
|
||||
user=f"account-{account.id}"
|
||||
)
|
||||
all_documents = hybrid_rerank.rerank(query, all_documents,
|
||||
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
retrieval_model['top_k'])
|
||||
|
||||
end = time.perf_counter()
|
||||
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
from typing import Optional, Union, List
|
||||
|
||||
from core.completion import Completion
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
@@ -216,21 +218,27 @@ class MessageService:
|
||||
raise SuggestedQuestionsAfterAnswerDisabledError()
|
||||
|
||||
# get memory of conversation (read-only)
|
||||
memory = Completion.get_memory_from_conversation(
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
max_token_limit=3000,
|
||||
message_limit=3,
|
||||
return_messages=False,
|
||||
memory_key="histories"
|
||||
provider=app_model_config.model_dict['provider'],
|
||||
model_type=ModelType.LLM,
|
||||
model=app_model_config.model_dict['name']
|
||||
)
|
||||
|
||||
external_context = memory.load_memory_variables({})
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
histories = memory.get_history_prompt_text(
|
||||
max_token_limit=3000,
|
||||
message_limit=3,
|
||||
)
|
||||
|
||||
questions = LLMGenerator.generate_suggested_questions_after_answer(
|
||||
tenant_id=app_model.tenant_id,
|
||||
**external_context
|
||||
histories=histories
|
||||
)
|
||||
|
||||
return questions
|
||||
|
||||
530
api/services/model_provider_service.py
Normal file
530
api/services/model_provider_service.py
Normal file
@@ -0,0 +1,530 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Optional, cast, Tuple
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, DefaultModelEntity
|
||||
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider import ProviderType
|
||||
from services.entities.model_provider_entities import ProviderResponse, CustomConfigurationResponse, \
|
||||
SystemConfigurationResponse, CustomConfigurationStatus, ProviderWithModelsResponse, ModelResponse, \
|
||||
DefaultModelResponse, ModelWithProviderEntityResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProviderService:
|
||||
"""
|
||||
Model Provider Service
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]:
|
||||
"""
|
||||
get provider list.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
provider_responses = []
|
||||
for provider_configuration in provider_configurations.values():
|
||||
if model_type:
|
||||
model_type_entity = ModelType.value_of(model_type)
|
||||
if model_type_entity not in provider_configuration.provider.supported_model_types:
|
||||
continue
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
**provider_configuration.provider.dict(),
|
||||
preferred_provider_type=provider_configuration.preferred_provider_type,
|
||||
custom_configuration=CustomConfigurationResponse(
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if provider_configuration.is_custom_configuration_available()
|
||||
else CustomConfigurationStatus.NO_CONFIGURE
|
||||
),
|
||||
system_configuration=SystemConfigurationResponse(
|
||||
**provider_configuration.system_configuration.dict()
|
||||
)
|
||||
)
|
||||
|
||||
provider_responses.append(provider_response)
|
||||
|
||||
return provider_responses
|
||||
|
||||
def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWithProviderEntityResponse]:
|
||||
"""
|
||||
get provider models.
|
||||
For the model provider page,
|
||||
only supports passing in a single provider to query the list of supported models.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider:
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider available models
|
||||
return [ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(
|
||||
provider=provider
|
||||
)]
|
||||
|
||||
def get_provider_credentials(self, tenant_id: str, provider: str) -> dict:
|
||||
"""
|
||||
get provider credentials.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider:
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Get provider custom credentials from workspace
|
||||
return provider_configuration.get_custom_credentials(obfuscated=True)
|
||||
|
||||
def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None:
|
||||
"""
|
||||
validate provider credentials.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider:
|
||||
:param credentials:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
provider_configuration.custom_credentials_validate(credentials)
|
||||
|
||||
def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
|
||||
"""
|
||||
save custom provider config.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Add or update custom provider credentials.
|
||||
provider_configuration.add_or_update_custom_credentials(credentials)
|
||||
|
||||
def remove_provider_credentials(self, tenant_id: str, provider: str) -> None:
|
||||
"""
|
||||
remove custom provider config.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Remove custom provider credentials.
|
||||
provider_configuration.delete_custom_credentials()
|
||||
|
||||
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict:
|
||||
"""
|
||||
get model credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Get model custom credentials from ProviderModel if exists
|
||||
return provider_configuration.get_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model,
|
||||
obfuscated=True
|
||||
)
|
||||
|
||||
def model_credentials_validate(self, tenant_id: str, provider: str, model_type: str, model: str,
|
||||
credentials: dict) -> None:
|
||||
"""
|
||||
validate model credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Validate model credentials
|
||||
provider_configuration.custom_model_credentials_validate(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model,
|
||||
credentials=credentials
|
||||
)
|
||||
|
||||
def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str,
|
||||
credentials: dict) -> None:
|
||||
"""
|
||||
save model credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Add or update custom model credentials
|
||||
provider_configuration.add_or_update_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model,
|
||||
credentials=credentials
|
||||
)
|
||||
|
||||
def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
|
||||
"""
|
||||
remove model credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Remove custom model credentials
|
||||
provider_configuration.delete_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model
|
||||
)
|
||||
|
||||
def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
|
||||
"""
|
||||
get models by model type.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider available models
|
||||
models = provider_configurations.get_models(
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
|
||||
# Group models by provider
|
||||
provider_models = {}
|
||||
for model in models:
|
||||
if model.provider.provider not in provider_models:
|
||||
provider_models[model.provider.provider] = []
|
||||
|
||||
if model.deprecated:
|
||||
continue
|
||||
|
||||
provider_models[model.provider.provider].append(model)
|
||||
|
||||
# convert to ProviderWithModelsResponse list
|
||||
providers_with_models: list[ProviderWithModelsResponse] = []
|
||||
for provider, models in provider_models.items():
|
||||
if not models:
|
||||
continue
|
||||
|
||||
first_model = models[0]
|
||||
|
||||
has_active_models = any([model.status == ModelStatus.ACTIVE for model in models])
|
||||
|
||||
providers_with_models.append(
|
||||
ProviderWithModelsResponse(
|
||||
provider=provider,
|
||||
label=first_model.provider.label,
|
||||
icon_small=first_model.provider.icon_small,
|
||||
icon_large=first_model.provider.icon_large,
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if has_active_models else CustomConfigurationStatus.NO_CONFIGURE,
|
||||
models=[ModelResponse(
|
||||
model=model.model,
|
||||
label=model.label,
|
||||
model_type=model.model_type,
|
||||
features=model.features,
|
||||
fetch_from=model.fetch_from,
|
||||
model_properties=model.model_properties,
|
||||
status=model.status
|
||||
) for model in models]
|
||||
)
|
||||
)
|
||||
|
||||
return providers_with_models
|
||||
|
||||
def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) -> list[ParameterRule]:
|
||||
"""
|
||||
get model parameter rules.
|
||||
Only supports LLM.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Get model instance of LLM
|
||||
model_type_instance = provider_configuration.get_model_type_instance(ModelType.LLM)
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# fetch credentials
|
||||
credentials = provider_configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model
|
||||
)
|
||||
|
||||
if not credentials:
|
||||
return []
|
||||
|
||||
# Call get_parameter_rules method of model instance to get model parameter rules
|
||||
return model_type_instance.get_parameter_rules(
|
||||
model=model,
|
||||
credentials=credentials
|
||||
)
|
||||
|
||||
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
|
||||
"""
|
||||
get default model of model type.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
result = self.provider_manager.get_default_model(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type_enum
|
||||
)
|
||||
|
||||
return DefaultModelResponse(
|
||||
**result.dict()
|
||||
) if result else None
|
||||
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
||||
"""
|
||||
update default model of model type.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param model_type: model type
|
||||
:param provider: provider name
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
self.provider_manager.update_default_model_record(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type_enum,
|
||||
provider=provider,
|
||||
model=model
|
||||
)
|
||||
|
||||
def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> Tuple[Optional[bytes], Optional[str]]:
|
||||
"""
|
||||
get model provider icon.
|
||||
|
||||
:param provider: provider name
|
||||
:param icon_type: icon type (icon_small or icon_large)
|
||||
:param lang: language (zh_Hans or en_US)
|
||||
:return:
|
||||
"""
|
||||
provider_instance = model_provider_factory.get_provider_instance(provider)
|
||||
provider_schema = provider_instance.get_provider_schema()
|
||||
|
||||
if icon_type.lower() == 'icon_small':
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
|
||||
if lang.lower() == 'zh_hans':
|
||||
file_name = provider_schema.icon_small.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_small.en_US
|
||||
else:
|
||||
if not provider_schema.icon_large:
|
||||
raise ValueError(f"Provider {provider} does not have large icon.")
|
||||
|
||||
if lang.lower() == 'zh_hans':
|
||||
file_name = provider_schema.icon_large.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_large.en_US
|
||||
|
||||
root_path = current_app.root_path
|
||||
provider_instance_path = os.path.dirname(os.path.join(root_path, provider_instance.__class__.__module__.replace('.', '/')))
|
||||
file_path = os.path.join(provider_instance_path, "_assets")
|
||||
file_path = os.path.join(file_path, file_name)
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
return None, None
|
||||
|
||||
mimetype, _ = mimetypes.guess_type(file_path)
|
||||
mimetype = mimetype or 'application/octet-stream'
|
||||
|
||||
# read binary from file
|
||||
with open(file_path, 'rb') as f:
|
||||
byte_data = f.read()
|
||||
return byte_data, mimetype
|
||||
|
||||
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
|
||||
"""
|
||||
switch preferred provider.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param preferred_provider_type: preferred provider type
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Convert preferred_provider_type to ProviderType
|
||||
preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Switch preferred provider type
|
||||
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
|
||||
|
||||
def free_quota_submit(self, tenant_id: str, provider: str):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
|
||||
api_url = api_base_url + '/api/v1/providers/apply'
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {api_key}"
|
||||
}
|
||||
response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider})
|
||||
if not response.ok:
|
||||
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
|
||||
raise ValueError(f"Error: {response.status_code} ")
|
||||
|
||||
if response.json()["code"] != 'success':
|
||||
raise ValueError(
|
||||
f"error: {response.json()['message']}"
|
||||
)
|
||||
|
||||
rst = response.json()
|
||||
|
||||
if rst['type'] == 'redirect':
|
||||
return {
|
||||
'type': rst['type'],
|
||||
'redirect_url': rst['redirect_url']
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'type': rst['type'],
|
||||
'result': 'success'
|
||||
}
|
||||
|
||||
def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
|
||||
api_url = api_base_url + '/api/v1/providers/qualification-verify'
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {api_key}"
|
||||
}
|
||||
json_data = {'workspace_id': tenant_id, 'provider_name': provider}
|
||||
if token:
|
||||
json_data['token'] = token
|
||||
response = requests.post(api_url, headers=headers,
|
||||
json=json_data)
|
||||
if not response.ok:
|
||||
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
|
||||
raise ValueError(f"Error: {response.status_code} ")
|
||||
|
||||
rst = response.json()
|
||||
if rst["code"] != 'success':
|
||||
raise ValueError(
|
||||
f"error: {rst['message']}"
|
||||
)
|
||||
|
||||
data = rst['data']
|
||||
if data['qualified'] is True:
|
||||
return {
|
||||
'result': 'success',
|
||||
'provider_name': provider,
|
||||
'flag': True
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'result': 'success',
|
||||
'provider_name': provider,
|
||||
'flag': False,
|
||||
'reason': data['reason']
|
||||
}
|
||||
@@ -1,596 +0,0 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
|
||||
from models.provider import Provider, ProviderModel, TenantPreferredModelProvider, ProviderType, ProviderQuotaType, \
|
||||
TenantDefaultModel
|
||||
|
||||
|
||||
class ProviderService:
|
||||
|
||||
def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list:
|
||||
"""
|
||||
get provider list of tenant.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param model_type: filter by model type
|
||||
:return:
|
||||
"""
|
||||
# get rules for all providers
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
||||
model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()]
|
||||
|
||||
for model_provider_name, model_provider_rule in model_provider_rules.items():
|
||||
if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types'] \
|
||||
and 'system_config' in model_provider_rule and model_provider_rule['system_config'] \
|
||||
and 'supported_quota_types' in model_provider_rule['system_config'] \
|
||||
and 'trial' in model_provider_rule['system_config']['supported_quota_types']:
|
||||
ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
configurable_model_provider_names = [
|
||||
model_provider_name
|
||||
for model_provider_name, model_provider_rules in model_provider_rules.items()
|
||||
if 'custom' in model_provider_rules['support_provider_types']
|
||||
and model_provider_rules['model_flexibility'] == 'configurable'
|
||||
]
|
||||
|
||||
# get all providers for the tenant
|
||||
providers = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name.in_(model_provider_names),
|
||||
Provider.is_valid == True
|
||||
).order_by(Provider.created_at.desc()).all()
|
||||
|
||||
provider_name_to_provider_dict = defaultdict(list)
|
||||
for provider in providers:
|
||||
provider_name_to_provider_dict[provider.provider_name].append(provider)
|
||||
|
||||
# get all configurable provider models for the tenant
|
||||
provider_models = db.session.query(ProviderModel) \
|
||||
.filter(
|
||||
ProviderModel.tenant_id == tenant_id,
|
||||
ProviderModel.provider_name.in_(configurable_model_provider_names),
|
||||
ProviderModel.is_valid == True
|
||||
).order_by(ProviderModel.created_at.desc()).all()
|
||||
|
||||
provider_name_to_provider_model_dict = defaultdict(list)
|
||||
for provider_model in provider_models:
|
||||
provider_name_to_provider_model_dict[provider_model.provider_name].append(provider_model)
|
||||
|
||||
# get all preferred provider type for the tenant
|
||||
preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.in_(model_provider_names)
|
||||
).all()
|
||||
|
||||
provider_name_to_preferred_provider_type_dict = {preferred_provider_type.provider_name: preferred_provider_type
|
||||
for preferred_provider_type in preferred_provider_types}
|
||||
|
||||
providers_list = {}
|
||||
|
||||
for model_provider_name, model_provider_rule in model_provider_rules.items():
|
||||
if model_type and model_type not in model_provider_rule.get('supported_model_types', []):
|
||||
continue
|
||||
|
||||
# get preferred provider type
|
||||
preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name)
|
||||
preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider(
|
||||
tenant_id,
|
||||
model_provider_name,
|
||||
preferred_model_provider
|
||||
)
|
||||
|
||||
provider_config_dict = {
|
||||
"preferred_provider_type": preferred_provider_type,
|
||||
"model_flexibility": model_provider_rule['model_flexibility'],
|
||||
"supported_model_types": model_provider_rule.get("supported_model_types", []),
|
||||
}
|
||||
|
||||
provider_parameter_dict = {}
|
||||
if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types']:
|
||||
for quota_type_enum in ProviderQuotaType:
|
||||
quota_type = quota_type_enum.value
|
||||
if quota_type in model_provider_rule['system_config']['supported_quota_types']:
|
||||
key = ProviderType.SYSTEM.value + ':' + quota_type
|
||||
provider_parameter_dict[key] = {
|
||||
"provider_name": model_provider_name,
|
||||
"provider_type": ProviderType.SYSTEM.value,
|
||||
"config": None,
|
||||
"is_valid": False, # need update
|
||||
"quota_type": quota_type,
|
||||
"quota_unit": model_provider_rule['system_config']['quota_unit'], # need update
|
||||
"quota_limit": 0 if quota_type != ProviderQuotaType.TRIAL.value else
|
||||
model_provider_rule['system_config']['quota_limit'], # need update
|
||||
"quota_used": 0, # need update
|
||||
"last_used": None # need update
|
||||
}
|
||||
|
||||
if ProviderType.CUSTOM.value in model_provider_rule['support_provider_types']:
|
||||
provider_parameter_dict[ProviderType.CUSTOM.value] = {
|
||||
"provider_name": model_provider_name,
|
||||
"provider_type": ProviderType.CUSTOM.value,
|
||||
"config": None, # need update
|
||||
"models": [], # need update
|
||||
"is_valid": False,
|
||||
"last_used": None # need update
|
||||
}
|
||||
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
|
||||
|
||||
current_providers = provider_name_to_provider_dict[model_provider_name]
|
||||
for provider in current_providers:
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_type = provider.quota_type
|
||||
key = f'{ProviderType.SYSTEM.value}:{quota_type}'
|
||||
|
||||
if key in provider_parameter_dict:
|
||||
provider_parameter_dict[key]['is_valid'] = provider.is_valid
|
||||
provider_parameter_dict[key]['quota_used'] = provider.quota_used
|
||||
provider_parameter_dict[key]['quota_limit'] = provider.quota_limit
|
||||
provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
|
||||
if provider.last_used else None
|
||||
elif provider.provider_type == ProviderType.CUSTOM.value \
|
||||
and ProviderType.CUSTOM.value in provider_parameter_dict:
|
||||
# if custom
|
||||
key = ProviderType.CUSTOM.value
|
||||
provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
|
||||
if provider.last_used else None
|
||||
provider_parameter_dict[key]['is_valid'] = provider.is_valid
|
||||
|
||||
if model_provider_rule['model_flexibility'] == 'fixed':
|
||||
provider_parameter_dict[key]['config'] = model_provider_class(provider=provider) \
|
||||
.get_provider_credentials(obfuscated=True)
|
||||
else:
|
||||
models = []
|
||||
provider_models = provider_name_to_provider_model_dict[model_provider_name]
|
||||
for provider_model in provider_models:
|
||||
models.append({
|
||||
"model_name": provider_model.model_name,
|
||||
"model_type": provider_model.model_type,
|
||||
"config": model_provider_class(provider=provider) \
|
||||
.get_model_credentials(provider_model.model_name,
|
||||
ModelType.value_of(provider_model.model_type),
|
||||
obfuscated=True),
|
||||
"is_valid": provider_model.is_valid
|
||||
})
|
||||
provider_parameter_dict[key]['models'] = models
|
||||
|
||||
provider_config_dict['providers'] = list(provider_parameter_dict.values())
|
||||
providers_list[model_provider_name] = provider_config_dict
|
||||
|
||||
return providers_list
|
||||
|
||||
def custom_provider_config_validate(self, provider_name: str, config: dict) -> None:
|
||||
"""
|
||||
validate custom provider config.
|
||||
|
||||
:param provider_name:
|
||||
:param config:
|
||||
:return:
|
||||
:raises CredentialsValidateFailedError: When the config credential verification fails.
|
||||
"""
|
||||
# get model provider rules
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
|
||||
|
||||
if model_provider_rules['model_flexibility'] != 'fixed':
|
||||
raise ValueError('Only support fixed model provider')
|
||||
|
||||
# only support provider type CUSTOM
|
||||
if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
|
||||
raise ValueError('Only support provider type CUSTOM')
|
||||
|
||||
# validate provider config
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
|
||||
model_provider_class.is_provider_credentials_valid_or_raise(config)
|
||||
|
||||
def save_custom_provider_config(self, tenant_id: str, provider_name: str, config: dict) -> None:
|
||||
"""
|
||||
save custom provider config.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider_name:
|
||||
:param config:
|
||||
:return:
|
||||
"""
|
||||
# validate custom provider config
|
||||
self.custom_provider_config_validate(provider_name, config)
|
||||
|
||||
# get provider
|
||||
provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
|
||||
encrypted_config = model_provider_class.encrypt_provider_credentials(tenant_id, config)
|
||||
|
||||
# save provider
|
||||
if provider:
|
||||
provider.encrypted_config = json.dumps(encrypted_config)
|
||||
provider.is_valid = True
|
||||
provider.updated_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
else:
|
||||
provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_config),
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
def delete_custom_provider(self, tenant_id: str, provider_name: str) -> None:
|
||||
"""
|
||||
delete custom provider.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider_name:
|
||||
:return:
|
||||
"""
|
||||
# get provider
|
||||
provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
if provider:
|
||||
try:
|
||||
self.switch_preferred_provider(tenant_id, provider_name, ProviderType.SYSTEM.value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
def custom_provider_model_config_validate(self,
|
||||
provider_name: str,
|
||||
model_name: str,
|
||||
model_type: str,
|
||||
config: dict) -> None:
|
||||
"""
|
||||
validate custom provider model config.
|
||||
|
||||
:param provider_name:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param config:
|
||||
:return:
|
||||
:raises CredentialsValidateFailedError: When the config credential verification fails.
|
||||
"""
|
||||
# get model provider rules
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
|
||||
|
||||
if model_provider_rules['model_flexibility'] != 'configurable':
|
||||
raise ValueError('Only support configurable model provider')
|
||||
|
||||
# only support provider type CUSTOM
|
||||
if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
|
||||
raise ValueError('Only support provider type CUSTOM')
|
||||
|
||||
# validate provider model config
|
||||
model_type = ModelType.value_of(model_type)
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
|
||||
model_provider_class.is_model_credentials_valid_or_raise(model_name, model_type, config)
|
||||
|
||||
def add_or_save_custom_provider_model_config(self,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
model_name: str,
|
||||
model_type: str,
|
||||
config: dict) -> None:
|
||||
"""
|
||||
Add or save custom provider model config.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider_name:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param config:
|
||||
:return:
|
||||
"""
|
||||
# validate custom provider model config
|
||||
self.custom_provider_model_config_validate(provider_name, model_name, model_type, config)
|
||||
|
||||
# get provider
|
||||
provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
elif not provider.is_valid:
|
||||
provider.is_valid = True
|
||||
provider.encrypted_config = None
|
||||
db.session.commit()
|
||||
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
|
||||
encrypted_config = model_provider_class.encrypt_model_credentials(
|
||||
tenant_id,
|
||||
model_name,
|
||||
ModelType.value_of(model_type),
|
||||
config
|
||||
)
|
||||
|
||||
# get provider model
|
||||
provider_model = db.session.query(ProviderModel) \
|
||||
.filter(
|
||||
ProviderModel.tenant_id == tenant_id,
|
||||
ProviderModel.provider_name == provider_name,
|
||||
ProviderModel.model_name == model_name,
|
||||
ProviderModel.model_type == model_type
|
||||
).first()
|
||||
|
||||
if provider_model:
|
||||
provider_model.encrypted_config = json.dumps(encrypted_config)
|
||||
provider_model.is_valid = True
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_model = ProviderModel(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
encrypted_config=json.dumps(encrypted_config),
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider_model)
|
||||
db.session.commit()
|
||||
|
||||
def delete_custom_provider_model(self,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
model_name: str,
|
||||
model_type: str) -> None:
|
||||
"""
|
||||
delete custom provider model.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider_name:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model = db.session.query(ProviderModel) \
|
||||
.filter(
|
||||
ProviderModel.tenant_id == tenant_id,
|
||||
ProviderModel.provider_name == provider_name,
|
||||
ProviderModel.model_name == model_name,
|
||||
ProviderModel.model_type == model_type
|
||||
).first()
|
||||
|
||||
if provider_model:
|
||||
db.session.delete(provider_model)
|
||||
db.session.commit()
|
||||
|
||||
def switch_preferred_provider(self, tenant_id: str, provider_name: str, preferred_provider_type: str) -> None:
|
||||
"""
|
||||
switch preferred provider.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider_name:
|
||||
:param preferred_provider_type:
|
||||
:return:
|
||||
"""
|
||||
provider_type = ProviderType.value_of(preferred_provider_type)
|
||||
if not provider_type:
|
||||
raise ValueError(f'Invalid preferred provider type: {preferred_provider_type}')
|
||||
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
|
||||
if preferred_provider_type not in model_provider_rules['support_provider_types']:
|
||||
raise ValueError(f'Not support provider type: {preferred_provider_type}')
|
||||
|
||||
model_provider = ModelProviderFactory.get_model_provider_class(provider_name)
|
||||
if not model_provider.is_provider_type_system_supported():
|
||||
return
|
||||
|
||||
# get preferred provider
|
||||
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == provider_name
|
||||
).first()
|
||||
|
||||
if preferred_model_provider:
|
||||
preferred_model_provider.preferred_provider_type = preferred_provider_type
|
||||
else:
|
||||
preferred_model_provider = TenantPreferredModelProvider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
preferred_provider_type=preferred_provider_type
|
||||
)
|
||||
db.session.add(preferred_model_provider)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[TenantDefaultModel]:
|
||||
"""
|
||||
get default model of model type.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelFactory.get_default_model(tenant_id, ModelType.value_of(model_type))
|
||||
|
||||
def update_default_model_of_model_type(self,
|
||||
tenant_id: str,
|
||||
model_type: str,
|
||||
provider_name: str,
|
||||
model_name: str) -> TenantDefaultModel:
|
||||
"""
|
||||
update default model of model type.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:param provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
return ModelFactory.update_default_model(tenant_id, ModelType.value_of(model_type), provider_name, model_name)
|
||||
|
||||
def get_valid_model_list(self, tenant_id: str, model_type: str) -> list:
|
||||
"""
|
||||
get valid model list.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
valid_model_list = []
|
||||
|
||||
# get model provider rules
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
||||
for model_provider_name, model_provider_rule in model_provider_rules.items():
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
if not model_provider:
|
||||
continue
|
||||
|
||||
model_list = model_provider.get_supported_model_list(ModelType.value_of(model_type))
|
||||
provider = model_provider.provider
|
||||
for model in model_list:
|
||||
valid_model_dict = {
|
||||
"model_name": model['id'],
|
||||
"model_display_name": model['name'],
|
||||
"model_type": model_type,
|
||||
"model_provider": {
|
||||
"provider_name": provider.provider_name,
|
||||
"provider_type": provider.provider_type
|
||||
},
|
||||
'features': []
|
||||
}
|
||||
|
||||
if 'mode' in model:
|
||||
valid_model_dict['model_mode'] = model['mode']
|
||||
|
||||
if 'features' in model:
|
||||
valid_model_dict['features'] = model['features']
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
valid_model_dict['model_provider']['quota_type'] = provider.quota_type
|
||||
valid_model_dict['model_provider']['quota_unit'] = model_provider_rule['system_config']['quota_unit']
|
||||
valid_model_dict['model_provider']['quota_limit'] = provider.quota_limit
|
||||
valid_model_dict['model_provider']['quota_used'] = provider.quota_used
|
||||
|
||||
valid_model_list.append(valid_model_dict)
|
||||
|
||||
return valid_model_list
|
||||
|
||||
def get_model_parameter_rules(self, tenant_id: str, model_provider_name: str, model_name: str, model_type: str) \
|
||||
-> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
It depends on preferred provider in use.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
if not model_provider:
|
||||
# get empty model provider
|
||||
return ModelKwargsRules()
|
||||
|
||||
# get model parameter rules
|
||||
return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type))
|
||||
|
||||
def free_quota_submit(self, tenant_id: str, provider_name: str):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
|
||||
api_url = api_base_url + '/api/v1/providers/apply'
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {api_key}"
|
||||
}
|
||||
response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider_name})
|
||||
if not response.ok:
|
||||
logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
|
||||
raise ValueError(f"Error: {response.status_code} ")
|
||||
|
||||
if response.json()["code"] != 'success':
|
||||
raise ValueError(
|
||||
f"error: {response.json()['message']}"
|
||||
)
|
||||
|
||||
rst = response.json()
|
||||
|
||||
if rst['type'] == 'redirect':
|
||||
return {
|
||||
'type': rst['type'],
|
||||
'redirect_url': rst['redirect_url']
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'type': rst['type'],
|
||||
'result': 'success'
|
||||
}
|
||||
|
||||
def free_quota_qualification_verify(self, tenant_id: str, provider_name: str, token: Optional[str]):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
|
||||
api_url = api_base_url + '/api/v1/providers/qualification-verify'
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {api_key}"
|
||||
}
|
||||
json_data = {'workspace_id': tenant_id, 'provider_name': provider_name}
|
||||
if token:
|
||||
json_data['token'] = token
|
||||
response = requests.post(api_url, headers=headers,
|
||||
json=json_data)
|
||||
if not response.ok:
|
||||
logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
|
||||
raise ValueError(f"Error: {response.status_code} ")
|
||||
|
||||
rst = response.json()
|
||||
if rst["code"] != 'success':
|
||||
raise ValueError(
|
||||
f"error: {rst['message']}"
|
||||
)
|
||||
|
||||
data = rst['data']
|
||||
if data['qualified'] is True:
|
||||
return {
|
||||
'result': 'success',
|
||||
'provider_name': provider_name,
|
||||
'flag': True
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'result': 'success',
|
||||
'provider_name': provider_name,
|
||||
'flag': False,
|
||||
'reason': data['reason']
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
|
||||
from typing import Optional
|
||||
from flask import current_app, Flask
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
@@ -50,12 +52,24 @@ class RetrievalService:
|
||||
|
||||
if documents:
|
||||
if reranking_model and search_method == 'semantic_search':
|
||||
rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=reranking_model['reranking_provider_name'],
|
||||
model_name=reranking_model['reranking_model_name']
|
||||
)
|
||||
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model['reranking_model_name']
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents.extend(rerank_runner.run(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents)
|
||||
))
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
|
||||
@@ -81,15 +95,23 @@ class RetrievalService:
|
||||
)
|
||||
if documents:
|
||||
if reranking_model and search_method == 'full_text_search':
|
||||
rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=reranking_model['reranking_provider_name'],
|
||||
model_name=reranking_model['reranking_model_name']
|
||||
)
|
||||
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model['reranking_model_name']
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents.extend(rerank_runner.run(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents)
|
||||
))
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ class WorkspaceService:
|
||||
'plan': tenant.plan,
|
||||
'status': tenant.status,
|
||||
'created_at': tenant.created_at,
|
||||
'providers': [],
|
||||
'in_trail': True,
|
||||
'trial_end_reason': None,
|
||||
'role': 'normal',
|
||||
@@ -37,12 +36,4 @@ class WorkspaceService:
|
||||
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
|
||||
tenant_info['custom_config'] = tenant.custom_config_dict
|
||||
|
||||
# Get providers
|
||||
providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id
|
||||
).all()
|
||||
|
||||
# Add providers to the tenant info
|
||||
tenant_info['providers'] = providers
|
||||
|
||||
return tenant_info
|
||||
|
||||
Reference in New Issue
Block a user