pytorch的快速入门


from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

cudnn.benchmark = True
plt.ion()   # interactive mode

transform = transforms.Compose(
    [transforms.ToTensor()]
)

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)

dataloaders = {'train':torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                                   shuffle=True, num_workers=2),
               'val':torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                                 shuffle=False, num_workers=2)
}

dataset_sizes = {'train':60000, 'val':10000}

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


import torch.nn as nn
import torch.nn.functional as F



class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Conv2d(1, 16, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*3*3, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )
 
    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits
 
model_ft = NeuralNetwork()
print(model_ft)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


model_ft = model_ft.to(device)


criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model_ft.parameters())

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)


NeuralNetwork(
  (linear_relu_stack): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (7): ReLU()
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=576, out_features=64, bias=True)
    (10): ReLU()
    (11): Linear(in_features=64, out_features=10, bias=True)
  )
)

Epoch 0/9
----------
train Loss: 0.1799 Acc: 0.9451
val Loss: 0.0493 Acc: 0.9834

Epoch 1/9
----------
train Loss: 0.0566 Acc: 0.9819
val Loss: 0.0512 Acc: 0.9838

Epoch 2/9
----------
train Loss: 0.0380 Acc: 0.9882
val Loss: 0.0618 Acc: 0.9804

Epoch 3/9
----------
train Loss: 0.0307 Acc: 0.9903
val Loss: 0.0398 Acc: 0.9875

Epoch 4/9
----------
train Loss: 0.0238 Acc: 0.9925
val Loss: 0.0351 Acc: 0.9886

Epoch 5/9
----------
train Loss: 0.0201 Acc: 0.9936
val Loss: 0.0416 Acc: 0.9878

Epoch 6/9
----------
train Loss: 0.0173 Acc: 0.9945
val Loss: 0.0268 Acc: 0.9907

Epoch 7/9
----------
train Loss: 0.0056 Acc: 0.9985
val Loss: 0.0256 Acc: 0.9917

Epoch 8/9
----------
train Loss: 0.0035 Acc: 0.9991
val Loss: 0.0253 Acc: 0.9927

Epoch 9/9
----------
train Loss: 0.0025 Acc: 0.9994
val Loss: 0.0262 Acc: 0.9924

Training complete in 4m 40s
Best val Acc: 0.992700

相关文章

学习编程是顺着互联网的发展潮流,是一件好事。新手如何学习...
IT行业是什么工作做什么?IT行业的工作有:产品策划类、页面...
女生学Java好就业吗?女生适合学Java编程吗?目前有不少女生...
Can’t connect to local MySQL server through socket \'/v...
oracle基本命令 一、登录操作 1.管理员登录 # 管理员登录 ...
一、背景 因为项目中需要通北京网络,所以需要连vpn,但是服...