推荐系统
Change-Id: I49b9205568f1ccf88b32b08511aff8b0bea8d1bd
diff --git a/rhj/backend/app/__init__.py b/rhj/backend/app/__init__.py
index 41611ae..c50a674 100644
--- a/rhj/backend/app/__init__.py
+++ b/rhj/backend/app/__init__.py
@@ -13,6 +13,10 @@
# Register blueprints or routes
from .routes import main as main_blueprint
app.register_blueprint(main_blueprint)
+
+ # Register recommendation blueprint
+ from .blueprints.recommend import recommend_bp
+ app.register_blueprint(recommend_bp)
return app
diff --git a/rhj/backend/app/__pycache__/__init__.cpython-312.pyc b/rhj/backend/app/__pycache__/__init__.cpython-312.pyc
index 5c357bc..7c7d017 100644
--- a/rhj/backend/app/__pycache__/__init__.cpython-312.pyc
+++ b/rhj/backend/app/__pycache__/__init__.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/blueprints/__pycache__/__init__.cpython-312.pyc b/rhj/backend/app/blueprints/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000..1388273
--- /dev/null
+++ b/rhj/backend/app/blueprints/__pycache__/__init__.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/blueprints/__pycache__/recommend.cpython-312.pyc b/rhj/backend/app/blueprints/__pycache__/recommend.cpython-312.pyc
new file mode 100644
index 0000000..29c786d
--- /dev/null
+++ b/rhj/backend/app/blueprints/__pycache__/recommend.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/blueprints/recommend.py b/rhj/backend/app/blueprints/recommend.py
new file mode 100644
index 0000000..97b8908
--- /dev/null
+++ b/rhj/backend/app/blueprints/recommend.py
@@ -0,0 +1,303 @@
+from flask import Blueprint, request, jsonify
+from app.services.recommendation_service import RecommendationService
+from app.functions.FAuth import FAuth
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker
+from config import Config
+from functools import wraps
+
+recommend_bp = Blueprint('recommend', __name__, url_prefix='/api/recommend')
+
+def token_required(f):
+ """装饰器:需要令牌验证"""
+ @wraps(f)
+ def decorated(*args, **kwargs):
+ token = request.headers.get('Authorization')
+ if not token:
+ return jsonify({'success': False, 'message': '缺少访问令牌'}), 401
+
+ session = None
+ try:
+ # 移除Bearer前缀
+ if token.startswith('Bearer '):
+ token = token[7:]
+
+ engine = create_engine(Config.SQLURL)
+ SessionLocal = sessionmaker(bind=engine)
+ session = SessionLocal()
+ f_auth = FAuth(session)
+
+ user = f_auth.get_user_by_token(token)
+ if not user:
+ return jsonify({'success': False, 'message': '无效的访问令牌'}), 401
+
+ # 将用户信息传递给路由函数
+ return f(user, *args, **kwargs)
+ except Exception as e:
+ if session:
+ session.rollback()
+ return jsonify({'success': False, 'message': '令牌验证失败'}), 401
+ finally:
+ if session:
+ session.close()
+
+ return decorated
+
+# 初始化推荐服务
+recommendation_service = RecommendationService()
+
+@recommend_bp.route('/get_recommendations', methods=['POST'])
+@token_required
+def get_recommendations(current_user):
+ """获取个性化推荐"""
+ try:
+ data = request.get_json()
+ user_id = data.get('user_id') or current_user.user_id
+ topk = data.get('topk', 2)
+
+ recommendations = recommendation_service.get_recommendations(user_id, topk)
+
+ return jsonify({
+ 'success': True,
+ 'data': {
+ 'user_id': user_id,
+ 'recommendations': recommendations,
+ 'count': len(recommendations)
+ },
+ 'message': '推荐获取成功'
+ })
+ except Exception as e:
+ return jsonify({
+ 'success': False,
+ 'message': f'推荐获取失败: {str(e)}'
+ }), 500
+
+@recommend_bp.route('/cold_start', methods=['GET'])
+def cold_start_recommendations():
+ """冷启动推荐(无需登录)"""
+ try:
+ topk = request.args.get('topk', 2, type=int)
+
+ recommendations = recommendation_service.user_cold_start(topk)
+
+ return jsonify({
+ 'success': True,
+ 'data': {
+ 'recommendations': recommendations,
+ 'count': len(recommendations),
+ 'type': 'cold_start'
+ },
+ 'message': '热门推荐获取成功'
+ })
+ except Exception as e:
+ return jsonify({
+ 'success': False,
+ 'message': f'推荐获取失败: {str(e)}'
+ }), 500
+
+@recommend_bp.route('/health', methods=['GET'])
+def health_check():
+ """推荐系统健康检查"""
+ try:
+ # 简单的健康检查
+ import torch
+ cuda_available = torch.cuda.is_available()
+
+ return jsonify({
+ 'success': True,
+ 'data': {
+ 'status': 'healthy',
+ 'cuda_available': cuda_available,
+ 'device': 'cuda' if cuda_available else 'cpu'
+ },
+ 'message': '推荐系统运行正常'
+ })
+ except Exception as e:
+ return jsonify({
+ 'success': False,
+ 'message': f'推荐系统异常: {str(e)}'
+ }), 500
+
+@recommend_bp.route('/multi_recall', methods=['POST'])
+@token_required
+def multi_recall_recommendations(current_user):
+ """多路召回推荐"""
+ try:
+ data = request.get_json()
+ user_id = data.get('user_id') or current_user.user_id
+ topk = data.get('topk', 2)
+
+ # 强制使用多路召回
+ result = recommendation_service.run_inference(user_id, topk, use_multi_recall=True)
+
+ # 如果是冷启动直接返回详细信息,否则查详情
+ if isinstance(result, list) and result and isinstance(result[0], dict):
+ recommendations = result
+ else:
+ # result 是 (topk_post_ids, topk_scores) 的元组
+ if isinstance(result, tuple) and len(result) == 2:
+ topk_post_ids, topk_scores = result
+ recommendations = recommendation_service.get_post_info(topk_post_ids, topk_scores)
+ else:
+ recommendations = recommendation_service.get_post_info(result)
+
+ return jsonify({
+ 'success': True,
+ 'data': {
+ 'user_id': user_id,
+ 'recommendations': recommendations,
+ 'count': len(recommendations),
+ 'type': 'multi_recall'
+ },
+ 'message': '多路召回推荐获取成功'
+ })
+ except Exception as e:
+ return jsonify({
+ 'success': False,
+ 'message': f'多路召回推荐获取失败: {str(e)}'
+ }), 500
+
+@recommend_bp.route('/lightgcn', methods=['POST'])
+@token_required
+def lightgcn_recommendations(current_user):
+ """LightGCN推荐"""
+ try:
+ data = request.get_json()
+ user_id = data.get('user_id') or current_user.user_id
+ topk = data.get('topk', 2)
+
+ # 强制使用LightGCN
+ result = recommendation_service.run_inference(user_id, topk, use_multi_recall=False)
+
+ # 如果是冷启动直接返回详细信息,否则查详情
+ if isinstance(result, list) and result and isinstance(result[0], dict):
+ recommendations = result
+ else:
+ # result 是 (topk_post_ids, topk_scores) 的元组
+ if isinstance(result, tuple) and len(result) == 2:
+ topk_post_ids, topk_scores = result
+ recommendations = recommendation_service.get_post_info(topk_post_ids, topk_scores)
+ else:
+ recommendations = recommendation_service.get_post_info(result)
+
+ return jsonify({
+ 'success': True,
+ 'data': {
+ 'user_id': user_id,
+ 'recommendations': recommendations,
+ 'count': len(recommendations),
+ 'type': 'lightgcn'
+ },
+ 'message': 'LightGCN推荐获取成功'
+ })
+ except Exception as e:
+ return jsonify({
+ 'success': False,
+ 'message': f'LightGCN推荐获取失败: {str(e)}'
+ }), 500
+
+@recommend_bp.route('/train_multi_recall', methods=['POST'])
+@token_required
+def train_multi_recall(current_user):
+ """训练多路召回模型"""
+ try:
+ # 只有管理员才能训练模型
+ if not hasattr(current_user, 'is_admin') or not current_user.is_admin:
+ return jsonify({
+ 'success': False,
+ 'message': '需要管理员权限'
+ }), 403
+
+ recommendation_service.train_multi_recall()
+
+ return jsonify({
+ 'success': True,
+ 'message': '多路召回模型训练完成'
+ })
+ except Exception as e:
+ return jsonify({
+ 'success': False,
+ 'message': f'模型训练失败: {str(e)}'
+ }), 500
+
+@recommend_bp.route('/recall_config', methods=['GET'])
+@token_required
+def get_recall_config(current_user):
+ """获取多路召回配置"""
+ try:
+ config = recommendation_service.recall_config
+ return jsonify({
+ 'success': True,
+ 'data': {
+ 'config': config,
+ 'multi_recall_enabled': recommendation_service.multi_recall_enabled
+ },
+ 'message': '配置获取成功'
+ })
+ except Exception as e:
+ return jsonify({
+ 'success': False,
+ 'message': f'配置获取失败: {str(e)}'
+ }), 500
+
+@recommend_bp.route('/recall_config', methods=['POST'])
+@token_required
+def update_recall_config(current_user):
+ """更新多路召回配置"""
+ try:
+ # 只有管理员才能更新配置
+ if not hasattr(current_user, 'is_admin') or not current_user.is_admin:
+ return jsonify({
+ 'success': False,
+ 'message': '需要管理员权限'
+ }), 403
+
+ data = request.get_json()
+ new_config = data.get('config', {})
+
+ # 更新多路召回启用状态
+ if 'multi_recall_enabled' in data:
+ recommendation_service.multi_recall_enabled = data['multi_recall_enabled']
+
+ # 更新具体配置
+ if new_config:
+ recommendation_service.update_recall_config(new_config)
+
+ return jsonify({
+ 'success': True,
+ 'data': {
+ 'config': recommendation_service.recall_config,
+ 'multi_recall_enabled': recommendation_service.multi_recall_enabled
+ },
+ 'message': '配置更新成功'
+ })
+ except Exception as e:
+ return jsonify({
+ 'success': False,
+ 'message': f'配置更新失败: {str(e)}'
+ }), 500
+
+@recommend_bp.route('/recall_stats/<int:user_id>', methods=['GET'])
+@token_required
+def get_recall_stats(current_user, user_id):
+ """获取用户的召回统计信息"""
+ try:
+ # 只允许查看自己的统计或管理员查看
+ if current_user.user_id != user_id and (not hasattr(current_user, 'is_admin') or not current_user.is_admin):
+ return jsonify({
+ 'success': False,
+ 'message': '权限不足'
+ }), 403
+
+ stats = recommendation_service.get_multi_recall_stats(user_id)
+
+ return jsonify({
+ 'success': True,
+ 'data': stats,
+ 'message': '统计信息获取成功'
+ })
+ except Exception as e:
+ return jsonify({
+ 'success': False,
+ 'message': f'统计信息获取失败: {str(e)}'
+ }), 500
diff --git a/rhj/backend/app/models/recall/__init__.py b/rhj/backend/app/models/recall/__init__.py
new file mode 100644
index 0000000..98d926b
--- /dev/null
+++ b/rhj/backend/app/models/recall/__init__.py
@@ -0,0 +1,24 @@
+"""
+多路召回模块
+
+包含以下召回算法:
+- SwingRecall: Swing召回算法,基于物品相似度
+- HotRecall: 热度召回算法,基于物品热度
+- AdRecall: 广告召回算法,专门处理广告内容
+- UserCFRecall: 用户协同过滤召回算法
+- MultiRecallManager: 多路召回管理器,整合所有召回策略
+"""
+
+from .swing_recall import SwingRecall
+from .hot_recall import HotRecall
+from .ad_recall import AdRecall
+from .usercf_recall import UserCFRecall
+from .multi_recall_manager import MultiRecallManager
+
+__all__ = [
+ 'SwingRecall',
+ 'HotRecall',
+ 'AdRecall',
+ 'UserCFRecall',
+ 'MultiRecallManager'
+]
diff --git a/rhj/backend/app/models/recall/__pycache__/__init__.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000..d1cf37c
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/__init__.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/ad_recall.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/ad_recall.cpython-312.pyc
new file mode 100644
index 0000000..08a722c
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/ad_recall.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/bloom_filter.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/bloom_filter.cpython-312.pyc
new file mode 100644
index 0000000..c4dae7e
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/bloom_filter.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/hot_recall.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/hot_recall.cpython-312.pyc
new file mode 100644
index 0000000..cb6c725
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/hot_recall.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/multi_recall_manager.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/multi_recall_manager.cpython-312.pyc
new file mode 100644
index 0000000..9a95456
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/multi_recall_manager.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/swing_recall.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/swing_recall.cpython-312.pyc
new file mode 100644
index 0000000..d913d68
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/swing_recall.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/usercf_recall.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/usercf_recall.cpython-312.pyc
new file mode 100644
index 0000000..adb6177
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/usercf_recall.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/ad_recall.py b/rhj/backend/app/models/recall/ad_recall.py
new file mode 100644
index 0000000..0fe3b0a
--- /dev/null
+++ b/rhj/backend/app/models/recall/ad_recall.py
@@ -0,0 +1,207 @@
+import pymysql
+from typing import List, Tuple, Dict
+import random
+
+class AdRecall:
+ """
+ 广告召回算法实现
+ 专门用于召回广告类型的内容
+ """
+
+ def __init__(self, db_config: dict):
+ """
+ 初始化广告召回模型
+
+ Args:
+ db_config: 数据库配置
+ """
+ self.db_config = db_config
+ self.ad_items = []
+
+ def _get_ad_items(self):
+ """获取广告物品列表"""
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+
+ # 获取所有广告帖子,按热度和发布时间排序
+ cursor.execute("""
+ SELECT
+ p.id,
+ p.heat,
+ p.created_at,
+ COUNT(DISTINCT b.user_id) as interaction_count,
+ DATEDIFF(NOW(), p.created_at) as days_since_created
+ FROM posts p
+ LEFT JOIN behaviors b ON p.id = b.post_id
+ WHERE p.is_advertisement = 1 AND p.status = 'published'
+ GROUP BY p.id, p.heat, p.created_at
+ ORDER BY p.heat DESC, p.created_at DESC
+ """)
+
+ results = cursor.fetchall()
+
+ # 计算广告分数
+ items_with_scores = []
+ for row in results:
+ post_id, heat, created_at, interaction_count, days_since_created = row
+
+ # 处理None值
+ heat = heat or 0
+ interaction_count = interaction_count or 0
+ days_since_created = days_since_created or 0
+
+ # 广告分数计算:热度 + 交互数 - 时间惩罚
+ # 新发布的广告给予更高权重
+ freshness_bonus = max(0, 30 - days_since_created) / 30.0 # 30天内的新鲜度奖励
+
+ ad_score = (
+ heat * 0.6 +
+ interaction_count * 0.3 +
+ freshness_bonus * 100 # 新鲜度奖励
+ )
+
+ items_with_scores.append((post_id, ad_score))
+
+ # 按广告分数排序
+ self.ad_items = sorted(items_with_scores, key=lambda x: x[1], reverse=True)
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ def train(self):
+ """训练广告召回模型"""
+ print("开始获取广告物品...")
+ self._get_ad_items()
+ print(f"广告召回模型训练完成,共{len(self.ad_items)}个广告物品")
+
+ def recall(self, user_id: int, num_items: int = 10) -> List[Tuple[int, float]]:
+ """
+ 为用户召回广告物品
+
+ Args:
+ user_id: 用户ID
+ num_items: 召回物品数量
+
+ Returns:
+ List of (item_id, score) tuples
+ """
+ # 如果尚未训练,先进行训练
+ if not hasattr(self, 'ad_items') or not self.ad_items:
+ self.train()
+
+ # 获取用户已交互的广告,避免重复推荐
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+ cursor.execute("""
+ SELECT DISTINCT b.post_id
+ FROM behaviors b
+ JOIN posts p ON b.post_id = p.id
+ WHERE b.user_id = %s AND p.is_advertisement = 1
+ AND b.type IN ('like', 'favorite', 'comment', 'view')
+ """, (user_id,))
+
+ user_interacted_ads = set(row[0] for row in cursor.fetchall())
+
+ # 获取用户的兴趣标签(基于历史行为)
+ cursor.execute("""
+ SELECT t.name, COUNT(*) as count
+ FROM behaviors b
+ JOIN posts p ON b.post_id = p.id
+ JOIN post_tags pt ON p.id = pt.post_id
+ JOIN tags t ON pt.tag_id = t.id
+ WHERE b.user_id = %s AND b.type IN ('like', 'favorite', 'comment')
+ GROUP BY t.name
+ ORDER BY count DESC
+ LIMIT 10
+ """, (user_id,))
+
+ user_interest_tags = set(row[0] for row in cursor.fetchall())
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ # 过滤掉用户已交互的广告
+ filtered_ads = [
+ (item_id, score) for item_id, score in self.ad_items
+ if item_id not in user_interacted_ads
+ ]
+
+ # 如果没有未交互的广告,但有广告数据,返回评分最高的广告(可能用户会再次感兴趣)
+ if not filtered_ads and self.ad_items:
+ print(f"用户 {user_id} 已与所有广告交互,返回评分最高的广告")
+ filtered_ads = self.ad_items[:num_items]
+
+ # 如果用户有兴趣标签,可以进一步个性化广告推荐
+ if user_interest_tags and filtered_ads:
+ filtered_ads = self._personalize_ads(filtered_ads, user_interest_tags)
+
+ return filtered_ads[:num_items]
+
+ def _personalize_ads(self, ad_list: List[Tuple[int, float]], user_interest_tags: set) -> List[Tuple[int, float]]:
+ """
+ 根据用户兴趣标签个性化广告推荐
+
+ Args:
+ ad_list: 广告列表
+ user_interest_tags: 用户兴趣标签
+
+ Returns:
+ 个性化后的广告列表
+ """
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+
+ personalized_ads = []
+ for ad_id, ad_score in ad_list:
+ # 获取广告的标签
+ cursor.execute("""
+ SELECT t.name
+ FROM post_tags pt
+ JOIN tags t ON pt.tag_id = t.id
+ WHERE pt.post_id = %s
+ """, (ad_id,))
+
+ ad_tags = set(row[0] for row in cursor.fetchall())
+
+ # 计算标签匹配度
+ tag_match_score = len(ad_tags & user_interest_tags) / max(len(user_interest_tags), 1)
+
+ # 调整广告分数
+ final_score = ad_score * (1 + tag_match_score)
+ personalized_ads.append((ad_id, final_score))
+
+ # 重新排序
+ personalized_ads.sort(key=lambda x: x[1], reverse=True)
+ return personalized_ads
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ def get_random_ads(self, num_items: int = 5) -> List[Tuple[int, float]]:
+ """
+ 获取随机广告(用于多样性)
+
+ Args:
+ num_items: 返回物品数量
+
+ Returns:
+ List of (item_id, score) tuples
+ """
+ if len(self.ad_items) <= num_items:
+ return self.ad_items
+
+ # 随机选择但倾向于高分广告
+ weights = [score for _, score in self.ad_items]
+ selected_indices = random.choices(
+ range(len(self.ad_items)),
+ weights=weights,
+ k=num_items
+ )
+
+ return [self.ad_items[i] for i in selected_indices]
diff --git a/rhj/backend/app/models/recall/bloom_filter.py b/rhj/backend/app/models/recall/bloom_filter.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/rhj/backend/app/models/recall/bloom_filter.py
diff --git a/rhj/backend/app/models/recall/hot_recall.py b/rhj/backend/app/models/recall/hot_recall.py
new file mode 100644
index 0000000..dbc716c
--- /dev/null
+++ b/rhj/backend/app/models/recall/hot_recall.py
@@ -0,0 +1,163 @@
+import pymysql
+from typing import List, Tuple, Dict
+import numpy as np
+
+class HotRecall:
+ """
+ 热度召回算法实现
+ 基于物品的热度(热度分数、交互次数等)进行召回
+ """
+
+ def __init__(self, db_config: dict):
+ """
+ 初始化热度召回模型
+
+ Args:
+ db_config: 数据库配置
+ """
+ self.db_config = db_config
+ self.hot_items = []
+
+ def _calculate_heat_scores(self):
+ """计算物品热度分数"""
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+
+ # 综合考虑多个热度指标
+ cursor.execute("""
+ SELECT
+ p.id,
+ p.heat,
+ COUNT(DISTINCT CASE WHEN b.type = 'like' THEN b.user_id END) as like_count,
+ COUNT(DISTINCT CASE WHEN b.type = 'favorite' THEN b.user_id END) as favorite_count,
+ COUNT(DISTINCT CASE WHEN b.type = 'comment' THEN b.user_id END) as comment_count,
+ COUNT(DISTINCT CASE WHEN b.type = 'view' THEN b.user_id END) as view_count,
+ COUNT(DISTINCT CASE WHEN b.type = 'share' THEN b.user_id END) as share_count,
+ DATEDIFF(NOW(), p.created_at) as days_since_created
+ FROM posts p
+ LEFT JOIN behaviors b ON p.id = b.post_id
+ WHERE p.status = 'published'
+ GROUP BY p.id, p.heat, p.created_at
+ """)
+
+ results = cursor.fetchall()
+
+ # 计算综合热度分数
+ items_with_scores = []
+ for row in results:
+ post_id, heat, like_count, favorite_count, comment_count, view_count, share_count, days_since_created = row
+
+ # 处理None值
+ heat = heat or 0
+ like_count = like_count or 0
+ favorite_count = favorite_count or 0
+ comment_count = comment_count or 0
+ view_count = view_count or 0
+ share_count = share_count or 0
+ days_since_created = days_since_created or 0
+
+ # 综合热度分数计算
+ # 基础热度 + 加权的用户行为 + 时间衰减
+ behavior_score = (
+ like_count * 1.0 +
+ favorite_count * 2.0 +
+ comment_count * 3.0 +
+ view_count * 0.1 +
+ share_count * 5.0
+ )
+
+ # 时间衰减因子(越新的内容热度越高)
+ time_decay = np.exp(-days_since_created / 30.0) # 30天半衰期
+
+ # 最终热度分数
+ final_score = (heat * 0.3 + behavior_score * 0.7) * time_decay
+
+ items_with_scores.append((post_id, final_score))
+
+ # 按热度排序
+ self.hot_items = sorted(items_with_scores, key=lambda x: x[1], reverse=True)
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ def train(self):
+ """训练热度召回模型"""
+ print("开始计算热度分数...")
+ self._calculate_heat_scores()
+ print(f"热度召回模型训练完成,共{len(self.hot_items)}个物品")
+
+ def recall(self, user_id: int, num_items: int = 50) -> List[Tuple[int, float]]:
+ """
+ 为用户召回热门物品
+
+ Args:
+ user_id: 用户ID
+ num_items: 召回物品数量
+
+ Returns:
+ List of (item_id, score) tuples
+ """
+ # 如果尚未训练,先进行训练
+ if not hasattr(self, 'hot_items') or not self.hot_items:
+ self.train()
+
+ # 获取用户已交互的物品,避免重复推荐
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+ cursor.execute("""
+ SELECT DISTINCT post_id
+ FROM behaviors
+ WHERE user_id = %s AND type IN ('like', 'favorite', 'comment')
+ """, (user_id,))
+
+ user_interacted_items = set(row[0] for row in cursor.fetchall())
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ # 过滤掉用户已交互的物品
+ filtered_items = [
+ (item_id, score) for item_id, score in self.hot_items
+ if item_id not in user_interacted_items
+ ]
+
+ # 如果过滤后没有足够的候选,放宽条件:只过滤强交互(like, favorite, comment)
+ if len(filtered_items) < num_items:
+ print(f"热度召回:过滤后候选不足({len(filtered_items)}),放宽过滤条件")
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+ cursor.execute("""
+ SELECT DISTINCT post_id
+ FROM behaviors
+ WHERE user_id = %s AND type IN ('like', 'favorite', 'comment')
+ """, (user_id,))
+
+ strong_interacted_items = set(row[0] for row in cursor.fetchall())
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ filtered_items = [
+ (item_id, score) for item_id, score in self.hot_items
+ if item_id not in strong_interacted_items
+ ]
+
+ return filtered_items[:num_items]
+
+ def get_top_hot_items(self, num_items: int = 100) -> List[Tuple[int, float]]:
+ """
+ 获取全局热门物品(不考虑用户个性化)
+
+ Args:
+ num_items: 返回物品数量
+
+ Returns:
+ List of (item_id, score) tuples
+ """
+ return self.hot_items[:num_items]
diff --git a/rhj/backend/app/models/recall/multi_recall_manager.py b/rhj/backend/app/models/recall/multi_recall_manager.py
new file mode 100644
index 0000000..03cb3f8
--- /dev/null
+++ b/rhj/backend/app/models/recall/multi_recall_manager.py
@@ -0,0 +1,253 @@
+from typing import List, Tuple, Dict, Any
+import numpy as np
+from collections import defaultdict
+
+from .swing_recall import SwingRecall
+from .hot_recall import HotRecall
+from .ad_recall import AdRecall
+from .usercf_recall import UserCFRecall
+
+class MultiRecallManager:
+ """
+ 多路召回管理器
+ 整合Swing、热度召回、广告召回和UserCF等多种召回策略
+ """
+
+ def __init__(self, db_config: dict, recall_config: dict = None):
+ """
+ 初始化多路召回管理器
+
+ Args:
+ db_config: 数据库配置
+ recall_config: 召回配置,包含各个召回器的参数和召回数量
+ """
+ self.db_config = db_config
+
+ # 默认配置
+ default_config = {
+ 'swing': {
+ 'enabled': True,
+ 'num_items': 15,
+ 'alpha': 0.5
+ },
+ 'hot': {
+ 'enabled': True,
+ 'num_items': 10
+ },
+ 'ad': {
+ 'enabled': True,
+ 'num_items': 2
+ },
+ 'usercf': {
+ 'enabled': True,
+ 'num_items': 10,
+ 'min_common_items': 3,
+ 'num_similar_users': 50
+ }
+ }
+
+ # 合并用户配置
+ self.config = default_config
+ if recall_config:
+ for key, value in recall_config.items():
+ if key in self.config:
+ self.config[key].update(value)
+ else:
+ self.config[key] = value
+
+ # 初始化各个召回器
+ self.recalls = {}
+ self._init_recalls()
+
+ def _init_recalls(self):
+ """初始化各个召回器"""
+ if self.config['swing']['enabled']:
+ self.recalls['swing'] = SwingRecall(
+ self.db_config,
+ alpha=self.config['swing']['alpha']
+ )
+
+ if self.config['hot']['enabled']:
+ self.recalls['hot'] = HotRecall(self.db_config)
+
+ if self.config['ad']['enabled']:
+ self.recalls['ad'] = AdRecall(self.db_config)
+
+ if self.config['usercf']['enabled']:
+ self.recalls['usercf'] = UserCFRecall(
+ self.db_config,
+ min_common_items=self.config['usercf']['min_common_items']
+ )
+
+ def train_all(self):
+ """训练所有召回器"""
+ print("开始训练多路召回模型...")
+
+ for name, recall_model in self.recalls.items():
+ print(f"训练 {name} 召回器...")
+ recall_model.train()
+
+ print("所有召回器训练完成!")
+
+ def recall(self, user_id: int, total_items: int = 200) -> Tuple[List[int], List[float], Dict[str, List[Tuple[int, float]]]]:
+ """
+ 执行多路召回
+
+ Args:
+ user_id: 用户ID
+ total_items: 总召回物品数量
+
+ Returns:
+ Tuple containing:
+ - List of item IDs
+ - List of scores
+ - Dict of recall results by source
+ """
+ recall_results = {}
+ all_candidates = defaultdict(list) # item_id -> [(source, score), ...]
+
+ # 执行各路召回
+ for source, recall_model in self.recalls.items():
+ if not self.config[source]['enabled']:
+ continue
+
+ num_items = self.config[source]['num_items']
+
+ # 特殊处理UserCF的参数
+ if source == 'usercf':
+ items_scores = recall_model.recall(
+ user_id,
+ num_items=num_items,
+ num_similar_users=self.config[source]['num_similar_users']
+ )
+ else:
+ items_scores = recall_model.recall(user_id, num_items=num_items)
+
+ recall_results[source] = items_scores
+
+ # 收集候选物品
+ for item_id, score in items_scores:
+ all_candidates[item_id].append((source, score))
+
+ # 融合多路召回结果
+ final_candidates = self._merge_recall_results(all_candidates, total_items)
+
+ # 分离item_ids和scores
+ item_ids = [item_id for item_id, _ in final_candidates]
+ scores = [score for _, score in final_candidates]
+
+ return item_ids, scores, recall_results
+
+ def _merge_recall_results(self, all_candidates: Dict[int, List[Tuple[str, float]]],
+ total_items: int) -> List[Tuple[int, float]]:
+ """
+ 融合多路召回结果
+
+ Args:
+ all_candidates: 所有候选物品及其来源和分数
+ total_items: 最终返回的物品数量
+
+ Returns:
+ List of (item_id, final_score) tuples
+ """
+ # 定义各召回源的权重
+ source_weights = {
+ 'swing': 0.3,
+ 'hot': 0.2,
+ 'ad': 0.1,
+ 'usercf': 0.4
+ }
+
+ final_scores = []
+
+ for item_id, source_scores in all_candidates.items():
+ # 计算加权平均分数
+ weighted_score = 0.0
+ total_weight = 0.0
+
+ for source, score in source_scores:
+ weight = source_weights.get(source, 0.1)
+ weighted_score += weight * score
+ total_weight += weight
+
+ # 归一化
+ if total_weight > 0:
+ final_score = weighted_score / total_weight
+ else:
+ final_score = 0.0
+
+ # 多样性奖励:如果物品来自多个召回源,给予额外分数
+ diversity_bonus = len(source_scores) * 0.1
+ final_score += diversity_bonus
+
+ final_scores.append((item_id, final_score))
+
+ # 按最终分数排序
+ final_scores.sort(key=lambda x: x[1], reverse=True)
+
+ return final_scores[:total_items]
+
+ def get_recall_stats(self, user_id: int) -> Dict[str, Any]:
+ """
+ 获取召回统计信息
+
+ Args:
+ user_id: 用户ID
+
+ Returns:
+ 召回统计字典
+ """
+ stats = {
+ 'user_id': user_id,
+ 'enabled_recalls': list(self.recalls.keys()),
+ 'config': self.config
+ }
+
+ # 获取各召回器的统计信息
+ if 'usercf' in self.recalls:
+ try:
+ user_profile = self.recalls['usercf'].get_user_profile(user_id)
+ stats['user_profile'] = user_profile
+
+ neighbors = self.recalls['usercf'].get_user_neighbors(user_id, 5)
+ stats['similar_users'] = neighbors
+ except:
+ pass
+
+ return stats
+
+ def update_config(self, new_config: dict):
+ """
+ 更新召回配置
+
+ Args:
+ new_config: 新的配置字典
+ """
+ for key, value in new_config.items():
+ if key in self.config:
+ self.config[key].update(value)
+ else:
+ self.config[key] = value
+
+ # 重新初始化召回器
+ self._init_recalls()
+
+ def get_recall_breakdown(self, user_id: int) -> Dict[str, int]:
+ """
+ 获取各召回源的物品数量分布
+
+ Args:
+ user_id: 用户ID
+
+ Returns:
+ 各召回源的物品数量字典
+ """
+ breakdown = {}
+
+ for source in self.recalls.keys():
+ if self.config[source]['enabled']:
+ breakdown[source] = self.config[source]['num_items']
+ else:
+ breakdown[source] = 0
+
+ return breakdown
diff --git a/rhj/backend/app/models/recall/swing_recall.py b/rhj/backend/app/models/recall/swing_recall.py
new file mode 100644
index 0000000..bf7fdd6
--- /dev/null
+++ b/rhj/backend/app/models/recall/swing_recall.py
@@ -0,0 +1,126 @@
+import numpy as np
+import pymysql
+from collections import defaultdict
+import math
+from typing import List, Tuple, Dict
+
+class SwingRecall:
+ """
+ Swing召回算法实现
+ 基于物品相似度的协同过滤算法,能够有效处理热门物品的问题
+ """
+
+ def __init__(self, db_config: dict, alpha: float = 0.5):
+ """
+ 初始化Swing召回模型
+
+ Args:
+ db_config: 数据库配置
+ alpha: 控制热门物品惩罚的参数,值越大惩罚越强
+ """
+ self.db_config = db_config
+ self.alpha = alpha
+ self.item_similarity = {}
+ self.user_items = defaultdict(set)
+ self.item_users = defaultdict(set)
+
+ def _get_interaction_data(self):
+ """获取用户-物品交互数据"""
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+ # 获取用户行为数据(点赞、收藏、评论等)
+ cursor.execute("""
+ SELECT DISTINCT user_id, post_id
+ FROM behaviors
+ WHERE type IN ('like', 'favorite', 'comment')
+ """)
+ interactions = cursor.fetchall()
+
+ for user_id, post_id in interactions:
+ self.user_items[user_id].add(post_id)
+ self.item_users[post_id].add(user_id)
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ def _calculate_swing_similarity(self):
+ """计算Swing相似度矩阵"""
+ print("开始计算Swing相似度...")
+
+ # 获取所有物品对
+ items = list(self.item_users.keys())
+
+ for i, item_i in enumerate(items):
+ if i % 100 == 0:
+ print(f"处理进度: {i}/{len(items)}")
+
+ self.item_similarity[item_i] = {}
+
+ for item_j in items[i+1:]:
+ # 获取同时交互过两个物品的用户
+ common_users = self.item_users[item_i] & self.item_users[item_j]
+
+ if len(common_users) < 2: # 需要至少2个共同用户
+ similarity = 0.0
+ else:
+ # 计算Swing相似度
+ similarity = 0.0
+ for u in common_users:
+ for v in common_users:
+ if u != v:
+ # Swing算法的核心公式
+ swing_weight = 1.0 / (self.alpha + len(self.user_items[u] & self.user_items[v]))
+ similarity += swing_weight
+
+ # 归一化
+ similarity = similarity / (len(common_users) * (len(common_users) - 1))
+
+ self.item_similarity[item_i][item_j] = similarity
+ # 对称性
+ if item_j not in self.item_similarity:
+ self.item_similarity[item_j] = {}
+ self.item_similarity[item_j][item_i] = similarity
+
+ print("Swing相似度计算完成")
+
+ def train(self):
+ """训练Swing模型"""
+ self._get_interaction_data()
+ self._calculate_swing_similarity()
+
+ def recall(self, user_id: int, num_items: int = 50) -> List[Tuple[int, float]]:
+ """
+ 为用户召回相似物品
+
+ Args:
+ user_id: 用户ID
+ num_items: 召回物品数量
+
+ Returns:
+ List of (item_id, score) tuples
+ """
+ # 如果尚未训练,先进行训练
+ if not hasattr(self, 'item_similarity') or not self.item_similarity:
+ self.train()
+
+ if user_id not in self.user_items:
+ return []
+
+ # 获取用户历史交互的物品
+ user_interacted_items = self.user_items[user_id]
+
+ # 计算候选物品的分数
+ candidate_scores = defaultdict(float)
+
+ for item_i in user_interacted_items:
+ if item_i in self.item_similarity:
+ for item_j, similarity in self.item_similarity[item_i].items():
+ # 排除用户已经交互过的物品
+ if item_j not in user_interacted_items:
+ candidate_scores[item_j] += similarity
+
+ # 按分数排序并返回top-N
+ sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
+ return sorted_candidates[:num_items]
diff --git a/rhj/backend/app/models/recall/usercf_recall.py b/rhj/backend/app/models/recall/usercf_recall.py
new file mode 100644
index 0000000..d75e6d8
--- /dev/null
+++ b/rhj/backend/app/models/recall/usercf_recall.py
@@ -0,0 +1,235 @@
+import pymysql
+from typing import List, Tuple, Dict, Set
+from collections import defaultdict
+import math
+import numpy as np
+
+class UserCFRecall:
+ """
+ UserCF (User-based Collaborative Filtering) 召回算法实现
+ 基于用户相似度的协同过滤算法
+ """
+
+ def __init__(self, db_config: dict, min_common_items: int = 3):
+ """
+ 初始化UserCF召回模型
+
+ Args:
+ db_config: 数据库配置
+ min_common_items: 计算用户相似度时的最小共同物品数
+ """
+ self.db_config = db_config
+ self.min_common_items = min_common_items
+ self.user_items = defaultdict(set)
+ self.item_users = defaultdict(set)
+ self.user_similarity = {}
+
+ def _get_user_item_interactions(self):
+ """获取用户-物品交互数据"""
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+
+ # 获取用户行为数据,考虑不同行为的权重
+ cursor.execute("""
+ SELECT user_id, post_id, type, COUNT(*) as count
+ FROM behaviors
+ WHERE type IN ('like', 'favorite', 'comment', 'view')
+ GROUP BY user_id, post_id, type
+ """)
+
+ interactions = cursor.fetchall()
+
+ # 构建用户-物品交互矩阵(考虑行为权重)
+ user_item_scores = defaultdict(lambda: defaultdict(float))
+
+ # 定义不同行为的权重
+ behavior_weights = {
+ 'like': 1.0,
+ 'favorite': 2.0,
+ 'comment': 3.0,
+ 'view': 0.1
+ }
+
+ for user_id, post_id, behavior_type, count in interactions:
+ weight = behavior_weights.get(behavior_type, 1.0)
+ score = weight * count
+ user_item_scores[user_id][post_id] += score
+
+ # 转换为集合形式(用于相似度计算)
+ for user_id, items in user_item_scores.items():
+ # 只保留分数大于阈值的物品
+ threshold = 1.0 # 可调整阈值
+ for item_id, score in items.items():
+ if score >= threshold:
+ self.user_items[user_id].add(item_id)
+ self.item_users[item_id].add(user_id)
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ def _calculate_user_similarity(self):
+ """计算用户相似度矩阵"""
+ print("开始计算用户相似度...")
+
+ users = list(self.user_items.keys())
+ total_pairs = len(users) * (len(users) - 1) // 2
+ processed = 0
+
+ for i, user_i in enumerate(users):
+ self.user_similarity[user_i] = {}
+
+ for user_j in users[i+1:]:
+ processed += 1
+ if processed % 10000 == 0:
+ print(f"处理进度: {processed}/{total_pairs}")
+
+ # 获取两个用户共同交互的物品
+ common_items = self.user_items[user_i] & self.user_items[user_j]
+
+ if len(common_items) < self.min_common_items:
+ similarity = 0.0
+ else:
+ # 计算余弦相似度
+ numerator = len(common_items)
+ denominator = math.sqrt(len(self.user_items[user_i]) * len(self.user_items[user_j]))
+ similarity = numerator / denominator if denominator > 0 else 0.0
+
+ self.user_similarity[user_i][user_j] = similarity
+ # 对称性
+ if user_j not in self.user_similarity:
+ self.user_similarity[user_j] = {}
+ self.user_similarity[user_j][user_i] = similarity
+
+ print("用户相似度计算完成")
+
+ def train(self):
+ """训练UserCF模型"""
+ self._get_user_item_interactions()
+ self._calculate_user_similarity()
+
+ def recall(self, user_id: int, num_items: int = 50, num_similar_users: int = 50) -> List[Tuple[int, float]]:
+ """
+ 为用户召回相似用户喜欢的物品
+
+ Args:
+ user_id: 目标用户ID
+ num_items: 召回物品数量
+ num_similar_users: 考虑的相似用户数量
+
+ Returns:
+ List of (item_id, score) tuples
+ """
+ # 如果尚未训练,先进行训练
+ if not hasattr(self, 'user_similarity') or not self.user_similarity:
+ self.train()
+
+ if user_id not in self.user_similarity or user_id not in self.user_items:
+ return []
+
+ # 获取最相似的用户
+ similar_users = sorted(
+ self.user_similarity[user_id].items(),
+ key=lambda x: x[1],
+ reverse=True
+ )[:num_similar_users]
+
+ # 获取目标用户已交互的物品
+ user_interacted_items = self.user_items[user_id]
+
+ # 计算候选物品的分数
+ candidate_scores = defaultdict(float)
+
+ for similar_user_id, similarity in similar_users:
+ if similarity <= 0:
+ continue
+
+ # 获取相似用户交互的物品
+ similar_user_items = self.user_items[similar_user_id]
+
+ for item_id in similar_user_items:
+ # 排除目标用户已经交互过的物品
+ if item_id not in user_interacted_items:
+ candidate_scores[item_id] += similarity
+
+ # 按分数排序并返回top-N
+ sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
+ return sorted_candidates[:num_items]
+
+ def get_user_neighbors(self, user_id: int, num_neighbors: int = 10) -> List[Tuple[int, float]]:
+ """
+ 获取用户的相似邻居
+
+ Args:
+ user_id: 用户ID
+ num_neighbors: 邻居数量
+
+ Returns:
+ List of (neighbor_user_id, similarity) tuples
+ """
+ if user_id not in self.user_similarity:
+ return []
+
+ neighbors = sorted(
+ self.user_similarity[user_id].items(),
+ key=lambda x: x[1],
+ reverse=True
+ )[:num_neighbors]
+
+ return neighbors
+
+ def get_user_profile(self, user_id: int) -> Dict:
+ """
+ 获取用户画像信息
+
+ Args:
+ user_id: 用户ID
+
+ Returns:
+ 用户画像字典
+ """
+ if user_id not in self.user_items:
+ return {}
+
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+
+ # 获取用户交互的物品类别统计
+ user_item_list = list(self.user_items[user_id])
+ if not user_item_list:
+ return {}
+
+ format_strings = ','.join(['%s'] * len(user_item_list))
+ cursor.execute(f"""
+ SELECT t.name, COUNT(*) as count
+ FROM post_tags pt
+ JOIN tags t ON pt.tag_id = t.id
+ WHERE pt.post_id IN ({format_strings})
+ GROUP BY t.name
+ ORDER BY count DESC
+ """, tuple(user_item_list))
+
+ tag_preferences = cursor.fetchall()
+
+ # 获取用户行为统计
+ cursor.execute("""
+ SELECT type, COUNT(*) as count
+ FROM behaviors
+ WHERE user_id = %s
+ GROUP BY type
+ """, (user_id,))
+
+ behavior_stats = cursor.fetchall()
+
+ return {
+ 'user_id': user_id,
+ 'total_interactions': len(self.user_items[user_id]),
+ 'tag_preferences': dict(tag_preferences),
+ 'behavior_stats': dict(behavior_stats)
+ }
+
+ finally:
+ cursor.close()
+ conn.close()
diff --git a/rhj/backend/app/models/recommend/LightGCN.py b/rhj/backend/app/models/recommend/LightGCN.py
new file mode 100644
index 0000000..38b1732
--- /dev/null
+++ b/rhj/backend/app/models/recommend/LightGCN.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import scipy.sparse as sp
+import math
+import networkx as nx
+import random
+from copy import deepcopy
+from app.utils.parse_args import args
+from app.models.recommend.base_model import BaseModel
+from app.models.recommend.operators import EdgelistDrop
+from app.models.recommend.operators import scatter_add, scatter_sum
+
+
+init = nn.init.xavier_uniform_
+
+class LightGCN(BaseModel):
+ def __init__(self, dataset, pretrained_model=None, phase='pretrain'):
+ super().__init__(dataset)
+ self.adj = self._make_binorm_adj(dataset.graph)
+ self.edges = self.adj._indices().t()
+ self.edge_norm = self.adj._values()
+
+ self.phase = phase
+
+ self.emb_gate = lambda x: x
+
+ if self.phase == 'pretrain' or self.phase == 'vanilla' or self.phase == 'for_tune':
+ self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size)))
+ self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size)))
+
+
+ elif self.phase == 'finetune':
+ pre_user_emb, pre_item_emb = pretrained_model.generate()
+ self.user_embedding = nn.Parameter(pre_user_emb).requires_grad_(True)
+ self.item_embedding = nn.Parameter(pre_item_emb).requires_grad_(True)
+
+ elif self.phase == 'continue_tune':
+ # re-initialize for loading state dict
+ self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size)))
+ self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size)))
+
+ self.edge_dropout = EdgelistDrop()
+
+ def _agg(self, all_emb, edges, edge_norm):
+ src_emb = all_emb[edges[:, 0]]
+
+ # bi-norm
+ src_emb = src_emb * edge_norm.unsqueeze(1)
+
+ # conv
+ dst_emb = scatter_sum(src_emb, edges[:, 1], dim=0, dim_size=self.num_users+self.num_items)
+ return dst_emb
+
+ def _edge_binorm(self, edges):
+ user_degs = scatter_add(torch.ones_like(edges[:, 0]), edges[:, 0], dim=0, dim_size=self.num_users)
+ user_degs = user_degs[edges[:, 0]]
+ item_degs = scatter_add(torch.ones_like(edges[:, 1]), edges[:, 1], dim=0, dim_size=self.num_items)
+ item_degs = item_degs[edges[:, 1]]
+ norm = torch.pow(user_degs, -0.5) * torch.pow(item_degs, -0.5)
+ return norm
+
+ def forward(self, edges, edge_norm, return_layers=False):
+ all_emb = torch.cat([self.user_embedding, self.item_embedding], dim=0)
+ all_emb = self.emb_gate(all_emb)
+ res_emb = [all_emb]
+ for l in range(args.num_layers):
+ all_emb = self._agg(res_emb[-1], edges, edge_norm)
+ res_emb.append(all_emb)
+ if not return_layers:
+ res_emb = sum(res_emb)
+ user_res_emb, item_res_emb = res_emb.split([self.num_users, self.num_items], dim=0)
+ else:
+ user_res_emb, item_res_emb = [], []
+ for emb in res_emb:
+ u_emb, i_emb = emb.split([self.num_users, self.num_items], dim=0)
+ user_res_emb.append(u_emb)
+ item_res_emb.append(i_emb)
+ return user_res_emb, item_res_emb
+
+ def cal_loss(self, batch_data):
+ edges, dropout_mask = self.edge_dropout(self.edges, 1-args.edge_dropout, return_mask=True)
+ edge_norm = self.edge_norm[dropout_mask]
+
+ # forward
+ users, pos_items, neg_items = batch_data
+ user_emb, item_emb = self.forward(edges, edge_norm)
+ batch_user_emb = user_emb[users]
+ pos_item_emb = item_emb[pos_items]
+ neg_item_emb = item_emb[neg_items]
+ rec_loss = self._bpr_loss(batch_user_emb, pos_item_emb, neg_item_emb)
+ reg_loss = args.weight_decay * self._reg_loss(users, pos_items, neg_items)
+
+ loss = rec_loss + reg_loss
+ loss_dict = {
+ "rec_loss": rec_loss.item(),
+ "reg_loss": reg_loss.item(),
+ }
+ return loss, loss_dict
+
+ @torch.no_grad()
+ def generate(self, return_layers=False):
+ return self.forward(self.edges, self.edge_norm, return_layers=return_layers)
+
+ @torch.no_grad()
+ def generate_lgn(self, return_layers=False):
+ return self.forward(self.edges, self.edge_norm, return_layers=return_layers)
+
+ @torch.no_grad()
+ def rating(self, user_emb, item_emb):
+ return torch.matmul(user_emb, item_emb.t())
+
+ def _reg_loss(self, users, pos_items, neg_items):
+ u_emb = self.user_embedding[users]
+ pos_i_emb = self.item_embedding[pos_items]
+ neg_i_emb = self.item_embedding[neg_items]
+ reg_loss = (1/2)*(u_emb.norm(2).pow(2) +
+ pos_i_emb.norm(2).pow(2) +
+ neg_i_emb.norm(2).pow(2))/float(len(users))
+ return reg_loss
diff --git a/rhj/backend/app/models/recommend/LightGCN_pretrained.pt b/rhj/backend/app/models/recommend/LightGCN_pretrained.pt
new file mode 100644
index 0000000..825e0e2
--- /dev/null
+++ b/rhj/backend/app/models/recommend/LightGCN_pretrained.pt
Binary files differ
diff --git a/rhj/backend/app/models/recommend/__pycache__/LightGCN.cpython-312.pyc b/rhj/backend/app/models/recommend/__pycache__/LightGCN.cpython-312.pyc
new file mode 100644
index 0000000..c87435f
--- /dev/null
+++ b/rhj/backend/app/models/recommend/__pycache__/LightGCN.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recommend/__pycache__/base_model.cpython-312.pyc b/rhj/backend/app/models/recommend/__pycache__/base_model.cpython-312.pyc
new file mode 100644
index 0000000..b9d8c72
--- /dev/null
+++ b/rhj/backend/app/models/recommend/__pycache__/base_model.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recommend/__pycache__/lightgcn_scorer.cpython-312.pyc b/rhj/backend/app/models/recommend/__pycache__/lightgcn_scorer.cpython-312.pyc
new file mode 100644
index 0000000..b0887a9
--- /dev/null
+++ b/rhj/backend/app/models/recommend/__pycache__/lightgcn_scorer.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recommend/__pycache__/operators.cpython-312.pyc b/rhj/backend/app/models/recommend/__pycache__/operators.cpython-312.pyc
new file mode 100644
index 0000000..13bb375
--- /dev/null
+++ b/rhj/backend/app/models/recommend/__pycache__/operators.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recommend/base_model.py b/rhj/backend/app/models/recommend/base_model.py
new file mode 100644
index 0000000..6c59aa6
--- /dev/null
+++ b/rhj/backend/app/models/recommend/base_model.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+from app.utils.parse_args import args
+from scipy.sparse import csr_matrix
+import scipy.sparse as sp
+import numpy as np
+import torch.nn.functional as F
+
+
+class BaseModel(nn.Module):
+ def __init__(self, dataloader):
+ super(BaseModel, self).__init__()
+ self.num_users = dataloader.num_users
+ self.num_items = dataloader.num_items
+ self.emb_size = args.emb_size
+
+ def forward(self):
+ pass
+
+ def cal_loss(self, batch_data):
+ pass
+
+ def _check_inf(self, loss, pos_score, neg_score, edge_weight):
+ # find inf idx
+ inf_idx = torch.isinf(loss) | torch.isnan(loss)
+ if inf_idx.any():
+ print("find inf in loss")
+ if type(edge_weight) != int:
+ print(edge_weight[inf_idx])
+ print(f"pos_score: {pos_score[inf_idx]}")
+ print(f"neg_score: {neg_score[inf_idx]}")
+ raise ValueError("find inf in loss")
+
+ def _make_binorm_adj(self, mat):
+ a = csr_matrix((self.num_users, self.num_users))
+ b = csr_matrix((self.num_items, self.num_items))
+ mat = sp.vstack(
+ [sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
+ mat = (mat != 0) * 1.0
+ # mat = (mat + sp.eye(mat.shape[0])) * 1.0# MARK
+ degree = np.array(mat.sum(axis=-1))
+ d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
+ d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
+ d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
+ mat = mat.dot(d_inv_sqrt_mat).transpose().dot(
+ d_inv_sqrt_mat).tocoo()
+
+ # make torch tensor
+ idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
+ vals = torch.from_numpy(mat.data.astype(np.float32))
+ shape = torch.Size(mat.shape)
+ return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device)
+
+ def _make_binorm_adj_self_loop(self, mat):
+ a = csr_matrix((self.num_users, self.num_users))
+ b = csr_matrix((self.num_items, self.num_items))
+ mat = sp.vstack(
+ [sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
+ mat = (mat != 0) * 1.0
+ mat = (mat + sp.eye(mat.shape[0])) * 1.0 # self loop
+ degree = np.array(mat.sum(axis=-1))
+ d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
+ d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
+ d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
+ mat = mat.dot(d_inv_sqrt_mat).transpose().dot(
+ d_inv_sqrt_mat).tocoo()
+
+ # make torch tensor
+ idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
+ vals = torch.from_numpy(mat.data.astype(np.float32))
+ shape = torch.Size(mat.shape)
+ return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device)
+
+
+ def _sp_matrix_to_sp_tensor(self, sp_matrix):
+ coo = sp_matrix.tocoo()
+ indices = torch.LongTensor([coo.row, coo.col])
+ values = torch.FloatTensor(coo.data)
+ return torch.sparse.FloatTensor(indices, values, coo.shape).coalesce().to(args.device)
+
+ def _bpr_loss(self, user_emb, pos_item_emb, neg_item_emb):
+ pos_score = (user_emb * pos_item_emb).sum(dim=1)
+ neg_score = (user_emb * neg_item_emb).sum(dim=1)
+ loss = -torch.log(1e-10 + torch.sigmoid((pos_score - neg_score)))
+ self._check_inf(loss, pos_score, neg_score, 0)
+ return loss.mean()
+
+ def _nce_loss(self, pos_score, neg_score, edge_weight=1):
+ numerator = torch.exp(pos_score)
+ denominator = torch.exp(pos_score) + torch.exp(neg_score).sum(dim=1)
+ loss = -torch.log(numerator/denominator) * edge_weight
+ self._check_inf(loss, pos_score, neg_score, edge_weight)
+ return loss.mean()
+
+ def _infonce_loss(self, pos_1, pos_2, negs, tau):
+ pos_1 = self.cl_mlp(pos_1)
+ pos_2 = self.cl_mlp(pos_2)
+ negs = self.cl_mlp(negs)
+ pos_1 = F.normalize(pos_1, dim=-1)
+ pos_2 = F.normalize(pos_2, dim=-1)
+ negs = F.normalize(negs, dim=-1)
+ pos_score = torch.mul(pos_1, pos_2).sum(dim=1)
+ # B, 1, E * B, E, N -> B, N
+ neg_score = torch.bmm(pos_1.unsqueeze(1), negs.transpose(1, 2)).squeeze(1)
+ # infonce loss
+ numerator = torch.exp(pos_score / tau)
+ denominator = torch.exp(pos_score / tau) + torch.exp(neg_score / tau).sum(dim=1)
+ loss = -torch.log(numerator/denominator)
+ self._check_inf(loss, pos_score, neg_score, 0)
+ return loss.mean()
+
\ No newline at end of file
diff --git a/rhj/backend/app/models/recommend/operators.py b/rhj/backend/app/models/recommend/operators.py
new file mode 100644
index 0000000..a508966
--- /dev/null
+++ b/rhj/backend/app/models/recommend/operators.py
@@ -0,0 +1,52 @@
+import torch
+from typing import Optional, Tuple
+from torch import nn
+
+def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
+ if dim < 0:
+ dim = other.dim() + dim
+ if src.dim() == 1:
+ for _ in range(0, dim):
+ src = src.unsqueeze(0)
+ for _ in range(src.dim(), other.dim()):
+ src = src.unsqueeze(-1)
+ src = src.expand(other.size())
+ return src
+
+def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
+ out: Optional[torch.Tensor] = None,
+ dim_size: Optional[int] = None) -> torch.Tensor:
+ index = broadcast(index, src, dim)
+ if out is None:
+ size = list(src.size())
+ if dim_size is not None:
+ size[dim] = dim_size
+ elif index.numel() == 0:
+ size[dim] = 0
+ else:
+ size[dim] = int(index.max()) + 1
+ out = torch.zeros(size, dtype=src.dtype, device=src.device)
+ return out.scatter_add_(dim, index, src)
+ else:
+ return out.scatter_add_(dim, index, src)
+
+def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
+ out: Optional[torch.Tensor] = None,
+ dim_size: Optional[int] = None) -> torch.Tensor:
+ return scatter_sum(src, index, dim, out, dim_size)
+
+
+class EdgelistDrop(nn.Module):
+ def __init__(self):
+ super(EdgelistDrop, self).__init__()
+
+ def forward(self, edgeList, keep_rate, return_mask=False):
+ if keep_rate == 1.0:
+ return edgeList, torch.ones(edgeList.size(0)).type(torch.bool)
+ edgeNum = edgeList.size(0)
+ mask = (torch.rand(edgeNum) + keep_rate).floor().type(torch.bool)
+ newEdgeList = edgeList[mask, :]
+ if return_mask:
+ return newEdgeList, mask
+ else:
+ return newEdgeList
diff --git a/rhj/backend/app/services/__pycache__/__init__.cpython-312.pyc b/rhj/backend/app/services/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000..769373b
--- /dev/null
+++ b/rhj/backend/app/services/__pycache__/__init__.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/services/__pycache__/lightgcn_scorer.cpython-312.pyc b/rhj/backend/app/services/__pycache__/lightgcn_scorer.cpython-312.pyc
new file mode 100644
index 0000000..2c86f52
--- /dev/null
+++ b/rhj/backend/app/services/__pycache__/lightgcn_scorer.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/services/__pycache__/recommendation_service.cpython-312.pyc b/rhj/backend/app/services/__pycache__/recommendation_service.cpython-312.pyc
new file mode 100644
index 0000000..da8389f
--- /dev/null
+++ b/rhj/backend/app/services/__pycache__/recommendation_service.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/services/lightgcn_scorer.py b/rhj/backend/app/services/lightgcn_scorer.py
new file mode 100644
index 0000000..f6aeb19
--- /dev/null
+++ b/rhj/backend/app/services/lightgcn_scorer.py
@@ -0,0 +1,295 @@
+"""
+LightGCN评分服务
+用于对多路召回的结果进行LightGCN打分
+"""
+
+import torch
+import numpy as np
+from typing import List, Tuple, Dict, Any
+from app.models.recommend.LightGCN import LightGCN
+from app.utils.parse_args import args
+from app.utils.data_loader import EdgeListData
+from app.utils.graph_build import build_user_post_graph
+
+
+class LightGCNScorer:
+ """
+ LightGCN评分器
+ 专门用于对多路召回结果进行精准打分
+ """
+
+ def __init__(self):
+ """初始化LightGCN评分器"""
+ # 设备配置
+ args.device = 'cuda:7' if torch.cuda.is_available() else 'cpu'
+ args.data_path = './app/user_post_graph.txt'
+ args.pre_model_path = './app/models/recommend/LightGCN_pretrained.pt'
+
+ # 模型相关变量
+ self.model = None
+ self.user2idx = None
+ self.post2idx = None
+ self.idx2post = None
+ self.dataset = None
+ self.user_embeddings = None
+ self.item_embeddings = None
+
+ # 是否已初始化
+ self._initialized = False
+
+ def _initialize_model(self):
+ """初始化LightGCN模型"""
+ if self._initialized:
+ return
+
+ print("初始化LightGCN评分模型...")
+
+ # 构建用户-物品映射
+ self.user2idx, self.post2idx = build_user_post_graph(return_mapping=True)
+ self.idx2post = {v: k for k, v in self.post2idx.items()}
+
+ # 加载数据集
+ self.dataset = EdgeListData(args.data_path, args.data_path)
+
+ # 加载预训练模型
+ pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
+ pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:self.dataset.num_users]
+ pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:self.dataset.num_items]
+
+ # 初始化模型
+ self.model = LightGCN(self.dataset, phase='vanilla').to(args.device)
+ self.model.load_state_dict(pretrained_dict, strict=False)
+ self.model.eval()
+
+ # 预先计算所有用户和物品的嵌入表示
+ with torch.no_grad():
+ self.user_embeddings, self.item_embeddings = self.model.generate()
+
+ self._initialized = True
+ print("LightGCN评分模型初始化完成")
+
+ def score_candidates(self, user_id: int, candidate_post_ids: List[int]) -> List[float]:
+ """
+ 对候选物品进行LightGCN打分
+
+ Args:
+ user_id: 用户ID
+ candidate_post_ids: 候选物品ID列表
+
+ Returns:
+ List[float]: 每个候选物品的LightGCN分数
+ """
+ self._initialize_model()
+
+ # 检查用户是否存在
+ if user_id not in self.user2idx:
+ print(f"用户 {user_id} 不在训练数据中,返回零分数")
+ return [0.0] * len(candidate_post_ids)
+
+ user_idx = self.user2idx[user_id]
+ scores = []
+
+ print(len(candidate_post_ids), "候选物品数量")
+
+ with torch.no_grad():
+ user_emb = self.user_embeddings[user_idx].unsqueeze(0) # [1, emb_size]
+
+ for post_id in candidate_post_ids:
+ if post_id not in self.post2idx:
+ # 物品不在训练数据中,给予默认分数
+ scores.append(0.0)
+ continue
+
+ post_idx = self.post2idx[post_id]
+ item_emb = self.item_embeddings[post_idx].unsqueeze(0) # [1, emb_size]
+
+ # 计算评分:用户嵌入和物品嵌入的内积
+ score = torch.matmul(user_emb, item_emb.t()).item()
+ scores.append(float(score))
+
+ return scores
+
+ def score_batch_candidates(self, user_id: int, candidate_post_ids: List[int]) -> List[float]:
+ """
+ 批量对候选物品进行LightGCN打分(更高效)
+
+ Args:
+ user_id: 用户ID
+ candidate_post_ids: 候选物品ID列表
+
+ Returns:
+ List[float]: 每个候选物品的LightGCN分数
+ """
+ self._initialize_model()
+
+ # 检查用户是否存在
+ if user_id not in self.user2idx:
+ print(f"用户 {user_id} 不在训练数据中,返回零分数")
+ return [0.0] * len(candidate_post_ids)
+
+ print(len(candidate_post_ids), "候选物品数量")
+
+ user_idx = self.user2idx[user_id]
+
+ # 过滤出存在于训练数据中的物品
+ valid_items = []
+ valid_indices = []
+ for i, post_id in enumerate(candidate_post_ids):
+ if post_id in self.post2idx:
+ valid_items.append(self.post2idx[post_id])
+ valid_indices.append(i)
+
+ scores = [0.0] * len(candidate_post_ids)
+
+ if not valid_items:
+ return scores
+
+ with torch.no_grad():
+ user_emb = self.user_embeddings[user_idx].unsqueeze(0) # [1, emb_size]
+
+ # 批量获取物品嵌入
+ valid_item_indices = torch.tensor(valid_items, device=args.device)
+ valid_item_embs = self.item_embeddings[valid_item_indices] # [num_valid_items, emb_size]
+
+ # 批量计算评分
+ batch_scores = torch.matmul(user_emb, valid_item_embs.t()).squeeze(0) # [num_valid_items]
+
+ # 将分数填回原位置
+ for i, score in enumerate(batch_scores.cpu().numpy()):
+ original_idx = valid_indices[i]
+ scores[original_idx] = float(score)
+
+ return scores
+
+ def get_user_profile(self, user_id: int) -> Dict[str, Any]:
+ """
+ 获取用户在LightGCN中的表示和统计信息
+
+ Args:
+ user_id: 用户ID
+
+ Returns:
+ Dict: 用户画像信息
+ """
+ self._initialize_model()
+
+ if user_id not in self.user2idx:
+ return {
+ 'user_id': user_id,
+ 'exists_in_model': False,
+ 'message': '用户不在训练数据中'
+ }
+
+ user_idx = self.user2idx[user_id]
+
+ with torch.no_grad():
+ user_emb = self.user_embeddings[user_idx]
+
+ # 计算用户嵌入的统计信息
+ emb_norm = torch.norm(user_emb).item()
+ emb_mean = torch.mean(user_emb).item()
+ emb_std = torch.std(user_emb).item()
+
+ # 找到与用户最相似的物品(基于余弦相似度)
+ user_emb_normalized = user_emb / torch.norm(user_emb)
+ item_embs_normalized = self.item_embeddings / torch.norm(self.item_embeddings, dim=1, keepdim=True)
+
+ similarities = torch.matmul(user_emb_normalized.unsqueeze(0), item_embs_normalized.t()).squeeze(0)
+ top_k_indices = torch.topk(similarities, k=10).indices.cpu().numpy()
+
+ top_similar_items = []
+ for idx in top_k_indices:
+ if idx < len(self.idx2post):
+ post_id = self.idx2post[idx]
+ similarity = similarities[idx].item()
+ top_similar_items.append({
+ 'post_id': post_id,
+ 'similarity': float(similarity)
+ })
+
+ return {
+ 'user_id': user_id,
+ 'user_idx': user_idx,
+ 'exists_in_model': True,
+ 'embedding_stats': {
+ 'norm': float(emb_norm),
+ 'mean': float(emb_mean),
+ 'std': float(emb_std),
+ 'dimension': user_emb.shape[0]
+ },
+ 'top_similar_items': top_similar_items
+ }
+
+ def compare_scoring_methods(self, user_id: int, candidate_post_ids: List[int]) -> Dict[str, List[float]]:
+ """
+ 比较不同的评分方法
+
+ Args:
+ user_id: 用户ID
+ candidate_post_ids: 候选物品ID列表
+
+ Returns:
+ Dict: 包含不同评分方法结果的字典
+ """
+ self._initialize_model()
+
+ if user_id not in self.user2idx:
+ zero_scores = [0.0] * len(candidate_post_ids)
+ return {
+ 'lightgcn_inner_product': zero_scores,
+ 'lightgcn_cosine_similarity': zero_scores,
+ 'message': f'用户 {user_id} 不在训练数据中'
+ }
+
+ user_idx = self.user2idx[user_id]
+
+ inner_product_scores = []
+ cosine_similarity_scores = []
+
+ with torch.no_grad():
+ user_emb = self.user_embeddings[user_idx]
+ user_emb_normalized = user_emb / torch.norm(user_emb)
+
+ for post_id in candidate_post_ids:
+ if post_id not in self.post2idx:
+ inner_product_scores.append(0.0)
+ cosine_similarity_scores.append(0.0)
+ continue
+
+ post_idx = self.post2idx[post_id]
+ item_emb = self.item_embeddings[post_idx]
+ item_emb_normalized = item_emb / torch.norm(item_emb)
+
+ # 内积评分
+ inner_score = torch.dot(user_emb, item_emb).item()
+ inner_product_scores.append(float(inner_score))
+
+ # 余弦相似度评分
+ cosine_score = torch.dot(user_emb_normalized, item_emb_normalized).item()
+ cosine_similarity_scores.append(float(cosine_score))
+
+ return {
+ 'lightgcn_inner_product': inner_product_scores,
+ 'lightgcn_cosine_similarity': cosine_similarity_scores
+ }
+
+ def get_model_info(self) -> Dict[str, Any]:
+ """
+ 获取LightGCN模型的基本信息
+
+ Returns:
+ Dict: 模型信息
+ """
+ self._initialize_model()
+
+ return {
+ 'model_type': 'LightGCN',
+ 'device': str(args.device),
+ 'num_users': self.dataset.num_users,
+ 'num_items': self.dataset.num_items,
+ 'embedding_size': self.user_embeddings.shape[1],
+ 'num_layers': args.num_layers,
+ 'pretrained_model_path': args.pre_model_path,
+ 'data_path': args.data_path,
+ 'initialized': self._initialized
+ }
diff --git a/rhj/backend/app/services/recommendation_service.py b/rhj/backend/app/services/recommendation_service.py
new file mode 100644
index 0000000..2f4de13
--- /dev/null
+++ b/rhj/backend/app/services/recommendation_service.py
@@ -0,0 +1,719 @@
+import torch
+import pymysql
+import numpy as np
+import random
+from app.models.recommend.LightGCN import LightGCN
+from app.models.recall import MultiRecallManager
+from app.services.lightgcn_scorer import LightGCNScorer
+from app.utils.parse_args import args
+from app.utils.data_loader import EdgeListData
+from app.utils.graph_build import build_user_post_graph
+from config import Config
+
+class RecommendationService:
+ def __init__(self):
+ # 数据库连接配置 - 修改为redbook数据库
+ self.db_config = {
+ 'host': '10.126.59.25',
+ 'port': 3306,
+ 'user': 'root',
+ 'password': '123456',
+ 'database': 'redbook', # 使用redbook数据库
+ 'charset': 'utf8mb4'
+ }
+
+ # 模型配置
+ args.device = 'cuda:7' if torch.cuda.is_available() else 'cpu'
+ args.data_path = './app/user_post_graph.txt' # 修改为帖子图文件
+ args.pre_model_path = './app/models/recommend/LightGCN_pretrained.pt'
+
+ self.topk = 2 # 默认推荐数量
+
+ # 初始化多路召回管理器
+ self.multi_recall = None
+ self.multi_recall_enabled = True # 控制是否启用多路召回
+
+ # 初始化LightGCN评分器
+ self.lightgcn_scorer = None
+ self.use_lightgcn_rerank = True # 控制是否使用LightGCN对多路召回结果重新打分
+
+ # 多路召回配置
+ self.recall_config = {
+ 'swing': {
+ 'enabled': True,
+ 'num_items': 20, # 增加召回数量
+ 'alpha': 0.5
+ },
+ 'hot': {
+ 'enabled': True,
+ 'num_items': 15 # 增加热度召回数量
+ },
+ 'ad': {
+ 'enabled': True,
+ 'num_items': 5 # 增加广告召回数量
+ },
+ 'usercf': {
+ 'enabled': True,
+ 'num_items': 15,
+ 'min_common_items': 1, # 降低阈值,从3改为1
+ 'num_similar_users': 20 # 减少相似用户数量以提高效率
+ }
+ }
+
+ def calculate_tag_similarity(self, tags1, tags2):
+ """
+ 计算两个帖子标签的相似度
+ 输入: tags1, tags2 - 标签字符串,以逗号分隔
+ 输出: 相似度分数(0-1之间)
+ """
+ if not tags1 or not tags2:
+ return 0.0
+
+ # 将标签字符串转换为集合
+ set1 = set(tag.strip() for tag in tags1.split(',') if tag.strip())
+ set2 = set(tag.strip() for tag in tags2.split(',') if tag.strip())
+
+ if not set1 or not set2:
+ return 0.0
+
+ # 计算标签重叠比例(Jaccard相似度)
+ intersection = len(set1.intersection(set2))
+ union = len(set1.union(set2))
+
+ return intersection / union if union > 0 else 0.0
+
+ def mmr_rerank_with_ads(self, post_ids, scores, theta=0.5, target_size=None):
+ """
+ 使用MMR算法重新排序推荐结果,并在过程中加入广告约束
+ 输入:
+ - post_ids: 帖子ID列表
+ - scores: 对应的推荐分数列表
+ - theta: 平衡相关性和多样性的参数(0.5表示各占一半)
+ - target_size: 目标结果数量,默认与输入相同
+ 输出: 重排后的(post_ids, scores),每5条帖子包含1条广告
+ """
+ if target_size is None:
+ target_size = len(post_ids)
+
+ if len(post_ids) <= 1:
+ return post_ids, scores
+
+ # 获取帖子标签信息和广告标识
+ conn = pymysql.connect(**self.db_config)
+ cursor = conn.cursor()
+
+ try:
+ # 查询所有候选帖子的标签和广告标识
+ format_strings = ','.join(['%s'] * len(post_ids))
+ cursor.execute(
+ f"""SELECT p.id, p.is_advertisement,
+ COALESCE(GROUP_CONCAT(t.name), '') as tags
+ FROM posts p
+ LEFT JOIN post_tags pt ON p.id = pt.post_id
+ LEFT JOIN tags t ON pt.tag_id = t.id
+ WHERE p.id IN ({format_strings}) AND p.status = 'published'
+ GROUP BY p.id, p.is_advertisement""",
+ tuple(post_ids)
+ )
+ post_info_rows = cursor.fetchall()
+ post_tags = {}
+ post_is_ad = {}
+
+ for row in post_info_rows:
+ post_id, is_ad, tags = row
+ post_tags[post_id] = tags or ""
+ post_is_ad[post_id] = bool(is_ad)
+
+ # 对于没有查询到的帖子,设置默认值
+ for post_id in post_ids:
+ if post_id not in post_tags:
+ post_tags[post_id] = ""
+ post_is_ad[post_id] = False
+
+ # 获取额外的广告帖子作为候选
+ cursor.execute("""
+ SELECT id, heat FROM posts
+ WHERE is_advertisement = 1 AND status = 'published'
+ AND id NOT IN ({})
+ ORDER BY heat DESC
+ LIMIT 50
+ """.format(format_strings), tuple(post_ids))
+ extra_ad_rows = cursor.fetchall()
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ # 分离普通帖子和广告帖子
+ normal_candidates = []
+ ad_candidates = []
+
+ for post_id, score in zip(post_ids, scores):
+ if post_is_ad[post_id]:
+ ad_candidates.append((post_id, score))
+ else:
+ normal_candidates.append((post_id, score))
+
+ # 添加额外的广告候选
+ for ad_id, heat in extra_ad_rows:
+ # 为广告帖子设置标签和广告标识
+ post_tags[ad_id] = "" # 广告帖子暂时设置为空标签
+ post_is_ad[ad_id] = True
+ ad_score = float(heat) / 1000.0 # 将热度转换为分数
+ ad_candidates.append((ad_id, ad_score))
+
+ # 排序候选列表
+ normal_candidates.sort(key=lambda x: x[1], reverse=True)
+ ad_candidates.sort(key=lambda x: x[1], reverse=True)
+
+ # MMR算法实现,加入广告约束
+ selected = []
+ normal_idx = 0
+ ad_idx = 0
+
+ while len(selected) < target_size:
+ current_position = len(selected)
+
+ # 检查是否需要插入广告(每5个位置插入1个广告)
+ if (current_position + 1) % 5 == 0 and ad_idx < len(ad_candidates):
+ # 插入广告
+ selected.append(ad_candidates[ad_idx])
+ ad_idx += 1
+ else:
+ # 使用MMR选择普通帖子
+ if normal_idx >= len(normal_candidates):
+ break
+
+ best_score = -float('inf')
+ best_local_idx = normal_idx
+
+ # 在剩余的普通候选中选择最佳的
+ for i in range(normal_idx, min(normal_idx + 10, len(normal_candidates))):
+ post_id, relevance_score = normal_candidates[i]
+
+ # 计算与已选帖子的最大相似度
+ max_similarity = 0.0
+ current_tags = post_tags[post_id]
+
+ for selected_post_id, _ in selected:
+ selected_tags = post_tags[selected_post_id]
+ similarity = self.calculate_tag_similarity(current_tags, selected_tags)
+ max_similarity = max(max_similarity, similarity)
+
+ # 计算MMR分数
+ mmr_score = theta * relevance_score - (1 - theta) * max_similarity
+
+ if mmr_score > best_score:
+ best_score = mmr_score
+ best_local_idx = i
+
+ # 选择最佳候选
+ selected.append(normal_candidates[best_local_idx])
+ # 将选中的元素移到已处理区域
+ normal_candidates[normal_idx], normal_candidates[best_local_idx] = \
+ normal_candidates[best_local_idx], normal_candidates[normal_idx]
+ normal_idx += 1
+
+ # 提取重排后的结果
+ reranked_post_ids = [post_id for post_id, _ in selected]
+ reranked_scores = [score for _, score in selected]
+
+ return reranked_post_ids, reranked_scores
+
+ def insert_advertisements(self, post_ids, scores):
+ """
+ 在推荐结果中插入广告,每5条帖子插入1条广告
+ 输入: post_ids, scores - 原始推荐结果
+ 输出: 插入广告后的(post_ids, scores)
+ """
+ # 获取可用的广告帖子
+ conn = pymysql.connect(**self.db_config)
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute("""
+ SELECT id, heat FROM posts
+ WHERE is_advertisement = 1 AND status = 'published'
+ ORDER BY heat DESC
+ LIMIT 50
+ """)
+ ad_rows = cursor.fetchall()
+
+ if not ad_rows:
+ # 没有广告,直接返回原结果
+ return post_ids, scores
+
+ # 可用的广告帖子(排除已在推荐结果中的)
+ available_ads = [(ad_id, heat) for ad_id, heat in ad_rows if ad_id not in post_ids]
+
+ if not available_ads:
+ # 没有可用的新广告,直接返回原结果
+ return post_ids, scores
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ # 插入广告的逻辑
+ result_posts = []
+ result_scores = []
+ ad_index = 0
+
+ for i, (post_id, score) in enumerate(zip(post_ids, scores)):
+ result_posts.append(post_id)
+ result_scores.append(score)
+
+ # 每5条帖子后插入一条广告
+ if (i + 1) % 5 == 0 and ad_index < len(available_ads):
+ ad_id, ad_heat = available_ads[ad_index]
+ result_posts.append(ad_id)
+ result_scores.append(float(ad_heat) / 1000.0) # 将热度转换为分数范围
+ ad_index += 1
+
+ return result_posts, result_scores
+
+ def user_cold_start(self, topk=None):
+ """
+ 冷启动:直接返回热度最高的topk个帖子详细信息
+ """
+ if topk is None:
+ topk = self.topk
+
+ conn = pymysql.connect(**self.db_config)
+ cursor = conn.cursor()
+
+ try:
+ # 查询热度最高的topk个帖子
+ cursor.execute("""
+ SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at
+ FROM posts p
+ WHERE p.status = 'published'
+ ORDER BY p.heat DESC
+ LIMIT %s
+ """, (topk,))
+ post_rows = cursor.fetchall()
+ post_ids = [row[0] for row in post_rows]
+ post_map = {row[0]: row for row in post_rows}
+
+ # 查询用户信息
+ owner_ids = list(set(row[1] for row in post_rows))
+ if owner_ids:
+ format_strings_user = ','.join(['%s'] * len(owner_ids))
+ cursor.execute(
+ f"SELECT id, username FROM users WHERE id IN ({format_strings_user})",
+ tuple(owner_ids)
+ )
+ user_rows = cursor.fetchall()
+ user_map = {row[0]: row[1] for row in user_rows}
+ else:
+ user_map = {}
+
+ # 查询帖子标签
+ if post_ids:
+ format_strings = ','.join(['%s'] * len(post_ids))
+ cursor.execute(
+ f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags
+ FROM post_tags pt
+ JOIN tags t ON pt.tag_id = t.id
+ WHERE pt.post_id IN ({format_strings})
+ GROUP BY pt.post_id""",
+ tuple(post_ids)
+ )
+ tag_rows = cursor.fetchall()
+ tag_map = {row[0]: row[1] for row in tag_rows}
+ else:
+ tag_map = {}
+
+ post_list = []
+ for post_id in post_ids:
+ row = post_map.get(post_id)
+ if not row:
+ continue
+ owner_user_id = row[1]
+ post_list.append({
+ 'post_id': post_id,
+ 'title': row[2],
+ 'content': row[3][:200] + '...' if len(row[3]) > 200 else row[3], # 截取前200字符
+ 'type': row[4],
+ 'username': user_map.get(owner_user_id, ""),
+ 'heat': row[5],
+ 'tags': tag_map.get(post_id, ""),
+ 'created_at': str(row[6]) if row[6] else ""
+ })
+ return post_list
+ finally:
+ cursor.close()
+ conn.close()
+
+ def run_inference(self, user_id, topk=None, use_multi_recall=None):
+ """
+ 推荐推理主函数
+
+ Args:
+ user_id: 用户ID
+ topk: 推荐数量
+ use_multi_recall: 是否使用多路召回,None表示使用默认设置
+ """
+ if topk is None:
+ topk = self.topk
+
+ # 决定使用哪种召回方式
+ if use_multi_recall is None:
+ use_multi_recall = self.multi_recall_enabled
+
+ return self._run_multi_recall_inference(user_id, topk)
+
+ def _run_multi_recall_inference(self, user_id, topk):
+ """使用多路召回进行推荐,并可选择使用LightGCN重新打分"""
+ try:
+ # 初始化多路召回(如果尚未初始化)
+ self.init_multi_recall()
+
+ # 执行多路召回,召回更多候选物品
+ total_candidates = min(topk * 10, 500) # 召回候选数是最终推荐数的10倍
+ candidate_post_ids, candidate_scores, recall_breakdown = self.multi_recall_inference(
+ user_id, total_candidates
+ )
+
+ if not candidate_post_ids:
+ # 如果多路召回没有结果,回退到冷启动
+ print(f"用户 {user_id} 多路召回无结果,使用冷启动")
+ return self.user_cold_start(topk)
+
+ print(f"用户 {user_id} 多路召回候选数量: {len(candidate_post_ids)}")
+ print(f"召回来源分布: {self._get_recall_source_stats(recall_breakdown)}")
+
+ # 如果启用LightGCN重新打分,使用LightGCN对候选结果进行评分
+ if self.use_lightgcn_rerank:
+ print("使用LightGCN对多路召回结果进行重新打分...")
+ lightgcn_scores = self._get_lightgcn_scores(user_id, candidate_post_ids)
+
+ # 直接使用LightGCN分数,不进行融合
+ final_scores = lightgcn_scores
+
+ print(f"LightGCN打分完成,分数范围: [{min(lightgcn_scores):.4f}, {max(lightgcn_scores):.4f}]")
+ print(f"使用LightGCN分数进行重排")
+ else:
+ # 使用原始多路召回分数
+ final_scores = candidate_scores
+
+ # 使用MMR算法重排,包含广告约束
+ final_post_ids, final_scores = self.mmr_rerank_with_ads(
+ candidate_post_ids, final_scores, theta=0.5, target_size=topk
+ )
+
+ return final_post_ids, final_scores
+
+ except Exception as e:
+ print(f"多路召回失败: {str(e)},回退到LightGCN")
+ return self._run_lightgcn_inference(user_id, topk)
+
+ def _run_lightgcn_inference(self, user_id, topk):
+ """使用原始LightGCN进行推荐"""
+ user2idx, post2idx = build_user_post_graph(return_mapping=True)
+ idx2post = {v: k for k, v in post2idx.items()}
+
+ if user_id not in user2idx:
+ # 冷启动
+ return self.user_cold_start(topk)
+
+ user_idx = user2idx[user_id]
+
+ dataset = EdgeListData(args.data_path, args.data_path)
+ pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
+ pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
+ pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
+
+ model = LightGCN(dataset, phase='vanilla').to(args.device)
+ model.load_state_dict(pretrained_dict, strict=False)
+ model.eval()
+
+ with torch.no_grad():
+ user_emb, item_emb = model.generate()
+ user_vec = user_emb[user_idx].unsqueeze(0)
+ scores = model.rating(user_vec, item_emb).squeeze(0)
+
+ # 获取所有物品的分数(而不是只取top候选)
+ all_scores = scores.cpu().numpy()
+ all_post_ids = [idx2post[idx] for idx in range(len(all_scores))]
+
+ # 过滤掉分数为负的物品,只保留正分数的候选
+ positive_candidates = [(post_id, score) for post_id, score in zip(all_post_ids, all_scores) if score > 0]
+
+ if not positive_candidates:
+ # 如果没有正分数的候选,取分数最高的一些
+ sorted_candidates = sorted(zip(all_post_ids, all_scores), key=lambda x: x[1], reverse=True)
+ positive_candidates = sorted_candidates[:min(100, len(sorted_candidates))]
+
+ candidate_post_ids = [post_id for post_id, _ in positive_candidates]
+ candidate_scores = [score for _, score in positive_candidates]
+
+ print(f"用户 {user_id} 的LightGCN候选物品数量: {len(candidate_post_ids)}")
+
+ # 使用MMR算法重排,包含广告约束,theta=0.5平衡相关性和多样性
+ final_post_ids, final_scores = self.mmr_rerank_with_ads(
+ candidate_post_ids, candidate_scores, theta=0.5, target_size=topk
+ )
+
+ return final_post_ids, final_scores
+
+ def _get_recall_source_stats(self, recall_breakdown):
+ """获取召回来源统计"""
+ stats = {}
+ for source, items in recall_breakdown.items():
+ stats[source] = len(items)
+ return stats
+
+ def get_post_info(self, topk_post_ids, topk_scores=None):
+ """
+ 输入: topk_post_ids(帖子ID列表),topk_scores(对应的打分列表,可选)
+ 输出: 推荐帖子的详细信息列表,每个元素为dict
+ """
+ if not topk_post_ids:
+ return []
+
+ print(f"获取帖子详细信息,帖子ID列表: {topk_post_ids}")
+ if topk_scores is not None:
+ print(f"对应的推荐打分: {topk_scores}")
+
+ conn = pymysql.connect(**self.db_config)
+ cursor = conn.cursor()
+
+ try:
+ # 查询帖子基本信息
+ format_strings = ','.join(['%s'] * len(topk_post_ids))
+ cursor.execute(
+ f"""SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at, p.is_advertisement
+ FROM posts p
+ WHERE p.id IN ({format_strings}) AND p.status = 'published'""",
+ tuple(topk_post_ids)
+ )
+ post_rows = cursor.fetchall()
+ post_map = {row[0]: row for row in post_rows}
+
+ # 查询用户信息
+ owner_ids = list(set(row[1] for row in post_rows))
+ if owner_ids:
+ format_strings_user = ','.join(['%s'] * len(owner_ids))
+ cursor.execute(
+ f"SELECT id, username FROM users WHERE id IN ({format_strings_user})",
+ tuple(owner_ids)
+ )
+ user_rows = cursor.fetchall()
+ user_map = {row[0]: row[1] for row in user_rows}
+ else:
+ user_map = {}
+
+ # 查询帖子标签
+ cursor.execute(
+ f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags
+ FROM post_tags pt
+ JOIN tags t ON pt.tag_id = t.id
+ WHERE pt.post_id IN ({format_strings})
+ GROUP BY pt.post_id""",
+ tuple(topk_post_ids)
+ )
+ tag_rows = cursor.fetchall()
+ tag_map = {row[0]: row[1] for row in tag_rows}
+
+ # 查询行为统计(点赞数、评论数等)
+ cursor.execute(
+ f"""SELECT post_id, type, COUNT(*) as count
+ FROM behaviors
+ WHERE post_id IN ({format_strings})
+ GROUP BY post_id, type""",
+ tuple(topk_post_ids)
+ )
+ behavior_rows = cursor.fetchall()
+ behavior_stats = {}
+ for row in behavior_rows:
+ post_id, behavior_type, count = row
+ if post_id not in behavior_stats:
+ behavior_stats[post_id] = {}
+ behavior_stats[post_id][behavior_type] = count
+
+ post_list = []
+ for i, post_id in enumerate(topk_post_ids):
+ row = post_map.get(post_id)
+ if not row:
+ print(f"帖子ID {post_id} 不存在或未发布,跳过")
+ continue
+ owner_user_id = row[1]
+ stats = behavior_stats.get(post_id, {})
+ post_info = {
+ 'post_id': post_id,
+ 'title': row[2],
+ 'content': row[3][:200] + '...' if len(row[3]) > 200 else row[3],
+ 'type': row[4],
+ 'username': user_map.get(owner_user_id, ""),
+ 'heat': row[5],
+ 'tags': tag_map.get(post_id, ""),
+ 'created_at': str(row[6]) if row[6] else "",
+ 'is_advertisement': bool(row[7]), # 添加广告标识
+ 'like_count': stats.get('like', 0),
+ 'comment_count': stats.get('comment', 0),
+ 'favorite_count': stats.get('favorite', 0),
+ 'view_count': stats.get('view', 0),
+ 'share_count': stats.get('share', 0)
+ }
+
+ # 如果有推荐打分,添加到结果中
+ if topk_scores is not None and i < len(topk_scores):
+ post_info['recommendation_score'] = float(topk_scores[i])
+
+ post_list.append(post_info)
+ return post_list
+ finally:
+ cursor.close()
+ conn.close()
+
+ def get_recommendations(self, user_id, topk=None):
+ """
+ 获取推荐结果的主要接口
+ """
+ try:
+ result = self.run_inference(user_id, topk)
+ # 如果是冷启动直接返回详细信息,否则查详情
+ if isinstance(result, list) and result and isinstance(result[0], dict):
+ return result
+ else:
+ # result 现在是 (topk_post_ids, topk_scores) 的元组
+ if isinstance(result, tuple) and len(result) == 2:
+ topk_post_ids, topk_scores = result
+ return self.get_post_info(topk_post_ids, topk_scores)
+ else:
+ # 兼容旧的返回格式
+ return self.get_post_info(result)
+ except Exception as e:
+ raise Exception(f"推荐系统错误: {str(e)}")
+
+ def get_all_item_scores(self, user_id):
+ """
+ 获取用户对所有物品的打分
+ 输入: user_id
+ 输出: (post_ids, scores) - 所有帖子ID和对应的打分
+ """
+ user2idx, post2idx = build_user_post_graph(return_mapping=True)
+ idx2post = {v: k for k, v in post2idx.items()}
+
+ if user_id not in user2idx:
+ # 用户不存在,返回空结果
+ return [], []
+
+ user_idx = user2idx[user_id]
+
+ dataset = EdgeListData(args.data_path, args.data_path)
+ pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
+ pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
+ pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
+
+ model = LightGCN(dataset, phase='vanilla').to(args.device)
+ model.load_state_dict(pretrained_dict, strict=False)
+ model.eval()
+
+ with torch.no_grad():
+ user_emb, item_emb = model.generate()
+ user_vec = user_emb[user_idx].unsqueeze(0)
+ scores = model.rating(user_vec, item_emb).squeeze(0)
+
+ # 获取所有物品的ID和分数
+ all_scores = scores.cpu().numpy()
+ all_post_ids = [idx2post[idx] for idx in range(len(all_scores))]
+
+ return all_post_ids, all_scores
+
+ def init_multi_recall(self):
+ """初始化多路召回管理器"""
+ if self.multi_recall is None:
+ print("初始化多路召回管理器...")
+ self.multi_recall = MultiRecallManager(self.db_config, self.recall_config)
+ print("多路召回管理器初始化完成")
+
+ def init_lightgcn_scorer(self):
+ """初始化LightGCN评分器"""
+ if self.lightgcn_scorer is None:
+ print("初始化LightGCN评分器...")
+ self.lightgcn_scorer = LightGCNScorer()
+ print("LightGCN评分器初始化完成")
+
+ def _get_lightgcn_scores(self, user_id, candidate_post_ids):
+ """
+ 获取候选物品的LightGCN分数
+
+ Args:
+ user_id: 用户ID
+ candidate_post_ids: 候选物品ID列表
+
+ Returns:
+ List[float]: LightGCN分数列表
+ """
+ self.init_lightgcn_scorer()
+ return self.lightgcn_scorer.score_batch_candidates(user_id, candidate_post_ids)
+
+ def _fuse_scores(self, multi_recall_scores, lightgcn_scores, alpha=0.6):
+ """
+ 融合多路召回分数和LightGCN分数
+
+ Args:
+ multi_recall_scores: 多路召回分数列表
+ lightgcn_scores: LightGCN分数列表
+ alpha: LightGCN分数的权重(0-1之间)
+
+ Returns:
+ List[float]: 融合后的分数列表
+ """
+ if len(multi_recall_scores) != len(lightgcn_scores):
+ raise ValueError("分数列表长度不匹配")
+
+ # 对分数进行归一化
+ def normalize_scores(scores):
+ scores = np.array(scores)
+ min_score = np.min(scores)
+ max_score = np.max(scores)
+ if max_score == min_score:
+ return np.ones_like(scores) * 0.5
+ return (scores - min_score) / (max_score - min_score)
+
+ norm_multi_scores = normalize_scores(multi_recall_scores)
+ norm_lightgcn_scores = normalize_scores(lightgcn_scores)
+
+ # 加权融合
+ fused_scores = alpha * norm_lightgcn_scores + (1 - alpha) * norm_multi_scores
+
+ return fused_scores.tolist()
+
+ def train_multi_recall(self):
+ """训练多路召回模型"""
+ self.init_multi_recall()
+ self.multi_recall.train_all()
+
+ def update_recall_config(self, new_config):
+ """更新多路召回配置"""
+ self.recall_config.update(new_config)
+ if self.multi_recall:
+ self.multi_recall.update_config(new_config)
+
+ def multi_recall_inference(self, user_id, total_items=200):
+ """
+ 使用多路召回进行推荐
+
+ Args:
+ user_id: 用户ID
+ total_items: 总召回物品数量
+
+ Returns:
+ Tuple of (item_ids, scores, recall_breakdown)
+ """
+ self.init_multi_recall()
+
+ # 执行多路召回
+ item_ids, scores, recall_results = self.multi_recall.recall(user_id, total_items)
+
+ return item_ids, scores, recall_results
+
+ def get_multi_recall_stats(self, user_id):
+ """获取多路召回统计信息"""
+ if self.multi_recall is None:
+ return {"error": "多路召回未初始化"}
+
+ return self.multi_recall.get_recall_stats(user_id)
diff --git a/rhj/backend/app/user_post_graph.txt b/rhj/backend/app/user_post_graph.txt
new file mode 100644
index 0000000..2c66fd1
--- /dev/null
+++ b/rhj/backend/app/user_post_graph.txt
@@ -0,0 +1,11 @@
+0 1 0 0 2 1 2 0 1 2 1 0 42 32 62 52 0 12 22 1749827292 1749827292 1749953091 1749953091 1749953091 1749953480 1749953480 1749953480 1749954059 1749954059 1749954059 1749955282 1749955282 1749955282 1749955282 1749955282 1749955282 1749955282 1 5 5 2 1 2 5 1 2 1 5 5 2 2 1 1 5 1
+1 2 0 0 43 33 53 1 5 13 23 1749827292 1749953091 1749953480 1749955282 1749955282 1749955282 1749955282 1749955282 1749955282 1749955282 5 5 5 1 5 2 2 5 1 2
+2 7 6 6 7 44 34 54 2 14 24 1749953091 1749953091 1749953480 1749953480 1749955282 1749955282 1749955282 1749955282 1749955282 1749955282 2 1 1 2 2 1 5 5 2 5
+3 3 0 3 0 0 1 45 35 55 15 25 1749953091 1749953091 1749953480 1749953480 1749954059 1749954059 1749955282 1749955282 1749955282 1749955282 1749955282 2 2 2 2 1 2 5 2 1 5 1
+4 0 0 2 46 36 56 6 16 26 1749953091 1749953480 1749954059 1749955282 1749955282 1749955282 1749955282 1749955282 1749955282 5 5 5 1 5 2 5 1 2
+5 37 47 57 3 7 17 27 1749955282 1749955282 1749955282 1749955282 1749955282 1749955282 1749955282 1 2 5 5 1 2 5
+6 38 48 58 8 18 28 1749955282 1749955282 1749955282 1749955282 1749955282 1749955282 2 5 1 2 5 1
+7 39 49 59 4 9 19 29 1749955282 1749955282 1749955282 1749955282 1749955282 1749955282 1749955282 5 1 2 2 5 1 2
+8 40 50 60 10 20 30 1749955282 1749955282 1749955282 1749955282 1749955282 1749955282 1 2 5 1 2 5
+9 41 51 61 11 31 21 1749955282 1749955282 1749955282 1749955282 1749955282 1749955282 2 5 1 2 1 5
+10 13 1749894174 5
diff --git a/rhj/backend/app/utils/__pycache__/bloom_filter.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/bloom_filter.cpython-312.pyc
new file mode 100644
index 0000000..5c90537
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/bloom_filter.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/__pycache__/bloom_filter_manager.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/bloom_filter_manager.cpython-312.pyc
new file mode 100644
index 0000000..268f1fb
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/bloom_filter_manager.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/__pycache__/data_loader.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/data_loader.cpython-312.pyc
new file mode 100644
index 0000000..10b3571
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/data_loader.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/__pycache__/graph_build.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/graph_build.cpython-312.pyc
new file mode 100644
index 0000000..a560e74
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/graph_build.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/__pycache__/parse_args.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/parse_args.cpython-312.pyc
new file mode 100644
index 0000000..a88ee3b
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/parse_args.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/data_loader.py b/rhj/backend/app/utils/data_loader.py
new file mode 100644
index 0000000..c882a12
--- /dev/null
+++ b/rhj/backend/app/utils/data_loader.py
@@ -0,0 +1,97 @@
+from app.utils.parse_args import args
+from os import path
+from tqdm import tqdm
+import numpy as np
+import scipy.sparse as sp
+import torch
+import networkx as nx
+from copy import deepcopy
+from collections import defaultdict
+import pandas as pd
+
+
+class EdgeListData:
+ def __init__(self, train_file, test_file, phase='pretrain', pre_dataset=None, has_time=True):
+ self.phase = phase
+ self.has_time = has_time
+ self.pre_dataset = pre_dataset
+
+ self.hour_interval = args.hour_interval_pre if phase == 'pretrain' else args.hour_interval_f
+
+ self.edgelist = []
+ self.edge_time = []
+ self.num_users = 0
+ self.num_items = 0
+ self.num_edges = 0
+
+ self.train_user_dict = {}
+ self.test_user_dict = {}
+
+ self._load_data(train_file, test_file, has_time)
+
+ if phase == 'pretrain':
+ self.user_hist_dict = self.train_user_dict
+
+ users_has_hist = set(list(self.user_hist_dict.keys()))
+ all_users = set(list(range(self.num_users)))
+ users_no_hist = all_users - users_has_hist
+ for u in users_no_hist:
+ self.user_hist_dict[u] = []
+
+ def _read_file(self, train_file, test_file, has_time=True):
+ with open(train_file, 'r') as f:
+ for line in f:
+ line = line.strip().split('\t')
+ if not has_time:
+ user, items = line[:2]
+ times = " ".join(["0"] * len(items.split(" ")))
+ weights = " ".join(["1"] * len(items.split(" "))) if len(line) < 4 else line[3]
+ else:
+ if len(line) >= 4: # 包含权重信息
+ user, items, times, weights = line
+ else:
+ user, items, times = line
+ weights = " ".join(["1"] * len(items.split(" ")))
+
+ for i in items.split(" "):
+ self.edgelist.append((int(user), int(i)))
+ for i in times.split(" "):
+ self.edge_time.append(int(i))
+ self.train_user_dict[int(user)] = [int(i) for i in items.split(" ")]
+
+ self.test_edge_num = 0
+ with open(test_file, 'r') as f:
+ for line in f:
+ line = line.strip().split('\t')
+ user, items = line[:2]
+ self.test_user_dict[int(user)] = [int(i) for i in items.split(" ")]
+ self.test_edge_num += len(self.test_user_dict[int(user)])
+
+ def _load_data(self, train_file, test_file, has_time=True):
+ self._read_file(train_file, test_file, has_time)
+
+ self.edgelist = np.array(self.edgelist, dtype=np.int32)
+ self.edge_time = 1 + self.timestamp_to_time_step(np.array(self.edge_time, dtype=np.int32))
+ self.num_edges = len(self.edgelist)
+ if self.pre_dataset is not None:
+ self.num_users = self.pre_dataset.num_users
+ self.num_items = self.pre_dataset.num_items
+ else:
+ self.num_users = max([np.max(self.edgelist[:, 0]) + 1, np.max(list(self.test_user_dict.keys())) + 1])
+ self.num_items = max([np.max(self.edgelist[:, 1]) + 1, np.max([np.max(self.test_user_dict[u]) for u in self.test_user_dict.keys()]) + 1])
+
+ self.graph = sp.coo_matrix((np.ones(self.num_edges), (self.edgelist[:, 0], self.edgelist[:, 1])), shape=(self.num_users, self.num_items))
+
+ if self.has_time:
+ self.edge_time_dict = defaultdict(dict)
+ for i in range(len(self.edgelist)):
+ self.edge_time_dict[self.edgelist[i][0]][self.edgelist[i][1]+self.num_users] = self.edge_time[i]
+ self.edge_time_dict[self.edgelist[i][1]+self.num_users][self.edgelist[i][0]] = self.edge_time[i]
+
+ def timestamp_to_time_step(self, timestamp_arr, least_time=None):
+ interval_hour = self.hour_interval
+ if least_time is None:
+ least_time = np.min(timestamp_arr)
+ timestamp_arr = timestamp_arr - least_time
+ timestamp_arr = timestamp_arr // (interval_hour * 3600)
+ return timestamp_arr
diff --git a/rhj/backend/app/utils/graph_build.py b/rhj/backend/app/utils/graph_build.py
new file mode 100644
index 0000000..a453e4e
--- /dev/null
+++ b/rhj/backend/app/utils/graph_build.py
@@ -0,0 +1,115 @@
+import pymysql
+import datetime
+from collections import defaultdict
+
+SqlURL = "10.126.59.25"
+SqlPort = 3306
+Database = "redbook" # 修改为redbook数据库
+SqlUsername = "root"
+SqlPassword = "123456"
+
+
+def fetch_user_post_data():
+ """
+ 从redbook数据库的behaviors表获取用户-帖子交互数据,只包含已发布的帖子
+ """
+ conn = pymysql.connect(
+ host=SqlURL,
+ port=SqlPort,
+ user=SqlUsername,
+ password=SqlPassword,
+ database=Database,
+ charset="utf8mb4"
+ )
+ cursor = conn.cursor()
+ # 获取用户行为数据,只包含已发布帖子的行为数据
+ cursor.execute("""
+ SELECT b.user_id, b.post_id, b.type, b.value, b.created_at
+ FROM behaviors b
+ INNER JOIN posts p ON b.post_id = p.id
+ WHERE b.type IN ('like', 'favorite', 'comment', 'view', 'share')
+ AND p.status = 'published'
+ ORDER BY b.created_at
+ """)
+ behavior_rows = cursor.fetchall()
+ cursor.close()
+ conn.close()
+ return behavior_rows
+
+
+def process_records(behavior_rows):
+ """
+ 处理用户行为记录,为不同类型的行为分配权重
+ """
+ records = []
+ user_set = set()
+ post_set = set()
+
+ # 为不同行为类型分配权重
+ behavior_weights = {
+ 'view': 1,
+ 'like': 2,
+ 'comment': 3,
+ 'share': 4,
+ 'favorite': 5
+ }
+
+ for row in behavior_rows:
+ user_id, post_id, behavior_type, value, created_at = row
+ user_set.add(user_id)
+ post_set.add(post_id)
+
+ if isinstance(created_at, datetime.datetime):
+ ts = int(created_at.timestamp())
+ else:
+ ts = 0
+
+ # 使用行为权重
+ weight = behavior_weights.get(behavior_type, 1) * (value or 1)
+ records.append((user_id, post_id, ts, weight))
+
+ return records, user_set, post_set
+
+
+def build_id_maps(user_set, post_set):
+ """
+ 构建用户和帖子的ID映射
+ """
+ user2idx = {uid: idx for idx, uid in enumerate(sorted(user_set))}
+ post2idx = {pid: idx for idx, pid in enumerate(sorted(post_set))}
+ return user2idx, post2idx
+
+
+def group_and_write(records, user2idx, post2idx, output_path="./app/user_post_graph.txt"):
+ """
+ 将记录按用户分组并写入文件,支持行为权重
+ """
+ user_items = defaultdict(list)
+ user_times = defaultdict(list)
+ user_weights = defaultdict(list)
+
+ for user_id, post_id, ts, weight in records:
+ uid = user2idx[user_id]
+ pid = post2idx[post_id]
+ user_items[uid].append(pid)
+ user_times[uid].append(ts)
+ user_weights[uid].append(weight)
+
+ with open(output_path, "w", encoding="utf-8") as f:
+ for uid in sorted(user_items.keys()):
+ items = " ".join(str(item) for item in user_items[uid])
+ times = " ".join(str(t) for t in user_times[uid])
+ weights = " ".join(str(w) for w in user_weights[uid])
+ f.write(f"{uid}\t{items}\t{times}\t{weights}\n")
+
+
+def build_user_post_graph(return_mapping=False):
+ """
+ 构建用户-帖子交互图
+ """
+ behavior_rows = fetch_user_post_data()
+ records, user_set, post_set = process_records(behavior_rows)
+ user2idx, post2idx = build_id_maps(user_set, post_set)
+ group_and_write(records, user2idx, post2idx)
+ if return_mapping:
+ return user2idx, post2idx
\ No newline at end of file
diff --git a/rhj/backend/app/utils/parse_args.py b/rhj/backend/app/utils/parse_args.py
new file mode 100644
index 0000000..82b3bb4
--- /dev/null
+++ b/rhj/backend/app/utils/parse_args.py
@@ -0,0 +1,77 @@
+import argparse
+import sys
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='GraphPro')
+ parser.add_argument('--phase', type=str, default='pretrain')
+ parser.add_argument('--plugin', action='store_true', default=False)
+ parser.add_argument('--save_path', type=str, default="saved" ,help='where to save model and logs')
+ parser.add_argument('--data_path', type=str, default="dataset/yelp",help='where to load data')
+ parser.add_argument('--exp_name', type=str, default='1')
+ parser.add_argument('--desc', type=str, default='')
+ parser.add_argument('--ab', type=str, default='full')
+ parser.add_argument('--log', type=int, default=1)
+
+ parser.add_argument('--device', type=str, default="cuda")
+ parser.add_argument('--model', type=str, default='GraphPro')
+ parser.add_argument('--pre_model', type=str, default='GraphPro')
+ parser.add_argument('--f_model', type=str, default='GraphPro')
+ parser.add_argument('--pre_model_path', type=str, default='pretrained_model.pt')
+
+ parser.add_argument('--hour_interval_pre', type=float, default=1)
+ parser.add_argument('--hour_interval_f', type=int, default=1)
+ parser.add_argument('--emb_dropout', type=float, default=0)
+
+ parser.add_argument('--updt_inter', type=int, default=1)
+ parser.add_argument('--samp_decay', type=float, default=0.05)
+
+ parser.add_argument('--edge_dropout', type=float, default=0.5)
+ parser.add_argument('--emb_size', type=int, default=64)
+ parser.add_argument('--batch_size', type=int, default=2048)
+ parser.add_argument('--eval_batch_size', type=int, default=512)
+ parser.add_argument('--seed', type=int, default=2023)
+ parser.add_argument('--num_epochs', type=int, default=300)
+ parser.add_argument('--neighbor_sample_num', type=int, default=5)
+ parser.add_argument('--lr', type=float, default=0.001)
+ parser.add_argument('--weight_decay', type=float, default=1e-4)
+ parser.add_argument('--metrics', type=str, default='recall;ndcg')
+ parser.add_argument('--metrics_k', type=str, default='20')
+ parser.add_argument('--early_stop_patience', type=int, default=10)
+ parser.add_argument('--neg_num', type=int, default=1)
+ parser.add_argument('--num_layers', type=int, default=3)
+ parser.add_argument('--n_layers', type=int, default=3)
+ parser.add_argument('--ssl_reg', type=float, default=1e-4)
+ parser.add_argument('--ssl_alpha', type=float, default=1)
+ parser.add_argument('--ssl_temp', type=float, default=0.2)
+ parser.add_argument('--epoch', type=int, default=200)
+ parser.add_argument('--decay', type=float, default=1e-3)
+ parser.add_argument('--model_reg', type=float, default=1e-4)
+ parser.add_argument('--topk', type=int, default=[1, 5, 10, 20], nargs='+')
+ parser.add_argument('--aug_type', type=str, default='ED')
+ parser.add_argument('--metric_topk', type=int, default=10)
+ parser.add_argument('--n_neighbors', type=int, default=32)
+ parser.add_argument('--n_samp', type=int, default=7)
+ parser.add_argument('--temp', type=float, default=1)
+ parser.add_argument('--temp_f', type=float, default=1)
+
+ return parser
+
+# 创建默认args,支持在没有命令行参数时使用
+try:
+ # 如果是在Flask应用中运行,使用默认参数
+ if len(sys.argv) == 1 or any(x in sys.argv[0] for x in ['flask', 'app.py', 'gunicorn']):
+ parser = parse_args()
+ args = parser.parse_args([]) # 使用空参数列表
+ else:
+ parser = parse_args()
+ args = parser.parse_args()
+except SystemExit:
+ # 如果parse_args失败,使用默认参数
+ parser = parse_args()
+ args = parser.parse_args([])
+
+if hasattr(args, 'pre_model') and hasattr(args, 'f_model'):
+ if args.pre_model == args.f_model:
+ args.model = args.pre_model
+ elif args.pre_model != 'LightGCN':
+ args.model = args.pre_model
diff --git a/rhj/backend/requirements.txt b/rhj/backend/requirements.txt
deleted file mode 100644
index 0cd4f4e..0000000
--- a/rhj/backend/requirements.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-Flask==2.3.3
-Flask-CORS==4.0.0
-SQLAlchemy==2.0.21
-PyJWT==2.8.0
-python-dotenv==1.0.0
-pymysql==1.1.0
-secure-smtplib==0.1.1
diff --git a/rhj/backend/test_bloom_filter.py b/rhj/backend/test_bloom_filter.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/rhj/backend/test_bloom_filter.py
diff --git a/rhj/backend/test_redbook_recommendation.py b/rhj/backend/test_redbook_recommendation.py
new file mode 100644
index 0000000..d025ace
--- /dev/null
+++ b/rhj/backend/test_redbook_recommendation.py
@@ -0,0 +1,279 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+测试基于redbook数据库的推荐系统
+"""
+
+import sys
+import os
+import time
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+
+from app.services.recommendation_service import RecommendationService
+from app.utils.graph_build import build_user_post_graph
+import pymysql
+
+def test_database_connection():
+ """测试数据库连接"""
+ print("=== 测试数据库连接 ===")
+ try:
+ db_config = {
+ 'host': '10.126.59.25',
+ 'port': 3306,
+ 'user': 'root',
+ 'password': '123456',
+ 'database': 'redbook',
+ 'charset': 'utf8mb4'
+ }
+ conn = pymysql.connect(**db_config)
+ cursor = conn.cursor()
+
+ # 检查用户数量
+ cursor.execute("SELECT COUNT(*) FROM users")
+ user_count = cursor.fetchone()[0]
+ print(f"用户总数: {user_count}")
+
+ # 检查帖子数量
+ cursor.execute("SELECT COUNT(*) FROM posts WHERE status = 'published'")
+ post_count = cursor.fetchone()[0]
+ print(f"已发布帖子数: {post_count}")
+
+ # 检查行为数据
+ cursor.execute("SELECT type, COUNT(*) FROM behaviors GROUP BY type")
+ behavior_stats = cursor.fetchall()
+ print("行为统计:")
+ for behavior_type, count in behavior_stats:
+ print(f" {behavior_type}: {count}")
+
+ cursor.close()
+ conn.close()
+ print("数据库连接测试成功!")
+ return True
+ except Exception as e:
+ print(f"数据库连接失败: {e}")
+ return False
+
+def test_graph_building():
+ """测试图构建"""
+ print("\n=== 测试图构建 ===")
+ try:
+ user2idx, post2idx = build_user_post_graph(return_mapping=True)
+ print(f"用户数量: {len(user2idx)}")
+ print(f"帖子数量: {len(post2idx)}")
+
+ # 显示前几个用户和帖子的映射
+ print("前5个用户映射:")
+ for i, (user_id, idx) in enumerate(list(user2idx.items())[:5]):
+ print(f" 用户{user_id} -> 索引{idx}")
+
+ print("前5个帖子映射:")
+ for i, (post_id, idx) in enumerate(list(post2idx.items())[:5]):
+ print(f" 帖子{post_id} -> 索引{idx}")
+
+ print("图构建测试成功!")
+ return True
+ except Exception as e:
+ print(f"图构建失败: {e}")
+ return False
+
+def test_cold_start_recommendation():
+ """测试冷启动推荐"""
+ print("\n=== 测试冷启动推荐 ===")
+ try:
+ service = RecommendationService()
+
+ # 使用一个不存在的用户ID进行冷启动测试
+ fake_user_id = 999999
+
+ # 计时开始
+ start_time = time.time()
+ recommendations = service.get_recommendations(fake_user_id, topk=10)
+ end_time = time.time()
+
+ # 计算推荐耗时
+ recommendation_time = end_time - start_time
+ print(f"冷启动推荐耗时: {recommendation_time:.4f} 秒")
+
+ print(f"冷启动推荐结果(用户{fake_user_id}):")
+ for i, rec in enumerate(recommendations):
+ print(f" {i+1}. 帖子ID: {rec['post_id']}, 标题: {rec['title'][:50]}...")
+ print(f" 作者: {rec['username']}, 热度: {rec['heat']}")
+ print(f" 点赞: {rec.get('like_count', 0)}, 评论: {rec.get('comment_count', 0)}")
+
+ print("冷启动推荐测试成功!")
+ return True
+ except Exception as e:
+ print(f"冷启动推荐失败: {e}")
+ return False
+
+def test_user_recommendation():
+ """测试用户推荐"""
+ print("\n=== 测试用户推荐 ===")
+ try:
+ service = RecommendationService()
+
+ # 获取一个真实用户ID
+ db_config = service.db_config
+ conn = pymysql.connect(**db_config)
+ cursor = conn.cursor()
+ cursor.execute("SELECT DISTINCT user_id FROM behaviors LIMIT 1")
+ result = cursor.fetchone()
+
+ if result:
+ user_id = result[0]
+ print(f"测试用户ID: {user_id}")
+
+ # 查看用户的历史行为
+ cursor.execute("""
+ SELECT b.type, COUNT(*) as count
+ FROM behaviors b
+ WHERE b.user_id = %s
+ GROUP BY b.type
+ """, (user_id,))
+ user_behaviors = cursor.fetchall()
+ print("用户历史行为:")
+ for behavior_type, count in user_behaviors:
+ print(f" {behavior_type}: {count}")
+
+ cursor.close()
+ conn.close()
+
+ # 尝试获取推荐 - 添加计时
+ print("开始生成推荐...")
+ start_time = time.time()
+ recommendations = service.get_recommendations(user_id, topk=10)
+ end_time = time.time()
+
+ # 计算推荐耗时
+ recommendation_time = end_time - start_time
+ print(f"用户推荐耗时: {recommendation_time:.4f} 秒")
+
+ print(f"用户推荐结果(用户{user_id}):")
+ for i, rec in enumerate(recommendations):
+ print(f" {i+1}. 帖子ID: {rec['post_id']}, 标题: {rec['title'][:50]}...")
+ print(f" 作者: {rec['username']}, 热度: {rec['heat']}")
+ print(f" 点赞: {rec.get('like_count', 0)}, 评论: {rec.get('comment_count', 0)}")
+ if 'recommendation_score' in rec:
+ print(f" 推荐分数: {rec['recommendation_score']:.4f}")
+ else:
+ print(f" 热度分数: {rec['heat']}")
+
+ print("用户推荐测试成功!")
+ return True
+ else:
+ print("没有找到有行为记录的用户")
+ cursor.close()
+ conn.close()
+ return False
+
+ except Exception as e:
+ print(f"用户推荐失败: {e}")
+ return False
+
+def test_recommendation_performance():
+ """测试推荐性能 - 多次调用统计"""
+ print("\n=== 测试推荐性能 ===")
+ try:
+ service = RecommendationService()
+
+ # 获取几个真实用户ID进行测试
+ db_config = service.db_config
+ conn = pymysql.connect(**db_config)
+ cursor = conn.cursor()
+ cursor.execute("SELECT DISTINCT user_id FROM behaviors LIMIT 5")
+ user_ids = [row[0] for row in cursor.fetchall()]
+ cursor.close()
+ conn.close()
+
+ if not user_ids:
+ print("没有找到有行为记录的用户")
+ return False
+
+ print(f"测试用户数量: {len(user_ids)}")
+
+ # 进行多次推荐测试
+ times = []
+ test_rounds = 3 # 每个用户测试3轮
+
+ for round_num in range(test_rounds):
+ print(f"\n第 {round_num + 1} 轮测试:")
+ round_times = []
+
+ for i, user_id in enumerate(user_ids):
+ start_time = time.time()
+ recommendations = service.get_recommendations(user_id, topk=10)
+ end_time = time.time()
+
+ recommendation_time = end_time - start_time
+ round_times.append(recommendation_time)
+ times.append(recommendation_time)
+
+ print(f" 用户 {user_id}: {recommendation_time:.4f}s, 推荐数量: {len(recommendations)}")
+
+ # 计算本轮统计
+ avg_time = sum(round_times) / len(round_times)
+ min_time = min(round_times)
+ max_time = max(round_times)
+ print(f" 本轮平均耗时: {avg_time:.4f}s, 最快: {min_time:.4f}s, 最慢: {max_time:.4f}s")
+
+ # 计算总体统计
+ print(f"\n=== 性能统计总结 ===")
+ print(f"总测试次数: {len(times)}")
+ print(f"平均推荐耗时: {sum(times) / len(times):.4f} 秒")
+ print(f"最快推荐耗时: {min(times):.4f} 秒")
+ print(f"最慢推荐耗时: {max(times):.4f} 秒")
+ print(f"推荐耗时标准差: {(sum([(t - sum(times)/len(times))**2 for t in times]) / len(times))**0.5:.4f} 秒")
+
+ # 性能等级评估
+ avg_time = sum(times) / len(times)
+ if avg_time < 0.1:
+ performance_level = "优秀"
+ elif avg_time < 0.5:
+ performance_level = "良好"
+ elif avg_time < 1.0:
+ performance_level = "一般"
+ else:
+ performance_level = "需要优化"
+
+ print(f"性能评级: {performance_level}")
+
+ print("推荐性能测试成功!")
+ return True
+
+ except Exception as e:
+ print(f"推荐性能测试失败: {e}")
+ return False
+
+def main():
+ """主测试函数"""
+ print("开始测试基于redbook数据库的推荐系统")
+ print("=" * 50)
+
+ tests = [
+ test_database_connection,
+ test_graph_building,
+ test_cold_start_recommendation,
+ test_user_recommendation,
+ test_recommendation_performance
+ ]
+
+ passed = 0
+ total = len(tests)
+
+ for test in tests:
+ try:
+ if test():
+ passed += 1
+ except Exception as e:
+ print(f"测试异常: {e}")
+
+ print("\n" + "=" * 50)
+ print(f"测试完成: {passed}/{total} 通过")
+
+ if passed == total:
+ print("所有测试通过!")
+ else:
+ print("部分测试失败,请检查配置和代码")
+
+if __name__ == "__main__":
+ main()