| from flask import Blueprint, request, jsonify |
| from app.services.recommendation_service import RecommendationService |
| from app.functions.FAuth import FAuth |
| from sqlalchemy import create_engine |
| from sqlalchemy.orm import sessionmaker |
| from config import Config |
| from functools import wraps |
| |
| recommend_bp = Blueprint('recommend', __name__, url_prefix='/api/recommend') |
| |
| def token_required(f): |
| """装饰器:需要令牌验证""" |
| @wraps(f) |
| def decorated(*args, **kwargs): |
| token = request.headers.get('Authorization') |
| if not token: |
| return jsonify({'success': False, 'message': '缺少访问令牌'}), 401 |
| |
| session = None |
| try: |
| # 移除Bearer前缀 |
| if token.startswith('Bearer '): |
| token = token[7:] |
| |
| engine = create_engine(Config.SQLURL) |
| SessionLocal = sessionmaker(bind=engine) |
| session = SessionLocal() |
| f_auth = FAuth(session) |
| |
| user = f_auth.get_user_by_token(token) |
| if not user: |
| return jsonify({'success': False, 'message': '无效的访问令牌'}), 401 |
| |
| # 将用户信息传递给路由函数 |
| return f(user, *args, **kwargs) |
| except Exception as e: |
| if session: |
| session.rollback() |
| return jsonify({'success': False, 'message': '令牌验证失败'}), 401 |
| finally: |
| if session: |
| session.close() |
| |
| return decorated |
| |
| # 初始化推荐服务 |
| recommendation_service = RecommendationService() |
| |
| @recommend_bp.route('/get_recommendations', methods=['POST']) |
| @token_required |
| def get_recommendations(current_user): |
| """获取个性化推荐""" |
| try: |
| data = request.get_json() |
| user_id = data.get('user_id') or current_user.user_id |
| topk = data.get('topk', 2) |
| |
| recommendations = recommendation_service.get_recommendations(user_id, topk) |
| |
| return jsonify({ |
| 'success': True, |
| 'data': { |
| 'user_id': user_id, |
| 'recommendations': recommendations, |
| 'count': len(recommendations) |
| }, |
| 'message': '推荐获取成功' |
| }) |
| except Exception as e: |
| return jsonify({ |
| 'success': False, |
| 'message': f'推荐获取失败: {str(e)}' |
| }), 500 |
| |
| @recommend_bp.route('/cold_start', methods=['GET']) |
| def cold_start_recommendations(): |
| """冷启动推荐(无需登录)""" |
| try: |
| topk = request.args.get('topk', 2, type=int) |
| |
| recommendations = recommendation_service.user_cold_start(topk) |
| |
| return jsonify({ |
| 'success': True, |
| 'data': { |
| 'recommendations': recommendations, |
| 'count': len(recommendations), |
| 'type': 'cold_start' |
| }, |
| 'message': '热门推荐获取成功' |
| }) |
| except Exception as e: |
| return jsonify({ |
| 'success': False, |
| 'message': f'推荐获取失败: {str(e)}' |
| }), 500 |
| |
| @recommend_bp.route('/health', methods=['GET']) |
| def health_check(): |
| """推荐系统健康检查""" |
| try: |
| # 简单的健康检查 |
| import torch |
| cuda_available = torch.cuda.is_available() |
| |
| return jsonify({ |
| 'success': True, |
| 'data': { |
| 'status': 'healthy', |
| 'cuda_available': cuda_available, |
| 'device': 'cuda' if cuda_available else 'cpu' |
| }, |
| 'message': '推荐系统运行正常' |
| }) |
| except Exception as e: |
| return jsonify({ |
| 'success': False, |
| 'message': f'推荐系统异常: {str(e)}' |
| }), 500 |
| |
| @recommend_bp.route('/multi_recall', methods=['POST']) |
| @token_required |
| def multi_recall_recommendations(current_user): |
| """多路召回推荐""" |
| try: |
| data = request.get_json() |
| user_id = data.get('user_id') or current_user.user_id |
| topk = data.get('topk', 2) |
| |
| # 强制使用多路召回 |
| result = recommendation_service.run_inference(user_id, topk, use_multi_recall=True) |
| |
| # 如果是冷启动直接返回详细信息,否则查详情 |
| if isinstance(result, list) and result and isinstance(result[0], dict): |
| recommendations = result |
| else: |
| # result 是 (topk_post_ids, topk_scores) 的元组 |
| if isinstance(result, tuple) and len(result) == 2: |
| topk_post_ids, topk_scores = result |
| recommendations = recommendation_service.get_post_info(topk_post_ids, topk_scores) |
| else: |
| recommendations = recommendation_service.get_post_info(result) |
| |
| return jsonify({ |
| 'success': True, |
| 'data': { |
| 'user_id': user_id, |
| 'recommendations': recommendations, |
| 'count': len(recommendations), |
| 'type': 'multi_recall' |
| }, |
| 'message': '多路召回推荐获取成功' |
| }) |
| except Exception as e: |
| return jsonify({ |
| 'success': False, |
| 'message': f'多路召回推荐获取失败: {str(e)}' |
| }), 500 |
| |
| @recommend_bp.route('/lightgcn', methods=['POST']) |
| @token_required |
| def lightgcn_recommendations(current_user): |
| """LightGCN推荐""" |
| try: |
| data = request.get_json() |
| user_id = data.get('user_id') or current_user.user_id |
| topk = data.get('topk', 2) |
| |
| # 强制使用LightGCN |
| result = recommendation_service.run_inference(user_id, topk, use_multi_recall=False) |
| |
| # 如果是冷启动直接返回详细信息,否则查详情 |
| if isinstance(result, list) and result and isinstance(result[0], dict): |
| recommendations = result |
| else: |
| # result 是 (topk_post_ids, topk_scores) 的元组 |
| if isinstance(result, tuple) and len(result) == 2: |
| topk_post_ids, topk_scores = result |
| recommendations = recommendation_service.get_post_info(topk_post_ids, topk_scores) |
| else: |
| recommendations = recommendation_service.get_post_info(result) |
| |
| return jsonify({ |
| 'success': True, |
| 'data': { |
| 'user_id': user_id, |
| 'recommendations': recommendations, |
| 'count': len(recommendations), |
| 'type': 'lightgcn' |
| }, |
| 'message': 'LightGCN推荐获取成功' |
| }) |
| except Exception as e: |
| return jsonify({ |
| 'success': False, |
| 'message': f'LightGCN推荐获取失败: {str(e)}' |
| }), 500 |
| |
| @recommend_bp.route('/train_multi_recall', methods=['POST']) |
| @token_required |
| def train_multi_recall(current_user): |
| """训练多路召回模型""" |
| try: |
| # 只有管理员才能训练模型 |
| if not hasattr(current_user, 'is_admin') or not current_user.is_admin: |
| return jsonify({ |
| 'success': False, |
| 'message': '需要管理员权限' |
| }), 403 |
| |
| recommendation_service.train_multi_recall() |
| |
| return jsonify({ |
| 'success': True, |
| 'message': '多路召回模型训练完成' |
| }) |
| except Exception as e: |
| return jsonify({ |
| 'success': False, |
| 'message': f'模型训练失败: {str(e)}' |
| }), 500 |
| |
| @recommend_bp.route('/recall_config', methods=['GET']) |
| @token_required |
| def get_recall_config(current_user): |
| """获取多路召回配置""" |
| try: |
| config = recommendation_service.recall_config |
| return jsonify({ |
| 'success': True, |
| 'data': { |
| 'config': config, |
| 'multi_recall_enabled': recommendation_service.multi_recall_enabled |
| }, |
| 'message': '配置获取成功' |
| }) |
| except Exception as e: |
| return jsonify({ |
| 'success': False, |
| 'message': f'配置获取失败: {str(e)}' |
| }), 500 |
| |
| @recommend_bp.route('/recall_config', methods=['POST']) |
| @token_required |
| def update_recall_config(current_user): |
| """更新多路召回配置""" |
| try: |
| # 只有管理员才能更新配置 |
| if not hasattr(current_user, 'is_admin') or not current_user.is_admin: |
| return jsonify({ |
| 'success': False, |
| 'message': '需要管理员权限' |
| }), 403 |
| |
| data = request.get_json() |
| new_config = data.get('config', {}) |
| |
| # 更新多路召回启用状态 |
| if 'multi_recall_enabled' in data: |
| recommendation_service.multi_recall_enabled = data['multi_recall_enabled'] |
| |
| # 更新具体配置 |
| if new_config: |
| recommendation_service.update_recall_config(new_config) |
| |
| return jsonify({ |
| 'success': True, |
| 'data': { |
| 'config': recommendation_service.recall_config, |
| 'multi_recall_enabled': recommendation_service.multi_recall_enabled |
| }, |
| 'message': '配置更新成功' |
| }) |
| except Exception as e: |
| return jsonify({ |
| 'success': False, |
| 'message': f'配置更新失败: {str(e)}' |
| }), 500 |
| |
| @recommend_bp.route('/recall_stats/<int:user_id>', methods=['GET']) |
| @token_required |
| def get_recall_stats(current_user, user_id): |
| """获取用户的召回统计信息""" |
| try: |
| # 只允许查看自己的统计或管理员查看 |
| if current_user.user_id != user_id and (not hasattr(current_user, 'is_admin') or not current_user.is_admin): |
| return jsonify({ |
| 'success': False, |
| 'message': '权限不足' |
| }), 403 |
| |
| stats = recommendation_service.get_multi_recall_stats(user_id) |
| |
| return jsonify({ |
| 'success': True, |
| 'data': stats, |
| 'message': '统计信息获取成功' |
| }) |
| except Exception as e: |
| return jsonify({ |
| 'success': False, |
| 'message': f'统计信息获取失败: {str(e)}' |
| }), 500 |