使用SVM构建光照鲁棒的颜色分类器:从特征提取到SVM

如何在多变光照条件下实现准确的颜色识别

在计算机视觉领域,颜色分类看似简单,实则充满挑战。一个红色的苹果在阳光下、阴影中、黄昏时分呈现出的颜色差异巨大。如何让机器像人眼一样,在不同光照条件下准确识别颜色?

为什么颜色分类如此具有挑战性?

颜色感知受光照影响极大,这被称为颜色恒常性问题。人类视觉系统能够在一定程度上"自动校正"光照变化,但机器需要明确的算法来处理这一问题。

主要挑战包括:

  • 光照强度变化:强光会使颜色过曝,弱光则使颜色暗淡
  • 色温变化:自然光与人工光源的色温差异显著
  • 阴影和反射:环境光反射会导致颜色失真
  • 设备差异:不同摄像头的色彩响应特性不同

系统架构概览

解决方案采用模块化设计:

复制代码
图像输入 → 预处理 → 特征提取 → SVM分类 → 颜色识别结果

一、鲁棒的颜色特征提取

1.1 颜色空间的选择

RGB空间的局限性

传统的RGB颜色空间对光照变化极其敏感。当光照强度改变时,R、G、B三个通道的值会同步变化,导致颜色识别不稳定。

解决方案:使用感知均匀的颜色空间

python 复制代码
import cv2
import numpy as np

def convert_to_perceptual_spaces(image):
    """
    转换到更符合人类视觉感知的颜色空间
    """
    # HSV: 色调(H)、饱和度(S)、明度(V)分离
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    
    # LAB: 亮度(L)与颜色信息(A,B)分离,设备无关
    lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    
    # YCrCb: 亮度(Y)与色度分离,广泛用于图像处理
    ycrcb = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb)
    
    return hsv, lab, ycrcb

1.2 多特征融合策略

基于特征提取的思想,我们采用多层次特征提取:

python 复制代码
class RobustColorFeatureExtractor:
    def __init__(self, histogram_bins=8):
        self.bins = histogram_bins
    
    def extract_histogram_features(self, image):
        """提取颜色直方图特征"""
        hsv, lab, ycrcb = convert_to_perceptual_spaces(image)
        
        features = []
        
        # HSV直方图 - 对色调特别敏感
        h_hist = self._calc_normalized_histogram(hsv[:,:,0], 180, self.bins)  # H通道
        s_hist = self._calc_normalized_histogram(hsv[:,:,1], 256, self.bins)  # S通道
        features.extend([h_hist, s_hist])
        
        # LAB直方图 - 忽略亮度通道
        a_hist = self._calc_normalized_histogram(lab[:,:,1], 256, self.bins)  # A通道
        b_hist = self._calc_normalized_histogram(lab[:,:,2], 256, self.bins)  # B通道
        features.extend([a_hist, b_hist])
        
        return np.concatenate(features)
    
    def extract_color_moments(self, image):
        """提取颜色矩特征 - 对光照变化鲁棒"""
        hsv, lab, _ = convert_to_perceptual_spaces(image)
        
        def calculate_moments(channel):
            """计算均值、标准差、偏度"""
            mean = np.mean(channel)
            std = np.std(channel)
            # 偏度衡量分布不对称性
            skew = np.mean((channel - mean) ** 3) / (std ** 3 + 1e-8)
            return [mean, std, skew]
        
        moments = []
        # 对每个颜色通道计算矩
        for i in range(3):  # HSV各通道
            moments.extend(calculate_moments(hsv[:,:,i]))
        for i in [1, 2]:    # LAB的A、B通道
            moments.extend(calculate_moments(lab[:,:,i]))
            
        return np.array(moments)
    
    def extract_autocorrelation_features(self, image):
        """提取颜色自相关特征 - 专利技术改进"""
        _, lab, _ = convert_to_perceptual_spaces(image)
        
        def spatial_autocorrelation(channel, max_distance=3):
            """计算空间自相关特征"""
            correlations = []
            height, width = channel.shape
            
            for d in range(1, max_distance + 1):
                # 水平自相关
                if d < width:
                    corr_h = np.corrcoef(channel[:, :-d].flatten(), 
                                       channel[:, d:].flatten())[0, 1]
                    correlations.append(corr_h if not np.isnan(corr_h) else 0)
                
                # 垂直自相关
                if d < height:
                    corr_v = np.corrcoef(channel[:-d, :].flatten(), 
                                       channel[d:, :].flatten())[0, 1]
                    correlations.append(corr_v if not np.isnan(corr_v) else 0)
            
            return correlations
        
        # 对LAB的A、B通道计算自相关
        features = []
        features.extend(spatial_autocorrelation(lab[:,:,1]))  # A通道
        features.extend(spatial_autocorrelation(lab[:,:,2]))  # B通道
        
        return np.array(features)
    
    def extract_all_features(self, image):
        """提取所有特征并融合"""
        histogram_feat = self.extract_histogram_features(image)
        moment_feat = self.extract_color_moments(image)
        autocorr_feat = self.extract_autocorrelation_features(image)
        
        # 特征融合
        all_features = np.concatenate([
            histogram_feat,
            moment_feat,
            autocorr_feat
        ])
        
        # L2归一化
        norm = np.linalg.norm(all_features)
        return all_features / (norm + 1e-8)
    
    def _calc_normalized_histogram(self, channel, range_max, bins):
        """计算归一化直方图"""
        hist = cv2.calcHist([channel], [0], None, [bins], [0, range_max])
        return hist.flatten() / (np.sum(hist) + 1e-8)

