Python案例如何实现图像修复?

wen python案例 1

本文目录导读:

Python案例如何实现图像修复?

  1. 使用OpenCV的修复函数(最简单)
  2. 交互式选择修复区域
  3. 自动检测并修复缺陷
  4. 使用深度学习方法(PatchMatch算法)
  5. 使用scikit-image的修复函数
  6. 完整的图像修复工具类
  7. 使用建议
  8. 安装依赖

我来详细讲解Python实现图像修复的几种方法,从基础到进阶。

使用OpenCV的修复函数(最简单)

基于邻域修复(INPAINT_TELEA)

import cv2
import numpy as np
import matplotlib.pyplot as plt
def basic_inpainting():
    # 读取图像
    img = cv2.imread('damaged_image.jpg')
    # 创建掩膜(需要修复的区域标记为白色)
    mask = cv2.imread('mask.jpg', 0)  # 灰度图
    # 或者手动创建掩膜
    # mask = np.zeros(img.shape[:2], np.uint8)
    # mask[100:200, 100:200] = 255  # 标记需要修复的区域
    # 图像修复
    result_telea = cv2.inpaint(img, mask, 3, cv2.INPAINT_TELEA)
    result_ns = cv2.inpaint(img, mask, 3, cv2.INPAINT_NS)
    # 显示结果
    plt.figure(figsize=(15, 5))
    plt.subplot(131)
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.title('原始图像')
    plt.axis('off')
    plt.subplot(132)
    plt.imshow(cv2.cvtColor(result_telea, cv2.COLOR_BGR2RGB))
    plt.title('Telea算法修复')
    plt.axis('off')
    plt.subplot(133)
    plt.imshow(cv2.cvtColor(result_ns, cv2.COLOR_BGR2RGB))
    plt.title('Navier-Stokes算法修复')
    plt.axis('off')
    plt.tight_layout()
    plt.show()
basic_inpainting()

交互式选择修复区域

import cv2
import numpy as np
class InteractiveInpainting:
    def __init__(self, image_path):
        self.img = cv2.imread(image_path)
        self.mask = np.zeros(self.img.shape[:2], np.uint8)
        self.drawing = False
        self.brush_size = 10
    def draw_mask(self, event, x, y, flags, param):
        if event == cv2.EVENT_LBUTTONDOWN:
            self.drawing = True
            cv2.circle(self.mask, (x, y), self.brush_size, 255, -1)
        elif event == cv2.EVENT_MOUSEMOVE and self.drawing:
            cv2.circle(self.mask, (x, y), self.brush_size, 255, -1)
        elif event == cv2.EVENT_LBUTTONUP:
            self.drawing = False
    def run(self):
        cv2.namedWindow('Mask Drawing')
        cv2.setMouseCallback('Mask Drawing', self.draw_mask)
        while True:
            display = self.img.copy()
            display[self.mask == 255] = [0, 0, 255]  # 标记为红色
            cv2.imshow('Mask Drawing', display)
            key = cv2.waitKey(1) & 0xFF
            if key == ord('r'):  # 重置
                self.mask = np.zeros(self.img.shape[:2], np.uint8)
            elif key == ord('+'):  # 增大笔刷
                self.brush_size += 2
            elif key == ord('-'):  # 减小笔刷
                self.brush_size = max(1, self.brush_size - 2)
            elif key == ord('p'):  # 执行修复
                result = cv2.inpaint(self.img, self.mask, 3, cv2.INPAINT_TELEA)
                cv2.imshow('Result', result)
            elif key == ord('s'):  # 保存
                result = cv2.inpaint(self.img, self.mask, 3, cv2.INPAINT_TELEA)
                cv2.imwrite('repaired_image.jpg', result)
                print("图像已保存")
            elif key == 27:  # ESC退出
                break
        cv2.destroyAllWindows()
# 使用交互式工具
# repair = InteractiveInpainting('damaged_image.jpg')
# repair.run()

自动检测并修复缺陷

