如何用Java案例实现图像识别?

wen java案例 1

本文目录导读:

如何用Java案例实现图像识别?

  1. 环境准备
  2. 图像识别核心类
  3. 主程序示例
  4. 简单图像识别(不使用深度学习框架)
  5. 使用说明
  6. 注意事项

我将为您提供一个使用Java实现图像识别的完整案例,这里使用TensorFlow Java API和预训练的深度学习模型来实现图像分类。

环境准备

Maven依赖 (pom.xml)

<dependencies>
    <!-- TensorFlow Java API -->
    <dependency>
        <groupId>org.tensorflow</groupId>
        <artifactId>tensorflow</artifactId>
        <version>1.15.0</version>
    </dependency>
    <!-- 图像处理 -->
    <dependency>
        <groupId>net.coobird</groupId>
        <artifactId>thumbnailator</artifactId>
        <version>0.4.8</version>
    </dependency>
    <!-- JSON处理 -->
    <dependency>
        <groupId>com.google.code.gson</groupId>
        <artifactId>gson</artifactId>
        <version>2.8.6</version>
    </dependency>
</dependencies>

图像识别核心类

package com.example.imagerecognition;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.types.UInt8;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
public class ImageRecognizer {
    private byte[] graphDef;
    private List<String> labels;
    private Graph graph;
    public ImageRecognizer(String modelPath, String labelsPath) throws IOException {
        // 加载模型
        this.graphDef = readAllBytes(Paths.get(modelPath));
        this.labels = readLabels(labelsPath);
        // 加载TensorFlow图
        this.graph = new Graph();
        graph.importGraphDef(graphDef);
    }
    /**
     * 识别图像
     * @param imagePath 图像路径
     * @return 识别结果列表(按置信度排序)
     */
    public List<RecognitionResult> recognize(String imagePath) throws IOException {
        BufferedImage image = ImageIO.read(Paths.get(imagePath).toFile());
        return recognize(image);
    }
    /**
     * 识别图像
     * @param image BufferedImage对象
     * @return 识别结果列表
     */
    public List<RecognitionResult> recognize(BufferedImage image) throws IOException {
        // 预处理图像
        Tensor<Float> imageTensor = preprocessImage(image);
        // 运行模型
        try (Session session = new Session(graph);
             Tensor<Float> result = session.runner()
                     .feed("input", imageTensor)
                     .fetch("output")
                     .run()
                     .get(0)
                     .expect(Float.class)) {
            // 解析结果
            return parseResult(result);
        }
    }
    /**
     * 预处理图像
     */
    private Tensor<Float> preprocessImage(BufferedImage image) throws IOException {
        // 调整图像大小为224x224 (InceptionV3/MobileNet)
        BufferedImage resizedImage = new BufferedImage(224, 224, BufferedImage.TYPE_3BYTE_BGR);
        resizedImage.getGraphics().drawImage(image, 0, 0, 224, 224, null);
        // 将图像转换为字节数组
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ImageIO.write(resizedImage, "jpg", baos);
        byte[] imageBytes = baos.toByteArray();
        // 创建Tensor
        return Tensor.create(new long[]{1, 224, 224, 3}, Float.class)
                .copyFrom(convertImageToFloat(resizedImage));
    }
    /**
     * 将BufferedImage转换为float数组
     */
    private float[][][][] convertImageToFloat(BufferedImage image) {
        int width = image.getWidth();
        int height = image.getHeight();
        float[][][][] result = new float[1][height][width][3];
        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                int rgb = image.getRGB(x, y);
                // 归一化到[-1, 1]范围 (适用于MobileNet)
                result[0][y][x][0] = ((rgb >> 16) & 0xFF) / 127.5f - 1.0f;  // R
                result[0][y][x][1] = ((rgb >> 8) & 0xFF) / 127.5f - 1.0f;   // G
                result[0][y][x][2] = (rgb & 0xFF) / 127.5f - 1.0f;          // B
            }
        }
        return result;
    }
    /**
     * 解析模型输出结果
     */
    private List<RecognitionResult> parseResult(Tensor<Float> tensor) {
        float[][] probabilities = tensor.copyTo(new float[1][labels.size()]);
        List<RecognitionResult> results = new ArrayList<>();
        for (int i = 0; i < labels.size(); i++) {
            results.add(new RecognitionResult(labels.get(i), probabilities[0][i]));
        }
        // 按置信度降序排序
        Collections.sort(results, (a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
        return results;
    }
    /**
     * 读取标签文件
     */
    private List<String> readLabels(String labelsPath) throws IOException {
        List<String> labels = new ArrayList<>();
        Files.lines(Paths.get(labelsPath))
                .forEach(line -> labels.add(line.trim()));
        return labels;
    }
    /**
     * 读取文件所有字节
     */
    private byte[] readAllBytes(Path path) throws IOException {
        return Files.readAllBytes(path);
    }
    /**
     * 关闭资源
     */
    public void close() {
        if (graph != null) {
            graph.close();
        }
    }
    /**
     * 识别结果类
     */
    public static class RecognitionResult {
        private String label;
        private float confidence;
        public RecognitionResult(String label, float confidence) {
            this.label = label;
            this.confidence = confidence;
        }
        public String getLabel() { return label; }
        public float getConfidence() { return confidence; }
        @Override
        public String toString() {
            return String.format("%s: %.2f%%", label, confidence * 100);
        }
    }
}

主程序示例

package com.example.imagerecognition;
import javax.swing.*;
import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.File;
import java.io.IOException;
import java.util.List;
public class ImageRecognitionDemo {
    private JFrame frame;
    private JLabel imageLabel;
    private JTextArea resultArea;
    private ImageRecognizer recognizer;
    public ImageRecognitionDemo() throws IOException {
        // 初始化识别器(需要下载模型文件)
        String modelPath = "models/mobilenet_v1_1.0_224_frozen.pb";
        String labelsPath = "models/labels.txt";
        // 如果模型不存在,使用备用方案
        if (!new File(modelPath).exists()) {
            System.out.println("模型文件不存在,使用模拟识别功能");
            recognizer = null;
        } else {
            recognizer = new ImageRecognizer(modelPath, labelsPath);
        }
        initUI();
    }
    private void initUI() {
        frame = new JFrame("Java图像识别演示");
        frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        frame.setLayout(new BorderLayout(10, 10));
        // 图像显示面板
        imageLabel = new JLabel("请选择图片", SwingConstants.CENTER);
        imageLabel.setPreferredSize(new Dimension(400, 400));
        imageLabel.setBorder(BorderFactory.createLineBorder(Color.GRAY));
        // 结果文本区域
        resultArea = new JTextArea(10, 40);
        resultArea.setEditable(false);
        resultArea.setFont(new Font("Monospaced", Font.PLAIN, 12));
        JScrollPane scrollPane = new JScrollPane(resultArea);
        // 按钮面板
        JPanel buttonPanel = new JPanel();
        JButton selectButton = new JButton("选择图片");
        JButton quitButton = new JButton("退出");
        selectButton.addActionListener(new ActionListener() {
            @Override
            public void actionPerformed(ActionEvent e) {
                selectAndRecognizeImage();
            }
        });
        quitButton.addActionListener(e -> System.exit(0));
        buttonPanel.add(selectButton);
        buttonPanel.add(quitButton);
        // 布局
        frame.add(imageLabel, BorderLayout.CENTER);
        frame.add(scrollPane, BorderLayout.SOUTH);
        frame.add(buttonPanel, BorderLayout.NORTH);
        frame.pack();
        frame.setLocationRelativeTo(null);
        frame.setVisible(true);
    }
    private void selectAndRecognizeImage() {
        JFileChooser fileChooser = new JFileChooser();
        fileChooser.setFileFilter(new javax.swing.filechooser.FileNameExtensionFilter(
                "图片文件", "jpg", "jpeg", "png", "gif", "bmp"));
        int result = fileChooser.showOpenDialog(frame);
        if (result == JFileChooser.APPROVE_OPTION) {
            File selectedFile = fileChooser.getSelectedFile();
            try {
                // 显示图片
                ImageIcon imageIcon = new ImageIcon(selectedFile.getAbsolutePath());
                Image image = imageIcon.getImage();
                Image scaledImage = image.getScaledInstance(400, 400, Image.SCALE_SMOOTH);
                imageLabel.setIcon(new ImageIcon(scaledImage));
                // 执行识别
                if (recognizer != null) {
                    List<ImageRecognizer.RecognitionResult> results = 
                            recognizer.recognize(selectedFile.getAbsolutePath());
                    displayResults(results);
                } else {
                    // 模拟识别结果
                    simulateRecognition();
                }
            } catch (IOException e) {
                resultArea.setText("识别失败: " + e.getMessage());
            }
        }
    }
    private void displayResults(List<ImageRecognizer.RecognitionResult> results) {
        StringBuilder sb = new StringBuilder();
        sb.append("=== 识别结果 ===\n\n");
        sb.append("前5个最可能的类别:\n");
        int count = Math.min(5, results.size());
        for (int i = 0; i < count; i++) {
            ImageRecognizer.RecognitionResult result = results.get(i);
            sb.append(String.format("%d. %s\n", i + 1, result));
        }
        resultArea.setText(sb.toString());
    }
    private void simulateRecognition() {
        String[] labels = {"猫", "狗", "花", "汽车", "建筑"};
        double[] confidences = {0.85, 0.72, 0.68, 0.45, 0.30};
        StringBuilder sb = new StringBuilder();
        sb.append("=== 模拟识别结果 ===\n\n");
        sb.append("注意:这是模拟数据,需要模型文件才能真实识别\n\n");
        sb.append("前5个最可能的类别:\n");
        for (int i = 0; i < labels.length; i++) {
            sb.append(String.format("%d. %s: %.2f%%\n", 
                    i + 1, labels[i], confidences[i] * 100));
        }
        sb.append("\n---\n要启用真实识别,请下载MobileNet模型文件:\n");
        sb.append("1. 从TensorFlow官网下载mobilenet_v1_1.0_224_frozen.pb\n");
        sb.append("2. 下载对应的labels.txt\n");
        sb.append("3. 放到models/目录下");
        resultArea.setText(sb.toString());
    }
    public static void main(String[] args) {
        SwingUtilities.invokeLater(new Runnable() {
            @Override
            public void run() {
                try {
                    new ImageRecognitionDemo();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        });
    }
}

简单图像识别(不使用深度学习框架)

如果不想使用TensorFlow,这里提供一个简单的基于像素比较的颜色识别器:

package com.example.imagerecognition;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.List;
public class SimpleColorRecognizer {
    /**
     * 识别图像中的主要颜色
     */
    public static List<ColorResult> recognizeMainColors(String imagePath) throws IOException {
        BufferedImage image = ImageIO.read(new File(imagePath));
        return analyzeColors(image);
    }
    private static List<ColorResult> analyzeColors(BufferedImage image) {
        int width = image.getWidth();
        int height = image.getHeight();
        // 颜色计数器
        Map<String, Integer> colorCount = new HashMap<>();
        // 采样像素(每10像素采样一次以减少计算量)
        for (int y = 0; y < height; y += 10) {
            for (int x = 0; x < width; x += 10) {
                int rgb = image.getRGB(x, y);
                String colorName = getColorName(new Color(rgb));
                colorCount.put(colorName, colorCount.getOrDefault(colorName, 0) + 1);
            }
        }
        // 计算总采样点
        int totalSamples = colorCount.values().stream().mapToInt(Integer::intValue).sum();
        // 排序并返回结果
        List<ColorResult> results = new ArrayList<>();
        colorCount.forEach((name, count) -> {
            double percentage = (double) count / totalSamples * 100;
            results.add(new ColorResult(name, percentage));
        });
        results.sort((a, b) -> Double.compare(b.getPercentage(), a.getPercentage()));
        return results;
    }
    private static String getColorName(Color color) {
        int red = color.getRed();
        int green = color.getGreen();
        int blue = color.getBlue();
        // 简化的颜色分类
        if (red > 200 && green > 200 && blue > 200) return "白色";
        if (red < 50 && green < 50 && blue < 50) return "黑色";
        if (red > 200 && green < 100 && blue < 100) return "红色";
        if (red < 100 && green > 200 && blue < 100) return "绿色";
        if (red < 100 && green < 100 && blue > 200) return "蓝色";
        if (red > 200 && green > 200 && blue < 100) return "黄色";
        if (red > 200 && green < 100 && blue > 200) return "紫色";
        if (red < 100 && green > 200 && blue > 200) return "青色";
        if (red > 150 && green < 150 && blue < 150) return "橙色";
        if (red > 150 && green > 150 && blue < 150) return "棕色";
        return "其他";
    }
    public static class ColorResult {
        private String colorName;
        private double percentage;
        public ColorResult(String colorName, double percentage) {
            this.colorName = colorName;
            this.percentage = percentage;
        }
        public String getColorName() { return colorName; }
        public double getPercentage() { return percentage; }
        @Override
        public String toString() {
            return String.format("%s: %.1f%%", colorName, percentage);
        }
    }
    public static void main(String[] args) {
        try {
            List<ColorResult> results = recognizeMainColors("test.jpg");
            System.out.println("图像主要颜色分析:");
            results.forEach(System.out::println);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

使用说明

下载模型文件

  1. 访问 TensorFlow Models 下载MobileNet模型
  2. 下载对应的标签文件

运行程序

# 编译
javac -cp "lib/*" ImageRecognitionDemo.java
# 运行
java -cp "lib/*;." ImageRecognitionDemo

注意事项

  1. 模型兼容性:确保TensorFlow模型版本与Java API版本匹配
  2. 性能优化:处理大图像时可使用缩略图来提高速度
  3. 内存管理:及时释放Tensor和Session资源
  4. 异常处理:添加适当的异常处理机制

这个案例提供了完整的图像识别实现,从基础的图像处理到使用深度学习模型进行真实识别,您可以根据需求选择合适的实现方式。

抱歉,评论功能暂时关闭!