论文部分内容阅读
摘 要: 为了提升合成表格数据的质量,提出一种简单的方法生成每个类的数据,使用度量损失控制每一类结构化数据的生成,将此方法命名为SCGAN。文章用此方法在二分类问题上进行了尝试。使用三种不同的度量损失在三个真实的数据集上训练生成对抗网络:逐次对每一类数据进行合成,利用合成数据训练分类器模型,使用gmean来评估模型的性能。结果表明,单独生成每一类数据能够提升模型的分类性能。
关键词: 合成数据; 度量损失; 生成对抗网络; 分类器
中图分类号:TP391 文献标识码:A 文章编号:1006-8228(2021)04-25-03
Abstract: In order to improve the quality of tabular data synthesis, a simple method to generate data of each category is proposed, and it is named SCGAN and uses metrics loss to control the generation of structured data of each category. In this paper, the binary classification problem is tried to be solved by this method. By using three different metrics losses, the generative adversarial network is trained on three real datasets that each category of data are synthesized one by one, the classifier model are trained with the synthesized data, and gmean is used to evaluate the performance of the model. The results show that generating each category of data separately can improve the classification performance of the model.
Key words: synthesized data; metrics loss; generative adversarial networks; classifier
0 引言
近年来,生成对抗网络在生成高质量合成图像方面取得了很大的成功[1]。多种数据类型,数据分布不确定,多模态分布,数据不均衡等特点对生成表格型数据带来了挑战[2]。MedGAN提出医学生成对抗网络,来生成逼真的合成病历[3]。TableGAN使用生成对抗网络来合成假表,这些假表在统计上类似于原始表[4]。CTGAN对连续数据进行建模,对离散数据增加条件损失来合成高质量数据[2]。
本文在CTGAN的基础上提出一种无监督的生成对抗网络方法,将衡量指标FID[5],MMD[6],最小二乘作为度量模块应用到生成对抗网络模型中,利用单个类别的数据训练模型生成大量的合成数据,利用梯度惩罚[7]和谱归一化方法[8]來增强模型训练的稳定性。在三个真实的数据集上选取相同数量的生成数据对三种度量方法做了比较,实验结果显示,本文提出的方法能够提升生成数据的质量,提升模型分类的性能。
1 SCGAN
1.1 生成对抗网络
生成对抗网络是一种生成模型[1],包含生成器(G)和判别器(D)两部分。生成器目的是生成逼真的合成数据以最大程度的骗过判别器来达到损失的最小化,判别器争取将真实数据和合成数据分别开来[9]。以下为生成对抗网络的一般形式:
其中[z]是随机输入的噪声,一般为高斯分布中的随机采样点,[pz]是潜在向量[z]的先验分布,[G?]是生成器函数,[D?]是判别器函数。
1.2 度量损失
为了保证生成数据的质量,将三种度量损失:FID,MMD,最小二乘等加入到生成对抗网络模型中,由于最小二乘比较简单,在此我们着重介绍前两种方法。
⑴ Frechet Inception Distance (FID)
FID[5]常用于评估生成器最终生成的图像质量,计算真实数据和合成数据在特征层面的距离,距离越小,说明合成数据与真实数据越相似,以下是FID的计算公式:
其中[Pr],[Pg]分别表示真实数据和生成数据,[C]表示数据的协方差矩阵,[u]表示数据的均值,我们将这种评估方式应用到生成表格数据的生成对抗模型中,参与生成器模型的训练,鼓励生成器学习真实数据的分布。
⑵ Maximum Mean Discrepancy (MMD)
MMD[6]是一种基于最大均方差的统计检验来优化两类样本的分布,常用于评估生成图像的质量。此处,我们使用MMD衡量生成的结构化数据,定义如下:
给定两类结构化数据集,[V=v1,v2,…vm]和[W=w1,w2,…wm],以下为MMD计算公式:
其中[k?]是高斯核函数。
1.3 SCGAN整体流程
整体流程如图1所示。我们使用生成对抗网络对划分好的训练集进行训练,生成指定类别的合成数据,TrainData0表示第一类数据对应生成数据Fake0,TrainData1表示第二类数据对应生成数据Fake1,在G,D网络中我们遵循了CTGAN的网络结构,但是由于我们是生成指定类别的数据,所以在生成器和判别器中去除了条件输入,在G中加入了3种度量损失函数。当生成指定类别的数据后,对生成的数据每个类分别选取500个和1000个样本,最终组成1000和2000大小的训练集,训练分类器(SVM,RF,DT)模型,使用gmean[10]评估分类器的性能。 2 实验
2.1 数据集介绍
本文研究的数据集来自于①Covtype,用来预测森林覆盖类型的多分类数据集,我们选择了Ponderosa Pine,Krummholz这两类数据来测试我们的模型。②Adult是一个从人口普查数据库中提取的个人信息记录的数据集,我们将收入是否超过50k,作为分类的二进制标签。③BitcoinHeist是一个有关比特币交易图的数据集,简记为Bit,从中选取了princetonCerber和montrealCryptoLocker类别的数据,对数据进行二分类。
2.2 方法比较
在我们的SCGAN中,我们对比了使用不同度量下生成样本的质量,而且也与不加度量损失的生成对抗网络和原始的CTGAN进行了对比。SCGAN-FID表示在生成器上使用FID作为度量损失,SCGAN-MMD表示在生成器上使用MMD作为度量损失,SCGAN-LS表示在生成器上使用最小二乘作为度量损失,GAN表示没有加度量损失。值得注意的一点,在三种度量方法和没有使用度量方法的GAN中,除了损失函数的差异,其他迭代次数和网络都是一致的。
2.3 实验结果
在实验中,我们记录了每一种方法以及每一种数据集在每一种基分类器实验结果,为了显现整体的有效性,表1至表3是每一种方法在三个基分类器上的平均结果。从表1和表2中可以看到,在三个真实的数据集上,本文提出的SCGAN整体优于CTGAN,另外,在表3中,我们记录了不使用度量损失下的GAN模型的性能,根据在gmean指标上的评估可以看到,进一步说明了度量损失的有效性。
3 总结
本文提出的SCGAN,分别进行每一类别的数据合成,通过实验表明能够提升模型的分类性能。我们只在二分类问题上进行了尝试,将此方法应用到多类不均衡数据集中是我们接下来的研究重点。
参考文献(References):
[1] Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems,2014:2672-2680
[2] Xu L, Skoularidou M, Cuesta-Infante A, et al. Modeling tabular data using conditional gan[C]//Advances in Neural Information Processing Systems,2019:7335-7345
[3] Choi E, Biswal S, Malin B, et al. Generating multi-label discrete patient records using generative adversarial networks[J]. arXiv preprint arXiv:1703.06490,2017.
[4] Park N, Mohammadi M, Gorde K, et al. Data synthesis based on generative adversarial networks[J].arXiv preprint arXiv:1806.03384,2018.
[5] Heusel M, Ramsauer H, Unterthiner T, et al. Gans trained by a two time-scale update rule converge to a local nash equilibrium[J]. Advances in neural information processing systems,2017.30: 6626-6637
[6] Sutherland D J, Tung H Y, Strathmann H, et al.Generative models and model criticism via optimized maximum mean discrepancy[J]. arXiv preprint arXiv:1611.04488,2016.
[7] Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[J]. Advances in neural information processing systems,2017.30: 5767-5777
[8] Miyato T, Kataoka T, Koyama M, et al. Spectral normalization for generative adversarial networks[J].arXiv preprint arXiv:1802.05957,2018
[9] 張重生著.人工智能 人脸识别与搜索[M].电子工业出版社,2020.
[10] Leevy J L, Khoshgoftaar T M, Bauder R A, et al. A survey on addressing high-class imbalance in big data[J]. Journal of Big Data,2018.5(1):42
关键词: 合成数据; 度量损失; 生成对抗网络; 分类器
中图分类号:TP391 文献标识码:A 文章编号:1006-8228(2021)04-25-03
Abstract: In order to improve the quality of tabular data synthesis, a simple method to generate data of each category is proposed, and it is named SCGAN and uses metrics loss to control the generation of structured data of each category. In this paper, the binary classification problem is tried to be solved by this method. By using three different metrics losses, the generative adversarial network is trained on three real datasets that each category of data are synthesized one by one, the classifier model are trained with the synthesized data, and gmean is used to evaluate the performance of the model. The results show that generating each category of data separately can improve the classification performance of the model.
Key words: synthesized data; metrics loss; generative adversarial networks; classifier
0 引言
近年来,生成对抗网络在生成高质量合成图像方面取得了很大的成功[1]。多种数据类型,数据分布不确定,多模态分布,数据不均衡等特点对生成表格型数据带来了挑战[2]。MedGAN提出医学生成对抗网络,来生成逼真的合成病历[3]。TableGAN使用生成对抗网络来合成假表,这些假表在统计上类似于原始表[4]。CTGAN对连续数据进行建模,对离散数据增加条件损失来合成高质量数据[2]。
本文在CTGAN的基础上提出一种无监督的生成对抗网络方法,将衡量指标FID[5],MMD[6],最小二乘作为度量模块应用到生成对抗网络模型中,利用单个类别的数据训练模型生成大量的合成数据,利用梯度惩罚[7]和谱归一化方法[8]來增强模型训练的稳定性。在三个真实的数据集上选取相同数量的生成数据对三种度量方法做了比较,实验结果显示,本文提出的方法能够提升生成数据的质量,提升模型分类的性能。
1 SCGAN
1.1 生成对抗网络
生成对抗网络是一种生成模型[1],包含生成器(G)和判别器(D)两部分。生成器目的是生成逼真的合成数据以最大程度的骗过判别器来达到损失的最小化,判别器争取将真实数据和合成数据分别开来[9]。以下为生成对抗网络的一般形式:
其中[z]是随机输入的噪声,一般为高斯分布中的随机采样点,[pz]是潜在向量[z]的先验分布,[G?]是生成器函数,[D?]是判别器函数。
1.2 度量损失
为了保证生成数据的质量,将三种度量损失:FID,MMD,最小二乘等加入到生成对抗网络模型中,由于最小二乘比较简单,在此我们着重介绍前两种方法。
⑴ Frechet Inception Distance (FID)
FID[5]常用于评估生成器最终生成的图像质量,计算真实数据和合成数据在特征层面的距离,距离越小,说明合成数据与真实数据越相似,以下是FID的计算公式:
其中[Pr],[Pg]分别表示真实数据和生成数据,[C]表示数据的协方差矩阵,[u]表示数据的均值,我们将这种评估方式应用到生成表格数据的生成对抗模型中,参与生成器模型的训练,鼓励生成器学习真实数据的分布。
⑵ Maximum Mean Discrepancy (MMD)
MMD[6]是一种基于最大均方差的统计检验来优化两类样本的分布,常用于评估生成图像的质量。此处,我们使用MMD衡量生成的结构化数据,定义如下:
给定两类结构化数据集,[V=v1,v2,…vm]和[W=w1,w2,…wm],以下为MMD计算公式:
其中[k?]是高斯核函数。
1.3 SCGAN整体流程
整体流程如图1所示。我们使用生成对抗网络对划分好的训练集进行训练,生成指定类别的合成数据,TrainData0表示第一类数据对应生成数据Fake0,TrainData1表示第二类数据对应生成数据Fake1,在G,D网络中我们遵循了CTGAN的网络结构,但是由于我们是生成指定类别的数据,所以在生成器和判别器中去除了条件输入,在G中加入了3种度量损失函数。当生成指定类别的数据后,对生成的数据每个类分别选取500个和1000个样本,最终组成1000和2000大小的训练集,训练分类器(SVM,RF,DT)模型,使用gmean[10]评估分类器的性能。 2 实验
2.1 数据集介绍
本文研究的数据集来自于①Covtype,用来预测森林覆盖类型的多分类数据集,我们选择了Ponderosa Pine,Krummholz这两类数据来测试我们的模型。②Adult是一个从人口普查数据库中提取的个人信息记录的数据集,我们将收入是否超过50k,作为分类的二进制标签。③BitcoinHeist是一个有关比特币交易图的数据集,简记为Bit,从中选取了princetonCerber和montrealCryptoLocker类别的数据,对数据进行二分类。
2.2 方法比较
在我们的SCGAN中,我们对比了使用不同度量下生成样本的质量,而且也与不加度量损失的生成对抗网络和原始的CTGAN进行了对比。SCGAN-FID表示在生成器上使用FID作为度量损失,SCGAN-MMD表示在生成器上使用MMD作为度量损失,SCGAN-LS表示在生成器上使用最小二乘作为度量损失,GAN表示没有加度量损失。值得注意的一点,在三种度量方法和没有使用度量方法的GAN中,除了损失函数的差异,其他迭代次数和网络都是一致的。
2.3 实验结果
在实验中,我们记录了每一种方法以及每一种数据集在每一种基分类器实验结果,为了显现整体的有效性,表1至表3是每一种方法在三个基分类器上的平均结果。从表1和表2中可以看到,在三个真实的数据集上,本文提出的SCGAN整体优于CTGAN,另外,在表3中,我们记录了不使用度量损失下的GAN模型的性能,根据在gmean指标上的评估可以看到,进一步说明了度量损失的有效性。
3 总结
本文提出的SCGAN,分别进行每一类别的数据合成,通过实验表明能够提升模型的分类性能。我们只在二分类问题上进行了尝试,将此方法应用到多类不均衡数据集中是我们接下来的研究重点。
参考文献(References):
[1] Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems,2014:2672-2680
[2] Xu L, Skoularidou M, Cuesta-Infante A, et al. Modeling tabular data using conditional gan[C]//Advances in Neural Information Processing Systems,2019:7335-7345
[3] Choi E, Biswal S, Malin B, et al. Generating multi-label discrete patient records using generative adversarial networks[J]. arXiv preprint arXiv:1703.06490,2017.
[4] Park N, Mohammadi M, Gorde K, et al. Data synthesis based on generative adversarial networks[J].arXiv preprint arXiv:1806.03384,2018.
[5] Heusel M, Ramsauer H, Unterthiner T, et al. Gans trained by a two time-scale update rule converge to a local nash equilibrium[J]. Advances in neural information processing systems,2017.30: 6626-6637
[6] Sutherland D J, Tung H Y, Strathmann H, et al.Generative models and model criticism via optimized maximum mean discrepancy[J]. arXiv preprint arXiv:1611.04488,2016.
[7] Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[J]. Advances in neural information processing systems,2017.30: 5767-5777
[8] Miyato T, Kataoka T, Koyama M, et al. Spectral normalization for generative adversarial networks[J].arXiv preprint arXiv:1802.05957,2018
[9] 張重生著.人工智能 人脸识别与搜索[M].电子工业出版社,2020.
[10] Leevy J L, Khoshgoftaar T M, Bauder R A, et al. A survey on addressing high-class imbalance in big data[J]. Journal of Big Data,2018.5(1):42