pyTorch_手写字识别

安装pytorch

1
pip3 install torch torchvision torchaudio

下载MNiST数据集

创建 dataset.py

1
2
3
4
5
6
7
8
import torch
import torchvision

def download_dataset():
dataset=torchvision.datasets.MNIST("./data",train=True,download=True)
return dataset
if __name__ == '__main__':
download_dataset()

保存训练模型测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from model import DNN

def save_model_fun1():
model=DNN(1,28,28)
#只保存模型参数
torch.save(model.state_dict(),"dnn_test1.pth")
def save_model_fun2():
model=DNN(1,28,28)
#保存整个类模型
torch.save(model,"dnn_test2.pt")
def load_model_test1(path):
model=DNN(1,28,28)
ckpt=torch.load(path)
def load_model_test2(path):
model=torch.load(path)
print(model)
if __name__ == '__main__':
save_model_fun1()
save_model_fun2()
load_model_test2("./dnn_test2.pt")

建议只保存 模型参数,会更加通用

如果是只有本人用的话,那保存整个类模型也可以

eval评估

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import  torch
from dataset import get_dataloader
from model import DNN

def eval_model(model,eval_dataloader,ckpt_path=None):
# 构造数据集
# 构造模型 导入保存的模型
if ckpt_path:
ckpt = torch.load(ckpt_path)
not_load=model.load_state_dict(ckpt,strict=True)

model.eval()

all_night_num=0
with torch.no_grad():
for images,labes in eval_dataloader:
images=images
labes=labes
images=images.reshape((-1,1*28*28))
#向前
output=model(images)
pre=output.max(1,keepdim=True)[1].reshape(labes.shape)
right_num=(pre==labes).sum()
all_night_num+=right_num
# print("pre: ")
# print(pre)
# print("labes ")
# print(labes)
# print("----------")
per=all_night_num/len(eval_dataloader.dataset)
print("per is {:.2f}%".format(per.cpu().item()*100))
return per
if __name__ == '__main__':
model=DNN(1,28,28)
eval_dataloader = get_dataloader(False)
#模型初始化
eval(model,eval_dataloader,"dnn.path")

保存最优模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from dataset import get_dataloader
from model import DNN
import eval

def train(epochs=10):
#构造数据集
train_data_loader = get_dataloader(True)
eval_dataloader = get_dataloader(False)

#模型初始化
#构造模型
dnn = DNN(1,28,28)

#构造优化器/学习率/loss
optimizer = torch.optim.Adam(dnn.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,len(train_data_loader),gamma=0.8)
nll_loss = torch.nn.NLLLoss()

#模型训练
per=0
for epoch in range(epochs):
dnn.train()
step = 0
for images,labels in train_data_loader:
optimizer.zero_grad()
images = images.reshape((-1,1*28*28))
# forward 前向
output = dnn(images)
loss = nll_loss(output,labels)
# backward 反向
loss.backward()


#参数更新
optimizer.step()
lr_scheduler.step()
step += 1
if step % 100 == 0:
print("epcoh: {} / {} ,step {} / {},lr: {},loss{}".format(
epoch + 1, epochs, step, len(train_data_loader),
lr_scheduler.get_last_lr(), loss))
per_now = eval.eval_model(dnn, eval_dataloader)
#保存模型
torch.save(dnn.state_dict(),"dnn_new.path")
if(per_now>per):
print("save best model per is {:.2f}%".format(per_now.cpu().item()*100))
torch.save(dnn.state_dict(),"dnn_best.path")
per=per_now


if __name__ == '__main__':
train()

infer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import cv2
from matplotlib import pyplot as plt
import numpy as np
import torch
import model
from dataset import get_dataloader
def image_pre_handle(image):
gray_image=cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)

gray_image=cv2.resize(gray_image,(28,28))[np.newaxis,:,:]

# image.cv2.imshow("image", gray_image)
# cv2.waitkey(0)
# cv2.destroyAllwinodows()
norm_image=gray_image/255
norm_image=(norm_image-0.1307)/0.3801
norm_tensor=torch.from_numpy(norm_image).view((-1,1*28*28))
return norm_tensor.to(torch.float32)
def infer(model,image_path):
#读取图片
image=cv2.imread(image_path)
# cv2.imshow("1", image)
# cv2.waitkey(0)
# cv2.destroyAllwinodows()

#图片预处理
input_data=image_pre_handle(image)
#模型推理
model.eval()
with torch.no_grad():
output=model(input_data)
pre=output.max(1,keepdim=True)[1]
print(pre)
cv2.imshow("1", image)
cv2.waitKey(0)
cv2.destroyAllWindows()


#参数更新

#保存模型
pass

if __name__ == '__main__':
model=model.DNN(1,28,28)
ckpt=torch.load("dnn_best.path")
model.load_state_dict(ckpt,strict=True)
print(model)
infer(model,"data/dataset/eval/1/1_0.jpg")

AlextNet

image-20230312150121641


pyTorch_手写字识别
http://example.com/2023/03/05/pytorch/
作者
CynicCat
发布于
2023年3月5日
许可协议