机器学习中出现的不平衡类别问题的分析、常用的解决方案。

定义

不同类别训练数据样本的数量相差很大,导致传统意义上的衡量指标“准确率” 失去了意义。 该问题出现在多个领域中,包括:

  • 信用欺诈
  • 垃圾邮件过滤
  • 疾病筛查

处理方法

数据层面

主要有以下三种方式:Undersample majority class, Oversample minority class, Try generating synthetic samples.

(1)欠(下)采样(undersampling):Undersample majority class

对训练集中多数类随机进行下采样,取多数类中的样本使得正例、反例数目接近,然后进行学习。从$S_{majority}$ 中样本数量为 $N_1$, 使得$N_1 =N_{minority}$

直接这种随机下采样方法是有不足的,比如采样导致原有信息的缺失。所以常常使用EasyEnsemble 算法进行优化。算法步骤:

  • 从多数类中有放回的随机采样$n$ 次,每次选取和少数类别样本相同数目的样本个数,得到 $n$ 个模型
  • 然后每个上述的子集和少数样本合并训练兵训练,可以得到 $n$ 个模型
  • 最终这些模型组合形成一个集成学习系统,模型的结果是这$n$ 个模型的平均值

(2)过(上)采样(oversampling):Oversample minority class

最简单的方法是(如果没有更多的数据,只需要复制现有的数据,并轻微的变化即可)

  1. 图像数据增强 镜像翻转、旋转、平移、缩放、颜色随机扰动、非线性几何变形等; GAN生成新样本;
  2. 文本数据增强
  • 随机过采样 在少数类 $S_{minority}$ 中随机选择一些样本,然后通过复制这些选择的样本得到新的结合,这些集合作为训练中的少数样本数据集。其缺点:容易造成模型的过拟合,因为这些样本都是通过对初始样本复制采样得来,不利于提高模型的泛化性能。

  • SMOTE(Synthetic Minority Oversampling)算法 即,合成少数类过采样技术是对随机过采样的一种改进算法。而SOMT算法的基本思想是对每个少数类样本$x_i$ ,从它的最近邻中随机选择一个样本 $\hat{x_i}$ ( $\hat{x_i}$ 是少数类中的一个样本),然后在$x_i$ 和 $\hat{x_i}$ 之间的连线上随机选择一点作为新合成的少数类样本。算法描述如下: 1).对于少数类中的每一个样本 $x_i$,以欧氏距离为标准计算它到少数类样本集 $S_{minority}$中所有样本的距离,得到其k近邻。 2).根据样本不平衡比例设置一个采样比例以确定采样倍率 $N$,对于每一个少数类样本 $x_i$ ,从其 $k$近邻中随机选择若干个样本,假设选择的是 $\hat{x_i}$ 。 3).对于每一个随机选出来的近邻 $\hat{x_i}$ ,分别与 $x_i$ 按照如下公式构建新的样本。

$$ x_{new} =x_i +rand(0,1) \times (\hat{x_i} -x_i) $$

参考, 有一种代码实现 不足之处 1). 观察到的数目是及其罕见的类别的时候,就不知所措。 2). 每个少数样本都生成新样本,容易发生样本重叠的问题 3). 生成机制存在一定的盲目性,可能有些少数的样本并不具有少数样本的代表性

  • K近邻单词替换 这里用到了word embedding工具,在序列模型中,每个单词都能映射成一个词向量。所以一个单词可以看做高维空间中一个样本点,这也就可以用K近邻来得到和它语义相近的单词了。这里放一个代码链接。可以被看做是处理OOV问题的一种手段。

(3)Try generating synthetic samples

SMOTE is an oversampling method which creates “synthetic” example rather than oversampling by replacements. The minority class is over-sampled by taking each minority class sample and introducing synthetic examples along the line segments joining any/all of the k minority class nearest neighbors. Depending upon the amount of over-sampling required, neighbors from the k nearest neighbors are randomly chosen.

The heart of SMOTE is the construction of the minority classes. The intuition behind the construction algorithm is simple. You have already studied that oversampling causes overfitting, and because of repeated instances, the decision boundary gets tightened. What if you could generate similar samples instead of repeating them? In the original SMOTE paper (linked above) it has been shown that to a machine learning algorithm, these newly constructed instances are not exact copies, and thus it softens the decision boundary and thereby helping the algorithm to approximate the hypothesis more accurately.

image-20210309180256637

(4)数据增强

