package com.github.example.pt.websocket;

import cn.dev33.satoken.stp.StpUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.example.pt.entity.User;
import com.github.example.pt.service.UserService;
import com.github.example.pt.controller.chat.dto.ChatMessageDTO;
import com.github.example.pt.repository.ChatMessageRepository;
import com.github.example.pt.entity.ChatMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.*;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import java.io.IOException;
import java.time.LocalDateTime;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Component
public class ChatRawWebSocketHandler extends TextWebSocketHandler {
    private static final Map<Long, WebSocketSession> sessions = new ConcurrentHashMap<>();
    @Autowired
    private UserService userService;
    @Autowired
    private ChatMessageRepository chatMessageRepository;
    private final ObjectMapper objectMapper = new ObjectMapper();

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        // 鉴权：从 session attributes 获取 token
        String token = (String) session.getAttributes().get("sapling-token");
        if (token == null) {
            session.close(CloseStatus.NOT_ACCEPTABLE.withReason("No token"));
            return;
        }
        Long userId;
        try {
            userId = Long.valueOf(StpUtil.getLoginIdByToken(token).toString());
        } catch (Exception e) {
            session.close(CloseStatus.NOT_ACCEPTABLE.withReason("Invalid token"));
            return;
        }
        sessions.put(userId, session);
    }

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws IOException {
        String token = (String) session.getAttributes().get("sapling-token");
        if (token == null) {
            session.close(CloseStatus.NOT_ACCEPTABLE.withReason("No token"));
            return;
        }
        Long userId;
        try {
            userId = Long.valueOf(StpUtil.getLoginIdByToken(token).toString());
        } catch (Exception e) {
            session.close(CloseStatus.NOT_ACCEPTABLE.withReason("Invalid token"));
            return;
        }
        User user = userService.getUser(userId);
        if (user == null) {
            session.close(CloseStatus.NOT_ACCEPTABLE.withReason("User not found"));
            return;
        }
        // 解析前端消息
        ChatMessageDTO dto = objectMapper.readValue(message.getPayload(), ChatMessageDTO.class);
        // 构建并保存消息实体
        ChatMessage chatMessage = new ChatMessage();
        chatMessage.setRoomId(dto.getRoomId() != null ? dto.getRoomId() : 1L);
        chatMessage.setUserId(userId);
        chatMessage.setContent(dto.getContent());
        chatMessage.setCreatedAt(LocalDateTime.now());
        chatMessageRepository.save(chatMessage);
        // 构建广播DTO
        ChatMessageDTO broadcast = new ChatMessageDTO();
        broadcast.setRoomId(chatMessage.getRoomId());
        broadcast.setContent(chatMessage.getContent());
        broadcast.setUsername(user.getUsername());
        broadcast.setAvatar(user.getAvatar());
        broadcast.setTimestamp(System.currentTimeMillis());
        String json = objectMapper.writeValueAsString(broadcast);
        // 广播所有在线用户
        for (WebSocketSession ws : sessions.values()) {
            if (ws.isOpen()) {
                ws.sendMessage(new TextMessage(json));
            }
        }
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
        // 移除会话
        sessions.entrySet().removeIf(e -> e.getValue().equals(session));
    }
}