Java案例如何实现一致性哈希?

wen java案例 3

本文目录导读:

Java案例如何实现一致性哈希?

  1. 基础一致性哈希实现
  2. 测试和演示代码
  3. 优化版实现(带监控和统计)
  4. 使用示例

我来详细介绍Java中实现一致性哈希的完整案例。

基础一致性哈希实现

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.*;
public class ConsistentHash<T> {
    // 虚拟节点数量
    private final int virtualNodeCount;
    // 哈希环(有序的TreeMap)
    private final SortedMap<Integer, T> ring = new TreeMap<>();
    // 节点对应的虚拟节点列表
    private final Map<T, List<Integer>> nodeVirtualNodes = new HashMap<>();
    public ConsistentHash(int virtualNodeCount, Collection<T> nodes) {
        this.virtualNodeCount = virtualNodeCount;
        for (T node : nodes) {
            addNode(node);
        }
    }
    // 添加节点
    public void addNode(T node) {
        List<Integer> virtualNodes = new ArrayList<>();
        for (int i = 0; i < virtualNodeCount; i++) {
            // 为每个节点生成虚拟节点
            String virtualNodeName = node.toString() + "#" + i;
            int hash = getHash(virtualNodeName);
            ring.put(hash, node);
            virtualNodes.add(hash);
        }
        nodeVirtualNodes.put(node, virtualNodes);
    }
    // 移除节点
    public void removeNode(T node) {
        List<Integer> virtualNodes = nodeVirtualNodes.get(node);
        if (virtualNodes != null) {
            for (Integer hash : virtualNodes) {
                ring.remove(hash);
            }
            nodeVirtualNodes.remove(node);
        }
    }
    // 获取key对应的节点
    public T getNode(Object key) {
        if (ring.isEmpty()) {
            return null;
        }
        int hash = getHash(key.toString());
        // 找到大于等于hash值的第一个节点
        SortedMap<Integer, T> tailMap = ring.tailMap(hash);
        // 如果不存在,则取第一个节点(形成环)
        Integer nodeHash = tailMap.isEmpty() ? ring.firstKey() : tailMap.firstKey();
        return ring.get(nodeHash);
    }
    // 计算hash值
    public static int getHash(String key) {
        try {
            MessageDigest md5 = MessageDigest.getInstance("MD5");
            md5.update(key.getBytes());
            byte[] digest = md5.digest();
            // 取前4个字节作为hash值
            int hash = ((digest[3] & 0xFF) << 24) 
                    | ((digest[2] & 0xFF) << 16) 
                    | ((digest[1] & 0xFF) << 8) 
                    | (digest[0] & 0xFF);
            return hash & 0x7FFFFFFF; // 确保非负
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException("MD5 algorithm not found", e);
        }
    }
    // 获取所有节点
    public Set<T> getNodes() {
        return new HashSet<>(ring.values());
    }
    // 获取节点数量
    public int size() {
        return nodeVirtualNodes.size();
    }
}

测试和演示代码

