推荐系统
Change-Id: I49b9205568f1ccf88b32b08511aff8b0bea8d1bd
diff --git a/rhj/backend/app/models/recall/__init__.py b/rhj/backend/app/models/recall/__init__.py
new file mode 100644
index 0000000..98d926b
--- /dev/null
+++ b/rhj/backend/app/models/recall/__init__.py
@@ -0,0 +1,24 @@
+"""
+多路召回模块
+
+包含以下召回算法:
+- SwingRecall: Swing召回算法,基于物品相似度
+- HotRecall: 热度召回算法,基于物品热度
+- AdRecall: 广告召回算法,专门处理广告内容
+- UserCFRecall: 用户协同过滤召回算法
+- MultiRecallManager: 多路召回管理器,整合所有召回策略
+"""
+
+from .swing_recall import SwingRecall
+from .hot_recall import HotRecall
+from .ad_recall import AdRecall
+from .usercf_recall import UserCFRecall
+from .multi_recall_manager import MultiRecallManager
+
+__all__ = [
+ 'SwingRecall',
+ 'HotRecall',
+ 'AdRecall',
+ 'UserCFRecall',
+ 'MultiRecallManager'
+]
diff --git a/rhj/backend/app/models/recall/__pycache__/__init__.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000..d1cf37c
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/__init__.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/ad_recall.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/ad_recall.cpython-312.pyc
new file mode 100644
index 0000000..08a722c
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/ad_recall.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/bloom_filter.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/bloom_filter.cpython-312.pyc
new file mode 100644
index 0000000..c4dae7e
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/bloom_filter.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/hot_recall.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/hot_recall.cpython-312.pyc
new file mode 100644
index 0000000..cb6c725
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/hot_recall.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/multi_recall_manager.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/multi_recall_manager.cpython-312.pyc
new file mode 100644
index 0000000..9a95456
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/multi_recall_manager.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/swing_recall.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/swing_recall.cpython-312.pyc
new file mode 100644
index 0000000..d913d68
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/swing_recall.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/usercf_recall.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/usercf_recall.cpython-312.pyc
new file mode 100644
index 0000000..adb6177
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/usercf_recall.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/ad_recall.py b/rhj/backend/app/models/recall/ad_recall.py
new file mode 100644
index 0000000..0fe3b0a
--- /dev/null
+++ b/rhj/backend/app/models/recall/ad_recall.py
@@ -0,0 +1,207 @@
+import pymysql
+from typing import List, Tuple, Dict
+import random
+
+class AdRecall:
+ """
+ 广告召回算法实现
+ 专门用于召回广告类型的内容
+ """
+
+ def __init__(self, db_config: dict):
+ """
+ 初始化广告召回模型
+
+ Args:
+ db_config: 数据库配置
+ """
+ self.db_config = db_config
+ self.ad_items = []
+
+ def _get_ad_items(self):
+ """获取广告物品列表"""
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+
+ # 获取所有广告帖子,按热度和发布时间排序
+ cursor.execute("""
+ SELECT
+ p.id,
+ p.heat,
+ p.created_at,
+ COUNT(DISTINCT b.user_id) as interaction_count,
+ DATEDIFF(NOW(), p.created_at) as days_since_created
+ FROM posts p
+ LEFT JOIN behaviors b ON p.id = b.post_id
+ WHERE p.is_advertisement = 1 AND p.status = 'published'
+ GROUP BY p.id, p.heat, p.created_at
+ ORDER BY p.heat DESC, p.created_at DESC
+ """)
+
+ results = cursor.fetchall()
+
+ # 计算广告分数
+ items_with_scores = []
+ for row in results:
+ post_id, heat, created_at, interaction_count, days_since_created = row
+
+ # 处理None值
+ heat = heat or 0
+ interaction_count = interaction_count or 0
+ days_since_created = days_since_created or 0
+
+ # 广告分数计算:热度 + 交互数 - 时间惩罚
+ # 新发布的广告给予更高权重
+ freshness_bonus = max(0, 30 - days_since_created) / 30.0 # 30天内的新鲜度奖励
+
+ ad_score = (
+ heat * 0.6 +
+ interaction_count * 0.3 +
+ freshness_bonus * 100 # 新鲜度奖励
+ )
+
+ items_with_scores.append((post_id, ad_score))
+
+ # 按广告分数排序
+ self.ad_items = sorted(items_with_scores, key=lambda x: x[1], reverse=True)
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ def train(self):
+ """训练广告召回模型"""
+ print("开始获取广告物品...")
+ self._get_ad_items()
+ print(f"广告召回模型训练完成,共{len(self.ad_items)}个广告物品")
+
+ def recall(self, user_id: int, num_items: int = 10) -> List[Tuple[int, float]]:
+ """
+ 为用户召回广告物品
+
+ Args:
+ user_id: 用户ID
+ num_items: 召回物品数量
+
+ Returns:
+ List of (item_id, score) tuples
+ """
+ # 如果尚未训练,先进行训练
+ if not hasattr(self, 'ad_items') or not self.ad_items:
+ self.train()
+
+ # 获取用户已交互的广告,避免重复推荐
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+ cursor.execute("""
+ SELECT DISTINCT b.post_id
+ FROM behaviors b
+ JOIN posts p ON b.post_id = p.id
+ WHERE b.user_id = %s AND p.is_advertisement = 1
+ AND b.type IN ('like', 'favorite', 'comment', 'view')
+ """, (user_id,))
+
+ user_interacted_ads = set(row[0] for row in cursor.fetchall())
+
+ # 获取用户的兴趣标签(基于历史行为)
+ cursor.execute("""
+ SELECT t.name, COUNT(*) as count
+ FROM behaviors b
+ JOIN posts p ON b.post_id = p.id
+ JOIN post_tags pt ON p.id = pt.post_id
+ JOIN tags t ON pt.tag_id = t.id
+ WHERE b.user_id = %s AND b.type IN ('like', 'favorite', 'comment')
+ GROUP BY t.name
+ ORDER BY count DESC
+ LIMIT 10
+ """, (user_id,))
+
+ user_interest_tags = set(row[0] for row in cursor.fetchall())
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ # 过滤掉用户已交互的广告
+ filtered_ads = [
+ (item_id, score) for item_id, score in self.ad_items
+ if item_id not in user_interacted_ads
+ ]
+
+ # 如果没有未交互的广告,但有广告数据,返回评分最高的广告(可能用户会再次感兴趣)
+ if not filtered_ads and self.ad_items:
+ print(f"用户 {user_id} 已与所有广告交互,返回评分最高的广告")
+ filtered_ads = self.ad_items[:num_items]
+
+ # 如果用户有兴趣标签,可以进一步个性化广告推荐
+ if user_interest_tags and filtered_ads:
+ filtered_ads = self._personalize_ads(filtered_ads, user_interest_tags)
+
+ return filtered_ads[:num_items]
+
+ def _personalize_ads(self, ad_list: List[Tuple[int, float]], user_interest_tags: set) -> List[Tuple[int, float]]:
+ """
+ 根据用户兴趣标签个性化广告推荐
+
+ Args:
+ ad_list: 广告列表
+ user_interest_tags: 用户兴趣标签
+
+ Returns:
+ 个性化后的广告列表
+ """
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+
+ personalized_ads = []
+ for ad_id, ad_score in ad_list:
+ # 获取广告的标签
+ cursor.execute("""
+ SELECT t.name
+ FROM post_tags pt
+ JOIN tags t ON pt.tag_id = t.id
+ WHERE pt.post_id = %s
+ """, (ad_id,))
+
+ ad_tags = set(row[0] for row in cursor.fetchall())
+
+ # 计算标签匹配度
+ tag_match_score = len(ad_tags & user_interest_tags) / max(len(user_interest_tags), 1)
+
+ # 调整广告分数
+ final_score = ad_score * (1 + tag_match_score)
+ personalized_ads.append((ad_id, final_score))
+
+ # 重新排序
+ personalized_ads.sort(key=lambda x: x[1], reverse=True)
+ return personalized_ads
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ def get_random_ads(self, num_items: int = 5) -> List[Tuple[int, float]]:
+ """
+ 获取随机广告(用于多样性)
+
+ Args:
+ num_items: 返回物品数量
+
+ Returns:
+ List of (item_id, score) tuples
+ """
+ if len(self.ad_items) <= num_items:
+ return self.ad_items
+
+ # 随机选择但倾向于高分广告
+ weights = [score for _, score in self.ad_items]
+ selected_indices = random.choices(
+ range(len(self.ad_items)),
+ weights=weights,
+ k=num_items
+ )
+
+ return [self.ad_items[i] for i in selected_indices]
diff --git a/rhj/backend/app/models/recall/bloom_filter.py b/rhj/backend/app/models/recall/bloom_filter.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/rhj/backend/app/models/recall/bloom_filter.py
diff --git a/rhj/backend/app/models/recall/hot_recall.py b/rhj/backend/app/models/recall/hot_recall.py
new file mode 100644
index 0000000..dbc716c
--- /dev/null
+++ b/rhj/backend/app/models/recall/hot_recall.py
@@ -0,0 +1,163 @@
+import pymysql
+from typing import List, Tuple, Dict
+import numpy as np
+
+class HotRecall:
+ """
+ 热度召回算法实现
+ 基于物品的热度(热度分数、交互次数等)进行召回
+ """
+
+ def __init__(self, db_config: dict):
+ """
+ 初始化热度召回模型
+
+ Args:
+ db_config: 数据库配置
+ """
+ self.db_config = db_config
+ self.hot_items = []
+
+ def _calculate_heat_scores(self):
+ """计算物品热度分数"""
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+
+ # 综合考虑多个热度指标
+ cursor.execute("""
+ SELECT
+ p.id,
+ p.heat,
+ COUNT(DISTINCT CASE WHEN b.type = 'like' THEN b.user_id END) as like_count,
+ COUNT(DISTINCT CASE WHEN b.type = 'favorite' THEN b.user_id END) as favorite_count,
+ COUNT(DISTINCT CASE WHEN b.type = 'comment' THEN b.user_id END) as comment_count,
+ COUNT(DISTINCT CASE WHEN b.type = 'view' THEN b.user_id END) as view_count,
+ COUNT(DISTINCT CASE WHEN b.type = 'share' THEN b.user_id END) as share_count,
+ DATEDIFF(NOW(), p.created_at) as days_since_created
+ FROM posts p
+ LEFT JOIN behaviors b ON p.id = b.post_id
+ WHERE p.status = 'published'
+ GROUP BY p.id, p.heat, p.created_at
+ """)
+
+ results = cursor.fetchall()
+
+ # 计算综合热度分数
+ items_with_scores = []
+ for row in results:
+ post_id, heat, like_count, favorite_count, comment_count, view_count, share_count, days_since_created = row
+
+ # 处理None值
+ heat = heat or 0
+ like_count = like_count or 0
+ favorite_count = favorite_count or 0
+ comment_count = comment_count or 0
+ view_count = view_count or 0
+ share_count = share_count or 0
+ days_since_created = days_since_created or 0
+
+ # 综合热度分数计算
+ # 基础热度 + 加权的用户行为 + 时间衰减
+ behavior_score = (
+ like_count * 1.0 +
+ favorite_count * 2.0 +
+ comment_count * 3.0 +
+ view_count * 0.1 +
+ share_count * 5.0
+ )
+
+ # 时间衰减因子(越新的内容热度越高)
+ time_decay = np.exp(-days_since_created / 30.0) # 30天半衰期
+
+ # 最终热度分数
+ final_score = (heat * 0.3 + behavior_score * 0.7) * time_decay
+
+ items_with_scores.append((post_id, final_score))
+
+ # 按热度排序
+ self.hot_items = sorted(items_with_scores, key=lambda x: x[1], reverse=True)
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ def train(self):
+ """训练热度召回模型"""
+ print("开始计算热度分数...")
+ self._calculate_heat_scores()
+ print(f"热度召回模型训练完成,共{len(self.hot_items)}个物品")
+
+ def recall(self, user_id: int, num_items: int = 50) -> List[Tuple[int, float]]:
+ """
+ 为用户召回热门物品
+
+ Args:
+ user_id: 用户ID
+ num_items: 召回物品数量
+
+ Returns:
+ List of (item_id, score) tuples
+ """
+ # 如果尚未训练,先进行训练
+ if not hasattr(self, 'hot_items') or not self.hot_items:
+ self.train()
+
+ # 获取用户已交互的物品,避免重复推荐
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+ cursor.execute("""
+ SELECT DISTINCT post_id
+ FROM behaviors
+ WHERE user_id = %s AND type IN ('like', 'favorite', 'comment')
+ """, (user_id,))
+
+ user_interacted_items = set(row[0] for row in cursor.fetchall())
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ # 过滤掉用户已交互的物品
+ filtered_items = [
+ (item_id, score) for item_id, score in self.hot_items
+ if item_id not in user_interacted_items
+ ]
+
+ # 如果过滤后没有足够的候选,放宽条件:只过滤强交互(like, favorite, comment)
+ if len(filtered_items) < num_items:
+ print(f"热度召回:过滤后候选不足({len(filtered_items)}),放宽过滤条件")
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+ cursor.execute("""
+ SELECT DISTINCT post_id
+ FROM behaviors
+ WHERE user_id = %s AND type IN ('like', 'favorite', 'comment')
+ """, (user_id,))
+
+ strong_interacted_items = set(row[0] for row in cursor.fetchall())
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ filtered_items = [
+ (item_id, score) for item_id, score in self.hot_items
+ if item_id not in strong_interacted_items
+ ]
+
+ return filtered_items[:num_items]
+
+ def get_top_hot_items(self, num_items: int = 100) -> List[Tuple[int, float]]:
+ """
+ 获取全局热门物品(不考虑用户个性化)
+
+ Args:
+ num_items: 返回物品数量
+
+ Returns:
+ List of (item_id, score) tuples
+ """
+ return self.hot_items[:num_items]
diff --git a/rhj/backend/app/models/recall/multi_recall_manager.py b/rhj/backend/app/models/recall/multi_recall_manager.py
new file mode 100644
index 0000000..03cb3f8
--- /dev/null
+++ b/rhj/backend/app/models/recall/multi_recall_manager.py
@@ -0,0 +1,253 @@
+from typing import List, Tuple, Dict, Any
+import numpy as np
+from collections import defaultdict
+
+from .swing_recall import SwingRecall
+from .hot_recall import HotRecall
+from .ad_recall import AdRecall
+from .usercf_recall import UserCFRecall
+
+class MultiRecallManager:
+ """
+ 多路召回管理器
+ 整合Swing、热度召回、广告召回和UserCF等多种召回策略
+ """
+
+ def __init__(self, db_config: dict, recall_config: dict = None):
+ """
+ 初始化多路召回管理器
+
+ Args:
+ db_config: 数据库配置
+ recall_config: 召回配置,包含各个召回器的参数和召回数量
+ """
+ self.db_config = db_config
+
+ # 默认配置
+ default_config = {
+ 'swing': {
+ 'enabled': True,
+ 'num_items': 15,
+ 'alpha': 0.5
+ },
+ 'hot': {
+ 'enabled': True,
+ 'num_items': 10
+ },
+ 'ad': {
+ 'enabled': True,
+ 'num_items': 2
+ },
+ 'usercf': {
+ 'enabled': True,
+ 'num_items': 10,
+ 'min_common_items': 3,
+ 'num_similar_users': 50
+ }
+ }
+
+ # 合并用户配置
+ self.config = default_config
+ if recall_config:
+ for key, value in recall_config.items():
+ if key in self.config:
+ self.config[key].update(value)
+ else:
+ self.config[key] = value
+
+ # 初始化各个召回器
+ self.recalls = {}
+ self._init_recalls()
+
+ def _init_recalls(self):
+ """初始化各个召回器"""
+ if self.config['swing']['enabled']:
+ self.recalls['swing'] = SwingRecall(
+ self.db_config,
+ alpha=self.config['swing']['alpha']
+ )
+
+ if self.config['hot']['enabled']:
+ self.recalls['hot'] = HotRecall(self.db_config)
+
+ if self.config['ad']['enabled']:
+ self.recalls['ad'] = AdRecall(self.db_config)
+
+ if self.config['usercf']['enabled']:
+ self.recalls['usercf'] = UserCFRecall(
+ self.db_config,
+ min_common_items=self.config['usercf']['min_common_items']
+ )
+
+ def train_all(self):
+ """训练所有召回器"""
+ print("开始训练多路召回模型...")
+
+ for name, recall_model in self.recalls.items():
+ print(f"训练 {name} 召回器...")
+ recall_model.train()
+
+ print("所有召回器训练完成!")
+
+ def recall(self, user_id: int, total_items: int = 200) -> Tuple[List[int], List[float], Dict[str, List[Tuple[int, float]]]]:
+ """
+ 执行多路召回
+
+ Args:
+ user_id: 用户ID
+ total_items: 总召回物品数量
+
+ Returns:
+ Tuple containing:
+ - List of item IDs
+ - List of scores
+ - Dict of recall results by source
+ """
+ recall_results = {}
+ all_candidates = defaultdict(list) # item_id -> [(source, score), ...]
+
+ # 执行各路召回
+ for source, recall_model in self.recalls.items():
+ if not self.config[source]['enabled']:
+ continue
+
+ num_items = self.config[source]['num_items']
+
+ # 特殊处理UserCF的参数
+ if source == 'usercf':
+ items_scores = recall_model.recall(
+ user_id,
+ num_items=num_items,
+ num_similar_users=self.config[source]['num_similar_users']
+ )
+ else:
+ items_scores = recall_model.recall(user_id, num_items=num_items)
+
+ recall_results[source] = items_scores
+
+ # 收集候选物品
+ for item_id, score in items_scores:
+ all_candidates[item_id].append((source, score))
+
+ # 融合多路召回结果
+ final_candidates = self._merge_recall_results(all_candidates, total_items)
+
+ # 分离item_ids和scores
+ item_ids = [item_id for item_id, _ in final_candidates]
+ scores = [score for _, score in final_candidates]
+
+ return item_ids, scores, recall_results
+
+ def _merge_recall_results(self, all_candidates: Dict[int, List[Tuple[str, float]]],
+ total_items: int) -> List[Tuple[int, float]]:
+ """
+ 融合多路召回结果
+
+ Args:
+ all_candidates: 所有候选物品及其来源和分数
+ total_items: 最终返回的物品数量
+
+ Returns:
+ List of (item_id, final_score) tuples
+ """
+ # 定义各召回源的权重
+ source_weights = {
+ 'swing': 0.3,
+ 'hot': 0.2,
+ 'ad': 0.1,
+ 'usercf': 0.4
+ }
+
+ final_scores = []
+
+ for item_id, source_scores in all_candidates.items():
+ # 计算加权平均分数
+ weighted_score = 0.0
+ total_weight = 0.0
+
+ for source, score in source_scores:
+ weight = source_weights.get(source, 0.1)
+ weighted_score += weight * score
+ total_weight += weight
+
+ # 归一化
+ if total_weight > 0:
+ final_score = weighted_score / total_weight
+ else:
+ final_score = 0.0
+
+ # 多样性奖励:如果物品来自多个召回源,给予额外分数
+ diversity_bonus = len(source_scores) * 0.1
+ final_score += diversity_bonus
+
+ final_scores.append((item_id, final_score))
+
+ # 按最终分数排序
+ final_scores.sort(key=lambda x: x[1], reverse=True)
+
+ return final_scores[:total_items]
+
+ def get_recall_stats(self, user_id: int) -> Dict[str, Any]:
+ """
+ 获取召回统计信息
+
+ Args:
+ user_id: 用户ID
+
+ Returns:
+ 召回统计字典
+ """
+ stats = {
+ 'user_id': user_id,
+ 'enabled_recalls': list(self.recalls.keys()),
+ 'config': self.config
+ }
+
+ # 获取各召回器的统计信息
+ if 'usercf' in self.recalls:
+ try:
+ user_profile = self.recalls['usercf'].get_user_profile(user_id)
+ stats['user_profile'] = user_profile
+
+ neighbors = self.recalls['usercf'].get_user_neighbors(user_id, 5)
+ stats['similar_users'] = neighbors
+ except:
+ pass
+
+ return stats
+
+ def update_config(self, new_config: dict):
+ """
+ 更新召回配置
+
+ Args:
+ new_config: 新的配置字典
+ """
+ for key, value in new_config.items():
+ if key in self.config:
+ self.config[key].update(value)
+ else:
+ self.config[key] = value
+
+ # 重新初始化召回器
+ self._init_recalls()
+
+ def get_recall_breakdown(self, user_id: int) -> Dict[str, int]:
+ """
+ 获取各召回源的物品数量分布
+
+ Args:
+ user_id: 用户ID
+
+ Returns:
+ 各召回源的物品数量字典
+ """
+ breakdown = {}
+
+ for source in self.recalls.keys():
+ if self.config[source]['enabled']:
+ breakdown[source] = self.config[source]['num_items']
+ else:
+ breakdown[source] = 0
+
+ return breakdown
diff --git a/rhj/backend/app/models/recall/swing_recall.py b/rhj/backend/app/models/recall/swing_recall.py
new file mode 100644
index 0000000..bf7fdd6
--- /dev/null
+++ b/rhj/backend/app/models/recall/swing_recall.py
@@ -0,0 +1,126 @@
+import numpy as np
+import pymysql
+from collections import defaultdict
+import math
+from typing import List, Tuple, Dict
+
+class SwingRecall:
+ """
+ Swing召回算法实现
+ 基于物品相似度的协同过滤算法,能够有效处理热门物品的问题
+ """
+
+ def __init__(self, db_config: dict, alpha: float = 0.5):
+ """
+ 初始化Swing召回模型
+
+ Args:
+ db_config: 数据库配置
+ alpha: 控制热门物品惩罚的参数,值越大惩罚越强
+ """
+ self.db_config = db_config
+ self.alpha = alpha
+ self.item_similarity = {}
+ self.user_items = defaultdict(set)
+ self.item_users = defaultdict(set)
+
+ def _get_interaction_data(self):
+ """获取用户-物品交互数据"""
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+ # 获取用户行为数据(点赞、收藏、评论等)
+ cursor.execute("""
+ SELECT DISTINCT user_id, post_id
+ FROM behaviors
+ WHERE type IN ('like', 'favorite', 'comment')
+ """)
+ interactions = cursor.fetchall()
+
+ for user_id, post_id in interactions:
+ self.user_items[user_id].add(post_id)
+ self.item_users[post_id].add(user_id)
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ def _calculate_swing_similarity(self):
+ """计算Swing相似度矩阵"""
+ print("开始计算Swing相似度...")
+
+ # 获取所有物品对
+ items = list(self.item_users.keys())
+
+ for i, item_i in enumerate(items):
+ if i % 100 == 0:
+ print(f"处理进度: {i}/{len(items)}")
+
+ self.item_similarity[item_i] = {}
+
+ for item_j in items[i+1:]:
+ # 获取同时交互过两个物品的用户
+ common_users = self.item_users[item_i] & self.item_users[item_j]
+
+ if len(common_users) < 2: # 需要至少2个共同用户
+ similarity = 0.0
+ else:
+ # 计算Swing相似度
+ similarity = 0.0
+ for u in common_users:
+ for v in common_users:
+ if u != v:
+ # Swing算法的核心公式
+ swing_weight = 1.0 / (self.alpha + len(self.user_items[u] & self.user_items[v]))
+ similarity += swing_weight
+
+ # 归一化
+ similarity = similarity / (len(common_users) * (len(common_users) - 1))
+
+ self.item_similarity[item_i][item_j] = similarity
+ # 对称性
+ if item_j not in self.item_similarity:
+ self.item_similarity[item_j] = {}
+ self.item_similarity[item_j][item_i] = similarity
+
+ print("Swing相似度计算完成")
+
+ def train(self):
+ """训练Swing模型"""
+ self._get_interaction_data()
+ self._calculate_swing_similarity()
+
+ def recall(self, user_id: int, num_items: int = 50) -> List[Tuple[int, float]]:
+ """
+ 为用户召回相似物品
+
+ Args:
+ user_id: 用户ID
+ num_items: 召回物品数量
+
+ Returns:
+ List of (item_id, score) tuples
+ """
+ # 如果尚未训练,先进行训练
+ if not hasattr(self, 'item_similarity') or not self.item_similarity:
+ self.train()
+
+ if user_id not in self.user_items:
+ return []
+
+ # 获取用户历史交互的物品
+ user_interacted_items = self.user_items[user_id]
+
+ # 计算候选物品的分数
+ candidate_scores = defaultdict(float)
+
+ for item_i in user_interacted_items:
+ if item_i in self.item_similarity:
+ for item_j, similarity in self.item_similarity[item_i].items():
+ # 排除用户已经交互过的物品
+ if item_j not in user_interacted_items:
+ candidate_scores[item_j] += similarity
+
+ # 按分数排序并返回top-N
+ sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
+ return sorted_candidates[:num_items]
diff --git a/rhj/backend/app/models/recall/usercf_recall.py b/rhj/backend/app/models/recall/usercf_recall.py
new file mode 100644
index 0000000..d75e6d8
--- /dev/null
+++ b/rhj/backend/app/models/recall/usercf_recall.py
@@ -0,0 +1,235 @@
+import pymysql
+from typing import List, Tuple, Dict, Set
+from collections import defaultdict
+import math
+import numpy as np
+
+class UserCFRecall:
+ """
+ UserCF (User-based Collaborative Filtering) 召回算法实现
+ 基于用户相似度的协同过滤算法
+ """
+
+ def __init__(self, db_config: dict, min_common_items: int = 3):
+ """
+ 初始化UserCF召回模型
+
+ Args:
+ db_config: 数据库配置
+ min_common_items: 计算用户相似度时的最小共同物品数
+ """
+ self.db_config = db_config
+ self.min_common_items = min_common_items
+ self.user_items = defaultdict(set)
+ self.item_users = defaultdict(set)
+ self.user_similarity = {}
+
+ def _get_user_item_interactions(self):
+ """获取用户-物品交互数据"""
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+
+ # 获取用户行为数据,考虑不同行为的权重
+ cursor.execute("""
+ SELECT user_id, post_id, type, COUNT(*) as count
+ FROM behaviors
+ WHERE type IN ('like', 'favorite', 'comment', 'view')
+ GROUP BY user_id, post_id, type
+ """)
+
+ interactions = cursor.fetchall()
+
+ # 构建用户-物品交互矩阵(考虑行为权重)
+ user_item_scores = defaultdict(lambda: defaultdict(float))
+
+ # 定义不同行为的权重
+ behavior_weights = {
+ 'like': 1.0,
+ 'favorite': 2.0,
+ 'comment': 3.0,
+ 'view': 0.1
+ }
+
+ for user_id, post_id, behavior_type, count in interactions:
+ weight = behavior_weights.get(behavior_type, 1.0)
+ score = weight * count
+ user_item_scores[user_id][post_id] += score
+
+ # 转换为集合形式(用于相似度计算)
+ for user_id, items in user_item_scores.items():
+ # 只保留分数大于阈值的物品
+ threshold = 1.0 # 可调整阈值
+ for item_id, score in items.items():
+ if score >= threshold:
+ self.user_items[user_id].add(item_id)
+ self.item_users[item_id].add(user_id)
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ def _calculate_user_similarity(self):
+ """计算用户相似度矩阵"""
+ print("开始计算用户相似度...")
+
+ users = list(self.user_items.keys())
+ total_pairs = len(users) * (len(users) - 1) // 2
+ processed = 0
+
+ for i, user_i in enumerate(users):
+ self.user_similarity[user_i] = {}
+
+ for user_j in users[i+1:]:
+ processed += 1
+ if processed % 10000 == 0:
+ print(f"处理进度: {processed}/{total_pairs}")
+
+ # 获取两个用户共同交互的物品
+ common_items = self.user_items[user_i] & self.user_items[user_j]
+
+ if len(common_items) < self.min_common_items:
+ similarity = 0.0
+ else:
+ # 计算余弦相似度
+ numerator = len(common_items)
+ denominator = math.sqrt(len(self.user_items[user_i]) * len(self.user_items[user_j]))
+ similarity = numerator / denominator if denominator > 0 else 0.0
+
+ self.user_similarity[user_i][user_j] = similarity
+ # 对称性
+ if user_j not in self.user_similarity:
+ self.user_similarity[user_j] = {}
+ self.user_similarity[user_j][user_i] = similarity
+
+ print("用户相似度计算完成")
+
+ def train(self):
+ """训练UserCF模型"""
+ self._get_user_item_interactions()
+ self._calculate_user_similarity()
+
+ def recall(self, user_id: int, num_items: int = 50, num_similar_users: int = 50) -> List[Tuple[int, float]]:
+ """
+ 为用户召回相似用户喜欢的物品
+
+ Args:
+ user_id: 目标用户ID
+ num_items: 召回物品数量
+ num_similar_users: 考虑的相似用户数量
+
+ Returns:
+ List of (item_id, score) tuples
+ """
+ # 如果尚未训练,先进行训练
+ if not hasattr(self, 'user_similarity') or not self.user_similarity:
+ self.train()
+
+ if user_id not in self.user_similarity or user_id not in self.user_items:
+ return []
+
+ # 获取最相似的用户
+ similar_users = sorted(
+ self.user_similarity[user_id].items(),
+ key=lambda x: x[1],
+ reverse=True
+ )[:num_similar_users]
+
+ # 获取目标用户已交互的物品
+ user_interacted_items = self.user_items[user_id]
+
+ # 计算候选物品的分数
+ candidate_scores = defaultdict(float)
+
+ for similar_user_id, similarity in similar_users:
+ if similarity <= 0:
+ continue
+
+ # 获取相似用户交互的物品
+ similar_user_items = self.user_items[similar_user_id]
+
+ for item_id in similar_user_items:
+ # 排除目标用户已经交互过的物品
+ if item_id not in user_interacted_items:
+ candidate_scores[item_id] += similarity
+
+ # 按分数排序并返回top-N
+ sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
+ return sorted_candidates[:num_items]
+
+ def get_user_neighbors(self, user_id: int, num_neighbors: int = 10) -> List[Tuple[int, float]]:
+ """
+ 获取用户的相似邻居
+
+ Args:
+ user_id: 用户ID
+ num_neighbors: 邻居数量
+
+ Returns:
+ List of (neighbor_user_id, similarity) tuples
+ """
+ if user_id not in self.user_similarity:
+ return []
+
+ neighbors = sorted(
+ self.user_similarity[user_id].items(),
+ key=lambda x: x[1],
+ reverse=True
+ )[:num_neighbors]
+
+ return neighbors
+
+ def get_user_profile(self, user_id: int) -> Dict:
+ """
+ 获取用户画像信息
+
+ Args:
+ user_id: 用户ID
+
+ Returns:
+ 用户画像字典
+ """
+ if user_id not in self.user_items:
+ return {}
+
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+
+ # 获取用户交互的物品类别统计
+ user_item_list = list(self.user_items[user_id])
+ if not user_item_list:
+ return {}
+
+ format_strings = ','.join(['%s'] * len(user_item_list))
+ cursor.execute(f"""
+ SELECT t.name, COUNT(*) as count
+ FROM post_tags pt
+ JOIN tags t ON pt.tag_id = t.id
+ WHERE pt.post_id IN ({format_strings})
+ GROUP BY t.name
+ ORDER BY count DESC
+ """, tuple(user_item_list))
+
+ tag_preferences = cursor.fetchall()
+
+ # 获取用户行为统计
+ cursor.execute("""
+ SELECT type, COUNT(*) as count
+ FROM behaviors
+ WHERE user_id = %s
+ GROUP BY type
+ """, (user_id,))
+
+ behavior_stats = cursor.fetchall()
+
+ return {
+ 'user_id': user_id,
+ 'total_interactions': len(self.user_items[user_id]),
+ 'tag_preferences': dict(tag_preferences),
+ 'behavior_stats': dict(behavior_stats)
+ }
+
+ finally:
+ cursor.close()
+ conn.close()
diff --git a/rhj/backend/app/models/recommend/LightGCN.py b/rhj/backend/app/models/recommend/LightGCN.py
new file mode 100644
index 0000000..38b1732
--- /dev/null
+++ b/rhj/backend/app/models/recommend/LightGCN.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import scipy.sparse as sp
+import math
+import networkx as nx
+import random
+from copy import deepcopy
+from app.utils.parse_args import args
+from app.models.recommend.base_model import BaseModel
+from app.models.recommend.operators import EdgelistDrop
+from app.models.recommend.operators import scatter_add, scatter_sum
+
+
+init = nn.init.xavier_uniform_
+
+class LightGCN(BaseModel):
+ def __init__(self, dataset, pretrained_model=None, phase='pretrain'):
+ super().__init__(dataset)
+ self.adj = self._make_binorm_adj(dataset.graph)
+ self.edges = self.adj._indices().t()
+ self.edge_norm = self.adj._values()
+
+ self.phase = phase
+
+ self.emb_gate = lambda x: x
+
+ if self.phase == 'pretrain' or self.phase == 'vanilla' or self.phase == 'for_tune':
+ self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size)))
+ self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size)))
+
+
+ elif self.phase == 'finetune':
+ pre_user_emb, pre_item_emb = pretrained_model.generate()
+ self.user_embedding = nn.Parameter(pre_user_emb).requires_grad_(True)
+ self.item_embedding = nn.Parameter(pre_item_emb).requires_grad_(True)
+
+ elif self.phase == 'continue_tune':
+ # re-initialize for loading state dict
+ self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size)))
+ self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size)))
+
+ self.edge_dropout = EdgelistDrop()
+
+ def _agg(self, all_emb, edges, edge_norm):
+ src_emb = all_emb[edges[:, 0]]
+
+ # bi-norm
+ src_emb = src_emb * edge_norm.unsqueeze(1)
+
+ # conv
+ dst_emb = scatter_sum(src_emb, edges[:, 1], dim=0, dim_size=self.num_users+self.num_items)
+ return dst_emb
+
+ def _edge_binorm(self, edges):
+ user_degs = scatter_add(torch.ones_like(edges[:, 0]), edges[:, 0], dim=0, dim_size=self.num_users)
+ user_degs = user_degs[edges[:, 0]]
+ item_degs = scatter_add(torch.ones_like(edges[:, 1]), edges[:, 1], dim=0, dim_size=self.num_items)
+ item_degs = item_degs[edges[:, 1]]
+ norm = torch.pow(user_degs, -0.5) * torch.pow(item_degs, -0.5)
+ return norm
+
+ def forward(self, edges, edge_norm, return_layers=False):
+ all_emb = torch.cat([self.user_embedding, self.item_embedding], dim=0)
+ all_emb = self.emb_gate(all_emb)
+ res_emb = [all_emb]
+ for l in range(args.num_layers):
+ all_emb = self._agg(res_emb[-1], edges, edge_norm)
+ res_emb.append(all_emb)
+ if not return_layers:
+ res_emb = sum(res_emb)
+ user_res_emb, item_res_emb = res_emb.split([self.num_users, self.num_items], dim=0)
+ else:
+ user_res_emb, item_res_emb = [], []
+ for emb in res_emb:
+ u_emb, i_emb = emb.split([self.num_users, self.num_items], dim=0)
+ user_res_emb.append(u_emb)
+ item_res_emb.append(i_emb)
+ return user_res_emb, item_res_emb
+
+ def cal_loss(self, batch_data):
+ edges, dropout_mask = self.edge_dropout(self.edges, 1-args.edge_dropout, return_mask=True)
+ edge_norm = self.edge_norm[dropout_mask]
+
+ # forward
+ users, pos_items, neg_items = batch_data
+ user_emb, item_emb = self.forward(edges, edge_norm)
+ batch_user_emb = user_emb[users]
+ pos_item_emb = item_emb[pos_items]
+ neg_item_emb = item_emb[neg_items]
+ rec_loss = self._bpr_loss(batch_user_emb, pos_item_emb, neg_item_emb)
+ reg_loss = args.weight_decay * self._reg_loss(users, pos_items, neg_items)
+
+ loss = rec_loss + reg_loss
+ loss_dict = {
+ "rec_loss": rec_loss.item(),
+ "reg_loss": reg_loss.item(),
+ }
+ return loss, loss_dict
+
+ @torch.no_grad()
+ def generate(self, return_layers=False):
+ return self.forward(self.edges, self.edge_norm, return_layers=return_layers)
+
+ @torch.no_grad()
+ def generate_lgn(self, return_layers=False):
+ return self.forward(self.edges, self.edge_norm, return_layers=return_layers)
+
+ @torch.no_grad()
+ def rating(self, user_emb, item_emb):
+ return torch.matmul(user_emb, item_emb.t())
+
+ def _reg_loss(self, users, pos_items, neg_items):
+ u_emb = self.user_embedding[users]
+ pos_i_emb = self.item_embedding[pos_items]
+ neg_i_emb = self.item_embedding[neg_items]
+ reg_loss = (1/2)*(u_emb.norm(2).pow(2) +
+ pos_i_emb.norm(2).pow(2) +
+ neg_i_emb.norm(2).pow(2))/float(len(users))
+ return reg_loss
diff --git a/rhj/backend/app/models/recommend/LightGCN_pretrained.pt b/rhj/backend/app/models/recommend/LightGCN_pretrained.pt
new file mode 100644
index 0000000..825e0e2
--- /dev/null
+++ b/rhj/backend/app/models/recommend/LightGCN_pretrained.pt
Binary files differ
diff --git a/rhj/backend/app/models/recommend/__pycache__/LightGCN.cpython-312.pyc b/rhj/backend/app/models/recommend/__pycache__/LightGCN.cpython-312.pyc
new file mode 100644
index 0000000..c87435f
--- /dev/null
+++ b/rhj/backend/app/models/recommend/__pycache__/LightGCN.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recommend/__pycache__/base_model.cpython-312.pyc b/rhj/backend/app/models/recommend/__pycache__/base_model.cpython-312.pyc
new file mode 100644
index 0000000..b9d8c72
--- /dev/null
+++ b/rhj/backend/app/models/recommend/__pycache__/base_model.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recommend/__pycache__/lightgcn_scorer.cpython-312.pyc b/rhj/backend/app/models/recommend/__pycache__/lightgcn_scorer.cpython-312.pyc
new file mode 100644
index 0000000..b0887a9
--- /dev/null
+++ b/rhj/backend/app/models/recommend/__pycache__/lightgcn_scorer.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recommend/__pycache__/operators.cpython-312.pyc b/rhj/backend/app/models/recommend/__pycache__/operators.cpython-312.pyc
new file mode 100644
index 0000000..13bb375
--- /dev/null
+++ b/rhj/backend/app/models/recommend/__pycache__/operators.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recommend/base_model.py b/rhj/backend/app/models/recommend/base_model.py
new file mode 100644
index 0000000..6c59aa6
--- /dev/null
+++ b/rhj/backend/app/models/recommend/base_model.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+from app.utils.parse_args import args
+from scipy.sparse import csr_matrix
+import scipy.sparse as sp
+import numpy as np
+import torch.nn.functional as F
+
+
+class BaseModel(nn.Module):
+ def __init__(self, dataloader):
+ super(BaseModel, self).__init__()
+ self.num_users = dataloader.num_users
+ self.num_items = dataloader.num_items
+ self.emb_size = args.emb_size
+
+ def forward(self):
+ pass
+
+ def cal_loss(self, batch_data):
+ pass
+
+ def _check_inf(self, loss, pos_score, neg_score, edge_weight):
+ # find inf idx
+ inf_idx = torch.isinf(loss) | torch.isnan(loss)
+ if inf_idx.any():
+ print("find inf in loss")
+ if type(edge_weight) != int:
+ print(edge_weight[inf_idx])
+ print(f"pos_score: {pos_score[inf_idx]}")
+ print(f"neg_score: {neg_score[inf_idx]}")
+ raise ValueError("find inf in loss")
+
+ def _make_binorm_adj(self, mat):
+ a = csr_matrix((self.num_users, self.num_users))
+ b = csr_matrix((self.num_items, self.num_items))
+ mat = sp.vstack(
+ [sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
+ mat = (mat != 0) * 1.0
+ # mat = (mat + sp.eye(mat.shape[0])) * 1.0# MARK
+ degree = np.array(mat.sum(axis=-1))
+ d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
+ d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
+ d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
+ mat = mat.dot(d_inv_sqrt_mat).transpose().dot(
+ d_inv_sqrt_mat).tocoo()
+
+ # make torch tensor
+ idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
+ vals = torch.from_numpy(mat.data.astype(np.float32))
+ shape = torch.Size(mat.shape)
+ return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device)
+
+ def _make_binorm_adj_self_loop(self, mat):
+ a = csr_matrix((self.num_users, self.num_users))
+ b = csr_matrix((self.num_items, self.num_items))
+ mat = sp.vstack(
+ [sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
+ mat = (mat != 0) * 1.0
+ mat = (mat + sp.eye(mat.shape[0])) * 1.0 # self loop
+ degree = np.array(mat.sum(axis=-1))
+ d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
+ d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
+ d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
+ mat = mat.dot(d_inv_sqrt_mat).transpose().dot(
+ d_inv_sqrt_mat).tocoo()
+
+ # make torch tensor
+ idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
+ vals = torch.from_numpy(mat.data.astype(np.float32))
+ shape = torch.Size(mat.shape)
+ return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device)
+
+
+ def _sp_matrix_to_sp_tensor(self, sp_matrix):
+ coo = sp_matrix.tocoo()
+ indices = torch.LongTensor([coo.row, coo.col])
+ values = torch.FloatTensor(coo.data)
+ return torch.sparse.FloatTensor(indices, values, coo.shape).coalesce().to(args.device)
+
+ def _bpr_loss(self, user_emb, pos_item_emb, neg_item_emb):
+ pos_score = (user_emb * pos_item_emb).sum(dim=1)
+ neg_score = (user_emb * neg_item_emb).sum(dim=1)
+ loss = -torch.log(1e-10 + torch.sigmoid((pos_score - neg_score)))
+ self._check_inf(loss, pos_score, neg_score, 0)
+ return loss.mean()
+
+ def _nce_loss(self, pos_score, neg_score, edge_weight=1):
+ numerator = torch.exp(pos_score)
+ denominator = torch.exp(pos_score) + torch.exp(neg_score).sum(dim=1)
+ loss = -torch.log(numerator/denominator) * edge_weight
+ self._check_inf(loss, pos_score, neg_score, edge_weight)
+ return loss.mean()
+
+ def _infonce_loss(self, pos_1, pos_2, negs, tau):
+ pos_1 = self.cl_mlp(pos_1)
+ pos_2 = self.cl_mlp(pos_2)
+ negs = self.cl_mlp(negs)
+ pos_1 = F.normalize(pos_1, dim=-1)
+ pos_2 = F.normalize(pos_2, dim=-1)
+ negs = F.normalize(negs, dim=-1)
+ pos_score = torch.mul(pos_1, pos_2).sum(dim=1)
+ # B, 1, E * B, E, N -> B, N
+ neg_score = torch.bmm(pos_1.unsqueeze(1), negs.transpose(1, 2)).squeeze(1)
+ # infonce loss
+ numerator = torch.exp(pos_score / tau)
+ denominator = torch.exp(pos_score / tau) + torch.exp(neg_score / tau).sum(dim=1)
+ loss = -torch.log(numerator/denominator)
+ self._check_inf(loss, pos_score, neg_score, 0)
+ return loss.mean()
+
\ No newline at end of file
diff --git a/rhj/backend/app/models/recommend/operators.py b/rhj/backend/app/models/recommend/operators.py
new file mode 100644
index 0000000..a508966
--- /dev/null
+++ b/rhj/backend/app/models/recommend/operators.py
@@ -0,0 +1,52 @@
+import torch
+from typing import Optional, Tuple
+from torch import nn
+
+def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
+ if dim < 0:
+ dim = other.dim() + dim
+ if src.dim() == 1:
+ for _ in range(0, dim):
+ src = src.unsqueeze(0)
+ for _ in range(src.dim(), other.dim()):
+ src = src.unsqueeze(-1)
+ src = src.expand(other.size())
+ return src
+
+def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
+ out: Optional[torch.Tensor] = None,
+ dim_size: Optional[int] = None) -> torch.Tensor:
+ index = broadcast(index, src, dim)
+ if out is None:
+ size = list(src.size())
+ if dim_size is not None:
+ size[dim] = dim_size
+ elif index.numel() == 0:
+ size[dim] = 0
+ else:
+ size[dim] = int(index.max()) + 1
+ out = torch.zeros(size, dtype=src.dtype, device=src.device)
+ return out.scatter_add_(dim, index, src)
+ else:
+ return out.scatter_add_(dim, index, src)
+
+def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
+ out: Optional[torch.Tensor] = None,
+ dim_size: Optional[int] = None) -> torch.Tensor:
+ return scatter_sum(src, index, dim, out, dim_size)
+
+
+class EdgelistDrop(nn.Module):
+ def __init__(self):
+ super(EdgelistDrop, self).__init__()
+
+ def forward(self, edgeList, keep_rate, return_mask=False):
+ if keep_rate == 1.0:
+ return edgeList, torch.ones(edgeList.size(0)).type(torch.bool)
+ edgeNum = edgeList.size(0)
+ mask = (torch.rand(edgeNum) + keep_rate).floor().type(torch.bool)
+ newEdgeList = edgeList[mask, :]
+ if return_mask:
+ return newEdgeList, mask
+ else:
+ return newEdgeList