如何用Java案例实现数据抽样?

wen java案例 3

本文目录导读:

如何用Java案例实现数据抽样?

  1. 基础数据结构
  2. 简单随机抽样
  3. 等距抽样(系统抽样)
  4. 分层抽样
  5. 完整示例和测试
  6. 高级:蓄水池抽样

我来介绍几种Java实现数据抽样的方法,包括简单随机抽样、等距抽样和分层抽样。

基础数据结构

import java.util.*;
import java.util.stream.Collectors;
public class DataSamplingDemo {
    // 数据样本类
    static class DataItem {
        private int id;
        private String category;
        private double value;
        public DataItem(int id, String category, double value) {
            this.id = id;
            this.category = category;
            this.value = value;
        }
        @Override
        public String toString() {
            return String.format("DataItem{id=%d, category='%s', value=%.2f}", 
                               id, category, value);
        }
    }
    // 生成测试数据
    private static List<DataItem> generateData(int size) {
        List<DataItem> data = new ArrayList<>();
        String[] categories = {"A", "B", "C"};
        Random random = new Random(42);
        for (int i = 0; i < size; i++) {
            data.add(new DataItem(
                i,
                categories[random.nextInt(categories.length)],
                random.nextDouble() * 100
            ));
        }
        return data;
    }
}

简单随机抽样

// 简单随机抽样实现
public class SimpleRandomSampling {
    /**
     * 简单随机抽样(不放回)
     * @param population 总体数据
     * @param sampleSize 样本大小
     * @return 抽样结果
     */
    public static <T> List<T> sampleWithoutReplacement(List<T> population, int sampleSize) {
        if (sampleSize > population.size()) {
            throw new IllegalArgumentException("样本大小不能超过总体大小");
        }
        List<T> copy = new ArrayList<>(population);
        List<T> sample = new ArrayList<>();
        Random random = new Random();
        for (int i = 0; i < sampleSize; i++) {
            int index = random.nextInt(copy.size());
            sample.add(copy.remove(index));
        }
        return sample;
    }
    /**
     * 简单随机抽样(放回)
     * @param population 总体数据
     * @param sampleSize 样本大小
     * @return 抽样结果
     */
    public static <T> List<T> sampleWithReplacement(List<T> population, int sampleSize) {
        List<T> sample = new ArrayList<>();
        Random random = new Random();
        for (int i = 0; i < sampleSize; i++) {
            int index = random.nextInt(population.size());
            sample.add(population.get(index));
        }
        return sample;
    }
    /**
     * 使用Fisher-Yates洗牌算法进行抽样
     */
    public static <T> List<T> fisherYatesSample(List<T> population, int sampleSize) {
        if (sampleSize > population.size()) {
            throw new IllegalArgumentException("样本大小不能超过总体大小");
        }
        List<T> result = new ArrayList<>(population);
        Random random = new Random();
        for (int i = result.size() - 1; i >= result.size() - sampleSize; i--) {
            int j = random.nextInt(i + 1);
            // 交换元素
            T temp = result.get(i);
            result.set(i, result.get(j));
            result.set(j, temp);
        }
        return result.subList(result.size() - sampleSize, result.size());
    }
}

等距抽样(系统抽样)

public class SystematicSampling {
    /**
     * 等距抽样
     * @param population 总体数据
     * @param sampleSize 样本大小
     * @return 抽样结果
     */
    public static <T> List<T> systematicSample(List<T> population, int sampleSize) {
        if (sampleSize > population.size()) {
            throw new IllegalArgumentException("样本大小不能超过总体大小");
        }
        int interval = population.size() / sampleSize;
        List<T> sample = new ArrayList<>();
        Random random = new Random();
        // 随机选择起点
        int start = random.nextInt(interval);
        for (int i = start; i < population.size(); i += interval) {
            if (sample.size() >= sampleSize) {
                break;
            }
            sample.add(population.get(i));
        }
        return sample;
    }
    /**
     * 循环等距抽样
     */
    public static <T> List<T> circularSystematicSample(List<T> population, int sampleSize) {
        if (sampleSize > population.size()) {
            throw new IllegalArgumentException("样本大小不能超过总体大小");
        }
        int interval = population.size() / sampleSize;
        List<T> sample = new ArrayList<>();
        Random random = new Random();
        int start = random.nextInt(population.size());
        for (int i = 0; i < sampleSize; i++) {
            int index = (start + i * interval) % population.size();
            sample.add(population.get(index));
        }
        return sample;
    }
}