public class ConsistentHashDemo {
    public static void main(String[] args) {
        // 1. 基础功能测试
        testBasicFunctionality();
        // 2. 节点增减测试
        testNodeAdditionAndRemoval();
        // 3. 缓存服务器负载均衡模拟
        testCacheLoadBalancing();
    }
    // 基础功能测试
    public static void testBasicFunctionality() {
        System.out.println("========== 基础功能测试 ==========");
        // 创建缓存服务器列表
        List<String> servers = Arrays.asList(
            "192.168.1.1:6379",
            "192.168.1.2:6379", 
            "192.168.1.3:6379"
        );
        // 创建一致性哈希,每个服务器100个虚拟节点
        ConsistentHash<String> consistentHash = 
            new ConsistentHash<>(100, servers);
        // 模拟缓存key
        String[] keys = {"user:1001", "user:1002", "user:1003", 
                        "order:2023001", "product:10001"};
        System.out.println("缓存服务器分布:");
        for (String key : keys) {
            String server = consistentHash.getNode(key);
            System.out.println(key + " -> " + server);
        }
    }
    // 节点增减测试
    public static void testNodeAdditionAndRemoval() {
        System.out.println("\n========== 节点增减测试 ==========");
        // 初始服务器列表
        List<String> servers = new ArrayList<>(Arrays.asList(
            "Server-A", "Server-B", "Server-C"
        ));
        ConsistentHash<String> consistentHash = 
            new ConsistentHash<>(100, servers);
        // 100个测试key
        Map<String, String> initialMapping = new HashMap<>();
        for (int i = 1; i <= 100; i++) {
            String key = "key" + i;
            initialMapping.put(key, consistentHash.getNode(key));
        }
        // 添加新节点
        System.out.println("添加新节点 Server-D:");
        consistentHash.addNode("Server-D");
        // 统计映射变化
        int changed = 0;
        for (Map.Entry<String, String> entry : initialMapping.entrySet()) {
            String newServer = consistentHash.getNode(entry.getKey());
            if (!newServer.equals(entry.getValue())) {
                changed++;
            }
        }
        System.out.println("添加节点后,100个key中 " + changed + 
                          " 个映射发生变化(理想情况下约 25%)");
    }
    // 缓存服务器负载均衡模拟
    public static void testCacheLoadBalancing() {
        System.out.println("\n========== 负载均衡测试 ==========");
        // 测试不同虚拟节点数量的影响
        int[] virtualNodeCounts = {1, 10, 100, 200};
        for (int virtualCount : virtualNodeCounts) {
            System.out.println("\n虚拟节点数量: " + virtualCount);
            List<String> servers = Arrays.asList(
                "Server-A", "Server-B", "Server-C", "Server-D"
            );
            ConsistentHash<String> ch = 
                new ConsistentHash<>(virtualCount, servers);
            // 模拟10000个缓存请求
            Map<String, Integer> distribution = new HashMap<>();
            for (String server : servers) {
                distribution.put(server, 0);
            }
            for (int i = 0; i < 10000; i++) {
                String key = "test:key:" + i;
                String server = ch.getNode(key);
                distribution.put(server, distribution.get(server) + 1);
            }
            // 输出分布情况
            System.out.println("负载分布:");
            for (Map.Entry<String, Integer> entry : distribution.entrySet()) {
                double percentage = (entry.getValue() / 100.0);
                System.out.println(entry.getKey() + ": " + 
                                 entry.getValue() + " (" + percentage + "%)");
            }
        }
    }
}

优化版实现(带监控和统计)

