blob: 0fe3b0a996118916dc16cc97bc48324d612dd16a [file] [log] [blame]
Raverd7895172025-06-18 17:54:38 +08001import pymysql
2from typing import List, Tuple, Dict
3import random
4
5class AdRecall:
6 """
7 广告召回算法实现
8 专门用于召回广告类型的内容
9 """
10
11 def __init__(self, db_config: dict):
12 """
13 初始化广告召回模型
14
15 Args:
16 db_config: 数据库配置
17 """
18 self.db_config = db_config
19 self.ad_items = []
20
21 def _get_ad_items(self):
22 """获取广告物品列表"""
23 conn = pymysql.connect(**self.db_config)
24 try:
25 cursor = conn.cursor()
26
27 # 获取所有广告帖子,按热度和发布时间排序
28 cursor.execute("""
29 SELECT
30 p.id,
31 p.heat,
32 p.created_at,
33 COUNT(DISTINCT b.user_id) as interaction_count,
34 DATEDIFF(NOW(), p.created_at) as days_since_created
35 FROM posts p
36 LEFT JOIN behaviors b ON p.id = b.post_id
37 WHERE p.is_advertisement = 1 AND p.status = 'published'
38 GROUP BY p.id, p.heat, p.created_at
39 ORDER BY p.heat DESC, p.created_at DESC
40 """)
41
42 results = cursor.fetchall()
43
44 # 计算广告分数
45 items_with_scores = []
46 for row in results:
47 post_id, heat, created_at, interaction_count, days_since_created = row
48
49 # 处理None值
50 heat = heat or 0
51 interaction_count = interaction_count or 0
52 days_since_created = days_since_created or 0
53
54 # 广告分数计算:热度 + 交互数 - 时间惩罚
55 # 新发布的广告给予更高权重
56 freshness_bonus = max(0, 30 - days_since_created) / 30.0 # 30天内的新鲜度奖励
57
58 ad_score = (
59 heat * 0.6 +
60 interaction_count * 0.3 +
61 freshness_bonus * 100 # 新鲜度奖励
62 )
63
64 items_with_scores.append((post_id, ad_score))
65
66 # 按广告分数排序
67 self.ad_items = sorted(items_with_scores, key=lambda x: x[1], reverse=True)
68
69 finally:
70 cursor.close()
71 conn.close()
72
73 def train(self):
74 """训练广告召回模型"""
75 print("开始获取广告物品...")
76 self._get_ad_items()
77 print(f"广告召回模型训练完成,共{len(self.ad_items)}个广告物品")
78
79 def recall(self, user_id: int, num_items: int = 10) -> List[Tuple[int, float]]:
80 """
81 为用户召回广告物品
82
83 Args:
84 user_id: 用户ID
85 num_items: 召回物品数量
86
87 Returns:
88 List of (item_id, score) tuples
89 """
90 # 如果尚未训练,先进行训练
91 if not hasattr(self, 'ad_items') or not self.ad_items:
92 self.train()
93
94 # 获取用户已交互的广告,避免重复推荐
95 conn = pymysql.connect(**self.db_config)
96 try:
97 cursor = conn.cursor()
98 cursor.execute("""
99 SELECT DISTINCT b.post_id
100 FROM behaviors b
101 JOIN posts p ON b.post_id = p.id
102 WHERE b.user_id = %s AND p.is_advertisement = 1
103 AND b.type IN ('like', 'favorite', 'comment', 'view')
104 """, (user_id,))
105
106 user_interacted_ads = set(row[0] for row in cursor.fetchall())
107
108 # 获取用户的兴趣标签(基于历史行为)
109 cursor.execute("""
110 SELECT t.name, COUNT(*) as count
111 FROM behaviors b
112 JOIN posts p ON b.post_id = p.id
113 JOIN post_tags pt ON p.id = pt.post_id
114 JOIN tags t ON pt.tag_id = t.id
115 WHERE b.user_id = %s AND b.type IN ('like', 'favorite', 'comment')
116 GROUP BY t.name
117 ORDER BY count DESC
118 LIMIT 10
119 """, (user_id,))
120
121 user_interest_tags = set(row[0] for row in cursor.fetchall())
122
123 finally:
124 cursor.close()
125 conn.close()
126
127 # 过滤掉用户已交互的广告
128 filtered_ads = [
129 (item_id, score) for item_id, score in self.ad_items
130 if item_id not in user_interacted_ads
131 ]
132
133 # 如果没有未交互的广告,但有广告数据,返回评分最高的广告(可能用户会再次感兴趣)
134 if not filtered_ads and self.ad_items:
135 print(f"用户 {user_id} 已与所有广告交互,返回评分最高的广告")
136 filtered_ads = self.ad_items[:num_items]
137
138 # 如果用户有兴趣标签,可以进一步个性化广告推荐
139 if user_interest_tags and filtered_ads:
140 filtered_ads = self._personalize_ads(filtered_ads, user_interest_tags)
141
142 return filtered_ads[:num_items]
143
144 def _personalize_ads(self, ad_list: List[Tuple[int, float]], user_interest_tags: set) -> List[Tuple[int, float]]:
145 """
146 根据用户兴趣标签个性化广告推荐
147
148 Args:
149 ad_list: 广告列表
150 user_interest_tags: 用户兴趣标签
151
152 Returns:
153 个性化后的广告列表
154 """
155 conn = pymysql.connect(**self.db_config)
156 try:
157 cursor = conn.cursor()
158
159 personalized_ads = []
160 for ad_id, ad_score in ad_list:
161 # 获取广告的标签
162 cursor.execute("""
163 SELECT t.name
164 FROM post_tags pt
165 JOIN tags t ON pt.tag_id = t.id
166 WHERE pt.post_id = %s
167 """, (ad_id,))
168
169 ad_tags = set(row[0] for row in cursor.fetchall())
170
171 # 计算标签匹配度
172 tag_match_score = len(ad_tags & user_interest_tags) / max(len(user_interest_tags), 1)
173
174 # 调整广告分数
175 final_score = ad_score * (1 + tag_match_score)
176 personalized_ads.append((ad_id, final_score))
177
178 # 重新排序
179 personalized_ads.sort(key=lambda x: x[1], reverse=True)
180 return personalized_ads
181
182 finally:
183 cursor.close()
184 conn.close()
185
186 def get_random_ads(self, num_items: int = 5) -> List[Tuple[int, float]]:
187 """
188 获取随机广告(用于多样性)
189
190 Args:
191 num_items: 返回物品数量
192
193 Returns:
194 List of (item_id, score) tuples
195 """
196 if len(self.ad_items) <= num_items:
197 return self.ad_items
198
199 # 随机选择但倾向于高分广告
200 weights = [score for _, score in self.ad_items]
201 selected_indices = random.choices(
202 range(len(self.ad_items)),
203 weights=weights,
204 k=num_items
205 )
206
207 return [self.ad_items[i] for i in selected_indices]