二、SVM分类器的精心设计

2.1 为什么选择SVM?

支持向量机(SVM)特别适合我们的颜色分类任务:

  • 小样本学习:在颜色分类中,标注数据通常有限
  • 高维处理:我们的特征维度较高(50+维)
  • 非线性分类:RBF核可以处理复杂的颜色决策边界
python 复制代码
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV, cross_val_score
import matplotlib.pyplot as plt
import seaborn as sns

class ColorSVMClassifier:
    def __init__(self):
        self.pipeline = None
        self.feature_extractor = RobustColorFeatureExtractor()
        self.classes_ = None
        
    def create_optimized_pipeline(self):
        """创建优化的SVM管道"""
        return Pipeline([
            ('scaler', StandardScaler()),  # 特征标准化
            ('svm', SVC(
                kernel='rbf',
                class_weight='balanced',  # 处理类别不平衡
                probability=True,         # 启用概率预测
                random_state=42
            ))
        ])
    
    def hyperparameter_tuning(self, X, y):
        """超参数优化"""
        param_grid = {
            'svm__C': [0.1, 1, 10, 100],      # 正则化参数
            'svm__gamma': ['scale', 'auto', 0.001, 0.01, 0.1]  # RBF核参数
        }
        
        grid_search = GridSearchCV(
            self.pipeline, param_grid, 
            cv=5, scoring='accuracy', n_jobs=-1, verbose=1
        )
        
        grid_search.fit(X, y)
        
        print("最佳参数:", grid_search.best_params_)
        print("最佳交叉验证分数: {:.3f}".format(grid_search.best_score_))
        
        return grid_search
    
    def plot_learning_curve(self, X, y):
        """绘制学习曲线评估模型"""
        from sklearn.model_selection import learning_curve
        
        train_sizes, train_scores, test_scores = learning_curve(
            self.pipeline, X, y, cv=5, n_jobs=-1,
            train_sizes=np.linspace(0.1, 1.0, 10)
        )
        
        plt.figure(figsize=(10, 6))
        plt.plot(train_sizes, np.mean(train_scores, axis=1), 'o-', label="训练分数")
        plt.plot(train_sizes, np.mean(test_scores, axis=1), 'o-', label="验证分数")
        plt.xlabel("训练样本数")
        plt.ylabel("准确率")
        plt.title("学习曲线")
        plt.legend()
        plt.grid(True)
        plt.show()

三、数据增强与预处理策略

3.1 光照不变性增强