常用的数据数据增强手段有以下几点:

  • 水平、垂直翻转
  • 90。,180。,270。90。,180。,270。翻转
  • 翻转+旋转
  • 亮度、饱和度、对比度的随机变换
  • 随机裁剪
  • 随机缩放
  • 加模糊(blurring)
  • 加高斯噪声(Gaussian Noise)

注意可能不同的操作会影响原有数据集的分布。

(5)视觉中的数据增强

(1) cutout

cutout being inefficient due to unused pixels. Mixup, on the other hand, makes full use of pixels, but introduces unnatural artifacts

(2) mixup

(3) cutmix

方法

CutMix augmentation strategy: patches are cut and pasted among traing images where ground truth labels are alsose mixed proportionally to the area of the patches.

效果

image-20210714103554646

Synthesizing training data

Some works have explored synthesizing training data for further generalizability. Generating new training samples by Stylizing ImageNet has guided the model to focus more on shape than texture, leading to better classification and object detection performances.

生成的图大概只能生成 shape 相同的,但 texture 相对来说差点。相应的在下游的泛化能力也就体现在这里。

(4) mosaic

(5) autoaugment

操作

cutout, mixup, cutmix, autoaugment 在 cifar-10 和 cifar-100 上的效果看,mixupautoaugment 是比较好的。

  • mixup 将随机的两张样本按比例混合,分类的结果按照比例分配
  • cutout 随机将样本中的部分区域 cut 掉,并且使用 0 填充,分类的结果不变
  • cutmix 将一部分区域 cut 掉,使用随机填充训练集中的其他数据的区域像素值,分类结果按一定的比例分配

(貌似上传 webp 格式的图像会失败)

上述的区别:

mixup 是将两张图按比例进行插值混合样本,会有不自然的情形。

yolov4 的 mosaic 参考的是 cutmix 数据增强,只不过 mosaic 利用的是四张图,cutmix 使用的是两张图。

特点

  • cutmix improves the model robustness against input corruptions and its out-of-distribution (分布之外的)detection performances.
  • mixup reduces the memorization of corrupt labels, increase the robustness to adversarial examples, and stablizes the training of generative adversarial networks (mixup 是有利于 gan 的训练,这个点倒是有点意外。说明不同的 data augmentation 确实是有不同的场景的应用)
  • cutout is a smple regularization technique for convolutional neural networks that involves removing contiguous sections of input images, effectively augmenting the dataset with partially occluded versions of existing samples. (cutout 可以使用在 occluded 比较严重的情况下)

重点:

intuition,如何做的,效果?为什么 work(原理)

原理方面,需要看 cutmix 的论文分析。

然后总结到 albumentations 的work 上

改变评价指标

不要使用 ACC。可以使用 AUC, ROC-Curve,f1-score。

for a general-purpose metric for classification, we recommand area under roc curve (auc )

改变模型

  1. 基于SVM惩罚算法 使用惩罚学习算法增加对少数类别分类错误的代价,一个流行的算法是惩罚性-SVM
1
2
3
4
from sklearn.svm import SVC
clf_3 = SVC(kernel='linear',
            class_weight='balanced', # penalize
            probability=True)
  1. 基于树的算法 决策树通常在不平衡的数据集上表现良好,因为这种层序结构允许其从两个类别去学习。而在目前看来,树集合(随机森林、梯度提升树)总是优于单个决策树,所以可以考虑xgboost 之类的框架。

The final tactic we’ll consider is using tree-based algorithms. Decision trees often perform well on imbalanced datasets becase their hirearchical structure allows them to learn signals from both classes.

In mordern applied machine learning, tree ensembles (Random Forests, …) almost always outperform singular decision trees.

  1. 修改损失函数
  • 如果你用的是keras,模型训练函数中是可以调整class_weight的,可以在class_weight中适当增大正样本的权重。比较忌讳把正样本权重增大到两者损失总量一样,这么设置从来没有一次效果是好的。笔者一般把正样本权重调到负样本权重的1.1~1.5倍,可以取得比之前要好的F1 score。
  • Focal Loss是个值得考虑的目标函数,论文:Focal Loss for Dense Object Detection。该损失函数在目标检测领域取得了良好的处理类别不平衡效果和改善误分类的效果,笔者在文本分类的任务中用了该目标函数,也能取得较大提升。

how to handle unbalanced data

对于这类问题是可以从数据和模型来进行考虑的。

Imbalanced data typically refers to a problem with classification problems where the classes are not represented equally. The accuracy paradox is the name for the exact situation in the introduction to this post.