分层抽样

public class StratifiedSampling {
    /**
     * 分层抽样(等比例分配)
     * @param population 总体数据
     * @param sampleSize 样本大小
     * @param categoryExtractor 类别提取函数
     * @return 抽样结果
     */
    public static <T> Map<String, List<T>> stratifiedSample(
            List<T> population, 
            int sampleSize,
            java.util.function.Function<T, String> categoryExtractor) {
        // 按类别分组
        Map<String, List<T>> groups = population.stream()
            .collect(Collectors.groupingBy(categoryExtractor));
        // 计算每个层应该抽取的样本数
        Map<String, Double> weights = new HashMap<>();
        int totalSize = population.size();
        for (Map.Entry<String, List<T>> entry : groups.entrySet()) {
            weights.put(entry.getKey(), (double) entry.getValue().size() / totalSize);
        }
        // 进行抽样
        Map<String, List<T>> sample = new HashMap<>();
        Random random = new Random();
        for (Map.Entry<String, List<T>> entry : groups.entrySet()) {
            int stratumSize = (int) Math.round(weights.get(entry.getKey()) * sampleSize);
            List<T> stratum = entry.getValue();
            // 从该层随机抽样
            List<T> stratumSample = new ArrayList<>();
            List<T> stratumCopy = new ArrayList<>(stratum);
            for (int i = 0; i < Math.min(stratumSize, stratumCopy.size()); i++) {
                int index = random.nextInt(stratumCopy.size());
                stratumSample.add(stratumCopy.remove(index));
            }
            sample.put(entry.getKey(), stratumSample);
        }
        return sample;
    }
    /**
     * 最优分配分层抽样(考虑层内方差)
     */
    public static <T> Map<String, List<T>> optimalStratifiedSample(
            List<T> population,
            int sampleSize,
            java.util.function.Function<T, String> categoryExtractor,
            java.util.function.ToDoubleFunction<T> valueExtractor) {
        // 按类别分组
        Map<String, List<T>> groups = population.stream()
            .collect(Collectors.groupingBy(categoryExtractor));
        // 计算每层的权重和标准差
        Map<String, Double> weights = new HashMap<>();
        Map<String, Double> stdDeviations = new HashMap<>();
        int totalSize = population.size();
        for (Map.Entry<String, List<T>> entry : groups.entrySet()) {
            List<T> stratum = entry.getValue();
            weights.put(entry.getKey(), (double) stratum.size() / totalSize);
            // 计算标准差
            double mean = stratum.stream()
                .mapToDouble(valueExtractor)
                .average()
                .orElse(0);
            double variance = stratum.stream()
                .mapToDouble(item -> Math.pow(valueExtractor.applyAsDouble(item) - mean, 2))
                .average()
                .orElse(0);
            stdDeviations.put(entry.getKey(), Math.sqrt(variance));
        }
        // 计算总加权标准差
        double totalWeightedStd = groups.keySet().stream()
            .mapToDouble(k -> weights.get(k) * stdDeviations.get(k))
            .sum();
        // 最优分配抽样
        Map<String, List<T>> sample = new HashMap<>();
        Random random = new Random();
        for (Map.Entry<String, List<T>> entry : groups.entrySet()) {
            String category = entry.getKey();
            List<T> stratum = entry.getValue();
            // 最优分配公式
            int stratumSize = (int) Math.round(
                sampleSize * weights.get(category) * stdDeviations.get(category) / totalWeightedStd
            );
            // 从该层随机抽样
            List<T> stratumSample = new ArrayList<>();
            List<T> stratumCopy = new ArrayList<>(stratum);
            for (int i = 0; i < Math.min(stratumSize, stratumCopy.size()); i++) {
                int index = random.nextInt(stratumCopy.size());
                stratumSample.add(stratumCopy.remove(index));
            }
            sample.put(category, stratumSample);
        }
        return sample;
    }
}

完整示例和测试

