用Netty实现自定义协议的通信框架?

wen java案例 55

本文目录导读:

用Netty实现自定义协议的通信框架?

  1. 基础架构设计
  2. Netty Handler 实现
  3. 服务端启动器
  4. 客户端实现
  5. 序列化工厂
  6. 优化与最佳实践
  7. 测试示例

这是一个关于用 Netty 实现自定义协议通信框架的完整指南,包含核心设计思路、关键代码示例和最佳实践。

基础架构设计

1 协议定义(以典型的二进制协议为例)

// 协议格式:
// +--------+--------+--------+--------+--------+--------+--------+--------+
// | 魔数   | 版本   | 序列化 | 指令   | 数据长度 | 数据体 |
// | (4B)   | (1B)   | (1B)   | (2B)   | (4B)     | (N字节)|
// +--------+--------+--------+--------+--------+--------+--------+--------+
public class ProtocolHeader {
    private int magicNumber;    // 魔数:0xCAFEBABE
    private byte version;       // 版本号
    private byte serializer;    // 序列化方式:0-JAVA,1-JSON
    private short command;      // 指令类型
    private int dataLength;     // 数据体长度
}

2 消息对象定义

@Data
@Builder
public class ProtocolMessage<T> {
    private ProtocolHeader header;
    private T body;
}
// 具体消息类型
public class LoginRequest {
    private String userId;
    private String token;
}
public class LoginResponse {
    private boolean success;
    private String message;
}

Netty Handler 实现

1 编码器(Encoder)

public class ProtocolEncoder extends MessageToByteEncoder<ProtocolMessage<?>> {
    @Override
    protected void encode(ChannelHandlerContext ctx, 
                         ProtocolMessage<?> msg, ByteBuf out) throws Exception {
        ProtocolHeader header = msg.getHeader();
        // 写入头部信息
        out.writeInt(header.getMagicNumber());
        out.writeByte(header.getVersion());
        out.writeByte(header.getSerializer());
        out.writeShort(header.getCommand());
        // 序列化消息体
        Serializer serializer = SerializerFactory.getSerializer(header.getSerializer());
        byte[] bodyBytes = serializer.serialize(msg.getBody());
        // 写入数据体长度和数据
        out.writeInt(bodyBytes.length);
        out.writeBytes(bodyBytes);
    }
}

2 解码器(Decoder)

public class ProtocolDecoder extends ByteToMessageDecoder {
    private static final int HEADER_LENGTH = 12; // 4+1+1+2+4
    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf in, 
                         List<Object> out) throws Exception {
        // 检查是否有足够的数据
        if (in.readableBytes() < HEADER_LENGTH) {
            return;
        }
        // 标记当前读取位置
        in.markReaderIndex();
        // 读取魔数验证
        int magicNumber = in.readInt();
        if (magicNumber != ProtocolConstants.MAGIC_NUMBER) {
            in.resetReaderIndex();
            throw new RuntimeException("Invalid magic number: " + magicNumber);
        }
        // 读取版本和序列化方式
        byte version = in.readByte();
        byte serializerType = in.readByte();
        short command = in.readShort();
        // 读取数据长度并验证完整性
        int dataLength = in.readInt();
        if (in.readableBytes() < dataLength) {
            in.resetReaderIndex();
            return; // 数据不完整,等待更多数据
        }
        // 读取消息体
        byte[] bodyBytes = new byte[dataLength];
        in.readBytes(bodyBytes);
        // 反序列化消息体
        Serializer serializer = SerializerFactory.getSerializer(serializerType);
        Class<?> bodyClass = CommandRegistry.getBodyClass(command);
        Object body = serializer.deserialize(bodyBytes, bodyClass);
        // 构建完整消息
        ProtocolHeader header = ProtocolHeader.builder()
                .magicNumber(magicNumber)
                .version(version)
                .serializer(serializerType)
                .command(command)
                .dataLength(dataLength)
                .build();
        ProtocolMessage<Object> message = ProtocolMessage.builder()
                .header(header)
                .body(body)
                .build();
        out.add(message);
    }
}

3 业务处理器(Handler)

@Slf4j
public class ServerBusinessHandler extends SimpleChannelInboundHandler<ProtocolMessage<?>> {
    private final Map<Short, CommandHandler> handlerMap = new HashMap<>();
    public ServerBusinessHandler(ApplicationContext applicationContext) {
        // 初始化命令处理器
        Map<String, CommandHandler> handlers = 
            applicationContext.getBeansOfType(CommandHandler.class);
        handlers.values().forEach(h -> handlerMap.put(h.getCommand(), h));
    }
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, 
                               ProtocolMessage<?> msg) throws Exception {
        short command = msg.getHeader().getCommand();
        CommandHandler handler = handlerMap.get(command);
        if (handler != null) {
            Object response = handler.handle(msg.getBody());
            // 发送响应
            ctx.writeAndFlush(buildResponse(msg, response));
        } else {
            log.warn("No handler for command: {}", command);
        }
    }
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        log.error("Error: ", cause);
        ctx.close();
    }
}

服务端启动器

