blob: d479afa2606bc0117a3e254b50e331056cbf4d98 [file] [log] [blame]
22301008cae762d2025-06-14 00:27:04 +08001# main_online.py
2# 搜索推荐算法服务的主入口
3
4import json
5import numpy as np
6import difflib
7from flask import Flask, request, jsonify, Response
8import pymysql
9import jieba
10from sklearn.feature_extraction.text import TfidfVectorizer
11from sklearn.metrics.pairwise import cosine_similarity
12import pypinyin
13from flask_cors import CORS
14import re
15import Levenshtein
16import os
17import logging
18
19# 设置日志
20logging.basicConfig(
21 level=logging.INFO,
22 format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23)
24logger = logging.getLogger("allpt-search")
25
26# 导入Word2Vec辅助模块
27try:
28 from word2vec_helper import get_word2vec_helper, expand_query, get_similar_words
29 WORD2VEC_ENABLED = True
30 logger.info("Word2Vec模块已加载")
31except ImportError as e:
32 logger.warning(f"Word2Vec模块加载失败: {e},将使用传统搜索")
33 WORD2VEC_ENABLED = False
34
35# 数据库配置
36DB_CONFIG = {
37 "host": "10.126.59.25",
38 "port": 3306,
39 "user": "root",
40 "password": "123456",
41 "database": "redbook",
42 "charset": "utf8mb4"
43}
44
45def get_db_conn():
46 return pymysql.connect(**DB_CONFIG)
47
48def get_pinyin(text):
49 # 返回字符串的全拼音(不带声调,全部小写),支持英文直接返回
50 if not text:
51 return ""
52 import re
53 # 如果全是英文,直接返回小写
54 if re.fullmatch(r'[a-zA-Z]+', text):
55 return text.lower()
56 return ''.join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.NORMAL)])
57
58def get_pinyin_initials(text):
59 # 返回字符串的首字母拼音(全部小写),支持英文直接返回
60 if not text:
61 return ""
62 import re
63 if re.fullmatch(r'[a-zA-Z]+', text):
64 return text.lower()
65 return ''.join([p[0][0] for p in pypinyin.pinyin(text, style=pypinyin.NORMAL)])
66
67# 新增词语相似度计算函数
68def word_similarity(word1, word2):
69 """计算两个词的相似度,支持拼音匹配"""
70 # 直接匹配
71 if word1 == word2:
72 return 1.0
73
74 # 拼音匹配
75 if get_pinyin(word1) == get_pinyin(word2):
76 return 0.9
77
78 # 拼音首字母匹配
79 if get_pinyin_initials(word1) == get_pinyin_initials(word2):
80 return 0.7
81
82 # 字符串相似度
83 return difflib.SequenceMatcher(None, word1, word2).ratio()
84
85def semantic_title_similarity(query, title):
86 """计算查询词与标题的语义相似度"""
87 # 分词
88 query_words = list(jieba.cut(query))
89 title_words = list(jieba.cut(title))
90
91 if not query_words or not title_words:
92 return 0.0
93
94 # 计算每个查询词与标题词的最大相似度
95 max_similarities = []
96 key_matches = 0 # 关键词精确匹配数量
97
98 for q_word in query_words:
99 if len(q_word.strip()) <= 1: # 忽略单字,减少噪音
100 continue
101
102 word_sims = [word_similarity(q_word, t_word) for t_word in title_words]
103 if word_sims:
104 max_sim = max(word_sims)
105 max_similarities.append(max_sim)
106 if max_sim > 0.85: # 认为是关键词匹配
107 key_matches += 1
108
109 if not max_similarities:
110 return 0.0
111
112 # 计算平均相似度
113 avg_sim = sum(max_similarities) / len(max_similarities)
114
115 # 权重计算: 平均相似度占70%,关键词匹配率占30%
116 key_match_ratio = key_matches / len(query_words) if query_words else 0
117
118 # 标题中包含完整查询短语时给予额外加分
119 exact_bonus = 0.3 if query in title else 0
120
121 return 0.7 * avg_sim + 0.3 * key_match_ratio + exact_bonus
122
123# 添加语义关联词典,用于增强搜索能力
124def load_semantic_mappings():
125 """
126 加载语义关联映射表,用于增强搜索语义理解
127 返回包含语义映射关系的字典
128 """
129 # 初始化空字典,所有映射将从配置文件加载
130 mappings = {}
131
132 # 从配置文件加载映射
133 try:
134 config_path = os.path.join(os.path.dirname(__file__), "semantic_config.json")
135 if os.path.exists(config_path):
136 with open(config_path, 'r', encoding='utf-8') as f:
137 mappings = json.load(f)
138 logger.info(f"已从配置文件加载 {len(mappings)} 个语义映射")
139 else:
140 logger.warning(f"语义配置文件不存在: {config_path}")
141 except Exception as e:
142 logger.error(f"加载语义配置文件失败: {e}")
143
144 return mappings
145
146# 初始化语义映射
147SEMANTIC_MAPPINGS = load_semantic_mappings()
148
149def expand_search_keywords(keyword):
150 """
151 扩展搜索关键词,增加语义关联词
152 """
153 expanded = [keyword]
154
155 # 分词处理
156 words = list(jieba.cut(keyword))
157 logger.info(f"关键词 '{keyword}' 分词结果: {words}") # 记录分词结果
158
159 # 分别对每个分词进行语义扩展
160 for word in words:
161 if word in SEMANTIC_MAPPINGS:
162 # 添加语义关联词
163 mapped_words = SEMANTIC_MAPPINGS[word]
164 expanded.extend(mapped_words)
165 logger.info(f"语义映射: '{word}' -> {mapped_words}")
166
167 # 移除所有特殊处理部分
168 # 不再对任何特定关键词如"越狱"进行特殊处理
169
170 # Word2Vec扩展 - 如果可用,对分词结果进行Word2Vec扩展
171 if WORD2VEC_ENABLED:
172 try:
173 # 使用单独的变量记录原始扩展结果,方便记录日志
174 original_expanded = set(expanded)
175
176 # 首先尝试对整个关键词进行扩展
177 w2v_expanded = set()
178 similar_words = get_similar_words(keyword, topn=3, min_similarity=0.6)
179 w2v_expanded.update(similar_words)
180
181 # 然后对较长的分词进行扩展
182 for word in words:
183 if len(word) > 1: # 忽略单字
184 similar_words = get_similar_words(word, topn=2, min_similarity=0.65)
185 w2v_expanded.update(similar_words)
186
187 # 合并结果
188 expanded.extend(w2v_expanded)
189
190 # 记录日志
191 if w2v_expanded:
192 logger.info(f"Word2Vec扩展: {keyword} -> {list(w2v_expanded)}")
193 except Exception as e:
194 # 出错时记录但不中断搜索流程
195 logger.error(f"Word2Vec扩展失败: {e}")
196 logger.info("将仅使用配置文件中的语义映射")
197
198 # 去重
199 return list(set(expanded))
200
201# 替换原有的calculate_keyword_relevance函数,采用更通用的相关性算法
202def calculate_keyword_relevance(keyword, item):
203 """计算搜索关键词与条目的相关性得分"""
204 title = item.get('title', '')
205 description = item.get('description', '') or ''
206 tags = item.get('tags', '') or ''
207 category = item.get('category', '') or '' # 添加category字段
208
209 # 初始化得分
210 score = 0
211
212 # 1. 精确匹配(最高优先级)
213 if keyword.lower() == title.lower():
214 return 15.0 # 完全匹配给予最高分
215
216 # 2. 标题中精确词匹配
217 title_words = re.findall(r'\b\w+\b', title.lower())
218 if keyword.lower() in title_words:
219 score += 10.0 # 作为独立词完全匹配
220
221 # 3. 标题包含关键词(部分匹配)
222 elif keyword.lower() in title.lower():
223 # 计算关键词所占标题比例
224 match_ratio = len(keyword) / len(title)
225 if match_ratio > 0.5: # 关键词占标题很大比例
226 score += 8.0
227 else:
228 score += 5.0
229
230 # 4. 标题分词匹配
231 keyword_words = list(jieba.cut(keyword))
232 title_jieba_words = list(jieba.cut(title))
233
234 matched_words = 0
235 for k_word in keyword_words:
236 if len(k_word) > 1: # 忽略单字
237 if k_word in title_jieba_words:
238 matched_words += 1
239 else:
240 # 拼音匹配
241 k_pinyin = get_pinyin(k_word)
242 for t_word in title_jieba_words:
243 if get_pinyin(t_word) == k_pinyin:
244 matched_words += 0.8
245 break
246
247 if len(keyword_words) > 0:
248 word_match_ratio = matched_words / len(keyword_words)
249 score += 3.0 * word_match_ratio
250
251 # 5. 拼音相似度
252 keyword_pinyin = get_pinyin(keyword)
253 title_pinyin = get_pinyin(title)
254
255 if keyword_pinyin == title_pinyin:
256 score += 3.5
257 elif keyword_pinyin in title_pinyin:
258 # 计算拼音在标题中的位置影响
259 pos = title_pinyin.find(keyword_pinyin)
260 if pos == 0: # 出现在开头
261 score += 3.0
262 else:
263 score += 2.0
264
265 # 6. 编辑距离相似度
266 try:
267 edit_distance = Levenshtein.distance(keyword.lower(), title.lower())
268 max_len = max(len(keyword), len(title))
269 if max_len > 0:
270 similarity = 1 - (edit_distance / max_len)
271 if similarity > 0.7:
272 score += 1.5 * similarity
273 except:
274 similarity = difflib.SequenceMatcher(None, keyword.lower(), title.lower()).ratio()
275 if similarity > 0.7:
276 score += 1.5 * similarity
277
278 # 7. 中文字符重叠检测 - 修改为仅当重叠2个以上汉字或占比超过40%时才计分
279 if re.search(r'[\u4e00-\u9fff]', keyword) and re.search(r'[\u4e00-\u9fff]', title):
280 cn_chars_keyword = set(re.findall(r'[\u4e00-\u9fff]', keyword))
281 cn_chars_title = set(re.findall(r'[\u4e00-\u9fff]', title))
282
283 # 计算重叠的汉字集合
284 overlapped_chars = cn_chars_keyword & cn_chars_title
285
286 # 仅当重叠汉字数量大于1且占比超过阈值时才计分
287 if len(overlapped_chars) > 1 and len(cn_chars_keyword) > 0:
288 overlap_ratio = len(overlapped_chars) / len(cn_chars_keyword)
289 # 增加重叠比例的阈值要求,防止单个汉字导致的误匹配
290 if overlap_ratio >= 0.4 or len(overlapped_chars) >= 3:
291 score += 2.0 * overlap_ratio
292 # 对于非常低的重叠度,不加分,避免无关内容干扰
293
294 # 记录日志,帮助调试特定案例
295 if keyword == "明日方舟" and "白日梦想家" in title:
296 logger.info(f"'明日方舟'与'{title}'的汉字重叠: {overlapped_chars}, 重叠比例: {len(overlapped_chars)/len(cn_chars_keyword) if cn_chars_keyword else 0}")
297
298 # 8. 序列资源检测(如"功夫熊猫2"是"功夫熊猫"的系列)
299 base_title_match = re.match(r'(.*?)([0-9]+|[一二三四五六七八九十]|:|\:|\s+[0-9]+)', title)
300 if base_title_match:
301 base_title = base_title_match.group(1).strip()
302 if keyword.lower() == base_title.lower():
303 score += 2.0
304
305 # 9. 标签和描述匹配(增加权重)
306 if tags:
307 tags_list = tags.split(',')
308 if keyword in tags_list:
309 score += 1.5 # 提高标签匹配的权重
310 elif any(keyword.lower() in tag.lower() for tag in tags_list):
311 score += 1.0 # 提高部分匹配的权重
312
313 # 描述匹配增强
314 if keyword.lower() in description.lower():
315 score += 1.5 # 提高描述匹配的权重
316
317 # 检查关键词在描述中的位置和上下文
318 pos = description.lower().find(keyword.lower())
319 if pos >= 0 and pos < len(description) / 3:
320 # 关键词出现在描述前1/3部分,可能更重要
321 score += 0.5
322
323 # 考虑分词匹配描述
324 keyword_words = list(jieba.cut(keyword))
325 description_words = list(jieba.cut(description))
326 matched_desc_words = 0
327 for k_word in keyword_words:
328 if len(k_word) > 1 and k_word in description_words:
329 matched_desc_words += 1
330
331 if len(keyword_words) > 0:
332 desc_match_ratio = matched_desc_words / len(keyword_words)
333 score += 1.0 * desc_match_ratio
334
335 # 分类匹配
336 if keyword.lower() in category.lower():
337 score += 1.0
338
339 # 添加语义关联匹配得分
340 # 扩展关键词进行匹配
341 expanded_keywords = expand_search_keywords(keyword)
22301008cae762d2025-06-14 00:27:04 +0800342 # 检测标题是否包含语义相关词
343 for exp_keyword in expanded_keywords:
344 if exp_keyword != keyword and exp_keyword in title: # 避免重复计算原关键词
22301008d5fbb782025-06-18 16:28:43 +0800345 score += 1.5 # 一般语义关联
22301008cae762d2025-06-14 00:27:04 +0800346
347 return score
348
349# 创建Flask应用
350app = Flask(__name__)
351CORS(app) # 允许所有跨域请求
352
353# 添加init_word2vec函数
354def init_word2vec():
355 """初始化Word2Vec模型"""
356 try:
357 helper = get_word2vec_helper()
358 if helper.initialized:
359 logger.info(f"Word2Vec模型已成功加载,词汇量: {len(helper.model.index_to_key)}, 向量维度: {helper.model.vector_size}")
360 else:
361 if helper.load_model():
362 logger.info(f"Word2Vec模型加载成功,词汇量: {len(helper.model.index_to_key)}, 向量维度: {helper.model.vector_size}")
363 else:
364 logger.error("Word2Vec模型加载失败")
365 except Exception as e:
366 logger.error(f"初始化Word2Vec出错: {e}")
367
368# 新的初始化方式:
369def initialize_app():
370 """应用初始化函数,替代before_first_request装饰器"""
371 # 修正:使用正确的函数名
372 # 原代码: init_semantic_mapping()
373 # 修正为使用已定义的函数名
374 global SEMANTIC_MAPPINGS
375 SEMANTIC_MAPPINGS = load_semantic_mappings() # 更新全局语义映射变量
376
377 if WORD2VEC_ENABLED:
378 init_word2vec() # 现在这个函数已经定义了
379
380# 在启动应用之前调用初始化函数
381initialize_app()
382
22301008d5fbb782025-06-18 16:28:43 +0800383# 测试路由
384@app.route('/test', methods=['GET'])
385def test():
386 import datetime
387 return jsonify({"message": "服务器正常运行", "timestamp": str(datetime.datetime.now())})
388
389# 获取单个帖子详情的API
390@app.route('/post/<int:post_id>', methods=['GET'])
391def get_post_detail(post_id):
392 """
393 获取单个帖子详情
394 """
395 logger.info(f"接收到获取帖子详情请求,post_id: {post_id}")
396 conn = get_db_conn()
397 try:
398 with conn.cursor(pymysql.cursors.DictCursor) as cursor:
399 # 查询帖子详情,先用简单查询调试
400 query = """
401 SELECT
402 p.id,
403 p.title,
404 p.content,
405 p.heat,
406 p.created_at as create_time,
407 p.updated_at as last_active,
408 p.status
409 FROM posts p
410 WHERE p.id = %s
411 """
412 logger.info(f"执行查询: {query} with post_id: {post_id}")
413 cursor.execute(query, (post_id,))
414 post = cursor.fetchone()
415
416 logger.info(f"查询结果: {post}")
417
418 if not post:
419 logger.warning(f"帖子不存在,post_id: {post_id}")
420 return jsonify({"error": "帖子不存在"}), 404
421
422 # 设置默认值
423 post['tags'] = []
424 post['category'] = '未分类'
425 post['author'] = '匿名用户'
426
427 # 格式化时间
428 if post['create_time']:
429 post['create_time'] = post['create_time'].strftime('%Y-%m-%d %H:%M:%S')
430 if post['last_active']:
431 post['last_active'] = post['last_active'].strftime('%Y-%m-%d %H:%M:%S')
432
433 logger.info(f"返回帖子详情: {post}")
434 return Response(json.dumps(post, ensure_ascii=False), mimetype='application/json; charset=utf-8')
435 except Exception as e:
436 logger.error(f"获取帖子详情失败: {e}")
437 import traceback
438 traceback.print_exc()
439 return jsonify({"error": "服务器内部错误"}), 500
440 finally:
441 conn.close()
442
22301008cae762d2025-06-14 00:27:04 +0800443# 搜索功能的API
444@app.route('/search', methods=['POST'])
445def search():
446 """
447 搜索功能API
448 请求格式:{
449 "keyword": "关键词",
450 "sort_by": "downloads" | "downloads_asc" | "newest" | "oldest" | "similarity" | "title_asc" | "title_desc",
451 "category": "可选,分类名",
452 "search_mode": "title" | "title_desc" | "tags" | "all" # 可选,默认"title",
453 "tags": ["标签1", "标签2"] # 可选,支持传递多个标签
454 }
455 """
456 if request.content_type != 'application/json':
457 return jsonify({"error": "Content-Type must be application/json"}), 415
458
459 data = request.get_json()
460 keyword = data.get("keyword", "").strip()
461 sort_by = data.get("sort_by", "similarity") # 默认按相似度排序
462 category = data.get("category", None)
463 search_mode = data.get("search_mode", "title")
464 tags = data.get("tags", None) # 支持传递多个标签
465
466 # 校验参数 - 不管什么模式都要求关键词
467 if not (1 <= len(keyword) <= 20):
468 return jsonify({"error": "请输入1-20个字符"}), 400
469
470 # 第一阶段:数据库查询获取候选集
471 results = []
472 conn = get_db_conn()
473 try:
474 with conn.cursor(pymysql.cursors.DictCursor) as cursor:
475 # 首先尝试查询完全匹配的结果
476 exact_query = f"""
477 SELECT id, title, topic_id, heat, created_at, content
478 FROM posts
479 WHERE title = %s
480 """
481 cursor.execute(exact_query, (keyword,))
482 exact_matches = cursor.fetchall() or [] # 确保返回列表而非元组
483
484 # 扩展关键词,增加语义关联词
485 expanded_keywords = expand_search_keywords(keyword)
486 logger.info(f"扩展后的关键词: {expanded_keywords}") # 调试信息
487
488 # 构建查询条件
489 conditions = []
490 params = []
491
492 # 标题匹配 - 所有搜索模式都匹配title
493 conditions.append("title LIKE %s")
494 params.append(f"%{keyword}%")
495
496 # 为扩展关键词添加标题匹配条件
497 for exp_keyword in expanded_keywords:
498 if exp_keyword != keyword: # 避免重复原关键词
499 conditions.append("title LIKE %s")
500 params.append(f"%{exp_keyword}%")
501
502 # 描述匹配
503 if search_mode in ["title_desc", "all"]:
504 # 原始关键词匹配描述
505 conditions.append("content LIKE %s")
506 params.append(f"%{keyword}%")
507
508 # 扩展关键词匹配描述
509 for exp_keyword in expanded_keywords:
510 if exp_keyword != keyword:
511 conditions.append("content LIKE %s")
512 params.append(f"%{exp_keyword}%")
513
514 # 标签匹配
515 # 暂不处理,后续join实现
516
517 # 分类匹配 - 仅在all模式下
518 if search_mode == "all":
519 # 原始关键词匹配分类
520 conditions.append("topic_id LIKE %s")
521 params.append(f"%{keyword}%")
522
523 # 扩展关键词匹配分类
524 for exp_keyword in expanded_keywords:
525 if exp_keyword != keyword:
526 conditions.append("topic_id LIKE %s")
527 params.append(f"%{exp_keyword}%")
528
529 # 构建SQL查询
530 if conditions:
531 where_clause = " OR ".join(conditions)
532 logger.info(f"搜索条件: {where_clause}")
533 logger.info(f"参数列表: {params}")
534
535 if category:
536 where_clause = f"({where_clause}) AND topic_id=%s"
537 params.append(category)
538
539 sql = f"""
540 SELECT p.id, p.title, tp.name as category, p.heat, p.created_at, p.content,
541 GROUP_CONCAT(t.name) as tags
542 FROM posts p
543 LEFT JOIN post_tags pt ON p.id = pt.post_id
544 LEFT JOIN tags t ON pt.tag_id = t.id
545 LEFT JOIN topics tp ON p.topic_id = tp.id
546 WHERE {where_clause}
547 GROUP BY p.id
548 LIMIT 500
549 """
550
551 cursor.execute(sql, params)
552 expanded_results = cursor.fetchall()
553 logger.info(f"数据库返回记录数: {len(expanded_results) if expanded_results else 0}")
554 else:
555 expanded_results = []
556
557 # 如果扩展查询和精确匹配都没有结果,获取全部记录进行相关性计算
558 if not expanded_results and not exact_matches:
559 sql = "SELECT p.id, p.title, tp.name as category, p.heat, p.created_at, p.content, GROUP_CONCAT(t.name) as tags FROM posts p LEFT JOIN post_tags pt ON p.id = pt.post_id LEFT JOIN tags t ON pt.tag_id = t.id LEFT JOIN topics tp ON p.topic_id = tp.id"
560 if category:
561 sql += " WHERE p.topic_id=%s"
562 category_params = [category]
563 cursor.execute(sql + " GROUP BY p.id", category_params)
564 else:
565 cursor.execute(sql + " GROUP BY p.id")
566
567 all_results = cursor.fetchall() or [] # 确保返回列表
568 else:
569 if isinstance(exact_matches, tuple):
570 exact_matches = list(exact_matches)
571 if isinstance(expanded_results, tuple):
572 expanded_results = list(expanded_results)
573 all_results = expanded_results + exact_matches
574
575 # 对所有结果使用相关性计算规则
576 scored_results = []
577 for item in all_results:
578 # 计算相关性得分
579 relevance_score = calculate_keyword_relevance(keyword, item)
580
581 # 降低相关性阈值,确保更多结果被保留 (从0.5改为0.1)
582 if relevance_score > 0.1:
583 item['relevance_score'] = relevance_score
584 scored_results.append(item)
585 logger.info(f"匹配项: {item['title']}, 相关性得分: {relevance_score}")
586
587 # 按相关性得分排序
588 scored_results.sort(key=lambda x: x.get('relevance_score', 0), reverse=True)
589
590 # 确保精确匹配的结果置顶
591 if exact_matches:
592 for exact_match in exact_matches:
593 exact_match['relevance_score'] = 20.0 # 超高分确保置顶
594
595 # 移除scored_results中已经存在于exact_matches的项
596 exact_ids = {item['id'] for item in exact_matches}
597 scored_results = [item for item in scored_results if item['id'] not in exact_ids]
598
599 # 合并两个结果集
600 results = exact_matches + scored_results
601 else:
602 results = scored_results
603
604 # 限制返回结果数量
605 results = results[:50]
606
607 except Exception as e:
608 logger.error(f"搜索出错: {e}")
609 import traceback
610 traceback.print_exc()
611 return jsonify({"error": "搜索系统异常,请稍后再试"}), 500
612 finally:
613 conn.close()
614
615 # 第二阶段:根据指定方式排序
616 if results:
617 if sort_by == "similarity" or not sort_by:
618 # 保持按相关性得分排序,已经排好了
619 pass
620 elif sort_by == "downloads":
621 results.sort(key=lambda x: x.get("download_count", 0), reverse=True)
622 elif sort_by == "downloads_asc":
623 results.sort(key=lambda x: x.get("download_count", 0))
624 elif sort_by == "newest":
625 results.sort(key=lambda x: x.get("create_time", ""), reverse=True)
626 elif sort_by == "oldest":
627 results.sort(key=lambda x: x.get("create_time", ""))
628 elif sort_by == "title_asc":
629 results.sort(key=lambda x: x.get("title", ""))
630 elif sort_by == "title_desc":
631 results.sort(key=lambda x: x.get("title", ""), reverse=True)
632
22301008af173152025-06-15 10:46:25 +0800633 # 最终处理:清理不需要返回的字段,并将 datetime 转为字符串
22301008cae762d2025-06-14 00:27:04 +0800634 for item in results:
635 item.pop("description", None)
636 item.pop("tags", None)
637 item.pop("relevance_score", None)
22301008af173152025-06-15 10:46:25 +0800638 for k, v in item.items():
639 if hasattr(v, 'isoformat'):
640 item[k] = v.isoformat(sep=' ', timespec='seconds')
22301008cae762d2025-06-14 00:27:04 +0800641
642 return Response(json.dumps({"results": results}, ensure_ascii=False), mimetype='application/json; charset=utf-8')
643
644# 推荐功能的API
645@app.route('/recommend_tags', methods=['POST'])
646def recommend_tags():
647 """
648 推荐功能API
649 请求格式:{
650 "user_id": "user1",
651 "tags": ["标签1", "标签2"] # 可为空
652 }
653 """
654 if request.content_type != 'application/json':
655 return jsonify({"error": "Content-Type must be application/json"}), 415
656
657 data = request.get_json()
658 user_id = data.get("user_id")
659 tags = set(data.get("tags", []))
660
661 # 查询用户已保存的兴趣标签
662 user_tags = set()
663 if user_id:
664 conn = get_db_conn()
665 try:
666 with conn.cursor() as cursor:
667 cursor.execute("SELECT t.name FROM user_tags ut JOIN tags t ON ut.tag_id = t.id WHERE ut.user_id=%s", (user_id,))
668 user_tags = set(row[0] for row in cursor.fetchall())
669 finally:
670 conn.close()
671
672 # 合并前端传递的tags和用户兴趣标签
673 all_tags = list(tags | user_tags)
674
675 if not all_tags:
676 return Response(json.dumps({"error": "暂无推荐结果"}, ensure_ascii=False), mimetype='application/json; charset=utf-8'), 200
677
678 conn = get_db_conn()
679 try:
680 with conn.cursor(pymysql.cursors.DictCursor) as cursor:
681 # 优先用tags字段匹配
682 # 先查找所有tag_id
683 tag_ids = []
684 for tag in all_tags:
685 cursor.execute("SELECT id FROM tags WHERE name=%s", (tag,))
686 row = cursor.fetchone()
687 if row:
688 tag_ids.append(row['id'])
689 if not tag_ids:
690 return Response(json.dumps({"error": "暂无推荐结果"}, ensure_ascii=False), mimetype='application/json; charset=utf-8'), 200
691 tag_placeholders = ','.join(['%s'] * len(tag_ids))
692 sql = f"""
693 SELECT p.id, p.title, tp.name as category, p.heat,
694 GROUP_CONCAT(tg.name) as tags
695 FROM posts p
696 LEFT JOIN post_tags pt ON p.id = pt.post_id
697 LEFT JOIN tags tg ON pt.tag_id = tg.id
698 LEFT JOIN topics tp ON p.topic_id = tp.id
699 WHERE pt.tag_id IN ({tag_placeholders})
700 GROUP BY p.id
701 LIMIT 50
702 """
703 cursor.execute(sql, tuple(tag_ids))
704 results = cursor.fetchall()
705 # 若无结果,回退title/content模糊匹配
706 if not results:
707 or_conditions = []
708 params = []
709 for tag in all_tags:
710 or_conditions.append("p.title LIKE %s OR p.content LIKE %s")
711 params.extend(['%' + tag + '%', '%' + tag + '%'])
712 where_clause = ' OR '.join(or_conditions)
713 sql = f"""
714 SELECT p.id, p.title, tp.name as category, p.heat,
715 GROUP_CONCAT(tg.name) as tags
716 FROM posts p
717 LEFT JOIN post_tags pt ON p.id = pt.post_id
718 LEFT JOIN tags tg ON pt.tag_id = tg.id
719 LEFT JOIN topics tp ON p.topic_id = tp.id
720 WHERE {where_clause}
721 GROUP BY p.id
722 LIMIT 50
723 """
724 cursor.execute(sql, tuple(params))
725 results = cursor.fetchall()
726 finally:
727 conn.close()
728
729 if not results:
730 return Response(json.dumps({"error": "暂无推荐结果"}, ensure_ascii=False), mimetype='application/json; charset=utf-8'), 200
731
732 return Response(json.dumps({"recommendations": results}, ensure_ascii=False), mimetype='application/json; charset=utf-8')
733
734# 用户兴趣标签管理API(可选)
735@app.route('/tags', methods=['POST', 'GET', 'DELETE'])
736def user_tags():
737 """
738 POST: 添加用户兴趣标签
739 GET: 查询用户兴趣标签
740 DELETE: 删除用户兴趣标签
741 """
742 if request.method == 'POST':
743 if request.content_type != 'application/json':
744 return jsonify({"error": "Content-Type must be application/json"}), 415
745 data = request.get_json()
746 user_id = data.get("user_id")
747 tags = data.get("tags", [])
748
749 if not user_id:
750 return jsonify({"error": "用户ID不能为空"}), 400
751
752 # 确保标签列表格式正确
753 if isinstance(tags, str):
754 tags = [tag.strip() for tag in tags.split(',') if tag.strip()]
755
756 if not tags:
757 return jsonify({"error": "标签不能为空"}), 400
758
759 conn = get_db_conn()
760 try:
761 with conn.cursor() as cursor:
762 # 添加用户标签
763 for tag in tags:
764 # 先查找tag_id
765 cursor.execute("SELECT id FROM tags WHERE name=%s", (tag,))
766 tag_row = cursor.fetchone()
767 if tag_row:
768 tag_id = tag_row[0]
769 cursor.execute("REPLACE INTO user_tags (user_id, tag_id) VALUES (%s, %s)", (user_id, tag_id))
770 conn.commit()
771 # 返回更新后的标签列表
772 cursor.execute("SELECT t.name FROM user_tags ut JOIN tags t ON ut.tag_id = t.id WHERE ut.user_id=%s", (user_id,))
773 updated_tags = [row[0] for row in cursor.fetchall()]
774 finally:
775 conn.close()
776 return Response(json.dumps({"msg": "添加成功", "tags": updated_tags}, ensure_ascii=False), mimetype='application/json; charset=utf-8')
777 elif request.method == 'DELETE':
778 if request.content_type != 'application/json':
779 return jsonify({"error": "Content-Type must be application/json"}), 415
780 data = request.get_json()
781 user_id = data.get("user_id")
782 tags = data.get("tags", [])
783 if not user_id:
784 return jsonify({"error": "用户ID不能为空"}), 400
785 if not tags:
786 return jsonify({"error": "标签不能为空"}), 400
787
788 conn = get_db_conn()
789 try:
790 with conn.cursor() as cursor:
791 for tag in tags:
792 cursor.execute("SELECT id FROM tags WHERE name=%s", (tag,))
793 tag_row = cursor.fetchone()
794 if tag_row:
795 tag_id = tag_row[0]
796 cursor.execute("DELETE FROM user_tags WHERE user_id=%s AND tag_id=%s", (user_id, tag_id))
797 conn.commit()
798 cursor.execute("SELECT t.name FROM user_tags ut JOIN tags t ON ut.tag_id = t.id WHERE ut.user_id=%s", (user_id,))
799 remaining_tags = [row[0] for row in cursor.fetchall()]
800 finally:
801 conn.close()
802 return Response(json.dumps({"msg": "删除成功", "tags": remaining_tags}, ensure_ascii=False), mimetype='application/json; charset=utf-8')
803 else: # GET 请求
804 user_id = request.args.get("user_id")
805 if not user_id:
806 return jsonify({"error": "用户ID不能为空"}), 400
807 conn = get_db_conn()
808 try:
809 with conn.cursor() as cursor:
810 cursor.execute("SELECT t.name FROM user_tags ut JOIN tags t ON ut.tag_id = t.id WHERE ut.user_id=%s", (user_id,))
811 tags = [row[0] for row in cursor.fetchall()]
812 finally:
813 conn.close()
814 return Response(json.dumps({"tags": tags}, ensure_ascii=False), mimetype='application/json; charset=utf-8')
815
816# 添加/user_tags路由作为/tags的别名
817@app.route('/user_tags', methods=['POST', 'GET', 'DELETE'])
818def user_tags_alias():
819 """
820 /user_tags路由 - 作为/tags路由的别名
821 POST: 添加用户兴趣标签
822 GET: 查询用户兴趣标签
823 DELETE: 删除用户兴趣标签
824 """
825 return user_tags()
826
827# 基于用户的协同过滤推荐API
828@app.route('/user_based_recommend', methods=['POST'])
829def user_based_recommend():
830 """
831 基于用户的协同过滤推荐API
832 请求格式:{
833 "user_id": "user1",
834 "top_n": 5
835 }
836 """
837 if request.content_type != 'application/json':
838 return jsonify({"error": "Content-Type must be application/json"}), 415
839
840 data = request.get_json()
841 user_id = data.get("user_id")
842 top_n = int(data.get("top_n", 5))
843
844 if not user_id:
845 return jsonify({"error": "用户ID不能为空"}), 400
846
847 conn = get_db_conn()
848 try:
849 with conn.cursor(pymysql.cursors.DictCursor) as cursor:
850 # 1. 检查用户是否存在下载记录(收藏或浏览)
851 cursor.execute("""
852 SELECT COUNT(*) as count
853 FROM behaviors
854 WHERE user_id = %s AND type IN ('favorite', 'view')
855 """, (user_id,))
856 result = cursor.fetchone()
857 user_download_count = result['count'] if result else 0
858
859 logger.info(f"用户 {user_id} 下载记录数: {user_download_count}")
860
861 # 如果用户没有足够的行为数据,返回基于热度的推荐
862 if user_download_count < 3:
863 logger.info(f"用户 {user_id} 下载记录不足,返回热门推荐")
864 cursor.execute("""
865 SELECT p.id, p.title, tp.name as category, p.heat
866 FROM posts p
867 LEFT JOIN topics tp ON p.topic_id = tp.id
868 ORDER BY p.heat DESC
869 LIMIT %s
870 """, (top_n,))
871 popular_seeds = cursor.fetchall()
872 return Response(json.dumps({"recommendations": popular_seeds, "type": "popular"}, ensure_ascii=False), mimetype='application/json; charset=utf-8')
873
874 # 2. 获取用户已下载(收藏/浏览)的帖子
875 cursor.execute("""
876 SELECT post_id
877 FROM behaviors
878 WHERE user_id = %s AND type IN ('favorite', 'view')
879 """, (user_id,))
880 user_seeds = set(row['post_id'] for row in cursor.fetchall())
881 logger.info(f"用户 {user_id} 已下载种子: {user_seeds}")
882
883 # 3. 获取所有用户-帖子下载(收藏/浏览)矩阵
884 cursor.execute("""
885 SELECT user_id, post_id
886 FROM behaviors
887 WHERE created_at > DATE_SUB(NOW(), INTERVAL 3 MONTH)
888 AND user_id <> %s AND type IN ('favorite', 'view')
889 """, (user_id,))
890 download_records = cursor.fetchall()
891
892 if not download_records:
893 logger.info(f"没有其他用户的下载记录,返回热门推荐")
894 cursor.execute("""
895 SELECT p.id, p.title, tp.name as category, p.heat
896 FROM posts p
897 LEFT JOIN topics tp ON p.topic_id = tp.id
898 ORDER BY p.heat DESC
899 LIMIT %s
900 """, (top_n,))
901 popular_seeds = cursor.fetchall()
902 return Response(json.dumps({"recommendations": popular_seeds, "type": "popular"}, ensure_ascii=False), mimetype='application/json; charset=utf-8')
903
904 # 构建用户-物品矩阵
905 user_item_matrix = {}
906 for record in download_records:
907 uid = record['user_id']
908 sid = record['post_id']
909 if uid not in user_item_matrix:
910 user_item_matrix[uid] = set()
911 user_item_matrix[uid].add(sid)
912
913 # 4. 计算用户相似度
914 similar_users = []
915 for other_id, other_seeds in user_item_matrix.items():
916 if other_id == user_id:
917 continue
918 intersection = len(user_seeds.intersection(other_seeds))
919 union = len(user_seeds.union(other_seeds))
920 if union > 0 and intersection > 0:
921 similarity = intersection / union
922 similar_users.append((other_id, similarity, other_seeds))
923 logger.info(f"找到 {len(similar_users)} 个相似用户")
924 similar_users.sort(key=lambda x: x[1], reverse=True)
925 similar_users = similar_users[:5]
926 # 5. 基于相似用户推荐帖子
927 candidate_seeds = {}
928 for similar_user, similarity, seeds in similar_users:
929 logger.info(f"相似用户 {similar_user}, 相似度 {similarity}")
930 for post_id in seeds:
931 if post_id not in user_seeds:
932 if post_id not in candidate_seeds:
933 candidate_seeds[post_id] = 0
934 candidate_seeds[post_id] += similarity
935 if not candidate_seeds:
936 logger.info(f"没有找到候选种子,返回热门推荐")
937 cursor.execute("""
938 SELECT p.id, p.title, tp.name as category, p.heat
939 FROM posts p
940 LEFT JOIN topics tp ON p.topic_id = tp.id
941 ORDER BY p.heat DESC
942 LIMIT %s
943 """, (top_n,))
944 popular_seeds = cursor.fetchall()
945 return Response(json.dumps({"recommendations": popular_seeds, "type": "popular"}, ensure_ascii=False), mimetype='application/json; charset=utf-8')
946 # 6. 获取推荐帖子的详细信息
947 recommended_seeds = sorted(candidate_seeds.items(), key=lambda x: x[1], reverse=True)[:top_n]
948 post_ids = [post_id for post_id, _ in recommended_seeds]
949 format_strings = ','.join(['%s'] * len(post_ids))
950 cursor.execute(f"""
951 SELECT p.id, p.title, tp.name as category, p.heat
952 FROM posts p
953 LEFT JOIN topics tp ON p.topic_id = tp.id
954 WHERE p.id IN ({format_strings})
955 """, tuple(post_ids))
956 result_seeds = cursor.fetchall()
957 seed_score_map = {post_id: score for post_id, score in recommended_seeds}
958 result_seeds.sort(key=lambda x: seed_score_map.get(x['id'], 0), reverse=True)
959 logger.info(f"返回 {len(result_seeds)} 个基于协同过滤的推荐")
960 return Response(json.dumps({"recommendations": result_seeds, "type": "collaborative"}, ensure_ascii=False), mimetype='application/json; charset=utf-8')
961 except Exception as e:
962 logger.error(f"推荐系统错误: {e}")
963 import traceback
964 traceback.print_exc()
965 return Response(json.dumps({"error": "推荐系统异常,请稍后再试", "details": str(e)}, ensure_ascii=False), mimetype='application/json; charset=utf-8')
966 finally:
967 conn.close()
968@app.route('/word2vec_status', methods=['GET'])
969def word2vec_status():
970 """
971 检查Word2Vec模型状态
972 返回模型是否加载、词汇量等信息
973 """
974 if not WORD2VEC_ENABLED:
975 return Response(json.dumps({
976 "enabled": False,
977 "message": "Word2Vec功能未启用"
978 }, ensure_ascii=False), mimetype='application/json; charset=utf-8')
979 try:
980 helper = get_word2vec_helper()
981 status = {
982 "enabled": WORD2VEC_ENABLED,
983 "initialized": helper.initialized,
984 "vocab_size": len(helper.model.index_to_key) if helper.model else 0,
985 "vector_size": helper.model.vector_size if helper.model else 0
986 }
987
988 # 测试几个常用词的相似词,展示模型效果
989 test_results = {}
990 test_words = ["电影", "动作", "科幻", "动漫", "游戏"]
991 for word in test_words:
992 similar_words = helper.get_similar_words(word, topn=5)
993 test_results[word] = similar_words
994
995 status["test_results"] = test_results
996 return Response(json.dumps(status, ensure_ascii=False), mimetype='application/json; charset=utf-8')
997 except Exception as e:
998 return Response(json.dumps({
999 "enabled": WORD2VEC_ENABLED,
1000 "initialized": False,
1001 "error": str(e)
1002 }, ensure_ascii=False), mimetype='application/json; charset=utf-8')
1003
1004# 添加一个临时诊断端点
1005@app.route('/debug_search', methods=['POST'])
1006def debug_search():
1007 """临时的调试端点,用于检查数据库中的记录"""
1008 if request.content_type != 'application/json':
1009 return jsonify({"error": "Content-Type must be application/json"}), 415
1010
1011 data = request.get_json()
1012 keyword = data.get("keyword", "").strip()
1013
1014 conn = get_db_conn()
1015 try:
1016 with conn.cursor(pymysql.cursors.DictCursor) as cursor:
1017 # 尝试查询包含特定词的所有记录
1018 queries = [
1019 ("标题中包含关键词", f"SELECT seed_id, title, description, tags FROM pt_seed WHERE title LIKE '%{keyword}%' LIMIT 10"),
1020 ("描述中包含关键词", f"SELECT seed_id, title, description, tags FROM pt_seed WHERE description LIKE '%{keyword}%' LIMIT 10"),
1021 ("标签中包含关键词", f"SELECT seed_id, title, description, tags FROM pt_seed WHERE FIND_IN_SET('{keyword}', tags) LIMIT 10"),
1022 ("肖申克的救赎", "SELECT seed_id, title, description, tags FROM pt_seed WHERE title = '肖申克的救赎'")
1023 ]
1024
1025 results = {}
1026 for query_name, query in queries:
1027 cursor.execute(query)
1028 results[query_name] = cursor.fetchall()
1029
1030 return Response(json.dumps(results, ensure_ascii=False), mimetype='application/json; charset=utf-8')
1031 finally:
1032 conn.close()
1033
1034"""
1035接口本地测试方法(可直接运行main_online.py后用curl或Postman测试):
1036
10371. 搜索接口
1038curl -X POST http://127.0.0.1:5000/search -H "Content-Type: application/json" -d '{"keyword":"电影","sort_by":"downloads"}'
1039
10402. 标签推荐接口
1041curl -X POST http://127.0.0.1:5000/recommend_tags -H "Content-Type: application/json" -d '{"user_id":"1","tags":["动作","科幻"]}'
1042
10433. 用户兴趣标签管理(添加标签)
1044curl -X POST http://127.0.0.1:5000/user_tags -H "Content-Type: application/json" -d '{"user_id":"1","tags":["动作","科幻"]}'
1045
10464. 用户兴趣标签管理(查询标签)
1047curl "http://127.0.0.1:5000/user_tags?user_id=1"
1048
10495. 用户兴趣标签管理(删除标签)
1050curl -X DELETE http://127.0.0.1:5000/user_tags -H "Content-Type: application/json" -d '{"user_id":"1","tags":["动作","科幻"]}'
1051
10526. 协同过滤推荐
1053curl -X POST http://127.0.0.1:5000/user_based_recommend -H "Content-Type: application/json" -d '{"user_id":"user1","top_n":3}'
1054
10557. Word2Vec状态检查
1056curl "http://127.0.0.1:5000/word2vec_status"
1057
10588. 调试接口(临时)
1059curl -X POST http://127.0.0.1:5000/debug_search -H "Content-Type: application/json" -d '{"keyword":"电影"}'
1060
1061所有接口均可用Postman按上述参数测试。
1062"""
1063
1064if __name__ == "__main__":
1065 try:
1066 logger.info("搜索推荐服务启动中...")
1067 app.run(host="0.0.0.0", port=5000)
1068 except Exception as e:
1069 logger.error(f"启动异常: {e}")
1070 import traceback
1071 traceback.print_exc()