package com.pt.utils;

import java.io.*;
import java.util.*;

public class TorrentPasskeyModifier {

    private byte[] infoValueBytes; // 保存 info 的 value 原始 bencode

    public byte[] analyzeTorrentFile(byte[] fileBytes, String username) throws IOException {
        ByteArrayInputStream in = new ByteArrayInputStream(fileBytes);
        Map<String, Object> torrentMap = decodeWithInfoPreservation(in);

        // 修改 announce
        if (torrentMap.containsKey("announce")) {
            String announce = (String) torrentMap.get("announce");
            torrentMap.put("announce", replacePasskeyInUrl(announce, username));
        }

        // 修改 announce-list
        if (torrentMap.containsKey("announce-list")) {
            Object list = torrentMap.get("announce-list");
            if (list instanceof List) {
                replacePasskeyInAnnounceList((List<?>) list, username);
            }
        }

        // 编码为新 torrent 文件
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        out.write('d');

        List<String> keys = new ArrayList<>(torrentMap.keySet());
        Collections.sort(keys);
        for (String key : keys) {
            encodeString(key, out);
            if ("info".equals(key)) {
                out.write(infoValueBytes); // 写入原始 info 的 value bencode
            } else {
                encode(torrentMap.get(key), out);
            }
        }

        out.write('e');
        return out.toByteArray();
    }

    private String replacePasskeyInUrl(String url, String username) {
        if (url == null) return null;
        if (url.contains("passkey=")) {
            return url.replaceAll("(?<=passkey=)[^&]*", username);
        } else {
            return url.contains("?")
                    ? url + "&passkey=" + username
                    : url + "?passkey=" + username;
        }
    }

    private void replacePasskeyInAnnounceList(List<?> list, String username) {
        for (Object tierObj : list) {
            if (tierObj instanceof List) {
                List<Object> tier = (List<Object>) tierObj;
                for (int i = 0; i < tier.size(); i++) {
                    Object url = tier.get(i);
                    if (url instanceof String) {
                        tier.set(i, replacePasskeyInUrl((String) url, username));
                    }
                }
            }
        }
    }

    // --- Bencode 解码并提取原始 info 字典字节 ---
    private Map<String, Object> decodeWithInfoPreservation(InputStream in) throws IOException {
        if (in.read() != 'd') throw new IOException("Not a bencode dict");
        Map<String, Object> map = new LinkedHashMap<>();
        while (true) {
            int c = in.read();
            if (c == 'e') break;
            if (c == -1) throw new EOFException();

            String key = decodeString(in, (char) c);
            if ("info".equals(key)) {
                ByteArrayOutputStream infoOut = new ByteArrayOutputStream();
                in.mark(1);
                int b = in.read();
                if (b != 'd') throw new IOException("Invalid info dict");
                infoOut.write(b);
                int depth = 1;

                while (depth > 0) {
                    b = in.read();
                    if (b == -1) throw new IOException("Unexpected EOF in info");
                    infoOut.write(b);

                    if (b == 'd' || b == 'l') depth++;
                    else if (b == 'e') depth--;
                    else if (b >= '0' && b <= '9') {
                        int len = b - '0';
                        while (true) {
                            int nc = in.read();
                            infoOut.write(nc);
                            if (nc == ':') break;
                            len = len * 10 + (nc - '0');
                        }
                        for (int i = 0; i < len; i++) {
                            infoOut.write(in.read());
                        }
                    } else if (b == 'i') {
                        while (true) {
                            int nc = in.read();
                            infoOut.write(nc);
                            if (nc == 'e') break;
                        }
                    }
                }

                this.infoValueBytes = infoOut.toByteArray();
                map.put("info", null); // 占位
            } else {
                map.put(key, decode(in));
            }
        }
        return map;
    }

    private Object decode(InputStream in) throws IOException {
        int c = in.read();
        if (c == -1) throw new EOFException();
        if (c == 'd') return decodeDict(in);
        if (c == 'l') return decodeList(in);
        if (c == 'i') return decodeInt(in);
        if (c >= '0' && c <= '9') return decodeString(in, (char) c);
        throw new IOException("Invalid bencode start: " + (char) c);
    }

    private Map<String, Object> decodeDict(InputStream in) throws IOException {
        Map<String, Object> map = new LinkedHashMap<>();
        while (true) {
            int c = in.read();
            if (c == 'e') break;
            if (c == -1) throw new EOFException();
            String key = decodeString(in, (char) c);
            map.put(key, decode(in));
        }
        return map;
    }

    private List<Object> decodeList(InputStream in) throws IOException {
        List<Object> list = new ArrayList<>();
        while (true) {
            int c = in.read();
            if (c == 'e') break;
            if (c == -1) throw new EOFException();
            in.reset();
            list.add(decode(in));
        }
        return list;
    }

    private String decodeString(InputStream in, char firstDigit) throws IOException {
        StringBuilder sb = new StringBuilder();
        sb.append(firstDigit);
        while (true) {
            int c = in.read();
            if (c == ':') break;
            sb.append((char) c);
        }
        int len = Integer.parseInt(sb.toString());
        byte[] buf = new byte[len];
        if (in.read(buf) != len) throw new EOFException();
        return new String(buf, "UTF-8");
    }

    private Long decodeInt(InputStream in) throws IOException {
        StringBuilder sb = new StringBuilder();
        while (true) {
            int c = in.read();
            if (c == 'e') break;
            sb.append((char) c);
        }
        return Long.parseLong(sb.toString());
    }

    // --- 编码 ---
    private void encode(Object obj, OutputStream out) throws IOException {
        if (obj instanceof String) encodeString((String) obj, out);
        else if (obj instanceof Integer || obj instanceof Long) {
            out.write('i');
            out.write(obj.toString().getBytes());
            out.write('e');
        } else if (obj instanceof List) {
            out.write('l');
            for (Object item : (List<?>) obj) encode(item, out);
            out.write('e');
        } else if (obj instanceof Map) {
            out.write('d');
            Map<String, Object> map = (Map<String, Object>) obj;
            List<String> keys = new ArrayList<>(map.keySet());
            Collections.sort(keys);
            for (String key : keys) {
                encodeString(key, out);
                encode(map.get(key), out);
            }
            out.write('e');
        } else if (obj == null) {
            // 跳过
        } else {
            throw new IOException("Unsupported type: " + obj.getClass());
        }
    }

    private void encodeString(String str, OutputStream out) throws IOException {
        byte[] bytes = str.getBytes("UTF-8");
        out.write(Integer.toString(bytes.length).getBytes());
        out.write(':');
        out.write(bytes);
    }
}
