blob: 97b8908f53709efe4d3859acb0f42a5aa0ce8c9b [file] [log] [blame]
Raverd7895172025-06-18 17:54:38 +08001from flask import Blueprint, request, jsonify
2from app.services.recommendation_service import RecommendationService
3from app.functions.FAuth import FAuth
4from sqlalchemy import create_engine
5from sqlalchemy.orm import sessionmaker
6from config import Config
7from functools import wraps
8
9recommend_bp = Blueprint('recommend', __name__, url_prefix='/api/recommend')
10
11def token_required(f):
12 """装饰器:需要令牌验证"""
13 @wraps(f)
14 def decorated(*args, **kwargs):
15 token = request.headers.get('Authorization')
16 if not token:
17 return jsonify({'success': False, 'message': '缺少访问令牌'}), 401
18
19 session = None
20 try:
21 # 移除Bearer前缀
22 if token.startswith('Bearer '):
23 token = token[7:]
24
25 engine = create_engine(Config.SQLURL)
26 SessionLocal = sessionmaker(bind=engine)
27 session = SessionLocal()
28 f_auth = FAuth(session)
29
30 user = f_auth.get_user_by_token(token)
31 if not user:
32 return jsonify({'success': False, 'message': '无效的访问令牌'}), 401
33
34 # 将用户信息传递给路由函数
35 return f(user, *args, **kwargs)
36 except Exception as e:
37 if session:
38 session.rollback()
39 return jsonify({'success': False, 'message': '令牌验证失败'}), 401
40 finally:
41 if session:
42 session.close()
43
44 return decorated
45
46# 初始化推荐服务
47recommendation_service = RecommendationService()
48
49@recommend_bp.route('/get_recommendations', methods=['POST'])
50@token_required
51def get_recommendations(current_user):
52 """获取个性化推荐"""
53 try:
54 data = request.get_json()
55 user_id = data.get('user_id') or current_user.user_id
56 topk = data.get('topk', 2)
57
58 recommendations = recommendation_service.get_recommendations(user_id, topk)
59
60 return jsonify({
61 'success': True,
62 'data': {
63 'user_id': user_id,
64 'recommendations': recommendations,
65 'count': len(recommendations)
66 },
67 'message': '推荐获取成功'
68 })
69 except Exception as e:
70 return jsonify({
71 'success': False,
72 'message': f'推荐获取失败: {str(e)}'
73 }), 500
74
75@recommend_bp.route('/cold_start', methods=['GET'])
76def cold_start_recommendations():
77 """冷启动推荐(无需登录)"""
78 try:
79 topk = request.args.get('topk', 2, type=int)
80
81 recommendations = recommendation_service.user_cold_start(topk)
82
83 return jsonify({
84 'success': True,
85 'data': {
86 'recommendations': recommendations,
87 'count': len(recommendations),
88 'type': 'cold_start'
89 },
90 'message': '热门推荐获取成功'
91 })
92 except Exception as e:
93 return jsonify({
94 'success': False,
95 'message': f'推荐获取失败: {str(e)}'
96 }), 500
97
98@recommend_bp.route('/health', methods=['GET'])
99def health_check():
100 """推荐系统健康检查"""
101 try:
102 # 简单的健康检查
103 import torch
104 cuda_available = torch.cuda.is_available()
105
106 return jsonify({
107 'success': True,
108 'data': {
109 'status': 'healthy',
110 'cuda_available': cuda_available,
111 'device': 'cuda' if cuda_available else 'cpu'
112 },
113 'message': '推荐系统运行正常'
114 })
115 except Exception as e:
116 return jsonify({
117 'success': False,
118 'message': f'推荐系统异常: {str(e)}'
119 }), 500
120
121@recommend_bp.route('/multi_recall', methods=['POST'])
122@token_required
123def multi_recall_recommendations(current_user):
124 """多路召回推荐"""
125 try:
126 data = request.get_json()
127 user_id = data.get('user_id') or current_user.user_id
128 topk = data.get('topk', 2)
129
130 # 强制使用多路召回
131 result = recommendation_service.run_inference(user_id, topk, use_multi_recall=True)
132
133 # 如果是冷启动直接返回详细信息,否则查详情
134 if isinstance(result, list) and result and isinstance(result[0], dict):
135 recommendations = result
136 else:
137 # result 是 (topk_post_ids, topk_scores) 的元组
138 if isinstance(result, tuple) and len(result) == 2:
139 topk_post_ids, topk_scores = result
140 recommendations = recommendation_service.get_post_info(topk_post_ids, topk_scores)
141 else:
142 recommendations = recommendation_service.get_post_info(result)
143
144 return jsonify({
145 'success': True,
146 'data': {
147 'user_id': user_id,
148 'recommendations': recommendations,
149 'count': len(recommendations),
150 'type': 'multi_recall'
151 },
152 'message': '多路召回推荐获取成功'
153 })
154 except Exception as e:
155 return jsonify({
156 'success': False,
157 'message': f'多路召回推荐获取失败: {str(e)}'
158 }), 500
159
160@recommend_bp.route('/lightgcn', methods=['POST'])
161@token_required
162def lightgcn_recommendations(current_user):
163 """LightGCN推荐"""
164 try:
165 data = request.get_json()
166 user_id = data.get('user_id') or current_user.user_id
167 topk = data.get('topk', 2)
168
169 # 强制使用LightGCN
170 result = recommendation_service.run_inference(user_id, topk, use_multi_recall=False)
171
172 # 如果是冷启动直接返回详细信息,否则查详情
173 if isinstance(result, list) and result and isinstance(result[0], dict):
174 recommendations = result
175 else:
176 # result 是 (topk_post_ids, topk_scores) 的元组
177 if isinstance(result, tuple) and len(result) == 2:
178 topk_post_ids, topk_scores = result
179 recommendations = recommendation_service.get_post_info(topk_post_ids, topk_scores)
180 else:
181 recommendations = recommendation_service.get_post_info(result)
182
183 return jsonify({
184 'success': True,
185 'data': {
186 'user_id': user_id,
187 'recommendations': recommendations,
188 'count': len(recommendations),
189 'type': 'lightgcn'
190 },
191 'message': 'LightGCN推荐获取成功'
192 })
193 except Exception as e:
194 return jsonify({
195 'success': False,
196 'message': f'LightGCN推荐获取失败: {str(e)}'
197 }), 500
198
199@recommend_bp.route('/train_multi_recall', methods=['POST'])
200@token_required
201def train_multi_recall(current_user):
202 """训练多路召回模型"""
203 try:
204 # 只有管理员才能训练模型
205 if not hasattr(current_user, 'is_admin') or not current_user.is_admin:
206 return jsonify({
207 'success': False,
208 'message': '需要管理员权限'
209 }), 403
210
211 recommendation_service.train_multi_recall()
212
213 return jsonify({
214 'success': True,
215 'message': '多路召回模型训练完成'
216 })
217 except Exception as e:
218 return jsonify({
219 'success': False,
220 'message': f'模型训练失败: {str(e)}'
221 }), 500
222
223@recommend_bp.route('/recall_config', methods=['GET'])
224@token_required
225def get_recall_config(current_user):
226 """获取多路召回配置"""
227 try:
228 config = recommendation_service.recall_config
229 return jsonify({
230 'success': True,
231 'data': {
232 'config': config,
233 'multi_recall_enabled': recommendation_service.multi_recall_enabled
234 },
235 'message': '配置获取成功'
236 })
237 except Exception as e:
238 return jsonify({
239 'success': False,
240 'message': f'配置获取失败: {str(e)}'
241 }), 500
242
243@recommend_bp.route('/recall_config', methods=['POST'])
244@token_required
245def update_recall_config(current_user):
246 """更新多路召回配置"""
247 try:
248 # 只有管理员才能更新配置
249 if not hasattr(current_user, 'is_admin') or not current_user.is_admin:
250 return jsonify({
251 'success': False,
252 'message': '需要管理员权限'
253 }), 403
254
255 data = request.get_json()
256 new_config = data.get('config', {})
257
258 # 更新多路召回启用状态
259 if 'multi_recall_enabled' in data:
260 recommendation_service.multi_recall_enabled = data['multi_recall_enabled']
261
262 # 更新具体配置
263 if new_config:
264 recommendation_service.update_recall_config(new_config)
265
266 return jsonify({
267 'success': True,
268 'data': {
269 'config': recommendation_service.recall_config,
270 'multi_recall_enabled': recommendation_service.multi_recall_enabled
271 },
272 'message': '配置更新成功'
273 })
274 except Exception as e:
275 return jsonify({
276 'success': False,
277 'message': f'配置更新失败: {str(e)}'
278 }), 500
279
280@recommend_bp.route('/recall_stats/<int:user_id>', methods=['GET'])
281@token_required
282def get_recall_stats(current_user, user_id):
283 """获取用户的召回统计信息"""
284 try:
285 # 只允许查看自己的统计或管理员查看
286 if current_user.user_id != user_id and (not hasattr(current_user, 'is_admin') or not current_user.is_admin):
287 return jsonify({
288 'success': False,
289 'message': '权限不足'
290 }), 403
291
292 stats = recommendation_service.get_multi_recall_stats(user_id)
293
294 return jsonify({
295 'success': True,
296 'data': stats,
297 'message': '统计信息获取成功'
298 })
299 except Exception as e:
300 return jsonify({
301 'success': False,
302 'message': f'统计信息获取失败: {str(e)}'
303 }), 500