Java案例如何实现WebSocket通信?

wen java案例 71

本文目录导读:

Java案例如何实现WebSocket通信?

  1. 基于Java EE标准实现(使用注解)
  2. 基于Spring Boot实现(推荐)
  3. 基于Netty实现(高性能)
  4. WebSocket应用场景
  5. 最佳实践

我来详细讲解Java实现WebSocket通信的几种方式,并提供完整案例。

基于Java EE标准实现(使用注解)

Maven依赖

<dependency>
    <groupId>javax.websocket</groupId>
    <artifactId>javax.websocket-api</artifactId>
    <version>1.1</version>
    <scope>provided</scope>
</dependency>

WebSocket服务端

import javax.websocket.*;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.concurrent.CopyOnWriteArraySet;
@ServerEndpoint("/websocket/{userId}")
public class WebSocketServer {
    // 静态变量,用来记录当前在线连接数
    private static int onlineCount = 0;
    // concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象
    private static CopyOnWriteArraySet<WebSocketServer> webSocketSet = 
        new CopyOnWriteArraySet<>();
    // 与某个客户端的连接会话,需要通过它来给客户端发送数据
    private Session session;
    // 用户ID
    private String userId;
    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("userId") String userId) {
        this.session = session;
        this.userId = userId;
        webSocketSet.add(this);
        addOnlineCount();
        System.out.println("有新连接加入!当前在线人数为" + getOnlineCount());
        // 发送欢迎消息
        try {
            sendMessage("欢迎用户 " + userId + " 连接到WebSocket服务器");
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose() {
        webSocketSet.remove(this);
        subOnlineCount();
        System.out.println("有一连接关闭!当前在线人数为" + getOnlineCount());
    }
    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     * @param session 会话
     */
    @OnMessage
    public void onMessage(String message, Session session) {
        System.out.println("来自客户端" + userId + "的消息:" + message);
        // 群发消息
        for (WebSocketServer item : webSocketSet) {
            try {
                item.sendMessage("用户 " + userId + " 说: " + message);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
    /**
     * 发生错误时调用
     */
    @OnError
    public void onError(Session session, Throwable error) {
        System.out.println("发生错误");
        error.printStackTrace();
    }
    /**
     * 发送消息
     */
    public void sendMessage(String message) throws IOException {
        this.session.getBasicRemote().sendText(message);
    }
    /**
     * 向指定用户发送消息
     */
    public static void sendInfo(String message, @PathParam("userId") String userId) 
            throws IOException {
        for (WebSocketServer item : webSocketSet) {
            if (item.userId.equals(userId)) {
                item.sendMessage(message);
                break;
            }
        }
    }
    public static synchronized int getOnlineCount() {
        return onlineCount;
    }
    public static synchronized void addOnlineCount() {
        WebSocketServer.onlineCount++;
    }
    public static synchronized void subOnlineCount() {
        WebSocketServer.onlineCount--;
    }
}

WebSocket客户端(Java)

import javax.websocket.*;
import java.net.URI;
@ClientEndpoint
public class WebSocketClient {
    private Session session;
    private MessageHandler messageHandler;
    public WebSocketClient(URI endpointURI) {
        try {
            WebSocketContainer container = ContainerProvider.getWebSocketContainer();
            container.connectToServer(this, endpointURI);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    @OnOpen
    public void onOpen(Session session) {
        this.session = session;
        System.out.println("Connected to server");
    }
    @OnClose
    public void onClose(Session session, CloseReason reason) {
        System.out.println("Disconnected from server");
    }
    @OnMessage
    public void onMessage(String message) {
        if (messageHandler != null) {
            messageHandler.handleMessage(message);
        }
    }
    @OnError
    public void onError(Session session, Throwable error) {
        error.printStackTrace();
    }
    public void sendMessage(String message) {
        this.session.getAsyncRemote().sendText(message);
    }
    public void addMessageHandler(MessageHandler msgHandler) {
        this.messageHandler = msgHandler;
    }
    public interface MessageHandler {
        public void handleMessage(String message);
    }
}

客户端使用示例

public class ClientTest {
    public static void main(String[] args) {
        try {
            // 连接到WebSocket服务器
            WebSocketClient client = new WebSocketClient(
                new URI("ws://localhost:8080/websocket/user123")
            );
            // 添加消息处理器
            client.addMessageHandler(message -> {
                System.out.println("收到消息: " + message);
            });
            // 发送消息
            client.sendMessage("Hello Server!");
            // 保持连接一段时间
            Thread.sleep(10000);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

基于Spring Boot实现(推荐)

Maven依赖

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-websocket</artifactId>
</dependency>

WebSocket配置

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(myHandler(), "/myHandler")
                .addInterceptors(new HttpSessionHandshakeInterceptor())
                .setAllowedOrigins("*");
    }
    @Bean
    public WebSocketHandler myHandler() {
        return new MyHandler();
    }
    @Bean
    public ServerEndpointExporter serverEndpointExporter() {
        return new ServerEndpointExporter();
    }
}

WebSocket处理器

import org.springframework.web.socket.*;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class MyHandler extends TextWebSocketHandler {
    // 保存所有连接
    private static final Map<String, WebSocketSession> sessions = 
        new ConcurrentHashMap<>();
    @Override
    public void afterConnectionEstablished(WebSocketSession session) 
            throws Exception {
        // 获取用户ID (从握手信息中获取)
        String userId = getUserId(session);
        sessions.put(userId, session);
        // 发送连接成功消息
        session.sendMessage(new TextMessage(
            "连接成功! 当前在线人数: " + sessions.size()
        ));
    }
    @Override
    protected void handleTextMessage(WebSocketSession session, 
            TextMessage message) throws Exception {
        String payload = message.getPayload();
        String userId = getUserId(session);
        // 处理收到的消息
        System.out.println("收到用户 " + userId + " 的消息: " + payload);
        // 广播消息给所有客户端
        broadcast("用户 " + userId + " 说: " + payload);
    }
    @Override
    public void afterConnectionClosed(WebSocketSession session, 
            CloseStatus status) throws Exception {
        String userId = getUserId(session);
        sessions.remove(userId);
        System.out.println("用户 " + userId + " 断开连接");
    }
    @Override
    public void handleTransportError(WebSocketSession session, 
            Throwable exception) throws Exception {
        session.close(CloseStatus.SERVER_ERROR);
    }
    // 广播消息
    private void broadcast(String message) {
        for (WebSocketSession session : sessions.values()) {
            try {
                if (session.isOpen()) {
                    session.sendMessage(new TextMessage(message));
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
    // 向特定用户发送消息
    public void sendToUser(String userId, String message) {
        WebSocketSession session = sessions.get(userId);
        if (session != null && session.isOpen()) {
            try {
                session.sendMessage(new TextMessage(message));
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
    // 获取用户ID (从session attributes获取)
    private String getUserId(WebSocketSession session) {
        // 实际开发中可以传入token或session ID
        return session.getId(); // 这里简单使用sessionId
    }
}

前端JavaScript客户端

// 创建WebSocket连接
let socket = new WebSocket("ws://localhost:8080/myHandler");
// 连接成功
socket.onopen = function(event) {
    console.log("WebSocket连接成功");
    socket.send("Hello Server!");
};
// 接收消息
socket.onmessage = function(event) {
    let message = event.data;
    console.log("收到消息: " + message);
    // 更新UI
    document.getElementById('messages').innerHTML += 
        '<p>' + message + '</p>';
};
// 连接关闭
socket.onclose = function(event) {
    console.log("WebSocket连接关闭");
};
// 发送消息
function sendMessage() {
    let message = document.getElementById('messageInput').value;
    socket.send(message);
}
// 断开连接
function disconnect() {
    socket.close();
}

基于Netty实现(高性能)

Maven依赖

<dependency>
    <groupId>io.netty</groupId>
    <artifactId>netty-all</artifactId>
    <version>4.1.68.Final</version>
</dependency>

Netty WebSocket服务端

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.handler.stream.ChunkedWriteHandler;
public class NettyWebSocketServer {
    private final int port;
    public NettyWebSocketServer(int port) {
        this.port = port;
    }
    public void start() throws Exception {
        EventLoopGroup bossGroup = new NioEventLoopGroup(1);
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        try {
            ServerBootstrap bootstrap = new ServerBootstrap();
            bootstrap.group(bossGroup, workerGroup)
                     .channel(NioServerSocketChannel.class)
                     .childHandler(new ChannelInitializer<SocketChannel>() {
                         @Override
                         protected void initChannel(SocketChannel ch) {
                             ChannelPipeline pipeline = ch.pipeline();
                             // HTTP编解码器
                             pipeline.addLast(new HttpServerCodec());
                             // HTTP对象聚合
                             pipeline.addLast(new HttpObjectAggregator(65536));
                             // 大文件支持
                             pipeline.addLast(new ChunkedWriteHandler());
                             // WebSocket处理器
                             pipeline.addLast(new WebSocketServerHandler());
                         }
                     });
            ChannelFuture future = bootstrap.bind(port).sync();
            System.out.println("Netty WebSocket服务器启动在端口: " + port);
            future.channel().closeFuture().sync();
        } finally {
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
        }
    }
    public static void main(String[] args) throws Exception {
        new NettyWebSocketServer(8080).start();
    }
}

Netty WebSocket处理器

import io.netty.channel.*;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.util.AttributeKey;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object> {
    private static final Map<String, Channel> channels = new ConcurrentHashMap<>();
    private WebSocketServerHandshaker handshaker;
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, Object msg) 
            throws Exception {
        if (msg instanceof FullHttpRequest) {
            // 处理HTTP请求,完成WebSocket握手
            handleHttpRequest(ctx, (FullHttpRequest) msg);
        } else if (msg instanceof WebSocketFrame) {
            // 处理WebSocket帧
            handleWebSocketFrame(ctx, (WebSocketFrame) msg);
        }
    }
    private void handleHttpRequest(ChannelHandlerContext ctx, 
            FullHttpRequest request) {
        // 如果不是WebSocket握手请求
        if (!request.decoderResult().isSuccess() ||
            (!"websocket".equals(request.headers().get("Upgrade")))) {
            sendHttpResponse(ctx, request, 
                new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, 
                    HttpResponseStatus.BAD_REQUEST));
            return;
        }
        // 构造WebSocket握手响应
        WebSocketServerHandshakerFactory factory = 
            new WebSocketServerHandshakerFactory(
                "ws://localhost:8080/websocket", null, false);
        handshaker = factory.newHandshaker(request);
        if (handshaker == null) {
            WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
        } else {
            handshaker.handshake(ctx.channel(), request);
            // 保存连接
            String channelId = ctx.channel().id().asLongText();
            channels.put(channelId, ctx.channel());
            System.out.println("新客户端连接: " + channelId);
        }
    }
    private void handleWebSocketFrame(ChannelHandlerContext ctx, 
            WebSocketFrame frame) {
        // 关闭帧
        if (frame instanceof CloseWebSocketFrame) {
            handshaker.close(ctx.channel(), 
                (CloseWebSocketFrame) frame.retain());
            return;
        }
        // Ping帧
        if (frame instanceof PingWebSocketFrame) {
            ctx.channel().write(new PongWebSocketFrame(
                frame.content().retain()));
            return;
        }
        // 文本帧
        if (frame instanceof TextWebSocketFrame) {
            String request = ((TextWebSocketFrame) frame).text();
            System.out.println("收到消息: " + request);
            // 广播消息
            for (Channel channel : channels.values()) {
                if (channel.isActive()) {
                    channel.writeAndFlush(new TextWebSocketFrame(
                        "服务器回复: " + request));
                }
            }
        }
    }
    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        String channelId = ctx.channel().id().asLongText();
        channels.remove(channelId);
        System.out.println("客户端断开连接: " + channelId);
    }
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) 
            throws Exception {
        cause.printStackTrace();
        ctx.close();
    }
    private void sendHttpResponse(ChannelHandlerContext ctx, 
            FullHttpRequest request, FullHttpResponse response) {
        if (response.status().code() != 200) {
            response.headers().set(HttpHeaderNames.CONNECTION, 
                HttpHeaderValues.CLOSE);
        }
        ChannelFuture future = ctx.channel().writeAndFlush(response);
        if (!HttpUtil.isKeepAlive(request) || 
            response.status().code() != 200) {
            future.addListener(ChannelFutureListener.CLOSE);
        }
    }
}

WebSocket应用场景

在线聊天室

// 群聊消息
public class ChatRoom {
    private static Map<String, Set<WebSocketSession>> rooms = 
        new ConcurrentHashMap<>();
    public void joinRoom(String roomId, WebSocketSession session) {
        rooms.computeIfAbsent(roomId, k -> new CopyOnWriteArraySet<>())
             .add(session);
    }
    public void leaveRoom(String roomId, WebSocketSession session) {
        Set<WebSocketSession> room = rooms.get(roomId);
        if (room != null) {
            room.remove(session);
        }
    }
    public void sendToRoom(String roomId, String message) {
        Set<WebSocketSession> room = rooms.get(roomId);
        if (room != null) {
            for (WebSocketSession session : room) {
                try {
                    session.sendMessage(new TextMessage(message));
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }
}

实时数据推送

@Component
public class DataPushService {
    @Autowired
    private SimpMessagingTemplate messagingTemplate;
    // 定时推送数据
    @Scheduled(fixedRate = 5000)
    public void pushData() {
        Map<String, Object> data = new HashMap<>();
        data.put("timestamp", System.currentTimeMillis());
        data.put("value", Math.random() * 100);
        messagingTemplate.convertAndSend("/topic/data", data);
    }
}

最佳实践

心跳机制

@OnMessage
public void onMessage(String message, Session session) {
    if (message.equals("ping")) {
        try {
            session.getBasicRemote().sendText("pong");
        } catch (IOException e) {
            e.printStackTrace();
        }
    } else {
        // 处理业务消息
        handleBusinessMessage(message, session);
    }
}

重连机制

let ws = null;
let retryCount = 0;
const MAX_RETRY = 5;
function connect() {
    ws = new WebSocket("ws://localhost:8080/websocket");
    ws.onopen = function(event) {
        console.log("连接成功");
        retryCount = 0;
    };
    ws.onclose = function(event) {
        console.log("连接断开,尝试重连");
        if (retryCount < MAX_RETRY) {
            retryCount++;
            setTimeout(connect, 3000 * retryCount);
        }
    };
    ws.onerror = function(event) {
        console.log("连接错误");
    };
}
connect();

消息协议设计

// 定义消息格式
public class WebSocketMessage {
    private String type;      // 消息类型
    private String content;   // 消息内容
    private String sender;    // 发送者
    private long timestamp;   // 时间戳
    // getters and setters
}
// 使用JSON格式传输
@OnMessage
public void onMessage(String message, Session session) {
    WebSocketMessage wsMessage = 
        new ObjectMapper().readValue(message, WebSocketMessage.class);
    switch (wsMessage.getType()) {
        case "chat":
            handleChat(wsMessage);
            break;
        case "notification":
            handleNotification(wsMessage);
            break;
        case "ping":
            handlePing(session);
            break;
    }
}
  • Java EE标准:简单直接,适合小型项目
  • Spring Boot:功能完善,推荐用于实际项目
  • Netty:高性能,适合高并发场景
  • 关键点:心跳机制、重连机制、消息协议、安全性考虑

选择哪种实现方式取决于你的项目需求、技术栈和性能要求。

抱歉,评论功能暂时关闭!