blob: cc2c60211064f82f039651492f0babe27800e6be [file] [log] [blame]
TRM-coding2a8fd602025-06-19 19:33:16 +08001import pymysql
2from typing import List, Tuple, Dict
3import random
4class AdRecall:
5 """
6 广告召回算法实现
7 专门用于召回广告类型的内容
8 """
9
10 def __init__(self, db_config: dict):
11 """
12 初始化广告召回模型
13
14 Args:
15 db_config: 数据库配置
16 """
17 self.db_config = db_config
18 self.ad_items = []
19
20 def _get_ad_items(self):
21 """获取广告物品列表"""
22 conn = pymysql.connect(**self.db_config)
23 try:
24 cursor = conn.cursor()
25
26 # 获取所有广告帖子,按热度和发布时间排序
27 cursor.execute("""
28 SELECT
29 p.id,
30 p.heat,
31 p.created_at,
32 COUNT(DISTINCT b.user_id) as interaction_count,
33 DATEDIFF(NOW(), p.created_at) as days_since_created
34 FROM posts p
35 LEFT JOIN behaviors b ON p.id = b.post_id
36 WHERE p.is_advertisement = 1 AND p.status = 'published'
37 GROUP BY p.id, p.heat, p.created_at
38 ORDER BY p.heat DESC, p.created_at DESC
39 """)
40
41 results = cursor.fetchall()
42
43 # 计算广告分数
44 items_with_scores = []
45 for row in results:
46 post_id, heat, created_at, interaction_count, days_since_created = row
47
48 # 处理 None 值
49 heat = heat or 0
50 interaction_count = interaction_count or 0
51 days_since_created = days_since_created or 0
52
53 # 广告分数计算:热度 + 交互数 - 时间惩罚
54 # 新发布的广告给予更高权重
55 freshness_bonus = max(0, 30 - days_since_created) / 30.0 # 30 天内的新鲜度奖励
56
57 ad_score = (
58 heat * 0.6 +
59 interaction_count * 0.3 +
60 freshness_bonus * 100 # 新鲜度奖励
61 )
62
63 items_with_scores.append((post_id, ad_score))
64
65 # 按广告分数排序
66 self.ad_items = sorted(items_with_scores, key=lambda x: x[1], reverse=True)
67
68 finally:
69 cursor.close()
70 conn.close()
71
72 def train(self):
73 """训练广告召回模型"""
74 print("开始获取广告物品...")
75 self._get_ad_items()
76 print(f"广告召回模型训练完成,共 {len (self.ad_items)} 个广告物品")
77
78 def recall(self, user_id: int, num_items: int = 10) -> List[Tuple[int, float]]:
79 """
80 为用户召回广告物品
81
82 Args:
83 user_id: 用户 ID
84 num_items: 召回物品数量
85
86 Returns:
87 List of (item_id, score) tuples
88 """
89 # 如果尚未训练,先进行训练
90 if not hasattr(self, 'ad_items') or not self.ad_items:
91 self.train()
92
93 # 获取用户已交互的广告,避免重复推荐
94 conn = pymysql.connect(**self.db_config)
95 try:
96 cursor = conn.cursor()
97 cursor.execute("""
98 SELECT DISTINCT b.post_id
99 FROM behaviors b
100 JOIN posts p ON b.post_id = p.id
101 WHERE b.user_id = %s AND p.is_advertisement = 1
102 AND b.type IN ('like', 'favorite', 'comment', 'view')
103 """, (user_id,))
104
105 user_interacted_ads = set(row[0] for row in cursor.fetchall())
106
107 # 获取用户的兴趣标签(基于历史行为)
108 cursor.execute("""
109 SELECT t.name, COUNT(*) as count
110 FROM behaviors b
111 JOIN posts p ON b.post_id = p.id
112 JOIN post_tags pt ON p.id = pt.post_id
113 JOIN tags t ON pt.tag_id = t.id
114 WHERE b.user_id = %s AND b.type IN ('like', 'favorite', 'comment')
115 GROUP BY t.name
116 ORDER BY count DESC
117 LIMIT 10
118 """, (user_id,))
119
120 user_interest_tags = set(row[0] for row in cursor.fetchall())
121
122 finally:
123 cursor.close()
124 conn.close()
125
126 # 过滤掉用户已交互的广告
127 filtered_ads = [
128 (item_id, score) for item_id, score in self.ad_items
129 if item_id not in user_interacted_ads
130 ]
131
132 # 如果没有未交互的广告,但有广告数据,返回评分最高的广告(可能用户会再次感兴趣)
133 if not filtered_ads and self.ad_items:
134 print(f"用户 {user_id} 已与所有广告交互,返回评分最高的广告")
135 filtered_ads = self.ad_items[:num_items]
136
137 # 如果用户有兴趣标签,可以进一步个性化广告推荐
138 if user_interest_tags and filtered_ads:
139 filtered_ads = self._personalize_ads(filtered_ads, user_interest_tags)
140
141 return filtered_ads[:num_items]
142
143 def _personalize_ads(self, ad_list: List[Tuple[int, float]], user_interest_tags: set) -> List[Tuple[int, float]]:
144 """
145 根据用户兴趣标签个性化广告推荐
146
147 Args:
148 ad_list: 广告列表
149 user_interest_tags: 用户兴趣标签
150
151 Returns:
152 个性化后的广告列表
153 """
154 conn = pymysql.connect(**self.db_config)
155 try:
156 cursor = conn.cursor()
157
158 personalized_ads = []
159 for ad_id, ad_score in ad_list:
160 # 获取广告的标签
161 cursor.execute("""
162 SELECT t.name
163 FROM post_tags pt
164 JOIN tags t ON pt.tag_id = t.id
165 WHERE pt.post_id = %s
166 """, (ad_id,))
167
168 ad_tags = set(row[0] for row in cursor.fetchall())
169
170 # 计算标签匹配度
171 tag_match_score = len(ad_tags & user_interest_tags) / max(len(user_interest_tags), 1)
172
173 # 调整广告分数
174 final_score = ad_score * (1 + tag_match_score)
175 personalized_ads.append((ad_id, final_score))
176
177 # 重新排序
178 personalized_ads.sort(key=lambda x: x[1], reverse=True)
179 return personalized_ads
180
181 finally:
182 cursor.close()
183 conn.close()
184
185 def get_random_ads(self, num_items: int = 5) -> List[Tuple[int, float]]:
186 """
187 获取随机广告(用于多样性)
188
189 Args:
190 num_items: 返回物品数量
191
192 Returns:
193 List of (item_id, score) tuples
194 """
195 if len(self.ad_items) <= num_items:
196 return self.ad_items
197
198 # 随机选择但倾向于高分广告
199 weights = [score for _, score in self.ad_items]
200 selected_indices = random.choices(
201 range(len(self.ad_items)),
202 weights=weights,
203 k=num_items
204 )
205
206 return [self.ad_items[i] for i in selected_indices]