新增管理员页面和用户申诉、迁移审核页面,推荐系统

Change-Id: Ief5646321feb98fadb17da4b4e91caeaacdbacc5
diff --git a/.gitignore b/.gitignore
index 1928d6a..dcb56f0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,3 +2,5 @@
 .vscode/
 torrents/
 front/node_moduels
+recommend/model/__pycache__
+recommend/utils/__pycache__
\ No newline at end of file
diff --git a/front/src/AdminPage.js b/front/src/AdminPage.js
new file mode 100644
index 0000000..4862449
--- /dev/null
+++ b/front/src/AdminPage.js
@@ -0,0 +1,145 @@
+import React, { useState } from "react";
+import { useNavigate } from "react-router-dom";
+
+// 示例数据
+const initialConfig = {
+    FarmNumber: 3,
+    FakeTime: 3,
+    BegVote: 3,
+    CheatTime: 5,
+};
+
+const cheatUsers = [
+    { user_id: "u001", email: "cheat1@example.com", username: "cheater1", account_status: 1 },
+    { user_id: "u002", email: "cheat2@example.com", username: "cheater2", account_status: 0 },
+];
+
+const suspiciousUsers = [
+    { user_id: "u101", email: "suspect1@example.com", username: "suspect1", account_status: 0 },
+    { user_id: "u102", email: "suspect2@example.com", username: "suspect2", account_status: 0 },
+];
+
+export default function AdminPage() {
+    const navigate = useNavigate();
+    const [config, setConfig] = useState(initialConfig);
+
+    const handleConfigChange = (e) => {
+        const { name, value } = e.target;
+        setConfig({ ...config, [name]: value });
+    };
+
+    const handleBan = (user) => {
+        alert(`已封禁用户:${user.username}`);
+    };
+
+    return (
+        <div style={{ padding: 40, maxWidth: 900, margin: "0 auto" }}>
+            <h1 style={{ textAlign: "center", marginBottom: 32 }}>管理员页面</h1>
+            {/* 参数设置 */}
+            <div style={{ marginBottom: 32, padding: 18, background: "#f7faff", borderRadius: 12, display: "flex", gap: 24, alignItems: "center" }}>
+                <b>系统参数:</b>
+                <label>
+                    FarmNumber:
+                    <input type="number" name="FarmNumber" value={config.FarmNumber} onChange={handleConfigChange} style={{ width: 60, margin: "0 12px" }} />
+                </label>
+                <label>
+                    FakeTime:
+                    <input type="number" name="FakeTime" value={config.FakeTime} onChange={handleConfigChange} style={{ width: 60, margin: "0 12px" }} />
+                </label>
+                <label>
+                    BegVote:
+                    <input type="number" name="BegVote" value={config.BegVote} onChange={handleConfigChange} style={{ width: 60, margin: "0 12px" }} />
+                </label>
+                <label>
+                    CheatTime:
+                    <input type="number" name="CheatTime" value={config.CheatTime} onChange={handleConfigChange} style={{ width: 60, margin: "0 12px" }} />
+                </label>
+            </div>
+            {/* 作弊用户 */}
+            <div style={{ marginBottom: 32 }}>
+                <h2 style={{ color: "#e53935" }}>作弊用户</h2>
+                <table style={{ width: "100%", background: "#fff", borderRadius: 10, boxShadow: "0 2px 8px #e0e7ff", marginBottom: 18 }}>
+                    <thead>
+                        <tr style={{ background: "#f5f5f5" }}>
+                            <th>user_id</th>
+                            <th>email</th>
+                            <th>username</th>
+                            <th>account_status</th>
+                            <th>操作</th>
+                        </tr>
+                    </thead>
+                    <tbody>
+                        {cheatUsers.map((u) => (
+                            <tr key={u.user_id}>
+                                <td>{u.user_id}</td>
+                                <td>{u.email}</td>
+                                <td>{u.username}</td>
+                                <td style={{ color: u.account_status === 1 ? "#e53935" : "#43a047" }}>
+                                    {u.account_status === 1 ? "封禁" : "正常"}
+                                </td>
+                                <td>
+                                    <button
+                                        style={{ background: "#e53935", color: "#fff", border: "none", borderRadius: 6, padding: "4px 14px", cursor: "pointer" }}
+                                        onClick={() => handleBan(u)}
+                                    >
+                                        封禁
+                                    </button>
+                                </td>
+                            </tr>
+                        ))}
+                    </tbody>
+                </table>
+            </div>
+            {/* 可疑用户 */}
+            <div style={{ marginBottom: 32 }}>
+                <h2 style={{ color: "#ff9800" }}>可疑用户</h2>
+                <table style={{ width: "100%", background: "#fff", borderRadius: 10, boxShadow: "0 2px 8px #e0e7ff" }}>
+                    <thead>
+                        <tr style={{ background: "#f5f5f5" }}>
+                            <th>user_id</th>
+                            <th>email</th>
+                            <th>username</th>
+                            <th>account_status</th>
+                            <th>操作</th>
+                        </tr>
+                    </thead>
+                    <tbody>
+                        {suspiciousUsers.map((u) => (
+                            <tr key={u.user_id}>
+                                <td>{u.user_id}</td>
+                                <td>{u.email}</td>
+                                <td>{u.username}</td>
+                                <td style={{ color: u.account_status === 1 ? "#e53935" : "#43a047" }}>
+                                    {u.account_status === 1 ? "封禁" : "正常"}
+                                </td>
+                                <td>
+                                    <button
+                                        style={{ background: "#e53935", color: "#fff", border: "none", borderRadius: 6, padding: "4px 14px", cursor: "pointer" }}
+                                        onClick={() => handleBan(u)}
+                                    >
+                                        封禁
+                                    </button>
+                                </td>
+                            </tr>
+                        ))}
+                    </tbody>
+                </table>
+            </div>
+            {/* 跳转按钮 */}
+            <div style={{ display: "flex", gap: 24, justifyContent: "center" }}>
+                <button
+                    style={{ background: "#1976d2", color: "#fff", border: "none", borderRadius: 8, padding: "10px 28px", fontWeight: 600, fontSize: 16, cursor: "pointer" }}
+                    onClick={() => navigate("/appeal-review")}
+                >
+                    用户申诉
+                </button>
+                <button
+                    style={{ background: "#43a047", color: "#fff", border: "none", borderRadius: 8, padding: "10px 28px", fontWeight: 600, fontSize: 16, cursor: "pointer" }}
+                    onClick={() => navigate("/migration-review")}
+                >
+                    用户迁移
+                </button>
+            </div>
+        </div>
+    );
+}
\ No newline at end of file
diff --git a/front/src/App.js b/front/src/App.js
index 2f85943..372fa62 100644
--- a/front/src/App.js
+++ b/front/src/App.js
@@ -25,6 +25,10 @@
 import LoginPage from './LoginPage';
 import RegisterPage from './RegisterPage';
 import RequireAuth from './RequireAuth';
