blob: e3dfe5ad55e9b7b0b0f582b26e270df80a980f4e [file] [log] [blame]
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +08001package com.pt.utils;
2
3import java.io.*;
4import java.net.InetAddress;
5import java.nio.ByteBuffer;
6import java.nio.charset.StandardCharsets;
7import java.util.*;
8
9public class BencodeCodec {
10
11 /* ------------- 编码部分 ------------- */
12
13 public static void encode(Object obj, OutputStream out) throws IOException {
14 if (obj instanceof String) {
15 encodeString((String) obj, out);
16 } else if (obj instanceof Number) {
17 encodeInteger(((Number) obj).longValue(), out);
18 } else if (obj instanceof byte[]) {
19 encodeBytes((byte[]) obj, out);
20 } else if (obj instanceof List) {
21 encodeList((List<?>) obj, out);
22 } else if (obj instanceof Map) {
23 encodeMap((Map<String, Object>) obj, out);
24 } else {
25 throw new IllegalArgumentException("Unsupported type: " + obj.getClass());
26 }
27 }
28
29 public static byte[] encode(Object obj) {
30 try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
31 encode(obj, baos);
32 return baos.toByteArray();
33 } catch (IOException e) {
34 throw new RuntimeException(e);
35 }
36 }
37
38 private static void encodeString(String s, OutputStream out) throws IOException {
39 byte[] bytes = s.getBytes(StandardCharsets.UTF_8);
40 out.write(String.valueOf(bytes.length).getBytes(StandardCharsets.US_ASCII));
41 out.write(':');
42 out.write(bytes);
43 }
44
45 private static void encodeBytes(byte[] bytes, OutputStream out) throws IOException {
46 out.write(String.valueOf(bytes.length).getBytes(StandardCharsets.US_ASCII));
47 out.write(':');
48 out.write(bytes);
49 }
50
51 private static void encodeInteger(long value, OutputStream out) throws IOException {
52 out.write('i');
53 out.write(Long.toString(value).getBytes(StandardCharsets.US_ASCII));
54 out.write('e');
55 }
56
57 private static void encodeList(List<?> list, OutputStream out) throws IOException {
58 out.write('l');
59 for (Object item : list) {
60 encode(item, out);
61 }
62 out.write('e');
63 }
64
65 private static void encodeMap(Map<String, Object> map, OutputStream out) throws IOException {
66 out.write('d');
67 List<String> keys = new ArrayList<>(map.keySet());
68 Collections.sort(keys); // bencode字典必须按key排序
69 for (String key : keys) {
70 encodeString(key, out);
71 encode(map.get(key), out);
72 }
73 out.write('e');
74 }
75
76 /* ------------- 解码部分 ------------- */
77
78 public static Object decode(byte[] data) throws IOException {
22301102ca0fb2f2025-06-09 18:40:42 +080079 try (PushbackInputStream in = new PushbackInputStream(new ByteArrayInputStream(data))) {
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +080080 return decodeNext(in);
81 }
82 }
83
22301102ca0fb2f2025-06-09 18:40:42 +080084 private static Object decodeNext(PushbackInputStream in) throws IOException {
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +080085 int prefix = in.read();
86 if (prefix == -1) {
87 throw new IOException("Unexpected end of stream");
88 }
22301102ca0fb2f2025-06-09 18:40:42 +080089 // no mark/reset calls here
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +080090
91 if (prefix >= '0' && prefix <= '9') {
22301102ca0fb2f2025-06-09 18:40:42 +080092 // 字符串,回退这个字节,parseString自行读长度
93 in.unread(prefix);
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +080094 return parseString(in);
95 } else if (prefix == 'i') {
96 return parseInteger(in);
97 } else if (prefix == 'l') {
98 return parseList(in);
99 } else if (prefix == 'd') {
100 return parseDict(in);
101 } else {
102 throw new IOException("Invalid bencode prefix: " + (char) prefix);
103 }
104 }
105
22301102ca0fb2f2025-06-09 18:40:42 +0800106
107
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800108 private static String parseString(InputStream in) throws IOException {
22301102ca0fb2f2025-06-09 18:40:42 +0800109 int ch;
110 StringBuilder lenBuilder = new StringBuilder();
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800111
22301102ca0fb2f2025-06-09 18:40:42 +0800112 while ((ch = in.read()) != -1 && ch != ':') {
113 if (!Character.isDigit(ch)) {
114 throw new IOException("Invalid string length prefix: " + (char) ch);
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800115 }
22301102ca0fb2f2025-06-09 18:40:42 +0800116 lenBuilder.append((char) ch);
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800117 }
118
22301102ca0fb2f2025-06-09 18:40:42 +0800119 if (ch != ':') {
120 throw new IOException("Expected ':' after string length");
121 }
122
123 int len = Integer.parseInt(lenBuilder.toString());
124 byte[] strBytes = new byte[len];
125
126 int read = 0;
127 while (read < len) {
128 int r = in.read(strBytes, read, len - read);
129 if (r == -1) {
130 throw new IOException("Unexpected end of stream when reading string");
131 }
132 read += r;
133 }
134
135 // 这里转换为 UTF-8 字符串返回,如果你确定是文本;如果是二进制可以改成返回byte[]
136 return new String(strBytes, StandardCharsets.UTF_8);
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800137 }
138
22301102ca0fb2f2025-06-09 18:40:42 +0800139
140
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800141 private static long parseInteger(InputStream in) throws IOException {
142 StringBuilder intStr = new StringBuilder();
143 int b;
144 while ((b = in.read()) != -1 && b != 'e') {
145 intStr.append((char) b);
146 }
147 if (b == -1) {
148 throw new IOException("Unexpected end of stream reading integer");
149 }
22301102ca0fb2f2025-06-09 18:40:42 +0800150
151 String intValue = intStr.toString();
152 System.out.println("Integer parsed raw: " + intValue); // debug line
153
154 return Long.parseLong(intValue);
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800155 }
156
22301102ca0fb2f2025-06-09 18:40:42 +0800157 private static List<Object> parseList(PushbackInputStream in) throws IOException {
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800158 List<Object> list = new ArrayList<>();
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800159 while (true) {
22301102ca0fb2f2025-06-09 18:40:42 +0800160 int ch = in.read();
161 if (ch == 'e') {
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800162 break;
163 }
22301102ca0fb2f2025-06-09 18:40:42 +0800164 if (ch == -1) {
165 throw new IOException("Unexpected end of stream in list");
166 }
167 in.unread(ch);
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800168 list.add(decodeNext(in));
169 }
170 return list;
171 }
172
22301102ca0fb2f2025-06-09 18:40:42 +0800173 private static Map<String, Object> parseDict(PushbackInputStream in) throws IOException {
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800174 Map<String, Object> map = new LinkedHashMap<>();
22301102ca0fb2f2025-06-09 18:40:42 +0800175
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800176 while (true) {
22301102ca0fb2f2025-06-09 18:40:42 +0800177 int ch = in.read();
178 if (ch == 'e') {
179 break; // 字典结束
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800180 }
22301102ca0fb2f2025-06-09 18:40:42 +0800181 if (ch == -1) {
182 throw new IOException("Unexpected end of stream in dict");
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800183 }
22301102ca0fb2f2025-06-09 18:40:42 +0800184 // 回退到上面读的字节,parseString 自己读长度
185 in.unread(ch);
186
187 String key = parseString(in);
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800188 Object value = decodeNext(in);
189 map.put(key, value);
190 }
191 return map;
192 }
193
194 /* ------------- 其他辅助方法 ------------- */
195
196 // 构造单个compact peer的二进制格式 (4字节IP + 2字节端口)
197 public static byte[] buildCompactPeer(String ip, int port) {
198 try {
199 InetAddress addr = InetAddress.getByName(ip);
22301102ca0fb2f2025-06-09 18:40:42 +0800200 byte[] ipBytes = addr.getAddress();
201 ByteBuffer buffer = ByteBuffer.allocate(ipBytes.length + 2);
202 buffer.put(ipBytes);
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800203 buffer.putShort((short) port);
204 return buffer.array();
205 } catch (IOException e) {
206 throw new RuntimeException(e);
207 }
208 }
209
22301102ca0fb2f2025-06-09 18:40:42 +0800210
Edwardsamaxlf1bf7ad2025-06-03 23:52:16 +0800211 // 构造多个compact peer的二进制拼接
212 public static byte[] buildCompactPeers(List<String> ips, List<Integer> ports) {
213 if (ips.size() != ports.size()) throw new IllegalArgumentException("IPs and ports list size mismatch");
214 ByteArrayOutputStream out = new ByteArrayOutputStream();
215 for (int i = 0; i < ips.size(); i++) {
216 out.write(buildCompactPeer(ips.get(i), ports.get(i)), 0, 6);
217 }
218 return out.toByteArray();
219 }
220
221 // 构造tracker响应字典,至少包含interval和peers
222 public static byte[] buildTrackerResponse(int interval, byte[] peersCompact) {
223 Map<String, Object> dict = new LinkedHashMap<>();
224 dict.put("interval", interval);
225 dict.put("peers", peersCompact);
226 return encode(dict);
227 }
228}