22301008 | cae762d | 2025-06-14 00:27:04 +0800 | [diff] [blame] | 1 | # main_online.py |
| 2 | # 搜索推荐算法服务的主入口 |
| 3 | |
| 4 | import json |
| 5 | import numpy as np |
| 6 | import difflib |
| 7 | from flask import Flask, request, jsonify, Response |
| 8 | import pymysql |
| 9 | import jieba |
| 10 | from sklearn.feature_extraction.text import TfidfVectorizer |
| 11 | from sklearn.metrics.pairwise import cosine_similarity |
| 12 | import pypinyin |
| 13 | from flask_cors import CORS |
| 14 | import re |
| 15 | import Levenshtein |
| 16 | import os |
| 17 | import logging |
| 18 | |
| 19 | # 设置日志 |
| 20 | logging.basicConfig( |
| 21 | level=logging.INFO, |
| 22 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| 23 | ) |
| 24 | logger = logging.getLogger("allpt-search") |
| 25 | |
| 26 | # 导入Word2Vec辅助模块 |
| 27 | try: |
| 28 | from word2vec_helper import get_word2vec_helper, expand_query, get_similar_words |
| 29 | WORD2VEC_ENABLED = True |
| 30 | logger.info("Word2Vec模块已加载") |
| 31 | except ImportError as e: |
| 32 | logger.warning(f"Word2Vec模块加载失败: {e},将使用传统搜索") |
| 33 | WORD2VEC_ENABLED = False |
| 34 | |
| 35 | # 数据库配置 |
| 36 | DB_CONFIG = { |
| 37 | "host": "10.126.59.25", |
| 38 | "port": 3306, |
| 39 | "user": "root", |
| 40 | "password": "123456", |
| 41 | "database": "redbook", |
| 42 | "charset": "utf8mb4" |
| 43 | } |
| 44 | |
| 45 | def get_db_conn(): |
| 46 | return pymysql.connect(**DB_CONFIG) |
| 47 | |
| 48 | def get_pinyin(text): |
| 49 | # 返回字符串的全拼音(不带声调,全部小写),支持英文直接返回 |
| 50 | if not text: |
| 51 | return "" |
| 52 | import re |
| 53 | # 如果全是英文,直接返回小写 |
| 54 | if re.fullmatch(r'[a-zA-Z]+', text): |
| 55 | return text.lower() |
| 56 | return ''.join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.NORMAL)]) |
| 57 | |
| 58 | def get_pinyin_initials(text): |
| 59 | # 返回字符串的首字母拼音(全部小写),支持英文直接返回 |
| 60 | if not text: |
| 61 | return "" |
| 62 | import re |
| 63 | if re.fullmatch(r'[a-zA-Z]+', text): |
| 64 | return text.lower() |
| 65 | return ''.join([p[0][0] for p in pypinyin.pinyin(text, style=pypinyin.NORMAL)]) |
| 66 | |
| 67 | # 新增词语相似度计算函数 |
| 68 | def word_similarity(word1, word2): |
| 69 | """计算两个词的相似度,支持拼音匹配""" |
| 70 | # 直接匹配 |
| 71 | if word1 == word2: |
| 72 | return 1.0 |
| 73 | |
| 74 | # 拼音匹配 |
| 75 | if get_pinyin(word1) == get_pinyin(word2): |
| 76 | return 0.9 |
| 77 | |
| 78 | # 拼音首字母匹配 |
| 79 | if get_pinyin_initials(word1) == get_pinyin_initials(word2): |
| 80 | return 0.7 |
| 81 | |
| 82 | # 字符串相似度 |
| 83 | return difflib.SequenceMatcher(None, word1, word2).ratio() |
| 84 | |
| 85 | def semantic_title_similarity(query, title): |
| 86 | """计算查询词与标题的语义相似度""" |
| 87 | # 分词 |
| 88 | query_words = list(jieba.cut(query)) |
| 89 | title_words = list(jieba.cut(title)) |
| 90 | |
| 91 | if not query_words or not title_words: |
| 92 | return 0.0 |
| 93 | |
| 94 | # 计算每个查询词与标题词的最大相似度 |
| 95 | max_similarities = [] |
| 96 | key_matches = 0 # 关键词精确匹配数量 |
| 97 | |
| 98 | for q_word in query_words: |
| 99 | if len(q_word.strip()) <= 1: # 忽略单字,减少噪音 |
| 100 | continue |
| 101 | |
| 102 | word_sims = [word_similarity(q_word, t_word) for t_word in title_words] |
| 103 | if word_sims: |
| 104 | max_sim = max(word_sims) |
| 105 | max_similarities.append(max_sim) |
| 106 | if max_sim > 0.85: # 认为是关键词匹配 |
| 107 | key_matches += 1 |
| 108 | |
| 109 | if not max_similarities: |
| 110 | return 0.0 |
| 111 | |
| 112 | # 计算平均相似度 |
| 113 | avg_sim = sum(max_similarities) / len(max_similarities) |
| 114 | |
| 115 | # 权重计算: 平均相似度占70%,关键词匹配率占30% |
| 116 | key_match_ratio = key_matches / len(query_words) if query_words else 0 |
| 117 | |
| 118 | # 标题中包含完整查询短语时给予额外加分 |
| 119 | exact_bonus = 0.3 if query in title else 0 |
| 120 | |
| 121 | return 0.7 * avg_sim + 0.3 * key_match_ratio + exact_bonus |
| 122 | |
| 123 | # 添加语义关联词典,用于增强搜索能力 |
| 124 | def load_semantic_mappings(): |
| 125 | """ |
| 126 | 加载语义关联映射表,用于增强搜索语义理解 |
| 127 | 返回包含语义映射关系的字典 |
| 128 | """ |
| 129 | # 初始化空字典,所有映射将从配置文件加载 |
| 130 | mappings = {} |
| 131 | |
| 132 | # 从配置文件加载映射 |
| 133 | try: |
| 134 | config_path = os.path.join(os.path.dirname(__file__), "semantic_config.json") |
| 135 | if os.path.exists(config_path): |
| 136 | with open(config_path, 'r', encoding='utf-8') as f: |
| 137 | mappings = json.load(f) |
| 138 | logger.info(f"已从配置文件加载 {len(mappings)} 个语义映射") |
| 139 | else: |
| 140 | logger.warning(f"语义配置文件不存在: {config_path}") |
| 141 | except Exception as e: |
| 142 | logger.error(f"加载语义配置文件失败: {e}") |
| 143 | |
| 144 | return mappings |
| 145 | |
| 146 | # 初始化语义映射 |
| 147 | SEMANTIC_MAPPINGS = load_semantic_mappings() |
| 148 | |
| 149 | def expand_search_keywords(keyword): |
| 150 | """ |
| 151 | 扩展搜索关键词,增加语义关联词 |
| 152 | """ |
| 153 | expanded = [keyword] |
| 154 | |
| 155 | # 分词处理 |
| 156 | words = list(jieba.cut(keyword)) |
| 157 | logger.info(f"关键词 '{keyword}' 分词结果: {words}") # 记录分词结果 |
| 158 | |
| 159 | # 分别对每个分词进行语义扩展 |
| 160 | for word in words: |
| 161 | if word in SEMANTIC_MAPPINGS: |
| 162 | # 添加语义关联词 |
| 163 | mapped_words = SEMANTIC_MAPPINGS[word] |
| 164 | expanded.extend(mapped_words) |
| 165 | logger.info(f"语义映射: '{word}' -> {mapped_words}") |
| 166 | |
| 167 | # 移除所有特殊处理部分 |
| 168 | # 不再对任何特定关键词如"越狱"进行特殊处理 |
| 169 | |
| 170 | # Word2Vec扩展 - 如果可用,对分词结果进行Word2Vec扩展 |
| 171 | if WORD2VEC_ENABLED: |
| 172 | try: |
| 173 | # 使用单独的变量记录原始扩展结果,方便记录日志 |
| 174 | original_expanded = set(expanded) |
| 175 | |
| 176 | # 首先尝试对整个关键词进行扩展 |
| 177 | w2v_expanded = set() |
| 178 | similar_words = get_similar_words(keyword, topn=3, min_similarity=0.6) |
| 179 | w2v_expanded.update(similar_words) |
| 180 | |
| 181 | # 然后对较长的分词进行扩展 |
| 182 | for word in words: |
| 183 | if len(word) > 1: # 忽略单字 |
| 184 | similar_words = get_similar_words(word, topn=2, min_similarity=0.65) |
| 185 | w2v_expanded.update(similar_words) |
| 186 | |
| 187 | # 合并结果 |
| 188 | expanded.extend(w2v_expanded) |
| 189 | |
| 190 | # 记录日志 |
| 191 | if w2v_expanded: |
| 192 | logger.info(f"Word2Vec扩展: {keyword} -> {list(w2v_expanded)}") |
| 193 | except Exception as e: |
| 194 | # 出错时记录但不中断搜索流程 |
| 195 | logger.error(f"Word2Vec扩展失败: {e}") |
| 196 | logger.info("将仅使用配置文件中的语义映射") |
| 197 | |
| 198 | # 去重 |
| 199 | return list(set(expanded)) |
| 200 | |
| 201 | # 替换原有的calculate_keyword_relevance函数,采用更通用的相关性算法 |
| 202 | def calculate_keyword_relevance(keyword, item): |
| 203 | """计算搜索关键词与条目的相关性得分""" |
| 204 | title = item.get('title', '') |
| 205 | description = item.get('description', '') or '' |
| 206 | tags = item.get('tags', '') or '' |
| 207 | category = item.get('category', '') or '' # 添加category字段 |
| 208 | |
| 209 | # 初始化得分 |
| 210 | score = 0 |
| 211 | |
| 212 | # 1. 精确匹配(最高优先级) |
| 213 | if keyword.lower() == title.lower(): |
| 214 | return 15.0 # 完全匹配给予最高分 |
| 215 | |
| 216 | # 2. 标题中精确词匹配 |
| 217 | title_words = re.findall(r'\b\w+\b', title.lower()) |
| 218 | if keyword.lower() in title_words: |
| 219 | score += 10.0 # 作为独立词完全匹配 |
| 220 | |
| 221 | # 3. 标题包含关键词(部分匹配) |
| 222 | elif keyword.lower() in title.lower(): |
| 223 | # 计算关键词所占标题比例 |
| 224 | match_ratio = len(keyword) / len(title) |
| 225 | if match_ratio > 0.5: # 关键词占标题很大比例 |
| 226 | score += 8.0 |
| 227 | else: |
| 228 | score += 5.0 |
| 229 | |
| 230 | # 4. 标题分词匹配 |
| 231 | keyword_words = list(jieba.cut(keyword)) |
| 232 | title_jieba_words = list(jieba.cut(title)) |
| 233 | |
| 234 | matched_words = 0 |
| 235 | for k_word in keyword_words: |
| 236 | if len(k_word) > 1: # 忽略单字 |
| 237 | if k_word in title_jieba_words: |
| 238 | matched_words += 1 |
| 239 | else: |
| 240 | # 拼音匹配 |
| 241 | k_pinyin = get_pinyin(k_word) |
| 242 | for t_word in title_jieba_words: |
| 243 | if get_pinyin(t_word) == k_pinyin: |
| 244 | matched_words += 0.8 |
| 245 | break |
| 246 | |
| 247 | if len(keyword_words) > 0: |
| 248 | word_match_ratio = matched_words / len(keyword_words) |
| 249 | score += 3.0 * word_match_ratio |
| 250 | |
| 251 | # 5. 拼音相似度 |
| 252 | keyword_pinyin = get_pinyin(keyword) |
| 253 | title_pinyin = get_pinyin(title) |
| 254 | |
| 255 | if keyword_pinyin == title_pinyin: |
| 256 | score += 3.5 |
| 257 | elif keyword_pinyin in title_pinyin: |
| 258 | # 计算拼音在标题中的位置影响 |
| 259 | pos = title_pinyin.find(keyword_pinyin) |
| 260 | if pos == 0: # 出现在开头 |
| 261 | score += 3.0 |
| 262 | else: |
| 263 | score += 2.0 |
| 264 | |
| 265 | # 6. 编辑距离相似度 |
| 266 | try: |
| 267 | edit_distance = Levenshtein.distance(keyword.lower(), title.lower()) |
| 268 | max_len = max(len(keyword), len(title)) |
| 269 | if max_len > 0: |
| 270 | similarity = 1 - (edit_distance / max_len) |
| 271 | if similarity > 0.7: |
| 272 | score += 1.5 * similarity |
| 273 | except: |
| 274 | similarity = difflib.SequenceMatcher(None, keyword.lower(), title.lower()).ratio() |
| 275 | if similarity > 0.7: |
| 276 | score += 1.5 * similarity |
| 277 | |
| 278 | # 7. 中文字符重叠检测 - 修改为仅当重叠2个以上汉字或占比超过40%时才计分 |
| 279 | if re.search(r'[\u4e00-\u9fff]', keyword) and re.search(r'[\u4e00-\u9fff]', title): |
| 280 | cn_chars_keyword = set(re.findall(r'[\u4e00-\u9fff]', keyword)) |
| 281 | cn_chars_title = set(re.findall(r'[\u4e00-\u9fff]', title)) |
| 282 | |
| 283 | # 计算重叠的汉字集合 |
| 284 | overlapped_chars = cn_chars_keyword & cn_chars_title |
| 285 | |
| 286 | # 仅当重叠汉字数量大于1且占比超过阈值时才计分 |
| 287 | if len(overlapped_chars) > 1 and len(cn_chars_keyword) > 0: |
| 288 | overlap_ratio = len(overlapped_chars) / len(cn_chars_keyword) |
| 289 | # 增加重叠比例的阈值要求,防止单个汉字导致的误匹配 |
| 290 | if overlap_ratio >= 0.4 or len(overlapped_chars) >= 3: |
| 291 | score += 2.0 * overlap_ratio |
| 292 | # 对于非常低的重叠度,不加分,避免无关内容干扰 |
| 293 | |
| 294 | # 记录日志,帮助调试特定案例 |
| 295 | if keyword == "明日方舟" and "白日梦想家" in title: |
| 296 | logger.info(f"'明日方舟'与'{title}'的汉字重叠: {overlapped_chars}, 重叠比例: {len(overlapped_chars)/len(cn_chars_keyword) if cn_chars_keyword else 0}") |
| 297 | |
| 298 | # 8. 序列资源检测(如"功夫熊猫2"是"功夫熊猫"的系列) |
| 299 | base_title_match = re.match(r'(.*?)([0-9]+|[一二三四五六七八九十]|:|\:|\s+[0-9]+)', title) |
| 300 | if base_title_match: |
| 301 | base_title = base_title_match.group(1).strip() |
| 302 | if keyword.lower() == base_title.lower(): |
| 303 | score += 2.0 |
| 304 | |
| 305 | # 9. 标签和描述匹配(增加权重) |
| 306 | if tags: |
| 307 | tags_list = tags.split(',') |
| 308 | if keyword in tags_list: |
| 309 | score += 1.5 # 提高标签匹配的权重 |
| 310 | elif any(keyword.lower() in tag.lower() for tag in tags_list): |
| 311 | score += 1.0 # 提高部分匹配的权重 |
| 312 | |
| 313 | # 描述匹配增强 |
| 314 | if keyword.lower() in description.lower(): |
| 315 | score += 1.5 # 提高描述匹配的权重 |
| 316 | |
| 317 | # 检查关键词在描述中的位置和上下文 |
| 318 | pos = description.lower().find(keyword.lower()) |
| 319 | if pos >= 0 and pos < len(description) / 3: |
| 320 | # 关键词出现在描述前1/3部分,可能更重要 |
| 321 | score += 0.5 |
| 322 | |
| 323 | # 考虑分词匹配描述 |
| 324 | keyword_words = list(jieba.cut(keyword)) |
| 325 | description_words = list(jieba.cut(description)) |
| 326 | matched_desc_words = 0 |
| 327 | for k_word in keyword_words: |
| 328 | if len(k_word) > 1 and k_word in description_words: |
| 329 | matched_desc_words += 1 |
| 330 | |
| 331 | if len(keyword_words) > 0: |
| 332 | desc_match_ratio = matched_desc_words / len(keyword_words) |
| 333 | score += 1.0 * desc_match_ratio |
| 334 | |
| 335 | # 分类匹配 |
| 336 | if keyword.lower() in category.lower(): |
| 337 | score += 1.0 |
| 338 | |
| 339 | # 添加语义关联匹配得分 |
| 340 | # 扩展关键词进行匹配 |
| 341 | expanded_keywords = expand_search_keywords(keyword) |
22301008 | cae762d | 2025-06-14 00:27:04 +0800 | [diff] [blame] | 342 | # 检测标题是否包含语义相关词 |
| 343 | for exp_keyword in expanded_keywords: |
| 344 | if exp_keyword != keyword and exp_keyword in title: # 避免重复计算原关键词 |
22301008 | d5fbb78 | 2025-06-18 16:28:43 +0800 | [diff] [blame] | 345 | score += 1.5 # 一般语义关联 |
22301008 | cae762d | 2025-06-14 00:27:04 +0800 | [diff] [blame] | 346 | |
| 347 | return score |
| 348 | |
| 349 | # 创建Flask应用 |
| 350 | app = Flask(__name__) |
| 351 | CORS(app) # 允许所有跨域请求 |
| 352 | |
| 353 | # 添加init_word2vec函数 |
| 354 | def init_word2vec(): |
| 355 | """初始化Word2Vec模型""" |
| 356 | try: |
| 357 | helper = get_word2vec_helper() |
| 358 | if helper.initialized: |
| 359 | logger.info(f"Word2Vec模型已成功加载,词汇量: {len(helper.model.index_to_key)}, 向量维度: {helper.model.vector_size}") |
| 360 | else: |
| 361 | if helper.load_model(): |
| 362 | logger.info(f"Word2Vec模型加载成功,词汇量: {len(helper.model.index_to_key)}, 向量维度: {helper.model.vector_size}") |
| 363 | else: |
| 364 | logger.error("Word2Vec模型加载失败") |
| 365 | except Exception as e: |
| 366 | logger.error(f"初始化Word2Vec出错: {e}") |
| 367 | |
| 368 | # 新的初始化方式: |
| 369 | def initialize_app(): |
| 370 | """应用初始化函数,替代before_first_request装饰器""" |
| 371 | # 修正:使用正确的函数名 |
| 372 | # 原代码: init_semantic_mapping() |
| 373 | # 修正为使用已定义的函数名 |
| 374 | global SEMANTIC_MAPPINGS |
| 375 | SEMANTIC_MAPPINGS = load_semantic_mappings() # 更新全局语义映射变量 |
| 376 | |
| 377 | if WORD2VEC_ENABLED: |
| 378 | init_word2vec() # 现在这个函数已经定义了 |
| 379 | |
| 380 | # 在启动应用之前调用初始化函数 |
| 381 | initialize_app() |
| 382 | |
22301008 | d5fbb78 | 2025-06-18 16:28:43 +0800 | [diff] [blame] | 383 | # 测试路由 |
| 384 | @app.route('/test', methods=['GET']) |
| 385 | def test(): |
| 386 | import datetime |
| 387 | return jsonify({"message": "服务器正常运行", "timestamp": str(datetime.datetime.now())}) |
| 388 | |
| 389 | # 获取单个帖子详情的API |
| 390 | @app.route('/post/<int:post_id>', methods=['GET']) |
| 391 | def get_post_detail(post_id): |
| 392 | """ |
| 393 | 获取单个帖子详情 |
| 394 | """ |
| 395 | logger.info(f"接收到获取帖子详情请求,post_id: {post_id}") |
| 396 | conn = get_db_conn() |
| 397 | try: |
| 398 | with conn.cursor(pymysql.cursors.DictCursor) as cursor: |
| 399 | # 查询帖子详情,先用简单查询调试 |
| 400 | query = """ |
| 401 | SELECT |
| 402 | p.id, |
| 403 | p.title, |
| 404 | p.content, |
| 405 | p.heat, |
| 406 | p.created_at as create_time, |
| 407 | p.updated_at as last_active, |
| 408 | p.status |
| 409 | FROM posts p |
| 410 | WHERE p.id = %s |
| 411 | """ |
| 412 | logger.info(f"执行查询: {query} with post_id: {post_id}") |
| 413 | cursor.execute(query, (post_id,)) |
| 414 | post = cursor.fetchone() |
| 415 | |
| 416 | logger.info(f"查询结果: {post}") |
| 417 | |
| 418 | if not post: |
| 419 | logger.warning(f"帖子不存在,post_id: {post_id}") |
| 420 | return jsonify({"error": "帖子不存在"}), 404 |
| 421 | |
| 422 | # 设置默认值 |
| 423 | post['tags'] = [] |
| 424 | post['category'] = '未分类' |
| 425 | post['author'] = '匿名用户' |
| 426 | |
| 427 | # 格式化时间 |
| 428 | if post['create_time']: |
| 429 | post['create_time'] = post['create_time'].strftime('%Y-%m-%d %H:%M:%S') |
| 430 | if post['last_active']: |
| 431 | post['last_active'] = post['last_active'].strftime('%Y-%m-%d %H:%M:%S') |
| 432 | |
| 433 | logger.info(f"返回帖子详情: {post}") |
| 434 | return Response(json.dumps(post, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 435 | except Exception as e: |
| 436 | logger.error(f"获取帖子详情失败: {e}") |
| 437 | import traceback |
| 438 | traceback.print_exc() |
| 439 | return jsonify({"error": "服务器内部错误"}), 500 |
| 440 | finally: |
| 441 | conn.close() |
| 442 | |
22301008 | cae762d | 2025-06-14 00:27:04 +0800 | [diff] [blame] | 443 | # 搜索功能的API |
| 444 | @app.route('/search', methods=['POST']) |
| 445 | def search(): |
| 446 | """ |
| 447 | 搜索功能API |
| 448 | 请求格式:{ |
| 449 | "keyword": "关键词", |
| 450 | "sort_by": "downloads" | "downloads_asc" | "newest" | "oldest" | "similarity" | "title_asc" | "title_desc", |
| 451 | "category": "可选,分类名", |
| 452 | "search_mode": "title" | "title_desc" | "tags" | "all" # 可选,默认"title", |
| 453 | "tags": ["标签1", "标签2"] # 可选,支持传递多个标签 |
| 454 | } |
| 455 | """ |
| 456 | if request.content_type != 'application/json': |
| 457 | return jsonify({"error": "Content-Type must be application/json"}), 415 |
| 458 | |
| 459 | data = request.get_json() |
| 460 | keyword = data.get("keyword", "").strip() |
| 461 | sort_by = data.get("sort_by", "similarity") # 默认按相似度排序 |
| 462 | category = data.get("category", None) |
| 463 | search_mode = data.get("search_mode", "title") |
| 464 | tags = data.get("tags", None) # 支持传递多个标签 |
| 465 | |
| 466 | # 校验参数 - 不管什么模式都要求关键词 |
| 467 | if not (1 <= len(keyword) <= 20): |
| 468 | return jsonify({"error": "请输入1-20个字符"}), 400 |
| 469 | |
| 470 | # 第一阶段:数据库查询获取候选集 |
| 471 | results = [] |
| 472 | conn = get_db_conn() |
| 473 | try: |
| 474 | with conn.cursor(pymysql.cursors.DictCursor) as cursor: |
| 475 | # 首先尝试查询完全匹配的结果 |
| 476 | exact_query = f""" |
| 477 | SELECT id, title, topic_id, heat, created_at, content |
| 478 | FROM posts |
| 479 | WHERE title = %s |
| 480 | """ |
| 481 | cursor.execute(exact_query, (keyword,)) |
| 482 | exact_matches = cursor.fetchall() or [] # 确保返回列表而非元组 |
| 483 | |
| 484 | # 扩展关键词,增加语义关联词 |
| 485 | expanded_keywords = expand_search_keywords(keyword) |
| 486 | logger.info(f"扩展后的关键词: {expanded_keywords}") # 调试信息 |
| 487 | |
| 488 | # 构建查询条件 |
| 489 | conditions = [] |
| 490 | params = [] |
| 491 | |
| 492 | # 标题匹配 - 所有搜索模式都匹配title |
| 493 | conditions.append("title LIKE %s") |
| 494 | params.append(f"%{keyword}%") |
| 495 | |
| 496 | # 为扩展关键词添加标题匹配条件 |
| 497 | for exp_keyword in expanded_keywords: |
| 498 | if exp_keyword != keyword: # 避免重复原关键词 |
| 499 | conditions.append("title LIKE %s") |
| 500 | params.append(f"%{exp_keyword}%") |
| 501 | |
| 502 | # 描述匹配 |
| 503 | if search_mode in ["title_desc", "all"]: |
| 504 | # 原始关键词匹配描述 |
| 505 | conditions.append("content LIKE %s") |
| 506 | params.append(f"%{keyword}%") |
| 507 | |
| 508 | # 扩展关键词匹配描述 |
| 509 | for exp_keyword in expanded_keywords: |
| 510 | if exp_keyword != keyword: |
| 511 | conditions.append("content LIKE %s") |
| 512 | params.append(f"%{exp_keyword}%") |
| 513 | |
| 514 | # 标签匹配 |
| 515 | # 暂不处理,后续join实现 |
| 516 | |
| 517 | # 分类匹配 - 仅在all模式下 |
| 518 | if search_mode == "all": |
| 519 | # 原始关键词匹配分类 |
| 520 | conditions.append("topic_id LIKE %s") |
| 521 | params.append(f"%{keyword}%") |
| 522 | |
| 523 | # 扩展关键词匹配分类 |
| 524 | for exp_keyword in expanded_keywords: |
| 525 | if exp_keyword != keyword: |
| 526 | conditions.append("topic_id LIKE %s") |
| 527 | params.append(f"%{exp_keyword}%") |
| 528 | |
| 529 | # 构建SQL查询 |
| 530 | if conditions: |
| 531 | where_clause = " OR ".join(conditions) |
| 532 | logger.info(f"搜索条件: {where_clause}") |
| 533 | logger.info(f"参数列表: {params}") |
| 534 | |
| 535 | if category: |
| 536 | where_clause = f"({where_clause}) AND topic_id=%s" |
| 537 | params.append(category) |
| 538 | |
| 539 | sql = f""" |
| 540 | SELECT p.id, p.title, tp.name as category, p.heat, p.created_at, p.content, |
| 541 | GROUP_CONCAT(t.name) as tags |
| 542 | FROM posts p |
| 543 | LEFT JOIN post_tags pt ON p.id = pt.post_id |
| 544 | LEFT JOIN tags t ON pt.tag_id = t.id |
| 545 | LEFT JOIN topics tp ON p.topic_id = tp.id |
| 546 | WHERE {where_clause} |
| 547 | GROUP BY p.id |
| 548 | LIMIT 500 |
| 549 | """ |
| 550 | |
| 551 | cursor.execute(sql, params) |
| 552 | expanded_results = cursor.fetchall() |
| 553 | logger.info(f"数据库返回记录数: {len(expanded_results) if expanded_results else 0}") |
| 554 | else: |
| 555 | expanded_results = [] |
| 556 | |
| 557 | # 如果扩展查询和精确匹配都没有结果,获取全部记录进行相关性计算 |
| 558 | if not expanded_results and not exact_matches: |
| 559 | sql = "SELECT p.id, p.title, tp.name as category, p.heat, p.created_at, p.content, GROUP_CONCAT(t.name) as tags FROM posts p LEFT JOIN post_tags pt ON p.id = pt.post_id LEFT JOIN tags t ON pt.tag_id = t.id LEFT JOIN topics tp ON p.topic_id = tp.id" |
| 560 | if category: |
| 561 | sql += " WHERE p.topic_id=%s" |
| 562 | category_params = [category] |
| 563 | cursor.execute(sql + " GROUP BY p.id", category_params) |
| 564 | else: |
| 565 | cursor.execute(sql + " GROUP BY p.id") |
| 566 | |
| 567 | all_results = cursor.fetchall() or [] # 确保返回列表 |
| 568 | else: |
| 569 | if isinstance(exact_matches, tuple): |
| 570 | exact_matches = list(exact_matches) |
| 571 | if isinstance(expanded_results, tuple): |
| 572 | expanded_results = list(expanded_results) |
| 573 | all_results = expanded_results + exact_matches |
| 574 | |
| 575 | # 对所有结果使用相关性计算规则 |
| 576 | scored_results = [] |
| 577 | for item in all_results: |
| 578 | # 计算相关性得分 |
| 579 | relevance_score = calculate_keyword_relevance(keyword, item) |
| 580 | |
| 581 | # 降低相关性阈值,确保更多结果被保留 (从0.5改为0.1) |
| 582 | if relevance_score > 0.1: |
| 583 | item['relevance_score'] = relevance_score |
| 584 | scored_results.append(item) |
| 585 | logger.info(f"匹配项: {item['title']}, 相关性得分: {relevance_score}") |
| 586 | |
| 587 | # 按相关性得分排序 |
| 588 | scored_results.sort(key=lambda x: x.get('relevance_score', 0), reverse=True) |
| 589 | |
| 590 | # 确保精确匹配的结果置顶 |
| 591 | if exact_matches: |
| 592 | for exact_match in exact_matches: |
| 593 | exact_match['relevance_score'] = 20.0 # 超高分确保置顶 |
| 594 | |
| 595 | # 移除scored_results中已经存在于exact_matches的项 |
| 596 | exact_ids = {item['id'] for item in exact_matches} |
| 597 | scored_results = [item for item in scored_results if item['id'] not in exact_ids] |
| 598 | |
| 599 | # 合并两个结果集 |
| 600 | results = exact_matches + scored_results |
| 601 | else: |
| 602 | results = scored_results |
| 603 | |
| 604 | # 限制返回结果数量 |
| 605 | results = results[:50] |
| 606 | |
| 607 | except Exception as e: |
| 608 | logger.error(f"搜索出错: {e}") |
| 609 | import traceback |
| 610 | traceback.print_exc() |
| 611 | return jsonify({"error": "搜索系统异常,请稍后再试"}), 500 |
| 612 | finally: |
| 613 | conn.close() |
| 614 | |
| 615 | # 第二阶段:根据指定方式排序 |
| 616 | if results: |
| 617 | if sort_by == "similarity" or not sort_by: |
| 618 | # 保持按相关性得分排序,已经排好了 |
| 619 | pass |
| 620 | elif sort_by == "downloads": |
| 621 | results.sort(key=lambda x: x.get("download_count", 0), reverse=True) |
| 622 | elif sort_by == "downloads_asc": |
| 623 | results.sort(key=lambda x: x.get("download_count", 0)) |
| 624 | elif sort_by == "newest": |
| 625 | results.sort(key=lambda x: x.get("create_time", ""), reverse=True) |
| 626 | elif sort_by == "oldest": |
| 627 | results.sort(key=lambda x: x.get("create_time", "")) |
| 628 | elif sort_by == "title_asc": |
| 629 | results.sort(key=lambda x: x.get("title", "")) |
| 630 | elif sort_by == "title_desc": |
| 631 | results.sort(key=lambda x: x.get("title", ""), reverse=True) |
| 632 | |
22301008 | af17315 | 2025-06-15 10:46:25 +0800 | [diff] [blame] | 633 | # 最终处理:清理不需要返回的字段,并将 datetime 转为字符串 |
22301008 | cae762d | 2025-06-14 00:27:04 +0800 | [diff] [blame] | 634 | for item in results: |
| 635 | item.pop("description", None) |
| 636 | item.pop("tags", None) |
| 637 | item.pop("relevance_score", None) |
22301008 | af17315 | 2025-06-15 10:46:25 +0800 | [diff] [blame] | 638 | for k, v in item.items(): |
| 639 | if hasattr(v, 'isoformat'): |
| 640 | item[k] = v.isoformat(sep=' ', timespec='seconds') |
22301008 | cae762d | 2025-06-14 00:27:04 +0800 | [diff] [blame] | 641 | |
| 642 | return Response(json.dumps({"results": results}, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 643 | |
| 644 | # 推荐功能的API |
| 645 | @app.route('/recommend_tags', methods=['POST']) |
| 646 | def recommend_tags(): |
| 647 | """ |
| 648 | 推荐功能API |
| 649 | 请求格式:{ |
| 650 | "user_id": "user1", |
| 651 | "tags": ["标签1", "标签2"] # 可为空 |
| 652 | } |
| 653 | """ |
| 654 | if request.content_type != 'application/json': |
| 655 | return jsonify({"error": "Content-Type must be application/json"}), 415 |
| 656 | |
| 657 | data = request.get_json() |
| 658 | user_id = data.get("user_id") |
| 659 | tags = set(data.get("tags", [])) |
| 660 | |
| 661 | # 查询用户已保存的兴趣标签 |
| 662 | user_tags = set() |
| 663 | if user_id: |
| 664 | conn = get_db_conn() |
| 665 | try: |
| 666 | with conn.cursor() as cursor: |
| 667 | cursor.execute("SELECT t.name FROM user_tags ut JOIN tags t ON ut.tag_id = t.id WHERE ut.user_id=%s", (user_id,)) |
| 668 | user_tags = set(row[0] for row in cursor.fetchall()) |
| 669 | finally: |
| 670 | conn.close() |
| 671 | |
| 672 | # 合并前端传递的tags和用户兴趣标签 |
| 673 | all_tags = list(tags | user_tags) |
| 674 | |
| 675 | if not all_tags: |
| 676 | return Response(json.dumps({"error": "暂无推荐结果"}, ensure_ascii=False), mimetype='application/json; charset=utf-8'), 200 |
| 677 | |
| 678 | conn = get_db_conn() |
| 679 | try: |
| 680 | with conn.cursor(pymysql.cursors.DictCursor) as cursor: |
| 681 | # 优先用tags字段匹配 |
| 682 | # 先查找所有tag_id |
| 683 | tag_ids = [] |
| 684 | for tag in all_tags: |
| 685 | cursor.execute("SELECT id FROM tags WHERE name=%s", (tag,)) |
| 686 | row = cursor.fetchone() |
| 687 | if row: |
| 688 | tag_ids.append(row['id']) |
| 689 | if not tag_ids: |
| 690 | return Response(json.dumps({"error": "暂无推荐结果"}, ensure_ascii=False), mimetype='application/json; charset=utf-8'), 200 |
| 691 | tag_placeholders = ','.join(['%s'] * len(tag_ids)) |
| 692 | sql = f""" |
| 693 | SELECT p.id, p.title, tp.name as category, p.heat, |
| 694 | GROUP_CONCAT(tg.name) as tags |
| 695 | FROM posts p |
| 696 | LEFT JOIN post_tags pt ON p.id = pt.post_id |
| 697 | LEFT JOIN tags tg ON pt.tag_id = tg.id |
| 698 | LEFT JOIN topics tp ON p.topic_id = tp.id |
| 699 | WHERE pt.tag_id IN ({tag_placeholders}) |
| 700 | GROUP BY p.id |
| 701 | LIMIT 50 |
| 702 | """ |
| 703 | cursor.execute(sql, tuple(tag_ids)) |
| 704 | results = cursor.fetchall() |
| 705 | # 若无结果,回退title/content模糊匹配 |
| 706 | if not results: |
| 707 | or_conditions = [] |
| 708 | params = [] |
| 709 | for tag in all_tags: |
| 710 | or_conditions.append("p.title LIKE %s OR p.content LIKE %s") |
| 711 | params.extend(['%' + tag + '%', '%' + tag + '%']) |
| 712 | where_clause = ' OR '.join(or_conditions) |
| 713 | sql = f""" |
| 714 | SELECT p.id, p.title, tp.name as category, p.heat, |
| 715 | GROUP_CONCAT(tg.name) as tags |
| 716 | FROM posts p |
| 717 | LEFT JOIN post_tags pt ON p.id = pt.post_id |
| 718 | LEFT JOIN tags tg ON pt.tag_id = tg.id |
| 719 | LEFT JOIN topics tp ON p.topic_id = tp.id |
| 720 | WHERE {where_clause} |
| 721 | GROUP BY p.id |
| 722 | LIMIT 50 |
| 723 | """ |
| 724 | cursor.execute(sql, tuple(params)) |
| 725 | results = cursor.fetchall() |
| 726 | finally: |
| 727 | conn.close() |
| 728 | |
| 729 | if not results: |
| 730 | return Response(json.dumps({"error": "暂无推荐结果"}, ensure_ascii=False), mimetype='application/json; charset=utf-8'), 200 |
| 731 | |
| 732 | return Response(json.dumps({"recommendations": results}, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 733 | |
| 734 | # 用户兴趣标签管理API(可选) |
| 735 | @app.route('/tags', methods=['POST', 'GET', 'DELETE']) |
| 736 | def user_tags(): |
| 737 | """ |
| 738 | POST: 添加用户兴趣标签 |
| 739 | GET: 查询用户兴趣标签 |
| 740 | DELETE: 删除用户兴趣标签 |
| 741 | """ |
| 742 | if request.method == 'POST': |
| 743 | if request.content_type != 'application/json': |
| 744 | return jsonify({"error": "Content-Type must be application/json"}), 415 |
| 745 | data = request.get_json() |
| 746 | user_id = data.get("user_id") |
| 747 | tags = data.get("tags", []) |
| 748 | |
| 749 | if not user_id: |
| 750 | return jsonify({"error": "用户ID不能为空"}), 400 |
| 751 | |
| 752 | # 确保标签列表格式正确 |
| 753 | if isinstance(tags, str): |
| 754 | tags = [tag.strip() for tag in tags.split(',') if tag.strip()] |
| 755 | |
| 756 | if not tags: |
| 757 | return jsonify({"error": "标签不能为空"}), 400 |
| 758 | |
| 759 | conn = get_db_conn() |
| 760 | try: |
| 761 | with conn.cursor() as cursor: |
| 762 | # 添加用户标签 |
| 763 | for tag in tags: |
| 764 | # 先查找tag_id |
| 765 | cursor.execute("SELECT id FROM tags WHERE name=%s", (tag,)) |
| 766 | tag_row = cursor.fetchone() |
| 767 | if tag_row: |
| 768 | tag_id = tag_row[0] |
| 769 | cursor.execute("REPLACE INTO user_tags (user_id, tag_id) VALUES (%s, %s)", (user_id, tag_id)) |
| 770 | conn.commit() |
| 771 | # 返回更新后的标签列表 |
| 772 | cursor.execute("SELECT t.name FROM user_tags ut JOIN tags t ON ut.tag_id = t.id WHERE ut.user_id=%s", (user_id,)) |
| 773 | updated_tags = [row[0] for row in cursor.fetchall()] |
| 774 | finally: |
| 775 | conn.close() |
| 776 | return Response(json.dumps({"msg": "添加成功", "tags": updated_tags}, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 777 | elif request.method == 'DELETE': |
| 778 | if request.content_type != 'application/json': |
| 779 | return jsonify({"error": "Content-Type must be application/json"}), 415 |
| 780 | data = request.get_json() |
| 781 | user_id = data.get("user_id") |
| 782 | tags = data.get("tags", []) |
| 783 | if not user_id: |
| 784 | return jsonify({"error": "用户ID不能为空"}), 400 |
| 785 | if not tags: |
| 786 | return jsonify({"error": "标签不能为空"}), 400 |
| 787 | |
| 788 | conn = get_db_conn() |
| 789 | try: |
| 790 | with conn.cursor() as cursor: |
| 791 | for tag in tags: |
| 792 | cursor.execute("SELECT id FROM tags WHERE name=%s", (tag,)) |
| 793 | tag_row = cursor.fetchone() |
| 794 | if tag_row: |
| 795 | tag_id = tag_row[0] |
| 796 | cursor.execute("DELETE FROM user_tags WHERE user_id=%s AND tag_id=%s", (user_id, tag_id)) |
| 797 | conn.commit() |
| 798 | cursor.execute("SELECT t.name FROM user_tags ut JOIN tags t ON ut.tag_id = t.id WHERE ut.user_id=%s", (user_id,)) |
| 799 | remaining_tags = [row[0] for row in cursor.fetchall()] |
| 800 | finally: |
| 801 | conn.close() |
| 802 | return Response(json.dumps({"msg": "删除成功", "tags": remaining_tags}, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 803 | else: # GET 请求 |
| 804 | user_id = request.args.get("user_id") |
| 805 | if not user_id: |
| 806 | return jsonify({"error": "用户ID不能为空"}), 400 |
| 807 | conn = get_db_conn() |
| 808 | try: |
| 809 | with conn.cursor() as cursor: |
| 810 | cursor.execute("SELECT t.name FROM user_tags ut JOIN tags t ON ut.tag_id = t.id WHERE ut.user_id=%s", (user_id,)) |
| 811 | tags = [row[0] for row in cursor.fetchall()] |
| 812 | finally: |
| 813 | conn.close() |
| 814 | return Response(json.dumps({"tags": tags}, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 815 | |
| 816 | # 添加/user_tags路由作为/tags的别名 |
| 817 | @app.route('/user_tags', methods=['POST', 'GET', 'DELETE']) |
| 818 | def user_tags_alias(): |
| 819 | """ |
| 820 | /user_tags路由 - 作为/tags路由的别名 |
| 821 | POST: 添加用户兴趣标签 |
| 822 | GET: 查询用户兴趣标签 |
| 823 | DELETE: 删除用户兴趣标签 |
| 824 | """ |
| 825 | return user_tags() |
| 826 | |
| 827 | # 基于用户的协同过滤推荐API |
| 828 | @app.route('/user_based_recommend', methods=['POST']) |
| 829 | def user_based_recommend(): |
| 830 | """ |
| 831 | 基于用户的协同过滤推荐API |
| 832 | 请求格式:{ |
| 833 | "user_id": "user1", |
| 834 | "top_n": 5 |
| 835 | } |
| 836 | """ |
| 837 | if request.content_type != 'application/json': |
| 838 | return jsonify({"error": "Content-Type must be application/json"}), 415 |
| 839 | |
| 840 | data = request.get_json() |
| 841 | user_id = data.get("user_id") |
| 842 | top_n = int(data.get("top_n", 5)) |
| 843 | |
| 844 | if not user_id: |
| 845 | return jsonify({"error": "用户ID不能为空"}), 400 |
| 846 | |
| 847 | conn = get_db_conn() |
| 848 | try: |
| 849 | with conn.cursor(pymysql.cursors.DictCursor) as cursor: |
| 850 | # 1. 检查用户是否存在下载记录(收藏或浏览) |
| 851 | cursor.execute(""" |
| 852 | SELECT COUNT(*) as count |
| 853 | FROM behaviors |
| 854 | WHERE user_id = %s AND type IN ('favorite', 'view') |
| 855 | """, (user_id,)) |
| 856 | result = cursor.fetchone() |
| 857 | user_download_count = result['count'] if result else 0 |
| 858 | |
| 859 | logger.info(f"用户 {user_id} 下载记录数: {user_download_count}") |
| 860 | |
| 861 | # 如果用户没有足够的行为数据,返回基于热度的推荐 |
| 862 | if user_download_count < 3: |
| 863 | logger.info(f"用户 {user_id} 下载记录不足,返回热门推荐") |
| 864 | cursor.execute(""" |
| 865 | SELECT p.id, p.title, tp.name as category, p.heat |
| 866 | FROM posts p |
| 867 | LEFT JOIN topics tp ON p.topic_id = tp.id |
| 868 | ORDER BY p.heat DESC |
| 869 | LIMIT %s |
| 870 | """, (top_n,)) |
| 871 | popular_seeds = cursor.fetchall() |
| 872 | return Response(json.dumps({"recommendations": popular_seeds, "type": "popular"}, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 873 | |
| 874 | # 2. 获取用户已下载(收藏/浏览)的帖子 |
| 875 | cursor.execute(""" |
| 876 | SELECT post_id |
| 877 | FROM behaviors |
| 878 | WHERE user_id = %s AND type IN ('favorite', 'view') |
| 879 | """, (user_id,)) |
| 880 | user_seeds = set(row['post_id'] for row in cursor.fetchall()) |
| 881 | logger.info(f"用户 {user_id} 已下载种子: {user_seeds}") |
| 882 | |
| 883 | # 3. 获取所有用户-帖子下载(收藏/浏览)矩阵 |
| 884 | cursor.execute(""" |
| 885 | SELECT user_id, post_id |
| 886 | FROM behaviors |
| 887 | WHERE created_at > DATE_SUB(NOW(), INTERVAL 3 MONTH) |
| 888 | AND user_id <> %s AND type IN ('favorite', 'view') |
| 889 | """, (user_id,)) |
| 890 | download_records = cursor.fetchall() |
| 891 | |
| 892 | if not download_records: |
| 893 | logger.info(f"没有其他用户的下载记录,返回热门推荐") |
| 894 | cursor.execute(""" |
| 895 | SELECT p.id, p.title, tp.name as category, p.heat |
| 896 | FROM posts p |
| 897 | LEFT JOIN topics tp ON p.topic_id = tp.id |
| 898 | ORDER BY p.heat DESC |
| 899 | LIMIT %s |
| 900 | """, (top_n,)) |
| 901 | popular_seeds = cursor.fetchall() |
| 902 | return Response(json.dumps({"recommendations": popular_seeds, "type": "popular"}, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 903 | |
| 904 | # 构建用户-物品矩阵 |
| 905 | user_item_matrix = {} |
| 906 | for record in download_records: |
| 907 | uid = record['user_id'] |
| 908 | sid = record['post_id'] |
| 909 | if uid not in user_item_matrix: |
| 910 | user_item_matrix[uid] = set() |
| 911 | user_item_matrix[uid].add(sid) |
| 912 | |
| 913 | # 4. 计算用户相似度 |
| 914 | similar_users = [] |
| 915 | for other_id, other_seeds in user_item_matrix.items(): |
| 916 | if other_id == user_id: |
| 917 | continue |
| 918 | intersection = len(user_seeds.intersection(other_seeds)) |
| 919 | union = len(user_seeds.union(other_seeds)) |
| 920 | if union > 0 and intersection > 0: |
| 921 | similarity = intersection / union |
| 922 | similar_users.append((other_id, similarity, other_seeds)) |
| 923 | logger.info(f"找到 {len(similar_users)} 个相似用户") |
| 924 | similar_users.sort(key=lambda x: x[1], reverse=True) |
| 925 | similar_users = similar_users[:5] |
| 926 | # 5. 基于相似用户推荐帖子 |
| 927 | candidate_seeds = {} |
| 928 | for similar_user, similarity, seeds in similar_users: |
| 929 | logger.info(f"相似用户 {similar_user}, 相似度 {similarity}") |
| 930 | for post_id in seeds: |
| 931 | if post_id not in user_seeds: |
| 932 | if post_id not in candidate_seeds: |
| 933 | candidate_seeds[post_id] = 0 |
| 934 | candidate_seeds[post_id] += similarity |
| 935 | if not candidate_seeds: |
| 936 | logger.info(f"没有找到候选种子,返回热门推荐") |
| 937 | cursor.execute(""" |
| 938 | SELECT p.id, p.title, tp.name as category, p.heat |
| 939 | FROM posts p |
| 940 | LEFT JOIN topics tp ON p.topic_id = tp.id |
| 941 | ORDER BY p.heat DESC |
| 942 | LIMIT %s |
| 943 | """, (top_n,)) |
| 944 | popular_seeds = cursor.fetchall() |
| 945 | return Response(json.dumps({"recommendations": popular_seeds, "type": "popular"}, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 946 | # 6. 获取推荐帖子的详细信息 |
| 947 | recommended_seeds = sorted(candidate_seeds.items(), key=lambda x: x[1], reverse=True)[:top_n] |
| 948 | post_ids = [post_id for post_id, _ in recommended_seeds] |
| 949 | format_strings = ','.join(['%s'] * len(post_ids)) |
| 950 | cursor.execute(f""" |
| 951 | SELECT p.id, p.title, tp.name as category, p.heat |
| 952 | FROM posts p |
| 953 | LEFT JOIN topics tp ON p.topic_id = tp.id |
| 954 | WHERE p.id IN ({format_strings}) |
| 955 | """, tuple(post_ids)) |
| 956 | result_seeds = cursor.fetchall() |
| 957 | seed_score_map = {post_id: score for post_id, score in recommended_seeds} |
| 958 | result_seeds.sort(key=lambda x: seed_score_map.get(x['id'], 0), reverse=True) |
| 959 | logger.info(f"返回 {len(result_seeds)} 个基于协同过滤的推荐") |
| 960 | return Response(json.dumps({"recommendations": result_seeds, "type": "collaborative"}, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 961 | except Exception as e: |
| 962 | logger.error(f"推荐系统错误: {e}") |
| 963 | import traceback |
| 964 | traceback.print_exc() |
| 965 | return Response(json.dumps({"error": "推荐系统异常,请稍后再试", "details": str(e)}, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 966 | finally: |
| 967 | conn.close() |
| 968 | @app.route('/word2vec_status', methods=['GET']) |
| 969 | def word2vec_status(): |
| 970 | """ |
| 971 | 检查Word2Vec模型状态 |
| 972 | 返回模型是否加载、词汇量等信息 |
| 973 | """ |
| 974 | if not WORD2VEC_ENABLED: |
| 975 | return Response(json.dumps({ |
| 976 | "enabled": False, |
| 977 | "message": "Word2Vec功能未启用" |
| 978 | }, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 979 | try: |
| 980 | helper = get_word2vec_helper() |
| 981 | status = { |
| 982 | "enabled": WORD2VEC_ENABLED, |
| 983 | "initialized": helper.initialized, |
| 984 | "vocab_size": len(helper.model.index_to_key) if helper.model else 0, |
| 985 | "vector_size": helper.model.vector_size if helper.model else 0 |
| 986 | } |
| 987 | |
| 988 | # 测试几个常用词的相似词,展示模型效果 |
| 989 | test_results = {} |
| 990 | test_words = ["电影", "动作", "科幻", "动漫", "游戏"] |
| 991 | for word in test_words: |
| 992 | similar_words = helper.get_similar_words(word, topn=5) |
| 993 | test_results[word] = similar_words |
| 994 | |
| 995 | status["test_results"] = test_results |
| 996 | return Response(json.dumps(status, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 997 | except Exception as e: |
| 998 | return Response(json.dumps({ |
| 999 | "enabled": WORD2VEC_ENABLED, |
| 1000 | "initialized": False, |
| 1001 | "error": str(e) |
| 1002 | }, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 1003 | |
| 1004 | # 添加一个临时诊断端点 |
| 1005 | @app.route('/debug_search', methods=['POST']) |
| 1006 | def debug_search(): |
| 1007 | """临时的调试端点,用于检查数据库中的记录""" |
| 1008 | if request.content_type != 'application/json': |
| 1009 | return jsonify({"error": "Content-Type must be application/json"}), 415 |
| 1010 | |
| 1011 | data = request.get_json() |
| 1012 | keyword = data.get("keyword", "").strip() |
| 1013 | |
| 1014 | conn = get_db_conn() |
| 1015 | try: |
| 1016 | with conn.cursor(pymysql.cursors.DictCursor) as cursor: |
| 1017 | # 尝试查询包含特定词的所有记录 |
| 1018 | queries = [ |
| 1019 | ("标题中包含关键词", f"SELECT seed_id, title, description, tags FROM pt_seed WHERE title LIKE '%{keyword}%' LIMIT 10"), |
| 1020 | ("描述中包含关键词", f"SELECT seed_id, title, description, tags FROM pt_seed WHERE description LIKE '%{keyword}%' LIMIT 10"), |
| 1021 | ("标签中包含关键词", f"SELECT seed_id, title, description, tags FROM pt_seed WHERE FIND_IN_SET('{keyword}', tags) LIMIT 10"), |
| 1022 | ("肖申克的救赎", "SELECT seed_id, title, description, tags FROM pt_seed WHERE title = '肖申克的救赎'") |
| 1023 | ] |
| 1024 | |
| 1025 | results = {} |
| 1026 | for query_name, query in queries: |
| 1027 | cursor.execute(query) |
| 1028 | results[query_name] = cursor.fetchall() |
| 1029 | |
| 1030 | return Response(json.dumps(results, ensure_ascii=False), mimetype='application/json; charset=utf-8') |
| 1031 | finally: |
| 1032 | conn.close() |
| 1033 | |
| 1034 | """ |
| 1035 | 接口本地测试方法(可直接运行main_online.py后用curl或Postman测试): |
| 1036 | |
| 1037 | 1. 搜索接口 |
| 1038 | curl -X POST http://127.0.0.1:5000/search -H "Content-Type: application/json" -d '{"keyword":"电影","sort_by":"downloads"}' |
| 1039 | |
| 1040 | 2. 标签推荐接口 |
| 1041 | curl -X POST http://127.0.0.1:5000/recommend_tags -H "Content-Type: application/json" -d '{"user_id":"1","tags":["动作","科幻"]}' |
| 1042 | |
| 1043 | 3. 用户兴趣标签管理(添加标签) |
| 1044 | curl -X POST http://127.0.0.1:5000/user_tags -H "Content-Type: application/json" -d '{"user_id":"1","tags":["动作","科幻"]}' |
| 1045 | |
| 1046 | 4. 用户兴趣标签管理(查询标签) |
| 1047 | curl "http://127.0.0.1:5000/user_tags?user_id=1" |
| 1048 | |
| 1049 | 5. 用户兴趣标签管理(删除标签) |
| 1050 | curl -X DELETE http://127.0.0.1:5000/user_tags -H "Content-Type: application/json" -d '{"user_id":"1","tags":["动作","科幻"]}' |
| 1051 | |
| 1052 | 6. 协同过滤推荐 |
| 1053 | curl -X POST http://127.0.0.1:5000/user_based_recommend -H "Content-Type: application/json" -d '{"user_id":"user1","top_n":3}' |
| 1054 | |
| 1055 | 7. Word2Vec状态检查 |
| 1056 | curl "http://127.0.0.1:5000/word2vec_status" |
| 1057 | |
| 1058 | 8. 调试接口(临时) |
| 1059 | curl -X POST http://127.0.0.1:5000/debug_search -H "Content-Type: application/json" -d '{"keyword":"电影"}' |
| 1060 | |
| 1061 | 所有接口均可用Postman按上述参数测试。 |
| 1062 | """ |
| 1063 | |
| 1064 | if __name__ == "__main__": |
| 1065 | try: |
| 1066 | logger.info("搜索推荐服务启动中...") |
| 1067 | app.run(host="0.0.0.0", port=5000) |
| 1068 | except Exception as e: |
| 1069 | logger.error(f"启动异常: {e}") |
| 1070 | import traceback |
| 1071 | traceback.print_exc() |