推荐系统完成
Change-Id: I244590be01b1b4f37664a0e7f3103827e607ffbe
diff --git a/recommend/download_model.py b/recommend/download_model.py
index 1bb0c50..b9555dc 100644
--- a/recommend/download_model.py
+++ b/recommend/download_model.py
@@ -1,6 +1,6 @@
import os
import urllib.request
-from recommend import train_and_save_itemcf
+
MODEL_URL = "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zh.300.bin.gz"
MODEL_DIR = "./models"
MODEL_PATH = os.path.join(MODEL_DIR, "cc.zh.300.bin")
@@ -28,5 +28,4 @@
print("✅ 模型下载并解压完成!")
if __name__ == "__main__":
- train_and_save_itemcf()
download_model()
diff --git a/recommend/recommend.py b/recommend/recommend.py
index 25032a0..b216d52 100644
--- a/recommend/recommend.py
+++ b/recommend/recommend.py
@@ -15,7 +15,7 @@
engine = create_engine("mysql+pymysql://sy:sy_password@49.233.215.144:3306/pt_station")
# === ✅ 加载 fastText 模型 ===
-fasttext_model_path = 'E:\\course\\pt\\recommend\\models\\cc.zh.300.bin'
+fasttext_model_path = 'models\\cc.zh.300.bin'
if not os.path.exists(fasttext_model_path):
raise FileNotFoundError("fastText 模型文件不存在,请检查路径。")
print("加载 fastText 模型中...")
@@ -25,7 +25,7 @@
# === ✅ 用户标签行为矩阵构建 ===
def get_user_tag_matrix():
df = pd.read_sql("SELECT user_id, tag, score FROM user_tag_scores", engine)
- print(df)
+ #print(df)
df['user_id'] = df['user_id'].astype(str)
user_map = {u: i for i, u in enumerate(df['user_id'].unique())}
tag_map = {t: i for i, t in enumerate(df['tag'].unique())}
@@ -39,20 +39,27 @@
def semantic_recommend(user_id, topn=5):
print(f"正在为用户 {user_id} 生成推荐...")
- # 读取数据库中的用户标签数据
+ # 读取数据
df = pd.read_sql("SELECT user_id, tag, score FROM user_tag_scores", engine)
- print(f"总记录数: {len(df)}")
- print(f"数据示例:\n{df.head()}")
- print(df.dtypes)
- user_id = str(user_id) # 确保匹配
- # 获取该用户的所有标签(按分数从高到低排序)
+ # 统一类型转换
+ df['user_id'] = df['user_id'].astype(str) # 确保整个列转为字符串
+ user_id = str(user_id) # 要查询的ID也转为字符串
+
+ # 现在查询应该正常工作了
user_tags = df[df['user_id'] == user_id].sort_values(by="score", ascending=False)['tag'].tolist()
print(f"用户 {user_id} 的标签(按分数排序): {user_tags}")
if not user_tags:
print(f"用户 {user_id} 没有标签记录,返回空推荐结果。")
return []
+ else:
+ user_tags = user_tags[:3]
+ print(f"用户 {user_id} 的 Top 3 标签: {user_tags}")
+
+ if not user_tags:
+ print(f"用户 {user_id} 没有标签记录,返回空推荐结果。")
+ return []
# 截取前 3 个标签作为“兴趣标签”
user_tags = user_tags[:3]
@@ -85,8 +92,8 @@
# 排序并返回 topN 标签
sorted_tags = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:topn]
print(f"\n最终推荐标签(前 {topn}):")
- for tag, score in sorted_tags:
- print(f"{tag}: {score:.4f}")
+ #for tag, score in sorted_tags:
+ # print(f"{tag}: {score:.4f}")
return [tag for tag, _ in sorted_tags]
@@ -117,8 +124,6 @@
print(f"⚠️ 用户 {user_id} 没有任何标签评分记录。")
return []
- print(f"用户 {user_id} 的标签评分:\n{user_tags}")
-
scores = {}
for tag, val in user_tags.items():
if tag not in sim_df:
@@ -212,28 +217,32 @@
def get_torrent_ids_by_tags(tags, limit_per_tag=10):
if not tags:
tags = []
+ print(f"传递给 get_torrent_ids_by_tags 的标签: {tags}")
recommended_ids = set()
with engine.connect() as conn:
for tag in tags:
query = text("""
SELECT torrent_id
- FROM bt_torrent_tags
- WHERE tag = :tag
+ FROM bt_torrent_tags
+ WHERE tag = :tag
LIMIT :limit
""")
result = conn.execute(query, {"tag": tag, "limit": limit_per_tag})
+ print(f"标签 '{tag}' 的推荐结果:")
for row in result:
+ print(row[0]) # 打印每个torrent_id
recommended_ids.add(row[0])
# 获取数据库中所有 torrent_id
- all_query = text("SELECT DISTINCT torrent_id FROM bt_torrent_tags")
+ all_query = text("SELECT DISTINCT torrent_id FROM bt_torrent")
all_result = conn.execute(all_query)
all_ids = set(row[0] for row in all_result)
+ print("数据库中所有torrent_id:", all_ids)
# 剩下的(非推荐)种子 ID
remaining_ids = all_ids - recommended_ids
-
+ print(remaining_ids)
# 随机打乱推荐和剩下的 ID
recommended_list = list(recommended_ids)
remaining_list = list(remaining_ids)
diff --git a/recommend/requirements.txt b/recommend/requirements.txt
index 2efe47a..b609f83 100644
--- a/recommend/requirements.txt
+++ b/recommend/requirements.txt
@@ -3,7 +3,7 @@
huggingface_hub==0.31.2
jieba==0.42.1
mysql_connector_repackaged==0.3.1
-numpy==2.2.6
+numpy==1.26.4
pandas==2.2.3
scikit_learn==1.6.1
scikit_surprise==1.1.4
@@ -16,3 +16,4 @@
torch==2.7.0
transformers==4.51.3
waitress==3.0.2
+pymysql
\ No newline at end of file