blob: e3dfe5ad55e9b7b0b0f582b26e270df80a980f4e [file] [log] [blame]
package com.pt.utils;
import java.io.*;
import java.net.InetAddress;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.*;
public class BencodeCodec {
/* ------------- 编码部分 ------------- */
public static void encode(Object obj, OutputStream out) throws IOException {
if (obj instanceof String) {
encodeString((String) obj, out);
} else if (obj instanceof Number) {
encodeInteger(((Number) obj).longValue(), out);
} else if (obj instanceof byte[]) {
encodeBytes((byte[]) obj, out);
} else if (obj instanceof List) {
encodeList((List<?>) obj, out);
} else if (obj instanceof Map) {
encodeMap((Map<String, Object>) obj, out);
} else {
throw new IllegalArgumentException("Unsupported type: " + obj.getClass());
}
}
public static byte[] encode(Object obj) {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
encode(obj, baos);
return baos.toByteArray();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private static void encodeString(String s, OutputStream out) throws IOException {
byte[] bytes = s.getBytes(StandardCharsets.UTF_8);
out.write(String.valueOf(bytes.length).getBytes(StandardCharsets.US_ASCII));
out.write(':');
out.write(bytes);
}
private static void encodeBytes(byte[] bytes, OutputStream out) throws IOException {
out.write(String.valueOf(bytes.length).getBytes(StandardCharsets.US_ASCII));
out.write(':');
out.write(bytes);
}
private static void encodeInteger(long value, OutputStream out) throws IOException {
out.write('i');
out.write(Long.toString(value).getBytes(StandardCharsets.US_ASCII));
out.write('e');
}
private static void encodeList(List<?> list, OutputStream out) throws IOException {
out.write('l');
for (Object item : list) {
encode(item, out);
}
out.write('e');
}
private static void encodeMap(Map<String, Object> map, OutputStream out) throws IOException {
out.write('d');
List<String> keys = new ArrayList<>(map.keySet());
Collections.sort(keys); // bencode字典必须按key排序
for (String key : keys) {
encodeString(key, out);
encode(map.get(key), out);
}
out.write('e');
}
/* ------------- 解码部分 ------------- */
public static Object decode(byte[] data) throws IOException {
try (PushbackInputStream in = new PushbackInputStream(new ByteArrayInputStream(data))) {
return decodeNext(in);
}
}
private static Object decodeNext(PushbackInputStream in) throws IOException {
int prefix = in.read();
if (prefix == -1) {
throw new IOException("Unexpected end of stream");
}
// no mark/reset calls here
if (prefix >= '0' && prefix <= '9') {
// 字符串,回退这个字节,parseString自行读长度
in.unread(prefix);
return parseString(in);
} else if (prefix == 'i') {
return parseInteger(in);
} else if (prefix == 'l') {
return parseList(in);
} else if (prefix == 'd') {
return parseDict(in);
} else {
throw new IOException("Invalid bencode prefix: " + (char) prefix);
}
}
private static String parseString(InputStream in) throws IOException {
int ch;
StringBuilder lenBuilder = new StringBuilder();
while ((ch = in.read()) != -1 && ch != ':') {
if (!Character.isDigit(ch)) {
throw new IOException("Invalid string length prefix: " + (char) ch);
}
lenBuilder.append((char) ch);
}
if (ch != ':') {
throw new IOException("Expected ':' after string length");
}
int len = Integer.parseInt(lenBuilder.toString());
byte[] strBytes = new byte[len];
int read = 0;
while (read < len) {
int r = in.read(strBytes, read, len - read);
if (r == -1) {
throw new IOException("Unexpected end of stream when reading string");
}
read += r;
}
// 这里转换为 UTF-8 字符串返回,如果你确定是文本;如果是二进制可以改成返回byte[]
return new String(strBytes, StandardCharsets.UTF_8);
}
private static long parseInteger(InputStream in) throws IOException {
StringBuilder intStr = new StringBuilder();
int b;
while ((b = in.read()) != -1 && b != 'e') {
intStr.append((char) b);
}
if (b == -1) {
throw new IOException("Unexpected end of stream reading integer");
}
String intValue = intStr.toString();
System.out.println("Integer parsed raw: " + intValue); // debug line
return Long.parseLong(intValue);
}
private static List<Object> parseList(PushbackInputStream in) throws IOException {
List<Object> list = new ArrayList<>();
while (true) {
int ch = in.read();
if (ch == 'e') {
break;
}
if (ch == -1) {
throw new IOException("Unexpected end of stream in list");
}
in.unread(ch);
list.add(decodeNext(in));
}
return list;
}
private static Map<String, Object> parseDict(PushbackInputStream in) throws IOException {
Map<String, Object> map = new LinkedHashMap<>();
while (true) {
int ch = in.read();
if (ch == 'e') {
break; // 字典结束
}
if (ch == -1) {
throw new IOException("Unexpected end of stream in dict");
}
// 回退到上面读的字节,parseString 自己读长度
in.unread(ch);
String key = parseString(in);
Object value = decodeNext(in);
map.put(key, value);
}
return map;
}
/* ------------- 其他辅助方法 ------------- */
// 构造单个compact peer的二进制格式 (4字节IP + 2字节端口)
public static byte[] buildCompactPeer(String ip, int port) {
try {
InetAddress addr = InetAddress.getByName(ip);
byte[] ipBytes = addr.getAddress();
ByteBuffer buffer = ByteBuffer.allocate(ipBytes.length + 2);
buffer.put(ipBytes);
buffer.putShort((short) port);
return buffer.array();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
// 构造多个compact peer的二进制拼接
public static byte[] buildCompactPeers(List<String> ips, List<Integer> ports) {
if (ips.size() != ports.size()) throw new IllegalArgumentException("IPs and ports list size mismatch");
ByteArrayOutputStream out = new ByteArrayOutputStream();
for (int i = 0; i < ips.size(); i++) {
out.write(buildCompactPeer(ips.get(i), ports.get(i)), 0, 6);
}
return out.toByteArray();
}
// 构造tracker响应字典,至少包含interval和peers
public static byte[] buildTrackerResponse(int interval, byte[] peersCompact) {
Map<String, Object> dict = new LinkedHashMap<>();
dict.put("interval", interval);
dict.put("peers", peersCompact);
return encode(dict);
}
}