推荐系统更新
Change-Id: I0c1cd2201bd3baae442b5fd11f36e73c001a7173
diff --git a/recommend/utils/data_generator.py b/recommend/utils/data_generator.py
new file mode 100644
index 0000000..e50b4e9
--- /dev/null
+++ b/recommend/utils/data_generator.py
@@ -0,0 +1,87 @@
+import pymysql
+import datetime
+from collections import defaultdict
+
+SqlURL = "10.126.59.25"
+SqlPort = 3306
+Database = "pt_database_test"
+SqlUsername = "root"
+SqlPassword = "123456"
+
+
+def fetch_data():
+ conn = pymysql.connect(
+ host=SqlURL,
+ port=SqlPort,
+ user=SqlUsername,
+ password=SqlPassword,
+ database=Database,
+ charset="utf8mb4"
+ )
+ cursor = conn.cursor()
+ cursor.execute("SELECT user_id, seed_id, download_start FROM SeedDownload")
+ download_rows = cursor.fetchall()
+ cursor.execute("SELECT user_id, seed_id, created_at FROM UserFavorite")
+ favorite_rows = cursor.fetchall()
+ cursor.close()
+ conn.close()
+ return download_rows, favorite_rows
+
+
+def process_records(download_rows, favorite_rows):
+ records = []
+ user_set = set()
+ seed_set = set()
+ for row in download_rows:
+ user_id, seed_id, created_at = row
+ user_set.add(user_id)
+ seed_set.add(seed_id)
+ if isinstance(created_at, datetime.datetime):
+ ts = int(created_at.timestamp())
+ else:
+ ts = 0
+ records.append((user_id, seed_id, ts))
+ for row in favorite_rows:
+ user_id, seed_id, created_at = row
+ user_set.add(user_id)
+ seed_set.add(seed_id)
+ if isinstance(created_at, datetime.datetime):
+ ts = int(created_at.timestamp())
+ else:
+ ts = 0
+ records.append((user_id, seed_id, ts))
+ return records, user_set, seed_set
+
+
+def build_id_maps(user_set, seed_set):
+ user2idx = {uid: idx for idx, uid in enumerate(sorted(user_set))}
+ seed2idx = {sid: idx for idx, sid in enumerate(sorted(seed_set))}
+ return user2idx, seed2idx
+
+
+def group_and_write(records, user2idx, seed2idx, output_path="./user_seed_graph.txt"):
+ user_items = defaultdict(list)
+ user_times = defaultdict(list)
+ for user_id, seed_id, ts in records:
+ uid = user2idx[user_id]
+ sid = seed2idx[seed_id]
+ user_items[uid].append(sid)
+ user_times[uid].append(ts)
+ print(user_items)
+ print(user_times)
+ with open(output_path, "w", encoding="utf-8") as f:
+ for uid in sorted(user_items.keys()):
+ items = " ".join(str(item) for item in user_items[uid])
+ times = " ".join(str(t) for t in user_times[uid])
+ f.write(f"{uid}\t{items}\t{times}\n")
+
+
+def main():
+ download_rows, favorite_rows = fetch_data()
+ records, user_set, seed_set = process_records(download_rows, favorite_rows)
+ user2idx, seed2idx = build_id_maps(user_set, seed_set)
+ group_and_write(records, user2idx, seed2idx)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/recommend/utils/dataloader.py b/recommend/utils/data_loader.py
similarity index 100%
rename from recommend/utils/dataloader.py
rename to recommend/utils/data_loader.py