# vit_visual_transformer **Repository Path**: SailorCoder/vit_visual_transformer ## Basic Information - **Project Name**: vit_visual_transformer - **Description**: vit模型学习 - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: main - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2026-01-24 - **Last Updated**: 2026-01-24 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # Vision Transformer (ViT) 实现 ## 项目简介 本项目是基于 PyTorch 实现的 Vision Transformer (ViT) 模型,完全复现了 Google Research 原始论文 [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) 中的架构设计。Vision Transformer 将Transformer架构成功应用于图像分类任务,打破了CNN在计算机视觉领域的垄断地位,为视觉任务提供了新的研究范式。 该项目提供了完整的模型实现、训练脚本、可视化工具以及多种预训练配置,支持研究人员和开发者快速上手ViT模型并进行二次开发。代码结构清晰,模块化程度高,方便进行实验和扩展。同时,项目还支持混合模型架构,可以将ResNet作为特征提取器与Transformer编码器结合使用。 ## 功能特性 本项目实现了完整的Vision Transformer训练流程,包含以下核心功能: - **多种模型配置**:支持B16、B32、L16、L32、R50_B16、H14等多种ViT变体配置,每种配置对应不同的模型深度、隐藏层维度和注意力头数量,满足不同规模和精度的需求 - **混合模型支持**:提供了ResNet-V2作为backbone的混合模型架构,结合卷积特征提取与Transformer编码的优势,在某些场景下可获得更好的性能 - **分布式训练**:完整的分布式训练支持,可利用多GPU或多节点计算资源加速训练过程,适合大规模数据集训练 - **学习率调度**:内置多种学习率调度策略,包括Constant、Warmup+Constant、Warmup+Linear、Warmup+Cosine四种模式,支持自定义预热步数 - **注意力可视化**:提供注意力图可视化工具,可以直观展示模型在图像不同区域的注意力分布,帮助理解模型决策过程 - **模型加载**:支持加载预训练权重进行微调或推理,提供与原论文权重格式兼容的加载器 ## 环境要求 在开始使用本项目之前,请确保您的环境满足以下要求: - Python 3.6 或更高版本 - PyTorch 1.4+ (推荐使用1.7+以获得最佳性能) - torchvision (用于数据增强和预处理) - numpy - tensorboardX (用于训练可视化,可选) 建议使用Anaconda或Miniconda创建独立的虚拟环境进行开发,以避免依赖冲突。以下是推荐的环境配置流程: ```bash # 创建虚拟环境 conda create -n vit python=3.8 conda activate vit # 安装PyTorch (根据CUDA版本选择合适的安装命令) pip install torch torchvision # 安装其他依赖 pip install numpy tensorboardX # 克隆项目 git clone https://gitee.com/SailorCoder/vit_visual_transformer.git cd vit_visual_transformer ``` ## 快速开始 ### 模型配置 项目提供了七种预设的模型配置,您可以根据任务需求和计算资源选择合适的配置: - **get_testing()**:用于调试的最小配置,包含2层编码器 - **get_b16_config()**:标准的ViT-Base/16配置,包含12层编码器,隐藏层维度768,12个注意力头 - **get_r50_b16_config()**:ResNet与ViT的混合配置,使用ResNet-V2作为特征提取器 - **get_b32_config()**:ViT-Base/32配置,与B16结构相同但使用更大的patch尺寸 - **get_l16_config()**:ViT-Large/16配置,包含24层编码器,隐藏层维度1024,16个注意力头 - **get_l32_config()**:ViT-Large/32配置 - **get_h14_config()**:超大规模配置,包含14层编码器,更高的隐藏层维度 ### 基本用法 以下代码展示了如何使用本项目构建和训练ViT模型: ```python import torch from models.configs import get_b16_config from models.modeling import VisionTransformer # 创建模型配置 config = get_b16_config() # 初始化模型 model = VisionTransformer(config, img_size=224, num_classes=1000, zero_head=False, vis=False) # 准备输入数据 (batch_size, channels, height, width) x = torch.randn(2, 3, 224, 224) labels = torch.tensor([0, 1]) # 类别标签 # 前向传播 logits = model(x, labels=labels) print(f"输出形状: {logits.shape}") # [2, 1000] # 仅推理模式 with torch.no_grad(): logits = model(x) predictions = logits.argmax(dim=-1) print(f"预测类别: {predictions}") ``` ### 加载预训练权重 项目支持加载Google官方发布的预训练权重进行微调,权重文件通常为.npy格式: ```python from models.modeling import VisionTransformer import numpy as np model = VisionTransformer(get_b16_config(), img_size=224, num_classes=1000) weights = np.load("pretrained_weights/imagenet21k+imagenet2012_ViT-B_16.npy") model.load_from(weights) print("预训练权重加载成功!") ``` ## 训练自定义数据集 项目提供了完整的训练脚本,支持单卡和多卡训练。以下是使用自定义数据集进行训练的步骤: ### 1. 准备数据 将您的数据集整理为以下格式,或修改`utils/data_utils.py`以适配您的数据格式: ``` dataset/ ├── train/ │ ├── class_1/ │ │ ├── img1.jpg │ │ └── img2.jpg │ └── class_2/ └── val/ ├── class_1/ └── class_2/ ``` ### 2. 修改训练配置 在`train.py`中修改数据路径和相关参数: ```python # 数据路径 args.data_dir = "path/to/your/dataset" args.num_classes = 10 # 根据您的数据集类别数修改 args.img_size = 224 # 模型配置 args.model = "ViT-B_16" argspretrained = "pretrained_weights/imagenet21k+imagenet2012_ViT-B_16.npy" # 训练超参数 args.batch_size = 32 args.learning_rate = 0.01 args.epochs = 100 args.warmup_steps = 500 ``` ### 3. 启动训练 ```bash # 单卡训练 python train.py --name vit_experiment --model ViT-B_16 --data_dir ./dataset --epochs 100 --batch_size 32 # 多卡训练 (使用torchrun) torchrun --nproc_per_node=4 train.py --name vit_distributed --model ViT-B_16 --data_dir ./dataset --epochs 100 --batch_size 64 ``` ### 4. 监控训练过程 使用TensorBoard监控训练指标: ```bash tensorboard --logdir=output ``` 训练过程中会自动记录损失值、准确率、学习率等关键指标,方便进行实验对比和调参。 ## 注意力可视化 项目提供了Jupyter Notebook (`visualize_attention_map.ipynb`) 用于可视化模型注意力分布。通过设置`vis=True`启用注意力模块: ```python from models.modeling import VisionTransformer model = VisionTransformer(get_b16_config(), img_size=224, vis=True) ``` 可视化功能可以展示: - 多头注意力在不同头部的注意力模式 - 注意力权重在空间维度上的热力图分布 - 验证模型关注的关键图像区域 ## 项目结构 ``` vit_visual_transformer/ ├── models/ │ ├── configs.py # 模型配置文件,定义各种ViT变体的超参数 │ ├── modeling.py # ViT核心实现,包含Attention、MLP、Encoder等模块 │ └── modeling_resnet.py # ResNet-V2实现,用于混合模型 ├── utils/ │ ├── data_utils.py # 数据加载器实现 │ ├── dist_util.py # 分布式训练工具函数 │ └── scheduler.py # 学习率调度器实现 ├── img/ # 文档和可视化资源 │ ├── figure1.png │ ├── figure2.png │ └── figure3.png ├── train.py # 主训练脚本 ├── visualize_attention_map.ipynb # 注意力可视化Notebook └── README.md # 项目说明文档 ``` ## 核心模块说明 ### VisionTransformer类 主模型类,封装了完整的ViT前向传播逻辑: - `__init__(config, img_size, num_classes, zero_head, vis)`:初始化模型 - `forward(x, labels=None)`:前向传播,支持训练和推理模式 - `load_from(weights)`:加载预训练权重 ### Encoder类 Transformer编码器实现,包含多个堆叠的Block: - 支持配置参数自定义编码器深度和宽度 - 支持可视化模式输出中间注意力状态 ### Attention类 多头注意力机制实现: - 支持设置注意力头数和隐藏层维度 - 支持可视化模式输出注意力权重 ### 学习率调度器 项目实现了四种学习率调度策略: | 调度器 | 特点 | 适用场景 | |--------|------|----------| | ConstantLRSchedule | 学习率恒定不变 | 简单任务基线 | | WarmupConstantSchedule | 预热后保持恒定 | 稳定初期训练 | | WarmupLinearSchedule | 预热后线性下降 | 常规训练任务 | | WarmupCosineSchedule | 预热后余弦退火 | 需要精细收敛 | ## 贡献指南 欢迎社区贡献者提交Issue或Pull Request来改进本项目。在贡献之前,请确保: 1. 遵循项目现有的代码风格和规范 2. 对新功能提供完整的测试用例 3. 更新相关文档和README 4. 提交信息清晰描述修改内容和原因 ## 许可证 本项目遵循MIT许可证开源,您可以自由使用、修改和分发本项目代码,但需保留原始版权声明。 ## 参考文献 如果您在研究中使用了本项目,请引用原始论文: ``` @article{dosovitskiy2020image, title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale}, author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others}, journal={International Conference on Learning Representations}, year={2021} } ``` ## 联系方式 如有问题或建议,请在Gitee项目页面提交Issue,或通过项目维护者的联系方式获取帮助。