public class OptimizedConsistentHash<T> {
    private final int virtualNodeCount;
    private final SortedMap<Integer, VirtualNode<T>> ring = new TreeMap<>();
    private final Map<T, List<VirtualNode<T>>> nodeVirtualNodes = new HashMap<>();
    private volatile boolean monitoring = false;
    // 监控数据
    private final Map<T, AtomicLong> requestCount = new ConcurrentHashMap<>();
    private final Map<T, AtomicLong> hitCount = new ConcurrentHashMap<>();
    public OptimizedConsistentHash(int virtualNodeCount, Collection<T> nodes) {
        this.virtualNodeCount = virtualNodeCount;
        for (T node : nodes) {
            addNode(node);
        }
        startMonitoring();
    }
    public void addNode(T node) {
        List<VirtualNode<T>> virtualNodes = new ArrayList<>();
        for (int i = 0; i < virtualNodeCount; i++) {
            VirtualNode<T> virtualNode = new VirtualNode<>(node, i);
            int hash = getHash(virtualNode.getKey());
            ring.put(hash, virtualNode);
            virtualNodes.add(virtualNode);
        }
        nodeVirtualNodes.put(node, virtualNodes);
        requestCount.put(node, new AtomicLong(0));
        hitCount.put(node, new AtomicLong(0));
    }
    public void removeNode(T node) {
        List<VirtualNode<T>> virtualNodes = nodeVirtualNodes.get(node);
        if (virtualNodes != null) {
            for (VirtualNode<T> virtualNode : virtualNodes) {
                ring.remove(getHash(virtualNode.getKey()));
            }
            nodeVirtualNodes.remove(node);
            requestCount.remove(node);
            hitCount.remove(node);
        }
    }
    public T getNode(Object key, boolean trackStats) {
        if (ring.isEmpty()) {
            return null;
        }
        int hash = getHash(key.toString());
        SortedMap<Integer, VirtualNode<T>> tailMap = ring.tailMap(hash);
        Integer nodeHash = tailMap.isEmpty() ? ring.firstKey() : tailMap.firstKey();
        VirtualNode<T> virtualNode = ring.get(nodeHash);
        T node = virtualNode.getPhysicalNode();
        if (trackStats) {
            requestCount.get(node).incrementAndGet();
        }
        return node;
    }
    // 记录缓存命中
    public void recordHit(T node) {
        hitCount.get(node).incrementAndGet();
    }
    // 获取命中率
    public Map<T, Double> getHitRates() {
        Map<T, Double> hitRates = new HashMap<>();
        for (T node : nodeVirtualNodes.keySet()) {
            long requests = requestCount.get(node).get();
            long hits = hitCount.get(node).get();
            double rate = requests > 0 ? (double) hits / requests : 0;
            hitRates.put(node, rate * 100);
        }
        return hitRates;
    }
    // 启动监控
    private void startMonitoring() {
        monitoring = true;
        new Thread(() -> {
            while (monitoring) {
                try {
                    Thread.sleep(60000); // 每分钟输出一次
                    printStats();
                } catch (InterruptedException e) {
                    break;
                }
            }
        }).start();
    }
    // 打印统计信息
    public void printStats() {
        System.out.println("\n=== 缓存节点统计 ===");
        for (T node : nodeVirtualNodes.keySet()) {
            long requests = requestCount.get(node).get();
            long hits = hitCount.get(node).get();
            double hitRate = requests > 0 ? (double) hits / requests * 100 : 0;
            System.out.printf("节点 %s: 请求=%d, 命中=%d, 命中率=%.2f%%\n", 
                            node, requests, hits, hitRate);
        }
    }
    public void stopMonitoring() {
        this.monitoring = false;
    }
    // 虚拟节点类
    private static class VirtualNode<T> {
        private final T physicalNode;
        private final int replicaIndex;
        public VirtualNode(T physicalNode, int replicaIndex) {
            this.physicalNode = physicalNode;
            this.replicaIndex = replicaIndex;
        }
        public T getPhysicalNode() {
            return physicalNode;
        }
        public String getKey() {
            return physicalNode.toString() + "#" + replicaIndex;
        }
    }
    // Hash计算
    public static int getHash(String key) {
        // 使用FNV-1a算法,性能更好
        int hash = 2166136261;
        for (int i = 0; i < key.length(); i++) {
            hash ^= key.charAt(i);
            hash *= 16777619;
        }
        return hash & 0x7FFFFFFF;
    }
}

使用示例

public class UsageExample {
    public static void main(String[] args) {
        // 创建缓存节点
        List<CacheNode> nodes = Arrays.asList(
            new CacheNode("127.0.0.1", 6379),
            new CacheNode("127.0.0.2", 6379),
            new CacheNode("127.0.0.3", 6379)
        );
        // 使用优化版一致性哈希(200个虚拟节点)
        OptimizedConsistentHash<CacheNode> ch = 
            new OptimizedConsistentHash<>(200, nodes);
        // 缓存数据
        Map<String, String> cacheData = new HashMap<>();
        for (int i = 1; i <= 1000; i++) {
            String key = "cache:data:" + i;
            String value = "value-" + i;
            CacheNode node = ch.getNode(key, true);
            // 实际应用中这里会将数据存储到对应的Redis节点
            cacheData.put(key, value);
            System.out.println("Key: " + key + " -> 节点: " + node.getAddress());
        }
        // 模拟缓存命中
        for (int i = 1; i <= 500; i++) {
            String key = "cache:data:" + i;
            CacheNode node = ch.getNode(key, true);
            ch.recordHit(node); // 记录命中
        }
        // 查看统计
        ch.printStats();
        ch.stopMonitoring();
    }
}
// 缓存节点类
class CacheNode {
    private final String host;
    private final int port;
    public CacheNode(String host, int port) {
        this.host = host;
        this.port = port;
    }
    public String getAddress() {
        return host + ":" + port;
    }
    @Override
    public String toString() {
        return getAddress();
    }
    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        CacheNode cacheNode = (CacheNode) o;
        return port == cacheNode.port && Objects.equals(host, cacheNode.host);
    }
    @Override
    public int hashCode() {
        return Objects.hash(host, port);
    }
}

这个实现的主要特点:

  1. 均匀分布:使用虚拟节点确保数据均匀分布
  2. 最小干扰:节点增减时只影响少量数据
  3. 高效查找:使用TreeMap实现O(log n)的查找效率
  4. 可监控:支持统计和监控功能

应用场景包括分布式缓存、负载均衡、数据库分片等。

抱歉,评论功能暂时关闭!