public class SamplingExample {
    public static void main(String[] args) {
        // 生成测试数据
        List<DataItem> population = DataSamplingDemo.generateData(1000);
        System.out.println("总体大小: " + population.size());
        // 1. 简单随机抽样(不放回)
        System.out.println("\n=== 简单随机抽样(不放回)===");
        List<DataItem> simpleSample = SimpleRandomSampling.sampleWithoutReplacement(population, 100);
        System.out.println("样本大小: " + simpleSample.size());
        simpleSample.subList(0, 5).forEach(System.out::println);
        // 2. 简单随机抽样(放回)
        System.out.println("\n=== 简单随机抽样(放回)===");
        List<DataItem> sampleWithReplace = SimpleRandomSampling.sampleWithReplacement(population, 100);
        System.out.println("样本大小: " + sampleWithReplace.size());
        // 3. 等距抽样
        System.out.println("\n=== 等距抽样 ===");
        List<DataItem> systematicSample = SystematicSampling.systematicSample(population, 100);
        System.out.println("样本大小: " + systematicSample.size());
        // 4. 分层抽样(等比例)
        System.out.println("\n=== 分层抽样(等比例)===");
        Map<String, List<DataItem>> stratifiedSample = 
            StratifiedSampling.stratifiedSample(population, 100, item -> item.category);
        stratifiedSample.forEach((category, sample) -> {
            System.out.printf("类别 %s: 样本大小 %d%n", category, sample.size());
        });
        // 5. 最优分配分层抽样
        System.out.println("\n=== 最优分配分层抽样 ===");
        Map<String, List<DataItem>> optimalSample = 
            StratifiedSampling.optimalStratifiedSample(
                population, 100, 
                item -> item.category,
                item -> item.value
            );
        optimalSample.forEach((category, sample) -> {
            System.out.printf("类别 %s: 样本大小 %d%n", category, sample.size());
        });
        // 性能对比
        System.out.println("\n=== 性能对比 ===");
        long startTime = System.nanoTime();
        SimpleRandomSampling.fisherYatesSample(population, 100);
        long fisherYatesTime = System.nanoTime() - startTime;
        startTime = System.nanoTime();
        SimpleRandomSampling.sampleWithoutReplacement(population, 100);
        long removeTime = System.nanoTime() - startTime;
        System.out.println("Fisher-Yates洗牌算法: " + fisherYatesTime / 1_000_000 + " ms");
        System.out.println("remove方法: " + removeTime / 1_000_000 + " ms");
    }
}

高级:蓄水池抽样

public class ReservoirSampling {
    /**
     * 蓄水池抽样(适用于数据流或未知大小的数据)
     * @param stream 数据流
     * @param k 样本大小
     * @return 抽样结果
     */
    public static <T> List<T> reservoirSample(Iterator<T> stream, int k) {
        List<T> reservoir = new ArrayList<>(k);
        Random random = new Random();
        int i = 0;
        // 填充蓄水池
        while (i < k && stream.hasNext()) {
            reservoir.add(stream.next());
            i++;
        }
        // 对后续元素进行处理
        while (stream.hasNext()) {
            T item = stream.next();
            i++;
            // 以 k/i 的概率替换蓄水池中的元素
            int j = random.nextInt(i);
            if (j < k) {
                reservoir.set(j, item);
            }
        }
        return reservoir;
    }
    /**
     * 加权蓄水池抽样
     */
    public static <T> List<T> weightedReservoirSample(
            Iterator<T> stream, 
            int k,
            java.util.function.ToDoubleFunction<T> weightExtractor) {
        List<T> reservoir = new ArrayList<>(k);
        List<Double> weights = new ArrayList<>(k);
        Random random = new Random();
        int i = 0;
        // 填充蓄水池
        while (i < k && stream.hasNext()) {
            T item = stream.next();
            reservoir.add(item);
            weights.add(Math.pow(random.nextDouble(), 1.0 / weightExtractor.applyAsDouble(item)));
            i++;
        }
        // 对后续元素进行处理
        while (stream.hasNext()) {
            T item = stream.next();
            double weight = weightExtractor.applyAsDouble(item);
            double key = Math.pow(random.nextDouble(), 1.0 / weight);
            if (key > Collections.min(weights)) {
                int minIndex = weights.indexOf(Collections.min(weights));
                reservoir.set(minIndex, item);
                weights.set(minIndex, key);
            }
        }
        return reservoir;
    }
}

这些实现提供了多种数据抽样策略,可以根据具体需求选择合适的方法,关键考虑因素包括:

  1. 抽样精度要求:简单随机抽样最常用
  2. 数据分布特征:分层抽样适合有明确分层的数据
  3. 性能要求:Fisher-Yates洗牌算法效率较高
  4. 数据规模:蓄水池抽样适合大数据流
  5. 抽样方式:放回或不放回抽样

抱歉,评论功能暂时关闭!