package com.pt.utils;

import java.io.*;
import java.net.InetAddress;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.*;

public class BencodeCodec {

    /* ------------- 编码部分 ------------- */

    public static void encode(Object obj, OutputStream out) throws IOException {
        if (obj instanceof String) {
            encodeString((String) obj, out);
        } else if (obj instanceof Number) {
            encodeInteger(((Number) obj).longValue(), out);
        } else if (obj instanceof byte[]) {
            encodeBytes((byte[]) obj, out);
        } else if (obj instanceof List) {
            encodeList((List<?>) obj, out);
        } else if (obj instanceof Map) {
            encodeMap((Map<String, Object>) obj, out);
        } else {
            throw new IllegalArgumentException("Unsupported type: " + obj.getClass());
        }
    }

    public static byte[] encode(Object obj) {
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
            encode(obj, baos);
            return baos.toByteArray();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private static void encodeString(String s, OutputStream out) throws IOException {
        byte[] bytes = s.getBytes(StandardCharsets.UTF_8);
        out.write(String.valueOf(bytes.length).getBytes(StandardCharsets.US_ASCII));
        out.write(':');
        out.write(bytes);
    }

    private static void encodeBytes(byte[] bytes, OutputStream out) throws IOException {
        out.write(String.valueOf(bytes.length).getBytes(StandardCharsets.US_ASCII));
        out.write(':');
        out.write(bytes);
    }

    private static void encodeInteger(long value, OutputStream out) throws IOException {
        out.write('i');
        out.write(Long.toString(value).getBytes(StandardCharsets.US_ASCII));
        out.write('e');
    }

    private static void encodeList(List<?> list, OutputStream out) throws IOException {
        out.write('l');
        for (Object item : list) {
            encode(item, out);
        }
        out.write('e');
    }

    private static void encodeMap(Map<String, Object> map, OutputStream out) throws IOException {
        out.write('d');
        List<String> keys = new ArrayList<>(map.keySet());
        Collections.sort(keys);  // bencode字典必须按key排序
        for (String key : keys) {
            encodeString(key, out);
            encode(map.get(key), out);
        }
        out.write('e');
    }

    /* ------------- 解码部分 ------------- */

    public static Object decode(byte[] data) throws IOException {
        try (ByteArrayInputStream in = new ByteArrayInputStream(data)) {
            in.mark(data.length);
            return decodeNext(in);
        }
    }

    private static Object decodeNext(InputStream in) throws IOException {
        int prefix = in.read();
        if (prefix == -1) {
            throw new IOException("Unexpected end of stream");
        }

        in.mark(1024);

        if (prefix >= '0' && prefix <= '9') {
            in.reset();
            return parseString(in);
        } else if (prefix == 'i') {
            return parseInteger(in);
        } else if (prefix == 'l') {
            return parseList(in);
        } else if (prefix == 'd') {
            return parseDict(in);
        } else {
            throw new IOException("Invalid bencode prefix: " + (char) prefix);
        }
    }

    private static String parseString(InputStream in) throws IOException {
        StringBuilder lenStr = new StringBuilder();
        int b;
        while ((b = in.read()) != -1 && b != ':') {
            if (b < '0' || b > '9') {
                throw new IOException("Invalid string length character: " + (char) b);
            }
            lenStr.append((char) b);
        }
        if (b == -1) {
            throw new IOException("Unexpected end of stream reading string length");
        }
        int length = Integer.parseInt(lenStr.toString());

        byte[] buf = new byte[length];
        int offset = 0;
        while (offset < length) {
            int read = in.read(buf, offset, length - offset);
            if (read == -1) {
                throw new IOException("Unexpected end of stream reading string data");
            }
            offset += read;
        }

        return new String(buf, StandardCharsets.UTF_8);
    }

    private static long parseInteger(InputStream in) throws IOException {
        StringBuilder intStr = new StringBuilder();
        int b;
        while ((b = in.read()) != -1 && b != 'e') {
            intStr.append((char) b);
        }
        if (b == -1) {
            throw new IOException("Unexpected end of stream reading integer");
        }
        return Long.parseLong(intStr.toString());
    }

    private static List<Object> parseList(InputStream in) throws IOException {
        List<Object> list = new ArrayList<>();
        int b;
        while (true) {
            in.mark(1);
            b = in.read();
            if (b == -1) {
                throw new IOException("Unexpected end of stream reading list");
            }
            if (b == 'e') {
                break;
            }
            in.reset();
            list.add(decodeNext(in));
        }
        return list;
    }

    private static Map<String, Object> parseDict(InputStream in) throws IOException {
        Map<String, Object> map = new LinkedHashMap<>();
        int b;
        while (true) {
            in.mark(1);
            b = in.read();
            if (b == -1) {
                throw new IOException("Unexpected end of stream reading dictionary");
            }
            if (b == 'e') {
                break;
            }
            in.reset();
            String key = (String) decodeNext(in);
            Object value = decodeNext(in);
            map.put(key, value);
        }
        return map;
    }

    /* ------------- 其他辅助方法 ------------- */

    // 构造单个compact peer的二进制格式 (4字节IP + 2字节端口)
    public static byte[] buildCompactPeer(String ip, int port) {
        try {
            InetAddress addr = InetAddress.getByName(ip);
            ByteBuffer buffer = ByteBuffer.allocate(6);
            buffer.put(addr.getAddress());
            buffer.putShort((short) port);
            return buffer.array();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    // 构造多个compact peer的二进制拼接
    public static byte[] buildCompactPeers(List<String> ips, List<Integer> ports) {
        if (ips.size() != ports.size()) throw new IllegalArgumentException("IPs and ports list size mismatch");
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        for (int i = 0; i < ips.size(); i++) {
            out.write(buildCompactPeer(ips.get(i), ports.get(i)), 0, 6);
        }
        return out.toByteArray();
    }

    // 构造tracker响应字典，至少包含interval和peers
    public static byte[] buildTrackerResponse(int interval, byte[] peersCompact) {
        Map<String, Object> dict = new LinkedHashMap<>();
        dict.put("interval", interval);
        dict.put("peers", peersCompact);
        return encode(dict);
    }
}