python 复制代码
class IlluminationAugmentation:
    @staticmethod
    def apply_illumination_changes(image):
        """应用光照变化增强"""
        augmented = [image]
        
        # 亮度变化
        for alpha in [0.3, 0.6, 1.5, 2.0]:
            bright = cv2.convertScaleAbs(image, alpha=alpha, beta=0)
            augmented.append(bright)
        
        # 对比度变化
        for beta in [-50, -20, 20, 50]:
            contrast = cv2.convertScaleAbs(image, alpha=1.0, beta=beta)
            augmented.append(contrast)
        
        # 伽马校正模拟不同光照条件
        for gamma in [0.3, 0.7, 1.5, 2.0]:
            inv_gamma = 1.0 / gamma
            table = np.array([((i / 255.0) ** inv_gamma) * 255
                            for i in range(256)]).astype("uint8")
            gamma_img = cv2.LUT(image, table)
            augmented.append(gamma_img)
        
        return augmented
    
    @staticmethod
    def white_balance_gray_world(image):
        """灰度世界白平衡算法"""
        # 转换为LAB空间进行白平衡
        lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
        l, a, b = cv2.split(lab)
        
        # 灰度世界假设:平均颜色应该是灰色
        avg_a = np.mean(a)
        avg_b = np.mean(b)
        
        a = cv2.addWeighted(a, 1, a, 0, 128 - avg_a)
        b = cv2.addWeighted(b, 1, b, 0, 128 - avg_b)
        
        balanced_lab = cv2.merge([l, a, b])
        return cv2.cvtColor(balanced_lab, cv2.COLOR_LAB2BGR)

四、完整系统实现与评估

4.1 端到端颜色分类系统

python 复制代码
class ComprehensiveColorClassifier:
    def __init__(self):
        self.feature_extractor = RobustColorFeatureExtractor()
        self.classifier = ColorSVMClassifier()
        self.is_trained = False
        self.label_encoder = {}
        self.reverse_encoder = {}
    
    def prepare_training_data(self, image_paths, labels, augment=True):
        """准备训练数据"""
        all_features = []
        all_encoded_labels = []
        
        # 创建标签编码
        unique_labels = list(set(labels))
        self.label_encoder = {label: idx for idx, label in enumerate(unique_labels)}
        self.reverse_encoder = {idx: label for label, idx in self.label_encoder.items()}
        
        print("提取特征中...")
        for path, label in zip(image_paths, labels):
            image = cv2.imread(path)
            if image is None:
                continue
                
            if augment:
                # 数据增强
                augmented_images = IlluminationAugmentation.apply_illumination_changes(image)
            else:
                augmented_images = [image]
            
            for aug_img in augmented_images:
                # 白平衡预处理
                balanced_img = IlluminationAugmentation.white_balance_gray_world(aug_img)
                
                # 提取特征
                features = self.feature_extractor.extract_all_features(balanced_img)
                all_features.append(features)
                all_encoded_labels.append(self.label_encoder[label])
        
        return np.array(all_features), np.array(all_encoded_labels)
    
    def train(self, image_paths, labels, optimize_hyperparams=True):
        """训练分类器"""
        X, y = self.prepare_training_data(image_paths, labels)
        
        print(f"训练数据形状: {X.shape}")
        print(f"类别分布: {np.bincount(y)}")
        
        if optimize_hyperparams:
            # 超参数优化
            self.classifier.pipeline = self.classifier.create_optimized_pipeline()
            grid_search = self.classifier.hyperparameter_tuning(X, y)
            self.classifier.pipeline = grid_search.best_estimator_
        else:
            self.classifier.pipeline = self.classifier.create_optimized_pipeline()
            self.classifier.pipeline.fit(X, y)
        
        self.is_trained = True
        
        # 绘制学习曲线
        self.classifier.plot_learning_curve(X, y)
    
    def predict(self, image, return_probability=True):
        """预测图像颜色"""
        if not self.is_trained:
            raise ValueError("分类器尚未训练")
        
        # 预处理
        balanced_image = IlluminationAugmentation.white_balance_gray_world(image)
        
        # 特征提取
        features = self.feature_extractor.extract_all_features(balanced_image)
        features = features.reshape(1, -1)
        
        if return_probability:
            probabilities = self.classifier.pipeline.predict_proba(features)[0]
            pred_class_idx = np.argmax(probabilities)
            confidence = probabilities[pred_class_idx]
            
            pred_class = self.reverse_encoder[pred_class_idx]
            return pred_class, confidence, probabilities
        else:
            pred_class_idx = self.classifier.pipeline.predict(features)[0]
            return self.reverse_encoder[pred_class_idx]
    
    def evaluate_on_test_set(self, test_image_paths, test_labels):
        """在测试集上评估性能"""
        correct = 0
        total = len(test_image_paths)
        confidence_scores = []
        
        for path, true_label in zip(test_image_paths, test_labels):
            image = cv2.imread(path)
            if image is None:
                continue
                
            pred_label, confidence, _ = self.predict(image)
            
            if pred_label == true_label:
                correct += 1
            confidence_scores.append(confidence)
        
        accuracy = correct / total
        avg_confidence = np.mean(confidence_scores)
        
        print(f"测试准确率: {accuracy:.3f}")
        print(f"平均置信度: {avg_confidence:.3f}")
        
        return accuracy, avg_confidence

