本文目录导读:

- 第一步:定义配置
- 第二步:核心路由引擎
- 第三步:核心拦截代理(关键)
- 第四步:代理Connection和Statement
- 第五步:使用示例
- 这个极简版的局限性(通往生产级的关键差距)
- 如果你想深入学习并构建可用版本
这是一个比较复杂的话题,实现一个生产级的分库分表中间件(类似ShardingSphere的核心功能)需要处理SQL解析、路由、改写、归并等复杂逻辑。
实现一个最简版本的中间件来理解其核心思想是可行的,它的核心功能是:
- 拦截SQL: 拦截应用发来的SQL。
- 解析分片键: 从SQL中提取出用于分片的关键字段(如
user_id)。 - 路由计算: 根据分片键的值,计算出该SQL应该发往哪个数据库(DS)、哪张表(T)。
- SQL改写: 将逻辑表名(如
t_order)改写为真实物理表名(如t_order_0)。 - 执行并归并(简化): 向目标数据源执行改写后的SQL,并(必要时)归并结果(如分组、排序)。
下面通过一个基于JDBC的DataSource代理 + 自定义注解的思路,实现一个极简版本。
第一步:定义配置
定义一个简单的配置,指定分库分表规则。
// 分片配置类
public class ShardingConfig {
// 逻辑表名 -> 物理库数
private int dbCount;
// 逻辑表名 -> 物理表数
private int tableCount;
// 分片键字段名
private String shardingColumn;
// 真实数据源映射,假设有 ds0, ds1
private Map<String, DataSource> dataSourceMap;
// 构造函数、getter/setter 省略...
}
第二步:核心路由引擎
这是计算数据应该去哪里的核心。
public class ShardingEngine {
private ShardingConfig config;
public ShardingEngine(ShardingConfig config) {
this.config = config;
}
/**
* 根据分片键的值计算目标数据源和表名
* @param shardKeyValue 分片键的值,userId = 3
* @return 目标数据库名和表名
*/
public ShardResult route(String shardKeyValue) {
// 1. 计算哈希(或取模)
int hash = shardKeyValue.hashCode();
// 防止负数
hash = Math.abs(hash);
// 2. 计算库索引
int dbIndex = hash % config.getDbCount();
// 3. 计算表索引
int tableIndex = hash % config.getTableCount();
// 4. 生成目标数据源名称
String targetDsName = "ds" + dbIndex;
// 5. 生成目标物理表名
String targetTableName = config.getLogicTableName() + "_" + tableIndex;
return new ShardResult(targetDsName, targetTableName);
}
}
// 结果封装
public class ShardResult {
private String targetDataSourceName;
private String targetTableName;
// 省略构造/Getter
}
第三步:核心拦截代理(关键)
这是最关键的实现部分,我们通过实现 javax.sql.DataSource 接口并重写 getConnection() 方法来拦截SQL。
(为了简化,这里只处理单条SQL,不考虑事务和连接池的多连接情况)
import javax.sql.DataSource;
import java.io.PrintWriter;
import java.sql.*;
import java.util.logging.Logger;
public class ShardingDataSource implements DataSource {
private ShardingEngine shardingEngine;
private Map<String, DataSource> realDataSources; // ds0 -> 真实连接池, ds1 -> 真实连接池
public ShardingDataSource(ShardingConfig config) {
this.shardingEngine = new ShardingEngine(config);
this.realDataSources = config.getDataSourceMap();
}
@Override
public Connection getConnection() throws SQLException {
// 关键点:返回一个代理 Connection
Connection connection = new ShardingConnection(shardingEngine, realDataSources);
return connection;
}
// 其他 DataSource 方法(getConnection(username, password), getLogWriter, setLogWriter, setLoginTimeout, getLoginTimeout, getParentLogger, unwrap, isWrapperFor)
// 全部直接委托给第一个真实数据源或抛出 UnsupportedOperationException
// (简化实现,省略大量代码)
}
第四步:代理Connection和Statement
这里是最核心的执行拦截点。
import java.sql.*;
import java.util.Map;
public class ShardingConnection implements Connection {
private ShardingEngine shardingEngine;
private Map<String, DataSource> realDataSources;
public ShardingConnection(ShardingEngine shardingEngine, Map<String, DataSource> realDataSources) {
this.shardingEngine = shardingEngine;
this.realDataSources = realDataSources;
}
@Override
public Statement createStatement() throws SQLException {
// 重点:返回一个代理 Statement
return new ShardingStatement(shardingEngine, realDataSources);
}
@Override
public PreparedStatement prepareStatement(String sql) throws SQLException {
// 这里要解析SQL,提取预编译参数,比较麻烦,极简版只处理 Statement。
throw new UnsupportedOperationException("PreparedStatement 暂不支持");
}
// ... 其他 Connection 方法(close, commit, rollback 等需要代理)
// 这里为了简化,直接抛出异常或调用第一个真实连接的对应方法。
}
import java.sql.*;
import java.util.Map;
public class ShardingStatement implements Statement {
private ShardingEngine shardingEngine;
private Map<String, DataSource> realDataSources;
public ShardingStatement(ShardingEngine shardingEngine, Map<String, DataSource> realDataSources) {
this.shardingEngine = shardingEngine;
this.realDataSources = realDataSources;
}
@Override
public ResultSet executeQuery(String sql) throws SQLException {
// 1. 解析SQL(极简:假设SQL是 "SELECT * FROM t_order WHERE user_id = 3")
// 这里需要做简单的字符串解析来提取分片键值。
// 生产环境会使用Druid解析器或JSQLParser。
// 假设我们有一个 SQLParser 工具提取出了 userId 的值是 3
int shardKeyValue = extractShardKey(sql); // 简化为从"user_id = ?"中提取
// 2. 路由
ShardResult result = shardingEngine.route(String.valueOf(shardKeyValue));
// 3. 改写SQL
String newSql = sql.replace("t_order", result.getTargetTableName());
// 4. 获取真实数据源并执行
DataSource realDs = realDataSources.get(result.getTargetDataSourceName());
try (Connection realConn = realDs.getConnection();
Statement realStmt = realConn.createStatement()) {
return realStmt.executeQuery(newSql);
}
// 注意:这里没有对结果集进行归并,如果要支持 Order By 或 Group By 跨库,需要更复杂的结果集合并。
}
@Override
public int executeUpdate(String sql) throws SQLException {
// 类似executeQuery的逻辑
int shardKeyValue = extractShardKey(sql);
ShardResult result = shardingEngine.route(String.valueOf(shardKeyValue));
String newSql = sql.replace("t_order", result.getTargetTableName());
DataSource realDs = realDataSources.get(result.getTargetDataSourceName());
try (Connection realConn = realDs.getConnection();
Statement realStmt = realConn.createStatement()) {
return realStmt.executeUpdate(newSql);
}
}
// 模拟SQL解析(极度简化)
private int extractShardKey(String sql) {
// 假设SQL格式为 "SELECT * FROM t_order WHERE user_id = 3"
// 或 "INSERT INTO t_order (user_id, ...) VALUES (3, ...)"
// 真实场景需用SQL解析器。
String lowerSql = sql.toLowerCase();
int idx = lowerSql.indexOf("user_id");
if (idx == -1) throw new RuntimeException("未找到分片键");
// 找等号后面的数字(极度简化,不处理复杂条件)
String afterUserId = sql.substring(idx + 8); // "user_id = "..."
afterUserId = afterUserId.trim();
if (!afterUserId.startsWith("=")) throw new RuntimeException("解析失败");
afterUserId = afterUserId.substring(1).trim();
// 找到第一个数字
StringBuilder num = new StringBuilder();
for (char c : afterUserId.toCharArray()) {
if (Character.isDigit(c)) {
num.append(c);
} else {
break;
}
}
return Integer.parseInt(num.toString());
}
// ... 其他 Statement 方法(close, getResultSet, getMoreResults, etc.)
// 同样需要简化代理。
}
第五步:使用示例
public class Main {
public static void main(String[] args) throws Exception {
// 1. 构建真实数据源 (假设已有 2 个数据库)
Map<String, DataSource> dsMap = new HashMap<>();
HikariConfig config0 = new HikariConfig();
config0.setJdbcUrl("jdbc:mysql://localhost:3306/ds0");
dsMap.put("ds0", new HikariDataSource(config0));
HikariConfig config1 = new HikariConfig();
config1.setJdbcUrl("jdbc:mysql://localhost:3306/ds1");
dsMap.put("ds1", new HikariDataSource(config1));
// 2. 配置分片规则
ShardingConfig shardingConfig = new ShardingConfig();
shardingConfig.setDataSourceMap(dsMap);
shardingConfig.setDbCount(2); // 2个库
shardingConfig.setTableCount(4); // 每个库4张表
shardingConfig.setShardingColumn("user_id");
// 3. 创建分片数据源
DataSource shardingDataSource = new ShardingDataSource(shardingConfig);
// 4. 使用JDBC操作
try (Connection conn = shardingDataSource.getConnection();
Statement stmt = conn.createStatement()) {
// 查询:根据user_id的值,被路由到 ds0.t_order_1 或 ds1.t_order_3 等
ResultSet rs = stmt.executeQuery("SELECT * FROM t_order WHERE user_id = 3");
while (rs.next()) {
System.out.println(rs.getLong("id") + " - " + rs.getString("name"));
}
// 插入
stmt.executeUpdate("INSERT INTO t_order (id, user_id, name) VALUES (1, 3, 'test')");
}
}
}
这个极简版的局限性(通往生产级的关键差距)
- SQL解析(致命问题): 上面的
extractShardKey方法太简单,几乎无法应对真实的SQL场景,生产级使用Druid Parser或JSqlParser进行AST(抽象语法树)解析和改写。 - 不支持PreparedStatement: 预编译语句无法直接在字符串层面替换SQL,需要解析参数列表并进行动态改写。
- 不支持范围查询/跨库查询:
WHERE user_id IN (1,2,3,4)或WHERE user_id between 1 and 100,需要同时路由到多个库、多张表,然后归并结果集,这里完全没处理。 - 不支持排序/分组/聚合: 跨库的
ORDER BY、GROUP BY、SUM/COUNT需要从多个库读取结果,在内存中重新排序或聚合。 - 不支持分页:
LIMIT 10 OFFSET 20在分库分表后需要改写为在每个分片上取足够多的数据,然后归并取最终结果。 - 事务问题: 不支持分布式事务(XA或TCC)。
- 连接管理: 对
Connection的代理非常粗糙,没有正确处理close()、setAutoCommit()等操作。 - 没有缓存: 每次路由都需要解析SQL,性能差。
如果你想深入学习并构建可用版本
- 使用成熟的SQL解析器: 引入
com.github.jsqlparser:jsqlparser或com.alibaba:druid,解析SQL的语法树,提取表名、分片键条件、字段列表等。 - 实现SQL改写器: 修改AST中的表名为物理表名,修改(或复制)SQL语句。
- 实现结果集归并: 实现一个多结果集的合并逻辑,支持
StreamResultSet的迭代。 - 实现分片算法接口: 让用户自定义分片算法(取模、一致性哈希、按时间范围等)。
- 支持XA/JTA: 使用
Atomikos或Narayana实现分布式事务。
推荐阅读源码: 在初步理解这个简版模型后,强烈建议去阅读 Apache ShardingSphere 的核心源码(特别是 sharding-core-route 和 sharding-core-execute 模块),这能让你真正理解工业级中间件的设计精髓。