import cv2
import numpy as np
from skimage import restoration, morphology
def auto_detect_and_repair(image_path):
    # 读取图像
    img = cv2.imread(image_path)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    # 方法1:使用边缘检测找出缺陷
    edges = cv2.Canny(gray, 50, 150)
    kernel = np.ones((5,5), np.uint8)
    dilated = cv2.dilate(edges, kernel, iterations=2)
    # 创建掩膜
    mask = dilated.copy()
    # 方法2:使用阈值检测暗区域
    _, thresh = cv2.threshold(gray, 30, 255, cv2.THRESH_BINARY_INV)
    mask = cv2.bitwise_or(mask, thresh)
    # 清理掩膜
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
    # 图像修复
    result = cv2.inpaint(img, mask, 3, cv2.INPAINT_TELEA)
    return img, mask, result
# 使用示例
original, mask, repaired = auto_detect_and_repair('old_photo.jpg')
# 显示结果
plt.figure(figsize=(15, 5))
plt.subplot(131), plt.imshow(cv2.cvtColor(original, cv2.COLOR_BGR2RGB)), plt.title('原始图像')
plt.subplot(132), plt.imshow(mask, cmap='gray'), plt.title('检测到的缺陷')
plt.subplot(133), plt.imshow(cv2.cvtColor(repaired, cv2.COLOR_BGR2RGB)), plt.title('修复结果')
plt.show()

使用深度学习方法(PatchMatch算法)

import numpy as np
import cv2
from scipy import ndimage
class PatchMatchInpainting:
    def __init__(self, patch_size=7):
        self.patch_size = patch_size
        self.half_patch = patch_size // 2
    def get_patch(self, img, y, x):
        """获取图像块"""
        return img[y-self.half_patch:y+self.half_patch+1, 
                   x-self.half_patch:x+self.half_patch+1]
    def patch_distance(self, patch1, patch2, mask_patch):
        """计算两个patch之间的距离(只考虑已知像素)"""
        diff = (patch1 - patch2) ** 2
        diff = diff * mask_patch
        return np.sum(diff) / max(np.sum(mask_patch), 1)
    def inpaint(self, image, mask):
        """执行基于PatchMatch的图像修复"""
        img = image.copy().astype(np.float32)
        h, w = img.shape[:2]
        # 找到需要修复的像素
        unknown_pixels = np.where(mask == 255)
        # 对于每个需要修复的像素
        for i in range(len(unknown_pixels[0])):
            y, x = unknown_pixels[0][i], unknown_pixels[1][i]
            # 跳过边缘像素
            if (y < self.half_patch or y >= h - self.half_patch or
                x < self.half_patch or x >= w - self.half_patch):
                continue
            # 获取目标patch
            target_patch = self.get_patch(img, y, x)
            target_mask = self.get_patch(mask.astype(np.float32), y, x)
            target_mask = 1 - target_mask / 255  # 未知区域为1
            best_patch = None
            best_distance = float('inf')
            # 搜索最佳匹配patch
            for sy in range(self.half_patch, h - self.half_patch, 2):
                for sx in range(self.half_patch, w - self.half_patch, 2):
                    # 跳过未知区域
                    if mask[sy, sx] == 255:
                        continue
                    source_patch = self.get_patch(img, sy, sx)
                    distance = self.patch_distance(target_patch, source_patch, target_mask)
                    if distance < best_distance:
                        best_distance = distance
                        best_patch = source_patch
            # 用最佳匹配填充
            if best_patch is not None:
                for c in range(3):
                    img[y-self.half_patch:y+self.half_patch+1, 
                        x-self.half_patch:x+self.half_patch+1, c] = best_patch[:, :, c]
        return img.astype(np.uint8)
# 使用PatchMatch修复
def patchmatch_demo():
    img = cv2.imread('damaged_image.jpg')
    mask = cv2.imread('mask.jpg', 0)
    # 确保掩膜是二值图
    _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    # 执行修复
    pm_inpainting = PatchMatchInpainting(patch_size=7)
    result = pm_inpainting.inpaint(img, mask)
    # 显示结果
    plt.figure(figsize=(15, 5))
    plt.subplot(131), plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)), plt.title('原始图像')
    plt.subplot(132), plt.imshow(mask, cmap='gray'), plt.title('掩膜')
    plt.subplot(133), plt.imshow(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)), plt.title('PatchMatch修复')
    plt.show()
# patchmatch_demo()

使用scikit-image的修复函数

