diff --git a/tutorials/source_zh_cn/advanced_use/membership_inference.md b/tutorials/source_zh_cn/advanced_use/membership_inference.md
index 4f887583fbae6643fed9e8782bd1eae0b80a55b3..1a3ad0e5ec03f4ffca5ac9bebc7033d41fdc398a 100644
--- a/tutorials/source_zh_cn/advanced_use/membership_inference.md
+++ b/tutorials/source_zh_cn/advanced_use/membership_inference.md
@@ -1,32 +1,32 @@
-# 成员推理攻击
-
-`Linux` `Ascend` `全流程` `初级` `中级` `高级`
+# 成员推理
-- [成员推理攻击](#成员推理攻击)
+- [成员推理](#成员推理)
- [概述](#概述)
- [实现阶段](#实现阶段)
- [导入需要的库文件](#导入需要的库文件)
- [加载数据集](#加载数据集)
- [建立模型](#建立模型)
- - [运用MembershipInference](#运用membershipinference)
+ - [运用MembershipInference进行隐私安全评估](#运用membershipinference进行隐私安全评估)
- [参考文献](#参考文献)
-
+
+
## 概述
-成员推理攻击是一种窃取用户数据隐私的方法。隐私指的是单个用户的某些属性,一旦泄露可能会造成人身损害、名誉损害等后果。通常情况下,用户的隐私数据会作保密处理,但我们可以利用非敏感信息来进行推测。例如:”抽烟的人更容易得肺癌“,这个信息不属于隐私信息,但如果知道“张三抽烟”,就可以推断“张三”更容易得肺癌,这就是成员推理。
+成员推理是一种推测用户隐私数据的方法。隐私指的是单个用户的某些属性,一旦泄露可能会造成人身损害、名誉损害等后果。通常情况下,用户的隐私数据会作保密处理,但我们可以利用非敏感信息来进行推测。如果我们知道了某个私人俱乐部的成员都喜欢戴紫色墨镜、穿红色皮鞋,那么我们遇到一个戴紫色墨镜且穿红色皮鞋(非敏感信息)的人,就可以推断他/她很可能是这个私人俱乐部的成员(敏感信息)。这就是成员推理。
-机器学习/深度学习的成员推理攻击(Membership Inference),指的是攻击者拥有模型的部分访问权限(黑盒、灰盒或白盒),能够获取到模型的输出、结构或参数等部分或全部信息,并基于这些信息推断某个样本是否属于模型的训练集。
+机器学习/深度学习的成员推理(Membership
+Inference),指的是攻击者拥有模型的部分访问权限(黑盒、灰盒或白盒),能够获取到模型的输出、结构或参数等部分或全部信息,并基于这些信息推断某个样本是否属于模型的训练集。利用成员推理,我们可以评估机器学习/深度学习模型的隐私数据安全。如果在成员推理下能正确识别出60%+的样本,那么我们认为该模型存在隐私数据泄露风险。
-这里以VGG16模型,CIFAR-100数据集为例,说明如何使用MembershipInference。本教程使用预训练的模型参数进行演示,这里仅给出模型结构、参数设置和数据集预处理方式。
+这里以VGG16模型,CIFAR-100数据集为例,说明如何使用MembershipInference进行模型隐私安全评估。本教程使用预训练的模型参数进行演示,这里仅给出模型结构、参数设置和数据集预处理方式。
>本例面向Ascend 910处理器,您可以在这里下载完整的样例代码:
>
->
+>
## 实现阶段
@@ -51,7 +51,7 @@ from mindspore.common.initializer import initializer
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as vision
-from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference
+from mindarmour import MembershipInference
from mindarmour.utils import LogUtil
LOGGER = LogUtil.get_instance()
@@ -178,7 +178,7 @@ def vgg16(num_classes=1000, args=None, phase="train"):
return net
```
-### 运用MembershipInference
+### 运用MembershipInference进行隐私安全评估
1. 构建VGG16模型并加载参数文件。
这里直接加载预训练完成的VGG16参数配置,您也可以使用如上的网络自行训练。
@@ -196,37 +196,34 @@ def vgg16(num_classes=1000, args=None, phase="train"):
args.padding = 0
args.pad_mode = "same"
args.weight_decay = 5e-4
- args.loss_scale = 1.0
-
- data_path = "./cifar-100-binary" # Replace your data path here.
- pre_trained = "./VGG16-100_781.ckpt" # Replace your pre trained checkpoint file here.
+ args.loss_scale = 1.0
# Load the pretrained model.
net = vgg16(num_classes=100, args=args)
- loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
+ loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9,
weight_decay=args.weight_decay, loss_scale=args.loss_scale)
load_param_into_net(net, load_checkpoint(args.pre_trained))
model = Model(network=net, loss_fn=loss, optimizer=opt)
```
-2. 加载CIFAR-100数据集,按8:2分割为成员推理攻击模型的训练集和测试集。
+2. 加载CIFAR-100数据集,按8:2分割为成员推理模型的训练集和测试集。
```python
# Load and split dataset.
train_dataset = vgg_create_dataset100(data_home=args.data_path, image_size=(224, 224),
- batch_size=64, num_samples=10000, shuffle=False)
+ batch_size=64, num_samples=5000, shuffle=False)
test_dataset = vgg_create_dataset100(data_home=args.data_path, image_size=(224, 224),
- batch_size=64, num_samples=10000, shuffle=False, training=False)
+ batch_size=64, num_samples=5000, shuffle=False, training=False)
train_train, eval_train = train_dataset.split([0.8, 0.2])
train_test, eval_test = test_dataset.split([0.8, 0.2])
msg = "Data loading completed."
LOGGER.info(TAG, msg)
```
-3. 配置攻击参数和评估参数
+3. 配置推理参数和评估参数
- 设置用于成员推理评估的方法和参数。目前支持的推理方法有:KNN、LR、MLPClassifier和RandomForest Classifier。
+ 设置用于成员推理的方法和参数。目前支持的推理方法有:KNN、LR、MLPClassifier和RandomForestClassifier。推理参数数据类型使用list,各个方法使用key为"method"和"params"的字典表示。
```python
config = [
@@ -263,35 +260,42 @@ def vgg16(num_classes=1000, args=None, phase="train"):
]
```
- 设置评价指标,目前支持3种评价指标。包括:
- * 准确率:accuracy。
- * 精确率:precision。
- * 召回率:recall。
-
- ```python
- metrics = ["precision", "accuracy", "recall"]
- ```
+ 我们约定标签为数据集的是正类,标签为测试集的是负类。设置评价指标,目前支持3种评价指标。包括:
+ * 准确率:accuracy,正确推理的数量占全体样本中的比例。
+ * 精确率:precision,正确推理的正类样本占所有推理为正类中的比例。
+ * 召回率:recall,正确推理的正类样本占全体正类样本的比例。
+ 在样本数量足够大时,如果上述指标均大于0.6,我们认为目标模型就存在隐私泄露的风险。
-4. 训练成员推理攻击模型,并给出评估结果。
+```python
+ metrics = ["precision", "accuracy", "recall"]
+```
+
+4. 训练成员推理模型,并给出评估结果。
```python
- attacker = MembershipInference(model) # Get attack model.
+ inference = MembershipInference(model) # Get inference model.
- attacker.train(train_train, train_test, config) # Train attack model.
+ inference.train(train_train, train_test, config) # Train inference model.
msg = "Membership inference model training completed."
LOGGER.info(TAG, msg)
- result = attacker.eval(eval_train, eval_test, metrics) # Eval metrics.
+ result = inference.eval(eval_train, eval_test, metrics) # Eval metrics.
count = len(config)
for i in range(count):
print("Method: {}, {}".format(config[i]["method"], result[i]))
```
5. 实验结果。
+ 执行如下指令,开始成员推理训练和评估:
+
+ ```
+ python membership_inference_example.py --data_path ./cifar-100-binary/ --pre_trained ./VGG16-100_781.ckpt
+ ```
成员推理的指标如下所示,各数值均保留至小数点后四位。
- 以第一行结果为例:在使用lr(逻辑回归分类)进行成员推理时,推理的准确率(accuracy)为0.7132,推理精确率(precision)为0.6596,正类样本召回率为0.8810。在二分类任务下,指标表明我们的成员推理是有效的。
+ 以第一行结果为例:在使用lr(逻辑回归分类)进行成员推理时,推理的准确率(accuracy)为0.7132,推理精确率(precision)为0.6596,正类样本召回率为0.8810,说明lr有71.32%的概率能正确分辨一个数据样本是否属于目标模型的训练数据集。
+ 在二分类任务下,指标表明成员推理是有效的,即该模型存在隐私泄露的风险。
```
Method: lr, {'recall': 0.8810,'precision': 0.6596,'accuracy': 0.7132}