blob: 97b8908f53709efe4d3859acb0f42a5aa0ce8c9b [file] [log] [blame]
from flask import Blueprint, request, jsonify
from app.services.recommendation_service import RecommendationService
from app.functions.FAuth import FAuth
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from config import Config
from functools import wraps
recommend_bp = Blueprint('recommend', __name__, url_prefix='/api/recommend')
def token_required(f):
"""装饰器:需要令牌验证"""
@wraps(f)
def decorated(*args, **kwargs):
token = request.headers.get('Authorization')
if not token:
return jsonify({'success': False, 'message': '缺少访问令牌'}), 401
session = None
try:
# 移除Bearer前缀
if token.startswith('Bearer '):
token = token[7:]
engine = create_engine(Config.SQLURL)
SessionLocal = sessionmaker(bind=engine)
session = SessionLocal()
f_auth = FAuth(session)
user = f_auth.get_user_by_token(token)
if not user:
return jsonify({'success': False, 'message': '无效的访问令牌'}), 401
# 将用户信息传递给路由函数
return f(user, *args, **kwargs)
except Exception as e:
if session:
session.rollback()
return jsonify({'success': False, 'message': '令牌验证失败'}), 401
finally:
if session:
session.close()
return decorated
# 初始化推荐服务
recommendation_service = RecommendationService()
@recommend_bp.route('/get_recommendations', methods=['POST'])
@token_required
def get_recommendations(current_user):
"""获取个性化推荐"""
try:
data = request.get_json()
user_id = data.get('user_id') or current_user.user_id
topk = data.get('topk', 2)
recommendations = recommendation_service.get_recommendations(user_id, topk)
return jsonify({
'success': True,
'data': {
'user_id': user_id,
'recommendations': recommendations,
'count': len(recommendations)
},
'message': '推荐获取成功'
})
except Exception as e:
return jsonify({
'success': False,
'message': f'推荐获取失败: {str(e)}'
}), 500
@recommend_bp.route('/cold_start', methods=['GET'])
def cold_start_recommendations():
"""冷启动推荐(无需登录)"""
try:
topk = request.args.get('topk', 2, type=int)
recommendations = recommendation_service.user_cold_start(topk)
return jsonify({
'success': True,
'data': {
'recommendations': recommendations,
'count': len(recommendations),
'type': 'cold_start'
},
'message': '热门推荐获取成功'
})
except Exception as e:
return jsonify({
'success': False,
'message': f'推荐获取失败: {str(e)}'
}), 500
@recommend_bp.route('/health', methods=['GET'])
def health_check():
"""推荐系统健康检查"""
try:
# 简单的健康检查
import torch
cuda_available = torch.cuda.is_available()
return jsonify({
'success': True,
'data': {
'status': 'healthy',
'cuda_available': cuda_available,
'device': 'cuda' if cuda_available else 'cpu'
},
'message': '推荐系统运行正常'
})
except Exception as e:
return jsonify({
'success': False,
'message': f'推荐系统异常: {str(e)}'
}), 500
@recommend_bp.route('/multi_recall', methods=['POST'])
@token_required
def multi_recall_recommendations(current_user):
"""多路召回推荐"""
try:
data = request.get_json()
user_id = data.get('user_id') or current_user.user_id
topk = data.get('topk', 2)
# 强制使用多路召回
result = recommendation_service.run_inference(user_id, topk, use_multi_recall=True)
# 如果是冷启动直接返回详细信息,否则查详情
if isinstance(result, list) and result and isinstance(result[0], dict):
recommendations = result
else:
# result 是 (topk_post_ids, topk_scores) 的元组
if isinstance(result, tuple) and len(result) == 2:
topk_post_ids, topk_scores = result
recommendations = recommendation_service.get_post_info(topk_post_ids, topk_scores)
else:
recommendations = recommendation_service.get_post_info(result)
return jsonify({
'success': True,
'data': {
'user_id': user_id,
'recommendations': recommendations,
'count': len(recommendations),
'type': 'multi_recall'
},
'message': '多路召回推荐获取成功'
})
except Exception as e:
return jsonify({
'success': False,
'message': f'多路召回推荐获取失败: {str(e)}'
}), 500
@recommend_bp.route('/lightgcn', methods=['POST'])
@token_required
def lightgcn_recommendations(current_user):
"""LightGCN推荐"""
try:
data = request.get_json()
user_id = data.get('user_id') or current_user.user_id
topk = data.get('topk', 2)
# 强制使用LightGCN
result = recommendation_service.run_inference(user_id, topk, use_multi_recall=False)
# 如果是冷启动直接返回详细信息,否则查详情
if isinstance(result, list) and result and isinstance(result[0], dict):
recommendations = result
else:
# result 是 (topk_post_ids, topk_scores) 的元组
if isinstance(result, tuple) and len(result) == 2:
topk_post_ids, topk_scores = result
recommendations = recommendation_service.get_post_info(topk_post_ids, topk_scores)
else:
recommendations = recommendation_service.get_post_info(result)
return jsonify({
'success': True,
'data': {
'user_id': user_id,
'recommendations': recommendations,
'count': len(recommendations),
'type': 'lightgcn'
},
'message': 'LightGCN推荐获取成功'
})
except Exception as e:
return jsonify({
'success': False,
'message': f'LightGCN推荐获取失败: {str(e)}'
}), 500
@recommend_bp.route('/train_multi_recall', methods=['POST'])
@token_required
def train_multi_recall(current_user):
"""训练多路召回模型"""
try:
# 只有管理员才能训练模型
if not hasattr(current_user, 'is_admin') or not current_user.is_admin:
return jsonify({
'success': False,
'message': '需要管理员权限'
}), 403
recommendation_service.train_multi_recall()
return jsonify({
'success': True,
'message': '多路召回模型训练完成'
})
except Exception as e:
return jsonify({
'success': False,
'message': f'模型训练失败: {str(e)}'
}), 500
@recommend_bp.route('/recall_config', methods=['GET'])
@token_required
def get_recall_config(current_user):
"""获取多路召回配置"""
try:
config = recommendation_service.recall_config
return jsonify({
'success': True,
'data': {
'config': config,
'multi_recall_enabled': recommendation_service.multi_recall_enabled
},
'message': '配置获取成功'
})
except Exception as e:
return jsonify({
'success': False,
'message': f'配置获取失败: {str(e)}'
}), 500
@recommend_bp.route('/recall_config', methods=['POST'])
@token_required
def update_recall_config(current_user):
"""更新多路召回配置"""
try:
# 只有管理员才能更新配置
if not hasattr(current_user, 'is_admin') or not current_user.is_admin:
return jsonify({
'success': False,
'message': '需要管理员权限'
}), 403
data = request.get_json()
new_config = data.get('config', {})
# 更新多路召回启用状态
if 'multi_recall_enabled' in data:
recommendation_service.multi_recall_enabled = data['multi_recall_enabled']
# 更新具体配置
if new_config:
recommendation_service.update_recall_config(new_config)
return jsonify({
'success': True,
'data': {
'config': recommendation_service.recall_config,
'multi_recall_enabled': recommendation_service.multi_recall_enabled
},
'message': '配置更新成功'
})
except Exception as e:
return jsonify({
'success': False,
'message': f'配置更新失败: {str(e)}'
}), 500
@recommend_bp.route('/recall_stats/<int:user_id>', methods=['GET'])
@token_required
def get_recall_stats(current_user, user_id):
"""获取用户的召回统计信息"""
try:
# 只允许查看自己的统计或管理员查看
if current_user.user_id != user_id and (not hasattr(current_user, 'is_admin') or not current_user.is_admin):
return jsonify({
'success': False,
'message': '权限不足'
}), 403
stats = recommendation_service.get_multi_recall_stats(user_id)
return jsonify({
'success': True,
'data': stats,
'message': '统计信息获取成功'
})
except Exception as e:
return jsonify({
'success': False,
'message': f'统计信息获取失败: {str(e)}'
}), 500