blob: 1ef5f44e18b3eb3e9b5480d283099a640734f3a5 [file] [log] [blame]
package trackertest;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import javax.persistence.EntityManager;
import javax.persistence.EntityManagerFactory;
import javax.persistence.Persistence;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.TestFactory;
import entity.config;
import tracker.Tracker;
public class TrackerTest {
private static EntityManagerFactory emf;
private static EntityManager em;
private static List<String> userIds;
private static Map<String, Long> originalUploads;
private static Tracker tracker;
@BeforeAll
static void setup() throws Exception {
// 强制加载 MySQL 驱动,否则无法建立连接
Class.forName("com.mysql.cj.jdbc.Driver");
config cfg = new config();
Map<String,Object> props = new HashMap<>();
// 添加时区和 SSL 参数
String jdbcUrl = String.format(
"jdbc:mysql://%s/%s?useSSL=false&serverTimezone=UTC",
cfg.SqlURL, cfg.TestDatabase);
props.put("javax.persistence.jdbc.url", jdbcUrl);
props.put("javax.persistence.jdbc.user", cfg.SqlUsername);
props.put("javax.persistence.jdbc.password", cfg.SqlPassword);
props.put("javax.persistence.jdbc.driver", "com.mysql.cj.jdbc.Driver");
emf = Persistence.createEntityManagerFactory("myPersistenceUnit", props);
em = emf.createEntityManager();
// 使用简单实体名而非带包名前缀
userIds = em.createQuery(
"select u.userid from UserPT u", String.class
).getResultList();
// 保存初始 upload 值
originalUploads = new HashMap<>();
for (String uid : userIds) {
Long up = em.createQuery(
"select u.upload from UserPT u where u.userid = :uid", Long.class
).setParameter("uid", uid)
.getSingleResult();
originalUploads.put(uid, up != null ? up : 0L);
}
tracker = new Tracker(emf);
}
@AfterAll
static void teardown() {
if (em != null && em.isOpen()) em.close();
if (emf != null && emf.isOpen()) emf.close();
}
@TestFactory
Collection<DynamicTest> testAddUpLoad() {
Random rnd = new Random();
return userIds.stream()
.map(uid -> DynamicTest.dynamicTest("AddUpLoad for user " + uid, () -> {
int delta = rnd.nextInt(1000) + 1; // Ensure non-zero value
long before = originalUploads.get(uid);
// 操作成功时返回 false
Assertions.assertFalse(tracker.AddUpLoad(uid, delta),
"AddUpLoad should return false on successful operation");
// Clear the persistence context to ensure fresh data is fetched
em.clear();
// Fetch updated value
Long after = em.createQuery(
"select u.upload from UserPT u where u.userid = :uid", Long.class
).setParameter("uid", uid)
.getSingleResult();
Assertions.assertEquals(before + delta, after,
"Upload value should be increased by " + delta);
// 操作成功时返回 false
Assertions.assertFalse(tracker.ReduceUpLoad(uid, delta),
"ReduceUpLoad should return false on successful operation");
}))
.collect(Collectors.toList());
}
@TestFactory
Collection<DynamicTest> testReduceUpLoad() {
Random rnd = new Random();
return userIds.stream()
.map(uid -> DynamicTest.dynamicTest("ReduceUpLoad for user " + uid, () -> {
long before = originalUploads.get(uid);
int max = (int)Math.min(before, 1000);
int delta = max > 0 ? rnd.nextInt(max) + 1 : 0;
if (delta == 0) return; // 无可减量时跳过
Assertions.assertFalse(tracker.ReduceUpLoad(uid, delta));
Long after = em.createQuery(
"select u.upload from UserPT u where u.userid = :uid", Long.class
).setParameter("uid", uid)
.getSingleResult();
Assertions.assertEquals(before - delta, after);
// 回滚到初始值
Assertions.assertFalse(tracker.AddUpLoad(uid, delta));
}))
.collect(Collectors.toList());
}
}