22301008 | cae762d | 2025-06-14 00:27:04 +0800 | [diff] [blame^] | 1 | # word2vec_helper.py |
| 2 | # Word2Vec模型加载与使用的辅助模块 |
| 3 | |
| 4 | import os |
| 5 | import numpy as np |
| 6 | from gensim.models import KeyedVectors, Word2Vec |
| 7 | import jieba |
| 8 | import logging |
| 9 | import time |
| 10 | |
| 11 | # 设置日志 |
| 12 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) |
| 13 | |
| 14 | class Word2VecHelper: |
| 15 | def __init__(self, model_path=None): |
| 16 | """ |
| 17 | 初始化Word2Vec辅助类 |
| 18 | |
| 19 | 参数: |
| 20 | model_path: 预训练模型路径,支持word2vec格式和二进制格式 |
| 21 | 如果为None,将使用默认路径或尝试下载小型模型 |
| 22 | """ |
| 23 | self.model = None |
| 24 | |
| 25 | # 更改默认模型路径和备用选项 |
| 26 | if model_path: |
| 27 | self.model_path = model_path |
| 28 | else: |
| 29 | # 首选路径 - 大型腾讯模型 |
| 30 | primary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), |
| 31 | "models", "chinese_word2vec.bin") |
| 32 | |
| 33 | # 备用路径 - 小型模型 |
| 34 | backup_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), |
| 35 | "models", "chinese_word2vec_small.bin") |
| 36 | |
| 37 | if os.path.exists(primary_path): |
| 38 | self.model_path = primary_path |
| 39 | elif os.path.exists(backup_path): |
| 40 | self.model_path = backup_path |
| 41 | else: |
| 42 | # 如果都不存在,可以尝试自动下载小模型 |
| 43 | self.model_path = primary_path |
| 44 | self._try_download_small_model() |
| 45 | |
| 46 | self.initialized = False |
| 47 | # 缓存查询结果,提高性能 |
| 48 | self.similarity_cache = {} |
| 49 | self.similar_words_cache = {} |
| 50 | |
| 51 | def _try_download_small_model(self): |
| 52 | """尝试下载小型词向量模型作为备用选项""" |
| 53 | try: |
| 54 | import gensim.downloader as api |
| 55 | logging.info("尝试下载小型中文词向量模型...") |
| 56 | |
| 57 | # 创建模型目录 |
| 58 | os.makedirs(os.path.dirname(self.model_path), exist_ok=True) |
| 59 | |
| 60 | # 尝试下载fastText的小型中文模型 |
| 61 | small_model = api.load("fasttext-wiki-news-subwords-300") |
| 62 | small_model.save(self.model_path.replace(".bin", "_small.bin")) |
| 63 | logging.info(f"小型模型已下载并保存到 {self.model_path}") |
| 64 | except Exception as e: |
| 65 | logging.error(f"无法下载备用模型: {e}") |
| 66 | |
| 67 | def load_model(self): |
| 68 | """加载Word2Vec模型""" |
| 69 | try: |
| 70 | start_time = time.time() |
| 71 | logging.info(f"开始加载Word2Vec模型: {self.model_path}") |
| 72 | |
| 73 | # 判断文件扩展名,选择合适的加载方式 |
| 74 | if self.model_path.endswith('.bin'): |
| 75 | # 加载二进制格式的模型 |
| 76 | self.model = KeyedVectors.load_word2vec_format(self.model_path, binary=True) |
| 77 | else: |
| 78 | # 加载文本格式的模型或gensim模型 |
| 79 | self.model = Word2Vec.load(self.model_path).wv |
| 80 | |
| 81 | self.initialized = True |
| 82 | logging.info(f"Word2Vec模型加载完成,耗时 {time.time() - start_time:.2f} 秒") |
| 83 | logging.info(f"词向量维度: {self.model.vector_size}") |
| 84 | logging.info(f"词汇表大小: {len(self.model.index_to_key)}") |
| 85 | return True |
| 86 | except Exception as e: |
| 87 | logging.error(f"加载Word2Vec模型失败: {e}") |
| 88 | self.initialized = False |
| 89 | return False |
| 90 | |
| 91 | def ensure_initialized(self): |
| 92 | """确保模型已初始化""" |
| 93 | if not self.initialized: |
| 94 | return self.load_model() |
| 95 | return True |
| 96 | |
| 97 | def get_similar_words(self, word, topn=10, min_similarity=0.5): |
| 98 | """ |
| 99 | 获取与给定词语最相似的词语列表 |
| 100 | |
| 101 | 参数: |
| 102 | word: 输入词语 |
| 103 | topn: 返回相似词的数量 |
| 104 | min_similarity: 最小相似度阈值 |
| 105 | 返回: |
| 106 | 相似词列表,如果词不存在或模型未加载则返回空列表 |
| 107 | """ |
| 108 | if not self.ensure_initialized(): |
| 109 | return [] |
| 110 | |
| 111 | # 检查缓存 |
| 112 | cache_key = f"{word}_{topn}_{min_similarity}" |
| 113 | if cache_key in self.similar_words_cache: |
| 114 | return self.similar_words_cache[cache_key] |
| 115 | |
| 116 | try: |
| 117 | # 如果词不在词汇表中,进行分词处理 |
| 118 | if word not in self.model.key_to_index: |
| 119 | # 对中文词进行分词,然后查找每个子词的相似词 |
| 120 | word_parts = list(jieba.cut(word)) |
| 121 | |
| 122 | if not word_parts: |
| 123 | return [] |
| 124 | |
| 125 | # 如果存在多个子词,找到存在于模型中的子词 |
| 126 | valid_parts = [w for w in word_parts if w in self.model.key_to_index] |
| 127 | |
| 128 | if not valid_parts: |
| 129 | return [] |
| 130 | |
| 131 | # 使用最长的有效子词或第一个有效子词 |
| 132 | valid_parts.sort(key=len, reverse=True) |
| 133 | word = valid_parts[0] |
| 134 | |
| 135 | # 如果替换后的词仍不在词汇表中,返回空列表 |
| 136 | if word not in self.model.key_to_index: |
| 137 | return [] |
| 138 | |
| 139 | # 获取相似词 |
| 140 | similar_words = self.model.most_similar(word, topn=topn*2) # 多获取一些,后续过滤 |
| 141 | |
| 142 | # 过滤低于阈值的结果,并只返回词语(不返回相似度) |
| 143 | filtered_words = [w for w, sim in similar_words if sim >= min_similarity][:topn] |
| 144 | |
| 145 | # 缓存结果 |
| 146 | self.similar_words_cache[cache_key] = filtered_words |
| 147 | return filtered_words |
| 148 | |
| 149 | except Exception as e: |
| 150 | logging.error(f"获取相似词失败: {e}, 词语: {word}") |
| 151 | return [] |
| 152 | |
| 153 | def calculate_similarity(self, word1, word2): |
| 154 | """ |
| 155 | 计算两个词的相似度 |
| 156 | |
| 157 | 参数: |
| 158 | word1, word2: 输入词语 |
| 159 | 返回: |
| 160 | 相似度分数(0-1),如果任意词不存在则返回0 |
| 161 | """ |
| 162 | if not self.ensure_initialized(): |
| 163 | return 0 |
| 164 | |
| 165 | # 检查缓存 |
| 166 | cache_key = f"{word1}_{word2}" |
| 167 | reverse_key = f"{word2}_{word1}" |
| 168 | |
| 169 | if cache_key in self.similarity_cache: |
| 170 | return self.similarity_cache[cache_key] |
| 171 | if reverse_key in self.similarity_cache: |
| 172 | return self.similarity_cache[reverse_key] |
| 173 | |
| 174 | try: |
| 175 | # 检查词是否在词汇表中 |
| 176 | if word1 not in self.model.key_to_index or word2 not in self.model.key_to_index: |
| 177 | return 0 |
| 178 | |
| 179 | similarity = self.model.similarity(word1, word2) |
| 180 | |
| 181 | # 缓存结果 |
| 182 | self.similarity_cache[cache_key] = similarity |
| 183 | return similarity |
| 184 | |
| 185 | except Exception as e: |
| 186 | logging.error(f"计算相似度失败: {e}, 词语: {word1}, {word2}") |
| 187 | return 0 |
| 188 | |
| 189 | def expand_query(self, query, topn=5, min_similarity=0.6): |
| 190 | """ |
| 191 | 扩展查询词,返回相关词汇 |
| 192 | |
| 193 | 参数: |
| 194 | query: 查询词 |
| 195 | topn: 每个词扩展的相似词数量 |
| 196 | min_similarity: 最小相似度阈值 |
| 197 | 返回: |
| 198 | 扩展后的词语列表 |
| 199 | """ |
| 200 | if not self.ensure_initialized(): |
| 201 | return [query] |
| 202 | |
| 203 | expanded_terms = [query] |
| 204 | |
| 205 | # 对查询进行分词 |
| 206 | words = list(jieba.cut(query)) |
| 207 | |
| 208 | # 为每个词找相似词 |
| 209 | for word in words: |
| 210 | if len(word) <= 1: # 忽略单字,减少噪音 |
| 211 | continue |
| 212 | |
| 213 | similar_words = self.get_similar_words(word, topn=topn, min_similarity=min_similarity) |
| 214 | expanded_terms.extend(similar_words) |
| 215 | |
| 216 | # 确保唯一性 |
| 217 | return list(set(expanded_terms)) |
| 218 | |
| 219 | # 单例模式,全局使用一个模型实例 |
| 220 | _word2vec_helper = None |
| 221 | |
| 222 | def get_word2vec_helper(model_path=None): |
| 223 | """获取Word2Vec辅助类的全局单例""" |
| 224 | global _word2vec_helper |
| 225 | if _word2vec_helper is None: |
| 226 | _word2vec_helper = Word2VecHelper(model_path) |
| 227 | _word2vec_helper.ensure_initialized() |
| 228 | return _word2vec_helper |
| 229 | |
| 230 | # 便捷函数,方便直接调用 |
| 231 | def get_similar_words(word, topn=10, min_similarity=0.5): |
| 232 | """获取相似词的便捷函数""" |
| 233 | helper = get_word2vec_helper() |
| 234 | return helper.get_similar_words(word, topn, min_similarity) |
| 235 | |
| 236 | def calculate_similarity(word1, word2): |
| 237 | """计算相似度的便捷函数""" |
| 238 | helper = get_word2vec_helper() |
| 239 | return helper.calculate_similarity(word1, word2) |
| 240 | |
| 241 | def expand_query(query, topn=5, min_similarity=0.6): |
| 242 | """扩展查询的便捷函数""" |
| 243 | helper = get_word2vec_helper() |
| 244 | return helper.expand_query(query, topn, min_similarity) |
| 245 | |
| 246 | # 使用示例 |
| 247 | if __name__ == "__main__": |
| 248 | # 测试模型加载和词语相似度 |
| 249 | helper = get_word2vec_helper() |
| 250 | |
| 251 | # 测试词 |
| 252 | test_words = ["电影", "功夫", "熊猫", "科幻", "漫威"] |
| 253 | |
| 254 | for word in test_words: |
| 255 | print(f"\n{word} 的相似词:") |
| 256 | similar = helper.get_similar_words(word, topn=5) |
| 257 | for sim_word in similar: |
| 258 | print(f" - {sim_word}") |
| 259 | |
| 260 | # 测试相似度计算 |
| 261 | word_pairs = [ |
| 262 | ("电影", "电视"), |
| 263 | ("功夫", "武术"), |
| 264 | ("科幻", "未来"), |
| 265 | ("漫威", "超级英雄") |
| 266 | ] |
| 267 | |
| 268 | print("\n词语相似度:") |
| 269 | for w1, w2 in word_pairs: |
| 270 | sim = helper.calculate_similarity(w1, w2) |
| 271 | print(f" {w1} <-> {w2}: {sim:.4f}") |
| 272 | |
| 273 | # 测试查询扩展 |
| 274 | test_queries = ["功夫熊猫", "科幻电影", "漫威英雄"] |
| 275 | |
| 276 | print("\n查询扩展:") |
| 277 | for query in test_queries: |
| 278 | expanded = helper.expand_query(query) |
| 279 | print(f" {query} -> {expanded}") |