# 新型优化算法的实现和分析 **Repository Path**: spring-breeze-blows/sam_project ## Basic Information - **Project Name**: 新型优化算法的实现和分析 - **Description**: 神经网络和深度学习课程设计:新型优化算法的实现和分析 - **Primary Language**: Python - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2025-01-14 - **Last Updated**: 2025-01-14 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # 实验环境 Python=3.10.14, Ubuntu=22.04, Cuda=12.1 torch==2.3.0 torchvision==0.18.0 文件夹中也有requirement.txt文件 # 目录结构和文件简述 . ├── cifar //数据集 │ ├── cifar-10-batches-py │ │ ├── batches.meta │ │ ├── data_batch_1 │ │ ├── data_batch_2 │ │ ├── data_batch_3 │ │ ├── data_batch_4 │ │ ├── data_batch_5 │ │ ├── readme.html │ │ └── test_batch │ └── cifar-10-python.tar.gz ├── data │ ├── cifar100.py //用于载入CIFAR100数据集 │ ├── cifar-10-batches-py │ │ ├── batches.meta │ │ ├── data_batch_1 │ │ ├── data_batch_2 │ │ ├── data_batch_3 │ │ ├── data_batch_4 │ │ ├── data_batch_5 │ │ ├── readme.html │ │ └── test_batch │ ├── cifar10.py //用于载入CIFAR10数据集 │ ├── cifar-10-python.tar.gz │ └── __pycache__ │ ├── cifar10.cpython-310.pyc │ └── cifar10.cpython-312.pyc ├── ESAM.py //ESAM优化算法代码实现 ├── model //存储模型的文件夹 │ ├── __pycache__ │ │ ├── pyramidNet.cpython-310.pyc │ │ ├── pyramidNet.cpython-312.pyc │ │ ├── ResNet.cpython-310.pyc │ │ ├── ResNet.cpython-312.pyc │ │ ├── smoothCrossEntropy.cpython-310.pyc │ │ ├── smoothCrossEntropy.cpython-312.pyc │ │ ├── wideResNet.cpython-310.pyc │ │ └── wideResNet.cpython-312.pyc │ ├── pyramidNet.py //pyramidNet实现代码 │ ├── ResNet.py //ResNet实现代码 │ ├── smoothCrossEntropy.py //平滑交叉熵损失计算函数 │ └── wideResNet.py //WideResNet实现代码 ├── __pycache__ │ └── sam.cpython-310.pyc ├── README.md ├── requirements.txt //依赖包 ├── sam.py //sam算法实现代码 ├── trainAdam.py //使用Adam优化算法训练模型 ├── trainESAM.py //使用ESAM优化算法训练模型 ├── train.py //使用SAM优化算法训练模型 ├── trainSGD.py //使用SGD优化算法训练模型 └── utility //存储各种在训练时用到的函数 ├── bypassBn.py //BN归一化实现函数 ├── cutout.py //实现Cutout,随机裁剪代码 ├── initialize.py //初始化函数 ├── loadingBar.py //实现打印在终端的进度条 ├── log.py //实现在终端输出的代码 ├── __pycache__ │ ├── bypassBn.cpython-310.pyc │ ├── cutout.cpython-310.pyc │ ├── initialize.cpython-310.pyc │ ├── loadingBar.cpython-310.pyc │ ├── log.cpython-310.pyc │ └── stepLR.cpython-310.pyc └── stepLR.py //实现学习率调度器代码 # 数据集下载 本次实验采用的数据集是CIFAR10,在运行时会自动下载 # 运行方式 修改模型和参数需要对文件内对应部分进行修改。修改模型需要对各个文件内注释区域去修改。 使用终端运行 python train.py # 优化算法为SAM算法 python trainESAM.py # 优化算法为ESAM算法 python trainAdam.py # 优化算法为Adam算法 python trainSGD.py # 优化算法为SGD算法 # 实验结果 ## 在WideResNet架构下的运行结果 | 优化算法 | 参数修改 | 准确率 | |-----------------|------------------------|---------------| | SAM | 默认 | 97.06 | | SGD | 默认 | 96.25 | | ADAM | 默认 | 75.37 | | SAM | lr=0.01 | 94.47 | | SGD | lr=0.01 | 93.92 | | ADAM | lr=0.01 | 81.89 | | SAM | 扰动系数为3 | 96.69 | | SAM | 动量为0.8 | 96.7 | | SGD | 动量为0.8 | 96.25 | | SAM | 权重衰减为0.001 | 96.79 | | SGD | 权重衰减为0.001 | 96.13 | | ADAM | 权重衰减为0.001 | 68.56 | | ESAM | 默认 | 96.66 | ## 在ResNet架构下的运行结果 | 优化算法 | 参数修改 | 准确率 | |-----------------|------------------------|---------------| | SAM | 默认 | 95.44 | | SGD | 默认 | 95.62 | | ADAM | 默认 | 71.29 | | SAM | lr=0.01 | 94.15 | | SGD | lr=0.01 | 94.06 | | ADAM | lr=0.01 | 80.22 | | SAM | 扰动系数为3 | 93.69 | | SAM | 动量为0.8 | 92.9 | | SGD | 动量为0.8 | 95.22 | | SAM | 权重衰减为0.001 | 95.75 | | SGD | 权重衰减为0.001 | 95.08 | | ADAM | 权重衰减为0.001 | 58.90 | | ESAM | 默认 | 95.34 |