public class NettyServer {
    public void start(int port) throws InterruptedException {
        EventLoopGroup bossGroup = new NioEventLoopGroup();
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        try {
            ServerBootstrap bootstrap = new ServerBootstrap();
            bootstrap.group(bossGroup, workerGroup)
                    .channel(NioServerSocketChannel.class)
                    .option(ChannelOption.SO_BACKLOG, 1024)
                    .childOption(ChannelOption.SO_KEEPALIVE, true)
                    .childHandler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel ch) {
                            ChannelPipeline pipeline = ch.pipeline();
                            // 粘包/拆包处理
                            pipeline.addLast(new LengthFieldBasedFrameDecoder(
                                1024 * 1024,  // 最大帧长度
                                10,           // 长度字段偏移
                                4,            // 长度字段长度
                                0,            // 长度调整值
                                4             // 跳过的字节数
                            ));
                            pipeline.addLast(new ProtocolDecoder());
                            pipeline.addLast(new ProtocolEncoder());
                            pipeline.addLast(new ServerIdleStateHandler());
                            pipeline.addLast(new ServerBusinessHandler());
                        }
                    });
            ChannelFuture future = bootstrap.bind(port).sync();
            log.info("Server started on port: {}", port);
            future.channel().closeFuture().sync();
        } finally {
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
        }
    }
}

客户端实现

public class NettyClient {
    private Channel channel;
    private EventLoopGroup group;
    public void connect(String host, int port) throws InterruptedException {
        group = new NioEventLoopGroup();
        Bootstrap bootstrap = new Bootstrap();
        bootstrap.group(group)
                .channel(NioSocketChannel.class)
                .option(ChannelOption.TCP_NODELAY, true)
                .handler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel ch) {
                        ChannelPipeline pipeline = ch.pipeline();
                        pipeline.addLast(new ProtocolEncoder());
                        pipeline.addLast(new ProtocolDecoder());
                        pipeline.addLast(new ClientBusinessHandler());
                    }
                });
        ChannelFuture future = bootstrap.connect(host, port).sync();
        channel = future.channel();
    }
    public void sendMessage(ProtocolMessage<?> message) {
        if (channel != null && channel.isActive()) {
            channel.writeAndFlush(message);
        }
    }
    public void close() {
        if (channel != null) {
            channel.close();
        }
        if (group != null) {
            group.shutdownGracefully();
        }
    }
}

序列化工厂

public class SerializerFactory {
    private static final Map<Byte, Serializer> serializers = new HashMap<>();
    static {
        register((byte) 0, new JavaSerializer());
        register((byte) 1, new JsonSerializer());
    }
    public static void register(byte type, Serializer serializer) {
        serializers.put(type, serializer);
    }
    public static Serializer getSerializer(byte type) {
        Serializer serializer = serializers.get(type);
        if (serializer == null) {
            throw new RuntimeException("Unknown serializer type: " + type);
        }
        return serializer;
    }
}
// JSON序列化实现示例
public class JsonSerializer implements Serializer {
    private static final ObjectMapper mapper = new ObjectMapper();
    @Override
    public byte[] serialize(Object obj) throws Exception {
        return mapper.writeValueAsBytes(obj);
    }
    @Override
    public <T> T deserialize(byte[] bytes, Class<T> clazz) throws Exception {
        return mapper.readValue(bytes, clazz);
    }
}

优化与最佳实践

1 心跳检测

public class ServerIdleStateHandler extends IdleStateHandler {
    private static final int READER_IDLE_TIME = 15; // 秒
    public ServerIdleStateHandler() {
        super(READER_IDLE_TIME, 0, 0, TimeUnit.SECONDS);
    }
    @Override
    protected void channelIdle(ChannelHandlerContext ctx, IdleStateEvent evt) {
        if (evt.state() == IdleState.READER_IDLE) {
            ctx.close(); // 超过15秒未读取到数据,关闭连接
        }
    }
}

2 内存池优化

// 使用内存池
bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
        .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
// 编码器中重用 ByteBuf
protected void encode(ChannelHandlerContext ctx, 
                     ProtocolMessage<?> msg, ByteBuf out) throws Exception {
    ByteBuf buf = ctx.alloc().buffer();
    try {
        // 写入数据到 buf
        ctx.write(buf);
    } catch (Exception e) {
        buf.release();
        throw e;
    }
}

3 性能调优参数

// 线程模型优化
EventLoopGroup bossGroup = new EpollEventLoopGroup(1, 
    new DefaultThreadFactory("boss"));
EventLoopGroup workerGroup = new EpollEventLoopGroup(
    Runtime.getRuntime().availableProcessors() * 2,
    new DefaultThreadFactory("worker"));
// 其他优化参数
bootstrap.option(ChannelOption.SO_RCVBUF, 256 * 1024)
        .option(ChannelOption.SO_SNDBUF, 256 * 1024)
        .option(ChannelOption.WRITE_BUFFER_WATER_MARK, 
                new WriteBufferWaterMark(64 * 1024, 256 * 1024))
        .childOption(ChannelOption.TCP_NODELAY, true)
        .childOption(ChannelOption.SO_KEEPALIVE, true);

测试示例

@Test
public void testProtocol() throws Exception {
    NettyServer server = new NettyServer();
    new Thread(() -> {
        try {
            server.start(8080);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }).start();
    Thread.sleep(1000);
    NettyClient client = new NettyClient();
    client.connect("localhost", 8080);
    // 构造登录请求
    LoginRequest loginRequest = new LoginRequest("user123", "token456");
    ProtocolMessage<LoginRequest> message = ProtocolMessage.builder()
            .header(ProtocolHeader.builder()
                    .magicNumber(0xCAFEBABE)
                    .version((byte) 1)
                    .serializer((byte) 1) // JSON
                    .command((short) 100) // 登录指令
                    .build())
            .body(loginRequest)
            .build();
    client.sendMessage(message);
    Thread.sleep(5000);
    client.close();
    server.stop();
}

这个框架提供了完整的自定义协议实现方案,包括协议定义、编解码、粘包处理、序列化、心跳检测等核心功能,可以根据实际需求进行调整和扩展。

抱歉,评论功能暂时关闭!