Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -3,10 +3,10 @@ import logging
|
||||
|
||||
from langchain.schema import OutputParserException
|
||||
|
||||
from core.model_providers.error import LLMError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage, SystemPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
|
||||
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
@@ -26,17 +26,22 @@ class LLMGenerator:
|
||||
|
||||
prompt += query + "\n"
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=1,
|
||||
max_tokens=100
|
||||
)
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
prompts = [PromptMessage(content=prompt)]
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
prompts = [UserPromptMessage(content=prompt)]
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompts,
|
||||
model_parameters={
|
||||
"max_tokens": 100,
|
||||
"temperature": 1
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
answer = response.message.content
|
||||
|
||||
result_dict = json.loads(answer)
|
||||
answer = result_dict['Your Output']
|
||||
@@ -62,22 +67,28 @@ class LLMGenerator:
|
||||
})
|
||||
|
||||
try:
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=256,
|
||||
temperature=0
|
||||
)
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
except InvokeAuthorizationError:
|
||||
return []
|
||||
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
output = model_instance.run(prompt_messages)
|
||||
questions = output_parser.parse(output.content)
|
||||
except LLMError:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={
|
||||
"max_tokens": 256,
|
||||
"temperature": 0
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
|
||||
questions = output_parser.parse(response.message.content)
|
||||
except InvokeError:
|
||||
questions = []
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
@@ -105,20 +116,26 @@ class LLMGenerator:
|
||||
remove_template_variables=False
|
||||
)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=512,
|
||||
temperature=0
|
||||
)
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
output = model_instance.run(prompt_messages)
|
||||
rule_config = output_parser.parse(output.content)
|
||||
except LLMError as e:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={
|
||||
"max_tokens": 512,
|
||||
"temperature": 0
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
|
||||
rule_config = output_parser.parse(response.message.content)
|
||||
except InvokeError as e:
|
||||
raise e
|
||||
except OutputParserException:
|
||||
raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
|
||||
@@ -136,18 +153,24 @@ class LLMGenerator:
|
||||
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
||||
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=2000
|
||||
)
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
PromptMessage(content=prompt, type=MessageType.SYSTEM),
|
||||
PromptMessage(content=query)
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=prompt),
|
||||
UserPromptMessage(content=query)
|
||||
]
|
||||
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={
|
||||
"max_tokens": 2000
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
|
||||
answer = response.message.content
|
||||
return answer.strip()
|
||||
|
||||
Reference in New Issue
Block a user