blob: ecd1a7273f8414f4833a260198cfd352c620399e [file] [log] [blame]
22301008b2ef5192025-06-19 22:00:37 +08001# word2vec_helper.py
2# Word2Vec模型加载与使用的辅助模块
3
4import os
5import numpy as np
6from gensim.models import KeyedVectors, Word2Vec
7import jieba
8import logging
9import time
10
11# 设置日志
12logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
13
14class Word2VecHelper:
15 def __init__(self, model_path=None):
16 """
17 初始化Word2Vec辅助类
18
19 参数:
20 model_path: 预训练模型路径,支持word2vec格式和二进制格式
21 如果为None,将使用默认路径或尝试下载小型模型
22 """
23 self.model = None
24
25 # 更改默认模型路径和备用选项
26 if model_path:
27 self.model_path = model_path
28 else:
29 # 首选路径 - 大型腾讯模型
30 primary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
31 "models", "chinese_word2vec.bin")
32
33 # 备用路径 - 小型模型
34 backup_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
35 "models", "chinese_word2vec_small.bin")
36
37 if os.path.exists(primary_path):
38 self.model_path = primary_path
39 elif os.path.exists(backup_path):
40 self.model_path = backup_path
41 else:
42 # 如果都不存在,可以尝试自动下载小模型
43 self.model_path = primary_path
44 self._try_download_small_model()
45
46 self.initialized = False
47 # 缓存查询结果,提高性能
48 self.similarity_cache = {}
49 self.similar_words_cache = {}
50
51 def _try_download_small_model(self):
52 """尝试下载小型词向量模型作为备用选项"""
53 try:
54 import gensim.downloader as api
55 logging.info("尝试下载小型中文词向量模型...")
56
57 # 创建模型目录
58 os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
59
60 # 尝试下载fastText的小型中文模型
61 small_model = api.load("fasttext-wiki-news-subwords-300")
62 small_model.save(self.model_path.replace(".bin", "_small.bin"))
63 logging.info(f"小型模型已下载并保存到 {self.model_path}")
64 except Exception as e:
65 logging.error(f"无法下载备用模型: {e}")
66
67 def load_model(self):
68 """加载Word2Vec模型"""
69 try:
70 start_time = time.time()
71 logging.info(f"开始加载Word2Vec模型: {self.model_path}")
72
73 # 判断文件扩展名,选择合适的加载方式
74 if self.model_path.endswith('.bin'):
75 # 加载二进制格式的模型
76 self.model = KeyedVectors.load_word2vec_format(self.model_path, binary=True)
77 else:
78 # 加载文本格式的模型或gensim模型
79 self.model = Word2Vec.load(self.model_path).wv
80
81 self.initialized = True
82 logging.info(f"Word2Vec模型加载完成,耗时 {time.time() - start_time:.2f} 秒")
83 logging.info(f"词向量维度: {self.model.vector_size}")
84 logging.info(f"词汇表大小: {len(self.model.index_to_key)}")
85 return True
86 except Exception as e:
87 logging.error(f"加载Word2Vec模型失败: {e}")
88 self.initialized = False
89 return False
90
91 def ensure_initialized(self):
92 """确保模型已初始化"""
93 if not self.initialized:
94 return self.load_model()
95 return True
96
97 def get_similar_words(self, word, topn=10, min_similarity=0.5):
98 """
99 获取与给定词语最相似的词语列表
100
101 参数:
102 word: 输入词语
103 topn: 返回相似词的数量
104 min_similarity: 最小相似度阈值
105 返回:
106 相似词列表,如果词不存在或模型未加载则返回空列表
107 """
108 if not self.ensure_initialized():
109 return []
110
111 # 检查缓存
112 cache_key = f"{word}_{topn}_{min_similarity}"
113 if cache_key in self.similar_words_cache:
114 return self.similar_words_cache[cache_key]
115
116 try:
117 # 如果词不在词汇表中,进行分词处理
118 if word not in self.model.key_to_index:
119 # 对中文词进行分词,然后查找每个子词的相似词
120 word_parts = list(jieba.cut(word))
121
122 if not word_parts:
123 return []
124
125 # 如果存在多个子词,找到存在于模型中的子词
126 valid_parts = [w for w in word_parts if w in self.model.key_to_index]
127
128 if not valid_parts:
129 return []
130
131 # 使用最长的有效子词或第一个有效子词
132 valid_parts.sort(key=len, reverse=True)
133 word = valid_parts[0]
134
135 # 如果替换后的词仍不在词汇表中,返回空列表
136 if word not in self.model.key_to_index:
137 return []
138
139 # 获取相似词
140 similar_words = self.model.most_similar(word, topn=topn*2) # 多获取一些,后续过滤
141
142 # 过滤低于阈值的结果,并只返回词语(不返回相似度)
143 filtered_words = [w for w, sim in similar_words if sim >= min_similarity][:topn]
144
145 # 缓存结果
146 self.similar_words_cache[cache_key] = filtered_words
147 return filtered_words
148
149 except Exception as e:
150 logging.error(f"获取相似词失败: {e}, 词语: {word}")
151 return []
152
153 def calculate_similarity(self, word1, word2):
154 """
155 计算两个词的相似度
156
157 参数:
158 word1, word2: 输入词语
159 返回:
160 相似度分数(0-1),如果任意词不存在则返回0
161 """
162 if not self.ensure_initialized():
163 return 0
164
165 # 检查缓存
166 cache_key = f"{word1}_{word2}"
167 reverse_key = f"{word2}_{word1}"
168
169 if cache_key in self.similarity_cache:
170 return self.similarity_cache[cache_key]
171 if reverse_key in self.similarity_cache:
172 return self.similarity_cache[reverse_key]
173
174 try:
175 # 检查词是否在词汇表中
176 if word1 not in self.model.key_to_index or word2 not in self.model.key_to_index:
177 return 0
178
179 similarity = self.model.similarity(word1, word2)
180
181 # 缓存结果
182 self.similarity_cache[cache_key] = similarity
183 return similarity
184
185 except Exception as e:
186 logging.error(f"计算相似度失败: {e}, 词语: {word1}, {word2}")
187 return 0
188
189 def expand_query(self, query, topn=5, min_similarity=0.6):
190 """
191 扩展查询词,返回相关词汇
192
193 参数:
194 query: 查询词
195 topn: 每个词扩展的相似词数量
196 min_similarity: 最小相似度阈值
197 返回:
198 扩展后的词语列表
199 """
200 if not self.ensure_initialized():
201 return [query]
202
203 expanded_terms = [query]
204
205 # 对查询进行分词
206 words = list(jieba.cut(query))
207
208 # 为每个词找相似词
209 for word in words:
210 if len(word) <= 1: # 忽略单字,减少噪音
211 continue
212
213 similar_words = self.get_similar_words(word, topn=topn, min_similarity=min_similarity)
214 expanded_terms.extend(similar_words)
215
216 # 确保唯一性
217 return list(set(expanded_terms))
218
219# 单例模式,全局使用一个模型实例
220_word2vec_helper = None
221
222def get_word2vec_helper(model_path=None):
223 """获取Word2Vec辅助类的全局单例"""
224 global _word2vec_helper
225 if _word2vec_helper is None:
226 _word2vec_helper = Word2VecHelper(model_path)
227 _word2vec_helper.ensure_initialized()
228 return _word2vec_helper
229
230# 便捷函数,方便直接调用
231def get_similar_words(word, topn=10, min_similarity=0.5):
232 """获取相似词的便捷函数"""
233 helper = get_word2vec_helper()
234 return helper.get_similar_words(word, topn, min_similarity)
235
236def calculate_similarity(word1, word2):
237 """计算相似度的便捷函数"""
238 helper = get_word2vec_helper()
239 return helper.calculate_similarity(word1, word2)
240
241def expand_query(query, topn=5, min_similarity=0.6):
242 """扩展查询的便捷函数"""
243 helper = get_word2vec_helper()
244 return helper.expand_query(query, topn, min_similarity)
245
246# 使用示例
247if __name__ == "__main__":
248 # 测试模型加载和词语相似度
249 helper = get_word2vec_helper()
250
251 # 测试词
252 test_words = ["电影", "功夫", "熊猫", "科幻", "漫威"]
253
254 for word in test_words:
255 print(f"\n{word} 的相似词:")
256 similar = helper.get_similar_words(word, topn=5)
257 for sim_word in similar:
258 print(f" - {sim_word}")
259
260 # 测试相似度计算
261 word_pairs = [
262 ("电影", "电视"),
263 ("功夫", "武术"),
264 ("科幻", "未来"),
265 ("漫威", "超级英雄")
266 ]
267
268 print("\n词语相似度:")
269 for w1, w2 in word_pairs:
270 sim = helper.calculate_similarity(w1, w2)
271 print(f" {w1} <-> {w2}: {sim:.4f}")
272
273 # 测试查询扩展
274 test_queries = ["功夫熊猫", "科幻电影", "漫威英雄"]
275
276 print("\n查询扩展:")
277 for query in test_queries:
278 expanded = helper.expand_query(query)
279 print(f" {query} -> {expanded}")