本文目录导读:

我来介绍几种Java实现数据抽样的方法,包括简单随机抽样、等距抽样和分层抽样。
基础数据结构
import java.util.*;
import java.util.stream.Collectors;
public class DataSamplingDemo {
// 数据样本类
static class DataItem {
private int id;
private String category;
private double value;
public DataItem(int id, String category, double value) {
this.id = id;
this.category = category;
this.value = value;
}
@Override
public String toString() {
return String.format("DataItem{id=%d, category='%s', value=%.2f}",
id, category, value);
}
}
// 生成测试数据
private static List<DataItem> generateData(int size) {
List<DataItem> data = new ArrayList<>();
String[] categories = {"A", "B", "C"};
Random random = new Random(42);
for (int i = 0; i < size; i++) {
data.add(new DataItem(
i,
categories[random.nextInt(categories.length)],
random.nextDouble() * 100
));
}
return data;
}
}
简单随机抽样
// 简单随机抽样实现
public class SimpleRandomSampling {
/**
* 简单随机抽样(不放回)
* @param population 总体数据
* @param sampleSize 样本大小
* @return 抽样结果
*/
public static <T> List<T> sampleWithoutReplacement(List<T> population, int sampleSize) {
if (sampleSize > population.size()) {
throw new IllegalArgumentException("样本大小不能超过总体大小");
}
List<T> copy = new ArrayList<>(population);
List<T> sample = new ArrayList<>();
Random random = new Random();
for (int i = 0; i < sampleSize; i++) {
int index = random.nextInt(copy.size());
sample.add(copy.remove(index));
}
return sample;
}
/**
* 简单随机抽样(放回)
* @param population 总体数据
* @param sampleSize 样本大小
* @return 抽样结果
*/
public static <T> List<T> sampleWithReplacement(List<T> population, int sampleSize) {
List<T> sample = new ArrayList<>();
Random random = new Random();
for (int i = 0; i < sampleSize; i++) {
int index = random.nextInt(population.size());
sample.add(population.get(index));
}
return sample;
}
/**
* 使用Fisher-Yates洗牌算法进行抽样
*/
public static <T> List<T> fisherYatesSample(List<T> population, int sampleSize) {
if (sampleSize > population.size()) {
throw new IllegalArgumentException("样本大小不能超过总体大小");
}
List<T> result = new ArrayList<>(population);
Random random = new Random();
for (int i = result.size() - 1; i >= result.size() - sampleSize; i--) {
int j = random.nextInt(i + 1);
// 交换元素
T temp = result.get(i);
result.set(i, result.get(j));
result.set(j, temp);
}
return result.subList(result.size() - sampleSize, result.size());
}
}
等距抽样(系统抽样)
public class SystematicSampling {
/**
* 等距抽样
* @param population 总体数据
* @param sampleSize 样本大小
* @return 抽样结果
*/
public static <T> List<T> systematicSample(List<T> population, int sampleSize) {
if (sampleSize > population.size()) {
throw new IllegalArgumentException("样本大小不能超过总体大小");
}
int interval = population.size() / sampleSize;
List<T> sample = new ArrayList<>();
Random random = new Random();
// 随机选择起点
int start = random.nextInt(interval);
for (int i = start; i < population.size(); i += interval) {
if (sample.size() >= sampleSize) {
break;
}
sample.add(population.get(i));
}
return sample;
}
/**
* 循环等距抽样
*/
public static <T> List<T> circularSystematicSample(List<T> population, int sampleSize) {
if (sampleSize > population.size()) {
throw new IllegalArgumentException("样本大小不能超过总体大小");
}
int interval = population.size() / sampleSize;
List<T> sample = new ArrayList<>();
Random random = new Random();
int start = random.nextInt(population.size());
for (int i = 0; i < sampleSize; i++) {
int index = (start + i * interval) % population.size();
sample.add(population.get(index));
}
return sample;
}
}
分层抽样
public class StratifiedSampling {
/**
* 分层抽样(等比例分配)
* @param population 总体数据
* @param sampleSize 样本大小
* @param categoryExtractor 类别提取函数
* @return 抽样结果
*/
public static <T> Map<String, List<T>> stratifiedSample(
List<T> population,
int sampleSize,
java.util.function.Function<T, String> categoryExtractor) {
// 按类别分组
Map<String, List<T>> groups = population.stream()
.collect(Collectors.groupingBy(categoryExtractor));
// 计算每个层应该抽取的样本数
Map<String, Double> weights = new HashMap<>();
int totalSize = population.size();
for (Map.Entry<String, List<T>> entry : groups.entrySet()) {
weights.put(entry.getKey(), (double) entry.getValue().size() / totalSize);
}
// 进行抽样
Map<String, List<T>> sample = new HashMap<>();
Random random = new Random();
for (Map.Entry<String, List<T>> entry : groups.entrySet()) {
int stratumSize = (int) Math.round(weights.get(entry.getKey()) * sampleSize);
List<T> stratum = entry.getValue();
// 从该层随机抽样
List<T> stratumSample = new ArrayList<>();
List<T> stratumCopy = new ArrayList<>(stratum);
for (int i = 0; i < Math.min(stratumSize, stratumCopy.size()); i++) {
int index = random.nextInt(stratumCopy.size());
stratumSample.add(stratumCopy.remove(index));
}
sample.put(entry.getKey(), stratumSample);
}
return sample;
}
/**
* 最优分配分层抽样(考虑层内方差)
*/
public static <T> Map<String, List<T>> optimalStratifiedSample(
List<T> population,
int sampleSize,
java.util.function.Function<T, String> categoryExtractor,
java.util.function.ToDoubleFunction<T> valueExtractor) {
// 按类别分组
Map<String, List<T>> groups = population.stream()
.collect(Collectors.groupingBy(categoryExtractor));
// 计算每层的权重和标准差
Map<String, Double> weights = new HashMap<>();
Map<String, Double> stdDeviations = new HashMap<>();
int totalSize = population.size();
for (Map.Entry<String, List<T>> entry : groups.entrySet()) {
List<T> stratum = entry.getValue();
weights.put(entry.getKey(), (double) stratum.size() / totalSize);
// 计算标准差
double mean = stratum.stream()
.mapToDouble(valueExtractor)
.average()
.orElse(0);
double variance = stratum.stream()
.mapToDouble(item -> Math.pow(valueExtractor.applyAsDouble(item) - mean, 2))
.average()
.orElse(0);
stdDeviations.put(entry.getKey(), Math.sqrt(variance));
}
// 计算总加权标准差
double totalWeightedStd = groups.keySet().stream()
.mapToDouble(k -> weights.get(k) * stdDeviations.get(k))
.sum();
// 最优分配抽样
Map<String, List<T>> sample = new HashMap<>();
Random random = new Random();
for (Map.Entry<String, List<T>> entry : groups.entrySet()) {
String category = entry.getKey();
List<T> stratum = entry.getValue();
// 最优分配公式
int stratumSize = (int) Math.round(
sampleSize * weights.get(category) * stdDeviations.get(category) / totalWeightedStd
);
// 从该层随机抽样
List<T> stratumSample = new ArrayList<>();
List<T> stratumCopy = new ArrayList<>(stratum);
for (int i = 0; i < Math.min(stratumSize, stratumCopy.size()); i++) {
int index = random.nextInt(stratumCopy.size());
stratumSample.add(stratumCopy.remove(index));
}
sample.put(category, stratumSample);
}
return sample;
}
}
完整示例和测试
public class SamplingExample {
public static void main(String[] args) {
// 生成测试数据
List<DataItem> population = DataSamplingDemo.generateData(1000);
System.out.println("总体大小: " + population.size());
// 1. 简单随机抽样(不放回)
System.out.println("\n=== 简单随机抽样(不放回)===");
List<DataItem> simpleSample = SimpleRandomSampling.sampleWithoutReplacement(population, 100);
System.out.println("样本大小: " + simpleSample.size());
simpleSample.subList(0, 5).forEach(System.out::println);
// 2. 简单随机抽样(放回)
System.out.println("\n=== 简单随机抽样(放回)===");
List<DataItem> sampleWithReplace = SimpleRandomSampling.sampleWithReplacement(population, 100);
System.out.println("样本大小: " + sampleWithReplace.size());
// 3. 等距抽样
System.out.println("\n=== 等距抽样 ===");
List<DataItem> systematicSample = SystematicSampling.systematicSample(population, 100);
System.out.println("样本大小: " + systematicSample.size());
// 4. 分层抽样(等比例)
System.out.println("\n=== 分层抽样(等比例)===");
Map<String, List<DataItem>> stratifiedSample =
StratifiedSampling.stratifiedSample(population, 100, item -> item.category);
stratifiedSample.forEach((category, sample) -> {
System.out.printf("类别 %s: 样本大小 %d%n", category, sample.size());
});
// 5. 最优分配分层抽样
System.out.println("\n=== 最优分配分层抽样 ===");
Map<String, List<DataItem>> optimalSample =
StratifiedSampling.optimalStratifiedSample(
population, 100,
item -> item.category,
item -> item.value
);
optimalSample.forEach((category, sample) -> {
System.out.printf("类别 %s: 样本大小 %d%n", category, sample.size());
});
// 性能对比
System.out.println("\n=== 性能对比 ===");
long startTime = System.nanoTime();
SimpleRandomSampling.fisherYatesSample(population, 100);
long fisherYatesTime = System.nanoTime() - startTime;
startTime = System.nanoTime();
SimpleRandomSampling.sampleWithoutReplacement(population, 100);
long removeTime = System.nanoTime() - startTime;
System.out.println("Fisher-Yates洗牌算法: " + fisherYatesTime / 1_000_000 + " ms");
System.out.println("remove方法: " + removeTime / 1_000_000 + " ms");
}
}
高级:蓄水池抽样
public class ReservoirSampling {
/**
* 蓄水池抽样(适用于数据流或未知大小的数据)
* @param stream 数据流
* @param k 样本大小
* @return 抽样结果
*/
public static <T> List<T> reservoirSample(Iterator<T> stream, int k) {
List<T> reservoir = new ArrayList<>(k);
Random random = new Random();
int i = 0;
// 填充蓄水池
while (i < k && stream.hasNext()) {
reservoir.add(stream.next());
i++;
}
// 对后续元素进行处理
while (stream.hasNext()) {
T item = stream.next();
i++;
// 以 k/i 的概率替换蓄水池中的元素
int j = random.nextInt(i);
if (j < k) {
reservoir.set(j, item);
}
}
return reservoir;
}
/**
* 加权蓄水池抽样
*/
public static <T> List<T> weightedReservoirSample(
Iterator<T> stream,
int k,
java.util.function.ToDoubleFunction<T> weightExtractor) {
List<T> reservoir = new ArrayList<>(k);
List<Double> weights = new ArrayList<>(k);
Random random = new Random();
int i = 0;
// 填充蓄水池
while (i < k && stream.hasNext()) {
T item = stream.next();
reservoir.add(item);
weights.add(Math.pow(random.nextDouble(), 1.0 / weightExtractor.applyAsDouble(item)));
i++;
}
// 对后续元素进行处理
while (stream.hasNext()) {
T item = stream.next();
double weight = weightExtractor.applyAsDouble(item);
double key = Math.pow(random.nextDouble(), 1.0 / weight);
if (key > Collections.min(weights)) {
int minIndex = weights.indexOf(Collections.min(weights));
reservoir.set(minIndex, item);
weights.set(minIndex, key);
}
}
return reservoir;
}
}
这些实现提供了多种数据抽样策略,可以根据具体需求选择合适的方法,关键考虑因素包括:
- 抽样精度要求:简单随机抽样最常用
- 数据分布特征:分层抽样适合有明确分层的数据
- 性能要求:Fisher-Yates洗牌算法效率较高
- 数据规模:蓄水池抽样适合大数据流
- 抽样方式:放回或不放回抽样