# CornerNet-Lite-Pytorch **Repository Path**: soverngity/CornerNet-Lite-Pytorch ## Basic Information - **Project Name**: CornerNet-Lite-Pytorch - **Description**: :rotating_light::rotating_light::rotating_light: CornerNet:基于虚拟仿真环境下的自动驾驶交通标志识别 - **Primary Language**: Unknown - **License**: BSD-3-Clause - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 1 - **Forks**: 2 - **Created**: 2019-12-21 - **Last Updated**: 2024-07-07 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README ## CornerNet (CornerNet-Lite)实现基于虚拟仿真环境下的自动驾驶交通标志识别 **Xu Jing** 随着汽车产业变革的推进,自动驾驶已经成为行业新方向。如今,无论是科技巨头还是汽车厂商都在加紧布局自动驾驶,如何保障研发优势、降低投入成本,从而加快实现自动驾驶汽车商业化成为了主要焦点。作为典型的自主式智能系统,自动驾驶是集人工智能、机器学习、控制理论和电子技术等多种技术学科交叉的产物。 **虚拟仿真测试**作为一种新兴测试方法,可快速提供实车路测难以企及的测试里程并模拟任意场景,凭借“低成本、高效率、高安全性”成为验证自动驾驶技术的关键环节,根据各传感器采集到的数据信息作出精准分析和智能决策,从而提高自动驾驶汽车行驶安全性,成为自动驾驶发展过程中不可或缺的技术支撑手段。天津卡达克数据有限公司在此背景下,积极应对产业变革,依托数据资源及相关产业背景,研发智能网联汽车仿真云平台,助推自动驾驶技术快速落地。 自动驾驶系统的环境感知能力是决定仿真结果准确性的重要因素之一,天津卡达克数据有限公司在[DataFountain](https://www.datafountain.cn/)发布**基于虚拟仿真环境下的自动驾驶交通标志识别**赛题的目的旨在推动仿真环境下环境感知算法的科研水平。本数据集以虚拟仿真环境下依托视频传感器数据进行交通标志检测与识别为例,希望在全球范围内发掘和培养自动驾驶算法技术人才。 ### 1.:racehorse:数据集介绍 数据来源于虚拟场景环境下自动驾驶车辆采集的道路交通数据,包括道路周边交通标志牌数据,场景中会有不同的天气状况和行人状况作为干扰因素,采用仿真环境下车辆摄像头采集数据为依靠,指导虚拟仿真环境自动驾驶技术的提升。 提供训练数据集train.csv和训练集打包文件,数据集包含两列,分别为训练数据图像文件名称和label; 提供测试数据集evaluation_private.csv和测试集打包文件,测试数据集包含两列,分别为测试图像文件名称和答案,测试数据集检测label; 提供评测数据集evaluation_public.csv和评测集打包文件,评测数据集包含一列,为评测图像文件名称; **训练数据说明**:每一张图片,都会有标注结果文件中的(.csv)(UTF-8编码)一行。文本文件每行对应于图像中的一个四边形框,以“,”分割不同的字段,具体描述具体格式如下: ``` filename,X1,Y1,X2,Y2,X3,Y3,X4,Y4,type ```
字段名称 | 类型 | 描述 |
---|---|---|
filename | string | 图片名称 |
X1 | int | 左上角X坐标 |
Y1 | int | 左上角Y坐标 |
X2 | int | 右上角X坐标 |
Y2 | int | 右上角Y坐标 |
X3 | int | 右下角X坐标 |
Y3 | int | 右下角Y坐标 |
X4 | int | 左下角X坐标 |
Y4 | int | 左下角Y坐标 |
type | int | 交通标示对应的编号(具体看下表) |
类型 | 对应的编号 |
---|---|
停车场 | 1 |
停车让行 | 2 |
右侧行驶 | 3 |
向左和向右转弯 | 4 |
大客车通行 | 5 |
左侧行驶 | 6 |
慢行 | 7 |
机动车直行和右转弯 | 8 |
注意行人 | 9 |
环岛行驶 | 10 |
直行和右转弯 | 11 |
禁止大客车通行 | 12 |
禁止摩托车通行 | 13 |
禁止机动车通行 | 14 |
禁止非机动车通行 | 15 |
禁止鸣喇叭 | 16 |
立交直行和转弯行驶 | 17 |
限制速度40公里每小时 | 18 |
限速30公里每小时 | 19 |
鸣喇叭 | 20 |
其他 | 0 |
{
"system": {
"dataset": "myData", #数据集
"batch_size": 32, #batch_size
"sampling_function": "cornernet_saccade", #数据增强策略
"train_split": "train", #训练集
"val_split": "test", #验证集
"learning_rate": 0.00025, #初始学习率
"decay_rate": 10, #学习率衰减因子
"val_iter": 100, #每迭代val_iter计算一次val loss
"opt_algo": "adam", #优化器
"prefetch_size": 5, #队列预取数据量
"max_iter": 100000, #训练迭代的总次数
"stepsize": 200, #训练时每迭代stepsize次学习率衰减为原来的1/decay_rate
"snapshot": 5000, #训练每迭代snapshot次保存一次模型参数
"chunk_sizes": [
32
] #每块GPU上处理的图片数,其和等于batch_size
},
"db": {
"rand_scale_min": 0.5, #随机裁减比例[0.5,0.6,0.7,...,1.1]
"rand_scale_max": 1.1,
"rand_scale_step": 0.1,
"rand_scales": null,
"rand_full_crop": true, #随机裁剪
"gaussian_bump": true, #是否使用二维高斯给出惩罚减少量
"gaussian_iou": 0.5, #高斯半径的大小根据object尺寸得到,bounding box和gt box至少0.5IoU
"min_scale": 16,
"view_sizes": [],
"height_mult": 31,
"width_mult": 31,
"input_size": [
255,
255
], #网络输入图片的size
"output_sizes": [
[
64,
64
]
], #网络输出图片的size
"att_max_crops": 30, #和attention map相关的参数设置
"att_scales": [
[
1,
2,
4
]
],
"att_thresholds": [
0.3
], #概率>0.3的location被筛出来
"top_k": 12, #maximum number of crops to process
"num_dets": 12,
"categories": 10, #类别数
"ae_threshold": 0.3, #测试时,仅处理attention maps上 score > as_threshold=0.3 的locations
"nms_threshold": 0.5, #nms的阈值
"max_per_image": 100 #maximum number of objects to predict on a single image
}
}
import torch
import torch.nn as nn
from .py_utils import TopPool, BottomPool, LeftPool, RightPool
from .py_utils.utils import convolution, residual, corner_pool
from .py_utils.losses import CornerNet_Saccade_Loss
from .py_utils.modules import saccade_net, saccade_module, saccade
def make_pool_layer(dim):
return nn.Sequential()
def make_hg_layer(inp_dim, out_dim, modules):
layers = [residual(inp_dim, out_dim, stride=2)]
layers += [residual(out_dim, out_dim) for _ in range(1, modules)]
return nn.Sequential(*layers)
class model(saccade_net):
def _pred_mod(self, dim):
return nn.Sequential(
convolution(3, 256, 256, with_bn=False),
nn.Conv2d(256, dim, (1, 1))
)
def _merge_mod(self):
return nn.Sequential(
nn.Conv2d(256, 256, (1, 1), bias=False),
nn.BatchNorm2d(256)
)
def __init__(self):
stacks = 3
pre = nn.Sequential(
convolution(7, 3, 128, stride=2),
residual(128, 256, stride=2)
)
hg_mods = nn.ModuleList([
saccade_module(
3, [256, 384, 384, 512], [1, 1, 1, 1],
make_pool_layer=make_pool_layer,
make_hg_layer=make_hg_layer
) for _ in range(stacks)
])
cnvs = nn.ModuleList([convolution(3, 256, 256) for _ in range(stacks)])
inters = nn.ModuleList([residual(256, 256) for _ in range(stacks - 1)])
cnvs_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)])
inters_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)])
att_mods = nn.ModuleList([
nn.ModuleList([
nn.Sequential(
convolution(3, 384, 256, with_bn=False),
nn.Conv2d(256, 1, (1, 1))
),
nn.Sequential(
convolution(3, 384, 256, with_bn=False),
nn.Conv2d(256, 1, (1, 1))
),
nn.Sequential(
convolution(3, 256, 256, with_bn=False),
nn.Conv2d(256, 1, (1, 1))
)
]) for _ in range(stacks)
])
for att_mod in att_mods:
for att in att_mod:
torch.nn.init.constant_(att[-1].bias, -2.19)
hgs = saccade(pre, hg_mods, cnvs, inters, cnvs_, inters_)
tl_modules = nn.ModuleList([corner_pool(256, TopPool, LeftPool) for _ in range(stacks)])
br_modules = nn.ModuleList([corner_pool(256, BottomPool, RightPool) for _ in range(stacks)])
#这路需要修改为自己的类别数,我们这里是21个类别!!!
tl_heats = nn.ModuleList([self._pred_mod(21) for _ in range(stacks)])
br_heats = nn.ModuleList([self._pred_mod(21) for _ in range(stacks)])
for tl_heat, br_heat in zip(tl_heats, br_heats):
torch.nn.init.constant_(tl_heat[-1].bias, -2.19)
torch.nn.init.constant_(br_heat[-1].bias, -2.19)
tl_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)])
br_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)])
tl_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)])
br_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)])
super(model, self).__init__(
hgs, tl_modules, br_modules, tl_heats, br_heats,
tl_tags, br_tags, tl_offs, br_offs, att_mods
)
self.loss = CornerNet_Saccade_Loss(pull_weight=1e-1, push_weight=1e-1)