# 基于SSRNetV2的杂草识别 **Repository Path**: wang-yong_ke/weed-classification-based-on-ssrnetv2 ## Basic Information - **Project Name**: 基于SSRNetV2的杂草识别 - **Description**: 本项目使用自组网络SSRNetV2在杂草数据集进行训练和预测,精度为97.55%,速度为9ms,参数量为1.12M,优于MobileNetV3。 - **Primary Language**: Python - **License**: Apache-2.0 - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 3 - **Created**: 2024-06-26 - **Last Updated**: 2024-06-26 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # 基于SSRNetV2的杂草识别 本项目使用自组网络[SSRNetV2](http://www.ecice06.com/CN/10.19678/j.issn.1000-3428.0061329)在杂草数据集进行训练和预测,精度为97.55%,速度为9ms,参数量为1.12M,优于MobileNetV3。 [运行链接](https://aistudio.baidu.com/aistudio/projectdetail/2274656?shared=1) # 一、项目背景 为了解决农田无人自动除草,需要把杂草识别算法部署到嵌入式设备上。杂草识别任务对精度、速度和参数量有较高的要求。 # 二、数据简介 本项目使用的[杂草识别数据集](https://aistudio.baidu.com/aistudio/datasetdetail/96823)来自公开杂草数据,包含9个类别,8个杂草类别和1个负类。总共17509张图片,14036张训练图片和3473张验证图片。 ## 1. 解压图片 解压图片到data目录 ```python !unzip -oq -d data/ data/data96823/weeds.zip ``` ## 2. 查看图片 查看训练列表中的前三张图片 ```python import os, cv2 import numpy as np import matplotlib.pyplot as plt # 创建画布 plt.figure(figsize=(9, 3)) # 绘制图像 train_path = './data/train.txt' # 训练文件路径 with open(train_path, 'r') as f: # 打开训练文件 for i, line in enumerate(f.readlines()): # 遍历每行记录 # 读取数据 if i > 2: # 读取三条数据 break image_path, label_id = line.strip().split() # 读取一行记录 image_path = os.path.join(os.path.split(train_path)[0], image_path) # 获取图像路径 # 读取图像 with open(image_path, 'rb') as f: # 打开图像文件 image = f.read() # 读取图像数据 image = np.frombuffer(image, dtype='uint8') # 读到ndarray缓存 image = cv2.imdecode(image, 1) # 解码为3通道图像 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转换为RGB图像 # 打印信息 print('图像路径:', image_path, '图像形状:', image.shape, '标签编号:', label_id) # 绘制图像 plt.subplot(1, 3, i + 1) plt.axis('off') plt.imshow(image) # 显示图像 plt.tight_layout() plt.show() ``` 图像路径: ./data/trainImageSet/20170811-110835-1.jpg 图像形状: (256, 256, 3) 标签编号: 0 图像路径: ./data/trainImageSet/20161207-144902-0.jpg 图像形状: (256, 256, 3) 标签编号: 0 图像路径: ./data/trainImageSet/20170704-153719-0.jpg 图像形状: (256, 256, 3) 标签编号: 0 ![png](1.png) ## 3. 读取图片 使用paddle.io.Dataset()接口对训练数据和验证数据进行封装,便于对数据集进行训练。训练数据首先缩放到224×224,然后进行随机裁剪、随机翻转、图片格式由HWC变换到CHW和减均值操作。验证数据首先缩放到224×224,然后进行变换图片格式HWC到CHW和减均值操作。 ```python import os, cv2, paddle import numpy as np class WeedDataset(paddle.io.Dataset): def __init__(self, lists_path, mode='train'): """ 初始化数据集 - lists_path: 列表文件路径 - mode : 数据读取模式 """ super().__init__() # 设置参数 assert os.path.exists(lists_path), f"错误:{lists_path}不存在!" # 检测文件是否存在 assert mode in ['train', 'valid'], "错误:数据读取模式必须为'train'或'valid'!" # 检测数据读取模式 self.lists_path = lists_path # 列表文件路径 self.lists_dire = os.path.split(lists_path)[0] # 列表文件目录 # 读取列表 self.image_list = [] # 图像路径列表 self.label_list = [] # 标签编号列表 with open(self.lists_path, 'r') as f: # 打开列表文件 for line in f.readlines(): # 遍历每行记录 image_path, label_id = line.strip().split() # 读取一行记录 self.image_list.append(os.path.join(self.lists_dire, image_path)) # 添加图像路径 self.label_list.append(label_id) # 添加标签编号 # 数据增强 if mode == 'train': self.transforms = paddle.vision.transforms.Compose([ paddle.vision.transforms.Resize(size=(224, 224)), # 变换图像大小 paddle.vision.transforms.RandomCrop(size=224, padding=24), # 随机填充裁剪 paddle.vision.transforms.RandomHorizontalFlip(prob=0.5), # 随机水平翻转 paddle.vision.transforms.Transpose(order=(2, 0, 1)), # 变换图像通道 paddle.vision.transforms.Normalize(mean=[123.675, 116.28, 103.53], std =[58.395 , 57.12 , 57.375]) # 图像数据归一 ]) else: self.transforms = paddle.vision.transforms.Compose([ paddle.vision.transforms.Resize(size=(224, 224)), # 变换图像大小 paddle.vision.transforms.Transpose(order=(2, 0, 1)), # 变换图像通道 paddle.vision.transforms.Normalize(mean=[123.675, 116.28, 103.53], std =[58.395 , 57.12 , 57.375]) # 图像数据归一 ]) def __getitem__(self, index): """ 获取一项数据 params: - index: 数据索引 return: - image: 图像数据 - label: 标签编号 """ # 读取图像 with open(self.image_list[index], 'rb') as f: # 打开图像文件 image = f.read() # 读取图像数据 image = np.frombuffer(image, dtype='uint8') # 读到ndarray缓存 image = cv2.imdecode(image, 1) # 解码为3通道图片 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转换为RGB图片 # 读取标签 label = np.array(self.label_list[index], dtype='int64') # 增强数据 image = self.transforms(image) return image, label def __len__(self): """ 返回数据总数 """ return len(self.image_list) ``` # 三、模型组网 本项目使用[SSRNetV2](http://www.ecice06.com/CN/10.19678/j.issn.1000-3428.0061329)进行组网,该网络模型使用了浅层网络和多尺度分割技术,在保持精度的条件下,极大的提升了速度,减小了模型参数量,非常适合在嵌入式设备上进行部署。 ## 1. 网络实现 ```python import paddle class ConvUnit(paddle.nn.Sequential): def __init__(self, in_dim, ou_dim, kernel_size=1, stride=1): """ 初始化卷积单元,带有激活函数 params: - in_dim (int): 输入维度 - ou_dim (int): 输出维度 - kernel_size(int): 卷积大小 - stride (int): 滑动步长 """ super().__init__( paddle.nn.Conv2D(in_dim, ou_dim, kernel_size, stride, padding='same', bias_attr=False), paddle.nn.BatchNorm2D(ou_dim), paddle.nn.LeakyReLU() ) class ConvNorm(paddle.nn.Sequential): def __init__(self, in_dim, ou_dim, kernel_size=1, stride=1): """ 初始化卷积单元,不带激活函数 params: - in_dim (int): 输入维度 - ou_dim (int): 输出维度 - kernel_size(int): 卷积大小 - stride (int): 滑动步长 """ super().__init__( paddle.nn.Conv2D(in_dim, ou_dim, kernel_size, stride, padding='same', bias_attr=False), paddle.nn.BatchNorm2D(ou_dim) ) class ProjUint(paddle.nn.Sequential): def __init__(self, in_dim, ou_dim, stride=1): """ 初始化投影单元,改变特征大小和维度 params: - in_dim(int): 输入维度 - ou_dim(int): 输出维度 - stride(int): 滑动步长 """ super().__init__( paddle.nn.AvgPool2D(kernel_size=stride, stride=stride, padding=0), paddle.nn.Conv2D(in_dim, ou_dim, kernel_size=1, stride=1, padding=0, bias_attr=False), paddle.nn.BatchNorm2D(ou_dim) ) class SSRSplit(paddle.nn.Layer): def __init__(self, in_dim, ou_dim, stride=1, splits=1): """ 初始化分割单元 params: - in_dim (int): 输入维度 - ou_dim (int): 输出维度 - stride (int): 滑动步长 - splits (int): 分割次数,分割尺度为2^n """ super().__init__() # 输入参数检查 assert stride in [1, 2], '错误:滑动步长必须为1或2!' assert splits >= 0 , '错误:分割次数必须大于等于0!' # 设置分割变量 self.splits = splits # 分割次数 self.dimensions = [] # 维度列表 self.split_list = [] # 分割列表 # 添加分割列表 split_item = self.add_sublayer( # 添加第一个分割项目 'split_' + str(0), ConvUnit(in_dim, ou_dim, kernel_size=3, stride=stride) ) self.split_list.append(split_item) for i in range(self.splits): # 添加剩余分的割项目 # 添加维度列表 if i < self.splits: ou_dim //= 2 self.dimensions.append(ou_dim) # 添加分割列表 split_item = self.add_sublayer( 'split_' + str(i+1), ConvUnit(ou_dim, ou_dim, kernel_size=3, stride=1) ) self.split_list.append(split_item) def forward(self, x): # 提取特征 x_list = [] # 特征列表 for i, split_item in enumerate(self.split_list): if i < self.splits: x = split_item(x) # x_item, x = paddle.split(x, num_or_sections=[-1, self.dimensions[i]], axis=1) x_item, x = paddle.split(x, num_or_sections=2, axis=1) # 当使用paddelite不支持上面分割时,通道数需要设为2的次幂 x_list.append(x_item) else: x = split_item(x) x_list.append(x) # 合并特征 x = paddle.concat(x_list, axis=1) return x class SSRBasic(paddle.nn.Layer): def __init__(self, in_dim, ch_dim, ou_dim, stride=1, splits=0, direct=True): """ 初始化基础模块 params: - in_dim (int): 输入维度 - ch_dim (int): 通道维度 - ou_dim (int): 输出维度 - stride (int): 滑动步长 - splits (int): 分割次数 - direct(bool): 直连标识 """ super().__init__() # 设置直连标识 self.is_pass = direct # 添加投影单元 self.project = ProjUint(in_dim, ou_dim, stride) # 添加卷积单元 self.convbn0 = ConvUnit(in_dim, ch_dim, kernel_size=1, stride=1) self.convbn1 = SSRSplit(ch_dim, ch_dim, stride, splits) self.convbn2 = ConvNorm(ch_dim, ou_dim, kernel_size=1, stride=1) # 添加激活函数 self.ly_relu = paddle.nn.LeakyReLU() def forward(self, x): # 直连路径 if self.is_pass: y = x else: y = self.project(x) # 卷积路径 x = self.convbn0(x) x = self.convbn1(x) x = self.convbn2(x) x = x + y x = self.ly_relu(x) return x class SSRBlock(paddle.nn.Layer): def __init__(self, in_dim, ch_dim, ou_dim, reduce=0, repeat=0, splits=0): """ 初始化模块结构 params: - in_dim(int): 输入维度 - ch_dim(int): 通道维度 - ou_dim(int): 输出维度 - reduce(int): 缩小次数,缩小倍数为2^n - repeat(int): 重复次数,必须大于等于0 - splits(int): 分割次数,分割尺度为2^n """ super().__init__() # 设置输入参数 assert reduce >= 0, '错误:缩小次数必须大于等于0!' assert repeat >= 0, '错误:重复次数必须大于等于0!' self.reduce = reduce # 缩小次数 self.repeat = repeat # 重复次数 # 添加缩放项目 self.block_item0 = SSRBasic( in_dim, ch_dim, ou_dim, stride=(2 if self.reduce > 0 else 1), splits=splits, direct=False ) if self.reduce > 1: # 当缩小次数大于等于2时 self.block_item1 = SSRBasic(ou_dim, ch_dim, ou_dim, stride=2, splits=splits, direct=False) # 添加重复项目 if self.repeat > 0: # 当重复次数大于等于1时 self.block_item2 = SSRBasic(ou_dim, ch_dim, ou_dim, stride=1, splits=splits, direct=True) def forward(self, x): # 缩放特征 x = self.block_item0(x) for i in range(1, self.reduce): x = self.block_item1(x) # 当缩小次数大于等于2时 # 重复特征 for i in range(self.repeat): x = self.block_item2(x) # 当重复次数大于等于1时 return x class SSRGroup(paddle.nn.Layer): def __init__(self, group_arch): """ 初始化模组结构 params: - group_arch(list): 特征模组结构 """ super().__init__() # 添加模组列表 self.group_list = [] # 模组列表 for i, block_arch in enumerate(group_arch): group_item = self.add_sublayer( 'group_' + str(i), SSRBlock( in_dim=block_arch[0], ch_dim=block_arch[1], ou_dim=block_arch[2], reduce=block_arch[3], repeat=block_arch[4], splits=block_arch[5] ) ) self.group_list.append(group_item) def forward(self, x): # 提取特征 for group_item in self.group_list: x = group_item(x) return x class SSRNetV2(paddle.nn.Layer): def __init__(self, num_classes=1000): """ 初始化网络模型 params: - num_classes(int): 分类类别数量. 如果num_classes<=0, 分类头部不会被定义. 默认值为: 1000. """ super().__init__() # 定义模型结构 group_arch = [ # 输入维度, 通道维度, 输出维度, 缩小次数, 重复次数, 分割次数 [3, 32, 128, 2, 1, 2], [128, 64, 256, 2, 1, 2], [256, 128, 512, 1, 1, 2] ] # 特征块组结构 self.num_feature = 512 # 特征输出维度 self.num_classes = num_classes # 分类类别数量 # 添加骨干网络 self.backbone = SSRGroup(group_arch) # 添加分类头部 if self.num_classes > 0: self.head = paddle.nn.Sequential( paddle.nn.AdaptiveAvgPool2D(output_size=1), paddle.nn.Flatten(start_axis=1), paddle.nn.Linear(self.num_feature, self.num_classes) ) def forward(self, x): # 提取特征 x = self.backbone(x) # 进行分类 if self.num_classes > 0: x = self.head(x) return x ``` ## 2. 网络结构 ```python import paddle # 定义模型 num_classes = 9 # 类别数量 model = SSRNetV2(num_classes) # 声明模型 paddle.summary(model, (1, 3, 224, 224)) # 模型参数 # 输入数据 image = paddle.randn(shape=[1, 3, 224, 224], dtype='float32') # 处理数据 infer = model(image) # 输出特征 print('输出形状:', infer.shape) ``` ------------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # =============================================================================== AvgPool2D-9 [[1, 3, 224, 224]] [1, 3, 112, 112] 0 Conv2D-49 [[1, 3, 112, 112]] [1, 128, 112, 112] 384 BatchNorm2D-49 [[1, 128, 112, 112]] [1, 128, 112, 112] 512 Conv2D-50 [[1, 3, 224, 224]] [1, 32, 224, 224] 96 BatchNorm2D-50 [[1, 32, 224, 224]] [1, 32, 224, 224] 128 LeakyReLU-41 [[1, 32, 224, 224]] [1, 32, 224, 224] 0 Conv2D-51 [[1, 32, 224, 224]] [1, 32, 112, 112] 9,216 BatchNorm2D-51 [[1, 32, 112, 112]] [1, 32, 112, 112] 128 LeakyReLU-42 [[1, 32, 112, 112]] [1, 32, 112, 112] 0 Conv2D-52 [[1, 16, 112, 112]] [1, 16, 112, 112] 2,304 BatchNorm2D-52 [[1, 16, 112, 112]] [1, 16, 112, 112] 64 LeakyReLU-43 [[1, 16, 112, 112]] [1, 16, 112, 112] 0 Conv2D-53 [[1, 8, 112, 112]] [1, 8, 112, 112] 576 BatchNorm2D-53 [[1, 8, 112, 112]] [1, 8, 112, 112] 32 LeakyReLU-44 [[1, 8, 112, 112]] [1, 8, 112, 112] 0 SSRSplit-9 [[1, 32, 224, 224]] [1, 32, 112, 112] 0 Conv2D-54 [[1, 32, 112, 112]] [1, 128, 112, 112] 4,096 BatchNorm2D-54 [[1, 128, 112, 112]] [1, 128, 112, 112] 512 LeakyReLU-45 [[1, 128, 112, 112]] [1, 128, 112, 112] 0 SSRBasic-9 [[1, 3, 224, 224]] [1, 128, 112, 112] 0 AvgPool2D-10 [[1, 128, 112, 112]] [1, 128, 56, 56] 0 Conv2D-55 [[1, 128, 56, 56]] [1, 128, 56, 56] 16,384 BatchNorm2D-55 [[1, 128, 56, 56]] [1, 128, 56, 56] 512 Conv2D-56 [[1, 128, 112, 112]] [1, 32, 112, 112] 4,096 BatchNorm2D-56 [[1, 32, 112, 112]] [1, 32, 112, 112] 128 LeakyReLU-46 [[1, 32, 112, 112]] [1, 32, 112, 112] 0 Conv2D-57 [[1, 32, 112, 112]] [1, 32, 56, 56] 9,216 BatchNorm2D-57 [[1, 32, 56, 56]] [1, 32, 56, 56] 128 LeakyReLU-47 [[1, 32, 56, 56]] [1, 32, 56, 56] 0 Conv2D-58 [[1, 16, 56, 56]] [1, 16, 56, 56] 2,304 BatchNorm2D-58 [[1, 16, 56, 56]] [1, 16, 56, 56] 64 LeakyReLU-48 [[1, 16, 56, 56]] [1, 16, 56, 56] 0 Conv2D-59 [[1, 8, 56, 56]] [1, 8, 56, 56] 576 BatchNorm2D-59 [[1, 8, 56, 56]] [1, 8, 56, 56] 32 LeakyReLU-49 [[1, 8, 56, 56]] [1, 8, 56, 56] 0 SSRSplit-10 [[1, 32, 112, 112]] [1, 32, 56, 56] 0 Conv2D-60 [[1, 32, 56, 56]] [1, 128, 56, 56] 4,096 BatchNorm2D-60 [[1, 128, 56, 56]] [1, 128, 56, 56] 512 LeakyReLU-50 [[1, 128, 56, 56]] [1, 128, 56, 56] 0 SSRBasic-10 [[1, 128, 112, 112]] [1, 128, 56, 56] 0 Conv2D-62 [[1, 128, 56, 56]] [1, 32, 56, 56] 4,096 BatchNorm2D-62 [[1, 32, 56, 56]] [1, 32, 56, 56] 128 LeakyReLU-51 [[1, 32, 56, 56]] [1, 32, 56, 56] 0 Conv2D-63 [[1, 32, 56, 56]] [1, 32, 56, 56] 9,216 BatchNorm2D-63 [[1, 32, 56, 56]] [1, 32, 56, 56] 128 LeakyReLU-52 [[1, 32, 56, 56]] [1, 32, 56, 56] 0 Conv2D-64 [[1, 16, 56, 56]] [1, 16, 56, 56] 2,304 BatchNorm2D-64 [[1, 16, 56, 56]] [1, 16, 56, 56] 64 LeakyReLU-53 [[1, 16, 56, 56]] [1, 16, 56, 56] 0 Conv2D-65 [[1, 8, 56, 56]] [1, 8, 56, 56] 576 BatchNorm2D-65 [[1, 8, 56, 56]] [1, 8, 56, 56] 32 LeakyReLU-54 [[1, 8, 56, 56]] [1, 8, 56, 56] 0 SSRSplit-11 [[1, 32, 56, 56]] [1, 32, 56, 56] 0 Conv2D-66 [[1, 32, 56, 56]] [1, 128, 56, 56] 4,096 BatchNorm2D-66 [[1, 128, 56, 56]] [1, 128, 56, 56] 512 LeakyReLU-55 [[1, 128, 56, 56]] [1, 128, 56, 56] 0 SSRBasic-11 [[1, 128, 56, 56]] [1, 128, 56, 56] 0 SSRBlock-4 [[1, 3, 224, 224]] [1, 128, 56, 56] 0 AvgPool2D-12 [[1, 128, 56, 56]] [1, 128, 28, 28] 0 Conv2D-67 [[1, 128, 28, 28]] [1, 256, 28, 28] 32,768 BatchNorm2D-67 [[1, 256, 28, 28]] [1, 256, 28, 28] 1,024 Conv2D-68 [[1, 128, 56, 56]] [1, 64, 56, 56] 8,192 BatchNorm2D-68 [[1, 64, 56, 56]] [1, 64, 56, 56] 256 LeakyReLU-56 [[1, 64, 56, 56]] [1, 64, 56, 56] 0 Conv2D-69 [[1, 64, 56, 56]] [1, 64, 28, 28] 36,864 BatchNorm2D-69 [[1, 64, 28, 28]] [1, 64, 28, 28] 256 LeakyReLU-57 [[1, 64, 28, 28]] [1, 64, 28, 28] 0 Conv2D-70 [[1, 32, 28, 28]] [1, 32, 28, 28] 9,216 BatchNorm2D-70 [[1, 32, 28, 28]] [1, 32, 28, 28] 128 LeakyReLU-58 [[1, 32, 28, 28]] [1, 32, 28, 28] 0 Conv2D-71 [[1, 16, 28, 28]] [1, 16, 28, 28] 2,304 BatchNorm2D-71 [[1, 16, 28, 28]] [1, 16, 28, 28] 64 LeakyReLU-59 [[1, 16, 28, 28]] [1, 16, 28, 28] 0 SSRSplit-12 [[1, 64, 56, 56]] [1, 64, 28, 28] 0 Conv2D-72 [[1, 64, 28, 28]] [1, 256, 28, 28] 16,384 BatchNorm2D-72 [[1, 256, 28, 28]] [1, 256, 28, 28] 1,024 LeakyReLU-60 [[1, 256, 28, 28]] [1, 256, 28, 28] 0 SSRBasic-12 [[1, 128, 56, 56]] [1, 256, 28, 28] 0 AvgPool2D-13 [[1, 256, 28, 28]] [1, 256, 14, 14] 0 Conv2D-73 [[1, 256, 14, 14]] [1, 256, 14, 14] 65,536 BatchNorm2D-73 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024 Conv2D-74 [[1, 256, 28, 28]] [1, 64, 28, 28] 16,384 BatchNorm2D-74 [[1, 64, 28, 28]] [1, 64, 28, 28] 256 LeakyReLU-61 [[1, 64, 28, 28]] [1, 64, 28, 28] 0 Conv2D-75 [[1, 64, 28, 28]] [1, 64, 14, 14] 36,864 BatchNorm2D-75 [[1, 64, 14, 14]] [1, 64, 14, 14] 256 LeakyReLU-62 [[1, 64, 14, 14]] [1, 64, 14, 14] 0 Conv2D-76 [[1, 32, 14, 14]] [1, 32, 14, 14] 9,216 BatchNorm2D-76 [[1, 32, 14, 14]] [1, 32, 14, 14] 128 LeakyReLU-63 [[1, 32, 14, 14]] [1, 32, 14, 14] 0 Conv2D-77 [[1, 16, 14, 14]] [1, 16, 14, 14] 2,304 BatchNorm2D-77 [[1, 16, 14, 14]] [1, 16, 14, 14] 64 LeakyReLU-64 [[1, 16, 14, 14]] [1, 16, 14, 14] 0 SSRSplit-13 [[1, 64, 28, 28]] [1, 64, 14, 14] 0 Conv2D-78 [[1, 64, 14, 14]] [1, 256, 14, 14] 16,384 BatchNorm2D-78 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024 LeakyReLU-65 [[1, 256, 14, 14]] [1, 256, 14, 14] 0 SSRBasic-13 [[1, 256, 28, 28]] [1, 256, 14, 14] 0 Conv2D-80 [[1, 256, 14, 14]] [1, 64, 14, 14] 16,384 BatchNorm2D-80 [[1, 64, 14, 14]] [1, 64, 14, 14] 256 LeakyReLU-66 [[1, 64, 14, 14]] [1, 64, 14, 14] 0 Conv2D-81 [[1, 64, 14, 14]] [1, 64, 14, 14] 36,864 BatchNorm2D-81 [[1, 64, 14, 14]] [1, 64, 14, 14] 256 LeakyReLU-67 [[1, 64, 14, 14]] [1, 64, 14, 14] 0 Conv2D-82 [[1, 32, 14, 14]] [1, 32, 14, 14] 9,216 BatchNorm2D-82 [[1, 32, 14, 14]] [1, 32, 14, 14] 128 LeakyReLU-68 [[1, 32, 14, 14]] [1, 32, 14, 14] 0 Conv2D-83 [[1, 16, 14, 14]] [1, 16, 14, 14] 2,304 BatchNorm2D-83 [[1, 16, 14, 14]] [1, 16, 14, 14] 64 LeakyReLU-69 [[1, 16, 14, 14]] [1, 16, 14, 14] 0 SSRSplit-14 [[1, 64, 14, 14]] [1, 64, 14, 14] 0 Conv2D-84 [[1, 64, 14, 14]] [1, 256, 14, 14] 16,384 BatchNorm2D-84 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024 LeakyReLU-70 [[1, 256, 14, 14]] [1, 256, 14, 14] 0 SSRBasic-14 [[1, 256, 14, 14]] [1, 256, 14, 14] 0 SSRBlock-5 [[1, 128, 56, 56]] [1, 256, 14, 14] 0 AvgPool2D-15 [[1, 256, 14, 14]] [1, 256, 7, 7] 0 Conv2D-85 [[1, 256, 7, 7]] [1, 512, 7, 7] 131,072 BatchNorm2D-85 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048 Conv2D-86 [[1, 256, 14, 14]] [1, 128, 14, 14] 32,768 BatchNorm2D-86 [[1, 128, 14, 14]] [1, 128, 14, 14] 512 LeakyReLU-71 [[1, 128, 14, 14]] [1, 128, 14, 14] 0 Conv2D-87 [[1, 128, 14, 14]] [1, 128, 7, 7] 147,456 BatchNorm2D-87 [[1, 128, 7, 7]] [1, 128, 7, 7] 512 LeakyReLU-72 [[1, 128, 7, 7]] [1, 128, 7, 7] 0 Conv2D-88 [[1, 64, 7, 7]] [1, 64, 7, 7] 36,864 BatchNorm2D-88 [[1, 64, 7, 7]] [1, 64, 7, 7] 256 LeakyReLU-73 [[1, 64, 7, 7]] [1, 64, 7, 7] 0 Conv2D-89 [[1, 32, 7, 7]] [1, 32, 7, 7] 9,216 BatchNorm2D-89 [[1, 32, 7, 7]] [1, 32, 7, 7] 128 LeakyReLU-74 [[1, 32, 7, 7]] [1, 32, 7, 7] 0 SSRSplit-15 [[1, 128, 14, 14]] [1, 128, 7, 7] 0 Conv2D-90 [[1, 128, 7, 7]] [1, 512, 7, 7] 65,536 BatchNorm2D-90 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048 LeakyReLU-75 [[1, 512, 7, 7]] [1, 512, 7, 7] 0 SSRBasic-15 [[1, 256, 14, 14]] [1, 512, 7, 7] 0 Conv2D-92 [[1, 512, 7, 7]] [1, 128, 7, 7] 65,536 BatchNorm2D-92 [[1, 128, 7, 7]] [1, 128, 7, 7] 512 LeakyReLU-76 [[1, 128, 7, 7]] [1, 128, 7, 7] 0 Conv2D-93 [[1, 128, 7, 7]] [1, 128, 7, 7] 147,456 BatchNorm2D-93 [[1, 128, 7, 7]] [1, 128, 7, 7] 512 LeakyReLU-77 [[1, 128, 7, 7]] [1, 128, 7, 7] 0 Conv2D-94 [[1, 64, 7, 7]] [1, 64, 7, 7] 36,864 BatchNorm2D-94 [[1, 64, 7, 7]] [1, 64, 7, 7] 256 LeakyReLU-78 [[1, 64, 7, 7]] [1, 64, 7, 7] 0 Conv2D-95 [[1, 32, 7, 7]] [1, 32, 7, 7] 9,216 BatchNorm2D-95 [[1, 32, 7, 7]] [1, 32, 7, 7] 128 LeakyReLU-79 [[1, 32, 7, 7]] [1, 32, 7, 7] 0 SSRSplit-16 [[1, 128, 7, 7]] [1, 128, 7, 7] 0 Conv2D-96 [[1, 128, 7, 7]] [1, 512, 7, 7] 65,536 BatchNorm2D-96 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048 LeakyReLU-80 [[1, 512, 7, 7]] [1, 512, 7, 7] 0 SSRBasic-16 [[1, 512, 7, 7]] [1, 512, 7, 7] 0 SSRBlock-6 [[1, 256, 14, 14]] [1, 512, 7, 7] 0 SSRGroup-2 [[1, 3, 224, 224]] [1, 512, 7, 7] 0 AdaptiveAvgPool2D-2 [[1, 512, 7, 7]] [1, 512, 1, 1] 0 Flatten-4 [[1, 512, 1, 1]] [1, 512] 0 Linear-2 [[1, 512]] [1, 9] 4,617 =============================================================================== Total params: 1,179,145 Trainable params: 1,159,337 Non-trainable params: 19,808 ------------------------------------------------------------------------------- Input size (MB): 0.57 Forward/backward pass size (MB): 215.36 Params size (MB): 4.50 Estimated Total Size (MB): 220.44 ------------------------------------------------------------------------------- 输出形状: [1, 9] # 三、网络训练 本项目使用线性预热加余弦退火策略进行训练,前5轮进行线性预热,学习率由0增加到0.5,后595轮学习率由0.5降到0。总共训练600轮,每批次为128。使用Momentum优化器进行训练,并使用L2正则化,正则化系数为0.0001。 ```python import os, math import paddle, visualdl # 设置参数 batch_size = 128 # 数据批次大小 epoch_num = 602 # 训练轮次总数 eval_freq = 10 # 验证数据频率 epoch_lin = 5 # 线性预热轮数 lrate_max = 0.5 # 最大学习率值 lrate_min = 0.0 # 最小学习率值 l2_decay = 0.0001 # 权重衰减系数 momentum = 0.9 # 优化器动量值 train_path = './data/train.txt' # 训练数据路径 valid_path = './data/valid.txt' # 验证数据路径 save_path = './out/' # 模型保存路径 logs_path = './log/' # 日志保存路径 train_dataset = WeedDataset(train_path, 'train') # 训练的数据集 valid_dataset = WeedDataset(valid_path, 'valid') # 验证的数据集 iters_num = math.ceil(len(train_dataset)/batch_size) # 每轮迭代次数 cos_iters = iters_num * (epoch_num - epoch_lin) # 余弦衰减迭数 lin_iters = iters_num * epoch_lin # 线性预热迭数 # 声明模型 num_classes = 9 # 类别数量 model = SSRNetV2(num_classes) # 声明模型 model = paddle.Model(model) # 封装模型 # 优化算法 cos_lrate = paddle.optimizer.lr.CosineAnnealingDecay(lrate_max, cos_iters, lrate_min) # 余弦学习率 scheduler = paddle.optimizer.lr.LinearWarmup(cos_lrate, lin_iters, lrate_min, lrate_max) # 线性学习率 optimizer = paddle.optimizer.Momentum( # 优化器算法 learning_rate=scheduler, momentum=momentum, weight_decay=paddle.regularizer.L2Decay(l2_decay), parameters=model.parameters() ) ############################################################################################################# # 回调日志 class CallLogs(paddle.callbacks.Callback): def __init__(self, logs_path, save_path): """ 初始化记录器 params: - logs_path(str): 日志路径 - save_path(str): 模型路径 """ super().__init__() # 清空日志目录 for root, dirs, files in os.walk(logs_path, topdown=False): # 遍历目录 for name in files: os.remove(os.path.join(root, name)) # 删除文件 for name in dirs: os.rmdir(os.path.join(root, name)) # 删除目录 # 设置日志变量 self.logs_path = logs_path # 日志路径 self.save_path = save_path # 模型路径 self.trian = visualdl.LogWriter(self.logs_path + 'train/') # 训练日志 self.valid = visualdl.LogWriter(self.logs_path + 'valid/') # 验证日志 self.best_epoch = 0 # 最好轮数 self.best_loss = 1e6 # 最好损失 self.best_acc = 0.0 # 最好精度 def on_epoch_end(self, epoch, logs): """ 轮次结束调用 params: - epoch(int): 当前轮数 - logs(dict): 日志字典 """ # 添加训练标量 self.epoch = epoch + 1 # 设置当前轮数 self.trian.add_scalar(tag='Trian/Loss', step=self.epoch, value=logs['loss'][0]) # 添加训练损失 self.trian.add_scalar(tag='Trian/Accuracy', step=self.epoch, value=logs['acc']) # 添加训练精度 def on_eval_end(self, logs): """ 验证结束调用 params: - logs(dict): 日志字典 """ # 添加验证标量 self.valid.add_scalar(tag='Trian/Loss', step=self.epoch, value=logs['loss'][0]) # 添加验证损失 self.valid.add_scalar(tag='Trian/Accuracy', step=self.epoch, value=logs['acc']) # 添加验证精度 # 保存最好参数 if logs['acc'] >= self.best_acc: # 是否最好精度 self.best_epoch = self.epoch # 更新最好轮数 self.best_loss = logs['loss'][0] # 更新最好损失 self.best_acc = logs['acc'] # 更新最好精度 self.model.save(self.save_path + 'great') # 保存训练参数 def on_train_end(self, logs): """ 训练结束调用 params: - logs(dict): 日志字典 """ message = f'ENDED - best epoch: {self.best_epoch}, '+\ f'best loss: {self.best_loss:.6f}, '+\ f'best accuracy: {self.best_acc:.4f}, '+\ f'logs path: {self.logs_path}' # 保存信息 print(message) # 打印信息 with open(self.logs_path + 'log.txt', 'w') as f: # 打开文件 f.write(message + '\n') # 写入信息 call_logs = CallLogs(logs_path, save_path) # 回调日志 ############################################################################################################# # 配置模型 model.prepare( optimizer=optimizer, # 优化算法 loss=paddle.nn.CrossEntropyLoss(), # 损失函数 metrics=paddle.metric.Accuracy() # 评估方法 ) # 训练模型 model.fit( train_dataset, # 训练数据 valid_dataset, # 验证数据 batch_size=batch_size, # 批次大小 epochs =epoch_num, # 训练轮次 eval_freq =eval_freq, # 验证频率 save_dir =save_path, # 保存路径 save_freq =epoch_num, # 保存频率 verbose =1, # 打印日志 shuffle =True, # 打乱数据 callbacks =call_logs # 回调日志 ) ``` # 四、模型保存 ## 1. 网络保存 把动态图训练的模型和权重保存为静态图的模型和权重,便于在嵌入式设备上进行移植。 ```python import paddle # 设置参数 load_path = './out/great.pdparams' # 模型权重路径 save_path = './out/model' # 模型保存路径 # 声明模型 num_classes = 9 # 类别数量 model = SSRNetV2(num_classes) # 声明模型 model.set_state_dict(paddle.load(load_path)) # 加载权重 # 保存模型 paddle.jit.save( model, save_path, input_spec=[paddle.static.InputSpec([None, 3, 224, 224], 'float32', 'x')] ) print('save path:', save_path) ``` ## 2. 网络验证 ```python import paddle # 设置参数 valid_path = './data/valid.txt' # 验证数据路径 model_path = './out/model' # 模型加载路径 # 读取数据 valid_dataset = WeedDataset(valid_path, 'valid') # 验证数据 # 声明模型 model = paddle.jit.load(model_path) # 加载模型 model = paddle.Model(model, inputs=paddle.static.InputSpec([None, 3, 224, 224], 'float32', 'x')) # 封装模型 # 配置模型 model.prepare(metrics=paddle.metric.Accuracy()) # 评估模型 result = model.evaluate(valid_dataset, verbose=1) ``` Eval begin... The loss value printed in the log is the current batch, and the metric is the average value of previous step. step 3473/3473 [==============================] - acc: 0.9755 - 9ms/step Eval samples: 3473 ## 3. 网络预测 ```python import paddle # 设置参数 valid_path = './data/valid.txt' # 验证数据路径 model_path = './out/model' # 模型加载路径 # 读取数据 valid_dataset = WeedDataset(valid_path, 'valid') # 验证数据 # 声明模型 model = paddle.jit.load(model_path) # 加载模型 model = paddle.Model(model, inputs=paddle.static.InputSpec([None, 3, 224, 224], 'float32', 'x')) # 封装模型 # 配置模型 model.prepare() # 模型预测 result = model.predict(valid_dataset) # 显示结果 sample = 1 # 样本编号 print(f'sample: {sample}, infer: {np.argmax(result[0][sample - 1])}') ``` Predict begin... step 3473/3473 [==============================] - 9ms/step Predict samples: 3473 sample: 1, infer: 0 # 五、项目总结 本项目为杂草图像识别任务,使用SSRNetV2能取得了较好的识别精度和速度,并且模型参数量也较小,比较适合部署在嵌入式设备上。 # 六、个人简介 三峡大学 计算机与信息学院 研究生 研究方向:人工智能、嵌入式系统、计算机视觉 [盛夏夜博客](https://www.cnblogs.com/d442130165/) [盛夏夜码云](https://gitee.com/shengxiaye)