Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame] | 1 | from flask import Blueprint, request, jsonify |
| 2 | from app.services.recommendation_service import RecommendationService |
| 3 | from app.functions.FAuth import FAuth |
| 4 | from sqlalchemy import create_engine |
| 5 | from sqlalchemy.orm import sessionmaker |
| 6 | from config import Config |
| 7 | from functools import wraps |
| 8 | |
| 9 | recommend_bp = Blueprint('recommend', __name__, url_prefix='/api/recommend') |
| 10 | |
| 11 | def token_required(f): |
| 12 | """装饰器:需要令牌验证""" |
| 13 | @wraps(f) |
| 14 | def decorated(*args, **kwargs): |
| 15 | token = request.headers.get('Authorization') |
| 16 | if not token: |
| 17 | return jsonify({'success': False, 'message': '缺少访问令牌'}), 401 |
| 18 | |
| 19 | session = None |
| 20 | try: |
| 21 | # 移除Bearer前缀 |
| 22 | if token.startswith('Bearer '): |
| 23 | token = token[7:] |
| 24 | |
| 25 | engine = create_engine(Config.SQLURL) |
| 26 | SessionLocal = sessionmaker(bind=engine) |
| 27 | session = SessionLocal() |
| 28 | f_auth = FAuth(session) |
| 29 | |
| 30 | user = f_auth.get_user_by_token(token) |
| 31 | if not user: |
| 32 | return jsonify({'success': False, 'message': '无效的访问令牌'}), 401 |
| 33 | |
| 34 | # 将用户信息传递给路由函数 |
| 35 | return f(user, *args, **kwargs) |
| 36 | except Exception as e: |
| 37 | if session: |
| 38 | session.rollback() |
| 39 | return jsonify({'success': False, 'message': '令牌验证失败'}), 401 |
| 40 | finally: |
| 41 | if session: |
| 42 | session.close() |
| 43 | |
| 44 | return decorated |
| 45 | |
| 46 | # 初始化推荐服务 |
| 47 | recommendation_service = RecommendationService() |
| 48 | |
| 49 | @recommend_bp.route('/get_recommendations', methods=['POST']) |
| 50 | @token_required |
| 51 | def get_recommendations(current_user): |
| 52 | """获取个性化推荐""" |
| 53 | try: |
| 54 | data = request.get_json() |
| 55 | user_id = data.get('user_id') or current_user.user_id |
| 56 | topk = data.get('topk', 2) |
| 57 | |
| 58 | recommendations = recommendation_service.get_recommendations(user_id, topk) |
| 59 | |
| 60 | return jsonify({ |
| 61 | 'success': True, |
| 62 | 'data': { |
| 63 | 'user_id': user_id, |
| 64 | 'recommendations': recommendations, |
| 65 | 'count': len(recommendations) |
| 66 | }, |
| 67 | 'message': '推荐获取成功' |
| 68 | }) |
| 69 | except Exception as e: |
| 70 | return jsonify({ |
| 71 | 'success': False, |
| 72 | 'message': f'推荐获取失败: {str(e)}' |
| 73 | }), 500 |
| 74 | |
| 75 | @recommend_bp.route('/cold_start', methods=['GET']) |
| 76 | def cold_start_recommendations(): |
| 77 | """冷启动推荐(无需登录)""" |
| 78 | try: |
| 79 | topk = request.args.get('topk', 2, type=int) |
| 80 | |
| 81 | recommendations = recommendation_service.user_cold_start(topk) |
| 82 | |
| 83 | return jsonify({ |
| 84 | 'success': True, |
| 85 | 'data': { |
| 86 | 'recommendations': recommendations, |
| 87 | 'count': len(recommendations), |
| 88 | 'type': 'cold_start' |
| 89 | }, |
| 90 | 'message': '热门推荐获取成功' |
| 91 | }) |
| 92 | except Exception as e: |
| 93 | return jsonify({ |
| 94 | 'success': False, |
| 95 | 'message': f'推荐获取失败: {str(e)}' |
| 96 | }), 500 |
| 97 | |
| 98 | @recommend_bp.route('/health', methods=['GET']) |
| 99 | def health_check(): |
| 100 | """推荐系统健康检查""" |
| 101 | try: |
| 102 | # 简单的健康检查 |
| 103 | import torch |
| 104 | cuda_available = torch.cuda.is_available() |
| 105 | |
| 106 | return jsonify({ |
| 107 | 'success': True, |
| 108 | 'data': { |
| 109 | 'status': 'healthy', |
| 110 | 'cuda_available': cuda_available, |
| 111 | 'device': 'cuda' if cuda_available else 'cpu' |
| 112 | }, |
| 113 | 'message': '推荐系统运行正常' |
| 114 | }) |
| 115 | except Exception as e: |
| 116 | return jsonify({ |
| 117 | 'success': False, |
| 118 | 'message': f'推荐系统异常: {str(e)}' |
| 119 | }), 500 |
| 120 | |
| 121 | @recommend_bp.route('/multi_recall', methods=['POST']) |
| 122 | @token_required |
| 123 | def multi_recall_recommendations(current_user): |
| 124 | """多路召回推荐""" |
| 125 | try: |
| 126 | data = request.get_json() |
| 127 | user_id = data.get('user_id') or current_user.user_id |
| 128 | topk = data.get('topk', 2) |
| 129 | |
| 130 | # 强制使用多路召回 |
| 131 | result = recommendation_service.run_inference(user_id, topk, use_multi_recall=True) |
| 132 | |
| 133 | # 如果是冷启动直接返回详细信息,否则查详情 |
| 134 | if isinstance(result, list) and result and isinstance(result[0], dict): |
| 135 | recommendations = result |
| 136 | else: |
| 137 | # result 是 (topk_post_ids, topk_scores) 的元组 |
| 138 | if isinstance(result, tuple) and len(result) == 2: |
| 139 | topk_post_ids, topk_scores = result |
| 140 | recommendations = recommendation_service.get_post_info(topk_post_ids, topk_scores) |
| 141 | else: |
| 142 | recommendations = recommendation_service.get_post_info(result) |
| 143 | |
| 144 | return jsonify({ |
| 145 | 'success': True, |
| 146 | 'data': { |
| 147 | 'user_id': user_id, |
| 148 | 'recommendations': recommendations, |
| 149 | 'count': len(recommendations), |
| 150 | 'type': 'multi_recall' |
| 151 | }, |
| 152 | 'message': '多路召回推荐获取成功' |
| 153 | }) |
| 154 | except Exception as e: |
| 155 | return jsonify({ |
| 156 | 'success': False, |
| 157 | 'message': f'多路召回推荐获取失败: {str(e)}' |
| 158 | }), 500 |
| 159 | |
| 160 | @recommend_bp.route('/lightgcn', methods=['POST']) |
| 161 | @token_required |
| 162 | def lightgcn_recommendations(current_user): |
| 163 | """LightGCN推荐""" |
| 164 | try: |
| 165 | data = request.get_json() |
| 166 | user_id = data.get('user_id') or current_user.user_id |
| 167 | topk = data.get('topk', 2) |
| 168 | |
| 169 | # 强制使用LightGCN |
| 170 | result = recommendation_service.run_inference(user_id, topk, use_multi_recall=False) |
| 171 | |
| 172 | # 如果是冷启动直接返回详细信息,否则查详情 |
| 173 | if isinstance(result, list) and result and isinstance(result[0], dict): |
| 174 | recommendations = result |
| 175 | else: |
| 176 | # result 是 (topk_post_ids, topk_scores) 的元组 |
| 177 | if isinstance(result, tuple) and len(result) == 2: |
| 178 | topk_post_ids, topk_scores = result |
| 179 | recommendations = recommendation_service.get_post_info(topk_post_ids, topk_scores) |
| 180 | else: |
| 181 | recommendations = recommendation_service.get_post_info(result) |
| 182 | |
| 183 | return jsonify({ |
| 184 | 'success': True, |
| 185 | 'data': { |
| 186 | 'user_id': user_id, |
| 187 | 'recommendations': recommendations, |
| 188 | 'count': len(recommendations), |
| 189 | 'type': 'lightgcn' |
| 190 | }, |
| 191 | 'message': 'LightGCN推荐获取成功' |
| 192 | }) |
| 193 | except Exception as e: |
| 194 | return jsonify({ |
| 195 | 'success': False, |
| 196 | 'message': f'LightGCN推荐获取失败: {str(e)}' |
| 197 | }), 500 |
| 198 | |
| 199 | @recommend_bp.route('/train_multi_recall', methods=['POST']) |
| 200 | @token_required |
| 201 | def train_multi_recall(current_user): |
| 202 | """训练多路召回模型""" |
| 203 | try: |
| 204 | # 只有管理员才能训练模型 |
| 205 | if not hasattr(current_user, 'is_admin') or not current_user.is_admin: |
| 206 | return jsonify({ |
| 207 | 'success': False, |
| 208 | 'message': '需要管理员权限' |
| 209 | }), 403 |
| 210 | |
| 211 | recommendation_service.train_multi_recall() |
| 212 | |
| 213 | return jsonify({ |
| 214 | 'success': True, |
| 215 | 'message': '多路召回模型训练完成' |
| 216 | }) |
| 217 | except Exception as e: |
| 218 | return jsonify({ |
| 219 | 'success': False, |
| 220 | 'message': f'模型训练失败: {str(e)}' |
| 221 | }), 500 |
| 222 | |
| 223 | @recommend_bp.route('/recall_config', methods=['GET']) |
| 224 | @token_required |
| 225 | def get_recall_config(current_user): |
| 226 | """获取多路召回配置""" |
| 227 | try: |
| 228 | config = recommendation_service.recall_config |
| 229 | return jsonify({ |
| 230 | 'success': True, |
| 231 | 'data': { |
| 232 | 'config': config, |
| 233 | 'multi_recall_enabled': recommendation_service.multi_recall_enabled |
| 234 | }, |
| 235 | 'message': '配置获取成功' |
| 236 | }) |
| 237 | except Exception as e: |
| 238 | return jsonify({ |
| 239 | 'success': False, |
| 240 | 'message': f'配置获取失败: {str(e)}' |
| 241 | }), 500 |
| 242 | |
| 243 | @recommend_bp.route('/recall_config', methods=['POST']) |
| 244 | @token_required |
| 245 | def update_recall_config(current_user): |
| 246 | """更新多路召回配置""" |
| 247 | try: |
| 248 | # 只有管理员才能更新配置 |
| 249 | if not hasattr(current_user, 'is_admin') or not current_user.is_admin: |
| 250 | return jsonify({ |
| 251 | 'success': False, |
| 252 | 'message': '需要管理员权限' |
| 253 | }), 403 |
| 254 | |
| 255 | data = request.get_json() |
| 256 | new_config = data.get('config', {}) |
| 257 | |
| 258 | # 更新多路召回启用状态 |
| 259 | if 'multi_recall_enabled' in data: |
| 260 | recommendation_service.multi_recall_enabled = data['multi_recall_enabled'] |
| 261 | |
| 262 | # 更新具体配置 |
| 263 | if new_config: |
| 264 | recommendation_service.update_recall_config(new_config) |
| 265 | |
| 266 | return jsonify({ |
| 267 | 'success': True, |
| 268 | 'data': { |
| 269 | 'config': recommendation_service.recall_config, |
| 270 | 'multi_recall_enabled': recommendation_service.multi_recall_enabled |
| 271 | }, |
| 272 | 'message': '配置更新成功' |
| 273 | }) |
| 274 | except Exception as e: |
| 275 | return jsonify({ |
| 276 | 'success': False, |
| 277 | 'message': f'配置更新失败: {str(e)}' |
| 278 | }), 500 |
| 279 | |
| 280 | @recommend_bp.route('/recall_stats/<int:user_id>', methods=['GET']) |
| 281 | @token_required |
| 282 | def get_recall_stats(current_user, user_id): |
| 283 | """获取用户的召回统计信息""" |
| 284 | try: |
| 285 | # 只允许查看自己的统计或管理员查看 |
| 286 | if current_user.user_id != user_id and (not hasattr(current_user, 'is_admin') or not current_user.is_admin): |
| 287 | return jsonify({ |
| 288 | 'success': False, |
| 289 | 'message': '权限不足' |
| 290 | }), 403 |
| 291 | |
| 292 | stats = recommendation_service.get_multi_recall_stats(user_id) |
| 293 | |
| 294 | return jsonify({ |
| 295 | 'success': True, |
| 296 | 'data': stats, |
| 297 | 'message': '统计信息获取成功' |
| 298 | }) |
| 299 | except Exception as e: |
| 300 | return jsonify({ |
| 301 | 'success': False, |
| 302 | 'message': f'统计信息获取失败: {str(e)}' |
| 303 | }), 500 |