post_rating
Change-Id: Ia1a6fb3f87b793a6307046e36951c1fb36b213c8
diff --git a/pom.xml b/pom.xml
index d4d9fc2..9238ad6 100644
--- a/pom.xml
+++ b/pom.xml
@@ -40,7 +40,6 @@
<version>3.5.11</version>
</dependency>
-
<dependency>
<groupId>com.mysql</groupId>
<artifactId>mysql-connector-j</artifactId>
diff --git a/src/main/java/com/example/g8backend/controller/PostController.java b/src/main/java/com/example/g8backend/controller/PostController.java
index 9ac733f..e5e4eab 100644
--- a/src/main/java/com/example/g8backend/controller/PostController.java
+++ b/src/main/java/com/example/g8backend/controller/PostController.java
@@ -8,17 +8,20 @@
import com.example.g8backend.entity.Post;
import com.example.g8backend.entity.PostView;
import com.example.g8backend.mapper.PostViewMapper;
+import com.example.g8backend.service.IPostRatingService;
import com.example.g8backend.service.IPostService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import java.util.List;
@RestController
@RequestMapping("/post")
+@Validated
public class PostController {
@Autowired
@@ -149,4 +152,43 @@
Page<Post> result = postService.getRecommendedByTags(page, size, userId);
return ResponseEntity.ok(ApiResponse.success(result));
}
+ @Autowired
+ private IPostRatingService postRatingService;
+
+ @PostMapping("/{postId}/rate")
+ public ResponseEntity<ApiResponse<String>> ratePost(
+ @PathVariable Long postId,
+ @RequestParam Integer rating) {
+ try {
+ long userId = (long) SecurityContextHolder.getContext().getAuthentication().getPrincipal();
+
+ // 调用服务层方法(服务层已内置校验逻辑)
+ postRatingService.ratePost(userId, postId, rating);
+
+ // 成功时返回空数据
+ return ResponseEntity.ok(ApiResponse.success("评分成功"));
+
+ } catch (IllegalArgumentException e) {
+ // 处理参数校验异常(如评分范围错误)
+ return ResponseEntity.badRequest()
+ .body(ApiResponse.error(400, e.getMessage()));
+
+ } catch (RuntimeException e) {
+ // 处理数据库操作失败等运行时异常
+ return ResponseEntity.internalServerError()
+ .body(ApiResponse.error(500, e.getMessage()));
+ }
+ }
+
+ @GetMapping("/{postId}/average-rating")
+ public ResponseEntity<ApiResponse<Double>> getAverageRating(@PathVariable Long postId) {
+ Double avg = postRatingService.getAverageRating(postId);
+ return ResponseEntity.ok(ApiResponse.success(avg));
+ }
+
+ @GetMapping("/{postId}/rating-users/count")
+ public ResponseEntity<ApiResponse<Long>> getRatingUserCount(@PathVariable Long postId) {
+ Long count = postRatingService.getRatingUserCount(postId);
+ return ResponseEntity.ok(ApiResponse.success(count));
+ }
}
diff --git a/src/main/java/com/example/g8backend/dto/ApiResponse.java b/src/main/java/com/example/g8backend/dto/ApiResponse.java
index fb6ed0e..8ad28f7 100644
--- a/src/main/java/com/example/g8backend/dto/ApiResponse.java
+++ b/src/main/java/com/example/g8backend/dto/ApiResponse.java
@@ -30,4 +30,5 @@
}
// Getters and Setters 略,也可使用 Lombok 注解
-}
\ No newline at end of file
+}
+
diff --git a/src/main/java/com/example/g8backend/entity/PostRating.java b/src/main/java/com/example/g8backend/entity/PostRating.java
new file mode 100644
index 0000000..ac592b2
--- /dev/null
+++ b/src/main/java/com/example/g8backend/entity/PostRating.java
@@ -0,0 +1,18 @@
+package com.example.g8backend.entity;
+
+import com.baomidou.mybatisplus.annotation.IdType;
+import com.baomidou.mybatisplus.annotation.TableId;
+import com.baomidou.mybatisplus.annotation.TableName;
+import lombok.Data;
+
+import java.time.LocalDateTime;
+
+@Data
+@TableName("post_ratings")
+public class PostRating {
+ @TableId(type = IdType.INPUT)
+ private Long userId;
+ private Long postId;
+ private Integer rating;
+ private LocalDateTime ratedAt;
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/g8backend/mapper/PostMapper.java b/src/main/java/com/example/g8backend/mapper/PostMapper.java
index baebb17..ff13000 100644
--- a/src/main/java/com/example/g8backend/mapper/PostMapper.java
+++ b/src/main/java/com/example/g8backend/mapper/PostMapper.java
@@ -64,4 +64,11 @@
"</script>"
})
int batchUpdateHotScore(@Param("posts") List<Post> posts);
+
+ @Update("UPDATE posts SET average_rating = #{averageRating}, rating_count = #{ratingCount} WHERE post_id = #{postId}")
+ void updateRatingStats(
+ @Param("postId") Long postId,
+ @Param("averageRating") Double averageRating,
+ @Param("ratingCount") Integer ratingCount
+ );
}
\ No newline at end of file
diff --git a/src/main/java/com/example/g8backend/mapper/PostRatingMapper.java b/src/main/java/com/example/g8backend/mapper/PostRatingMapper.java
new file mode 100644
index 0000000..edf2d26
--- /dev/null
+++ b/src/main/java/com/example/g8backend/mapper/PostRatingMapper.java
@@ -0,0 +1,18 @@
+package com.example.g8backend.mapper;
+
+import com.baomidou.mybatisplus.core.mapper.BaseMapper;
+import com.example.g8backend.entity.PostRating;
+import org.apache.ibatis.annotations.Param;
+import org.apache.ibatis.annotations.Select;
+
+public interface PostRatingMapper extends BaseMapper<PostRating> {
+ // 自定义查询平均分
+ @Select("SELECT AVG(rating) FROM post_ratings WHERE post_id = #{postId}")
+ Double calculateAverageRating(@Param("postId") Long postId);
+
+ @Select("SELECT COUNT(*) FROM post_ratings WHERE post_id = #{postId}")
+ Integer getRatingCount(@Param("postId") Long postId);
+
+ @Select("SELECT COUNT(DISTINCT user_id) FROM post_ratings WHERE post_id = #{postId}")
+ Long selectRatingUserCount(@Param("postId") Long postId);
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/g8backend/service/IPostRatingService.java b/src/main/java/com/example/g8backend/service/IPostRatingService.java
new file mode 100644
index 0000000..dc0a4f9
--- /dev/null
+++ b/src/main/java/com/example/g8backend/service/IPostRatingService.java
@@ -0,0 +1,10 @@
+package com.example.g8backend.service;
+
+import com.example.g8backend.dto.ApiResponse;
+
+public interface IPostRatingService {
+ void ratePost(Long userId, Long postId, Integer rating);
+ Double getAverageRating(Long postId);
+
+ Long getRatingUserCount(Long postId);
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/g8backend/service/impl/PostRatingServiceImpl.java b/src/main/java/com/example/g8backend/service/impl/PostRatingServiceImpl.java
new file mode 100644
index 0000000..9afb01d
--- /dev/null
+++ b/src/main/java/com/example/g8backend/service/impl/PostRatingServiceImpl.java
@@ -0,0 +1,55 @@
+package com.example.g8backend.service.impl;
+
+import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
+import com.example.g8backend.dto.ApiResponse;
+import com.example.g8backend.entity.Post;
+import com.example.g8backend.entity.PostRating;
+import com.example.g8backend.mapper.PostMapper;
+import com.example.g8backend.mapper.PostRatingMapper;
+import com.example.g8backend.service.IPostRatingService;
+import lombok.RequiredArgsConstructor;
+import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Transactional;
+
+@Service
+@RequiredArgsConstructor
+public class PostRatingServiceImpl extends ServiceImpl<PostRatingMapper, PostRating> implements IPostRatingService {
+
+ private final PostRatingMapper postRatingMapper;
+ private final PostMapper postMapper;
+
+ @Override
+ @Transactional
+ public void ratePost(Long userId, Long postId, Integer rating) {
+ // 校验评分范围
+ if (rating < 1 || rating > 5) {
+ throw new IllegalArgumentException("评分值必须在1到5之间");
+ }
+
+ // 插入或更新评分记录
+ PostRating postRating = new PostRating();
+ postRating.setUserId(userId);
+ postRating.setPostId(postId);
+ postRating.setRating(rating);
+ boolean success = postRatingMapper.insertOrUpdate(postRating);
+
+ if (!success) {
+ throw new RuntimeException("评分操作失败");
+ }
+
+ // 更新统计信息
+ Double avgRating = postRatingMapper.calculateAverageRating(postId);
+ Integer count = postRatingMapper.getRatingCount(postId);
+ postMapper.updateRatingStats(postId, avgRating, count);
+ }
+
+ @Override
+ public Double getAverageRating(Long postId) {
+ return postRatingMapper.calculateAverageRating(postId);
+ }
+
+ @Override
+ public Long getRatingUserCount(Long postId) {
+ return postRatingMapper.selectRatingUserCount(postId);
+ }
+}
\ No newline at end of file
diff --git a/src/main/resources/mapper/PostRatingMapper.xml b/src/main/resources/mapper/PostRatingMapper.xml
new file mode 100644
index 0000000..097eeb8
--- /dev/null
+++ b/src/main/resources/mapper/PostRatingMapper.xml
@@ -0,0 +1,5 @@
+<insert id="insertOrUpdate">
+ INSERT INTO post_ratings (user_id, post_id, rating)
+ VALUES (#{userId}, #{postId}, #{rating})
+ ON DUPLICATE KEY UPDATE rating = VALUES(rating)
+</insert>
\ No newline at end of file
diff --git a/src/main/resources/schema.sql b/src/main/resources/schema.sql
index 09276b4..f6d061b 100644
--- a/src/main/resources/schema.sql
+++ b/src/main/resources/schema.sql
@@ -51,6 +51,8 @@
`post_type` ENUM('resource', 'discussion') NOT NULL,
`created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
`last_calculated` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '最后热度计算时间',
+ `average_rating` DECIMAL(3,2) DEFAULT 0.00 COMMENT '帖子平均评分',
+ `rating_count` INT DEFAULT 0 COMMENT '总评分人数';
FOREIGN KEY (`user_id`) REFERENCES `users`(`user_id`),
FOREIGN KEY (`torrent_id`) REFERENCES `torrents`(`torrent_id`),
INDEX `idx_hot_score` (`hot_score`), -- 新增热度索引
@@ -132,3 +134,13 @@
FOREIGN KEY (user_id) REFERENCES users(user_id),
FOREIGN KEY (tag_id) REFERENCES tags(tag_id)
);
+CREATE TABLE IF NOT EXISTS `post_ratings` (
+ `user_id` INT NOT NULL COMMENT '用户ID',
+ `post_id` INT NOT NULL COMMENT '帖子ID',
+ `rating` TINYINT NOT NULL CHECK (`rating` BETWEEN 1 AND 5) COMMENT '评分值(1-5)',
+ `rated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '评分时间',
+ PRIMARY KEY (`user_id`, `post_id`), -- 确保每个用户对同一帖子只能评分一次
+ FOREIGN KEY (`user_id`) REFERENCES `users`(`user_id`),
+ FOREIGN KEY (`post_id`) REFERENCES `posts`(`post_id`),
+ INDEX idx_post_ratings_post_id ON post_ratings (post_id)
+);
\ No newline at end of file
diff --git a/src/test/java/com/example/g8backend/service/PostRatingServiceImplTest.java b/src/test/java/com/example/g8backend/service/PostRatingServiceImplTest.java
new file mode 100644
index 0000000..942080a
--- /dev/null
+++ b/src/test/java/com/example/g8backend/service/PostRatingServiceImplTest.java
@@ -0,0 +1,103 @@
+package com.example.g8backend.service;
+
+import com.example.g8backend.entity.PostRating;
+import com.example.g8backend.mapper.PostMapper;
+import com.example.g8backend.mapper.PostRatingMapper;
+import com.example.g8backend.service.impl.PostRatingServiceImpl;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.InjectMocks;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.springframework.transaction.annotation.Transactional;
+
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.ArgumentMatchers.*;
+import static org.mockito.Mockito.*;
+
+@ExtendWith(MockitoExtension.class)
+@Transactional
+public class PostRatingServiceImplTest {
+
+ @Mock
+ private PostRatingMapper postRatingMapper;
+
+ @Mock
+ private PostMapper postMapper;
+
+ @InjectMocks
+ private PostRatingServiceImpl postRatingService;
+
+ private final Long userId = 1L;
+ private final Long postId = 100L;
+ private final Integer validRating = 4;
+ private final Integer invalidRating = 6;
+
+ // 测试:合法评分应成功
+ @Test
+ public void testRatePost_Success() {
+ // 模拟依赖行为
+ when(postRatingMapper.insertOrUpdate(any(PostRating.class))).thenReturn(true);
+ when(postRatingMapper.calculateAverageRating(postId)).thenReturn(4.0);
+ when(postRatingMapper.getRatingCount(postId)).thenReturn(1);
+
+ // 调用方法并验证无异常
+ assertDoesNotThrow(() -> {
+ postRatingService.ratePost(userId, postId, validRating);
+ });
+
+ // 验证数据库交互
+ verify(postMapper).updateRatingStats(eq(postId), eq(4.0), eq(1));
+ }
+
+ // 测试:非法评分应抛出异常
+ @Test
+ public void testRatePost_InvalidRating() {
+ // 调用方法并验证异常
+ IllegalArgumentException exception = assertThrows(
+ IllegalArgumentException.class,
+ () -> postRatingService.ratePost(userId, postId, invalidRating)
+ );
+ assertEquals("评分值必须在1到5之间", exception.getMessage());
+
+ // 验证未调用数据库操作
+ verifyNoInteractions(postRatingMapper);
+ }
+
+ // 测试:重复评分应更新记录
+ @Test
+ public void testRatePost_UpdateExistingRating() {
+ // 模拟已存在评分
+ when(postRatingMapper.insertOrUpdate(any(PostRating.class))).thenReturn(true);
+ when(postRatingMapper.calculateAverageRating(postId)).thenReturn(3.5, 4.0); // 两次调用返回不同值
+ when(postRatingMapper.getRatingCount(postId)).thenReturn(1);
+
+ // 同一用户对同一帖子二次评分
+ assertDoesNotThrow(() -> {
+ postRatingService.ratePost(userId, postId, 3);
+ postRatingService.ratePost(userId, postId, 4);
+ });
+
+ // 验证两次更新统计信息
+ verify(postMapper, times(2)).updateRatingStats(eq(postId), anyDouble(), eq(1));
+ }
+
+ // 测试:数据库操作失败应抛出异常
+ @Test
+ public void testRatePost_DatabaseFailure() {
+ when(postRatingMapper.insertOrUpdate(any(PostRating.class))).thenReturn(false);
+ RuntimeException exception = assertThrows(
+ RuntimeException.class,
+ () -> postRatingService.ratePost(userId, postId, validRating)
+ );
+ assertEquals("评分操作失败", exception.getMessage());
+ }
+ // 测试:获取评分用户数量
+ @Test
+ public void testGetRatingUserCount() {
+ when(postRatingMapper.selectRatingUserCount(postId)).thenReturn(5L);
+ Long count = postRatingService.getRatingUserCount(postId);
+ assertEquals(5L, count);
+ }
+}
\ No newline at end of file