今天给大家带来的是VGG16深度网络的解析,虽然VGG比较出名的还有VGG19深度网络,但是本文只对VGG16网络做一下解析,至于其他变种,稍微修改一下网络层数和参数即可实现,如果这个系列比较复杂的VGG16网络可以自己搭建出来,那么其他的变种也肯定没问题。
代码git地址:https://gitee.com/xiaosidegitee/xs-dl/tree/master/vgg/pytorch
1 VGG网络结构
论文地址:https://arxiv.org/pdf/1409.1556.pdf
VGG算法团队研究了卷积网络深度在大规模的图像识别环境下对准确性的影响。其主要贡献是使用非常小的(3×3)卷积滤波器架构对网络深度的增加进行了全面评估,这表明通过将深度推到16-19加权层可以实现对现有技术配置的显著改进。也正是因为这些改进,使得VGG算法团队在定位和分类过程中分别获得了第一名和第二名。如图可以看到VGG16包含13个卷积层和三个全连接层,VGG19包含16个卷积层和三个全连接层。接下来我们会详细的剖析整个网络的搭建过程。
2 VGG16网络参数
layer_name | kernel_size | kernel_num | stride | padding | input_size | output_size |
conv1 | 3 | 64 | 1 | [0, 1] | 224*224*3 | 224*224*64 |
conv2 | 3 | 64 | 1 | [0, 1] | 224*224*64 | 224*224*64 |
max_pool1 | 2 | / | 2 | / | 224*224*64 | 112*112*64 |
conv3 | 3 | 128 | 1 | [0, 1] | 112*112*64 | 112*112*128 |
conv4 | 3 | 128 | 1 | [0, 1] | 112*112*128 | 112*112*128 |
max_pool2 | 2 | / | 2 | / | 112*112*128 | 56*56*128 |
conv5 | 3 | 256 | 1 | [0, 1] | 56*56*128 | 56*56*256 |
conv6 | 3 | 256 | 1 | [0, 1] | 56*56*256 | 56*56*256 |
conv7 | 3 | 256 | 1 | [0, 1] | 56*56*256 | 56*56*256 |
max_pool3 | 2 | / | 2 | / | 56*56*256 | 28*28*256 |
conv8 | 3 | 512 | 1 | [0 ,1] | 28*28*256 | 28*28*512 |
conv9 | 3 | 512 | 1 | [0, 1] | 28*28*512 | 28*28*512 |
conv10 | 3 | 512 | 1 | [0, 1] | 28*28*512 | 28*28*512 |
max_pool4 | 3 | 512 | 1 | [0, 1] | 14*14*512 | 14*14*512 |
conv12 | 3 | 512 | 1 | [0, 1] | 14*14*512 | 14*14*512 |
conv13 | 3 | 512 | 1 | [0, 1] | 14*14*512 | 14*14*512 |
max_pool5 | 2 | / | 2 | / | 14*14*512 | 7*7*512 |
fc1 | 4096 | / | / | / | 7*7*512 | 4096 |
fc2 | 4096 | / | / | / | 4096 | 4096 |
fc3 | 1000 | / | / | / | 4096 | 1000 |
上图的参数是根据原论文的数据和源代码中的参数反推出来的,大部分参数直接根据原论文的网络图即可获得,其中其他参数的计算可以参照我另个一篇博文:CV(计算机视觉)领域四大类之图像分类一(AlexNet) - 知乎 (zhihu.com),里边有详细的参数计算过程。
3 pytarch实现
(1)数据准备make_data.py(花分类数据集)
import osfrom PIL import Image# 下载数据集地址(手动下载解压即可)DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'# 解压数据集的路径(自己定义即可)flower_photos = "G:\\alexnet\\flower_photos\\"# 训练数据路径(自己定义即可)base_url = "G:\\alexnet\\train_data\\"for item in os.listdir(flower_photos): path_temp = flower_photos + item n = 0 for name in os.listdir(path_temp): n += 1 img = Image.open(path_temp + "\\" + name) # 转换通道 img = img.convert("RGB") # 验证集(20%验证集,80%数据集,可自行调节) if n % 8 == 0: if not os.path.exists(base_url + "val\\" + item): os.makedirs(base_url + "val\\" + item) img.save(base_url + "val\\" + item + "\\" + name) else: if not os.path.exists(base_url + "train\\" + item): os.makedirs(base_url + "train\\" + item) img.save(base_url + "train\\" + item + "\\" + name)
(2)vgg16模型网络(model.py)
import torchimport torch.nn as nnclass VGG(nn.Module): def __init__(self, num_classes=5): super(VGG, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2) ) self.classifier = nn.Sequential( nn.Linear(512*7*7, 4096), nn.ReLU(True), nn.Dropout(p=0.5), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(p=0.5), nn.Linear(4096, num_classes) ) def forward(self, x): # N x 3 x 224 x 224 x = self.features(x) # N x 512 x 7 x 7 x = torch.flatten(x, start_dim=1) # N x 512*7*7 x = self.classifier(x) return x
(3)训练文件(train.py)
import osimport sysimport jsonimport torchimport torch.nn as nnfrom torchvision import transforms, datasetsimport torch.optim as optimfrom tqdm import tqdmfrom model import VGGdef main(): # 确定是否可以启动GPU训练 # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # (如果没有GPU环境则可以直接选择CPU训练) device = torch.device("cpu") # 设置训练集和验证集格变换规则 train_form = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) val_form = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 训练集 image_path = "G:\\alexnet\\train_data\\" train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=train_form) # 由于我本机只有一个显卡,所以num_workers设置为0了 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0) # 验证集 validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=val_form) val_num = len(validate_dataset) validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=2, shuffle=False, num_workers=0) net = VGG(num_classes=5) net.to(device) loss_function = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=0.0001) epochs = 30 best_acc = 0.0 save_path = './vgg16-Net.pth' train_steps = len(train_loader) for epoch in range(epochs): # train net.train() running_loss = 0.0 train_bar = tqdm(train_loader, file=sys.stdout) for step, data in enumerate(train_bar): images, labels = data optimizer.zero_grad() outputs = net(images.to(device)) loss = loss_function(outputs, labels.to(device)) loss.backward() optimizer.step() # print statistics running_loss += loss.item() train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss) # validate net.eval() acc = 0.0 # accumulate accurate number / epoch with torch.no_grad(): val_bar = tqdm(validate_loader, file=sys.stdout) for val_data in val_bar: val_images, val_labels = val_data outputs = net(val_images.to(device)) predict_y = torch.max(outputs, dim=1)[1] acc += torch.eq(predict_y, val_labels.to(device)).sum().item() val_accurate = acc / val_num print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, running_loss / train_steps, val_accurate)) if val_accurate > best_acc: best_acc = val_accurate torch.save(net.state_dict(), save_path) print('Finished Training')if __name__ == '__main__': main()
(4)预测文件(predict.py)
import osimport jsonimport torchfrom PIL import Imagefrom torchvision import transformsimport matplotlib.pyplot as pltfrom model import VGGdef main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # load image img_path = "1.jpg" assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) img = Image.open(img_path) plt.imshow(img) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0) # read class_indict json_path = './class_indices.json' assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) with open(json_path, "r") as f: class_indict = json.load(f) # create model model = vgg(model_name="vgg16", num_classes=5).to(device) # load model weights weights_path = "./vgg16Net.pth" assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path) model.load_state_dict(torch.load(weights_path, map_location=device)) model.eval() with torch.no_grad(): # predict class output = torch.squeeze(model(img.to(device))).cpu() predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy() print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy()) plt.title(print_res) for i in range(len(predict)): print("class: {:10} prob: {:.3}".format(class_indict[str(i)], predict[i].numpy())) plt.show()if __name__ == '__main__': main()
4 VGG网络优势
(1)VGG网络采用重复堆叠的小卷积核替代大卷积核,在保证具有相同感受野的条件下,减少了网络的参数,提升了网络的深度,从而提升网络特征提取的能力。三个3*3的卷积核和一个7*7感受野相同,但是参数仅仅是后者的55%。
(2)提升的网络深度都使用ReLU激活函数:提升非线性变化的能力
版权声明:内容来源于互联网和用户投稿 如有侵权请联系删除