from skimage import restoration, io, color
import matplotlib.pyplot as plt
def skimage_inpainting():
    # 读取图像
    image = io.imread('damaged_image.jpg')
    mask = io.imread('mask.png', as_gray=True)
    # 转换为浮点型
    image_float = image.astype(float) / 255.0
    # 如果是彩色图像,分别处理每个通道
    if len(image_float.shape) == 3:
        result = np.zeros_like(image_float)
        for channel in range(3):
            result[:,:,channel] = restoration.inpaint_biharmonic(
                image_float[:,:,channel], 
                mask > 0.5,
                multichannel=False
            )
    else:
        result = restoration.inpaint_biharmonic(
            image_float, 
            mask > 0.5
        )
    return image, result
original, repaired = skimage_inpainting()

完整的图像修复工具类

import cv2
import numpy as np
from enum import Enum
class InpaintingMethod(Enum):
    TELEA = 1
    NAVIER_STOKES = 2
    PATCH_MATCH = 3
class ImageRepairTool:
    def __init__(self):
        self.methods = {
            InpaintingMethod.TELEA: cv2.INPAINT_TELEA,
            InpaintingMethod.NAVIER_STOKES: cv2.INPAINT_NS
        }
    def create_mask_from_text(self, image, text="", position=(50, 50)):
        """在图像上添加文字作为需要修复的区域"""
        mask = np.zeros(image.shape[:2], np.uint8)
        if text:
            font = cv2.FONT_HERSHEY_SIMPLEX
            text_size = cv2.getTextSize(text, font, 1, 2)[0]
            cv2.putText(mask, text, position, font, 1, 255, 2)
        return mask
    def remove_object(self, image, bbox):
        """移除指定边界框内的物体"""
        x, y, w, h = bbox
        mask = np.zeros(image.shape[:2], np.uint8)
        mask[y:y+h, x:x+w] = 255
        return mask
    def repair(self, image, mask, method=InpaintingMethod.TELEA, radius=3):
        """执行图像修复"""
        if method in [InpaintingMethod.TELEA, InpaintingMethod.NAVIER_STOKES]:
            return cv2.inpaint(image, mask, radius, self.methods[method])
        else:
            # 使用PatchMatch
            pm = PatchMatchInpainting()
            return pm.inpaint(image, mask)
    def batch_repair(self, image_paths, mask_paths, output_dir):
        """批量修复图像"""
        import os
        os.makedirs(output_dir, exist_ok=True)
        for img_path, mask_path in zip(image_paths, mask_paths):
            img = cv2.imread(img_path)
            mask = cv2.imread(mask_path, 0)
            result = self.repair(img, mask)
            output_path = os.path.join(output_dir, 
                                      f"repaired_{os.path.basename(img_path)}")
            cv2.imwrite(output_path, result)
            print(f"已修复: {img_path}")
# 使用示例
def demo_full_tool():
    tool = ImageRepairTool()
    # 读取图像
    img = cv2.imread('photo.jpg')
    # 方法1:去除文字
    mask_text = tool.create_mask_from_text(img, "Remove this text")
    result1 = tool.repair(img, mask_text)
    # 方法2:去除物体
    mask_object = tool.remove_object(img, (50, 50, 100, 100))
    result2 = tool.repair(img, mask_object, InpaintingMethod.NAVIER_STOKES)
    # 显示结果
    plt.figure(figsize=(15, 5))
    plt.subplot(131), plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.title('原始图像')
    plt.subplot(132), plt.imshow(cv2.cvtColor(result1, cv2.COLOR_BGR2RGB))
    plt.title('去除文字')
    plt.subplot(133), plt.imshow(cv2.cvtColor(result2, cv2.COLOR_BGR2RGB))
    plt.title('去除物体')
    plt.show()
# demo_full_tool()

使用建议

  1. 简单修复:使用OpenCV的cv2.inpaint()函数
  2. 交互式修复:创建GUI让用户手动标记修复区域
  3. 批量处理:使用上述工具类处理多张图像
  4. 深度学习:对于复杂情况,考虑使用OpenCV的DNN模块或PyTorch

安装依赖

pip install opencv-python numpy matplotlib scikit-image scipy

选择哪种方法取决于:

  • 修复区域的复杂度
  • 图像质量要求
  • 处理速度需求
  • 是否需要交互操作

对于大多数日常使用,OpenCV内置的修复函数已经足够好用。

抱歉,评论功能暂时关闭!