blob: ecd1a7273f8414f4833a260198cfd352c620399e [file] [log] [blame]
# word2vec_helper.py
# Word2Vec模型加载与使用的辅助模块
import os
import numpy as np
from gensim.models import KeyedVectors, Word2Vec
import jieba
import logging
import time
# 设置日志
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
class Word2VecHelper:
def __init__(self, model_path=None):
"""
初始化Word2Vec辅助类
参数:
model_path: 预训练模型路径,支持word2vec格式和二进制格式
如果为None,将使用默认路径或尝试下载小型模型
"""
self.model = None
# 更改默认模型路径和备用选项
if model_path:
self.model_path = model_path
else:
# 首选路径 - 大型腾讯模型
primary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"models", "chinese_word2vec.bin")
# 备用路径 - 小型模型
backup_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"models", "chinese_word2vec_small.bin")
if os.path.exists(primary_path):
self.model_path = primary_path
elif os.path.exists(backup_path):
self.model_path = backup_path
else:
# 如果都不存在,可以尝试自动下载小模型
self.model_path = primary_path
self._try_download_small_model()
self.initialized = False
# 缓存查询结果,提高性能
self.similarity_cache = {}
self.similar_words_cache = {}
def _try_download_small_model(self):
"""尝试下载小型词向量模型作为备用选项"""
try:
import gensim.downloader as api
logging.info("尝试下载小型中文词向量模型...")
# 创建模型目录
os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
# 尝试下载fastText的小型中文模型
small_model = api.load("fasttext-wiki-news-subwords-300")
small_model.save(self.model_path.replace(".bin", "_small.bin"))
logging.info(f"小型模型已下载并保存到 {self.model_path}")
except Exception as e:
logging.error(f"无法下载备用模型: {e}")
def load_model(self):
"""加载Word2Vec模型"""
try:
start_time = time.time()
logging.info(f"开始加载Word2Vec模型: {self.model_path}")
# 判断文件扩展名,选择合适的加载方式
if self.model_path.endswith('.bin'):
# 加载二进制格式的模型
self.model = KeyedVectors.load_word2vec_format(self.model_path, binary=True)
else:
# 加载文本格式的模型或gensim模型
self.model = Word2Vec.load(self.model_path).wv
self.initialized = True
logging.info(f"Word2Vec模型加载完成,耗时 {time.time() - start_time:.2f} 秒")
logging.info(f"词向量维度: {self.model.vector_size}")
logging.info(f"词汇表大小: {len(self.model.index_to_key)}")
return True
except Exception as e:
logging.error(f"加载Word2Vec模型失败: {e}")
self.initialized = False
return False
def ensure_initialized(self):
"""确保模型已初始化"""
if not self.initialized:
return self.load_model()
return True
def get_similar_words(self, word, topn=10, min_similarity=0.5):
"""
获取与给定词语最相似的词语列表
参数:
word: 输入词语
topn: 返回相似词的数量
min_similarity: 最小相似度阈值
返回:
相似词列表,如果词不存在或模型未加载则返回空列表
"""
if not self.ensure_initialized():
return []
# 检查缓存
cache_key = f"{word}_{topn}_{min_similarity}"
if cache_key in self.similar_words_cache:
return self.similar_words_cache[cache_key]
try:
# 如果词不在词汇表中,进行分词处理
if word not in self.model.key_to_index:
# 对中文词进行分词,然后查找每个子词的相似词
word_parts = list(jieba.cut(word))
if not word_parts:
return []
# 如果存在多个子词,找到存在于模型中的子词
valid_parts = [w for w in word_parts if w in self.model.key_to_index]
if not valid_parts:
return []
# 使用最长的有效子词或第一个有效子词
valid_parts.sort(key=len, reverse=True)
word = valid_parts[0]
# 如果替换后的词仍不在词汇表中,返回空列表
if word not in self.model.key_to_index:
return []
# 获取相似词
similar_words = self.model.most_similar(word, topn=topn*2) # 多获取一些,后续过滤
# 过滤低于阈值的结果,并只返回词语(不返回相似度)
filtered_words = [w for w, sim in similar_words if sim >= min_similarity][:topn]
# 缓存结果
self.similar_words_cache[cache_key] = filtered_words
return filtered_words
except Exception as e:
logging.error(f"获取相似词失败: {e}, 词语: {word}")
return []
def calculate_similarity(self, word1, word2):
"""
计算两个词的相似度
参数:
word1, word2: 输入词语
返回:
相似度分数(0-1),如果任意词不存在则返回0
"""
if not self.ensure_initialized():
return 0
# 检查缓存
cache_key = f"{word1}_{word2}"
reverse_key = f"{word2}_{word1}"
if cache_key in self.similarity_cache:
return self.similarity_cache[cache_key]
if reverse_key in self.similarity_cache:
return self.similarity_cache[reverse_key]
try:
# 检查词是否在词汇表中
if word1 not in self.model.key_to_index or word2 not in self.model.key_to_index:
return 0
similarity = self.model.similarity(word1, word2)
# 缓存结果
self.similarity_cache[cache_key] = similarity
return similarity
except Exception as e:
logging.error(f"计算相似度失败: {e}, 词语: {word1}, {word2}")
return 0
def expand_query(self, query, topn=5, min_similarity=0.6):
"""
扩展查询词,返回相关词汇
参数:
query: 查询词
topn: 每个词扩展的相似词数量
min_similarity: 最小相似度阈值
返回:
扩展后的词语列表
"""
if not self.ensure_initialized():
return [query]
expanded_terms = [query]
# 对查询进行分词
words = list(jieba.cut(query))
# 为每个词找相似词
for word in words:
if len(word) <= 1: # 忽略单字,减少噪音
continue
similar_words = self.get_similar_words(word, topn=topn, min_similarity=min_similarity)
expanded_terms.extend(similar_words)
# 确保唯一性
return list(set(expanded_terms))
# 单例模式,全局使用一个模型实例
_word2vec_helper = None
def get_word2vec_helper(model_path=None):
"""获取Word2Vec辅助类的全局单例"""
global _word2vec_helper
if _word2vec_helper is None:
_word2vec_helper = Word2VecHelper(model_path)
_word2vec_helper.ensure_initialized()
return _word2vec_helper
# 便捷函数,方便直接调用
def get_similar_words(word, topn=10, min_similarity=0.5):
"""获取相似词的便捷函数"""
helper = get_word2vec_helper()
return helper.get_similar_words(word, topn, min_similarity)
def calculate_similarity(word1, word2):
"""计算相似度的便捷函数"""
helper = get_word2vec_helper()
return helper.calculate_similarity(word1, word2)
def expand_query(query, topn=5, min_similarity=0.6):
"""扩展查询的便捷函数"""
helper = get_word2vec_helper()
return helper.expand_query(query, topn, min_similarity)
# 使用示例
if __name__ == "__main__":
# 测试模型加载和词语相似度
helper = get_word2vec_helper()
# 测试词
test_words = ["电影", "功夫", "熊猫", "科幻", "漫威"]
for word in test_words:
print(f"\n{word} 的相似词:")
similar = helper.get_similar_words(word, topn=5)
for sim_word in similar:
print(f" - {sim_word}")
# 测试相似度计算
word_pairs = [
("电影", "电视"),
("功夫", "武术"),
("科幻", "未来"),
("漫威", "超级英雄")
]
print("\n词语相似度:")
for w1, w2 in word_pairs:
sim = helper.calculate_similarity(w1, w2)
print(f" {w1} <-> {w2}: {sim:.4f}")
# 测试查询扩展
test_queries = ["功夫熊猫", "科幻电影", "漫威英雄"]
print("\n查询扩展:")
for query in test_queries:
expanded = helper.expand_query(query)
print(f" {query} -> {expanded}")