diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index e25f92399..5ad764596 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -70,7 +70,7 @@ from .app import ( ) # Import auth controllers -from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth +from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server # Import billing controllers from .billing import billing, compliance diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py new file mode 100644 index 000000000..19ca464a7 --- /dev/null +++ b/api/controllers/console/auth/oauth_server.py @@ -0,0 +1,189 @@ +from functools import wraps +from typing import cast + +import flask_login +from flask import request +from flask_restx import Resource, reqparse +from werkzeug.exceptions import BadRequest, NotFound + +from controllers.console.wraps import account_initialization_required, setup_required +from core.model_runtime.utils.encoders import jsonable_encoder +from libs.login import login_required +from models.account import Account +from models.model import OAuthProviderApp +from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService + +from .. import api + + +def oauth_server_client_id_required(view): + @wraps(view) + def decorated(*args, **kwargs): + parser = reqparse.RequestParser() + parser.add_argument("client_id", type=str, required=True, location="json") + parsed_args = parser.parse_args() + client_id = parsed_args.get("client_id") + if not client_id: + raise BadRequest("client_id is required") + + oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id) + if not oauth_provider_app: + raise NotFound("client_id is invalid") + + kwargs["oauth_provider_app"] = oauth_provider_app + + return view(*args, **kwargs) + + return decorated + + +def oauth_server_access_token_required(view): + @wraps(view) + def decorated(*args, **kwargs): + oauth_provider_app = kwargs.get("oauth_provider_app") + if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp): + raise BadRequest("Invalid oauth_provider_app") + + if not request.headers.get("Authorization"): + raise BadRequest("Authorization is required") + + authorization_header = request.headers.get("Authorization") + if not authorization_header: + raise BadRequest("Authorization header is required") + + parts = authorization_header.split(" ") + if len(parts) != 2: + raise BadRequest("Invalid Authorization header format") + + token_type = parts[0] + if token_type != "Bearer": + raise BadRequest("token_type is invalid") + + access_token = parts[1] + if not access_token: + raise BadRequest("access_token is required") + + account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token) + if not account: + raise BadRequest("access_token or client_id is invalid") + + kwargs["account"] = account + + return view(*args, **kwargs) + + return decorated + + +class OAuthServerAppApi(Resource): + @setup_required + @oauth_server_client_id_required + def post(self, oauth_provider_app: OAuthProviderApp): + parser = reqparse.RequestParser() + parser.add_argument("redirect_uri", type=str, required=True, location="json") + parsed_args = parser.parse_args() + redirect_uri = parsed_args.get("redirect_uri") + + # check if redirect_uri is valid + if redirect_uri not in oauth_provider_app.redirect_uris: + raise BadRequest("redirect_uri is invalid") + + return jsonable_encoder( + { + "app_icon": oauth_provider_app.app_icon, + "app_label": oauth_provider_app.app_label, + "scope": oauth_provider_app.scope, + } + ) + + +class OAuthServerUserAuthorizeApi(Resource): + @setup_required + @login_required + @account_initialization_required + @oauth_server_client_id_required + def post(self, oauth_provider_app: OAuthProviderApp): + account = cast(Account, flask_login.current_user) + user_account_id = account.id + + code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id) + return jsonable_encoder( + { + "code": code, + } + ) + + +class OAuthServerUserTokenApi(Resource): + @setup_required + @oauth_server_client_id_required + def post(self, oauth_provider_app: OAuthProviderApp): + parser = reqparse.RequestParser() + parser.add_argument("grant_type", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=False, location="json") + parser.add_argument("client_secret", type=str, required=False, location="json") + parser.add_argument("redirect_uri", type=str, required=False, location="json") + parser.add_argument("refresh_token", type=str, required=False, location="json") + parsed_args = parser.parse_args() + + grant_type = OAuthGrantType(parsed_args["grant_type"]) + + if grant_type == OAuthGrantType.AUTHORIZATION_CODE: + if not parsed_args["code"]: + raise BadRequest("code is required") + + if parsed_args["client_secret"] != oauth_provider_app.client_secret: + raise BadRequest("client_secret is invalid") + + if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris: + raise BadRequest("redirect_uri is invalid") + + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) + elif grant_type == OAuthGrantType.REFRESH_TOKEN: + if not parsed_args["refresh_token"]: + raise BadRequest("refresh_token is required") + + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) + else: + raise BadRequest("invalid grant_type") + + +class OAuthServerUserAccountApi(Resource): + @setup_required + @oauth_server_client_id_required + @oauth_server_access_token_required + def post(self, oauth_provider_app: OAuthProviderApp, account: Account): + return jsonable_encoder( + { + "name": account.name, + "email": account.email, + "avatar": account.avatar, + "interface_language": account.interface_language, + "timezone": account.timezone, + } + ) + + +api.add_resource(OAuthServerAppApi, "/oauth/provider") +api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize") +api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token") +api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account") diff --git a/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py new file mode 100644 index 000000000..5986853f0 --- /dev/null +++ b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py @@ -0,0 +1,45 @@ +"""empty message + +Revision ID: 8d289573e1da +Revises: fa8b0fa6f407 +Create Date: 2025-08-20 17:47:17.015695 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '8d289573e1da' +down_revision = '0e154742a5fa' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('oauth_provider_apps', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_icon', sa.String(length=255), nullable=False), + sa.Column('app_label', sa.JSON(), server_default='{}', nullable=False), + sa.Column('client_id', sa.String(length=255), nullable=False), + sa.Column('client_secret', sa.String(length=255), nullable=False), + sa.Column('redirect_uris', sa.JSON(), server_default='[]', nullable=False), + sa.Column('scope', sa.String(length=255), server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='oauth_provider_app_pkey') + ) + with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op: + batch_op.create_index('oauth_provider_app_client_id_idx', ['client_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op: + batch_op.drop_index('oauth_provider_app_client_id_idx') + + op.drop_table('oauth_provider_apps') + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 53646c015..6a0e0af48 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -580,6 +580,32 @@ class InstalledApp(Base): return tenant +class OAuthProviderApp(Base): + """ + Globally shared OAuth provider app information. + Only for Dify Cloud. + """ + + __tablename__ = "oauth_provider_apps" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="oauth_provider_app_pkey"), + sa.Index("oauth_provider_app_client_id_idx", "client_id"), + ) + + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + app_icon = mapped_column(String(255), nullable=False) + app_label = mapped_column(sa.JSON, nullable=False, server_default="{}") + client_id = mapped_column(String(255), nullable=False) + client_secret = mapped_column(String(255), nullable=False) + redirect_uris = mapped_column(sa.JSON, nullable=False, server_default="[]") + scope = mapped_column( + String(255), + nullable=False, + server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"), + ) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")) + + class Conversation(Base): __tablename__ = "conversations" __table_args__ = ( diff --git a/api/services/oauth_server.py b/api/services/oauth_server.py new file mode 100644 index 000000000..b722dbee2 --- /dev/null +++ b/api/services/oauth_server.py @@ -0,0 +1,94 @@ +import enum +import uuid + +from sqlalchemy import select +from sqlalchemy.orm import Session +from werkzeug.exceptions import BadRequest + +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import Account +from models.model import OAuthProviderApp +from services.account_service import AccountService + + +class OAuthGrantType(enum.StrEnum): + AUTHORIZATION_CODE = "authorization_code" + REFRESH_TOKEN = "refresh_token" + + +OAUTH_AUTHORIZATION_CODE_REDIS_KEY = "oauth_provider:{client_id}:authorization_code:{code}" +OAUTH_ACCESS_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:access_token:{token}" +OAUTH_ACCESS_TOKEN_EXPIRES_IN = 60 * 60 * 12 # 12 hours +OAUTH_REFRESH_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:refresh_token:{token}" +OAUTH_REFRESH_TOKEN_EXPIRES_IN = 60 * 60 * 24 * 30 # 30 days + + +class OAuthServerService: + @staticmethod + def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None: + query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id) + + with Session(db.engine) as session: + return session.execute(query).scalar_one_or_none() + + @staticmethod + def sign_oauth_authorization_code(client_id: str, user_account_id: str) -> str: + code = str(uuid.uuid4()) + redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code) + redis_client.set(redis_key, user_account_id, ex=60 * 10) # 10 minutes + return code + + @staticmethod + def sign_oauth_access_token( + grant_type: OAuthGrantType, + code: str = "", + client_id: str = "", + refresh_token: str = "", + ) -> tuple[str, str]: + match grant_type: + case OAuthGrantType.AUTHORIZATION_CODE: + redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code) + user_account_id = redis_client.get(redis_key) + if not user_account_id: + raise BadRequest("invalid code") + + # delete code + redis_client.delete(redis_key) + + access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id) + refresh_token = OAuthServerService._sign_oauth_refresh_token(client_id, user_account_id) + return access_token, refresh_token + case OAuthGrantType.REFRESH_TOKEN: + redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=refresh_token) + user_account_id = redis_client.get(redis_key) + if not user_account_id: + raise BadRequest("invalid refresh token") + + access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id) + return access_token, refresh_token + + @staticmethod + def _sign_oauth_access_token(client_id: str, user_account_id: str) -> str: + token = str(uuid.uuid4()) + redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token) + redis_client.set(redis_key, user_account_id, ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN) + return token + + @staticmethod + def _sign_oauth_refresh_token(client_id: str, user_account_id: str) -> str: + token = str(uuid.uuid4()) + redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=token) + redis_client.set(redis_key, user_account_id, ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN) + return token + + @staticmethod + def validate_oauth_access_token(client_id: str, token: str) -> Account | None: + redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token) + user_account_id = redis_client.get(redis_key) + if not user_account_id: + return None + + user_id_str = user_account_id.decode("utf-8") + + return AccountService.load_user(user_id_str) diff --git a/web/app/account/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx similarity index 100% rename from web/app/account/account-page/AvatarWithEdit.tsx rename to web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx diff --git a/web/app/account/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx similarity index 100% rename from web/app/account/account-page/email-change-modal.tsx rename to web/app/account/(commonLayout)/account-page/email-change-modal.tsx diff --git a/web/app/account/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx similarity index 100% rename from web/app/account/account-page/index.tsx rename to web/app/account/(commonLayout)/account-page/index.tsx diff --git a/web/app/account/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx similarity index 100% rename from web/app/account/avatar.tsx rename to web/app/account/(commonLayout)/avatar.tsx diff --git a/web/app/account/delete-account/components/check-email.tsx b/web/app/account/(commonLayout)/delete-account/components/check-email.tsx similarity index 100% rename from web/app/account/delete-account/components/check-email.tsx rename to web/app/account/(commonLayout)/delete-account/components/check-email.tsx diff --git a/web/app/account/delete-account/components/feed-back.tsx b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx similarity index 100% rename from web/app/account/delete-account/components/feed-back.tsx rename to web/app/account/(commonLayout)/delete-account/components/feed-back.tsx diff --git a/web/app/account/delete-account/components/verify-email.tsx b/web/app/account/(commonLayout)/delete-account/components/verify-email.tsx similarity index 100% rename from web/app/account/delete-account/components/verify-email.tsx rename to web/app/account/(commonLayout)/delete-account/components/verify-email.tsx diff --git a/web/app/account/delete-account/index.tsx b/web/app/account/(commonLayout)/delete-account/index.tsx similarity index 100% rename from web/app/account/delete-account/index.tsx rename to web/app/account/(commonLayout)/delete-account/index.tsx diff --git a/web/app/account/delete-account/state.tsx b/web/app/account/(commonLayout)/delete-account/state.tsx similarity index 100% rename from web/app/account/delete-account/state.tsx rename to web/app/account/(commonLayout)/delete-account/state.tsx diff --git a/web/app/account/header.tsx b/web/app/account/(commonLayout)/header.tsx similarity index 97% rename from web/app/account/header.tsx rename to web/app/account/(commonLayout)/header.tsx index af09ca1c9..ce804055b 100644 --- a/web/app/account/header.tsx +++ b/web/app/account/(commonLayout)/header.tsx @@ -2,11 +2,11 @@ import { useTranslation } from 'react-i18next' import { RiArrowRightUpLine, RiRobot2Line } from '@remixicon/react' import { useRouter } from 'next/navigation' -import Button from '../components/base/button' -import Avatar from './avatar' +import Button from '@/app/components/base/button' import DifyLogo from '@/app/components/base/logo/dify-logo' import { useCallback } from 'react' import { useGlobalPublicStore } from '@/context/global-public-context' +import Avatar from './avatar' const Header = () => { const { t } = useTranslation() diff --git a/web/app/account/layout.tsx b/web/app/account/(commonLayout)/layout.tsx similarity index 100% rename from web/app/account/layout.tsx rename to web/app/account/(commonLayout)/layout.tsx diff --git a/web/app/account/page.tsx b/web/app/account/(commonLayout)/page.tsx similarity index 100% rename from web/app/account/page.tsx rename to web/app/account/(commonLayout)/page.tsx diff --git a/web/app/account/oauth/authorize/layout.tsx b/web/app/account/oauth/authorize/layout.tsx new file mode 100644 index 000000000..078d23114 --- /dev/null +++ b/web/app/account/oauth/authorize/layout.tsx @@ -0,0 +1,37 @@ +'use client' +import Header from '@/app/signin/_header' + +import cn from '@/utils/classnames' +import { useGlobalPublicStore } from '@/context/global-public-context' +import useDocumentTitle from '@/hooks/use-document-title' +import { AppContextProvider } from '@/context/app-context' +import { useMemo } from 'react' + +export default function SignInLayout({ children }: any) { + const { systemFeatures } = useGlobalPublicStore() + useDocumentTitle('') + const isLoggedIn = useMemo(() => { + try { + return Boolean(localStorage.getItem('console_token') && localStorage.getItem('refresh_token')) + } + catch { return false } + }, []) + return <> +