+import AdminPage from './AdminPage';
+import AppealPage from './AppealPage';
+import MigrationPage from './MigrationPage';
+
 
 const navItems = [
   { label: "电影", icon: <MovieIcon />, path: "/movie" },
@@ -164,6 +168,9 @@
           <Route path="/user" element={<UserProfile />} />
           <Route path="/publish" element={<PublishPage />} />
           <Route path="/torrent/:torrentId" element={<TorrentDetailPage />} />
+          <Route path="/admin" element={<AdminPage />} />
+          <Route path="/appeal-review" element={<AppealPage />} />
+          <Route path="/migration-review" element={<MigrationPage />} />
         </Route>
       </Routes>
     </Router>
diff --git a/front/src/AppealPage.js b/front/src/AppealPage.js
new file mode 100644
index 0000000..a0314c4
--- /dev/null
+++ b/front/src/AppealPage.js
@@ -0,0 +1,141 @@
+import React, { useState } from "react";
+
+// 示例申诉数据
+const appeals = [
+    {
+        appeal_id: "a001",
+        user_id: "u001",
+        content: "我没有作弊,请审核我的账号。",
+        file_url: "http://sse.bjtu.edu.cn/media/attachments/2024/10/20241012160658.pdf",
+        status: 0,
+    },
+    {
+        appeal_id: "a002",
+        user_id: "u002",
+        content: "误封申诉,详见附件。",
+        file_url: "http://sse.bjtu.edu.cn/media/attachments/2024/10/20241012160658.pdf",
+        status: 1,
+    },
+];
+
+// 简单PDF预览组件
+function FileViewer({ url }) {
+    if (!url) return <div>无附件</div>;
+    if (url.endsWith(".pdf")) {
+        return (
+            <iframe
+                src={url}
+                title="PDF预览"
+                width="100%"
+                height="400px"
+                style={{ border: "1px solid #ccc", borderRadius: 8 }}
+            />
+        );
+    }
+    // 这里只做PDF示例,实际可扩展为DOC等
+    return <a href={url} target="_blank" rel="noopener noreferrer">下载附件</a>;
+}
+
+export default function AppealPage() {
+    const [selectedId, setSelectedId] = useState(appeals[0].appeal_id);
+    const selectedAppeal = appeals.find(a => a.appeal_id === selectedId);
+
+    const handleApprove = () => {
+        alert("已通过申诉(示例,无实际状态变更)");
+    };
+    const handleReject = () => {
+        alert("已拒绝申诉(示例,无实际状态变更)");
+    };
+
+    return (
+        <div style={{ display: "flex", minHeight: "100vh", background: "#f7faff" }}>
+            {/* 侧栏 */}
+            <div style={{ width: 180, background: "#fff", borderRight: "1px solid #e0e7ff", padding: 0 }}>
+                <h3 style={{ textAlign: "center", padding: "18px 0 0 0", color: "#1976d2" }}>申诉列表</h3>
+                <div style={{ display: "flex", flexDirection: "column", gap: 12, marginTop: 18 }}>
+                    {appeals.map(a => (
+                        <div
+                            key={a.appeal_id}
+                            onClick={() => setSelectedId(a.appeal_id)}
+                            style={{
+                                margin: "0 12px",
+                                padding: "16px 10px",
+                                borderRadius: 8,
+                                background: selectedId === a.appeal_id ? "#e3f2fd" : "#fff",
+                                border: `2px solid ${a.status === 1 ? "#43a047" : "#e53935"}`,
+                                color: a.status === 1 ? "#43a047" : "#e53935",
+                                fontWeight: 600,
+                                cursor: "pointer",
+                                boxShadow: selectedId === a.appeal_id ? "0 2px 8px #b2d8ea" : "none",
+                                transition: "all 0.2s"
+                            }}
+                        >
+                            {a.appeal_id}
+                            <span style={{
+                                float: "right",
+                                fontSize: 12,
+                                color: a.status === 1 ? "#43a047" : "#e53935"
+                            }}>
+                                {a.status === 1 ? "已审核" : "未审核"}
+                            </span>
+                        </div>
+                    ))}
+                </div>
+            </div>
+            {/* 申诉详情 */}
+            <div style={{ flex: 1, padding: "40px 48px" }}>
+                <h2 style={{ marginBottom: 24, color: "#1976d2" }}>申诉详情</h2>
+                <div style={{ background: "#fff", borderRadius: 12, padding: 32, boxShadow: "0 2px 8px #e0e7ff", marginBottom: 32 }}>
+                    <div style={{ marginBottom: 18 }}>
+                        <b>申诉ID:</b>{selectedAppeal.appeal_id}
+                    </div>
+                    <div style={{ marginBottom: 18 }}>
+                        <b>用户ID:</b>{selectedAppeal.user_id}
+                    </div>
+                    <div style={{ marginBottom: 18 }}>
+                        <b>申诉内容:</b>{selectedAppeal.content}
+                    </div>
+                    <div style={{ marginBottom: 18 }}>
+                        <b>申诉文件:</b>
+                        <FileViewer url={selectedAppeal.file_url} />
+                    </div>
+                </div>
+                {/* 审核按钮 */}
+                <div style={{ display: "flex", gap: 32, justifyContent: "center" }}>
+                    <button
+                        style={{
+                            background: selectedAppeal.status === 1 ? "#bdbdbd" : "#43a047",
+                            color: "#fff",
+                            border: "none",
+                            borderRadius: 8,
+                            padding: "10px 38px",
+                            fontWeight: 600,
+                            fontSize: 18,
+                            cursor: selectedAppeal.status === 1 ? "not-allowed" : "pointer"
+                        }}
+                        disabled={selectedAppeal.status === 1}
+                        onClick={handleApprove}
+                    >
+                        通过
+                    </button>
+                    <button
+                        style={{
+                            background: selectedAppeal.status === 1 ? "#bdbdbd" : "#e53935",
+                            color: "#fff",
+                            border: "none",
+                            borderRadius: 8,
+                            padding: "10px 38px",
+                            fontWeight: 600,
+                            fontSize: 18,
+                            cursor: selectedAppeal.status === 1 ? "not-allowed" : "pointer"
+                        }}
+                        disabled={selectedAppeal.status === 1}
+                        onClick={handleReject}
+                    >
+                        不通过
+                    </button>
+                </div>
+            </div>
+        </div>
+    );
+}
\ No newline at end of file
diff --git a/front/src/LoginPage.js b/front/src/LoginPage.js
index 5bb3603..890b690 100644
--- a/front/src/LoginPage.js
+++ b/front/src/LoginPage.js
@@ -15,6 +15,12 @@
   };

 

   const handleLogin = async () => {

+    // 进入管理员页面

+    if (formData.username === "admin" && formData.password === "admin123") {

+      navigate('/admin');

+      return;

+    }

+

     if (formData.password.length < 8) {

       setErrorMessage('密码必须至少包含八位字符!');

       return;

@@ -60,7 +66,7 @@
         const { username, password } = JSON.parse(regUser);

         setFormData({ username, password });

         sessionStorage.removeItem('registeredUser');

-      } catch {}

+      } catch { }

     }

   }, []);

 