Data approach Oversample minority class and Undersample majority class

  • Over-sampling increases the number of minority class members in the training set. The advantage of over-sampling is that no information from the original training set is lost, as all observations from the minority and majority classes are kept. On the other hand, it is prone to overfitting. (You can add copies of instances from the under-represented class called over-sampling or more formally sampling with replacement)

SMOTE (Synthetic minority oversampling technique)

由少量的数据生成比较多的数据

How does SMOTE work? SMOTE generates new samples in between existing data points based on their local density and their borders with the other class. Not only does it perform oversampling, but can subsequently use cleaning techniques (undersampling, more on this shortly) to remove redundancy in the end. Below is an illustration for how SMOTE works when studying class data.

Image for post

  • Under-sampling, on contrary to over-sampling, aims to reduce the number of majority samples to balance the class distribution. Since it is removing observations from the original data set, it might discard useful information. (You can delete instances from the over-represented class, called under-sampling.)

Some Rules of Thumb

  • Consider testing under-sampling when you have an a lot data (tens- or hundreds of thousands of instances or more)
  • Consider testing over-sampling when you don’t have a lot of data (tens of thousands of records or less)
  • Consider testing random and non-random (e.g. stratified) sampling schemes.
  • Consider testing different resampled ratios (e.g. you don’t have to target a 1:1 ratio in a binary classification problem, try other ratios)

Try Different Algorithms 基于树的这种结构的模型还是表现比较给力的。 That being said, decision trees often perform well on imbalanced datasets. The splitting rules that look at the class variable used in the creation of the trees, can force both classes to be addressed.

Try Penalized Models Penalized classification imposes an additional cost on the model for making classification mistakes on the minority class during training. These penalties can bias the model to pay more attention to the minority class.

Try Changing Your Performance Metric

使用 precision and recall curves or F1 去评价你的网络效果

  • Precision: A measure of a classifiers exactness.

  • Recall: A measure of a classifiers completeness

  • F1 Score (or F-score): A weighted average of precision and recall.

  • GBC参数 这些参数中,类似于Adaboost,我们把重要参数分为两类,第一类是Boosting框架的重要参数,第二类是弱学习器即CART回归树的重要参数。 n_estimators: 也就是弱学习器的最大迭代次数,或者说最大的弱学习器的个数。 learning_rate: 即每个弱学习器的权重缩减系数ν,也称作步长 对于分类模型,有对数似然损失函数"deviance"和指数损失函数"exponential"两者输入选择。默认是对数似然损失函数"deviance"。一般来说,推荐使用默认的"deviance"。它对二元分离和多元分类各自都有比较好的优化。而指数损失函数等于把我们带到了Adaboost算法。 对于回归模型,有均方差"ls", 绝对损失"lad", Huber损失"huber"和分位数损失“quantile”。默认是均方差"ls"。一般来说,如果数据的噪音点不多,用默认的均方差"ls"比较好。如果是噪音点较多,则推荐用抗噪音的损失函数"huber"。而如果我们需要对训练集进行分段预测的时候,则采用“quantile”。 max_features:可以使用很多种类型的值,默认是"None",意味着划分时考虑所有的特征数. subsample: 选择小于1的比例可以减少方差,即防止过拟合,但是会增加样本拟合的偏差,因此取值不能太低。推荐在[0.5, 0.8]之间,默认是1.0,即不使用子采样。

从 loss 角度出发

加权惩罚,在分类模型中, 使用了 weighted loss。(这个是更加照顾数量少的类别,还是更加照顾数量大的类别?)

weight 如何作用于 loss的原理,可以查看这里:https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

1
2
3
nSamples = [887, 6130, 480, 317, 972, 101, 128]
normedWeights = [1 - (x / sum(nSamples)) for x in nSamples]
normedWeights = torch.FloatTensor(normedWeights).to(device)

在实际应用中,weight 应该是 数量的导数。数量越少,那么模型对于该类别的重视程度越高,在总体的损失中占比相对大一些。

the more instance the less weight of a class

常见的例子

Many classification problems may have a severe imbalance in the class distribution; nevertheless, looking at common problem domains that are inherently imbalanced will make the ideas and challenges of class imbalance concrete.

  • Fraud Detection.
  • Claim Prediction
  • Default Prediction.
  • Churn Prediction.
  • Spam Detection.
  • Anomaly Detection.
  • Outlier Detection.
  • Intrusion Detection
  • Conversion Prediction.

参考文献

如何处理机器学习中的不平衡类别 分类中解决类别不平衡问题

GBC参数设置 ROC曲线和AUC值 Introduction to Python Ensembles