本文目录导读:

我来详细讲解Java实现WebSocket通信的几种方式,并提供完整案例。
基于Java EE标准实现(使用注解)
Maven依赖
<dependency>
<groupId>javax.websocket</groupId>
<artifactId>javax.websocket-api</artifactId>
<version>1.1</version>
<scope>provided</scope>
</dependency>
WebSocket服务端
import javax.websocket.*;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.concurrent.CopyOnWriteArraySet;
@ServerEndpoint("/websocket/{userId}")
public class WebSocketServer {
// 静态变量,用来记录当前在线连接数
private static int onlineCount = 0;
// concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象
private static CopyOnWriteArraySet<WebSocketServer> webSocketSet =
new CopyOnWriteArraySet<>();
// 与某个客户端的连接会话,需要通过它来给客户端发送数据
private Session session;
// 用户ID
private String userId;
/**
* 连接建立成功调用的方法
*/
@OnOpen
public void onOpen(Session session, @PathParam("userId") String userId) {
this.session = session;
this.userId = userId;
webSocketSet.add(this);
addOnlineCount();
System.out.println("有新连接加入!当前在线人数为" + getOnlineCount());
// 发送欢迎消息
try {
sendMessage("欢迎用户 " + userId + " 连接到WebSocket服务器");
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 连接关闭调用的方法
*/
@OnClose
public void onClose() {
webSocketSet.remove(this);
subOnlineCount();
System.out.println("有一连接关闭!当前在线人数为" + getOnlineCount());
}
/**
* 收到客户端消息后调用的方法
*
* @param message 客户端发送过来的消息
* @param session 会话
*/
@OnMessage
public void onMessage(String message, Session session) {
System.out.println("来自客户端" + userId + "的消息:" + message);
// 群发消息
for (WebSocketServer item : webSocketSet) {
try {
item.sendMessage("用户 " + userId + " 说: " + message);
} catch (IOException e) {
e.printStackTrace();
}
}
}
/**
* 发生错误时调用
*/
@OnError
public void onError(Session session, Throwable error) {
System.out.println("发生错误");
error.printStackTrace();
}
/**
* 发送消息
*/
public void sendMessage(String message) throws IOException {
this.session.getBasicRemote().sendText(message);
}
/**
* 向指定用户发送消息
*/
public static void sendInfo(String message, @PathParam("userId") String userId)
throws IOException {
for (WebSocketServer item : webSocketSet) {
if (item.userId.equals(userId)) {
item.sendMessage(message);
break;
}
}
}
public static synchronized int getOnlineCount() {
return onlineCount;
}
public static synchronized void addOnlineCount() {
WebSocketServer.onlineCount++;
}
public static synchronized void subOnlineCount() {
WebSocketServer.onlineCount--;
}
}
WebSocket客户端(Java)
import javax.websocket.*;
import java.net.URI;
@ClientEndpoint
public class WebSocketClient {
private Session session;
private MessageHandler messageHandler;
public WebSocketClient(URI endpointURI) {
try {
WebSocketContainer container = ContainerProvider.getWebSocketContainer();
container.connectToServer(this, endpointURI);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@OnOpen
public void onOpen(Session session) {
this.session = session;
System.out.println("Connected to server");
}
@OnClose
public void onClose(Session session, CloseReason reason) {
System.out.println("Disconnected from server");
}
@OnMessage
public void onMessage(String message) {
if (messageHandler != null) {
messageHandler.handleMessage(message);
}
}
@OnError
public void onError(Session session, Throwable error) {
error.printStackTrace();
}
public void sendMessage(String message) {
this.session.getAsyncRemote().sendText(message);
}
public void addMessageHandler(MessageHandler msgHandler) {
this.messageHandler = msgHandler;
}
public interface MessageHandler {
public void handleMessage(String message);
}
}
客户端使用示例
public class ClientTest {
public static void main(String[] args) {
try {
// 连接到WebSocket服务器
WebSocketClient client = new WebSocketClient(
new URI("ws://localhost:8080/websocket/user123")
);
// 添加消息处理器
client.addMessageHandler(message -> {
System.out.println("收到消息: " + message);
});
// 发送消息
client.sendMessage("Hello Server!");
// 保持连接一段时间
Thread.sleep(10000);
} catch (Exception e) {
e.printStackTrace();
}
}
}
基于Spring Boot实现(推荐)
Maven依赖
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
WebSocket配置
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(myHandler(), "/myHandler")
.addInterceptors(new HttpSessionHandshakeInterceptor())
.setAllowedOrigins("*");
}
@Bean
public WebSocketHandler myHandler() {
return new MyHandler();
}
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
}
WebSocket处理器
import org.springframework.web.socket.*;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class MyHandler extends TextWebSocketHandler {
// 保存所有连接
private static final Map<String, WebSocketSession> sessions =
new ConcurrentHashMap<>();
@Override
public void afterConnectionEstablished(WebSocketSession session)
throws Exception {
// 获取用户ID (从握手信息中获取)
String userId = getUserId(session);
sessions.put(userId, session);
// 发送连接成功消息
session.sendMessage(new TextMessage(
"连接成功! 当前在线人数: " + sessions.size()
));
}
@Override
protected void handleTextMessage(WebSocketSession session,
TextMessage message) throws Exception {
String payload = message.getPayload();
String userId = getUserId(session);
// 处理收到的消息
System.out.println("收到用户 " + userId + " 的消息: " + payload);
// 广播消息给所有客户端
broadcast("用户 " + userId + " 说: " + payload);
}
@Override
public void afterConnectionClosed(WebSocketSession session,
CloseStatus status) throws Exception {
String userId = getUserId(session);
sessions.remove(userId);
System.out.println("用户 " + userId + " 断开连接");
}
@Override
public void handleTransportError(WebSocketSession session,
Throwable exception) throws Exception {
session.close(CloseStatus.SERVER_ERROR);
}
// 广播消息
private void broadcast(String message) {
for (WebSocketSession session : sessions.values()) {
try {
if (session.isOpen()) {
session.sendMessage(new TextMessage(message));
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
// 向特定用户发送消息
public void sendToUser(String userId, String message) {
WebSocketSession session = sessions.get(userId);
if (session != null && session.isOpen()) {
try {
session.sendMessage(new TextMessage(message));
} catch (IOException e) {
e.printStackTrace();
}
}
}
// 获取用户ID (从session attributes获取)
private String getUserId(WebSocketSession session) {
// 实际开发中可以传入token或session ID
return session.getId(); // 这里简单使用sessionId
}
}
前端JavaScript客户端
// 创建WebSocket连接
let socket = new WebSocket("ws://localhost:8080/myHandler");
// 连接成功
socket.onopen = function(event) {
console.log("WebSocket连接成功");
socket.send("Hello Server!");
};
// 接收消息
socket.onmessage = function(event) {
let message = event.data;
console.log("收到消息: " + message);
// 更新UI
document.getElementById('messages').innerHTML +=
'<p>' + message + '</p>';
};
// 连接关闭
socket.onclose = function(event) {
console.log("WebSocket连接关闭");
};
// 发送消息
function sendMessage() {
let message = document.getElementById('messageInput').value;
socket.send(message);
}
// 断开连接
function disconnect() {
socket.close();
}
基于Netty实现(高性能)
Maven依赖
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.68.Final</version>
</dependency>
Netty WebSocket服务端
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.handler.stream.ChunkedWriteHandler;
public class NettyWebSocketServer {
private final int port;
public NettyWebSocketServer(int port) {
this.port = port;
}
public void start() throws Exception {
EventLoopGroup bossGroup = new NioEventLoopGroup(1);
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
ChannelPipeline pipeline = ch.pipeline();
// HTTP编解码器
pipeline.addLast(new HttpServerCodec());
// HTTP对象聚合
pipeline.addLast(new HttpObjectAggregator(65536));
// 大文件支持
pipeline.addLast(new ChunkedWriteHandler());
// WebSocket处理器
pipeline.addLast(new WebSocketServerHandler());
}
});
ChannelFuture future = bootstrap.bind(port).sync();
System.out.println("Netty WebSocket服务器启动在端口: " + port);
future.channel().closeFuture().sync();
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
public static void main(String[] args) throws Exception {
new NettyWebSocketServer(8080).start();
}
}
Netty WebSocket处理器
import io.netty.channel.*;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.util.AttributeKey;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object> {
private static final Map<String, Channel> channels = new ConcurrentHashMap<>();
private WebSocketServerHandshaker handshaker;
@Override
protected void channelRead0(ChannelHandlerContext ctx, Object msg)
throws Exception {
if (msg instanceof FullHttpRequest) {
// 处理HTTP请求,完成WebSocket握手
handleHttpRequest(ctx, (FullHttpRequest) msg);
} else if (msg instanceof WebSocketFrame) {
// 处理WebSocket帧
handleWebSocketFrame(ctx, (WebSocketFrame) msg);
}
}
private void handleHttpRequest(ChannelHandlerContext ctx,
FullHttpRequest request) {
// 如果不是WebSocket握手请求
if (!request.decoderResult().isSuccess() ||
(!"websocket".equals(request.headers().get("Upgrade")))) {
sendHttpResponse(ctx, request,
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
HttpResponseStatus.BAD_REQUEST));
return;
}
// 构造WebSocket握手响应
WebSocketServerHandshakerFactory factory =
new WebSocketServerHandshakerFactory(
"ws://localhost:8080/websocket", null, false);
handshaker = factory.newHandshaker(request);
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
handshaker.handshake(ctx.channel(), request);
// 保存连接
String channelId = ctx.channel().id().asLongText();
channels.put(channelId, ctx.channel());
System.out.println("新客户端连接: " + channelId);
}
}
private void handleWebSocketFrame(ChannelHandlerContext ctx,
WebSocketFrame frame) {
// 关闭帧
if (frame instanceof CloseWebSocketFrame) {
handshaker.close(ctx.channel(),
(CloseWebSocketFrame) frame.retain());
return;
}
// Ping帧
if (frame instanceof PingWebSocketFrame) {
ctx.channel().write(new PongWebSocketFrame(
frame.content().retain()));
return;
}
// 文本帧
if (frame instanceof TextWebSocketFrame) {
String request = ((TextWebSocketFrame) frame).text();
System.out.println("收到消息: " + request);
// 广播消息
for (Channel channel : channels.values()) {
if (channel.isActive()) {
channel.writeAndFlush(new TextWebSocketFrame(
"服务器回复: " + request));
}
}
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
String channelId = ctx.channel().id().asLongText();
channels.remove(channelId);
System.out.println("客户端断开连接: " + channelId);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
cause.printStackTrace();
ctx.close();
}
private void sendHttpResponse(ChannelHandlerContext ctx,
FullHttpRequest request, FullHttpResponse response) {
if (response.status().code() != 200) {
response.headers().set(HttpHeaderNames.CONNECTION,
HttpHeaderValues.CLOSE);
}
ChannelFuture future = ctx.channel().writeAndFlush(response);
if (!HttpUtil.isKeepAlive(request) ||
response.status().code() != 200) {
future.addListener(ChannelFutureListener.CLOSE);
}
}
}
WebSocket应用场景
在线聊天室
// 群聊消息
public class ChatRoom {
private static Map<String, Set<WebSocketSession>> rooms =
new ConcurrentHashMap<>();
public void joinRoom(String roomId, WebSocketSession session) {
rooms.computeIfAbsent(roomId, k -> new CopyOnWriteArraySet<>())
.add(session);
}
public void leaveRoom(String roomId, WebSocketSession session) {
Set<WebSocketSession> room = rooms.get(roomId);
if (room != null) {
room.remove(session);
}
}
public void sendToRoom(String roomId, String message) {
Set<WebSocketSession> room = rooms.get(roomId);
if (room != null) {
for (WebSocketSession session : room) {
try {
session.sendMessage(new TextMessage(message));
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
实时数据推送
@Component
public class DataPushService {
@Autowired
private SimpMessagingTemplate messagingTemplate;
// 定时推送数据
@Scheduled(fixedRate = 5000)
public void pushData() {
Map<String, Object> data = new HashMap<>();
data.put("timestamp", System.currentTimeMillis());
data.put("value", Math.random() * 100);
messagingTemplate.convertAndSend("/topic/data", data);
}
}
最佳实践
心跳机制
@OnMessage
public void onMessage(String message, Session session) {
if (message.equals("ping")) {
try {
session.getBasicRemote().sendText("pong");
} catch (IOException e) {
e.printStackTrace();
}
} else {
// 处理业务消息
handleBusinessMessage(message, session);
}
}
重连机制
let ws = null;
let retryCount = 0;
const MAX_RETRY = 5;
function connect() {
ws = new WebSocket("ws://localhost:8080/websocket");
ws.onopen = function(event) {
console.log("连接成功");
retryCount = 0;
};
ws.onclose = function(event) {
console.log("连接断开,尝试重连");
if (retryCount < MAX_RETRY) {
retryCount++;
setTimeout(connect, 3000 * retryCount);
}
};
ws.onerror = function(event) {
console.log("连接错误");
};
}
connect();
消息协议设计
// 定义消息格式
public class WebSocketMessage {
private String type; // 消息类型
private String content; // 消息内容
private String sender; // 发送者
private long timestamp; // 时间戳
// getters and setters
}
// 使用JSON格式传输
@OnMessage
public void onMessage(String message, Session session) {
WebSocketMessage wsMessage =
new ObjectMapper().readValue(message, WebSocketMessage.class);
switch (wsMessage.getType()) {
case "chat":
handleChat(wsMessage);
break;
case "notification":
handleNotification(wsMessage);
break;
case "ping":
handlePing(session);
break;
}
}
- Java EE标准:简单直接,适合小型项目
- Spring Boot:功能完善,推荐用于实际项目
- Netty:高性能,适合高并发场景
- 关键点:心跳机制、重连机制、消息协议、安全性考虑
选择哪种实现方式取决于你的项目需求、技术栈和性能要求。