1 Star 0 Fork 0

QFork/Bert-Chinese-Text-Classification

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
bert_test.py 1.08 KB
一键复制 编辑 原始数据 按行查看 历史
QijingGJ 提交于 2024-02-08 12:21 +08:00 . first commit
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 2 11:05:55 2024
@author: QiJing
"""
import os
import torch
from bert_get_data import BertClassifier, GenerateData
from torch.utils.data import DataLoader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
save_path = './bert_checkpoint'
model = BertClassifier()
model.load_state_dict(torch.load(os.path.join(save_path, 'best.pt')))
model = model.to(device)
model.eval()
def evaluate(model, dataset):
model.eval()
test_loader = DataLoader(dataset, batch_size=128)
total_acc_test = 0
with torch.no_grad():
for test_input, test_label in test_loader:
input_id = test_input['input_ids'].squeeze(1).to(device)
mask = test_input['attention_mask'].to(device)
test_label = test_label.to(device)
output = model(input_id, mask)
acc = (output.argmax(dim=1) == test_label).sum().item()
total_acc_test += acc
print(f'Test Accuracy: {total_acc_test / len(dataset): .3f}')
test_dataset = GenerateData(mode="test")
evaluate(model, test_dataset)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/qfork/Bert-Chinese-Text-Classification.git
git@gitee.com:qfork/Bert-Chinese-Text-Classification.git
qfork
Bert-Chinese-Text-Classification
Bert-Chinese-Text-Classification
main

搜索帮助