五、实战演示与结果分析

5.1 实际应用示例

python 复制代码
def demo_color_classification():
    """演示颜色分类器的使用"""
    
    # 我们有以下训练数据
    train_images = [
        "data/red/red1.jpg", "data/red/red2.jpg",
        "data/blue/blue1.jpg", "data/blue/blue2.jpg", 
        "data/green/green1.jpg", "data/green/green2.jpg",
        "data/yellow/yellow1.jpg", "data/yellow/yellow2.jpg"
    ]
    
    train_labels = ["red", "red", "blue", "blue", "green", "green", "yellow", "yellow"]
    
    # 创建并训练分类器
    classifier = ComprehensiveColorClassifier()
    classifier.train(train_images, train_labels)
    
    # 测试新图像
    test_image = cv2.imread("test_image.jpg")
    predicted_color, confidence, probabilities = classifier.predict(test_image)
    
    print(f"预测颜色: {predicted_color}")
    print(f"置信度: {confidence:.3f}")
    
    # 可视化预测结果
    plt.figure(figsize=(10, 4))
    
    # 显示原图
    plt.subplot(1, 2, 1)
    plt.imshow(cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB))
    plt.title(f"输入图像 - 预测: {predicted_color}")
    plt.axis('off')
    
    # 显示概率分布
    plt.subplot(1, 2, 2)
    colors = list(classifier.reverse_encoder.values())
    plt.bar(colors, probabilities)
    plt.title('颜色分类概率分布')
    plt.ylabel('概率')
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.show()

# 运行演示
if __name__ == "__main__":
    demo_color_classification()
相关推荐
白杆杆红伞伞4 小时前
02_svm_多分类
机器学习·支持向量机·分类·dlib
极客数模4 小时前
2025年MathorCup 大数据竞赛明日开赛,注意事项!论文提交规范、模板、承诺书正确使用!2025年第六届MathorCup数学应用挑战赛——大数据竞赛
大数据·python·算法·matlab·图论·比赛推荐
.小小陈.4 小时前
数据结构3:复杂度
c语言·开发语言·数据结构·笔记·学习·算法·visual studio
立志成为大牛的小牛4 小时前
数据结构——二十四、图(王道408)
数据结构·学习·程序人生·考研·算法
TT哇4 小时前
【优先级队列(堆)】2.数据流中的第 K ⼤元素(easy)
算法·1024程序员节
Matlab程序猿小助手5 小时前
【MATLAB源码-第303期】基于matlab的蒲公英优化算法(DO)机器人栅格路径规划,输出做短路径图和适应度曲线.
开发语言·算法·matlab·机器人·kmeans
CoderIsArt5 小时前
CORDIC三角计算技术
人工智能·算法·机器学习
立志成为大牛的小牛5 小时前
数据结构——二十九、图的广度优先遍历(BFS)(王道408)
数据结构·数据库·学习·程序人生·考研·算法·宽度优先
Alex艾力的IT数字空间5 小时前
基于PyTorch和CuPy的GPU并行化遗传算法实现
数据结构·人工智能·pytorch·python·深度学习·算法·机器学习