本文目录导读:

我来为您提供几个Java数据采样的实现案例,涵盖常见的采样算法:
简单随机采样
import java.util.*;
import java.util.stream.Collectors;
public class RandomSampling {
/**
* 简单随机采样 - 从数据集中随机选择n个样本
*/
public static <T> List<T> randomSample(List<T> data, int sampleSize) {
if (sampleSize >= data.size()) {
return new ArrayList<>(data);
}
List<T> copy = new ArrayList<>(data);
Collections.shuffle(copy);
return copy.subList(0, sampleSize);
}
/**
* 简单随机采样 - 使用Reservoir Sampling(适合大数据流)
*/
public static <T> List<T> reservoirSample(List<T> data, int sampleSize) {
List<T> reservoir = new ArrayList<>(sampleSize);
Random random = new Random();
// 初始化:填充前k个元素
for (int i = 0; i < sampleSize && i < data.size(); i++) {
reservoir.add(data.get(i));
}
// 后续元素按概率替换
for (int i = sampleSize; i < data.size(); i++) {
int j = random.nextInt(i + 1);
if (j < sampleSize) {
reservoir.set(j, data.get(i));
}
}
return reservoir;
}
public static void main(String[] args) {
List<Integer> data = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
// 测试简单随机采样
System.out.println("简单随机采样: " + randomSample(data, 3));
// 测试Reservoir Sampling
System.out.println("Reservoir Sampling: " + reservoirSample(data, 3));
}
}
分层采样
import java.util.*;
import java.util.stream.Collectors;
public class StratifiedSampling {
/**
* 分层采样 - 按类别均匀采样
*/
public static <T> List<T> stratifiedSample(
List<T> data,
java.util.function.Function<T, String> labelExtractor,
int totalSampleSize) {
// 按标签分组
Map<String, List<T>> groups = data.stream()
.collect(Collectors.groupingBy(labelExtractor));
// 计算每层采样数量
int numGroups = groups.size();
int perGroupSize = totalSampleSize / numGroups;
List<T> result = new ArrayList<>();
Random random = new Random();
for (List<T> group : groups.values()) {
List<T> sample = new ArrayList<>(group);
Collections.shuffle(sample, random);
int actualSize = Math.min(perGroupSize, sample.size());
result.addAll(sample.subList(0, actualSize));
}
return result;
}
public static void main(String[] args) {
// 创建带标签的数据
List<DataPoint> data = new ArrayList<>();
data.add(new DataPoint("A", 1.0));
data.add(new DataPoint("A", 1.5));
data.add(new DataPoint("A", 2.0));
data.add(new DataPoint("B", 3.0));
data.add(new DataPoint("B", 3.5));
data.add(new DataPoint("B", 4.0));
data.add(new DataPoint("C", 5.0));
data.add(new DataPoint("C", 5.5));
data.add(new DataPoint("C", 6.0));
// 执行分层采样
List<DataPoint> sample = stratifiedSample(
data,
dp -> dp.label,
3
);
System.out.println("分层采样结果:");
sample.forEach(System.out::println);
}
static class DataPoint {
String label;
double value;
DataPoint(String label, double value) {
this.label = label;
this.value = value;
}
@Override
public String toString() {
return "DataPoint{" + "label='" + label + '\'' + ", value=" + value + '}';
}
}
}
系统采样
import java.util.*;
import java.util.stream.Collectors;
public class SystematicSampling {
/**
* 系统采样 - 按固定间隔采样
*/
public static <T> List<T> systematicSample(List<T> data, int sampleSize) {
if (sampleSize >= data.size()) {
return new ArrayList<>(data);
}
int n = data.size();
int interval = n / sampleSize;
int start = new Random().nextInt(interval); // 随机起点
List<T> result = new ArrayList<>(sampleSize);
for (int i = 0; i < sampleSize; i++) {
int index = start + i * interval;
if (index < n) {
result.add(data.get(index));
}
}
return result;
}
public static void main(String[] args) {
List<Integer> data = new ArrayList<>();
for (int i = 1; i <= 100; i++) {
data.add(i);
}
System.out.println("系统采样结果: " + systematicSample(data, 10));
}
}
时间序列采样
import java.time.*;
import java.util.*;
import java.util.stream.Collectors;
public class TimeSeriesSampling {
/**
* 时间序列采样器
*/
public static class TimeSeriesSampler<T> {
/**
* 按时间间隔采样
*/
public List<TimedPoint<T>> sampleByInterval(
List<TimedPoint<T>> data,
Duration interval) {
if (data == null || data.isEmpty()) {
return Collections.emptyList();
}
List<TimedPoint<T>> result = new ArrayList<>();
TimedPoint<T> lastSample = data.get(0);
result.add(lastSample);
for (TimedPoint<T> point : data) {
if (Duration.between(lastSample.timestamp, point.timestamp)
.compareTo(interval) >= 0) {
result.add(point);
lastSample = point;
}
}
return result;
}
/**
* 按固定数量采样
*/
public List<TimedPoint<T>> sampleByCount(
List<TimedPoint<T>> data,
int sampleCount) {
if (data.size() <= sampleCount) {
return new ArrayList<>(data);
}
int step = data.size() / sampleCount;
List<TimedPoint<T>> result = new ArrayList<>(sampleCount);
for (int i = 0; i < data.size(); i += step) {
if (result.size() < sampleCount) {
result.add(data.get(i));
}
}
return result;
}
}
static class TimedPoint<T> {
LocalDateTime timestamp;
T value;
TimedPoint(LocalDateTime timestamp, T value) {
this.timestamp = timestamp;
this.value = value;
}
@Override
public String toString() {
return "TimedPoint{" +
"timestamp=" + timestamp +
", value=" + value + '}';
}
}
public static void main(String[] args) {
// 创建时间序列数据
List<TimedPoint<Double>> timeSeries = new ArrayList<>();
LocalDateTime start = LocalDateTime.now();
for (int i = 0; i < 100; i++) {
timeSeries.add(new TimedPoint<>(
start.plusMinutes(i),
Math.random() * 100
));
}
TimeSeriesSampler<Double> sampler = new TimeSeriesSampler<>();
// 按10分钟间隔采样
System.out.println("按时间间隔采样 (10分钟):");
List<TimedPoint<Double>> intervalSample =
sampler.sampleByInterval(timeSeries, Duration.ofMinutes(10));
intervalSample.forEach(p -> System.out.println(p.timestamp + " -> " + p.value));
// 按固定数量采样
System.out.println("\n按固定数量采样 (10个):");
List<TimedPoint<Double>> countSample =
sampler.sampleByCount(timeSeries, 10);
countSample.forEach(p -> System.out.println(p.timestamp + " -> " + p.value));
}
}
有权重的采样
import java.util.*;
import java.util.stream.Collectors;
public class WeightedSampling {
/**
* 加权随机采样
*/
public static <T> List<T> weightedSample(
List<T> data,
List<Double> weights,
int sampleSize) {
if (data.size() != weights.size()) {
throw new IllegalArgumentException("数据和权重长度必须相同");
}
// 计算累计权重
double[] cumulativeWeights = new double[weights.size()];
double totalWeight = 0;
for (int i = 0; i < weights.size(); i++) {
totalWeight += weights.get(i);
cumulativeWeights[i] = totalWeight;
}
// 进行加权采样
List<T> result = new ArrayList<>(sampleSize);
Random random = new Random();
// 使用轮盘赌选择
Set<Integer> selectedIndices = new HashSet<>();
while (result.size() < sampleSize && result.size() < data.size()) {
double randomValue = random.nextDouble() * totalWeight;
// 二分查找选择位置
int index = binarySearch(cumulativeWeights, randomValue);
if (!selectedIndices.contains(index)) {
selectedIndices.add(index);
result.add(data.get(index));
}
}
return result;
}
private static int binarySearch(double[] cumulativeWeights, double target) {
int left = 0;
int right = cumulativeWeights.length - 1;
while (left < right) {
int mid = left + (right - left) / 2;
if (cumulativeWeights[mid] < target) {
left = mid + 1;
} else {
right = mid;
}
}
return left;
}
public static void main(String[] args) {
List<String> items = Arrays.asList("A", "B", "C", "D");
List<Double> weights = Arrays.asList(0.1, 0.2, 0.3, 0.4); // 概率权重
System.out.println("加权采样结果:");
for (int i = 0; i < 5; i++) {
List<String> sample = weightedSample(items, weights, 2);
System.out.println("采样" + (i+1) + ": " + sample);
}
}
}
使用建议
- 小数据集:使用简单随机采样
- 大数据流:使用Reservoir Sampling
- 类别不均衡:使用分层采样
- 时间序列数据:使用时间序列采样
- 有偏好的采样:使用加权采样
这些案例覆盖了常见的数据采样需求,您可以根据具体场景选择合适的采样方法。