debug
Change-Id: I5c4de18f786f8cc336d8ad66ae9b424d02ed3674
diff --git a/src/main/java/com/example/g8backend/controller/PostController.java b/src/main/java/com/example/g8backend/controller/PostController.java
index 03687c3..d53db64 100644
--- a/src/main/java/com/example/g8backend/controller/PostController.java
+++ b/src/main/java/com/example/g8backend/controller/PostController.java
@@ -1,5 +1,4 @@
package com.example.g8backend.controller;
-
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.example.g8backend.dto.PostCreateDTO;
@@ -12,9 +11,7 @@
import org.springframework.web.bind.annotation.*;
import com.example.g8backend.entity.Post;
import com.example.g8backend.service.IPostService;
-
import java.util.List;
-
@RestController
@RequestMapping("/post")
public class PostController {
@@ -22,14 +19,12 @@
private IPostService postService;
@Autowired // ✅ 新增注入
private PostViewMapper postViewMapper;
-
@PostMapping("")
public ResponseEntity<?> createPost(@RequestBody PostCreateDTO postCreateDTO) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
long userId = (long) authentication.getPrincipal();
Post post = postCreateDTO.getPost();
Long[] tagIds = postCreateDTO.getTagIds();
-
post.setUserId(userId);
if (tagIds.length > 0){
postService.createPost(post, tagIds);
@@ -38,20 +33,16 @@
}
return ResponseEntity.ok().build();
}
-
@GetMapping("/{postId}")
public Post getPost(@PathVariable Long postId) {
// 获取当前用户ID
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
long userId = (long) authentication.getPrincipal();
-
// 记录浏览行为
postService.recordViewHistory(userId, postId);
-
// 返回帖子详情
return postService.getById(postId);
}
-
@DeleteMapping("/{postId}")
public ResponseEntity<?> deletePost(@PathVariable("postId") Long postId) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
@@ -66,42 +57,37 @@
postService.removeById(postId);
return ResponseEntity.ok().body("Post deleted successfully.");
}
-
@GetMapping("/getAll")
public List<Post> getAllPosts() {
return postService.list();
}
-
@GetMapping("/getByUserId/{userId}")
public List<Post> getPostsByUserId(@PathVariable("userId") Long userId) {
return postService.getPostsByUserId(userId);
}
-
@PutMapping("/{postId}")
public ResponseEntity<?> updatePost(@PathVariable("postId") Long postId, @RequestBody Post post) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
long userId = (long) authentication.getPrincipal();
Post existingPost = postService.getById(postId);
-
+
if (existingPost == null) {
return ResponseEntity.status(500).body("Post not found.");
}
if (existingPost.getUserId() != userId) {
return ResponseEntity.status(403).body("You are not authorized to update this post.");
}
-
+
post.setPostId(postId);
post.setUserId(userId);
postService.updateById(post);
return ResponseEntity.ok().body("Post updated successfully.");
}
-
@GetMapping("/type/{postType}")
public ResponseEntity<?> getPostsByType(@PathVariable String postType) {
List<Post> posts = postService.getPostsByType(postType);
return ResponseEntity.ok().body(posts);
}
-
@PostMapping("/{postId}/like")
public ResponseEntity<?> likePost(@PathVariable Long postId) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
@@ -109,7 +95,6 @@
postService.likePost(userId, postId);
return ResponseEntity.ok().body("Post liked successfully.");
}
-
@DeleteMapping("/{postId}/like")
public ResponseEntity<?> unlikePost(@PathVariable Long postId) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
@@ -117,51 +102,52 @@
postService.unlikePost(userId, postId);
return ResponseEntity.ok().body("Post unliked successfully.");
}
-
@GetMapping("/{postId}/likes")
public ResponseEntity<?> getPostLikeCount(@PathVariable Long postId) {
Long likeCount = postService.getPostLikeCount(postId);
return ResponseEntity.ok().body(likeCount);
}
-
// 搜索帖子
@GetMapping("/search")
public List<Post> searchPosts(
@RequestParam(required = false) String keyword,
@RequestParam(required = false) List<Long> tags, // 修改为接收多个标签
@RequestParam(required = false) String author) {
-
return postService.searchPosts(keyword, tags, author);
}
-
@GetMapping("/history")
public ResponseEntity<List<PostView>> getViewHistory() {
// 获取当前用户ID
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
long userId = (long) authentication.getPrincipal();
-
// 查询历史记录(按时间倒序)
List<PostView> history = postViewMapper.selectList(
new QueryWrapper<PostView>()
.eq("user_id", userId)
.orderByDesc("view_time")
);
-
return ResponseEntity.ok(history);
}
-
@GetMapping("/recommended")
public ResponseEntity<Page<Post>> getRecommendedPosts(
@RequestParam(defaultValue = "1") int page,
@RequestParam(defaultValue = "10") int size) {
-
// 获取当前用户ID
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
long userId = (long) authentication.getPrincipal();
-
// 调用 Service 层方法
Page<Post> pageResult = postService.getRecommendedPosts(page, size, userId);
-
return ResponseEntity.ok(pageResult);
}
-}
+ // PostController.java - 新增标签推荐接口
+ @GetMapping("/recommended-by-tags")
+ public ResponseEntity<Page<Post>> getRecommendedByTags(
+ @RequestParam(defaultValue = "1") int page,
+ @RequestParam(defaultValue = "10") int size) {
+ Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+ long userId = (long) authentication.getPrincipal();
+ // 调用标签推荐方法
+ Page<Post> result = postService.getRecommendedByTags(page, size, userId);
+ return ResponseEntity.ok(result);
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/g8backend/entity/UserTagPreference.java b/src/main/java/com/example/g8backend/entity/UserTagPreference.java
new file mode 100644
index 0000000..42ac4fd
--- /dev/null
+++ b/src/main/java/com/example/g8backend/entity/UserTagPreference.java
@@ -0,0 +1,15 @@
+
+package com.example.g8backend.entity;
+import com.baomidou.mybatisplus.annotation.TableName;
+import lombok.Data;
+import lombok.experimental.Accessors;
+import java.sql.Timestamp;
+@Data
+@Accessors(chain = true)
+@TableName("user_tag_preference") // 映射数据库表名
+public class UserTagPreference {
+ private Long userId; // 用户ID
+ private Long tagId; // 标签ID
+ private Double weight; // 偏好权重
+ private Timestamp lastUpdated; // 最后更新时间
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/g8backend/mapper/PostMapper.java b/src/main/java/com/example/g8backend/mapper/PostMapper.java
index b5a5280..baebb17 100644
--- a/src/main/java/com/example/g8backend/mapper/PostMapper.java
+++ b/src/main/java/com/example/g8backend/mapper/PostMapper.java
@@ -1,19 +1,14 @@
package com.example.g8backend.mapper;
-
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.example.g8backend.entity.Post;
import com.example.g8backend.entity.PostLike;
import org.apache.ibatis.annotations.*;
-
import java.util.List;
-
@Mapper
public interface PostMapper extends BaseMapper<Post> {
-
// 获取用户的帖子
List<Post> getPostsByUserId(@Param("userId") Long userId);
-
// 搜索帖子
@Select("<script>" +
"SELECT p.* " +
@@ -35,32 +30,24 @@
List<Post> searchPosts(@Param("keyword") String keyword,
@Param("tagIds") List<Long> tagIds,
@Param("author") String author);
-
// 检查用户是否已经点赞该帖子
@Select("SELECT EXISTS (SELECT 1 FROM post_likes WHERE user_id = #{userId} AND post_id = #{postId})")
boolean existsByUserIdAndPostId(@Param("userId") Long userId, @Param("postId") Long postId);
-
// 插入一条点赞记录
@Insert("INSERT INTO post_likes (user_id, post_id) VALUES (#{userId}, #{postId})")
void insert(PostLike postLike);
-
// 删除用户对帖子的点赞记录
@Delete("DELETE FROM post_likes WHERE user_id = #{userId} AND post_id = #{postId}")
void deleteLikeByUserIdAndPostId(@Param("userId") Long userId, @Param("postId") Long postId);
-
// 获取某个帖子点赞数
@Select("SELECT COUNT(*) FROM post_likes WHERE post_id = #{postId}")
Long selectCount(@Param("postId") Long postId);
-
@Update("UPDATE posts SET view_count = view_count + 1 WHERE post_id = #{postId}")
void incrementViewCount(Long postId);
-
@Select("SELECT COUNT(*) FROM post_likes WHERE post_id = #{postId}")
Long selectLikeCount(Long postId);
-
@Select("SELECT post_id FROM post_views WHERE user_id = #{userId}")
List<Long> findViewedPostIds(Long userId);
-
@Update({
"<script>",
"UPDATE posts",
@@ -77,4 +64,4 @@
"</script>"
})
int batchUpdateHotScore(@Param("posts") List<Post> posts);
-}
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/g8backend/mapper/PostTagMapper.java b/src/main/java/com/example/g8backend/mapper/PostTagMapper.java
index 184feb2..4c8cc1e 100644
--- a/src/main/java/com/example/g8backend/mapper/PostTagMapper.java
+++ b/src/main/java/com/example/g8backend/mapper/PostTagMapper.java
@@ -1,17 +1,27 @@
package com.example.g8backend.mapper;
-
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.example.g8backend.entity.Post;
import com.example.g8backend.entity.PostTag;
import com.example.g8backend.entity.Tag;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
-
+import org.apache.ibatis.annotations.Select;
import java.util.List;
-
@Mapper
public interface PostTagMapper extends BaseMapper<PostTag> {
List<Post> getPostsByTagIds(@Param("tagIds") Long[] tagIds);
List<Tag> getTagsByPostId(@Param("postId") Long postId);
int deleteByIds(@Param("postId") Long postId, @Param("tagId") Long tagId);
-}
+ @Select("SELECT tag_id FROM post_tag WHERE post_id = #{postId}")
+ List<Long> findTagIdsByPostId(Long postId);
+ @Select({
+ "<script>",
+ "SELECT post_id FROM post_tag",
+ "WHERE tag_id IN",
+ "<foreach item='tagId' collection='tagIds' open='(' separator=',' close=')'>",
+ "#{tagId}",
+ "</foreach>",
+ "</script>"
+ })
+ List<Long> findPostIdsByTagIds(@Param("tagIds") List<Long> tagIds);
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/g8backend/mapper/UserTagPreferenceMapper.java b/src/main/java/com/example/g8backend/mapper/UserTagPreferenceMapper.java
new file mode 100644
index 0000000..c59ea8e
--- /dev/null
+++ b/src/main/java/com/example/g8backend/mapper/UserTagPreferenceMapper.java
@@ -0,0 +1,22 @@
+package com.example.g8backend.mapper;
+import com.baomidou.mybatisplus.core.mapper.BaseMapper;
+import com.example.g8backend.entity.UserTagPreference;
+import org.apache.ibatis.annotations.Param;
+import org.apache.ibatis.annotations.Select;
+import org.apache.ibatis.annotations.Update;
+import java.util.List;
+public interface UserTagPreferenceMapper extends BaseMapper<UserTagPreference> {
+ /**
+ * 插入或更新用户标签偏好权重
+ */
+ @Update("INSERT INTO user_tag_preference (user_id, tag_id, weight) " +
+ "VALUES (#{userId}, #{tagId}, #{increment}) " +
+ "ON DUPLICATE KEY UPDATE weight = weight + #{increment}")
+ void insertOrUpdateWeight(
+ @Param("userId") Long userId,
+ @Param("tagId") Long tagId,
+ @Param("increment") Double increment
+ );
+ @Select("SELECT * FROM user_tag_preference WHERE user_id = #{userId}")
+ List<UserTagPreference> selectByUserId(@Param("userId") Long userId);
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/g8backend/service/IPostService.java b/src/main/java/com/example/g8backend/service/IPostService.java
index ca5376b..855952e 100644
--- a/src/main/java/com/example/g8backend/service/IPostService.java
+++ b/src/main/java/com/example/g8backend/service/IPostService.java
@@ -28,4 +28,6 @@
void calculateHotScores();
Page<Post> getRecommendedPosts(int page, int size, Long userId);
+
+ Page<Post> getRecommendedByTags(int page, int size, Long userId);
}
diff --git a/src/main/java/com/example/g8backend/service/impl/PostServiceImpl.java b/src/main/java/com/example/g8backend/service/impl/PostServiceImpl.java
index ae47d8a..5faa4ac 100644
--- a/src/main/java/com/example/g8backend/service/impl/PostServiceImpl.java
+++ b/src/main/java/com/example/g8backend/service/impl/PostServiceImpl.java
@@ -1,57 +1,47 @@
package com.example.g8backend.service.impl;
-
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
-import com.example.g8backend.entity.Post;
-import com.example.g8backend.entity.PostLike;
-import com.example.g8backend.entity.PostTag;
-import com.example.g8backend.entity.PostView;
-import com.example.g8backend.mapper.CommentMapper;
-import com.example.g8backend.mapper.PostMapper;
-import com.example.g8backend.mapper.PostViewMapper;
+import com.example.g8backend.entity.*;
+import com.example.g8backend.mapper.*;
import com.example.g8backend.service.IPostService;
import com.example.g8backend.service.IPostTagService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
-
import java.sql.Timestamp;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
-
+import java.util.stream.Collectors;
@Service
public class PostServiceImpl extends ServiceImpl<PostMapper, Post> implements IPostService {
-
private final PostMapper postMapper;
-
private final PostViewMapper postViewMapper;
-
private final CommentMapper commentMapper;
-
+ private final UserTagPreferenceMapper userTagPreferenceMapper;
+ private final PostTagMapper postTagMapper;
@Autowired
private IPostTagService postTagService;
-
- public PostServiceImpl(PostMapper postMapper, PostViewMapper postViewMapper, CommentMapper commentMapper) {
+ public PostServiceImpl(PostMapper postMapper, PostViewMapper postViewMapper, CommentMapper commentMapper,
+ UserTagPreferenceMapper userTagPreferenceMapper, PostTagMapper postTagMapper) {
this.postMapper = postMapper;
this.postViewMapper = postViewMapper;
this.commentMapper = commentMapper;
+ this.userTagPreferenceMapper = userTagPreferenceMapper;
+ this.postTagMapper = postTagMapper;
this.baseMapper = postMapper; // 重要:设置 baseMapper
}
-
@Override
public List<Post> getPostsByUserId(Long userId) {
return postMapper.getPostsByUserId(userId);
}
-
public void createPost(Post post) {
post.setHotScore(5.0); // 初始热度
post.setCreatedAt(new Timestamp(System.currentTimeMillis()));
save(post);
}
-
@Override
public void createPost(Post post, Long[] tagIds) {
post.setHotScore(5.0); // 初始热度
@@ -61,20 +51,17 @@
postTagService.save(new PostTag(post.getPostId(), tagId));
}
}
-
@Override
public Post updatePost(Post post) {
updateById(post);
return getById(post.getPostId());
}
-
@Override
public List<Post> getPostsByType(String postType) {
QueryWrapper<Post> wrapper = new QueryWrapper<>();
wrapper.eq("post_type", postType);
return list(wrapper);
}
-
@Override
public void likePost(Long userId, Long postId) {
// 检查用户是否已经点赞该帖子
@@ -85,14 +72,12 @@
postMapper.insert(postLike); // 执行插入点赞记录
}
}
-
// 取消点赞功能
@Override
public void unlikePost(Long userId, Long postId) {
// 删除用户对帖子的点赞记录
postMapper.deleteLikeByUserIdAndPostId(userId, postId); // 使用新的方法删除记录
}
-
// 获取帖子点赞数
@Override
public Long getPostLikeCount(Long postId) {
@@ -103,8 +88,6 @@
public List<Post> searchPosts(String keyword, List<Long> tagIds, String author) {
return postMapper.searchPosts(keyword, tagIds, author); // 调用mapper的搜索方法
}
-
-
@Override
@Transactional
public void recordViewHistory(Long userId, Long postId) {
@@ -114,11 +97,11 @@
.setPostId(postId)
.setViewTime(new Timestamp(System.currentTimeMillis()).toLocalDateTime());
postViewMapper.insert(view);
-
// 2. 原子更新浏览数
postMapper.incrementViewCount(postId); // 直接调用原子操作
+ // 3. 新增:更新用户标签偏好
+ updateUserTagPreference(userId, postId);
}
-
@Override
@Scheduled(cron = "0 */10 * * * *") // 每10分钟执行一次
@Transactional
@@ -126,7 +109,6 @@
// 1. 获取所有帖子
List<Post> posts = postMapper.selectList(new QueryWrapper<>());
Instant now = Instant.now();
-
// 2. 计算每个帖子的热度
posts.forEach(post -> {
// 计算时间衰减因子(以小时为单位)
@@ -134,40 +116,65 @@
post.getCreatedAt().toInstant(),
now
);
-
// 获取互动数据(点赞数、评论数)
Long likeCount = postMapper.selectLikeCount(post.getPostId());
Long commentCount = commentMapper.selectCountByPostId(post.getPostId());
-
// 热度计算公式
double hotScore = (
Math.log(post.getViewCount() + 1) * 0.2 +
likeCount * 0.5 +
commentCount * 0.3
) / Math.pow(hoursSinceCreation + 2, 1.5);
-
post.setHotScore(hotScore);
post.setLastCalculated(new Timestamp(System.currentTimeMillis()));
});
-
// 3. 批量更新热度(自定义SQL实现)
postMapper.batchUpdateHotScore(posts);
}
-
@Override
public Page<Post> getRecommendedPosts(int page, int size, Long userId) {
// 1. 获取用户已浏览的帖子ID列表
List<Long> viewedPostIds = postViewMapper.findViewedPostIds(userId);
-
// 2. 构建查询条件:排除已浏览帖子,按热度降序
QueryWrapper<Post> queryWrapper = new QueryWrapper<>();
if (!viewedPostIds.isEmpty()) {
queryWrapper.notIn("post_id", viewedPostIds);
}
queryWrapper.orderByDesc("hot_score");
-
// 3. 分页查询
return postMapper.selectPage(new Page<>(page, size), queryWrapper);
}
-
+ private void updateUserTagPreference(Long userId, Long postId) {
+ // 获取帖子关联的标签ID列表
+ List<Long> tagIds = postTagMapper.findTagIdsByPostId(postId);
+ // 对每个标签增加权重(示例:每次浏览 +0.1)
+ tagIds.forEach(tagId -> {
+ userTagPreferenceMapper.insertOrUpdateWeight(
+ userId,
+ tagId,
+ 0.1 // 权重增量
+ );
+ });
+ }
+ @Override
+ public Page<Post> getRecommendedByTags(int page, int size, Long userId) {
+ // 获取用户偏好标签
+ List<UserTagPreference> preferences = userTagPreferenceMapper.selectByUserId(userId);
+ if (preferences.isEmpty()) {
+ return new Page<>(page, size);
+ }
+ // 获取标签关联的帖子ID
+ List<Long> tagIds = preferences.stream()
+ .map(UserTagPreference::getTagId)
+ .collect(Collectors.toList());
+ List<Long> postIds = postTagMapper.findPostIdsByTagIds(tagIds);
+ if (postIds.isEmpty()) {
+ return new Page<>(page, size);
+ }
+ // 构建查询条件
+ QueryWrapper<Post> queryWrapper = new QueryWrapper<>();
+ queryWrapper.in("post_id", postIds) // 确保正确添加 IN 条件
+ .orderByDesc("hot_score"); // 确保排序条件正确
+ return postMapper.selectPage(new Page<>(page, size), queryWrapper);
+ }
}
\ No newline at end of file
diff --git a/src/main/resources/schema.sql b/src/main/resources/schema.sql
index e5ddbb4..c0659b3 100644
--- a/src/main/resources/schema.sql
+++ b/src/main/resources/schema.sql
@@ -6,7 +6,6 @@
`email` VARCHAR(255) NOT NULL UNIQUE,
`passkey` VARCHAR(255) NOT NULL UNIQUE
);
-
-- 种子表(保持不变)
CREATE TABLE IF NOT EXISTS `torrents` (
`torrent_id` INT AUTO_INCREMENT PRIMARY KEY,
@@ -16,7 +15,6 @@
`file_size` FLOAT NOT NULL,
FOREIGN KEY (`user_id`) REFERENCES `users`(`user_id`)
);
-
-- Peer表(保持不变)
CREATE TABLE IF NOT EXISTS `peers` (
`passkey` VARCHAR(255) NOT NULL,
@@ -30,7 +28,6 @@
FOREIGN KEY (`passkey`) REFERENCES `users`(`passkey`),
PRIMARY KEY (`passkey`, `info_hash`, `peer_id`)
);
-
-- 帖子表(新增 hot_score 和 last_calculated 字段)
CREATE TABLE IF NOT EXISTS `posts` (
`post_id` INT AUTO_INCREMENT PRIMARY KEY,
@@ -46,7 +43,6 @@
INDEX `idx_hot_score` (`hot_score`), -- 新增热度索引
INDEX `idx_post_type` (`post_type`) -- 新增类型索引
);
-
-- 标签表(保持不变)
CREATE TABLE IF NOT EXISTS `tags`(
`tag_id` INT AUTO_INCREMENT PRIMARY KEY,
@@ -54,7 +50,6 @@
`parent_id` INT DEFAULT NULL,
FOREIGN KEY (`parent_id`) REFERENCES `tags`(`tag_id`)
);
-
-- 帖子标签关联表(保持不变)
CREATE TABLE IF NOT EXISTS `post_tag` (
`post_id` INT NOT NULL,
@@ -63,7 +58,6 @@
FOREIGN KEY (`tag_id`) REFERENCES `tags`(`tag_id`),
PRIMARY KEY (`post_id`, `tag_id`)
);
-
-- 用户关注表(保持不变)
CREATE TABLE IF NOT EXISTS `user_follows` (
`follower_id` INT NOT NULL,
@@ -73,7 +67,6 @@
FOREIGN KEY (`followed_id`) REFERENCES `users`(`user_id`),
PRIMARY KEY (`follower_id`, `followed_id`)
);
-
-- 私信表(保持不变)
CREATE TABLE IF NOT EXISTS `private_messages` (
`message_id` INT AUTO_INCREMENT PRIMARY KEY,
@@ -85,7 +78,6 @@
FOREIGN KEY (`sender_id`) REFERENCES `users`(`user_id`),
FOREIGN KEY (`receiver_id`) REFERENCES `users`(`user_id`)
);
-
-- 评论表(保持不变)
CREATE TABLE IF NOT EXISTS `comments` (
`comment_id` INT AUTO_INCREMENT PRIMARY KEY,
@@ -99,7 +91,6 @@
FOREIGN KEY (`parent_comment_id`) REFERENCES `comments`(`comment_id`),
INDEX `idx_post_id` (`post_id`) -- 新增评论帖子索引
);
-
-- 帖子点赞表(保持不变)
CREATE TABLE IF NOT EXISTS `post_likes` (
`user_id` INT NOT NULL,
@@ -109,7 +100,6 @@
FOREIGN KEY (`post_id`) REFERENCES `posts`(`post_id`),
FOREIGN KEY (`user_id`) REFERENCES `users`(`user_id`)
);
-
-- 帖子浏览记录表(新增复合索引)
CREATE TABLE IF NOT EXISTS `post_views` (
`view_id` INT AUTO_INCREMENT PRIMARY KEY,
@@ -119,4 +109,13 @@
FOREIGN KEY (`user_id`) REFERENCES `users`(`user_id`),
FOREIGN KEY (`post_id`) REFERENCES `posts`(`post_id`),
INDEX `idx_user_view_time` (`user_id`, `view_time` DESC) -- 新增用户浏览时间索引
+);
+CREATE TABLE user_tag_preference (
+ user_id INT NOT NULL COMMENT '用户ID',
+ tag_id INT NOT NULL COMMENT '标签ID',
+ weight DOUBLE DEFAULT 1.0 COMMENT '偏好权重(浏览越多权重越高)',
+ last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '最后更新时间',
+ PRIMARY KEY (user_id, tag_id),
+ FOREIGN KEY (user_id) REFERENCES users(user_id),
+ FOREIGN KEY (tag_id) REFERENCES tags(tag_id)
);
\ No newline at end of file
diff --git a/src/test/java/com/example/g8backend/service/PostHistoryServiceTest.java b/src/test/java/com/example/g8backend/service/PostHistoryServiceTest.java
index 143422d..38b1c87 100644
--- a/src/test/java/com/example/g8backend/service/PostHistoryServiceTest.java
+++ b/src/test/java/com/example/g8backend/service/PostHistoryServiceTest.java
@@ -1,7 +1,7 @@
package com.example.g8backend.service;
-
import com.example.g8backend.entity.PostView;
import com.example.g8backend.mapper.PostMapper;
+import com.example.g8backend.mapper.PostTagMapper;
import com.example.g8backend.mapper.PostViewMapper;
import com.example.g8backend.service.impl.PostServiceImpl;
import org.junit.jupiter.api.Test;
@@ -11,47 +11,37 @@
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.transaction.annotation.Transactional;
-
import java.time.LocalDateTime;
-
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
-
@ExtendWith(MockitoExtension.class)
@Transactional
public class PostHistoryServiceTest {
-
@Mock
private PostViewMapper postViewMapper;
-
@Mock
private PostMapper postMapper;
-
+ @Mock
+ private PostTagMapper postTagMapper;
@InjectMocks
private PostServiceImpl postService;
-
@Test
public void testRecordViewHistory_NormalCase() {
// 测试数据
Long userId = 1L;
Long postId = 100L;
-
// 调用方法
postService.recordViewHistory(userId, postId);
-
// 验证行为
verify(postViewMapper, times(1)).insert(any(PostView.class));
verify(postMapper, times(1)).incrementViewCount(eq(postId));
}
-
@Test
public void testRecordViewHistory_CheckDataIntegrity() {
Long userId = 2L;
Long postId = 200L;
-
postService.recordViewHistory(userId, postId);
-
// 显式指定参数类型为 PostView
verify(postViewMapper).insert(argThat(new ArgumentMatcher<PostView>() {
@Override
@@ -68,7 +58,6 @@
Long postId = 300L;
postService.recordViewHistory(1L, postId);
postService.recordViewHistory(2L, postId);
-
// 验证浏览数更新次数
verify(postMapper, times(2)).incrementViewCount(postId);
}
diff --git a/src/test/java/com/example/g8backend/service/PostServiceRecommendTest.java b/src/test/java/com/example/g8backend/service/PostServiceRecommendTest.java
index 23890ba..e2a4c7f 100644
--- a/src/test/java/com/example/g8backend/service/PostServiceRecommendTest.java
+++ b/src/test/java/com/example/g8backend/service/PostServiceRecommendTest.java
@@ -1,50 +1,43 @@
package com.example.g8backend.service;
-
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.example.g8backend.entity.Post;
import com.example.g8backend.entity.PostView;
-import com.example.g8backend.mapper.CommentMapper;
-import com.example.g8backend.mapper.PostMapper;
-import com.example.g8backend.mapper.PostViewMapper;
+import com.example.g8backend.entity.UserTagPreference;
+import com.example.g8backend.mapper.*;
import com.example.g8backend.service.impl.PostServiceImpl;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.transaction.annotation.Transactional;
-
import java.sql.Timestamp;
import java.time.Instant;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-
+import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
-
@ExtendWith(MockitoExtension.class)
@Transactional
public class PostServiceRecommendTest {
-
@Mock
private PostMapper postMapper;
-
@Mock
private PostViewMapper postViewMapper;
-
@Mock
private CommentMapper commentMapper;
-
+ @Mock
+ private UserTagPreferenceMapper userTagPreferenceMapper;
+ @Mock
+ private PostTagMapper postTagMapper;
@InjectMocks
private PostServiceImpl postService;
-
private Post mockPost;
private PostView mockPostView;
-
@BeforeEach
void setUp() {
// 初始化测试数据
@@ -55,56 +48,55 @@
.setViewCount(50)
.setCreatedAt(Timestamp.from(Instant.now().minusSeconds(7200)))
.setHotScore(5.0);
-
mockPostView = new PostView()
.setViewId(1L)
.setUserId(200L)
.setPostId(1L)
.setViewTime(Timestamp.from(Instant.now()).toLocalDateTime());
}
-
+ @Test
+ void testRecordViewHistory_NormalCase() {
+ // 调用方法
+ postService.recordViewHistory(100L, 1L);
+ // 验证 postTagMapper 被调用
+ verify(postTagMapper).findTagIdsByPostId(1L);
+ verify(postViewMapper).insert(any(PostView.class));
+ verify(postMapper).incrementViewCount(1L);
+ }
@Test
public void testGetRecommendedPosts_ExcludesViewedPosts() {
// 模拟用户已浏览的帖子ID
Long userId = 200L;
when(postViewMapper.findViewedPostIds(userId))
.thenReturn(Arrays.asList(1L, 2L));
-
// 模拟推荐结果(未浏览的帖子)
Post recommendedPost = new Post().setPostId(3L).setHotScore(8.0);
when(postMapper.selectPage(any(Page.class), any(QueryWrapper.class)))
.thenReturn(new Page<Post>().setRecords(Collections.singletonList(recommendedPost)));
-
// 调用推荐接口
Page<Post> result = postService.getRecommendedPosts(1, 10, userId);
-
// 验证结果
assertEquals(1, result.getRecords().size(), "应返回1条推荐结果");
assertEquals(3L, result.getRecords().get(0).getPostId(), "推荐结果应为未浏览的帖子ID 3");
assertFalse(result.getRecords().stream().anyMatch(p -> p.getPostId() == 1L), "结果中不应包含已浏览的帖子ID 1");
}
-
@Test
public void testGetRecommendedPosts_NoViewedPosts() {
// 模拟用户未浏览任何帖子
Long userId = 300L;
when(postViewMapper.findViewedPostIds(userId))
.thenReturn(Collections.emptyList());
-
// 模拟推荐结果(所有帖子按热度排序)
Post post1 = new Post().setPostId(1L).setHotScore(7.5);
Post post2 = new Post().setPostId(2L).setHotScore(9.0);
when(postMapper.selectPage(any(Page.class), any(QueryWrapper.class)))
.thenReturn(new Page<Post>().setRecords(Arrays.asList(post2, post1)));
-
// 调用推荐接口
Page<Post> result = postService.getRecommendedPosts(1, 10, userId);
-
// 验证结果
assertEquals(2, result.getRecords().size(), "应返回所有帖子");
assertEquals(2L, result.getRecords().get(0).getPostId(), "热度更高的帖子应排在前面");
}
-
@Test
public void testCalculateHotScores_UpdatesHotScoreCorrectly() {
// 设置存根
@@ -113,20 +105,17 @@
when(postMapper.selectLikeCount(anyLong())).thenReturn(30L);
when(commentMapper.selectCountByPostId(anyLong())).thenReturn(20L);
when(postMapper.batchUpdateHotScore(anyList())).thenReturn(1);
-
// 执行并验证
postService.calculateHotScores();
double expectedScore = (Math.log(51) * 0.2 + 30 * 0.5 + 20 * 0.3) / Math.pow(4, 1.5);
assertEquals(expectedScore, mockPost.getHotScore(), 0.01);
verify(postMapper).batchUpdateHotScore(anyList());
}
-
//--------------------- 测试冷启动逻辑 ---------------------
@Test
public void testCreatePost_SetsInitialHotScore() {
Post newPost = new Post().setPostId(4L).setPostTitle("New Post");
postService.createPost(newPost);
-
assertEquals(5.0, newPost.getHotScore(), "新帖子的初始热度应为5.0");
assertNotNull(newPost.getCreatedAt(), "创建时间不应为空");
}
@@ -134,11 +123,82 @@
public void testConcurrentViewCountUpdate() {
// 设置存根
doNothing().when(postMapper).incrementViewCount(anyLong());
-
postService.recordViewHistory(100L, 1L);
postService.recordViewHistory(200L, 1L);
-
verify(postMapper, times(2)).incrementViewCount(1L);
verify(postViewMapper, times(2)).insert(any(PostView.class));
}
+ @Test
+ public void testGetRecommendedByTags_WithPreferredTags() {
+ // 模拟用户偏好标签
+ Long userId = 200L;
+ UserTagPreference pref1 = new UserTagPreference().setTagId(100L).setWeight(2.0);
+ UserTagPreference pref2 = new UserTagPreference().setTagId(200L).setWeight(1.5);
+ when(userTagPreferenceMapper.selectByUserId(userId))
+ .thenReturn(Arrays.asList(pref1, pref2));
+ // 模拟标签关联的帖子ID
+ List<Long> expectedPostIds = Arrays.asList(3L, 4L, 5L);
+ when(postTagMapper.findPostIdsByTagIds(Arrays.asList(100L, 200L)))
+ .thenReturn(expectedPostIds);
+ // 使用 ArgumentCaptor 捕获 QueryWrapper 对象
+ ArgumentCaptor<QueryWrapper<Post>> wrapperCaptor = ArgumentCaptor.forClass(QueryWrapper.class);
+ Page<Post> mockPage = new Page<Post>().setRecords(Arrays.asList(
+ new Post().setPostId(5L).setHotScore(9.0),
+ new Post().setPostId(3L).setHotScore(8.0),
+ new Post().setPostId(4L).setHotScore(7.5)
+ ));
+ when(postMapper.selectPage(any(Page.class), wrapperCaptor.capture()))
+ .thenReturn(mockPage);
+ // 调用标签推荐接口
+ Page<Post> result = postService.getRecommendedByTags(1, 10, userId);
+ // ---------- 验证查询条件 ----------
+ QueryWrapper<Post> actualWrapper = wrapperCaptor.getValue();
+ String sqlSegment = actualWrapper.getSqlSegment();
+ // 验证 SQL 条件格式
+ assertTrue(
+ sqlSegment.matches(".*post_id\\s+IN\\s*\\(.*\\).*"),
+ "应包含 post_id IN 条件,实际条件:" + sqlSegment
+ );
+ // 验证参数值(忽略顺序)
+ Map<String, Object> params = actualWrapper.getParamNameValuePairs();
+ List<Object> actualPostIds = new ArrayList<>(params.values());
+ assertEquals(3, actualPostIds.size(), "IN 条件应包含3个参数");
+ assertTrue(
+ actualPostIds.containsAll(expectedPostIds),
+ "参数应包含所有预期帖子ID,实际参数:" + actualPostIds
+ );
+ // ---------- 验证结果排序和内容 ----------
+ assertEquals(3, result.getRecords().size(), "应返回3条结果");
+ assertEquals(5L, result.getRecords().get(0).getPostId(), "热度最高的帖子应为ID 5");
+ assertEquals(3L, result.getRecords().get(1).getPostId(), "热度次高的帖子应为ID 3");
+ assertEquals(4L, result.getRecords().get(2).getPostId(), "热度最低的帖子应为ID 4");
+ }
+ @Test
+ public void testGetRecommendedByTags_NoPreferredTags() {
+ // 模拟用户无偏好标签
+ Long userId = 300L;
+ when(userTagPreferenceMapper.selectByUserId(userId))
+ .thenReturn(Collections.emptyList());
+ // 调用标签推荐接口
+ Page<Post> result = postService.getRecommendedByTags(1, 10, userId);
+ // 验证结果为空或兜底逻辑
+ assertTrue(result.getRecords().isEmpty(), "无偏好标签时应返回空结果");
+ }
+ @Test
+ public void testGetRecommendedByTags_WithNonExistingTags() {
+ // 模拟用户偏好标签(但无关联帖子)
+ Long userId = 400L;
+ UserTagPreference pref = new UserTagPreference().setTagId(999L).setWeight(2.0);
+ when(userTagPreferenceMapper.selectByUserId(userId))
+ .thenReturn(Collections.singletonList(pref));
+ when(postTagMapper.findPostIdsByTagIds(Collections.singletonList(999L)))
+ .thenReturn(Collections.emptyList());
+ // 调用标签推荐接口
+ Page<Post> result = postService.getRecommendedByTags(1, 10, userId);
+ // 验证结果为空
+ assertNotNull(result, "分页结果不应为null");
+ assertTrue(result.getRecords().isEmpty(), "无关联帖子时应返回空结果");
+ // 验证postMapper.selectPage未被调用
+ verify(postMapper, never()).selectPage(any(Page.class), any(QueryWrapper.class));
+ }
}
\ No newline at end of file
diff --git a/src/test/java/com/example/g8backend/service/PostServiceTest.java b/src/test/java/com/example/g8backend/service/PostServiceTest.java
index 5d19711..51bc26b 100644
--- a/src/test/java/com/example/g8backend/service/PostServiceTest.java
+++ b/src/test/java/com/example/g8backend/service/PostServiceTest.java
@@ -1,9 +1,6 @@
package com.example.g8backend.service;
-
import com.example.g8backend.entity.Post;
-import com.example.g8backend.mapper.CommentMapper;
-import com.example.g8backend.mapper.PostMapper;
-import com.example.g8backend.mapper.PostViewMapper;
+import com.example.g8backend.mapper.*;
import com.example.g8backend.service.impl.PostServiceImpl;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
@@ -12,39 +9,31 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.test.context.junit.jupiter.SpringExtension;
-
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.List;
-
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
-
@ExtendWith(SpringExtension.class)
@DisplayName("帖子服务测试")
class PostServiceTest {
-
@Mock
private PostMapper postMapper;
-
@Mock
private PostViewMapper postViewMapper;
-
@Mock
private CommentMapper commentMapper;
-
private PostServiceImpl postService;
-
+ private UserTagPreferenceMapper userTagPreferenceMapper;
+ private PostTagMapper postTagMapper;
private Post testPost;
-
@BeforeEach
void setUp() {
MockitoAnnotations.openMocks(this);
- postService = new PostServiceImpl(postMapper, postViewMapper,commentMapper);
+ postService = new PostServiceImpl(postMapper, postViewMapper,commentMapper,userTagPreferenceMapper,postTagMapper);
testPost = createTestPost();
}
-
private Post createTestPost() {
Post post = new Post();
post.setPostId(1L);
@@ -54,110 +43,88 @@
post.setCreatedAt(new Timestamp(System.currentTimeMillis()));
return post;
}
-
@Test
@DisplayName("创建帖子-成功")
void save_ShouldSucceed() {
// Arrange
when(postMapper.insert(any(Post.class))).thenReturn(1);
-
// Act
boolean result = postService.save(testPost);
-
// Assert
assertTrue(result);
verify(postMapper).insert(testPost);
}
-
@Test
@DisplayName("获取帖子-通过ID存在")
void getById_WhenExists_ShouldReturnPost() {
// Arrange
when(postMapper.selectById(1L)).thenReturn(testPost);
-
// Act
Post result = postService.getById(1L);
-
// Assert
assertNotNull(result);
assertEquals(testPost.getPostId(), result.getPostId());
verify(postMapper).selectById(1L);
}
-
@Test
@DisplayName("获取帖子-通过ID不存在")
void getById_WhenNotExists_ShouldReturnNull() {
// Arrange
when(postMapper.selectById(999L)).thenReturn(null);
-
// Act
Post result = postService.getById(999L);
-
// Assert
assertNull(result);
verify(postMapper).selectById(999L);
}
-
@Test
@DisplayName("更新帖子-成功")
void updateById_ShouldSucceed() {
// Arrange
when(postMapper.updateById(any(Post.class))).thenReturn(1);
-
// Act
boolean result = postService.updateById(testPost);
-
// Assert
assertTrue(result);
verify(postMapper).updateById(testPost);
}
-
@Test
@DisplayName("删除帖子-成功")
void removeById_ShouldSucceed() {
// Arrange
when(postMapper.deleteById(1L)).thenReturn(1);
-
// Act
boolean result = postService.removeById(1L);
-
// Assert
assertTrue(result);
verify(postMapper).deleteById(1L);
}
-
@Test
@DisplayName("获取用户帖子列表")
void getPostsByUserId_ShouldReturnPosts() {
// Arrange
List<Post> expectedPosts = Arrays.asList(testPost);
when(postMapper.getPostsByUserId(1L)).thenReturn(expectedPosts);
-
// Act
List<Post> result = postService.getPostsByUserId(1L);
-
// Assert
assertNotNull(result);
assertFalse(result.isEmpty());
assertEquals(testPost.getPostId(), result.get(0).getPostId());
verify(postMapper).getPostsByUserId(1L);
}
-
@Test
@DisplayName("获取用户帖子-空列表")
void getPostsByUserId_WhenNoPosts_ShouldReturnEmptyList() {
// Arrange
when(postMapper.getPostsByUserId(999L)).thenReturn(Arrays.asList());
-
// Act
List<Post> result = postService.getPostsByUserId(999L);
-
// Assert
assertNotNull(result);
assertTrue(result.isEmpty());
verify(postMapper).getPostsByUserId(999L);
}
-
// 新增测试方法:搜索帖子,支持多个标签
@Test
@DisplayName("搜索帖子-通过关键词和多个标签")
@@ -167,20 +134,16 @@
String keyword = "测试内容";
String author = "作者";
List<Post> expectedPosts = Arrays.asList(testPost);
-
// 模拟PostMapper的searchPosts方法
when(postMapper.searchPosts(keyword, tagIds, author)).thenReturn(expectedPosts);
-
// Act
List<Post> result = postService.searchPosts(keyword, tagIds, author);
-
// Assert
assertNotNull(result);
assertFalse(result.isEmpty());
assertEquals(testPost.getPostId(), result.get(0).getPostId());
verify(postMapper).searchPosts(keyword, tagIds, author);
}
-
@Test
@DisplayName("搜索帖子-没有匹配的帖子")
void searchPosts_WhenNoPosts_ShouldReturnEmptyList() {
@@ -189,17 +152,13 @@
String keyword = "不存在的内容";
String author = "不存在的作者";
List<Post> expectedPosts = Arrays.asList(); // 没有匹配的帖子
-
// 模拟PostMapper的searchPosts方法
when(postMapper.searchPosts(keyword, tagIds, author)).thenReturn(expectedPosts);
-
// Act
List<Post> result = postService.searchPosts(keyword, tagIds, author);
-
// Assert
assertNotNull(result);
assertTrue(result.isEmpty());
verify(postMapper).searchPosts(keyword, tagIds, author);
}
-
}
\ No newline at end of file