| import pymysql |
| from typing import List, Tuple, Dict |
| import random |
| class AdRecall: |
| """ |
| 广告召回算法实现 |
| 专门用于召回广告类型的内容 |
| """ |
| |
| def __init__(self, db_config: dict): |
| """ |
| 初始化广告召回模型 |
| |
| Args: |
| db_config: 数据库配置 |
| """ |
| self.db_config = db_config |
| self.ad_items = [] |
| |
| def _get_ad_items(self): |
| """获取广告物品列表""" |
| conn = pymysql.connect(**self.db_config) |
| try: |
| cursor = conn.cursor() |
| |
| # 获取所有广告帖子,按热度和发布时间排序 |
| cursor.execute(""" |
| SELECT |
| p.id, |
| p.heat, |
| p.created_at, |
| COUNT(DISTINCT b.user_id) as interaction_count, |
| DATEDIFF(NOW(), p.created_at) as days_since_created |
| FROM posts p |
| LEFT JOIN behaviors b ON p.id = b.post_id |
| WHERE p.is_advertisement = 1 AND p.status = 'published' |
| GROUP BY p.id, p.heat, p.created_at |
| ORDER BY p.heat DESC, p.created_at DESC |
| """) |
| |
| results = cursor.fetchall() |
| |
| # 计算广告分数 |
| items_with_scores = [] |
| for row in results: |
| post_id, heat, created_at, interaction_count, days_since_created = row |
| |
| # 处理 None 值 |
| heat = heat or 0 |
| interaction_count = interaction_count or 0 |
| days_since_created = days_since_created or 0 |
| |
| # 广告分数计算:热度 + 交互数 - 时间惩罚 |
| # 新发布的广告给予更高权重 |
| freshness_bonus = max(0, 30 - days_since_created) / 30.0 # 30 天内的新鲜度奖励 |
| |
| ad_score = ( |
| heat * 0.6 + |
| interaction_count * 0.3 + |
| freshness_bonus * 100 # 新鲜度奖励 |
| ) |
| |
| items_with_scores.append((post_id, ad_score)) |
| |
| # 按广告分数排序 |
| self.ad_items = sorted(items_with_scores, key=lambda x: x[1], reverse=True) |
| |
| finally: |
| cursor.close() |
| conn.close() |
| |
| def train(self): |
| """训练广告召回模型""" |
| print("开始获取广告物品...") |
| self._get_ad_items() |
| print(f"广告召回模型训练完成,共 {len (self.ad_items)} 个广告物品") |
| |
| def recall(self, user_id: int, num_items: int = 10) -> List[Tuple[int, float]]: |
| """ |
| 为用户召回广告物品 |
| |
| Args: |
| user_id: 用户 ID |
| num_items: 召回物品数量 |
| |
| Returns: |
| List of (item_id, score) tuples |
| """ |
| # 如果尚未训练,先进行训练 |
| if not hasattr(self, 'ad_items') or not self.ad_items: |
| self.train() |
| |
| # 获取用户已交互的广告,避免重复推荐 |
| conn = pymysql.connect(**self.db_config) |
| try: |
| cursor = conn.cursor() |
| cursor.execute(""" |
| SELECT DISTINCT b.post_id |
| FROM behaviors b |
| JOIN posts p ON b.post_id = p.id |
| WHERE b.user_id = %s AND p.is_advertisement = 1 |
| AND b.type IN ('like', 'favorite', 'comment', 'view') |
| """, (user_id,)) |
| |
| user_interacted_ads = set(row[0] for row in cursor.fetchall()) |
| |
| # 获取用户的兴趣标签(基于历史行为) |
| cursor.execute(""" |
| SELECT t.name, COUNT(*) as count |
| FROM behaviors b |
| JOIN posts p ON b.post_id = p.id |
| JOIN post_tags pt ON p.id = pt.post_id |
| JOIN tags t ON pt.tag_id = t.id |
| WHERE b.user_id = %s AND b.type IN ('like', 'favorite', 'comment') |
| GROUP BY t.name |
| ORDER BY count DESC |
| LIMIT 10 |
| """, (user_id,)) |
| |
| user_interest_tags = set(row[0] for row in cursor.fetchall()) |
| |
| finally: |
| cursor.close() |
| conn.close() |
| |
| # 过滤掉用户已交互的广告 |
| filtered_ads = [ |
| (item_id, score) for item_id, score in self.ad_items |
| if item_id not in user_interacted_ads |
| ] |
| |
| # 如果没有未交互的广告,但有广告数据,返回评分最高的广告(可能用户会再次感兴趣) |
| if not filtered_ads and self.ad_items: |
| print(f"用户 {user_id} 已与所有广告交互,返回评分最高的广告") |
| filtered_ads = self.ad_items[:num_items] |
| |
| # 如果用户有兴趣标签,可以进一步个性化广告推荐 |
| if user_interest_tags and filtered_ads: |
| filtered_ads = self._personalize_ads(filtered_ads, user_interest_tags) |
| |
| return filtered_ads[:num_items] |
| |
| def _personalize_ads(self, ad_list: List[Tuple[int, float]], user_interest_tags: set) -> List[Tuple[int, float]]: |
| """ |
| 根据用户兴趣标签个性化广告推荐 |
| |
| Args: |
| ad_list: 广告列表 |
| user_interest_tags: 用户兴趣标签 |
| |
| Returns: |
| 个性化后的广告列表 |
| """ |
| conn = pymysql.connect(**self.db_config) |
| try: |
| cursor = conn.cursor() |
| |
| personalized_ads = [] |
| for ad_id, ad_score in ad_list: |
| # 获取广告的标签 |
| cursor.execute(""" |
| SELECT t.name |
| FROM post_tags pt |
| JOIN tags t ON pt.tag_id = t.id |
| WHERE pt.post_id = %s |
| """, (ad_id,)) |
| |
| ad_tags = set(row[0] for row in cursor.fetchall()) |
| |
| # 计算标签匹配度 |
| tag_match_score = len(ad_tags & user_interest_tags) / max(len(user_interest_tags), 1) |
| |
| # 调整广告分数 |
| final_score = ad_score * (1 + tag_match_score) |
| personalized_ads.append((ad_id, final_score)) |
| |
| # 重新排序 |
| personalized_ads.sort(key=lambda x: x[1], reverse=True) |
| return personalized_ads |
| |
| finally: |
| cursor.close() |
| conn.close() |
| |
| def get_random_ads(self, num_items: int = 5) -> List[Tuple[int, float]]: |
| """ |
| 获取随机广告(用于多样性) |
| |
| Args: |
| num_items: 返回物品数量 |
| |
| Returns: |
| List of (item_id, score) tuples |
| """ |
| if len(self.ad_items) <= num_items: |
| return self.ad_items |
| |
| # 随机选择但倾向于高分广告 |
| weights = [score for _, score in self.ad_items] |
| selected_indices = random.choices( |
| range(len(self.ad_items)), |
| weights=weights, |
| k=num_items |
| ) |
| |
| return [self.ad_items[i] for i in selected_indices] |