74 lines
2.4 KiB
Python
74 lines
2.4 KiB
Python
from functools import wraps
|
|
from flask import jsonify
|
|
from flask_jwt_extended import verify_jwt_in_request, get_jwt_identity
|
|
from models import User, UserGroup, db
|
|
|
|
def login_required(f):
|
|
@wraps(f)
|
|
def wrapper(*args, **kwargs):
|
|
verify_jwt_in_request()
|
|
user_id = get_jwt_identity()
|
|
user = db.session.get(User, user_id)
|
|
|
|
# 检查用户是否存在且状态为激活
|
|
if not user:
|
|
return jsonify({'error': '用户不存在', 'code': 'USER_NOT_FOUND'}), 401
|
|
if user.status != 'active':
|
|
return jsonify({'error': '账号已被禁用', 'code': 'USER_DISABLED'}), 403
|
|
|
|
return f(*args, **kwargs)
|
|
return wrapper
|
|
|
|
def admin_required(f):
|
|
@wraps(f)
|
|
def wrapper(*args, **kwargs):
|
|
verify_jwt_in_request()
|
|
user_id = get_jwt_identity()
|
|
user = db.session.get(User, user_id)
|
|
|
|
# 检查用户是否存在且状态为激活
|
|
if not user:
|
|
return jsonify({'error': '用户不存在', 'code': 'USER_NOT_FOUND'}), 401
|
|
if user.status != 'active':
|
|
return jsonify({'error': '账号已被禁用', 'code': 'USER_DISABLED'}), 403
|
|
if user.role != 'admin':
|
|
return jsonify({'error': '需要管理员权限'}), 403
|
|
|
|
return f(*args, **kwargs)
|
|
return wrapper
|
|
|
|
def group_member_required(group_id):
|
|
def decorator(f):
|
|
@wraps(f)
|
|
def wrapper(*args, **kwargs):
|
|
verify_jwt_in_request()
|
|
user_id = get_jwt_identity()
|
|
|
|
# 检查用户是否存在且状态为激活
|
|
user = db.session.get(User, user_id)
|
|
if not user:
|
|
return jsonify({'error': '用户不存在', 'code': 'USER_NOT_FOUND'}), 401
|
|
if user.status != 'active':
|
|
return jsonify({'error': '账号已被禁用', 'code': 'USER_DISABLED'}), 403
|
|
|
|
# 检查是否为管理员
|
|
if user.role == 'admin':
|
|
return f(*args, **kwargs)
|
|
|
|
# 检查是否为组成员
|
|
membership = UserGroup.query.filter_by(
|
|
user_id=user_id,
|
|
group_id=group_id
|
|
).first()
|
|
|
|
if not membership:
|
|
return jsonify({'error': '无权访问该组别'}), 403
|
|
|
|
return f(*args, **kwargs)
|
|
return wrapper
|
|
return decorator
|
|
|
|
def get_current_user():
|
|
user_id = get_jwt_identity()
|
|
return db.session.get(User, user_id)
|