diff --git a/front/src/MigrationPage.js b/front/src/MigrationPage.js
new file mode 100644
index 0000000..ed88a55
--- /dev/null
+++ b/front/src/MigrationPage.js
@@ -0,0 +1,152 @@
+import React, { useState } from "react";
+
+// 示例迁移数据
+const migrations = [
+    {
+        migration_id: "m001",
+        user_id: "u001",
+        application_url: "http://sse.bjtu.edu.cn/media/attachments/2024/10/20241012160658.pdf",
+        approved: 0,
+        pending_magic: 10,
+        granted_magic: 0,
+        pending_uploaded: 1000,
+        granted_uploaded: 0,
+    },
+    {
+        migration_id: "m002",
+        user_id: "u002",
+        application_url: "http://sse.bjtu.edu.cn/media/attachments/2024/10/20241012160658.pdf",
+        approved: 1,
+        pending_magic: 20,
+        granted_magic: 20,
+        pending_uploaded: 2000,
+        granted_uploaded: 2000,
+    },
+];
+
+// 简单PDF预览组件
+function FileViewer({ url }) {
+    if (!url) return <div>无附件</div>;
+    if (url.endsWith(".pdf")) {
+        return (
+            <iframe
+                src={url}
+                title="PDF预览"
+                width="100%"
+                height="400px"
+                style={{ border: "1px solid #ccc", borderRadius: 8 }}
+            />
+        );
+    }
+    // 这里只做PDF示例,实际可扩展为DOC等
+    return <a href={url} target="_blank" rel="noopener noreferrer">下载附件</a>;
+}
+
+export default function MigrationPage() {
+    const [selectedId, setSelectedId] = useState(migrations[0].migration_id);
+    const selectedMigration = migrations.find(m => m.migration_id === selectedId);
+
+    const handleApprove = () => {
+        alert("已通过迁移(示例,无实际状态变更)");
+    };
+    const handleReject = () => {
+        alert("已拒绝迁移(示例,无实际状态变更)");
+    };
+
+    return (
+        <div style={{ display: "flex", minHeight: "100vh", background: "#f7faff" }}>
+            {/* 侧栏 */}
+            <div style={{ width: 180, background: "#fff", borderRight: "1px solid #e0e7ff", padding: 0 }}>
+                <h3 style={{ textAlign: "center", padding: "18px 0 0 0", color: "#1976d2" }}>迁移列表</h3>
+                <div style={{ display: "flex", flexDirection: "column", gap: 12, marginTop: 18 }}>
+                    {migrations.map(m => (
+                        <div
+                            key={m.migration_id}
+                            onClick={() => setSelectedId(m.migration_id)}
+                            style={{
+                                margin: "0 12px",
+                                padding: "16px 10px",
+                                borderRadius: 8,
+                                background: selectedId === m.migration_id ? "#e3f2fd" : "#fff",
+                                border: `2px solid ${m.approved === 1 ? "#43a047" : "#e53935"}`,
+                                color: m.approved === 1 ? "#43a047" : "#e53935",
+                                fontWeight: 600,
+                                cursor: "pointer",
+                                boxShadow: selectedId === m.migration_id ? "0 2px 8px #b2d8ea" : "none",
+                                transition: "all 0.2s"
+                            }}
+                        >
+                            {m.migration_id}
+                            <span style={{
+                                float: "right",
+                                fontSize: 12,
+                                color: m.approved === 1 ? "#43a047" : "#e53935"
+                            }}>
+                                {m.approved === 1 ? "已审核" : "未审核"}
+                            </span>
+                        </div>
+                    ))}
+                </div>
+            </div>
+            {/* 迁移详情 */}
+            <div style={{ flex: 1, padding: "40px 48px" }}>
+                <h2 style={{ marginBottom: 24, color: "#1976d2" }}>迁移详情</h2>
+                <div style={{ background: "#fff", borderRadius: 12, padding: 32, boxShadow: "0 2px 8px #e0e7ff", marginBottom: 32 }}>
+                    <div style={{ marginBottom: 18 }}>
+                        <b>迁移ID:</b>{selectedMigration.migration_id}
+                    </div>
+                    <div style={{ marginBottom: 18 }}>
+                        <b>用户ID:</b>{selectedMigration.user_id}
+                    </div>
+                    <div style={{ marginBottom: 18 }}>
+                        <b>申请文件:</b>
+                        <FileViewer url={selectedMigration.application_url} />
+                    </div>
+                    <div style={{ marginBottom: 18 }}>
+                        <b>待迁移魔法值:</b>{selectedMigration.pending_magic},
+                        <b>已迁移魔法值:</b>{selectedMigration.granted_magic}
+                    </div>
+                    <div style={{ marginBottom: 18 }}>
+                        <b>待迁移上传量:</b>{selectedMigration.pending_uploaded},
+                        <b>已迁移上传量:</b>{selectedMigration.granted_uploaded}
+                    </div>
+                </div>
+                {/* 审核按钮 */}
+                <div style={{ display: "flex", gap: 32, justifyContent: "center" }}>
+                    <button
+                        style={{
+                            background: selectedMigration.approved === 1 ? "#bdbdbd" : "#43a047",
+                            color: "#fff",
+                            border: "none",
+                            borderRadius: 8,
+                            padding: "10px 38px",
+                            fontWeight: 600,
+                            fontSize: 18,
+                            cursor: selectedMigration.approved === 1 ? "not-allowed" : "pointer"
+                        }}
+                        disabled={selectedMigration.approved === 1}
+                        onClick={handleApprove}
+                    >
+                        通过
+                    </button>
+                    <button
+                        style={{
+                            background: selectedMigration.approved === 1 ? "#bdbdbd" : "#e53935",
+                            color: "#fff",
+                            border: "none",
+                            borderRadius: 8,
+                            padding: "10px 38px",
+                            fontWeight: 600,
+                            fontSize: 18,
+                            cursor: selectedMigration.approved === 1 ? "not-allowed" : "pointer"
+                        }}
+                        disabled={selectedMigration.approved === 1}
+                        onClick={handleReject}
+                    >
+                        不通过
+                    </button>
+                </div>
+            </div>
+        </div>
+    );
+}
\ No newline at end of file
diff --git a/front/src/UserProfile.js b/front/src/UserProfile.js
index 753e901..16fdcd2 100644
--- a/front/src/UserProfile.js
+++ b/front/src/UserProfile.js
@@ -33,7 +33,7 @@
       if (!userid) return;

       try {

         const res = await fetch(`${API_BASE_URL}/api/user-profile?userid=${userid}`);

-        

+

         if (res.ok) {

           const data = await res.json();

           setUserInfo(data);

@@ -91,9 +91,9 @@
   };

 

   const handleSave = async () => {

-    if (tempUserInfo.gender === "男性"){

+    if (tempUserInfo.gender === "男性") {

       tempUserInfo.gender = "m";

-    }else if (tempUserInfo.gender === "女性"){

+    } else if (tempUserInfo.gender === "女性") {

       tempUserInfo.gender = "f";

     }

     setUserInfo({ ...tempUserInfo });

@@ -222,7 +222,7 @@
               >

                 {tempUserInfo.gender === 'm' ? '男性'

                   : tempUserInfo.gender === 'f' ? '女性'

-                  : '性别'}

+                    : '性别'}

               </button>

               {tempUserInfo.showGenderOptions && (

                 <ul

@@ -307,7 +307,7 @@
                       // const userid = localStorage.getItem("userid");

                       // const userid = "550e8400-e29b-41d4-a716-446655440000"; // 示例userid

                       try {

-                        

+

                         const res = await fetch(`${API_BASE_URL}/api/delete-seed`, {

                           method: 'POST',

                           headers: { 'Content-Type': 'application/json' },

diff --git a/recommend/hello.py b/recommend/hello.py
deleted file mode 100644
index c6d4e16..0000000
--- a/recommend/hello.py
+++ /dev/null
@@ -1 +0,0 @@
-print("Hello G10!")
diff --git a/recommend/inference.py b/recommend/inference.py
new file mode 100644
index 0000000..697b569
--- /dev/null
+++ b/recommend/inference.py
@@ -0,0 +1,54 @@
+import sys
+sys.path.append('./')
+
+from os import path
+from utils.parse_args import args
+from utils.dataloader import EdgeListData
+from model.LightGCN import LightGCN
+import torch
+import numpy as np
+import time
+
+# 计时:脚本开始
+t_start = time.time()
+
+# 配置参数
+args.data_path = './'
+args.device = 'cuda:7'
+args.pre_model_path = './model/LightGCN_pretrained.pt'
+
+# 1. 加载数据集
+t_data_start = time.time()
+pretrain_data = path.join(args.data_path, "uig.txt")
+pretrain_val_data = path.join(args.data_path, "uig.txt")
+dataset = EdgeListData(pretrain_data, pretrain_val_data)
+t_data_end = time.time()
+
+
+# 2. 加载LightGCN模型
+pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
+pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
+pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
+
+model = LightGCN(dataset, phase='vanilla').to(args.device)
+model.load_state_dict(pretrained_dict, strict=False)
+model.eval()
+
+# 3. 输入用户ID
+user_id = 1
+
+# 4. 推理:获取embedding并打分
+t_infer_start = time.time()
+with torch.no_grad():
+    user_emb, item_emb = model.generate()
+    user_vec = user_emb[user_id].unsqueeze(0)
+    scores = model.rating(user_vec, item_emb).squeeze(0)
+    pred_item = torch.argmax(scores).item()
+t_infer_end = time.time()
+
+t_end = time.time()
+
+print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}")
+print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒")
+print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒")
+print(f"脚本总耗时: {t_end - t_start:.4f} 秒")
\ No newline at end of file
diff --git a/recommend/model/LightGCN.py b/recommend/model/LightGCN.py
new file mode 100644
index 0000000..b6b447e
--- /dev/null
+++ b/recommend/model/LightGCN.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import scipy.sparse as sp
+import math
+import networkx as nx
+import random
+from copy import deepcopy
+from utils.parse_args import args
+from model.base_model import BaseModel
+from model.operators import EdgelistDrop
+from model.operators import scatter_add, scatter_sum
+
+
+init = nn.init.xavier_uniform_
+
+class LightGCN(BaseModel):
+    def __init__(self, dataset, pretrained_model=None, phase='pretrain'):
+        super().__init__(dataset)
+        self.adj = self._make_binorm_adj(dataset.graph)
+        self.edges = self.adj._indices().t()
+        self.edge_norm = self.adj._values()
+
+        self.phase = phase
+
+        self.emb_gate = lambda x: x
+
+        if self.phase == 'pretrain' or self.phase == 'vanilla' or self.phase == 'for_tune':
+            self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size)))
+            self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size)))
+
+
+        elif self.phase == 'finetune':
+            pre_user_emb, pre_item_emb = pretrained_model.generate()
+            self.user_embedding = nn.Parameter(pre_user_emb).requires_grad_(True)
+            self.item_embedding = nn.Parameter(pre_item_emb).requires_grad_(True)
+
+        elif self.phase == 'continue_tune':
+            # re-initialize for loading state dict
+            self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size)))
+            self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size)))
+
+        self.edge_dropout = EdgelistDrop()
+
+    def _agg(self, all_emb, edges, edge_norm):
+        src_emb = all_emb[edges[:, 0]]
+
+        # bi-norm
+        src_emb = src_emb * edge_norm.unsqueeze(1)
+
+        # conv
+        dst_emb = scatter_sum(src_emb, edges[:, 1], dim=0, dim_size=self.num_users+self.num_items)
+        return dst_emb
+    
+    def _edge_binorm(self, edges):
+        user_degs = scatter_add(torch.ones_like(edges[:, 0]), edges[:, 0], dim=0, dim_size=self.num_users)
+        user_degs = user_degs[edges[:, 0]]
+        item_degs = scatter_add(torch.ones_like(edges[:, 1]), edges[:, 1], dim=0, dim_size=self.num_items)
+        item_degs = item_degs[edges[:, 1]]
+        norm = torch.pow(user_degs, -0.5) * torch.pow(item_degs, -0.5)
+        return norm
+
+    def forward(self, edges, edge_norm, return_layers=False):
+        all_emb = torch.cat([self.user_embedding, self.item_embedding], dim=0)
+        all_emb = self.emb_gate(all_emb)
+        res_emb = [all_emb]
+        for l in range(args.num_layers):
+            all_emb = self._agg(res_emb[-1], edges, edge_norm)
+            res_emb.append(all_emb)
+        if not return_layers:
+            res_emb = sum(res_emb)
+            user_res_emb, item_res_emb = res_emb.split([self.num_users, self.num_items], dim=0)
+        else:
+            user_res_emb, item_res_emb = [], []
+            for emb in res_emb:
+                u_emb, i_emb = emb.split([self.num_users, self.num_items], dim=0)
+                user_res_emb.append(u_emb)
+                item_res_emb.append(i_emb)
+        return user_res_emb, item_res_emb
+    
+    def cal_loss(self, batch_data):
+        edges, dropout_mask = self.edge_dropout(self.edges, 1-args.edge_dropout, return_mask=True)
+        edge_norm = self.edge_norm[dropout_mask]
+
+        # forward
+        users, pos_items, neg_items = batch_data
+        user_emb, item_emb = self.forward(edges, edge_norm)
+        batch_user_emb = user_emb[users]
+        pos_item_emb = item_emb[pos_items]
+        neg_item_emb = item_emb[neg_items]
+        rec_loss = self._bpr_loss(batch_user_emb, pos_item_emb, neg_item_emb)
+        reg_loss = args.weight_decay * self._reg_loss(users, pos_items, neg_items)
+
+        loss = rec_loss + reg_loss
+        loss_dict = {
+            "rec_loss": rec_loss.item(),
+            "reg_loss": reg_loss.item(),
+        }
+        return loss, loss_dict
+    
+    @torch.no_grad()
+    def generate(self, return_layers=False):
+        return self.forward(self.edges, self.edge_norm, return_layers=return_layers)
+    
+    @torch.no_grad()
+    def generate_lgn(self, return_layers=False):
+        return self.forward(self.edges, self.edge_norm, return_layers=return_layers)
+    
+    @torch.no_grad()
+    def rating(self, user_emb, item_emb):
+        return torch.matmul(user_emb, item_emb.t())
+    
+    def _reg_loss(self, users, pos_items, neg_items):
+        u_emb = self.user_embedding[users]
+        pos_i_emb = self.item_embedding[pos_items]
+        neg_i_emb = self.item_embedding[neg_items]
+        reg_loss = (1/2)*(u_emb.norm(2).pow(2) +
+                          pos_i_emb.norm(2).pow(2) +
+                          neg_i_emb.norm(2).pow(2))/float(len(users))
+        return reg_loss
diff --git a/recommend/model/LightGCN_pretrained.pt b/recommend/model/LightGCN_pretrained.pt
new file mode 100644
index 0000000..825e0e2
--- /dev/null
+++ b/recommend/model/LightGCN_pretrained.pt
Binary files differ
diff --git a/recommend/model/base_model.py b/recommend/model/base_model.py
new file mode 100644
index 0000000..819442a
--- /dev/null
+++ b/recommend/model/base_model.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+from utils.parse_args import args
+from scipy.sparse import csr_matrix
+import scipy.sparse as sp
+import numpy as np
+import torch.nn.functional as F
+
+
+class BaseModel(nn.Module):
+    def __init__(self, dataloader):
+        super(BaseModel, self).__init__()
+        self.num_users = dataloader.num_users
+        self.num_items = dataloader.num_items
+        self.emb_size = args.emb_size
+
+    def forward(self):
+        pass
+
+    def cal_loss(self, batch_data):
+        pass
+
+    def _check_inf(self, loss, pos_score, neg_score, edge_weight):
+        # find inf idx
+        inf_idx = torch.isinf(loss) | torch.isnan(loss)
+        if inf_idx.any():
+            print("find inf in loss")
+            if type(edge_weight) != int:
+                print(edge_weight[inf_idx])
+            print(f"pos_score: {pos_score[inf_idx]}")
+            print(f"neg_score: {neg_score[inf_idx]}")
+            raise ValueError("find inf in loss")
+
+    def _make_binorm_adj(self, mat):
+        a = csr_matrix((self.num_users, self.num_users))
+        b = csr_matrix((self.num_items, self.num_items))
+        mat = sp.vstack(
+            [sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
+        mat = (mat != 0) * 1.0
+        # mat = (mat + sp.eye(mat.shape[0])) * 1.0# MARK
+        degree = np.array(mat.sum(axis=-1))
+        d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
+        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
+        d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
+        mat = mat.dot(d_inv_sqrt_mat).transpose().dot(
+            d_inv_sqrt_mat).tocoo()
+
+        # make torch tensor
+        idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
+        vals = torch.from_numpy(mat.data.astype(np.float32))
+        shape = torch.Size(mat.shape)
+        return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device)
+    
+    def _make_binorm_adj_self_loop(self, mat):
+        a = csr_matrix((self.num_users, self.num_users))
+        b = csr_matrix((self.num_items, self.num_items))
+        mat = sp.vstack(
+            [sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
+        mat = (mat != 0) * 1.0
+        mat = (mat + sp.eye(mat.shape[0])) * 1.0 # self loop
+        degree = np.array(mat.sum(axis=-1))
+        d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
+        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
+        d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
+        mat = mat.dot(d_inv_sqrt_mat).transpose().dot(
+            d_inv_sqrt_mat).tocoo()
+
+        # make torch tensor
+        idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
+        vals = torch.from_numpy(mat.data.astype(np.float32))
+        shape = torch.Size(mat.shape)
+        return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device)
+
+
+    def _sp_matrix_to_sp_tensor(self, sp_matrix):
+        coo = sp_matrix.tocoo()
+        indices = torch.LongTensor([coo.row, coo.col])
+        values = torch.FloatTensor(coo.data)
+        return torch.sparse.FloatTensor(indices, values, coo.shape).coalesce().to(args.device)
+
+    def _bpr_loss(self, user_emb, pos_item_emb, neg_item_emb):
+        pos_score = (user_emb * pos_item_emb).sum(dim=1)
+        neg_score = (user_emb * neg_item_emb).sum(dim=1)
+        loss = -torch.log(1e-10 + torch.sigmoid((pos_score - neg_score)))
+        self._check_inf(loss, pos_score, neg_score, 0)
+        return loss.mean()
+    
+    def _nce_loss(self, pos_score, neg_score, edge_weight=1):
+        numerator = torch.exp(pos_score)
+        denominator = torch.exp(pos_score) + torch.exp(neg_score).sum(dim=1)
+        loss = -torch.log(numerator/denominator) * edge_weight
+        self._check_inf(loss, pos_score, neg_score, edge_weight)
+        return loss.mean()
+    
+    def _infonce_loss(self, pos_1, pos_2, negs, tau):
+        pos_1 = self.cl_mlp(pos_1)
+        pos_2 = self.cl_mlp(pos_2)
+        negs = self.cl_mlp(negs)
+        pos_1 = F.normalize(pos_1, dim=-1)
+        pos_2 = F.normalize(pos_2, dim=-1)
+        negs = F.normalize(negs, dim=-1)
+        pos_score = torch.mul(pos_1, pos_2).sum(dim=1)
+        # B, 1, E * B, E, N -> B, N
+        neg_score = torch.bmm(pos_1.unsqueeze(1), negs.transpose(1, 2)).squeeze(1)
+        # infonce loss
+        numerator = torch.exp(pos_score / tau)
+        denominator = torch.exp(pos_score / tau) + torch.exp(neg_score / tau).sum(dim=1)
+        loss = -torch.log(numerator/denominator)
+        self._check_inf(loss, pos_score, neg_score, 0)
+        return loss.mean()
+    
\ No newline at end of file
diff --git a/recommend/model/operators.py b/recommend/model/operators.py
new file mode 100644
index 0000000..a508966
--- /dev/null
+++ b/recommend/model/operators.py
@@ -0,0 +1,52 @@
+import torch
+from typing import Optional, Tuple
+from torch import nn
+
+def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
+    if dim < 0:
+        dim = other.dim() + dim
+    if src.dim() == 1:
+        for _ in range(0, dim):
+            src = src.unsqueeze(0)
+    for _ in range(src.dim(), other.dim()):
+        src = src.unsqueeze(-1)
+    src = src.expand(other.size())
+    return src
+
+def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
+                out: Optional[torch.Tensor] = None,
+                dim_size: Optional[int] = None) -> torch.Tensor:
+    index = broadcast(index, src, dim)
+    if out is None:
+        size = list(src.size())
+        if dim_size is not None:
+            size[dim] = dim_size
+        elif index.numel() == 0:
+            size[dim] = 0
+        else:
+            size[dim] = int(index.max()) + 1
+        out = torch.zeros(size, dtype=src.dtype, device=src.device)
+        return out.scatter_add_(dim, index, src)
+    else:
+        return out.scatter_add_(dim, index, src)
+
+def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
+                out: Optional[torch.Tensor] = None,
+                dim_size: Optional[int] = None) -> torch.Tensor:
+    return scatter_sum(src, index, dim, out, dim_size)
+
+
+class EdgelistDrop(nn.Module):
+    def __init__(self):
+        super(EdgelistDrop, self).__init__()
+
+    def forward(self, edgeList, keep_rate, return_mask=False):
+        if keep_rate == 1.0:
+            return edgeList, torch.ones(edgeList.size(0)).type(torch.bool)
+        edgeNum = edgeList.size(0)
+        mask = (torch.rand(edgeNum) + keep_rate).floor().type(torch.bool)
+        newEdgeList = edgeList[mask, :]
+        if return_mask:
+            return newEdgeList, mask
+        else:
+            return newEdgeList
diff --git a/recommend/uig.txt b/recommend/uig.txt
new file mode 100644
index 0000000..5846057
--- /dev/null
+++ b/recommend/uig.txt
@@ -0,0 +1,2 @@
+0	1 3 9 12 5 7 6 8 4	1511683379 1511683385 1511683431 1511683453 1511683481 1511692992 1511693011 1511693077 1511787191

