feat: add spark v3.0 llm support (#1434)

This commit is contained in:
takatost
2023-10-31 16:13:11 +08:00
committed by GitHub
parent 518083dfe0
commit 076f3289d2
3 changed files with 55 additions and 11 deletions

View File

@@ -28,14 +28,19 @@ class SparkProvider(BaseModelProvider):
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'spark',
'name': 'Spark V1.5',
'id': 'spark-v3',
'name': 'Spark V3.0',
'mode': ModelMode.CHAT.value,
},
{
'id': 'spark-v2',
'name': 'Spark V2.0',
'mode': ModelMode.CHAT.value,
},
{
'id': 'spark',
'name': 'Spark V1.5',
'mode': ModelMode.CHAT.value,
}
]
else:
@@ -96,7 +101,7 @@ class SparkProvider(BaseModelProvider):
try:
chat_llm = ChatSpark(
model_name='spark-v2',
model_name='spark-v3',
max_tokens=10,
temperature=0.01,
**credential_kwargs
@@ -110,10 +115,10 @@ class SparkProvider(BaseModelProvider):
chat_llm(messages)
except SparkError as ex:
# try spark v1.5 if v2.1 failed
# try spark v2.1 if v3.1 failed
try:
chat_llm = ChatSpark(
model_name='spark',
model_name='spark-v2',
max_tokens=10,
temperature=0.01,
**credential_kwargs
@@ -127,10 +132,27 @@ class SparkProvider(BaseModelProvider):
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
# try spark v1.5 if v2.1 failed
try:
chat_llm = ChatSpark(
model_name='spark',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex