推荐系统更新
Change-Id: I0c1cd2201bd3baae442b5fd11f36e73c001a7173
diff --git a/recommend/inference.py b/recommend/inference.py
index 697b569..346209c 100644
--- a/recommend/inference.py
+++ b/recommend/inference.py
@@ -3,7 +3,7 @@
from os import path
from utils.parse_args import args
-from utils.dataloader import EdgeListData
+from utils.data_loader import EdgeListData
from model.LightGCN import LightGCN
import torch
import numpy as np
@@ -13,15 +13,14 @@
t_start = time.time()
# 配置参数
-args.data_path = './'
args.device = 'cuda:7'
+args.data_path = './user_seed_graph.txt'
args.pre_model_path = './model/LightGCN_pretrained.pt'
+
# 1. 加载数据集
t_data_start = time.time()
-pretrain_data = path.join(args.data_path, "uig.txt")
-pretrain_val_data = path.join(args.data_path, "uig.txt")
-dataset = EdgeListData(pretrain_data, pretrain_val_data)
+dataset = EdgeListData(args.data_path, args.data_path)
t_data_end = time.time()
diff --git a/recommend/uig.txt b/recommend/uig.txt
deleted file mode 100644
index 5846057..0000000
--- a/recommend/uig.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-0 1 3 9 12 5 7 6 8 4 1511683379 1511683385 1511683431 1511683453 1511683481 1511692992 1511693011 1511693077 1511787191
-1 10 11 2 1511578239 1511594732 1511664627
\ No newline at end of file
diff --git a/recommend/user_seed_graph.txt b/recommend/user_seed_graph.txt
new file mode 100644
index 0000000..ef29920
--- /dev/null
+++ b/recommend/user_seed_graph.txt
@@ -0,0 +1,3 @@
+0 0 0 1 2 1746061954 1736237924 1736240066 1736309966
+1 1 1 1746315010 1746583706
+2 2 0 2 1746738305 1746865166 1749284366
diff --git a/recommend/utils/data_generator.py b/recommend/utils/data_generator.py
new file mode 100644
index 0000000..e50b4e9
--- /dev/null
+++ b/recommend/utils/data_generator.py
@@ -0,0 +1,87 @@
+import pymysql
+import datetime
+from collections import defaultdict
+
+SqlURL = "10.126.59.25"
+SqlPort = 3306
+Database = "pt_database_test"
+SqlUsername = "root"
+SqlPassword = "123456"
+
+
+def fetch_data():
+ conn = pymysql.connect(
+ host=SqlURL,
+ port=SqlPort,
+ user=SqlUsername,
+ password=SqlPassword,
+ database=Database,
+ charset="utf8mb4"
+ )
+ cursor = conn.cursor()
+ cursor.execute("SELECT user_id, seed_id, download_start FROM SeedDownload")
+ download_rows = cursor.fetchall()
+ cursor.execute("SELECT user_id, seed_id, created_at FROM UserFavorite")
+ favorite_rows = cursor.fetchall()
+ cursor.close()
+ conn.close()
+ return download_rows, favorite_rows
+
+
+def process_records(download_rows, favorite_rows):
+ records = []
+ user_set = set()
+ seed_set = set()
+ for row in download_rows:
+ user_id, seed_id, created_at = row
+ user_set.add(user_id)
+ seed_set.add(seed_id)
+ if isinstance(created_at, datetime.datetime):
+ ts = int(created_at.timestamp())
+ else:
+ ts = 0
+ records.append((user_id, seed_id, ts))
+ for row in favorite_rows:
+ user_id, seed_id, created_at = row
+ user_set.add(user_id)
+ seed_set.add(seed_id)
+ if isinstance(created_at, datetime.datetime):
+ ts = int(created_at.timestamp())
+ else:
+ ts = 0
+ records.append((user_id, seed_id, ts))
+ return records, user_set, seed_set
+
+
+def build_id_maps(user_set, seed_set):
+ user2idx = {uid: idx for idx, uid in enumerate(sorted(user_set))}
+ seed2idx = {sid: idx for idx, sid in enumerate(sorted(seed_set))}
+ return user2idx, seed2idx
+
+
+def group_and_write(records, user2idx, seed2idx, output_path="./user_seed_graph.txt"):
+ user_items = defaultdict(list)
+ user_times = defaultdict(list)
+ for user_id, seed_id, ts in records:
+ uid = user2idx[user_id]
+ sid = seed2idx[seed_id]
+ user_items[uid].append(sid)
+ user_times[uid].append(ts)
+ print(user_items)
+ print(user_times)
+ with open(output_path, "w", encoding="utf-8") as f:
+ for uid in sorted(user_items.keys()):
+ items = " ".join(str(item) for item in user_items[uid])
+ times = " ".join(str(t) for t in user_times[uid])
+ f.write(f"{uid}\t{items}\t{times}\n")
+
+
+def main():
+ download_rows, favorite_rows = fetch_data()
+ records, user_set, seed_set = process_records(download_rows, favorite_rows)
+ user2idx, seed2idx = build_id_maps(user_set, seed_set)
+ group_and_write(records, user2idx, seed2idx)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/recommend/utils/dataloader.py b/recommend/utils/data_loader.py
similarity index 100%
rename from recommend/utils/dataloader.py
rename to recommend/utils/data_loader.py