推荐系统

Change-Id: I49b9205568f1ccf88b32b08511aff8b0bea8d1bd
diff --git a/rhj/backend/app/models/recall/__init__.py b/rhj/backend/app/models/recall/__init__.py
new file mode 100644
index 0000000..98d926b
--- /dev/null
+++ b/rhj/backend/app/models/recall/__init__.py
@@ -0,0 +1,24 @@
+"""
+多路召回模块
+
+包含以下召回算法:
+- SwingRecall: Swing召回算法,基于物品相似度
+- HotRecall: 热度召回算法,基于物品热度
+- AdRecall: 广告召回算法,专门处理广告内容
+- UserCFRecall: 用户协同过滤召回算法
+- MultiRecallManager: 多路召回管理器,整合所有召回策略
+"""
+
+from .swing_recall import SwingRecall
+from .hot_recall import HotRecall
+from .ad_recall import AdRecall
+from .usercf_recall import UserCFRecall
+from .multi_recall_manager import MultiRecallManager
+
+__all__ = [
+    'SwingRecall',
+    'HotRecall', 
+    'AdRecall',
+    'UserCFRecall',
+    'MultiRecallManager'
+]
diff --git a/rhj/backend/app/models/recall/__pycache__/__init__.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000..d1cf37c
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/__init__.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/ad_recall.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/ad_recall.cpython-312.pyc
new file mode 100644
index 0000000..08a722c
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/ad_recall.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/bloom_filter.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/bloom_filter.cpython-312.pyc
new file mode 100644
index 0000000..c4dae7e
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/bloom_filter.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/hot_recall.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/hot_recall.cpython-312.pyc
new file mode 100644
index 0000000..cb6c725
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/hot_recall.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/multi_recall_manager.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/multi_recall_manager.cpython-312.pyc
new file mode 100644
index 0000000..9a95456
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/multi_recall_manager.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/swing_recall.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/swing_recall.cpython-312.pyc
new file mode 100644
index 0000000..d913d68
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/swing_recall.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/__pycache__/usercf_recall.cpython-312.pyc b/rhj/backend/app/models/recall/__pycache__/usercf_recall.cpython-312.pyc
new file mode 100644
index 0000000..adb6177
--- /dev/null
+++ b/rhj/backend/app/models/recall/__pycache__/usercf_recall.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recall/ad_recall.py b/rhj/backend/app/models/recall/ad_recall.py
new file mode 100644
index 0000000..0fe3b0a
--- /dev/null
+++ b/rhj/backend/app/models/recall/ad_recall.py
@@ -0,0 +1,207 @@
+import pymysql
+from typing import List, Tuple, Dict
+import random
+
+class AdRecall:
+    """
+    广告召回算法实现
+    专门用于召回广告类型的内容
+    """
+    
+    def __init__(self, db_config: dict):
+        """
+        初始化广告召回模型
+        
+        Args:
+            db_config: 数据库配置
+        """
+        self.db_config = db_config
+        self.ad_items = []
+        
+    def _get_ad_items(self):
+        """获取广告物品列表"""
+        conn = pymysql.connect(**self.db_config)
+        try:
+            cursor = conn.cursor()
+            
+            # 获取所有广告帖子,按热度和发布时间排序
+            cursor.execute("""
+                SELECT 
+                    p.id,
+                    p.heat,
+                    p.created_at,
+                    COUNT(DISTINCT b.user_id) as interaction_count,
+                    DATEDIFF(NOW(), p.created_at) as days_since_created
+                FROM posts p
+                LEFT JOIN behaviors b ON p.id = b.post_id
+                WHERE p.is_advertisement = 1 AND p.status = 'published'
+                GROUP BY p.id, p.heat, p.created_at
+                ORDER BY p.heat DESC, p.created_at DESC
+            """)
+            
+            results = cursor.fetchall()
+            
+            # 计算广告分数
+            items_with_scores = []
+            for row in results:
+                post_id, heat, created_at, interaction_count, days_since_created = row
+                
+                # 处理None值
+                heat = heat or 0
+                interaction_count = interaction_count or 0
+                days_since_created = days_since_created or 0
+                
+                # 广告分数计算:热度 + 交互数 - 时间惩罚
+                # 新发布的广告给予更高权重
+                freshness_bonus = max(0, 30 - days_since_created) / 30.0  # 30天内的新鲜度奖励
+                
+                ad_score = (
+                    heat * 0.6 +
+                    interaction_count * 0.3 +
+                    freshness_bonus * 100  # 新鲜度奖励
+                )
+                
+                items_with_scores.append((post_id, ad_score))
+            
+            # 按广告分数排序
+            self.ad_items = sorted(items_with_scores, key=lambda x: x[1], reverse=True)
+            
+        finally:
+            cursor.close()
+            conn.close()
+    
+    def train(self):
+        """训练广告召回模型"""
+        print("开始获取广告物品...")
+        self._get_ad_items()
+        print(f"广告召回模型训练完成,共{len(self.ad_items)}个广告物品")
+    
+    def recall(self, user_id: int, num_items: int = 10) -> List[Tuple[int, float]]:
+        """
+        为用户召回广告物品
+        
+        Args:
+            user_id: 用户ID
+            num_items: 召回物品数量
+            
+        Returns:
+            List of (item_id, score) tuples
+        """
+        # 如果尚未训练,先进行训练
+        if not hasattr(self, 'ad_items') or not self.ad_items:
+            self.train()
+        
+        # 获取用户已交互的广告,避免重复推荐
+        conn = pymysql.connect(**self.db_config)
+        try:
+            cursor = conn.cursor()
+            cursor.execute("""
+                SELECT DISTINCT b.post_id 
+                FROM behaviors b
+                JOIN posts p ON b.post_id = p.id
+                WHERE b.user_id = %s AND p.is_advertisement = 1
+                AND b.type IN ('like', 'favorite', 'comment', 'view')
+            """, (user_id,))
+            
+            user_interacted_ads = set(row[0] for row in cursor.fetchall())
+            
+            # 获取用户的兴趣标签(基于历史行为)
+            cursor.execute("""
+                SELECT t.name, COUNT(*) as count
+                FROM behaviors b
+                JOIN posts p ON b.post_id = p.id
+                JOIN post_tags pt ON p.id = pt.post_id
+                JOIN tags t ON pt.tag_id = t.id
+                WHERE b.user_id = %s AND b.type IN ('like', 'favorite', 'comment')
+                GROUP BY t.name
+                ORDER BY count DESC
+                LIMIT 10
+            """, (user_id,))
+            
+            user_interest_tags = set(row[0] for row in cursor.fetchall())
+            
+        finally:
+            cursor.close()
+            conn.close()
+        
+        # 过滤掉用户已交互的广告
+        filtered_ads = [
+            (item_id, score) for item_id, score in self.ad_items
+            if item_id not in user_interacted_ads
+        ]
+        
+        # 如果没有未交互的广告,但有广告数据,返回评分最高的广告(可能用户会再次感兴趣)
+        if not filtered_ads and self.ad_items:
+            print(f"用户 {user_id} 已与所有广告交互,返回评分最高的广告")
+            filtered_ads = self.ad_items[:num_items]
+        
+        # 如果用户有兴趣标签,可以进一步个性化广告推荐
+        if user_interest_tags and filtered_ads:
+            filtered_ads = self._personalize_ads(filtered_ads, user_interest_tags)
+        
+        return filtered_ads[:num_items]
+    
+    def _personalize_ads(self, ad_list: List[Tuple[int, float]], user_interest_tags: set) -> List[Tuple[int, float]]:
+        """
+        根据用户兴趣标签个性化广告推荐
+        
+        Args:
+            ad_list: 广告列表
+            user_interest_tags: 用户兴趣标签
+            
+        Returns:
+            个性化后的广告列表
+        """
+        conn = pymysql.connect(**self.db_config)
+        try:
+            cursor = conn.cursor()
+            
+            personalized_ads = []
+            for ad_id, ad_score in ad_list:
+                # 获取广告的标签
+                cursor.execute("""
+                    SELECT t.name
+                    FROM post_tags pt
+                    JOIN tags t ON pt.tag_id = t.id
+                    WHERE pt.post_id = %s
+                """, (ad_id,))
+                
+                ad_tags = set(row[0] for row in cursor.fetchall())
+                
+                # 计算标签匹配度
+                tag_match_score = len(ad_tags & user_interest_tags) / max(len(user_interest_tags), 1)
+                
+                # 调整广告分数
+                final_score = ad_score * (1 + tag_match_score)
+                personalized_ads.append((ad_id, final_score))
+            
+            # 重新排序
+            personalized_ads.sort(key=lambda x: x[1], reverse=True)
+            return personalized_ads
+            
+        finally:
+            cursor.close()
+            conn.close()
+    
+    def get_random_ads(self, num_items: int = 5) -> List[Tuple[int, float]]:
+        """
+        获取随机广告(用于多样性)
+        
+        Args:
+            num_items: 返回物品数量
+            
+        Returns:
+            List of (item_id, score) tuples
+        """
+        if len(self.ad_items) <= num_items:
+            return self.ad_items
+        
+        # 随机选择但倾向于高分广告
+        weights = [score for _, score in self.ad_items]
+        selected_indices = random.choices(
+            range(len(self.ad_items)), 
+            weights=weights, 
+            k=num_items
+        )
+        
+        return [self.ad_items[i] for i in selected_indices]
diff --git a/rhj/backend/app/models/recall/bloom_filter.py b/rhj/backend/app/models/recall/bloom_filter.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/rhj/backend/app/models/recall/bloom_filter.py
diff --git a/rhj/backend/app/models/recall/hot_recall.py b/rhj/backend/app/models/recall/hot_recall.py
new file mode 100644
index 0000000..dbc716c
--- /dev/null
+++ b/rhj/backend/app/models/recall/hot_recall.py
@@ -0,0 +1,163 @@
+import pymysql
+from typing import List, Tuple, Dict
+import numpy as np
+
+class HotRecall:
+    """
+    热度召回算法实现
+    基于物品的热度(热度分数、交互次数等)进行召回
+    """
+    
+    def __init__(self, db_config: dict):
+        """
+        初始化热度召回模型
+        
+        Args:
+            db_config: 数据库配置
+        """
+        self.db_config = db_config
+        self.hot_items = []
+        
+    def _calculate_heat_scores(self):
+        """计算物品热度分数"""
+        conn = pymysql.connect(**self.db_config)
+        try:
+            cursor = conn.cursor()
+            
+            # 综合考虑多个热度指标
+            cursor.execute("""
+                SELECT 
+                    p.id,
+                    p.heat,
+                    COUNT(DISTINCT CASE WHEN b.type = 'like' THEN b.user_id END) as like_count,
+                    COUNT(DISTINCT CASE WHEN b.type = 'favorite' THEN b.user_id END) as favorite_count,
+                    COUNT(DISTINCT CASE WHEN b.type = 'comment' THEN b.user_id END) as comment_count,
+                    COUNT(DISTINCT CASE WHEN b.type = 'view' THEN b.user_id END) as view_count,
+                    COUNT(DISTINCT CASE WHEN b.type = 'share' THEN b.user_id END) as share_count,
+                    DATEDIFF(NOW(), p.created_at) as days_since_created
+                FROM posts p
+                LEFT JOIN behaviors b ON p.id = b.post_id
+                WHERE p.status = 'published'
+                GROUP BY p.id, p.heat, p.created_at
+            """)
+            
+            results = cursor.fetchall()
+            
+            # 计算综合热度分数
+            items_with_scores = []
+            for row in results:
+                post_id, heat, like_count, favorite_count, comment_count, view_count, share_count, days_since_created = row
+                
+                # 处理None值
+                heat = heat or 0
+                like_count = like_count or 0
+                favorite_count = favorite_count or 0
+                comment_count = comment_count or 0
+                view_count = view_count or 0
+                share_count = share_count or 0
+                days_since_created = days_since_created or 0
+                
+                # 综合热度分数计算
+                # 基础热度 + 加权的用户行为 + 时间衰减
+                behavior_score = (
+                    like_count * 1.0 +
+                    favorite_count * 2.0 +
+                    comment_count * 3.0 +
+                    view_count * 0.1 +
+                    share_count * 5.0
+                )
+                
+                # 时间衰减因子(越新的内容热度越高)
+                time_decay = np.exp(-days_since_created / 30.0)  # 30天半衰期
+                
+                # 最终热度分数
+                final_score = (heat * 0.3 + behavior_score * 0.7) * time_decay
+                
+                items_with_scores.append((post_id, final_score))
+            
+            # 按热度排序
+            self.hot_items = sorted(items_with_scores, key=lambda x: x[1], reverse=True)
+            
+        finally:
+            cursor.close()
+            conn.close()
+    
+    def train(self):
+        """训练热度召回模型"""
+        print("开始计算热度分数...")
+        self._calculate_heat_scores()
+        print(f"热度召回模型训练完成,共{len(self.hot_items)}个物品")
+    
+    def recall(self, user_id: int, num_items: int = 50) -> List[Tuple[int, float]]:
+        """
+        为用户召回热门物品
+        
+        Args:
+            user_id: 用户ID
+            num_items: 召回物品数量
+            
+        Returns:
+            List of (item_id, score) tuples
+        """
+        # 如果尚未训练,先进行训练
+        if not hasattr(self, 'hot_items') or not self.hot_items:
+            self.train()
+        
+        # 获取用户已交互的物品,避免重复推荐
+        conn = pymysql.connect(**self.db_config)
+        try:
+            cursor = conn.cursor()
+            cursor.execute("""
+                SELECT DISTINCT post_id 
+                FROM behaviors 
+                WHERE user_id = %s AND type IN ('like', 'favorite', 'comment')
+            """, (user_id,))
+            
+            user_interacted_items = set(row[0] for row in cursor.fetchall())
+            
+        finally:
+            cursor.close()
+            conn.close()
+        
+        # 过滤掉用户已交互的物品
+        filtered_items = [
+            (item_id, score) for item_id, score in self.hot_items
+            if item_id not in user_interacted_items
+        ]
+        
+        # 如果过滤后没有足够的候选,放宽条件:只过滤强交互(like, favorite, comment)
+        if len(filtered_items) < num_items:
+            print(f"热度召回:过滤后候选不足({len(filtered_items)}),放宽过滤条件")
+            conn = pymysql.connect(**self.db_config)
+            try:
+                cursor = conn.cursor()
+                cursor.execute("""
+                    SELECT DISTINCT post_id 
+                    FROM behaviors 
+                    WHERE user_id = %s AND type IN ('like', 'favorite', 'comment')
+                """, (user_id,))
+                
+                strong_interacted_items = set(row[0] for row in cursor.fetchall())
+                
+            finally:
+                cursor.close()
+                conn.close()
+            
+            filtered_items = [
+                (item_id, score) for item_id, score in self.hot_items
+                if item_id not in strong_interacted_items
+            ]
+        
+        return filtered_items[:num_items]
+    
+    def get_top_hot_items(self, num_items: int = 100) -> List[Tuple[int, float]]:
+        """
+        获取全局热门物品(不考虑用户个性化)
+        
+        Args:
+            num_items: 返回物品数量
+            
+        Returns:
+            List of (item_id, score) tuples
+        """
+        return self.hot_items[:num_items]
diff --git a/rhj/backend/app/models/recall/multi_recall_manager.py b/rhj/backend/app/models/recall/multi_recall_manager.py
new file mode 100644
index 0000000..03cb3f8
--- /dev/null
+++ b/rhj/backend/app/models/recall/multi_recall_manager.py
@@ -0,0 +1,253 @@
+from typing import List, Tuple, Dict, Any
+import numpy as np
+from collections import defaultdict
+
+from .swing_recall import SwingRecall
+from .hot_recall import HotRecall
+from .ad_recall import AdRecall
+from .usercf_recall import UserCFRecall
+
+class MultiRecallManager:
+    """
+    多路召回管理器
+    整合Swing、热度召回、广告召回和UserCF等多种召回策略
+    """
+    
+    def __init__(self, db_config: dict, recall_config: dict = None):
+        """
+        初始化多路召回管理器
+        
+        Args:
+            db_config: 数据库配置
+            recall_config: 召回配置,包含各个召回器的参数和召回数量
+        """
+        self.db_config = db_config
+        
+        # 默认配置
+        default_config = {
+            'swing': {
+                'enabled': True,
+                'num_items': 15,
+                'alpha': 0.5
+            },
+            'hot': {
+                'enabled': True,
+                'num_items': 10
+            },
+            'ad': {
+                'enabled': True,
+                'num_items': 2
+            },
+            'usercf': {
+                'enabled': True,
+                'num_items': 10,
+                'min_common_items': 3,
+                'num_similar_users': 50
+            }
+        }
+        
+        # 合并用户配置
+        self.config = default_config
+        if recall_config:
+            for key, value in recall_config.items():
+                if key in self.config:
+                    self.config[key].update(value)
+                else:
+                    self.config[key] = value
+        
+        # 初始化各个召回器
+        self.recalls = {}
+        self._init_recalls()
+        
+    def _init_recalls(self):
+        """初始化各个召回器"""
+        if self.config['swing']['enabled']:
+            self.recalls['swing'] = SwingRecall(
+                self.db_config, 
+                alpha=self.config['swing']['alpha']
+            )
+            
+        if self.config['hot']['enabled']:
+            self.recalls['hot'] = HotRecall(self.db_config)
+            
+        if self.config['ad']['enabled']:
+            self.recalls['ad'] = AdRecall(self.db_config)
+            
+        if self.config['usercf']['enabled']:
+            self.recalls['usercf'] = UserCFRecall(
+                self.db_config,
+                min_common_items=self.config['usercf']['min_common_items']
+            )
+    
+    def train_all(self):
+        """训练所有召回器"""
+        print("开始训练多路召回模型...")
+        
+        for name, recall_model in self.recalls.items():
+            print(f"训练 {name} 召回器...")
+            recall_model.train()
+            
+        print("所有召回器训练完成!")
+    
+    def recall(self, user_id: int, total_items: int = 200) -> Tuple[List[int], List[float], Dict[str, List[Tuple[int, float]]]]:
+        """
+        执行多路召回
+        
+        Args:
+            user_id: 用户ID
+            total_items: 总召回物品数量
+            
+        Returns:
+            Tuple containing:
+            - List of item IDs
+            - List of scores
+            - Dict of recall results by source
+        """
+        recall_results = {}
+        all_candidates = defaultdict(list)  # item_id -> [(source, score), ...]
+        
+        # 执行各路召回
+        for source, recall_model in self.recalls.items():
+            if not self.config[source]['enabled']:
+                continue
+                
+            num_items = self.config[source]['num_items']
+            
+            # 特殊处理UserCF的参数
+            if source == 'usercf':
+                items_scores = recall_model.recall(
+                    user_id, 
+                    num_items=num_items,
+                    num_similar_users=self.config[source]['num_similar_users']
+                )
+            else:
+                items_scores = recall_model.recall(user_id, num_items=num_items)
+            
+            recall_results[source] = items_scores
+            
+            # 收集候选物品
+            for item_id, score in items_scores:
+                all_candidates[item_id].append((source, score))
+        
+        # 融合多路召回结果
+        final_candidates = self._merge_recall_results(all_candidates, total_items)
+        
+        # 分离item_ids和scores
+        item_ids = [item_id for item_id, _ in final_candidates]
+        scores = [score for _, score in final_candidates]
+        
+        return item_ids, scores, recall_results
+    
+    def _merge_recall_results(self, all_candidates: Dict[int, List[Tuple[str, float]]], 
+                            total_items: int) -> List[Tuple[int, float]]:
+        """
+        融合多路召回结果
+        
+        Args:
+            all_candidates: 所有候选物品及其来源和分数
+            total_items: 最终返回的物品数量
+            
+        Returns:
+            List of (item_id, final_score) tuples
+        """
+        # 定义各召回源的权重
+        source_weights = {
+            'swing': 0.3,
+            'hot': 0.2,
+            'ad': 0.1,
+            'usercf': 0.4
+        }
+        
+        final_scores = []
+        
+        for item_id, source_scores in all_candidates.items():
+            # 计算加权平均分数
+            weighted_score = 0.0
+            total_weight = 0.0
+            
+            for source, score in source_scores:
+                weight = source_weights.get(source, 0.1)
+                weighted_score += weight * score
+                total_weight += weight
+            
+            # 归一化
+            if total_weight > 0:
+                final_score = weighted_score / total_weight
+            else:
+                final_score = 0.0
+            
+            # 多样性奖励:如果物品来自多个召回源,给予额外分数
+            diversity_bonus = len(source_scores) * 0.1
+            final_score += diversity_bonus
+            
+            final_scores.append((item_id, final_score))
+        
+        # 按最终分数排序
+        final_scores.sort(key=lambda x: x[1], reverse=True)
+        
+        return final_scores[:total_items]
+    
+    def get_recall_stats(self, user_id: int) -> Dict[str, Any]:
+        """
+        获取召回统计信息
+        
+        Args:
+            user_id: 用户ID
+            
+        Returns:
+            召回统计字典
+        """
+        stats = {
+            'user_id': user_id,
+            'enabled_recalls': list(self.recalls.keys()),
+            'config': self.config
+        }
+        
+        # 获取各召回器的统计信息
+        if 'usercf' in self.recalls:
+            try:
+                user_profile = self.recalls['usercf'].get_user_profile(user_id)
+                stats['user_profile'] = user_profile
+                
+                neighbors = self.recalls['usercf'].get_user_neighbors(user_id, 5)
+                stats['similar_users'] = neighbors
+            except:
+                pass
+        
+        return stats
+    
+    def update_config(self, new_config: dict):
+        """
+        更新召回配置
+        
+        Args:
+            new_config: 新的配置字典
+        """
+        for key, value in new_config.items():
+            if key in self.config:
+                self.config[key].update(value)
+            else:
+                self.config[key] = value
+        
+        # 重新初始化召回器
+        self._init_recalls()
+    
+    def get_recall_breakdown(self, user_id: int) -> Dict[str, int]:
+        """
+        获取各召回源的物品数量分布
+        
+        Args:
+            user_id: 用户ID
+            
+        Returns:
+            各召回源的物品数量字典
+        """
+        breakdown = {}
+        
+        for source in self.recalls.keys():
+            if self.config[source]['enabled']:
+                breakdown[source] = self.config[source]['num_items']
+            else:
+                breakdown[source] = 0
+        
+        return breakdown
diff --git a/rhj/backend/app/models/recall/swing_recall.py b/rhj/backend/app/models/recall/swing_recall.py
new file mode 100644
index 0000000..bf7fdd6
--- /dev/null
+++ b/rhj/backend/app/models/recall/swing_recall.py
@@ -0,0 +1,126 @@
+import numpy as np
+import pymysql
+from collections import defaultdict
+import math
+from typing import List, Tuple, Dict
+
+class SwingRecall:
+    """
+    Swing召回算法实现
+    基于物品相似度的协同过滤算法,能够有效处理热门物品的问题
+    """
+    
+    def __init__(self, db_config: dict, alpha: float = 0.5):
+        """
+        初始化Swing召回模型
+        
+        Args:
+            db_config: 数据库配置
+            alpha: 控制热门物品惩罚的参数,值越大惩罚越强
+        """
+        self.db_config = db_config
+        self.alpha = alpha
+        self.item_similarity = {}
+        self.user_items = defaultdict(set)
+        self.item_users = defaultdict(set)
+        
+    def _get_interaction_data(self):
+        """获取用户-物品交互数据"""
+        conn = pymysql.connect(**self.db_config)
+        try:
+            cursor = conn.cursor()
+            # 获取用户行为数据(点赞、收藏、评论等)
+            cursor.execute("""
+                SELECT DISTINCT user_id, post_id
+                FROM behaviors
+                WHERE type IN ('like', 'favorite', 'comment')
+            """)
+            interactions = cursor.fetchall()
+            
+            for user_id, post_id in interactions:
+                self.user_items[user_id].add(post_id)
+                self.item_users[post_id].add(user_id)
+                
+        finally:
+            cursor.close()
+            conn.close()
+    
+    def _calculate_swing_similarity(self):
+        """计算Swing相似度矩阵"""
+        print("开始计算Swing相似度...")
+        
+        # 获取所有物品对
+        items = list(self.item_users.keys())
+        
+        for i, item_i in enumerate(items):
+            if i % 100 == 0:
+                print(f"处理进度: {i}/{len(items)}")
+                
+            self.item_similarity[item_i] = {}
+            
+            for item_j in items[i+1:]:
+                # 获取同时交互过两个物品的用户
+                common_users = self.item_users[item_i] & self.item_users[item_j]
+                
+                if len(common_users) < 2:  # 需要至少2个共同用户
+                    similarity = 0.0
+                else:
+                    # 计算Swing相似度
+                    similarity = 0.0
+                    for u in common_users:
+                        for v in common_users:
+                            if u != v:
+                                # Swing算法的核心公式
+                                swing_weight = 1.0 / (self.alpha + len(self.user_items[u] & self.user_items[v]))
+                                similarity += swing_weight
+                    
+                    # 归一化
+                    similarity = similarity / (len(common_users) * (len(common_users) - 1))
+                
+                self.item_similarity[item_i][item_j] = similarity
+                # 对称性
+                if item_j not in self.item_similarity:
+                    self.item_similarity[item_j] = {}
+                self.item_similarity[item_j][item_i] = similarity
+        
+        print("Swing相似度计算完成")
+    
+    def train(self):
+        """训练Swing模型"""
+        self._get_interaction_data()
+        self._calculate_swing_similarity()
+    
+    def recall(self, user_id: int, num_items: int = 50) -> List[Tuple[int, float]]:
+        """
+        为用户召回相似物品
+        
+        Args:
+            user_id: 用户ID
+            num_items: 召回物品数量
+            
+        Returns:
+            List of (item_id, score) tuples
+        """
+        # 如果尚未训练,先进行训练
+        if not hasattr(self, 'item_similarity') or not self.item_similarity:
+            self.train()
+        
+        if user_id not in self.user_items:
+            return []
+        
+        # 获取用户历史交互的物品
+        user_interacted_items = self.user_items[user_id]
+        
+        # 计算候选物品的分数
+        candidate_scores = defaultdict(float)
+        
+        for item_i in user_interacted_items:
+            if item_i in self.item_similarity:
+                for item_j, similarity in self.item_similarity[item_i].items():
+                    # 排除用户已经交互过的物品
+                    if item_j not in user_interacted_items:
+                        candidate_scores[item_j] += similarity
+        
+        # 按分数排序并返回top-N
+        sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
+        return sorted_candidates[:num_items]
diff --git a/rhj/backend/app/models/recall/usercf_recall.py b/rhj/backend/app/models/recall/usercf_recall.py
new file mode 100644
index 0000000..d75e6d8
--- /dev/null
+++ b/rhj/backend/app/models/recall/usercf_recall.py
@@ -0,0 +1,235 @@
+import pymysql
+from typing import List, Tuple, Dict, Set
+from collections import defaultdict
+import math
+import numpy as np
+
+class UserCFRecall:
+    """
+    UserCF (User-based Collaborative Filtering) 召回算法实现
+    基于用户相似度的协同过滤算法
+    """
+    
+    def __init__(self, db_config: dict, min_common_items: int = 3):
+        """
+        初始化UserCF召回模型
+        
+        Args:
+            db_config: 数据库配置
+            min_common_items: 计算用户相似度时的最小共同物品数
+        """
+        self.db_config = db_config
+        self.min_common_items = min_common_items
+        self.user_items = defaultdict(set)
+        self.item_users = defaultdict(set)
+        self.user_similarity = {}
+        
+    def _get_user_item_interactions(self):
+        """获取用户-物品交互数据"""
+        conn = pymysql.connect(**self.db_config)
+        try:
+            cursor = conn.cursor()
+            
+            # 获取用户行为数据,考虑不同行为的权重
+            cursor.execute("""
+                SELECT user_id, post_id, type, COUNT(*) as count
+                FROM behaviors
+                WHERE type IN ('like', 'favorite', 'comment', 'view')
+                GROUP BY user_id, post_id, type
+            """)
+            
+            interactions = cursor.fetchall()
+            
+            # 构建用户-物品交互矩阵(考虑行为权重)
+            user_item_scores = defaultdict(lambda: defaultdict(float))
+            
+            # 定义不同行为的权重
+            behavior_weights = {
+                'like': 1.0,
+                'favorite': 2.0,
+                'comment': 3.0,
+                'view': 0.1
+            }
+            
+            for user_id, post_id, behavior_type, count in interactions:
+                weight = behavior_weights.get(behavior_type, 1.0)
+                score = weight * count
+                user_item_scores[user_id][post_id] += score
+            
+            # 转换为集合形式(用于相似度计算)
+            for user_id, items in user_item_scores.items():
+                # 只保留分数大于阈值的物品
+                threshold = 1.0  # 可调整阈值
+                for item_id, score in items.items():
+                    if score >= threshold:
+                        self.user_items[user_id].add(item_id)
+                        self.item_users[item_id].add(user_id)
+                        
+        finally:
+            cursor.close()
+            conn.close()
+    
+    def _calculate_user_similarity(self):
+        """计算用户相似度矩阵"""
+        print("开始计算用户相似度...")
+        
+        users = list(self.user_items.keys())
+        total_pairs = len(users) * (len(users) - 1) // 2
+        processed = 0
+        
+        for i, user_i in enumerate(users):
+            self.user_similarity[user_i] = {}
+            
+            for user_j in users[i+1:]:
+                processed += 1
+                if processed % 10000 == 0:
+                    print(f"处理进度: {processed}/{total_pairs}")
+                
+                # 获取两个用户共同交互的物品
+                common_items = self.user_items[user_i] & self.user_items[user_j]
+                
+                if len(common_items) < self.min_common_items:
+                    similarity = 0.0
+                else:
+                    # 计算余弦相似度
+                    numerator = len(common_items)
+                    denominator = math.sqrt(len(self.user_items[user_i]) * len(self.user_items[user_j]))
+                    similarity = numerator / denominator if denominator > 0 else 0.0
+                
+                self.user_similarity[user_i][user_j] = similarity
+                # 对称性
+                if user_j not in self.user_similarity:
+                    self.user_similarity[user_j] = {}
+                self.user_similarity[user_j][user_i] = similarity
+        
+        print("用户相似度计算完成")
+    
+    def train(self):
+        """训练UserCF模型"""
+        self._get_user_item_interactions()
+        self._calculate_user_similarity()
+    
+    def recall(self, user_id: int, num_items: int = 50, num_similar_users: int = 50) -> List[Tuple[int, float]]:
+        """
+        为用户召回相似用户喜欢的物品
+        
+        Args:
+            user_id: 目标用户ID
+            num_items: 召回物品数量
+            num_similar_users: 考虑的相似用户数量
+            
+        Returns:
+            List of (item_id, score) tuples
+        """
+        # 如果尚未训练,先进行训练
+        if not hasattr(self, 'user_similarity') or not self.user_similarity:
+            self.train()
+        
+        if user_id not in self.user_similarity or user_id not in self.user_items:
+            return []
+        
+        # 获取最相似的用户
+        similar_users = sorted(
+            self.user_similarity[user_id].items(),
+            key=lambda x: x[1],
+            reverse=True
+        )[:num_similar_users]
+        
+        # 获取目标用户已交互的物品
+        user_interacted_items = self.user_items[user_id]
+        
+        # 计算候选物品的分数
+        candidate_scores = defaultdict(float)
+        
+        for similar_user_id, similarity in similar_users:
+            if similarity <= 0:
+                continue
+                
+            # 获取相似用户交互的物品
+            similar_user_items = self.user_items[similar_user_id]
+            
+            for item_id in similar_user_items:
+                # 排除目标用户已经交互过的物品
+                if item_id not in user_interacted_items:
+                    candidate_scores[item_id] += similarity
+        
+        # 按分数排序并返回top-N
+        sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
+        return sorted_candidates[:num_items]
+    
+    def get_user_neighbors(self, user_id: int, num_neighbors: int = 10) -> List[Tuple[int, float]]:
+        """
+        获取用户的相似邻居
+        
+        Args:
+            user_id: 用户ID
+            num_neighbors: 邻居数量
+            
+        Returns:
+            List of (neighbor_user_id, similarity) tuples
+        """
+        if user_id not in self.user_similarity:
+            return []
+        
+        neighbors = sorted(
+            self.user_similarity[user_id].items(),
+            key=lambda x: x[1],
+            reverse=True
+        )[:num_neighbors]
+        
+        return neighbors
+    
+    def get_user_profile(self, user_id: int) -> Dict:
+        """
+        获取用户画像信息
+        
+        Args:
+            user_id: 用户ID
+            
+        Returns:
+            用户画像字典
+        """
+        if user_id not in self.user_items:
+            return {}
+        
+        conn = pymysql.connect(**self.db_config)
+        try:
+            cursor = conn.cursor()
+            
+            # 获取用户交互的物品类别统计
+            user_item_list = list(self.user_items[user_id])
+            if not user_item_list:
+                return {}
+                
+            format_strings = ','.join(['%s'] * len(user_item_list))
+            cursor.execute(f"""
+                SELECT t.name, COUNT(*) as count
+                FROM post_tags pt
+                JOIN tags t ON pt.tag_id = t.id
+                WHERE pt.post_id IN ({format_strings})
+                GROUP BY t.name
+                ORDER BY count DESC
+            """, tuple(user_item_list))
+            
+            tag_preferences = cursor.fetchall()
+            
+            # 获取用户行为统计
+            cursor.execute("""
+                SELECT type, COUNT(*) as count
+                FROM behaviors
+                WHERE user_id = %s
+                GROUP BY type
+            """, (user_id,))
+            
+            behavior_stats = cursor.fetchall()
+            
+            return {
+                'user_id': user_id,
+                'total_interactions': len(self.user_items[user_id]),
+                'tag_preferences': dict(tag_preferences),
+                'behavior_stats': dict(behavior_stats)
+            }
+            
+        finally:
+            cursor.close()
+            conn.close()
diff --git a/rhj/backend/app/models/recommend/LightGCN.py b/rhj/backend/app/models/recommend/LightGCN.py
new file mode 100644
index 0000000..38b1732
--- /dev/null
+++ b/rhj/backend/app/models/recommend/LightGCN.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import scipy.sparse as sp
+import math
+import networkx as nx
+import random
+from copy import deepcopy
+from app.utils.parse_args import args
+from app.models.recommend.base_model import BaseModel
+from app.models.recommend.operators import EdgelistDrop
+from app.models.recommend.operators import scatter_add, scatter_sum
+
+
+init = nn.init.xavier_uniform_
+
+class LightGCN(BaseModel):
+    def __init__(self, dataset, pretrained_model=None, phase='pretrain'):
+        super().__init__(dataset)
+        self.adj = self._make_binorm_adj(dataset.graph)
+        self.edges = self.adj._indices().t()
+        self.edge_norm = self.adj._values()
+
+        self.phase = phase
+
+        self.emb_gate = lambda x: x
+
+        if self.phase == 'pretrain' or self.phase == 'vanilla' or self.phase == 'for_tune':
+            self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size)))
+            self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size)))
+
+
+        elif self.phase == 'finetune':
+            pre_user_emb, pre_item_emb = pretrained_model.generate()
+            self.user_embedding = nn.Parameter(pre_user_emb).requires_grad_(True)
+            self.item_embedding = nn.Parameter(pre_item_emb).requires_grad_(True)
+
+        elif self.phase == 'continue_tune':
+            # re-initialize for loading state dict
+            self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size)))
+            self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size)))
+
+        self.edge_dropout = EdgelistDrop()
+
+    def _agg(self, all_emb, edges, edge_norm):
+        src_emb = all_emb[edges[:, 0]]
+
+        # bi-norm
+        src_emb = src_emb * edge_norm.unsqueeze(1)
+
+        # conv
+        dst_emb = scatter_sum(src_emb, edges[:, 1], dim=0, dim_size=self.num_users+self.num_items)
+        return dst_emb
+    
+    def _edge_binorm(self, edges):
+        user_degs = scatter_add(torch.ones_like(edges[:, 0]), edges[:, 0], dim=0, dim_size=self.num_users)
+        user_degs = user_degs[edges[:, 0]]
+        item_degs = scatter_add(torch.ones_like(edges[:, 1]), edges[:, 1], dim=0, dim_size=self.num_items)
+        item_degs = item_degs[edges[:, 1]]
+        norm = torch.pow(user_degs, -0.5) * torch.pow(item_degs, -0.5)
+        return norm
+
+    def forward(self, edges, edge_norm, return_layers=False):
+        all_emb = torch.cat([self.user_embedding, self.item_embedding], dim=0)
+        all_emb = self.emb_gate(all_emb)
+        res_emb = [all_emb]
+        for l in range(args.num_layers):
+            all_emb = self._agg(res_emb[-1], edges, edge_norm)
+            res_emb.append(all_emb)
+        if not return_layers:
+            res_emb = sum(res_emb)
+            user_res_emb, item_res_emb = res_emb.split([self.num_users, self.num_items], dim=0)
+        else:
+            user_res_emb, item_res_emb = [], []
+            for emb in res_emb:
+                u_emb, i_emb = emb.split([self.num_users, self.num_items], dim=0)
+                user_res_emb.append(u_emb)
+                item_res_emb.append(i_emb)
+        return user_res_emb, item_res_emb
+    
+    def cal_loss(self, batch_data):
+        edges, dropout_mask = self.edge_dropout(self.edges, 1-args.edge_dropout, return_mask=True)
+        edge_norm = self.edge_norm[dropout_mask]
+
+        # forward
+        users, pos_items, neg_items = batch_data
+        user_emb, item_emb = self.forward(edges, edge_norm)
+        batch_user_emb = user_emb[users]
+        pos_item_emb = item_emb[pos_items]
+        neg_item_emb = item_emb[neg_items]
+        rec_loss = self._bpr_loss(batch_user_emb, pos_item_emb, neg_item_emb)
+        reg_loss = args.weight_decay * self._reg_loss(users, pos_items, neg_items)
+
+        loss = rec_loss + reg_loss
+        loss_dict = {
+            "rec_loss": rec_loss.item(),
+            "reg_loss": reg_loss.item(),
+        }
+        return loss, loss_dict
+    
+    @torch.no_grad()
+    def generate(self, return_layers=False):
+        return self.forward(self.edges, self.edge_norm, return_layers=return_layers)
+    
+    @torch.no_grad()
+    def generate_lgn(self, return_layers=False):
+        return self.forward(self.edges, self.edge_norm, return_layers=return_layers)
+    
+    @torch.no_grad()
+    def rating(self, user_emb, item_emb):
+        return torch.matmul(user_emb, item_emb.t())
+    
+    def _reg_loss(self, users, pos_items, neg_items):
+        u_emb = self.user_embedding[users]
+        pos_i_emb = self.item_embedding[pos_items]
+        neg_i_emb = self.item_embedding[neg_items]
+        reg_loss = (1/2)*(u_emb.norm(2).pow(2) +
+                          pos_i_emb.norm(2).pow(2) +
+                          neg_i_emb.norm(2).pow(2))/float(len(users))
+        return reg_loss
diff --git a/rhj/backend/app/models/recommend/LightGCN_pretrained.pt b/rhj/backend/app/models/recommend/LightGCN_pretrained.pt
new file mode 100644
index 0000000..825e0e2
--- /dev/null
+++ b/rhj/backend/app/models/recommend/LightGCN_pretrained.pt
Binary files differ
diff --git a/rhj/backend/app/models/recommend/__pycache__/LightGCN.cpython-312.pyc b/rhj/backend/app/models/recommend/__pycache__/LightGCN.cpython-312.pyc
new file mode 100644
index 0000000..c87435f
--- /dev/null
+++ b/rhj/backend/app/models/recommend/__pycache__/LightGCN.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recommend/__pycache__/base_model.cpython-312.pyc b/rhj/backend/app/models/recommend/__pycache__/base_model.cpython-312.pyc
new file mode 100644
index 0000000..b9d8c72
--- /dev/null
+++ b/rhj/backend/app/models/recommend/__pycache__/base_model.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recommend/__pycache__/lightgcn_scorer.cpython-312.pyc b/rhj/backend/app/models/recommend/__pycache__/lightgcn_scorer.cpython-312.pyc
new file mode 100644
index 0000000..b0887a9
--- /dev/null
+++ b/rhj/backend/app/models/recommend/__pycache__/lightgcn_scorer.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recommend/__pycache__/operators.cpython-312.pyc b/rhj/backend/app/models/recommend/__pycache__/operators.cpython-312.pyc
new file mode 100644
index 0000000..13bb375
--- /dev/null
+++ b/rhj/backend/app/models/recommend/__pycache__/operators.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/models/recommend/base_model.py b/rhj/backend/app/models/recommend/base_model.py
new file mode 100644
index 0000000..6c59aa6
--- /dev/null
+++ b/rhj/backend/app/models/recommend/base_model.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+from app.utils.parse_args import args
+from scipy.sparse import csr_matrix
+import scipy.sparse as sp
+import numpy as np
+import torch.nn.functional as F
+
+
+class BaseModel(nn.Module):
+    def __init__(self, dataloader):
+        super(BaseModel, self).__init__()
+        self.num_users = dataloader.num_users
+        self.num_items = dataloader.num_items
+        self.emb_size = args.emb_size
+
+    def forward(self):
+        pass
+
+    def cal_loss(self, batch_data):
+        pass
+
+    def _check_inf(self, loss, pos_score, neg_score, edge_weight):
+        # find inf idx
+        inf_idx = torch.isinf(loss) | torch.isnan(loss)
+        if inf_idx.any():
+            print("find inf in loss")
+            if type(edge_weight) != int:
+                print(edge_weight[inf_idx])
+            print(f"pos_score: {pos_score[inf_idx]}")
+            print(f"neg_score: {neg_score[inf_idx]}")
+            raise ValueError("find inf in loss")
+
+    def _make_binorm_adj(self, mat):
+        a = csr_matrix((self.num_users, self.num_users))
+        b = csr_matrix((self.num_items, self.num_items))
+        mat = sp.vstack(
+            [sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
+        mat = (mat != 0) * 1.0
+        # mat = (mat + sp.eye(mat.shape[0])) * 1.0# MARK
+        degree = np.array(mat.sum(axis=-1))
+        d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
+        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
+        d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
+        mat = mat.dot(d_inv_sqrt_mat).transpose().dot(
+            d_inv_sqrt_mat).tocoo()
+
+        # make torch tensor
+        idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
+        vals = torch.from_numpy(mat.data.astype(np.float32))
+        shape = torch.Size(mat.shape)
+        return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device)
+    
+    def _make_binorm_adj_self_loop(self, mat):
+        a = csr_matrix((self.num_users, self.num_users))
+        b = csr_matrix((self.num_items, self.num_items))
+        mat = sp.vstack(
+            [sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
+        mat = (mat != 0) * 1.0
+        mat = (mat + sp.eye(mat.shape[0])) * 1.0 # self loop
+        degree = np.array(mat.sum(axis=-1))
+        d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
+        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
+        d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
+        mat = mat.dot(d_inv_sqrt_mat).transpose().dot(
+            d_inv_sqrt_mat).tocoo()
+
+        # make torch tensor
+        idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
+        vals = torch.from_numpy(mat.data.astype(np.float32))
+        shape = torch.Size(mat.shape)
+        return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device)
+
+
+    def _sp_matrix_to_sp_tensor(self, sp_matrix):
+        coo = sp_matrix.tocoo()
+        indices = torch.LongTensor([coo.row, coo.col])
+        values = torch.FloatTensor(coo.data)
+        return torch.sparse.FloatTensor(indices, values, coo.shape).coalesce().to(args.device)
+
+    def _bpr_loss(self, user_emb, pos_item_emb, neg_item_emb):
+        pos_score = (user_emb * pos_item_emb).sum(dim=1)
+        neg_score = (user_emb * neg_item_emb).sum(dim=1)
+        loss = -torch.log(1e-10 + torch.sigmoid((pos_score - neg_score)))
+        self._check_inf(loss, pos_score, neg_score, 0)
+        return loss.mean()
+    
+    def _nce_loss(self, pos_score, neg_score, edge_weight=1):
+        numerator = torch.exp(pos_score)
+        denominator = torch.exp(pos_score) + torch.exp(neg_score).sum(dim=1)
+        loss = -torch.log(numerator/denominator) * edge_weight
+        self._check_inf(loss, pos_score, neg_score, edge_weight)
+        return loss.mean()
+    
+    def _infonce_loss(self, pos_1, pos_2, negs, tau):
+        pos_1 = self.cl_mlp(pos_1)
+        pos_2 = self.cl_mlp(pos_2)
+        negs = self.cl_mlp(negs)
+        pos_1 = F.normalize(pos_1, dim=-1)
+        pos_2 = F.normalize(pos_2, dim=-1)
+        negs = F.normalize(negs, dim=-1)
+        pos_score = torch.mul(pos_1, pos_2).sum(dim=1)
+        # B, 1, E * B, E, N -> B, N
+        neg_score = torch.bmm(pos_1.unsqueeze(1), negs.transpose(1, 2)).squeeze(1)
+        # infonce loss
+        numerator = torch.exp(pos_score / tau)
+        denominator = torch.exp(pos_score / tau) + torch.exp(neg_score / tau).sum(dim=1)
+        loss = -torch.log(numerator/denominator)
+        self._check_inf(loss, pos_score, neg_score, 0)
+        return loss.mean()
+    
\ No newline at end of file
diff --git a/rhj/backend/app/models/recommend/operators.py b/rhj/backend/app/models/recommend/operators.py
new file mode 100644
index 0000000..a508966
--- /dev/null
+++ b/rhj/backend/app/models/recommend/operators.py
@@ -0,0 +1,52 @@
+import torch
+from typing import Optional, Tuple
+from torch import nn
+
+def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
+    if dim < 0:
+        dim = other.dim() + dim
+    if src.dim() == 1:
+        for _ in range(0, dim):
+            src = src.unsqueeze(0)
+    for _ in range(src.dim(), other.dim()):
+        src = src.unsqueeze(-1)
+    src = src.expand(other.size())
+    return src
+
+def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
+                out: Optional[torch.Tensor] = None,
+                dim_size: Optional[int] = None) -> torch.Tensor:
+    index = broadcast(index, src, dim)
+    if out is None:
+        size = list(src.size())
+        if dim_size is not None:
+            size[dim] = dim_size
+        elif index.numel() == 0:
+            size[dim] = 0
+        else:
+            size[dim] = int(index.max()) + 1
+        out = torch.zeros(size, dtype=src.dtype, device=src.device)
+        return out.scatter_add_(dim, index, src)
+    else:
+        return out.scatter_add_(dim, index, src)
+
+def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
+                out: Optional[torch.Tensor] = None,
+                dim_size: Optional[int] = None) -> torch.Tensor:
+    return scatter_sum(src, index, dim, out, dim_size)
+
+
+class EdgelistDrop(nn.Module):
+    def __init__(self):
+        super(EdgelistDrop, self).__init__()
+
+    def forward(self, edgeList, keep_rate, return_mask=False):
+        if keep_rate == 1.0:
+            return edgeList, torch.ones(edgeList.size(0)).type(torch.bool)
+        edgeNum = edgeList.size(0)
+        mask = (torch.rand(edgeNum) + keep_rate).floor().type(torch.bool)
+        newEdgeList = edgeList[mask, :]
+        if return_mask:
+            return newEdgeList, mask
+        else:
+            return newEdgeList