+1	10 11 2	1511578239 1511594732 1511664627
\ No newline at end of file
diff --git a/recommend/utils/dataloader.py b/recommend/utils/dataloader.py
new file mode 100644
index 0000000..d519f17
--- /dev/null
+++ b/recommend/utils/dataloader.py
@@ -0,0 +1,92 @@
+from utils.parse_args import args
+from os import path
+from tqdm import tqdm
+import numpy as np
+import scipy.sparse as sp
+import torch
+import networkx as nx
+from copy import deepcopy
+from collections import defaultdict
+import pandas as pd
+
+
+class EdgeListData:
+    def __init__(self, train_file, test_file, phase='pretrain', pre_dataset=None, has_time=True):
+        self.phase = phase
+        self.has_time = has_time
+        self.pre_dataset = pre_dataset
+
+        self.hour_interval = args.hour_interval_pre if phase == 'pretrain' else args.hour_interval_f
+
+        self.edgelist = []
+        self.edge_time = []
+        self.num_users = 0
+        self.num_items = 0
+        self.num_edges = 0
+
+        self.train_user_dict = {}
+        self.test_user_dict = {}
+
+        self._load_data(train_file, test_file, has_time)
+
+        if phase == 'pretrain':
+            self.user_hist_dict = self.train_user_dict
+        
+        users_has_hist = set(list(self.user_hist_dict.keys()))
+        all_users = set(list(range(self.num_users)))
+        users_no_hist = all_users - users_has_hist
+        for u in users_no_hist:
+            self.user_hist_dict[u] = []
+
+    def _read_file(self, train_file, test_file, has_time=True):
+        with open(train_file, 'r') as f:
+            for line in f:
+                line = line.strip().split('\t')
+                if not has_time:
+                    user, items = line[:2]
+                    times = " ".join(["0"] * len(items.split(" ")))
+                else:
+                    user, items, times = line
+                    
+                for i in items.split(" "):
+                    self.edgelist.append((int(user), int(i)))
+                for i in times.split(" "):
+                    self.edge_time.append(int(i))
+                self.train_user_dict[int(user)] = [int(i) for i in items.split(" ")]
+
+        self.test_edge_num = 0
+        with open(test_file, 'r') as f:
+            for line in f:
+                line = line.strip().split('\t')
+                user, items = line[:2]
+                self.test_user_dict[int(user)] = [int(i) for i in items.split(" ")]
+                self.test_edge_num += len(self.test_user_dict[int(user)])
+
+    def _load_data(self, train_file, test_file, has_time=True):
+        self._read_file(train_file, test_file, has_time)
+
+        self.edgelist = np.array(self.edgelist, dtype=np.int32)
+        self.edge_time = 1 + self.timestamp_to_time_step(np.array(self.edge_time, dtype=np.int32))
+        self.num_edges = len(self.edgelist)
+        if self.pre_dataset is not None:
+            self.num_users = self.pre_dataset.num_users
+            self.num_items = self.pre_dataset.num_items
+        else:
+            self.num_users = max([np.max(self.edgelist[:, 0]) + 1, np.max(list(self.test_user_dict.keys())) + 1])
+            self.num_items = max([np.max(self.edgelist[:, 1]) + 1, np.max([np.max(self.test_user_dict[u]) for u in self.test_user_dict.keys()]) + 1])
+
+        self.graph = sp.coo_matrix((np.ones(self.num_edges), (self.edgelist[:, 0], self.edgelist[:, 1])), shape=(self.num_users, self.num_items))
+
+        if self.has_time:
+            self.edge_time_dict = defaultdict(dict)
+            for i in range(len(self.edgelist)):
+                self.edge_time_dict[self.edgelist[i][0]][self.edgelist[i][1]+self.num_users] = self.edge_time[i]
+                self.edge_time_dict[self.edgelist[i][1]+self.num_users][self.edgelist[i][0]] = self.edge_time[i]
+
+    def timestamp_to_time_step(self, timestamp_arr, least_time=None):
+        interval_hour = self.hour_interval
+        if least_time is None:
+            least_time = np.min(timestamp_arr)
+        timestamp_arr = timestamp_arr - least_time
+        timestamp_arr = timestamp_arr // (interval_hour * 3600)
+        return timestamp_arr
diff --git a/recommend/utils/parse_args.py b/recommend/utils/parse_args.py
new file mode 100644
index 0000000..3e86a47
--- /dev/null
+++ b/recommend/utils/parse_args.py
@@ -0,0 +1,57 @@
+import argparse
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='GraphPro')
+    parser.add_argument('--phase', type=str, default='pretrain')
+    parser.add_argument('--plugin', action='store_true', default=False)
+    parser.add_argument('--save_path', type=str, default="saved" ,help='where to save model and logs')
+    parser.add_argument('--data_path', type=str, default="dataset/yelp",help='where to load data')
+    parser.add_argument('--exp_name', type=str, default='1')
+    parser.add_argument('--desc', type=str, default='')
+    parser.add_argument('--ab', type=str, default='full')
+    parser.add_argument('--log', type=int, default=1)
+
+    parser.add_argument('--device', type=str, default="cuda")
+    parser.add_argument('--model', type=str, default='GraphPro')
+    parser.add_argument('--pre_model', type=str, default='GraphPro')
+    parser.add_argument('--f_model', type=str, default='GraphPro')
+    parser.add_argument('--pre_model_path', type=str, default='pretrained_model.pt')
+
+    parser.add_argument('--hour_interval_pre', type=float, default=1)
+    parser.add_argument('--hour_interval_f', type=int, default=1)
+    parser.add_argument('--emb_dropout', type=float, default=0)
+
+    parser.add_argument('--updt_inter', type=int, default=1)
+    parser.add_argument('--samp_decay', type=float, default=0.05)
+    
+    parser.add_argument('--edge_dropout', type=float, default=0.5)
+    parser.add_argument('--emb_size', type=int, default=64)
+    parser.add_argument('--batch_size', type=int, default=2048)
+    parser.add_argument('--eval_batch_size', type=int, default=512)
+    parser.add_argument('--seed', type=int, default=2023)
+    parser.add_argument('--num_epochs', type=int, default=300)
+    parser.add_argument('--neighbor_sample_num', type=int, default=5)
+    parser.add_argument('--lr', type=float, default=0.001)
+    parser.add_argument('--weight_decay', type=float, default=1e-4)
+    parser.add_argument('--metrics', type=str, default='recall;ndcg')
+    parser.add_argument('--metrics_k', type=str, default='20')
+    parser.add_argument('--early_stop_patience', type=int, default=10)
+    parser.add_argument('--neg_num', type=int, default=1)
+
+    parser.add_argument('--num_layers', type=int, default=3)
+
+
+    return parser
+
+parser = parse_args()
+args = parser.parse_known_args()[0]
+if args.pre_model == args.f_model:
+    args.model = args.pre_model
+elif args.pre_model != 'LightGCN':
+    args.model = args.pre_model
+
+args = parser.parse_args()
+if args.pre_model == args.f_model:
+    args.model = args.pre_model
+elif args.pre_model != 'LightGCN':
+    args.model = args.pre_model
\ No newline at end of file