blob: 23890badfa722e22648369170199ddab8ddec3f5 [file] [log] [blame]
package com.example.g8backend.service;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.example.g8backend.entity.Post;
import com.example.g8backend.entity.PostView;
import com.example.g8backend.mapper.CommentMapper;
import com.example.g8backend.mapper.PostMapper;
import com.example.g8backend.mapper.PostViewMapper;
import com.example.g8backend.service.impl.PostServiceImpl;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.transaction.annotation.Transactional;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
@Transactional
public class PostServiceRecommendTest {
@Mock
private PostMapper postMapper;
@Mock
private PostViewMapper postViewMapper;
@Mock
private CommentMapper commentMapper;
@InjectMocks
private PostServiceImpl postService;
private Post mockPost;
private PostView mockPostView;
@BeforeEach
void setUp() {
// 初始化测试数据
mockPost = new Post()
.setPostId(1L)
.setUserId(100L)
.setPostTitle("Test Post")
.setViewCount(50)
.setCreatedAt(Timestamp.from(Instant.now().minusSeconds(7200)))
.setHotScore(5.0);
mockPostView = new PostView()
.setViewId(1L)
.setUserId(200L)
.setPostId(1L)
.setViewTime(Timestamp.from(Instant.now()).toLocalDateTime());
}
@Test
public void testGetRecommendedPosts_ExcludesViewedPosts() {
// 模拟用户已浏览的帖子ID
Long userId = 200L;
when(postViewMapper.findViewedPostIds(userId))
.thenReturn(Arrays.asList(1L, 2L));
// 模拟推荐结果(未浏览的帖子)
Post recommendedPost = new Post().setPostId(3L).setHotScore(8.0);
when(postMapper.selectPage(any(Page.class), any(QueryWrapper.class)))
.thenReturn(new Page<Post>().setRecords(Collections.singletonList(recommendedPost)));
// 调用推荐接口
Page<Post> result = postService.getRecommendedPosts(1, 10, userId);
// 验证结果
assertEquals(1, result.getRecords().size(), "应返回1条推荐结果");
assertEquals(3L, result.getRecords().get(0).getPostId(), "推荐结果应为未浏览的帖子ID 3");
assertFalse(result.getRecords().stream().anyMatch(p -> p.getPostId() == 1L), "结果中不应包含已浏览的帖子ID 1");
}
@Test
public void testGetRecommendedPosts_NoViewedPosts() {
// 模拟用户未浏览任何帖子
Long userId = 300L;
when(postViewMapper.findViewedPostIds(userId))
.thenReturn(Collections.emptyList());
// 模拟推荐结果(所有帖子按热度排序)
Post post1 = new Post().setPostId(1L).setHotScore(7.5);
Post post2 = new Post().setPostId(2L).setHotScore(9.0);
when(postMapper.selectPage(any(Page.class), any(QueryWrapper.class)))
.thenReturn(new Page<Post>().setRecords(Arrays.asList(post2, post1)));
// 调用推荐接口
Page<Post> result = postService.getRecommendedPosts(1, 10, userId);
// 验证结果
assertEquals(2, result.getRecords().size(), "应返回所有帖子");
assertEquals(2L, result.getRecords().get(0).getPostId(), "热度更高的帖子应排在前面");
}
@Test
public void testCalculateHotScores_UpdatesHotScoreCorrectly() {
// 设置存根
when(postMapper.selectList(any(QueryWrapper.class)))
.thenReturn(Collections.singletonList(mockPost));
when(postMapper.selectLikeCount(anyLong())).thenReturn(30L);
when(commentMapper.selectCountByPostId(anyLong())).thenReturn(20L);
when(postMapper.batchUpdateHotScore(anyList())).thenReturn(1);
// 执行并验证
postService.calculateHotScores();
double expectedScore = (Math.log(51) * 0.2 + 30 * 0.5 + 20 * 0.3) / Math.pow(4, 1.5);
assertEquals(expectedScore, mockPost.getHotScore(), 0.01);
verify(postMapper).batchUpdateHotScore(anyList());
}
//--------------------- 测试冷启动逻辑 ---------------------
@Test
public void testCreatePost_SetsInitialHotScore() {
Post newPost = new Post().setPostId(4L).setPostTitle("New Post");
postService.createPost(newPost);
assertEquals(5.0, newPost.getHotScore(), "新帖子的初始热度应为5.0");
assertNotNull(newPost.getCreatedAt(), "创建时间不应为空");
}
@Test
public void testConcurrentViewCountUpdate() {
// 设置存根
doNothing().when(postMapper).incrementViewCount(anyLong());
postService.recordViewHistory(100L, 1L);
postService.recordViewHistory(200L, 1L);
verify(postMapper, times(2)).incrementViewCount(1L);
verify(postViewMapper, times(2)).insert(any(PostView.class));
}
}