MLP to CNN

CNN-OnlyConv

Trước khi vào bài code thì chúng ta sẽ xem nhược điểm của các mô hình như MLP trên một số tập dữ liệu như FashionMNIST, Cifar10. Tôi sẽ áp dụng Conv để so sánh độ hiệu quả.

MLP_FashionMNIST_ReLU_He_Adam
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.init as init

# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Data

# Load CFashionMNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1024, num_workers=10, shuffle=True, drop_last=True)

testset = torchvision.datasets.FashionMNIST(root='data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1024, num_workers=10, shuffle=False)
import matplotlib.pyplot as plt
import numpy as np

# Function to display the images
def imshow(img):
    img = img*0.5 + 0.5
    np_img = img.numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()

for i, (images, labels) in enumerate(trainloader, 0):
    # Plot some images
    imshow(torchvision.utils.make_grid(images[:8]))  # Display 8 images from the batch
    break
image

Model

model = nn.Sequential(
    nn.Flatten(), 
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

# Initialize the weights using He initialization
for layer in model:
    if isinstance(layer, nn.Linear):
        init.kaiming_uniform_(layer.weight, nonlinearity='relu')
        if layer.bias is not None:
            layer.bias.data.fill_(0)
            
model = model.to(device)
print(model)
Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=256, bias=True)
  (2): ReLU()
  (3): Linear(in_features=256, out_features=10, bias=True)
)

Loss, Optimizer, and Evaluation Function

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Function to compute loss and accuracy for test set
def evaluate(model, testloader, criterion):
    model.eval()
    test_loss = 0.0
    running_correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            # Move inputs and labels to the device
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

    accuracy = 100 * running_correct / total
    test_loss = test_loss / len(testloader)
    return test_loss, accuracy

Train

# some parameter
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
max_epoch = 250
# train
for epoch in range(max_epoch):
    model.train()
    running_loss = 0.0
    running_correct = 0   # to track number of correct predictions
    total = 0             # to track total number of samples

    for i, (inputs, labels) in enumerate(trainloader, 0):
        # Move inputs and labels to the device
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

        # Determine class predictions and track accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        running_correct += (predicted == labels).sum().item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()        

    epoch_accuracy = 100 * running_correct / total
    epoch_loss = running_loss / (i + 1)
    
    test_loss, test_accuracy = evaluate(model, testloader, criterion)
    print(f"Epoch [{epoch + 1}/{max_epoch}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    # save for plot
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_accuracy)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
Epoch [1/250], Loss: 1.3692, Accuracy: 54.26%, Test Loss: 0.8277, Test Accuracy: 72.52%
Epoch [2/250], Loss: 0.7028, Accuracy: 76.04%, Test Loss: 0.6484, Test Accuracy: 77.70%
Epoch [3/250], Loss: 0.5892, Accuracy: 79.73%, Test Loss: 0.5793, Test Accuracy: 79.91%
Epoch [4/250], Loss: 0.5340, Accuracy: 81.46%, Test Loss: 0.5390, Test Accuracy: 81.28%
Epoch [5/250], Loss: 0.4985, Accuracy: 82.75%, Test Loss: 0.5128, Test Accuracy: 82.03%
Epoch [6/250], Loss: 0.4735, Accuracy: 83.59%, Test Loss: 0.4946, Test Accuracy: 82.63%
Epoch [7/250], Loss: 0.4549, Accuracy: 84.20%, Test Loss: 0.4792, Test Accuracy: 83.21%
Epoch [8/250], Loss: 0.4408, Accuracy: 84.73%, Test Loss: 0.4684, Test Accuracy: 83.45%
Epoch [9/250], Loss: 0.4278, Accuracy: 85.15%, Test Loss: 0.4574, Test Accuracy: 83.82%
Epoch [10/250], Loss: 0.4174, Accuracy: 85.53%, Test Loss: 0.4494, Test Accuracy: 84.02%
Epoch [11/250], Loss: 0.4087, Accuracy: 85.84%, Test Loss: 0.4419, Test Accuracy: 84.21%
Epoch [12/250], Loss: 0.3999, Accuracy: 86.22%, Test Loss: 0.4355, Test Accuracy: 84.38%
Epoch [13/250], Loss: 0.3929, Accuracy: 86.27%, Test Loss: 0.4294, Test Accuracy: 84.74%
Epoch [14/250], Loss: 0.3859, Accuracy: 86.57%, Test Loss: 0.4234, Test Accuracy: 84.95%
Epoch [15/250], Loss: 0.3807, Accuracy: 86.72%, Test Loss: 0.4188, Test Accuracy: 85.08%
Epoch [16/250], Loss: 0.3752, Accuracy: 86.91%, Test Loss: 0.4161, Test Accuracy: 85.13%
Epoch [17/250], Loss: 0.3706, Accuracy: 87.00%, Test Loss: 0.4143, Test Accuracy: 85.39%
Epoch [18/250], Loss: 0.3652, Accuracy: 87.22%, Test Loss: 0.4080, Test Accuracy: 85.35%
Epoch [19/250], Loss: 0.3605, Accuracy: 87.35%, Test Loss: 0.4043, Test Accuracy: 85.41%
Epoch [20/250], Loss: 0.3563, Accuracy: 87.49%, Test Loss: 0.4003, Test Accuracy: 85.54%
Epoch [21/250], Loss: 0.3515, Accuracy: 87.64%, Test Loss: 0.3980, Test Accuracy: 85.80%
Epoch [22/250], Loss: 0.3481, Accuracy: 87.81%, Test Loss: 0.3950, Test Accuracy: 85.94%
Epoch [23/250], Loss: 0.3437, Accuracy: 87.94%, Test Loss: 0.3927, Test Accuracy: 86.04%
Epoch [24/250], Loss: 0.3401, Accuracy: 88.11%, Test Loss: 0.3893, Test Accuracy: 86.09%
Epoch [25/250], Loss: 0.3380, Accuracy: 88.14%, Test Loss: 0.3868, Test Accuracy: 86.20%
Epoch [26/250], Loss: 0.3344, Accuracy: 88.28%, Test Loss: 0.3856, Test Accuracy: 86.18%
Epoch [27/250], Loss: 0.3304, Accuracy: 88.44%, Test Loss: 0.3827, Test Accuracy: 86.15%
Epoch [28/250], Loss: 0.3276, Accuracy: 88.42%, Test Loss: 0.3804, Test Accuracy: 86.20%
Epoch [29/250], Loss: 0.3241, Accuracy: 88.64%, Test Loss: 0.3781, Test Accuracy: 86.39%
Epoch [30/250], Loss: 0.3223, Accuracy: 88.64%, Test Loss: 0.3766, Test Accuracy: 86.38%
Epoch [31/250], Loss: 0.3193, Accuracy: 88.81%, Test Loss: 0.3748, Test Accuracy: 86.74%
Epoch [32/250], Loss: 0.3158, Accuracy: 88.91%, Test Loss: 0.3748, Test Accuracy: 86.49%
Epoch [33/250], Loss: 0.3140, Accuracy: 89.00%, Test Loss: 0.3705, Test Accuracy: 86.96%
Epoch [34/250], Loss: 0.3107, Accuracy: 89.12%, Test Loss: 0.3731, Test Accuracy: 86.54%
Epoch [35/250], Loss: 0.3076, Accuracy: 89.22%, Test Loss: 0.3676, Test Accuracy: 86.91%
Epoch [36/250], Loss: 0.3060, Accuracy: 89.22%, Test Loss: 0.3668, Test Accuracy: 87.07%
Epoch [37/250], Loss: 0.3035, Accuracy: 89.33%, Test Loss: 0.3671, Test Accuracy: 86.93%
Epoch [38/250], Loss: 0.3018, Accuracy: 89.40%, Test Loss: 0.3648, Test Accuracy: 87.00%
Epoch [39/250], Loss: 0.2992, Accuracy: 89.50%, Test Loss: 0.3630, Test Accuracy: 87.07%
Epoch [40/250], Loss: 0.2971, Accuracy: 89.53%, Test Loss: 0.3615, Test Accuracy: 86.97%
Epoch [41/250], Loss: 0.2946, Accuracy: 89.69%, Test Loss: 0.3587, Test Accuracy: 87.15%
Epoch [42/250], Loss: 0.2924, Accuracy: 89.78%, Test Loss: 0.3581, Test Accuracy: 87.16%
Epoch [43/250], Loss: 0.2907, Accuracy: 89.78%, Test Loss: 0.3583, Test Accuracy: 87.43%
Epoch [44/250], Loss: 0.2886, Accuracy: 89.89%, Test Loss: 0.3553, Test Accuracy: 87.38%
Epoch [45/250], Loss: 0.2871, Accuracy: 89.92%, Test Loss: 0.3561, Test Accuracy: 87.25%
Epoch [46/250], Loss: 0.2850, Accuracy: 89.97%, Test Loss: 0.3541, Test Accuracy: 87.30%
Epoch [47/250], Loss: 0.2832, Accuracy: 90.12%, Test Loss: 0.3552, Test Accuracy: 87.47%
Epoch [48/250], Loss: 0.2812, Accuracy: 90.10%, Test Loss: 0.3527, Test Accuracy: 87.45%
Epoch [49/250], Loss: 0.2787, Accuracy: 90.27%, Test Loss: 0.3523, Test Accuracy: 87.41%
Epoch [50/250], Loss: 0.2781, Accuracy: 90.26%, Test Loss: 0.3491, Test Accuracy: 87.36%
Epoch [51/250], Loss: 0.2751, Accuracy: 90.38%, Test Loss: 0.3493, Test Accuracy: 87.36%
Epoch [52/250], Loss: 0.2729, Accuracy: 90.46%, Test Loss: 0.3480, Test Accuracy: 87.54%
Epoch [53/250], Loss: 0.2725, Accuracy: 90.55%, Test Loss: 0.3471, Test Accuracy: 87.55%
Epoch [54/250], Loss: 0.2707, Accuracy: 90.52%, Test Loss: 0.3465, Test Accuracy: 87.63%
Epoch [55/250], Loss: 0.2692, Accuracy: 90.64%, Test Loss: 0.3455, Test Accuracy: 87.62%
Epoch [56/250], Loss: 0.2674, Accuracy: 90.67%, Test Loss: 0.3458, Test Accuracy: 87.61%
Epoch [57/250], Loss: 0.2660, Accuracy: 90.68%, Test Loss: 0.3443, Test Accuracy: 87.48%
Epoch [58/250], Loss: 0.2642, Accuracy: 90.82%, Test Loss: 0.3431, Test Accuracy: 87.61%
Epoch [59/250], Loss: 0.2629, Accuracy: 90.84%, Test Loss: 0.3415, Test Accuracy: 87.69%
Epoch [60/250], Loss: 0.2611, Accuracy: 90.92%, Test Loss: 0.3415, Test Accuracy: 87.68%
Epoch [61/250], Loss: 0.2599, Accuracy: 90.99%, Test Loss: 0.3415, Test Accuracy: 87.65%
Epoch [62/250], Loss: 0.2579, Accuracy: 91.08%, Test Loss: 0.3395, Test Accuracy: 87.53%
Epoch [63/250], Loss: 0.2561, Accuracy: 91.10%, Test Loss: 0.3399, Test Accuracy: 87.74%
Epoch [64/250], Loss: 0.2556, Accuracy: 91.11%, Test Loss: 0.3396, Test Accuracy: 87.60%
Epoch [65/250], Loss: 0.2534, Accuracy: 91.15%, Test Loss: 0.3392, Test Accuracy: 87.66%
Epoch [66/250], Loss: 0.2527, Accuracy: 91.23%, Test Loss: 0.3374, Test Accuracy: 87.89%
Epoch [67/250], Loss: 0.2509, Accuracy: 91.26%, Test Loss: 0.3371, Test Accuracy: 87.83%
Epoch [68/250], Loss: 0.2499, Accuracy: 91.33%, Test Loss: 0.3391, Test Accuracy: 88.02%
Epoch [69/250], Loss: 0.2485, Accuracy: 91.33%, Test Loss: 0.3355, Test Accuracy: 87.95%
Epoch [70/250], Loss: 0.2467, Accuracy: 91.37%, Test Loss: 0.3365, Test Accuracy: 87.95%
Epoch [71/250], Loss: 0.2456, Accuracy: 91.46%, Test Loss: 0.3337, Test Accuracy: 88.10%
Epoch [72/250], Loss: 0.2435, Accuracy: 91.57%, Test Loss: 0.3338, Test Accuracy: 87.85%
Epoch [73/250], Loss: 0.2435, Accuracy: 91.61%, Test Loss: 0.3347, Test Accuracy: 88.00%
Epoch [74/250], Loss: 0.2415, Accuracy: 91.67%, Test Loss: 0.3327, Test Accuracy: 88.11%
Epoch [75/250], Loss: 0.2400, Accuracy: 91.63%, Test Loss: 0.3318, Test Accuracy: 87.88%
Epoch [76/250], Loss: 0.2388, Accuracy: 91.72%, Test Loss: 0.3320, Test Accuracy: 88.10%
Epoch [77/250], Loss: 0.2374, Accuracy: 91.75%, Test Loss: 0.3345, Test Accuracy: 88.13%
Epoch [78/250], Loss: 0.2365, Accuracy: 91.78%, Test Loss: 0.3312, Test Accuracy: 88.11%
Epoch [79/250], Loss: 0.2355, Accuracy: 91.85%, Test Loss: 0.3324, Test Accuracy: 88.18%
Epoch [80/250], Loss: 0.2348, Accuracy: 91.84%, Test Loss: 0.3323, Test Accuracy: 88.19%
Epoch [81/250], Loss: 0.2326, Accuracy: 91.98%, Test Loss: 0.3305, Test Accuracy: 88.20%
Epoch [82/250], Loss: 0.2318, Accuracy: 91.92%, Test Loss: 0.3322, Test Accuracy: 88.02%
Epoch [83/250], Loss: 0.2308, Accuracy: 91.96%, Test Loss: 0.3296, Test Accuracy: 88.16%
Epoch [84/250], Loss: 0.2291, Accuracy: 92.10%, Test Loss: 0.3281, Test Accuracy: 88.25%
Epoch [85/250], Loss: 0.2288, Accuracy: 92.07%, Test Loss: 0.3281, Test Accuracy: 88.46%
Epoch [86/250], Loss: 0.2271, Accuracy: 92.16%, Test Loss: 0.3279, Test Accuracy: 88.29%
Epoch [87/250], Loss: 0.2262, Accuracy: 92.25%, Test Loss: 0.3272, Test Accuracy: 88.29%
Epoch [88/250], Loss: 0.2258, Accuracy: 92.18%, Test Loss: 0.3270, Test Accuracy: 88.34%
Epoch [89/250], Loss: 0.2248, Accuracy: 92.20%, Test Loss: 0.3281, Test Accuracy: 88.37%
Epoch [90/250], Loss: 0.2228, Accuracy: 92.32%, Test Loss: 0.3282, Test Accuracy: 88.08%
Epoch [91/250], Loss: 0.2222, Accuracy: 92.31%, Test Loss: 0.3270, Test Accuracy: 88.27%
Epoch [92/250], Loss: 0.2207, Accuracy: 92.45%, Test Loss: 0.3272, Test Accuracy: 88.34%
Epoch [93/250], Loss: 0.2195, Accuracy: 92.42%, Test Loss: 0.3264, Test Accuracy: 88.30%
Epoch [94/250], Loss: 0.2193, Accuracy: 92.35%, Test Loss: 0.3247, Test Accuracy: 88.30%
Epoch [95/250], Loss: 0.2174, Accuracy: 92.52%, Test Loss: 0.3248, Test Accuracy: 88.16%
Epoch [96/250], Loss: 0.2162, Accuracy: 92.56%, Test Loss: 0.3248, Test Accuracy: 88.31%
Epoch [97/250], Loss: 0.2158, Accuracy: 92.59%, Test Loss: 0.3233, Test Accuracy: 88.46%
Epoch [98/250], Loss: 0.2142, Accuracy: 92.62%, Test Loss: 0.3244, Test Accuracy: 88.41%
Epoch [99/250], Loss: 0.2133, Accuracy: 92.66%, Test Loss: 0.3232, Test Accuracy: 88.44%
Epoch [100/250], Loss: 0.2126, Accuracy: 92.73%, Test Loss: 0.3230, Test Accuracy: 88.31%
Epoch [101/250], Loss: 0.2118, Accuracy: 92.73%, Test Loss: 0.3230, Test Accuracy: 88.51%
Epoch [102/250], Loss: 0.2104, Accuracy: 92.79%, Test Loss: 0.3229, Test Accuracy: 88.49%
Epoch [103/250], Loss: 0.2089, Accuracy: 92.77%, Test Loss: 0.3240, Test Accuracy: 88.47%
Epoch [104/250], Loss: 0.2084, Accuracy: 92.88%, Test Loss: 0.3210, Test Accuracy: 88.60%
Epoch [105/250], Loss: 0.2080, Accuracy: 92.90%, Test Loss: 0.3222, Test Accuracy: 88.48%
Epoch [106/250], Loss: 0.2063, Accuracy: 92.88%, Test Loss: 0.3217, Test Accuracy: 88.58%
Epoch [107/250], Loss: 0.2052, Accuracy: 92.87%, Test Loss: 0.3215, Test Accuracy: 88.52%
Epoch [108/250], Loss: 0.2044, Accuracy: 93.05%, Test Loss: 0.3222, Test Accuracy: 88.51%
Epoch [109/250], Loss: 0.2041, Accuracy: 93.05%, Test Loss: 0.3234, Test Accuracy: 88.24%
Epoch [110/250], Loss: 0.2030, Accuracy: 93.08%, Test Loss: 0.3217, Test Accuracy: 88.41%
Epoch [111/250], Loss: 0.2014, Accuracy: 93.13%, Test Loss: 0.3198, Test Accuracy: 88.50%
Epoch [112/250], Loss: 0.2008, Accuracy: 93.19%, Test Loss: 0.3202, Test Accuracy: 88.51%
Epoch [113/250], Loss: 0.2000, Accuracy: 93.22%, Test Loss: 0.3202, Test Accuracy: 88.53%
Epoch [114/250], Loss: 0.1988, Accuracy: 93.23%, Test Loss: 0.3201, Test Accuracy: 88.71%
Epoch [115/250], Loss: 0.1985, Accuracy: 93.27%, Test Loss: 0.3205, Test Accuracy: 88.64%
Epoch [116/250], Loss: 0.1976, Accuracy: 93.30%, Test Loss: 0.3203, Test Accuracy: 88.68%
Epoch [117/250], Loss: 0.1961, Accuracy: 93.36%, Test Loss: 0.3209, Test Accuracy: 88.82%
Epoch [118/250], Loss: 0.1952, Accuracy: 93.41%, Test Loss: 0.3187, Test Accuracy: 88.49%
Epoch [119/250], Loss: 0.1947, Accuracy: 93.41%, Test Loss: 0.3197, Test Accuracy: 88.54%
Epoch [120/250], Loss: 0.1934, Accuracy: 93.48%, Test Loss: 0.3200, Test Accuracy: 88.59%
Epoch [121/250], Loss: 0.1938, Accuracy: 93.40%, Test Loss: 0.3207, Test Accuracy: 88.80%
Epoch [122/250], Loss: 0.1915, Accuracy: 93.49%, Test Loss: 0.3192, Test Accuracy: 88.60%
Epoch [123/250], Loss: 0.1911, Accuracy: 93.53%, Test Loss: 0.3179, Test Accuracy: 88.77%
Epoch [124/250], Loss: 0.1908, Accuracy: 93.53%, Test Loss: 0.3225, Test Accuracy: 88.55%
Epoch [125/250], Loss: 0.1898, Accuracy: 93.63%, Test Loss: 0.3190, Test Accuracy: 88.60%
Epoch [126/250], Loss: 0.1890, Accuracy: 93.62%, Test Loss: 0.3176, Test Accuracy: 88.54%
Epoch [127/250], Loss: 0.1878, Accuracy: 93.67%, Test Loss: 0.3171, Test Accuracy: 88.78%
Epoch [128/250], Loss: 0.1872, Accuracy: 93.65%, Test Loss: 0.3182, Test Accuracy: 88.87%
Epoch [129/250], Loss: 0.1860, Accuracy: 93.75%, Test Loss: 0.3182, Test Accuracy: 88.56%
Epoch [130/250], Loss: 0.1853, Accuracy: 93.73%, Test Loss: 0.3180, Test Accuracy: 88.56%
Epoch [131/250], Loss: 0.1851, Accuracy: 93.80%, Test Loss: 0.3189, Test Accuracy: 88.81%
Epoch [132/250], Loss: 0.1837, Accuracy: 93.83%, Test Loss: 0.3180, Test Accuracy: 88.81%
Epoch [133/250], Loss: 0.1829, Accuracy: 93.90%, Test Loss: 0.3175, Test Accuracy: 88.81%
Epoch [134/250], Loss: 0.1823, Accuracy: 93.88%, Test Loss: 0.3176, Test Accuracy: 88.69%
Epoch [135/250], Loss: 0.1812, Accuracy: 93.92%, Test Loss: 0.3174, Test Accuracy: 88.81%
Epoch [136/250], Loss: 0.1803, Accuracy: 94.01%, Test Loss: 0.3162, Test Accuracy: 88.84%
Epoch [137/250], Loss: 0.1797, Accuracy: 94.00%, Test Loss: 0.3175, Test Accuracy: 88.86%
Epoch [138/250], Loss: 0.1788, Accuracy: 94.01%, Test Loss: 0.3188, Test Accuracy: 88.80%
Epoch [139/250], Loss: 0.1780, Accuracy: 94.05%, Test Loss: 0.3161, Test Accuracy: 88.83%
Epoch [140/250], Loss: 0.1777, Accuracy: 94.12%, Test Loss: 0.3164, Test Accuracy: 88.96%
Epoch [141/250], Loss: 0.1768, Accuracy: 94.12%, Test Loss: 0.3181, Test Accuracy: 88.72%
Epoch [142/250], Loss: 0.1763, Accuracy: 94.14%, Test Loss: 0.3159, Test Accuracy: 88.88%
Epoch [143/250], Loss: 0.1750, Accuracy: 94.20%, Test Loss: 0.3185, Test Accuracy: 88.87%
Epoch [144/250], Loss: 0.1743, Accuracy: 94.24%, Test Loss: 0.3157, Test Accuracy: 89.04%
Epoch [145/250], Loss: 0.1737, Accuracy: 94.25%, Test Loss: 0.3154, Test Accuracy: 88.92%
Epoch [146/250], Loss: 0.1728, Accuracy: 94.26%, Test Loss: 0.3160, Test Accuracy: 88.84%
Epoch [147/250], Loss: 0.1721, Accuracy: 94.26%, Test Loss: 0.3156, Test Accuracy: 89.00%
Epoch [148/250], Loss: 0.1712, Accuracy: 94.36%, Test Loss: 0.3168, Test Accuracy: 88.92%
Epoch [149/250], Loss: 0.1706, Accuracy: 94.41%, Test Loss: 0.3153, Test Accuracy: 88.98%
Epoch [150/250], Loss: 0.1694, Accuracy: 94.40%, Test Loss: 0.3171, Test Accuracy: 88.91%
Epoch [151/250], Loss: 0.1693, Accuracy: 94.44%, Test Loss: 0.3160, Test Accuracy: 88.92%
Epoch [152/250], Loss: 0.1687, Accuracy: 94.40%, Test Loss: 0.3141, Test Accuracy: 89.15%
Epoch [153/250], Loss: 0.1684, Accuracy: 94.40%, Test Loss: 0.3183, Test Accuracy: 88.75%
Epoch [154/250], Loss: 0.1669, Accuracy: 94.48%, Test Loss: 0.3151, Test Accuracy: 89.07%
Epoch [155/250], Loss: 0.1666, Accuracy: 94.53%, Test Loss: 0.3158, Test Accuracy: 88.84%
Epoch [156/250], Loss: 0.1659, Accuracy: 94.59%, Test Loss: 0.3158, Test Accuracy: 88.99%
Epoch [157/250], Loss: 0.1644, Accuracy: 94.64%, Test Loss: 0.3170, Test Accuracy: 88.90%
Epoch [158/250], Loss: 0.1642, Accuracy: 94.66%, Test Loss: 0.3157, Test Accuracy: 88.97%
Epoch [159/250], Loss: 0.1642, Accuracy: 94.58%, Test Loss: 0.3139, Test Accuracy: 89.12%
Epoch [160/250], Loss: 0.1621, Accuracy: 94.69%, Test Loss: 0.3143, Test Accuracy: 89.09%
Epoch [161/250], Loss: 0.1623, Accuracy: 94.68%, Test Loss: 0.3194, Test Accuracy: 88.77%
Epoch [162/250], Loss: 0.1614, Accuracy: 94.77%, Test Loss: 0.3152, Test Accuracy: 89.13%
Epoch [163/250], Loss: 0.1608, Accuracy: 94.73%, Test Loss: 0.3150, Test Accuracy: 89.15%
Epoch [164/250], Loss: 0.1601, Accuracy: 94.78%, Test Loss: 0.3147, Test Accuracy: 89.14%
Epoch [165/250], Loss: 0.1589, Accuracy: 94.84%, Test Loss: 0.3154, Test Accuracy: 89.06%
Epoch [166/250], Loss: 0.1582, Accuracy: 94.83%, Test Loss: 0.3176, Test Accuracy: 89.03%
Epoch [167/250], Loss: 0.1576, Accuracy: 94.92%, Test Loss: 0.3155, Test Accuracy: 89.18%
Epoch [168/250], Loss: 0.1571, Accuracy: 94.89%, Test Loss: 0.3164, Test Accuracy: 89.04%
Epoch [169/250], Loss: 0.1566, Accuracy: 94.96%, Test Loss: 0.3176, Test Accuracy: 89.11%
Epoch [170/250], Loss: 0.1558, Accuracy: 94.94%, Test Loss: 0.3155, Test Accuracy: 89.22%
Epoch [171/250], Loss: 0.1556, Accuracy: 94.99%, Test Loss: 0.3165, Test Accuracy: 89.05%
Epoch [172/250], Loss: 0.1548, Accuracy: 95.00%, Test Loss: 0.3158, Test Accuracy: 89.00%
Epoch [173/250], Loss: 0.1539, Accuracy: 95.02%, Test Loss: 0.3156, Test Accuracy: 89.08%
Epoch [174/250], Loss: 0.1527, Accuracy: 95.11%, Test Loss: 0.3174, Test Accuracy: 89.20%
Epoch [175/250], Loss: 0.1532, Accuracy: 95.08%, Test Loss: 0.3145, Test Accuracy: 89.27%
Epoch [176/250], Loss: 0.1517, Accuracy: 95.12%, Test Loss: 0.3169, Test Accuracy: 89.08%
Epoch [177/250], Loss: 0.1515, Accuracy: 95.11%, Test Loss: 0.3164, Test Accuracy: 89.30%
Epoch [178/250], Loss: 0.1513, Accuracy: 95.11%, Test Loss: 0.3157, Test Accuracy: 89.20%
Epoch [179/250], Loss: 0.1502, Accuracy: 95.10%, Test Loss: 0.3147, Test Accuracy: 89.09%
Epoch [180/250], Loss: 0.1493, Accuracy: 95.26%, Test Loss: 0.3163, Test Accuracy: 89.07%
Epoch [181/250], Loss: 0.1493, Accuracy: 95.21%, Test Loss: 0.3162, Test Accuracy: 89.24%
Epoch [182/250], Loss: 0.1482, Accuracy: 95.29%, Test Loss: 0.3156, Test Accuracy: 89.23%
Epoch [183/250], Loss: 0.1481, Accuracy: 95.23%, Test Loss: 0.3156, Test Accuracy: 89.16%
Epoch [184/250], Loss: 0.1476, Accuracy: 95.21%, Test Loss: 0.3180, Test Accuracy: 89.02%
Epoch [185/250], Loss: 0.1471, Accuracy: 95.26%, Test Loss: 0.3158, Test Accuracy: 89.18%
Epoch [186/250], Loss: 0.1463, Accuracy: 95.30%, Test Loss: 0.3166, Test Accuracy: 89.29%
Epoch [187/250], Loss: 0.1448, Accuracy: 95.44%, Test Loss: 0.3175, Test Accuracy: 88.98%
Epoch [188/250], Loss: 0.1445, Accuracy: 95.31%, Test Loss: 0.3180, Test Accuracy: 89.12%
Epoch [189/250], Loss: 0.1438, Accuracy: 95.40%, Test Loss: 0.3180, Test Accuracy: 89.18%
Epoch [190/250], Loss: 0.1433, Accuracy: 95.44%, Test Loss: 0.3167, Test Accuracy: 89.13%
Epoch [191/250], Loss: 0.1425, Accuracy: 95.49%, Test Loss: 0.3156, Test Accuracy: 89.23%
Epoch [192/250], Loss: 0.1421, Accuracy: 95.52%, Test Loss: 0.3164, Test Accuracy: 89.16%
Epoch [193/250], Loss: 0.1415, Accuracy: 95.51%, Test Loss: 0.3182, Test Accuracy: 89.21%
Epoch [194/250], Loss: 0.1409, Accuracy: 95.48%, Test Loss: 0.3174, Test Accuracy: 89.17%
Epoch [195/250], Loss: 0.1404, Accuracy: 95.55%, Test Loss: 0.3164, Test Accuracy: 89.24%
Epoch [196/250], Loss: 0.1393, Accuracy: 95.58%, Test Loss: 0.3172, Test Accuracy: 89.25%
Epoch [197/250], Loss: 0.1390, Accuracy: 95.63%, Test Loss: 0.3165, Test Accuracy: 89.29%
Epoch [198/250], Loss: 0.1385, Accuracy: 95.63%, Test Loss: 0.3174, Test Accuracy: 89.01%
Epoch [199/250], Loss: 0.1382, Accuracy: 95.63%, Test Loss: 0.3153, Test Accuracy: 89.20%
Epoch [200/250], Loss: 0.1376, Accuracy: 95.65%, Test Loss: 0.3175, Test Accuracy: 89.08%
Epoch [201/250], Loss: 0.1370, Accuracy: 95.73%, Test Loss: 0.3163, Test Accuracy: 89.12%
Epoch [202/250], Loss: 0.1359, Accuracy: 95.70%, Test Loss: 0.3171, Test Accuracy: 89.35%
Epoch [203/250], Loss: 0.1353, Accuracy: 95.77%, Test Loss: 0.3168, Test Accuracy: 89.17%
Epoch [204/250], Loss: 0.1349, Accuracy: 95.77%, Test Loss: 0.3190, Test Accuracy: 89.19%
Epoch [205/250], Loss: 0.1351, Accuracy: 95.73%, Test Loss: 0.3190, Test Accuracy: 89.29%
Epoch [206/250], Loss: 0.1344, Accuracy: 95.75%, Test Loss: 0.3182, Test Accuracy: 89.23%
Epoch [207/250], Loss: 0.1335, Accuracy: 95.87%, Test Loss: 0.3165, Test Accuracy: 89.26%
Epoch [208/250], Loss: 0.1331, Accuracy: 95.86%, Test Loss: 0.3183, Test Accuracy: 89.19%
Epoch [209/250], Loss: 0.1326, Accuracy: 95.83%, Test Loss: 0.3186, Test Accuracy: 89.17%
Epoch [210/250], Loss: 0.1312, Accuracy: 95.96%, Test Loss: 0.3167, Test Accuracy: 89.20%
Epoch [211/250], Loss: 0.1310, Accuracy: 95.94%, Test Loss: 0.3184, Test Accuracy: 89.24%
Epoch [212/250], Loss: 0.1314, Accuracy: 95.86%, Test Loss: 0.3181, Test Accuracy: 89.28%
Epoch [213/250], Loss: 0.1296, Accuracy: 95.95%, Test Loss: 0.3182, Test Accuracy: 89.10%
Epoch [214/250], Loss: 0.1293, Accuracy: 95.96%, Test Loss: 0.3185, Test Accuracy: 89.36%
Epoch [215/250], Loss: 0.1289, Accuracy: 95.99%, Test Loss: 0.3198, Test Accuracy: 89.23%
Epoch [216/250], Loss: 0.1277, Accuracy: 96.07%, Test Loss: 0.3195, Test Accuracy: 89.09%
Epoch [217/250], Loss: 0.1279, Accuracy: 96.08%, Test Loss: 0.3198, Test Accuracy: 89.14%
Epoch [218/250], Loss: 0.1276, Accuracy: 96.08%, Test Loss: 0.3198, Test Accuracy: 89.09%
Epoch [219/250], Loss: 0.1271, Accuracy: 96.11%, Test Loss: 0.3182, Test Accuracy: 89.13%
Epoch [220/250], Loss: 0.1264, Accuracy: 96.13%, Test Loss: 0.3191, Test Accuracy: 89.13%
Epoch [221/250], Loss: 0.1252, Accuracy: 96.23%, Test Loss: 0.3190, Test Accuracy: 89.24%
Epoch [222/250], Loss: 0.1254, Accuracy: 96.20%, Test Loss: 0.3220, Test Accuracy: 88.93%
Epoch [223/250], Loss: 0.1247, Accuracy: 96.16%, Test Loss: 0.3195, Test Accuracy: 89.14%
Epoch [224/250], Loss: 0.1239, Accuracy: 96.19%, Test Loss: 0.3187, Test Accuracy: 89.10%
Epoch [225/250], Loss: 0.1238, Accuracy: 96.17%, Test Loss: 0.3190, Test Accuracy: 89.36%
Epoch [226/250], Loss: 0.1232, Accuracy: 96.22%, Test Loss: 0.3205, Test Accuracy: 89.16%
Epoch [227/250], Loss: 0.1229, Accuracy: 96.24%, Test Loss: 0.3193, Test Accuracy: 89.23%
Epoch [228/250], Loss: 0.1222, Accuracy: 96.30%, Test Loss: 0.3200, Test Accuracy: 89.11%
Epoch [229/250], Loss: 0.1212, Accuracy: 96.34%, Test Loss: 0.3223, Test Accuracy: 89.21%
Epoch [230/250], Loss: 0.1210, Accuracy: 96.34%, Test Loss: 0.3192, Test Accuracy: 89.28%
Epoch [231/250], Loss: 0.1206, Accuracy: 96.27%, Test Loss: 0.3205, Test Accuracy: 89.35%
Epoch [232/250], Loss: 0.1199, Accuracy: 96.41%, Test Loss: 0.3264, Test Accuracy: 88.99%
Epoch [233/250], Loss: 0.1199, Accuracy: 96.34%, Test Loss: 0.3209, Test Accuracy: 89.18%
Epoch [234/250], Loss: 0.1191, Accuracy: 96.42%, Test Loss: 0.3223, Test Accuracy: 89.31%
Epoch [235/250], Loss: 0.1193, Accuracy: 96.32%, Test Loss: 0.3201, Test Accuracy: 89.39%
Epoch [236/250], Loss: 0.1185, Accuracy: 96.46%, Test Loss: 0.3229, Test Accuracy: 89.35%
Epoch [237/250], Loss: 0.1172, Accuracy: 96.52%, Test Loss: 0.3218, Test Accuracy: 89.16%
Epoch [238/250], Loss: 0.1168, Accuracy: 96.52%, Test Loss: 0.3223, Test Accuracy: 89.24%
Epoch [239/250], Loss: 0.1164, Accuracy: 96.53%, Test Loss: 0.3214, Test Accuracy: 89.35%
Epoch [240/250], Loss: 0.1162, Accuracy: 96.53%, Test Loss: 0.3242, Test Accuracy: 89.27%
Epoch [241/250], Loss: 0.1158, Accuracy: 96.54%, Test Loss: 0.3230, Test Accuracy: 89.29%
Epoch [242/250], Loss: 0.1148, Accuracy: 96.60%, Test Loss: 0.3227, Test Accuracy: 89.18%
Epoch [243/250], Loss: 0.1146, Accuracy: 96.60%, Test Loss: 0.3244, Test Accuracy: 89.28%
Epoch [244/250], Loss: 0.1145, Accuracy: 96.57%, Test Loss: 0.3230, Test Accuracy: 89.12%
Epoch [245/250], Loss: 0.1143, Accuracy: 96.57%, Test Loss: 0.3272, Test Accuracy: 89.10%
Epoch [246/250], Loss: 0.1133, Accuracy: 96.63%, Test Loss: 0.3219, Test Accuracy: 89.18%
Epoch [247/250], Loss: 0.1122, Accuracy: 96.73%, Test Loss: 0.3226, Test Accuracy: 89.26%
Epoch [248/250], Loss: 0.1136, Accuracy: 96.64%, Test Loss: 0.3255, Test Accuracy: 89.27%
Epoch [249/250], Loss: 0.1116, Accuracy: 96.68%, Test Loss: 0.3245, Test Accuracy: 89.15%
Epoch [250/250], Loss: 0.1116, Accuracy: 96.75%, Test Loss: 0.3239, Test Accuracy: 89.35%
import matplotlib.pyplot as plt

plt.plot(train_losses, label='train_losses')
plt.plot(test_losses, label='test_losses')
plt.legend()
<matplotlib.legend.Legend at 0x7fe8f7cef610>
image
import matplotlib.pyplot as plt

plt.plot(train_accuracies, label='train_accuracy')
plt.plot(test_accuracies, label='test_accuracy')
plt.legend()
<matplotlib.legend.Legend at 0x7fe801c2af20>
image
MLP_Cifar10_ReLU_He_Adam

Data

# Load CFashionMNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5, 0.5), (0.5,0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1024, num_workers=10, shuffle=True, drop_last=True)

testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1024, num_workers=10, shuffle=False)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz
100%|█████████████████████████████████████| 170498071/170498071 [00:16<00:00, 10192348.32it/s]
Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
import matplotlib.pyplot as plt
import numpy as np

# Function to display the images
def imshow(img):
    img = img*0.5 + 0.5
    np_img = img.numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()

for i, (images, labels) in enumerate(trainloader, 0):
    # Plot some images
    imshow(torchvision.utils.make_grid(images[:8]))  # Display 8 images from the batch
    break
image

Model

model = nn.Sequential(
    nn.Flatten(), 
    nn.Linear(32*32*3, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

# Initialize the weights
for layer in model:
    if isinstance(layer, nn.Linear):
        init.kaiming_uniform_(layer.weight, nonlinearity='relu')
        if layer.bias is not None:
            layer.bias.data.fill_(0)
            
model = model.to(device)
print(model)
Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=3072, out_features=256, bias=True)
  (2): ReLU()
  (3): Linear(in_features=256, out_features=10, bias=True)
)

Loss, Optimizer, and Evaluation Function

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Function to compute loss and accuracy for test set
def evaluate(model, testloader, criterion):
    model.eval()
    test_loss = 0.0
    running_correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            # Move inputs and labels to the device
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

    accuracy = 100 * running_correct / total
    test_loss = test_loss / len(testloader)
    return test_loss, accuracy

Train

# some parameter
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
max_epoch = 250
# train
for epoch in range(max_epoch):
    model.train()
    running_loss = 0.0
    running_correct = 0   # to track number of correct predictions
    total = 0             # to track total number of samples

    for i, (inputs, labels) in enumerate(trainloader, 0):
        # Move inputs and labels to the device
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

        # Determine class predictions and track accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        running_correct += (predicted == labels).sum().item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()        

    epoch_accuracy = 100 * running_correct / total
    epoch_loss = running_loss / (i + 1)
    
    test_loss, test_accuracy = evaluate(model, testloader, criterion)
    print(f"Epoch [{epoch + 1}/{max_epoch}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    # save for plot
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_accuracy)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
Epoch [1/250], Loss: 1.9455, Accuracy: 31.24%, Test Loss: 1.7467, Test Accuracy: 39.18%
Epoch [2/250], Loss: 1.6807, Accuracy: 41.80%, Test Loss: 1.6353, Test Accuracy: 43.14%
Epoch [3/250], Loss: 1.5946, Accuracy: 45.04%, Test Loss: 1.5860, Test Accuracy: 44.92%
Epoch [4/250], Loss: 1.5413, Accuracy: 47.02%, Test Loss: 1.5515, Test Accuracy: 46.38%
Epoch [5/250], Loss: 1.5024, Accuracy: 48.41%, Test Loss: 1.5270, Test Accuracy: 47.06%
Epoch [6/250], Loss: 1.4706, Accuracy: 49.44%, Test Loss: 1.5074, Test Accuracy: 47.69%
Epoch [7/250], Loss: 1.4402, Accuracy: 50.84%, Test Loss: 1.4901, Test Accuracy: 48.32%
Epoch [8/250], Loss: 1.4176, Accuracy: 51.50%, Test Loss: 1.4769, Test Accuracy: 48.66%
Epoch [9/250], Loss: 1.3931, Accuracy: 52.51%, Test Loss: 1.4652, Test Accuracy: 48.99%
Epoch [10/250], Loss: 1.3746, Accuracy: 53.37%, Test Loss: 1.4536, Test Accuracy: 49.48%
Epoch [11/250], Loss: 1.3563, Accuracy: 54.07%, Test Loss: 1.4436, Test Accuracy: 49.50%
Epoch [12/250], Loss: 1.3385, Accuracy: 54.75%, Test Loss: 1.4323, Test Accuracy: 49.73%
Epoch [13/250], Loss: 1.3216, Accuracy: 55.16%, Test Loss: 1.4275, Test Accuracy: 49.73%
Epoch [14/250], Loss: 1.3066, Accuracy: 55.72%, Test Loss: 1.4188, Test Accuracy: 50.19%
Epoch [15/250], Loss: 1.2921, Accuracy: 56.38%, Test Loss: 1.4158, Test Accuracy: 50.13%
Epoch [16/250], Loss: 1.2797, Accuracy: 56.87%, Test Loss: 1.4076, Test Accuracy: 50.61%
Epoch [17/250], Loss: 1.2650, Accuracy: 57.40%, Test Loss: 1.4007, Test Accuracy: 51.12%
Epoch [18/250], Loss: 1.2505, Accuracy: 58.05%, Test Loss: 1.3934, Test Accuracy: 51.01%
Epoch [19/250], Loss: 1.2382, Accuracy: 58.40%, Test Loss: 1.3929, Test Accuracy: 51.16%
Epoch [20/250], Loss: 1.2277, Accuracy: 58.94%, Test Loss: 1.3887, Test Accuracy: 51.39%
Epoch [21/250], Loss: 1.2164, Accuracy: 59.33%, Test Loss: 1.3832, Test Accuracy: 51.31%
Epoch [22/250], Loss: 1.2047, Accuracy: 59.84%, Test Loss: 1.3843, Test Accuracy: 51.57%
Epoch [23/250], Loss: 1.1931, Accuracy: 60.18%, Test Loss: 1.3753, Test Accuracy: 51.58%
Epoch [24/250], Loss: 1.1834, Accuracy: 60.58%, Test Loss: 1.3731, Test Accuracy: 51.93%
Epoch [25/250], Loss: 1.1737, Accuracy: 60.91%, Test Loss: 1.3669, Test Accuracy: 51.82%
Epoch [26/250], Loss: 1.1637, Accuracy: 61.19%, Test Loss: 1.3696, Test Accuracy: 51.83%
Epoch [27/250], Loss: 1.1557, Accuracy: 61.57%, Test Loss: 1.3669, Test Accuracy: 51.81%
Epoch [28/250], Loss: 1.1444, Accuracy: 61.93%, Test Loss: 1.3617, Test Accuracy: 51.84%
Epoch [29/250], Loss: 1.1354, Accuracy: 62.38%, Test Loss: 1.3602, Test Accuracy: 52.46%
Epoch [30/250], Loss: 1.1246, Accuracy: 62.71%, Test Loss: 1.3574, Test Accuracy: 52.38%
Epoch [31/250], Loss: 1.1166, Accuracy: 63.13%, Test Loss: 1.3538, Test Accuracy: 52.40%
Epoch [32/250], Loss: 1.1097, Accuracy: 63.20%, Test Loss: 1.3598, Test Accuracy: 52.40%
Epoch [33/250], Loss: 1.1003, Accuracy: 63.70%, Test Loss: 1.3540, Test Accuracy: 52.49%
Epoch [34/250], Loss: 1.0917, Accuracy: 63.97%, Test Loss: 1.3500, Test Accuracy: 52.54%
Epoch [35/250], Loss: 1.0845, Accuracy: 64.17%, Test Loss: 1.3487, Test Accuracy: 52.48%
Epoch [36/250], Loss: 1.0749, Accuracy: 64.63%, Test Loss: 1.3519, Test Accuracy: 52.33%
Epoch [37/250], Loss: 1.0658, Accuracy: 64.96%, Test Loss: 1.3474, Test Accuracy: 52.63%
Epoch [38/250], Loss: 1.0605, Accuracy: 65.03%, Test Loss: 1.3474, Test Accuracy: 52.75%
Epoch [39/250], Loss: 1.0516, Accuracy: 65.51%, Test Loss: 1.3433, Test Accuracy: 52.67%
Epoch [40/250], Loss: 1.0463, Accuracy: 65.63%, Test Loss: 1.3441, Test Accuracy: 52.91%
Epoch [41/250], Loss: 1.0376, Accuracy: 66.11%, Test Loss: 1.3429, Test Accuracy: 53.31%
Epoch [42/250], Loss: 1.0318, Accuracy: 66.25%, Test Loss: 1.3424, Test Accuracy: 52.77%
Epoch [43/250], Loss: 1.0223, Accuracy: 66.59%, Test Loss: 1.3434, Test Accuracy: 53.28%
Epoch [44/250], Loss: 1.0157, Accuracy: 66.92%, Test Loss: 1.3449, Test Accuracy: 53.05%
Epoch [45/250], Loss: 1.0094, Accuracy: 67.08%, Test Loss: 1.3380, Test Accuracy: 53.35%
Epoch [46/250], Loss: 1.0018, Accuracy: 67.44%, Test Loss: 1.3369, Test Accuracy: 53.30%
Epoch [47/250], Loss: 0.9954, Accuracy: 67.61%, Test Loss: 1.3397, Test Accuracy: 53.26%
Epoch [48/250], Loss: 0.9890, Accuracy: 67.96%, Test Loss: 1.3417, Test Accuracy: 52.98%
Epoch [49/250], Loss: 0.9822, Accuracy: 68.30%, Test Loss: 1.3360, Test Accuracy: 53.43%
Epoch [50/250], Loss: 0.9756, Accuracy: 68.52%, Test Loss: 1.3360, Test Accuracy: 53.45%
Epoch [51/250], Loss: 0.9683, Accuracy: 68.77%, Test Loss: 1.3382, Test Accuracy: 53.44%
Epoch [52/250], Loss: 0.9608, Accuracy: 69.09%, Test Loss: 1.3393, Test Accuracy: 53.56%
Epoch [53/250], Loss: 0.9552, Accuracy: 69.38%, Test Loss: 1.3386, Test Accuracy: 53.77%
Epoch [54/250], Loss: 0.9511, Accuracy: 69.36%, Test Loss: 1.3393, Test Accuracy: 53.29%
Epoch [55/250], Loss: 0.9430, Accuracy: 69.67%, Test Loss: 1.3349, Test Accuracy: 53.71%
Epoch [56/250], Loss: 0.9364, Accuracy: 69.98%, Test Loss: 1.3354, Test Accuracy: 53.94%
Epoch [57/250], Loss: 0.9328, Accuracy: 70.04%, Test Loss: 1.3386, Test Accuracy: 53.45%
Epoch [58/250], Loss: 0.9244, Accuracy: 70.56%, Test Loss: 1.3401, Test Accuracy: 53.51%
Epoch [59/250], Loss: 0.9187, Accuracy: 70.71%, Test Loss: 1.3399, Test Accuracy: 53.49%
Epoch [60/250], Loss: 0.9147, Accuracy: 70.69%, Test Loss: 1.3381, Test Accuracy: 53.63%
Epoch [61/250], Loss: 0.9079, Accuracy: 71.10%, Test Loss: 1.3384, Test Accuracy: 53.58%
Epoch [62/250], Loss: 0.9008, Accuracy: 71.25%, Test Loss: 1.3362, Test Accuracy: 53.91%
Epoch [63/250], Loss: 0.8951, Accuracy: 71.54%, Test Loss: 1.3384, Test Accuracy: 53.67%
Epoch [64/250], Loss: 0.8898, Accuracy: 71.84%, Test Loss: 1.3428, Test Accuracy: 53.47%
Epoch [65/250], Loss: 0.8848, Accuracy: 71.86%, Test Loss: 1.3413, Test Accuracy: 53.59%
Epoch [66/250], Loss: 0.8789, Accuracy: 72.13%, Test Loss: 1.3401, Test Accuracy: 54.06%
Epoch [67/250], Loss: 0.8735, Accuracy: 72.40%, Test Loss: 1.3507, Test Accuracy: 53.63%
Epoch [68/250], Loss: 0.8670, Accuracy: 72.51%, Test Loss: 1.3389, Test Accuracy: 53.86%
Epoch [69/250], Loss: 0.8622, Accuracy: 72.85%, Test Loss: 1.3456, Test Accuracy: 53.59%
Epoch [70/250], Loss: 0.8560, Accuracy: 73.08%, Test Loss: 1.3417, Test Accuracy: 53.82%
Epoch [71/250], Loss: 0.8517, Accuracy: 73.29%, Test Loss: 1.3415, Test Accuracy: 53.52%
Epoch [72/250], Loss: 0.8464, Accuracy: 73.47%, Test Loss: 1.3450, Test Accuracy: 53.76%
Epoch [73/250], Loss: 0.8414, Accuracy: 73.60%, Test Loss: 1.3460, Test Accuracy: 53.85%
Epoch [74/250], Loss: 0.8357, Accuracy: 73.79%, Test Loss: 1.3440, Test Accuracy: 53.70%
Epoch [75/250], Loss: 0.8297, Accuracy: 74.01%, Test Loss: 1.3504, Test Accuracy: 53.70%
Epoch [76/250], Loss: 0.8258, Accuracy: 74.14%, Test Loss: 1.3448, Test Accuracy: 53.77%
Epoch [77/250], Loss: 0.8204, Accuracy: 74.50%, Test Loss: 1.3515, Test Accuracy: 53.42%
Epoch [78/250], Loss: 0.8166, Accuracy: 74.46%, Test Loss: 1.3427, Test Accuracy: 53.76%
Epoch [79/250], Loss: 0.8105, Accuracy: 74.73%, Test Loss: 1.3492, Test Accuracy: 53.76%
Epoch [80/250], Loss: 0.8059, Accuracy: 74.97%, Test Loss: 1.3492, Test Accuracy: 53.85%
Epoch [81/250], Loss: 0.8005, Accuracy: 75.09%, Test Loss: 1.3500, Test Accuracy: 53.90%
Epoch [82/250], Loss: 0.7950, Accuracy: 75.47%, Test Loss: 1.3480, Test Accuracy: 53.81%
Epoch [83/250], Loss: 0.7908, Accuracy: 75.62%, Test Loss: 1.3511, Test Accuracy: 53.63%
Epoch [84/250], Loss: 0.7872, Accuracy: 75.73%, Test Loss: 1.3576, Test Accuracy: 53.30%
Epoch [85/250], Loss: 0.7813, Accuracy: 75.94%, Test Loss: 1.3516, Test Accuracy: 53.61%
Epoch [86/250], Loss: 0.7760, Accuracy: 76.11%, Test Loss: 1.3541, Test Accuracy: 53.61%
Epoch [87/250], Loss: 0.7717, Accuracy: 76.28%, Test Loss: 1.3556, Test Accuracy: 53.59%
Epoch [88/250], Loss: 0.7682, Accuracy: 76.40%, Test Loss: 1.3582, Test Accuracy: 53.65%
Epoch [89/250], Loss: 0.7628, Accuracy: 76.73%, Test Loss: 1.3568, Test Accuracy: 53.74%
Epoch [90/250], Loss: 0.7568, Accuracy: 77.04%, Test Loss: 1.3595, Test Accuracy: 53.89%
Epoch [91/250], Loss: 0.7532, Accuracy: 77.17%, Test Loss: 1.3632, Test Accuracy: 53.49%
Epoch [92/250], Loss: 0.7491, Accuracy: 77.14%, Test Loss: 1.3595, Test Accuracy: 53.66%
Epoch [93/250], Loss: 0.7431, Accuracy: 77.63%, Test Loss: 1.3632, Test Accuracy: 54.02%
Epoch [94/250], Loss: 0.7398, Accuracy: 77.60%, Test Loss: 1.3659, Test Accuracy: 53.49%
Epoch [95/250], Loss: 0.7363, Accuracy: 77.70%, Test Loss: 1.3648, Test Accuracy: 53.62%
Epoch [96/250], Loss: 0.7307, Accuracy: 78.02%, Test Loss: 1.3680, Test Accuracy: 53.54%
Epoch [97/250], Loss: 0.7254, Accuracy: 78.10%, Test Loss: 1.3661, Test Accuracy: 53.84%
Epoch [98/250], Loss: 0.7238, Accuracy: 78.16%, Test Loss: 1.3698, Test Accuracy: 53.53%
Epoch [99/250], Loss: 0.7176, Accuracy: 78.53%, Test Loss: 1.3764, Test Accuracy: 53.31%
Epoch [100/250], Loss: 0.7139, Accuracy: 78.55%, Test Loss: 1.3725, Test Accuracy: 53.77%
Epoch [101/250], Loss: 0.7100, Accuracy: 78.74%, Test Loss: 1.3743, Test Accuracy: 53.46%
Epoch [102/250], Loss: 0.7055, Accuracy: 79.02%, Test Loss: 1.3749, Test Accuracy: 53.69%
Epoch [103/250], Loss: 0.7013, Accuracy: 79.05%, Test Loss: 1.3764, Test Accuracy: 53.33%
Epoch [104/250], Loss: 0.6976, Accuracy: 79.32%, Test Loss: 1.3756, Test Accuracy: 53.66%
Epoch [105/250], Loss: 0.6932, Accuracy: 79.35%, Test Loss: 1.3784, Test Accuracy: 53.75%
Epoch [106/250], Loss: 0.6898, Accuracy: 79.49%, Test Loss: 1.3801, Test Accuracy: 53.59%
Epoch [107/250], Loss: 0.6848, Accuracy: 79.78%, Test Loss: 1.3851, Test Accuracy: 53.33%
Epoch [108/250], Loss: 0.6800, Accuracy: 80.02%, Test Loss: 1.3839, Test Accuracy: 53.63%
Epoch [109/250], Loss: 0.6759, Accuracy: 80.12%, Test Loss: 1.3870, Test Accuracy: 53.47%
Epoch [110/250], Loss: 0.6736, Accuracy: 80.20%, Test Loss: 1.3896, Test Accuracy: 53.28%
Epoch [111/250], Loss: 0.6682, Accuracy: 80.37%, Test Loss: 1.3847, Test Accuracy: 53.60%
Epoch [112/250], Loss: 0.6636, Accuracy: 80.70%, Test Loss: 1.3897, Test Accuracy: 53.55%
Epoch [113/250], Loss: 0.6608, Accuracy: 80.61%, Test Loss: 1.3902, Test Accuracy: 53.63%
Epoch [114/250], Loss: 0.6563, Accuracy: 80.77%, Test Loss: 1.3955, Test Accuracy: 53.44%
Epoch [115/250], Loss: 0.6514, Accuracy: 81.05%, Test Loss: 1.3946, Test Accuracy: 53.39%
Epoch [116/250], Loss: 0.6477, Accuracy: 81.14%, Test Loss: 1.4011, Test Accuracy: 53.26%
Epoch [117/250], Loss: 0.6460, Accuracy: 81.33%, Test Loss: 1.3967, Test Accuracy: 53.38%
Epoch [118/250], Loss: 0.6422, Accuracy: 81.49%, Test Loss: 1.3978, Test Accuracy: 53.81%
Epoch [119/250], Loss: 0.6375, Accuracy: 81.52%, Test Loss: 1.4034, Test Accuracy: 53.31%
Epoch [120/250], Loss: 0.6329, Accuracy: 81.76%, Test Loss: 1.4046, Test Accuracy: 53.86%
Epoch [121/250], Loss: 0.6309, Accuracy: 81.85%, Test Loss: 1.4139, Test Accuracy: 53.15%
Epoch [122/250], Loss: 0.6263, Accuracy: 82.03%, Test Loss: 1.4040, Test Accuracy: 53.46%
Epoch [123/250], Loss: 0.6229, Accuracy: 82.24%, Test Loss: 1.4068, Test Accuracy: 53.47%
Epoch [124/250], Loss: 0.6181, Accuracy: 82.32%, Test Loss: 1.4155, Test Accuracy: 53.45%
Epoch [125/250], Loss: 0.6157, Accuracy: 82.49%, Test Loss: 1.4112, Test Accuracy: 53.35%
Epoch [126/250], Loss: 0.6129, Accuracy: 82.55%, Test Loss: 1.4173, Test Accuracy: 53.47%
Epoch [127/250], Loss: 0.6091, Accuracy: 82.65%, Test Loss: 1.4172, Test Accuracy: 53.22%
Epoch [128/250], Loss: 0.6040, Accuracy: 82.86%, Test Loss: 1.4164, Test Accuracy: 53.26%
Epoch [129/250], Loss: 0.6017, Accuracy: 83.02%, Test Loss: 1.4206, Test Accuracy: 53.60%
Epoch [130/250], Loss: 0.5974, Accuracy: 83.11%, Test Loss: 1.4206, Test Accuracy: 53.62%
Epoch [131/250], Loss: 0.5930, Accuracy: 83.35%, Test Loss: 1.4229, Test Accuracy: 53.54%
Epoch [132/250], Loss: 0.5896, Accuracy: 83.38%, Test Loss: 1.4260, Test Accuracy: 53.42%
Epoch [133/250], Loss: 0.5877, Accuracy: 83.51%, Test Loss: 1.4323, Test Accuracy: 53.40%
Epoch [134/250], Loss: 0.5850, Accuracy: 83.45%, Test Loss: 1.4299, Test Accuracy: 53.46%
Epoch [135/250], Loss: 0.5808, Accuracy: 83.70%, Test Loss: 1.4331, Test Accuracy: 53.13%
Epoch [136/250], Loss: 0.5759, Accuracy: 83.94%, Test Loss: 1.4392, Test Accuracy: 52.96%
Epoch [137/250], Loss: 0.5735, Accuracy: 84.00%, Test Loss: 1.4366, Test Accuracy: 53.38%
Epoch [138/250], Loss: 0.5701, Accuracy: 84.21%, Test Loss: 1.4346, Test Accuracy: 53.66%
Epoch [139/250], Loss: 0.5670, Accuracy: 84.37%, Test Loss: 1.4455, Test Accuracy: 53.34%
Epoch [140/250], Loss: 0.5647, Accuracy: 84.34%, Test Loss: 1.4382, Test Accuracy: 53.50%
Epoch [141/250], Loss: 0.5585, Accuracy: 84.57%, Test Loss: 1.4480, Test Accuracy: 53.05%
Epoch [142/250], Loss: 0.5558, Accuracy: 84.69%, Test Loss: 1.4463, Test Accuracy: 53.12%
Epoch [143/250], Loss: 0.5537, Accuracy: 84.72%, Test Loss: 1.4457, Test Accuracy: 53.16%
Epoch [144/250], Loss: 0.5499, Accuracy: 84.86%, Test Loss: 1.4509, Test Accuracy: 53.36%
Epoch [145/250], Loss: 0.5484, Accuracy: 85.09%, Test Loss: 1.4539, Test Accuracy: 53.21%
Epoch [146/250], Loss: 0.5432, Accuracy: 85.22%, Test Loss: 1.4556, Test Accuracy: 53.10%
Epoch [147/250], Loss: 0.5404, Accuracy: 85.44%, Test Loss: 1.4557, Test Accuracy: 53.43%
Epoch [148/250], Loss: 0.5372, Accuracy: 85.47%, Test Loss: 1.4611, Test Accuracy: 52.94%
Epoch [149/250], Loss: 0.5342, Accuracy: 85.58%, Test Loss: 1.4653, Test Accuracy: 52.85%
Epoch [150/250], Loss: 0.5315, Accuracy: 85.69%, Test Loss: 1.4675, Test Accuracy: 53.03%
Epoch [151/250], Loss: 0.5286, Accuracy: 85.72%, Test Loss: 1.4680, Test Accuracy: 53.33%
Epoch [152/250], Loss: 0.5251, Accuracy: 85.86%, Test Loss: 1.4705, Test Accuracy: 53.06%
Epoch [153/250], Loss: 0.5229, Accuracy: 86.01%, Test Loss: 1.4707, Test Accuracy: 53.12%
Epoch [154/250], Loss: 0.5181, Accuracy: 86.20%, Test Loss: 1.4688, Test Accuracy: 53.24%
Epoch [155/250], Loss: 0.5153, Accuracy: 86.24%, Test Loss: 1.4732, Test Accuracy: 53.31%
Epoch [156/250], Loss: 0.5131, Accuracy: 86.31%, Test Loss: 1.4737, Test Accuracy: 53.23%
Epoch [157/250], Loss: 0.5095, Accuracy: 86.47%, Test Loss: 1.4807, Test Accuracy: 52.96%
Epoch [158/250], Loss: 0.5068, Accuracy: 86.67%, Test Loss: 1.4822, Test Accuracy: 52.82%
Epoch [159/250], Loss: 0.5033, Accuracy: 86.71%, Test Loss: 1.4827, Test Accuracy: 53.31%
Epoch [160/250], Loss: 0.5016, Accuracy: 86.70%, Test Loss: 1.4825, Test Accuracy: 52.94%
Epoch [161/250], Loss: 0.4955, Accuracy: 87.04%, Test Loss: 1.4873, Test Accuracy: 53.08%
Epoch [162/250], Loss: 0.4943, Accuracy: 87.06%, Test Loss: 1.4910, Test Accuracy: 53.00%
Epoch [163/250], Loss: 0.4908, Accuracy: 87.22%, Test Loss: 1.4947, Test Accuracy: 53.18%
Epoch [164/250], Loss: 0.4880, Accuracy: 87.35%, Test Loss: 1.4956, Test Accuracy: 53.07%
Epoch [165/250], Loss: 0.4860, Accuracy: 87.42%, Test Loss: 1.4970, Test Accuracy: 53.14%
Epoch [166/250], Loss: 0.4821, Accuracy: 87.62%, Test Loss: 1.5008, Test Accuracy: 53.01%
Epoch [167/250], Loss: 0.4809, Accuracy: 87.61%, Test Loss: 1.5056, Test Accuracy: 53.05%
Epoch [168/250], Loss: 0.4769, Accuracy: 87.73%, Test Loss: 1.5021, Test Accuracy: 53.19%
Epoch [169/250], Loss: 0.4737, Accuracy: 87.88%, Test Loss: 1.5065, Test Accuracy: 53.15%
Epoch [170/250], Loss: 0.4719, Accuracy: 87.88%, Test Loss: 1.5080, Test Accuracy: 52.98%
Epoch [171/250], Loss: 0.4693, Accuracy: 88.06%, Test Loss: 1.5156, Test Accuracy: 52.82%
Epoch [172/250], Loss: 0.4650, Accuracy: 88.15%, Test Loss: 1.5130, Test Accuracy: 53.01%
Epoch [173/250], Loss: 0.4633, Accuracy: 88.21%, Test Loss: 1.5139, Test Accuracy: 53.31%
Epoch [174/250], Loss: 0.4601, Accuracy: 88.35%, Test Loss: 1.5208, Test Accuracy: 53.32%
Epoch [175/250], Loss: 0.4580, Accuracy: 88.30%, Test Loss: 1.5174, Test Accuracy: 52.89%
Epoch [176/250], Loss: 0.4544, Accuracy: 88.72%, Test Loss: 1.5234, Test Accuracy: 52.67%
Epoch [177/250], Loss: 0.4511, Accuracy: 88.76%, Test Loss: 1.5287, Test Accuracy: 52.98%
Epoch [178/250], Loss: 0.4493, Accuracy: 88.75%, Test Loss: 1.5291, Test Accuracy: 52.99%
Epoch [179/250], Loss: 0.4462, Accuracy: 88.89%, Test Loss: 1.5304, Test Accuracy: 53.04%
Epoch [180/250], Loss: 0.4453, Accuracy: 88.93%, Test Loss: 1.5311, Test Accuracy: 53.05%
Epoch [181/250], Loss: 0.4414, Accuracy: 89.09%, Test Loss: 1.5349, Test Accuracy: 53.12%
Epoch [182/250], Loss: 0.4373, Accuracy: 89.27%, Test Loss: 1.5442, Test Accuracy: 52.70%
Epoch [183/250], Loss: 0.4353, Accuracy: 89.28%, Test Loss: 1.5413, Test Accuracy: 53.05%
Epoch [184/250], Loss: 0.4342, Accuracy: 89.41%, Test Loss: 1.5440, Test Accuracy: 52.89%
Epoch [185/250], Loss: 0.4300, Accuracy: 89.55%, Test Loss: 1.5468, Test Accuracy: 53.00%
Epoch [186/250], Loss: 0.4281, Accuracy: 89.62%, Test Loss: 1.5492, Test Accuracy: 52.88%
Epoch [187/250], Loss: 0.4257, Accuracy: 89.60%, Test Loss: 1.5524, Test Accuracy: 52.68%
Epoch [188/250], Loss: 0.4218, Accuracy: 89.84%, Test Loss: 1.5619, Test Accuracy: 52.78%
Epoch [189/250], Loss: 0.4214, Accuracy: 89.89%, Test Loss: 1.5567, Test Accuracy: 52.87%
Epoch [190/250], Loss: 0.4180, Accuracy: 89.84%, Test Loss: 1.5587, Test Accuracy: 52.60%
Epoch [191/250], Loss: 0.4140, Accuracy: 90.03%, Test Loss: 1.5623, Test Accuracy: 52.70%
Epoch [192/250], Loss: 0.4126, Accuracy: 90.06%, Test Loss: 1.5656, Test Accuracy: 52.98%
Epoch [193/250], Loss: 0.4098, Accuracy: 90.27%, Test Loss: 1.5716, Test Accuracy: 52.62%
Epoch [194/250], Loss: 0.4082, Accuracy: 90.28%, Test Loss: 1.5750, Test Accuracy: 52.45%
Epoch [195/250], Loss: 0.4077, Accuracy: 90.22%, Test Loss: 1.5721, Test Accuracy: 52.65%
Epoch [196/250], Loss: 0.4040, Accuracy: 90.38%, Test Loss: 1.5801, Test Accuracy: 52.50%
Epoch [197/250], Loss: 0.3999, Accuracy: 90.68%, Test Loss: 1.5805, Test Accuracy: 52.51%
Epoch [198/250], Loss: 0.3988, Accuracy: 90.75%, Test Loss: 1.5782, Test Accuracy: 52.81%
Epoch [199/250], Loss: 0.3966, Accuracy: 90.76%, Test Loss: 1.5840, Test Accuracy: 52.67%
Epoch [200/250], Loss: 0.3925, Accuracy: 90.83%, Test Loss: 1.5856, Test Accuracy: 52.52%
Epoch [201/250], Loss: 0.3916, Accuracy: 90.81%, Test Loss: 1.5876, Test Accuracy: 52.77%
Epoch [202/250], Loss: 0.3883, Accuracy: 91.04%, Test Loss: 1.5931, Test Accuracy: 52.61%
Epoch [203/250], Loss: 0.3869, Accuracy: 91.14%, Test Loss: 1.6028, Test Accuracy: 52.29%
Epoch [204/250], Loss: 0.3839, Accuracy: 91.20%, Test Loss: 1.5975, Test Accuracy: 52.90%
Epoch [205/250], Loss: 0.3809, Accuracy: 91.36%, Test Loss: 1.6033, Test Accuracy: 52.58%
Epoch [206/250], Loss: 0.3800, Accuracy: 91.26%, Test Loss: 1.6029, Test Accuracy: 52.49%
Epoch [207/250], Loss: 0.3754, Accuracy: 91.55%, Test Loss: 1.6044, Test Accuracy: 52.36%
Epoch [208/250], Loss: 0.3745, Accuracy: 91.53%, Test Loss: 1.6067, Test Accuracy: 52.38%
Epoch [209/250], Loss: 0.3727, Accuracy: 91.58%, Test Loss: 1.6138, Test Accuracy: 52.48%
Epoch [210/250], Loss: 0.3697, Accuracy: 91.68%, Test Loss: 1.6127, Test Accuracy: 52.54%
Epoch [211/250], Loss: 0.3671, Accuracy: 91.80%, Test Loss: 1.6133, Test Accuracy: 52.42%
Epoch [212/250], Loss: 0.3652, Accuracy: 91.86%, Test Loss: 1.6190, Test Accuracy: 52.65%
Epoch [213/250], Loss: 0.3635, Accuracy: 91.95%, Test Loss: 1.6226, Test Accuracy: 52.85%
Epoch [214/250], Loss: 0.3606, Accuracy: 92.11%, Test Loss: 1.6281, Test Accuracy: 52.47%
Epoch [215/250], Loss: 0.3578, Accuracy: 92.20%, Test Loss: 1.6322, Test Accuracy: 52.24%
Epoch [216/250], Loss: 0.3567, Accuracy: 92.34%, Test Loss: 1.6297, Test Accuracy: 52.77%
Epoch [217/250], Loss: 0.3543, Accuracy: 92.29%, Test Loss: 1.6347, Test Accuracy: 52.22%
Epoch [218/250], Loss: 0.3519, Accuracy: 92.36%, Test Loss: 1.6366, Test Accuracy: 52.62%
Epoch [219/250], Loss: 0.3493, Accuracy: 92.45%, Test Loss: 1.6403, Test Accuracy: 52.22%
Epoch [220/250], Loss: 0.3481, Accuracy: 92.39%, Test Loss: 1.6387, Test Accuracy: 52.36%
Epoch [221/250], Loss: 0.3455, Accuracy: 92.52%, Test Loss: 1.6444, Test Accuracy: 52.48%
Epoch [222/250], Loss: 0.3439, Accuracy: 92.62%, Test Loss: 1.6492, Test Accuracy: 51.99%
Epoch [223/250], Loss: 0.3431, Accuracy: 92.57%, Test Loss: 1.6515, Test Accuracy: 52.23%
Epoch [224/250], Loss: 0.3383, Accuracy: 92.92%, Test Loss: 1.6516, Test Accuracy: 52.26%
Epoch [225/250], Loss: 0.3366, Accuracy: 92.93%, Test Loss: 1.6528, Test Accuracy: 52.01%
Epoch [226/250], Loss: 0.3349, Accuracy: 92.88%, Test Loss: 1.6601, Test Accuracy: 52.28%
Epoch [227/250], Loss: 0.3331, Accuracy: 92.95%, Test Loss: 1.6629, Test Accuracy: 52.10%
Epoch [228/250], Loss: 0.3319, Accuracy: 93.02%, Test Loss: 1.6632, Test Accuracy: 52.72%
Epoch [229/250], Loss: 0.3283, Accuracy: 93.13%, Test Loss: 1.6687, Test Accuracy: 52.23%
Epoch [230/250], Loss: 0.3275, Accuracy: 93.25%, Test Loss: 1.6782, Test Accuracy: 52.14%
Epoch [231/250], Loss: 0.3253, Accuracy: 93.24%, Test Loss: 1.6775, Test Accuracy: 52.11%
Epoch [232/250], Loss: 0.3227, Accuracy: 93.38%, Test Loss: 1.6805, Test Accuracy: 52.09%
Epoch [233/250], Loss: 0.3217, Accuracy: 93.46%, Test Loss: 1.6804, Test Accuracy: 51.81%
Epoch [234/250], Loss: 0.3198, Accuracy: 93.49%, Test Loss: 1.6840, Test Accuracy: 52.34%
Epoch [235/250], Loss: 0.3161, Accuracy: 93.60%, Test Loss: 1.6817, Test Accuracy: 52.07%
Epoch [236/250], Loss: 0.3133, Accuracy: 93.76%, Test Loss: 1.6867, Test Accuracy: 52.12%
Epoch [237/250], Loss: 0.3111, Accuracy: 93.83%, Test Loss: 1.6955, Test Accuracy: 51.99%
Epoch [238/250], Loss: 0.3097, Accuracy: 93.86%, Test Loss: 1.6919, Test Accuracy: 52.25%
Epoch [239/250], Loss: 0.3084, Accuracy: 93.83%, Test Loss: 1.6939, Test Accuracy: 52.01%
Epoch [240/250], Loss: 0.3062, Accuracy: 93.94%, Test Loss: 1.7009, Test Accuracy: 52.20%
Epoch [241/250], Loss: 0.3050, Accuracy: 93.92%, Test Loss: 1.6973, Test Accuracy: 52.07%
Epoch [242/250], Loss: 0.3023, Accuracy: 94.05%, Test Loss: 1.7075, Test Accuracy: 51.94%
Epoch [243/250], Loss: 0.3012, Accuracy: 94.07%, Test Loss: 1.7080, Test Accuracy: 52.05%
Epoch [244/250], Loss: 0.2979, Accuracy: 94.21%, Test Loss: 1.7067, Test Accuracy: 52.19%
Epoch [245/250], Loss: 0.2953, Accuracy: 94.34%, Test Loss: 1.7169, Test Accuracy: 51.85%
Epoch [246/250], Loss: 0.2939, Accuracy: 94.46%, Test Loss: 1.7203, Test Accuracy: 51.93%
Epoch [247/250], Loss: 0.2934, Accuracy: 94.29%, Test Loss: 1.7210, Test Accuracy: 51.78%
Epoch [248/250], Loss: 0.2912, Accuracy: 94.29%, Test Loss: 1.7291, Test Accuracy: 52.12%
Epoch [249/250], Loss: 0.2897, Accuracy: 94.44%, Test Loss: 1.7261, Test Accuracy: 52.08%
Epoch [250/250], Loss: 0.2875, Accuracy: 94.57%, Test Loss: 1.7311, Test Accuracy: 52.15%
import matplotlib.pyplot as plt

plt.plot(train_losses, label='train_losses')
plt.plot(test_losses, label='test_losses')
plt.legend()
<matplotlib.legend.Legend at 0x7fb9b0681270>
image
import matplotlib.pyplot as plt

plt.plot(train_accuracies, label='train_accuracy')
plt.plot(test_accuracies, label='test_accuracy')
plt.legend()
<matplotlib.legend.Legend at 0x7fb9b2777250>
image
MLP_Cifar10_ReLU_He_Adam_2H
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.init as init

# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Data

# Load CFashionMNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5, 0.5), (0.5,0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1024, num_workers=10, shuffle=True, drop_last=True)

testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1024, num_workers=10, shuffle=False)
Files already downloaded and verified
Files already downloaded and verified
import matplotlib.pyplot as plt
import numpy as np

# Function to display the images
def imshow(img):
    img = img*0.5 + 0.5
    np_img = img.numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()

for i, (images, labels) in enumerate(trainloader, 0):
    # Plot some images
    imshow(torchvision.utils.make_grid(images[:8]))  # Display 8 images from the batch
    break
image

Model

model = nn.Sequential(
    nn.Flatten(), 
    nn.Linear(32*32*3, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

# Initialize the weights using He initialization
for layer in model:
    if isinstance(layer, nn.Linear):
        init.kaiming_uniform_(layer.weight, nonlinearity='relu')
        if layer.bias is not None:
            layer.bias.data.fill_(0)
            
model = model.to(device)
print(model)
Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=3072, out_features=256, bias=True)
  (2): ReLU()
  (3): Linear(in_features=256, out_features=256, bias=True)
  (4): ReLU()
  (5): Linear(in_features=256, out_features=10, bias=True)
)

Loss, Optimizer, and Evaluation Function

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Function to compute loss and accuracy for test set
def evaluate(model, testloader, criterion):
    model.eval()
    test_loss = 0.0
    running_correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            # Move inputs and labels to the device
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

    accuracy = 100 * running_correct / total
    test_loss = test_loss / len(testloader)
    return test_loss, accuracy

Train

# some parameter
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
max_epoch = 250
# train
for epoch in range(max_epoch):
    model.train()
    running_loss = 0.0
    running_correct = 0   # to track number of correct predictions
    total = 0             # to track total number of samples

    for i, (inputs, labels) in enumerate(trainloader, 0):
        # Move inputs and labels to the device
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

        # Determine class predictions and track accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        running_correct += (predicted == labels).sum().item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()        

    epoch_accuracy = 100 * running_correct / total
    epoch_loss = running_loss / (i + 1)
    
    test_loss, test_accuracy = evaluate(model, testloader, criterion)
    print(f"Epoch [{epoch + 1}/{max_epoch}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    # save for plot
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_accuracy)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
Epoch [1/250], Loss: 1.9254, Accuracy: 32.19%, Test Loss: 1.7350, Test Accuracy: 39.12%
Epoch [2/250], Loss: 1.6684, Accuracy: 41.85%, Test Loss: 1.6295, Test Accuracy: 43.05%
Epoch [3/250], Loss: 1.5807, Accuracy: 44.99%, Test Loss: 1.5740, Test Accuracy: 45.58%
Epoch [4/250], Loss: 1.5196, Accuracy: 47.47%, Test Loss: 1.5384, Test Accuracy: 46.68%
Epoch [5/250], Loss: 1.4718, Accuracy: 49.29%, Test Loss: 1.5114, Test Accuracy: 47.27%
Epoch [6/250], Loss: 1.4309, Accuracy: 50.82%, Test Loss: 1.4893, Test Accuracy: 47.86%
Epoch [7/250], Loss: 1.3943, Accuracy: 52.14%, Test Loss: 1.4679, Test Accuracy: 48.45%
Epoch [8/250], Loss: 1.3618, Accuracy: 53.32%, Test Loss: 1.4528, Test Accuracy: 49.04%
Epoch [9/250], Loss: 1.3311, Accuracy: 54.51%, Test Loss: 1.4429, Test Accuracy: 49.24%
Epoch [10/250], Loss: 1.3052, Accuracy: 55.55%, Test Loss: 1.4294, Test Accuracy: 50.01%
Epoch [11/250], Loss: 1.2794, Accuracy: 56.47%, Test Loss: 1.4163, Test Accuracy: 50.67%
Epoch [12/250], Loss: 1.2553, Accuracy: 57.36%, Test Loss: 1.4095, Test Accuracy: 50.78%
Epoch [13/250], Loss: 1.2329, Accuracy: 58.23%, Test Loss: 1.4026, Test Accuracy: 50.89%
Epoch [14/250], Loss: 1.2122, Accuracy: 58.98%, Test Loss: 1.3908, Test Accuracy: 51.27%
Epoch [15/250], Loss: 1.1911, Accuracy: 59.75%, Test Loss: 1.3889, Test Accuracy: 51.44%
Epoch [16/250], Loss: 1.1715, Accuracy: 60.49%, Test Loss: 1.3836, Test Accuracy: 51.67%
Epoch [17/250], Loss: 1.1558, Accuracy: 61.13%, Test Loss: 1.3812, Test Accuracy: 51.78%
Epoch [18/250], Loss: 1.1353, Accuracy: 61.90%, Test Loss: 1.3761, Test Accuracy: 51.76%
Epoch [19/250], Loss: 1.1182, Accuracy: 62.57%, Test Loss: 1.3722, Test Accuracy: 52.33%
Epoch [20/250], Loss: 1.1030, Accuracy: 63.18%, Test Loss: 1.3638, Test Accuracy: 52.31%
Epoch [21/250], Loss: 1.0844, Accuracy: 63.73%, Test Loss: 1.3646, Test Accuracy: 52.36%
Epoch [22/250], Loss: 1.0684, Accuracy: 64.44%, Test Loss: 1.3616, Test Accuracy: 52.33%
Epoch [23/250], Loss: 1.0544, Accuracy: 64.92%, Test Loss: 1.3608, Test Accuracy: 52.98%
Epoch [24/250], Loss: 1.0402, Accuracy: 65.45%, Test Loss: 1.3585, Test Accuracy: 52.90%
Epoch [25/250], Loss: 1.0251, Accuracy: 66.07%, Test Loss: 1.3604, Test Accuracy: 52.59%
Epoch [26/250], Loss: 1.0118, Accuracy: 66.53%, Test Loss: 1.3606, Test Accuracy: 52.95%
Epoch [27/250], Loss: 0.9964, Accuracy: 67.10%, Test Loss: 1.3588, Test Accuracy: 52.72%
Epoch [28/250], Loss: 0.9822, Accuracy: 67.73%, Test Loss: 1.3567, Test Accuracy: 52.78%
Epoch [29/250], Loss: 0.9698, Accuracy: 68.03%, Test Loss: 1.3584, Test Accuracy: 52.90%
Epoch [30/250], Loss: 0.9578, Accuracy: 68.43%, Test Loss: 1.3537, Test Accuracy: 52.81%
Epoch [31/250], Loss: 0.9428, Accuracy: 69.18%, Test Loss: 1.3589, Test Accuracy: 53.09%
Epoch [32/250], Loss: 0.9323, Accuracy: 69.54%, Test Loss: 1.3590, Test Accuracy: 53.44%
Epoch [33/250], Loss: 0.9210, Accuracy: 69.99%, Test Loss: 1.3623, Test Accuracy: 53.08%
Epoch [34/250], Loss: 0.9082, Accuracy: 70.44%, Test Loss: 1.3637, Test Accuracy: 53.17%
Epoch [35/250], Loss: 0.8963, Accuracy: 70.97%, Test Loss: 1.3663, Test Accuracy: 53.05%
Epoch [36/250], Loss: 0.8818, Accuracy: 71.50%, Test Loss: 1.3649, Test Accuracy: 53.61%
Epoch [37/250], Loss: 0.8725, Accuracy: 71.85%, Test Loss: 1.3661, Test Accuracy: 52.92%
Epoch [38/250], Loss: 0.8607, Accuracy: 72.23%, Test Loss: 1.3697, Test Accuracy: 53.36%
Epoch [39/250], Loss: 0.8491, Accuracy: 72.68%, Test Loss: 1.3673, Test Accuracy: 53.43%
Epoch [40/250], Loss: 0.8362, Accuracy: 73.13%, Test Loss: 1.3756, Test Accuracy: 53.03%
Epoch [41/250], Loss: 0.8309, Accuracy: 73.37%, Test Loss: 1.3841, Test Accuracy: 52.50%
Epoch [42/250], Loss: 0.8167, Accuracy: 73.86%, Test Loss: 1.3774, Test Accuracy: 53.00%
Epoch [43/250], Loss: 0.8050, Accuracy: 74.47%, Test Loss: 1.3812, Test Accuracy: 53.14%
Epoch [44/250], Loss: 0.7963, Accuracy: 74.66%, Test Loss: 1.3790, Test Accuracy: 53.37%
Epoch [45/250], Loss: 0.7871, Accuracy: 75.15%, Test Loss: 1.3863, Test Accuracy: 53.20%
Epoch [46/250], Loss: 0.7739, Accuracy: 75.62%, Test Loss: 1.3883, Test Accuracy: 53.03%
Epoch [47/250], Loss: 0.7659, Accuracy: 75.78%, Test Loss: 1.3892, Test Accuracy: 53.36%
Epoch [48/250], Loss: 0.7554, Accuracy: 76.35%, Test Loss: 1.3955, Test Accuracy: 53.33%
Epoch [49/250], Loss: 0.7474, Accuracy: 76.58%, Test Loss: 1.3943, Test Accuracy: 53.70%
Epoch [50/250], Loss: 0.7358, Accuracy: 76.99%, Test Loss: 1.4002, Test Accuracy: 53.79%
Epoch [51/250], Loss: 0.7309, Accuracy: 77.01%, Test Loss: 1.4082, Test Accuracy: 53.00%
Epoch [52/250], Loss: 0.7193, Accuracy: 77.52%, Test Loss: 1.4102, Test Accuracy: 53.12%
Epoch [53/250], Loss: 0.7084, Accuracy: 77.99%, Test Loss: 1.4164, Test Accuracy: 53.18%
Epoch [54/250], Loss: 0.7017, Accuracy: 78.40%, Test Loss: 1.4132, Test Accuracy: 53.17%
Epoch [55/250], Loss: 0.6918, Accuracy: 78.61%, Test Loss: 1.4157, Test Accuracy: 53.24%
Epoch [56/250], Loss: 0.6821, Accuracy: 78.98%, Test Loss: 1.4296, Test Accuracy: 53.12%
Epoch [57/250], Loss: 0.6736, Accuracy: 79.30%, Test Loss: 1.4299, Test Accuracy: 53.07%
Epoch [58/250], Loss: 0.6646, Accuracy: 79.68%, Test Loss: 1.4340, Test Accuracy: 52.68%
Epoch [59/250], Loss: 0.6579, Accuracy: 79.90%, Test Loss: 1.4416, Test Accuracy: 53.03%
Epoch [60/250], Loss: 0.6493, Accuracy: 80.17%, Test Loss: 1.4424, Test Accuracy: 53.20%
Epoch [61/250], Loss: 0.6391, Accuracy: 80.70%, Test Loss: 1.4492, Test Accuracy: 52.90%
Epoch [62/250], Loss: 0.6326, Accuracy: 80.95%, Test Loss: 1.4480, Test Accuracy: 53.26%
Epoch [63/250], Loss: 0.6235, Accuracy: 81.30%, Test Loss: 1.4605, Test Accuracy: 52.89%
Epoch [64/250], Loss: 0.6155, Accuracy: 81.54%, Test Loss: 1.4678, Test Accuracy: 52.59%
Epoch [65/250], Loss: 0.6065, Accuracy: 81.92%, Test Loss: 1.4627, Test Accuracy: 52.81%
Epoch [66/250], Loss: 0.6004, Accuracy: 82.04%, Test Loss: 1.4697, Test Accuracy: 52.61%
Epoch [67/250], Loss: 0.5911, Accuracy: 82.57%, Test Loss: 1.4782, Test Accuracy: 52.62%
Epoch [68/250], Loss: 0.5845, Accuracy: 82.78%, Test Loss: 1.4815, Test Accuracy: 52.64%
Epoch [69/250], Loss: 0.5785, Accuracy: 82.92%, Test Loss: 1.4813, Test Accuracy: 53.00%
Epoch [70/250], Loss: 0.5709, Accuracy: 83.25%, Test Loss: 1.4885, Test Accuracy: 52.78%
Epoch [71/250], Loss: 0.5616, Accuracy: 83.63%, Test Loss: 1.4917, Test Accuracy: 52.91%
Epoch [72/250], Loss: 0.5539, Accuracy: 83.92%, Test Loss: 1.5010, Test Accuracy: 52.62%
Epoch [73/250], Loss: 0.5494, Accuracy: 84.15%, Test Loss: 1.5083, Test Accuracy: 52.90%
Epoch [74/250], Loss: 0.5411, Accuracy: 84.53%, Test Loss: 1.5130, Test Accuracy: 52.86%
Epoch [75/250], Loss: 0.5350, Accuracy: 84.65%, Test Loss: 1.5112, Test Accuracy: 52.54%
Epoch [76/250], Loss: 0.5284, Accuracy: 84.73%, Test Loss: 1.5167, Test Accuracy: 52.57%
Epoch [77/250], Loss: 0.5228, Accuracy: 85.00%, Test Loss: 1.5242, Test Accuracy: 53.05%
Epoch [78/250], Loss: 0.5132, Accuracy: 85.43%, Test Loss: 1.5355, Test Accuracy: 52.45%
Epoch [79/250], Loss: 0.5084, Accuracy: 85.72%, Test Loss: 1.5388, Test Accuracy: 52.86%
Epoch [80/250], Loss: 0.5004, Accuracy: 86.16%, Test Loss: 1.5502, Test Accuracy: 52.69%
Epoch [81/250], Loss: 0.4949, Accuracy: 86.14%, Test Loss: 1.5541, Test Accuracy: 52.56%
Epoch [82/250], Loss: 0.4872, Accuracy: 86.48%, Test Loss: 1.5620, Test Accuracy: 52.50%
Epoch [83/250], Loss: 0.4798, Accuracy: 86.74%, Test Loss: 1.5651, Test Accuracy: 52.34%
Epoch [84/250], Loss: 0.4741, Accuracy: 86.96%, Test Loss: 1.5675, Test Accuracy: 52.82%
Epoch [85/250], Loss: 0.4661, Accuracy: 87.25%, Test Loss: 1.5740, Test Accuracy: 52.68%
Epoch [86/250], Loss: 0.4628, Accuracy: 87.31%, Test Loss: 1.5893, Test Accuracy: 52.05%
Epoch [87/250], Loss: 0.4542, Accuracy: 87.69%, Test Loss: 1.5933, Test Accuracy: 52.43%
Epoch [88/250], Loss: 0.4495, Accuracy: 87.89%, Test Loss: 1.6008, Test Accuracy: 52.41%
Epoch [89/250], Loss: 0.4439, Accuracy: 88.10%, Test Loss: 1.6006, Test Accuracy: 52.63%
Epoch [90/250], Loss: 0.4378, Accuracy: 88.28%, Test Loss: 1.6151, Test Accuracy: 52.40%
Epoch [91/250], Loss: 0.4304, Accuracy: 88.58%, Test Loss: 1.6169, Test Accuracy: 52.44%
Epoch [92/250], Loss: 0.4247, Accuracy: 88.92%, Test Loss: 1.6262, Test Accuracy: 52.66%
Epoch [93/250], Loss: 0.4199, Accuracy: 89.00%, Test Loss: 1.6206, Test Accuracy: 52.58%
Epoch [94/250], Loss: 0.4124, Accuracy: 89.46%, Test Loss: 1.6325, Test Accuracy: 52.19%
Epoch [95/250], Loss: 0.4089, Accuracy: 89.43%, Test Loss: 1.6422, Test Accuracy: 52.22%
Epoch [96/250], Loss: 0.4047, Accuracy: 89.54%, Test Loss: 1.6578, Test Accuracy: 52.28%
Epoch [97/250], Loss: 0.3973, Accuracy: 89.76%, Test Loss: 1.6565, Test Accuracy: 52.09%
Epoch [98/250], Loss: 0.3908, Accuracy: 90.16%, Test Loss: 1.6646, Test Accuracy: 52.23%
Epoch [99/250], Loss: 0.3866, Accuracy: 90.24%, Test Loss: 1.6698, Test Accuracy: 52.35%
Epoch [100/250], Loss: 0.3790, Accuracy: 90.54%, Test Loss: 1.6779, Test Accuracy: 52.13%
Epoch [101/250], Loss: 0.3755, Accuracy: 90.75%, Test Loss: 1.6849, Test Accuracy: 52.19%
Epoch [102/250], Loss: 0.3709, Accuracy: 90.73%, Test Loss: 1.6979, Test Accuracy: 52.31%
Epoch [103/250], Loss: 0.3647, Accuracy: 91.05%, Test Loss: 1.7060, Test Accuracy: 52.46%
Epoch [104/250], Loss: 0.3635, Accuracy: 91.02%, Test Loss: 1.7104, Test Accuracy: 52.13%
Epoch [105/250], Loss: 0.3558, Accuracy: 91.23%, Test Loss: 1.7170, Test Accuracy: 51.94%
Epoch [106/250], Loss: 0.3521, Accuracy: 91.51%, Test Loss: 1.7224, Test Accuracy: 52.21%
Epoch [107/250], Loss: 0.3450, Accuracy: 91.74%, Test Loss: 1.7294, Test Accuracy: 51.99%
Epoch [108/250], Loss: 0.3384, Accuracy: 91.98%, Test Loss: 1.7335, Test Accuracy: 51.96%
Epoch [109/250], Loss: 0.3333, Accuracy: 92.17%, Test Loss: 1.7450, Test Accuracy: 51.91%
Epoch [110/250], Loss: 0.3294, Accuracy: 92.27%, Test Loss: 1.7536, Test Accuracy: 51.97%
Epoch [111/250], Loss: 0.3262, Accuracy: 92.48%, Test Loss: 1.7562, Test Accuracy: 52.30%
Epoch [112/250], Loss: 0.3195, Accuracy: 92.75%, Test Loss: 1.7693, Test Accuracy: 52.18%
Epoch [113/250], Loss: 0.3168, Accuracy: 92.81%, Test Loss: 1.7764, Test Accuracy: 52.15%
Epoch [114/250], Loss: 0.3119, Accuracy: 92.97%, Test Loss: 1.7822, Test Accuracy: 51.97%
Epoch [115/250], Loss: 0.3065, Accuracy: 93.21%, Test Loss: 1.8011, Test Accuracy: 51.85%
Epoch [116/250], Loss: 0.3024, Accuracy: 93.28%, Test Loss: 1.8024, Test Accuracy: 52.02%
Epoch [117/250], Loss: 0.3012, Accuracy: 93.21%, Test Loss: 1.8034, Test Accuracy: 51.90%
Epoch [118/250], Loss: 0.2937, Accuracy: 93.58%, Test Loss: 1.8105, Test Accuracy: 52.13%
Epoch [119/250], Loss: 0.2905, Accuracy: 93.56%, Test Loss: 1.8337, Test Accuracy: 51.80%
Epoch [120/250], Loss: 0.2868, Accuracy: 93.67%, Test Loss: 1.8310, Test Accuracy: 51.94%
Epoch [121/250], Loss: 0.2810, Accuracy: 94.05%, Test Loss: 1.8346, Test Accuracy: 51.80%
Epoch [122/250], Loss: 0.2781, Accuracy: 94.03%, Test Loss: 1.8553, Test Accuracy: 51.71%
Epoch [123/250], Loss: 0.2726, Accuracy: 94.30%, Test Loss: 1.8573, Test Accuracy: 51.51%
Epoch [124/250], Loss: 0.2688, Accuracy: 94.37%, Test Loss: 1.8654, Test Accuracy: 51.53%
Epoch [125/250], Loss: 0.2626, Accuracy: 94.69%, Test Loss: 1.8747, Test Accuracy: 51.55%
Epoch [126/250], Loss: 0.2613, Accuracy: 94.60%, Test Loss: 1.8833, Test Accuracy: 51.60%
Epoch [127/250], Loss: 0.2572, Accuracy: 94.71%, Test Loss: 1.8909, Test Accuracy: 51.46%
Epoch [128/250], Loss: 0.2509, Accuracy: 95.05%, Test Loss: 1.9084, Test Accuracy: 51.35%
Epoch [129/250], Loss: 0.2487, Accuracy: 95.01%, Test Loss: 1.9148, Test Accuracy: 51.59%
Epoch [130/250], Loss: 0.2461, Accuracy: 95.06%, Test Loss: 1.9113, Test Accuracy: 51.89%
Epoch [131/250], Loss: 0.2415, Accuracy: 95.33%, Test Loss: 1.9205, Test Accuracy: 51.63%
Epoch [132/250], Loss: 0.2369, Accuracy: 95.35%, Test Loss: 1.9244, Test Accuracy: 51.32%
Epoch [133/250], Loss: 0.2360, Accuracy: 95.44%, Test Loss: 1.9484, Test Accuracy: 51.47%
Epoch [134/250], Loss: 0.2303, Accuracy: 95.61%, Test Loss: 1.9518, Test Accuracy: 51.55%
Epoch [135/250], Loss: 0.2269, Accuracy: 95.68%, Test Loss: 1.9611, Test Accuracy: 51.25%
Epoch [136/250], Loss: 0.2225, Accuracy: 95.90%, Test Loss: 1.9637, Test Accuracy: 51.24%
Epoch [137/250], Loss: 0.2195, Accuracy: 95.97%, Test Loss: 1.9775, Test Accuracy: 51.17%
Epoch [138/250], Loss: 0.2151, Accuracy: 96.18%, Test Loss: 1.9850, Test Accuracy: 51.07%
Epoch [139/250], Loss: 0.2151, Accuracy: 95.98%, Test Loss: 2.0011, Test Accuracy: 51.34%
Epoch [140/250], Loss: 0.2081, Accuracy: 96.32%, Test Loss: 2.0006, Test Accuracy: 51.15%
Epoch [141/250], Loss: 0.2050, Accuracy: 96.41%, Test Loss: 2.0105, Test Accuracy: 51.28%
Epoch [142/250], Loss: 0.2020, Accuracy: 96.44%, Test Loss: 2.0170, Test Accuracy: 51.26%
Epoch [143/250], Loss: 0.2003, Accuracy: 96.48%, Test Loss: 2.0333, Test Accuracy: 50.95%
Epoch [144/250], Loss: 0.1946, Accuracy: 96.76%, Test Loss: 2.0372, Test Accuracy: 51.29%
Epoch [145/250], Loss: 0.1944, Accuracy: 96.59%, Test Loss: 2.0504, Test Accuracy: 51.34%
Epoch [146/250], Loss: 0.1923, Accuracy: 96.71%, Test Loss: 2.0615, Test Accuracy: 50.95%
Epoch [147/250], Loss: 0.1867, Accuracy: 96.92%, Test Loss: 2.0716, Test Accuracy: 51.22%
Epoch [148/250], Loss: 0.1835, Accuracy: 96.96%, Test Loss: 2.0813, Test Accuracy: 51.43%
Epoch [149/250], Loss: 0.1799, Accuracy: 97.15%, Test Loss: 2.0765, Test Accuracy: 51.22%
Epoch [150/250], Loss: 0.1779, Accuracy: 97.16%, Test Loss: 2.0937, Test Accuracy: 51.21%
Epoch [151/250], Loss: 0.1752, Accuracy: 97.29%, Test Loss: 2.1074, Test Accuracy: 51.26%
Epoch [152/250], Loss: 0.1702, Accuracy: 97.42%, Test Loss: 2.1200, Test Accuracy: 50.85%
Epoch [153/250], Loss: 0.1701, Accuracy: 97.39%, Test Loss: 2.1278, Test Accuracy: 51.17%
Epoch [154/250], Loss: 0.1652, Accuracy: 97.53%, Test Loss: 2.1383, Test Accuracy: 51.23%
Epoch [155/250], Loss: 0.1633, Accuracy: 97.55%, Test Loss: 2.1393, Test Accuracy: 50.97%
Epoch [156/250], Loss: 0.1593, Accuracy: 97.66%, Test Loss: 2.1599, Test Accuracy: 51.18%
Epoch [157/250], Loss: 0.1597, Accuracy: 97.64%, Test Loss: 2.1568, Test Accuracy: 51.07%
Epoch [158/250], Loss: 0.1542, Accuracy: 97.83%, Test Loss: 2.1847, Test Accuracy: 50.94%
Epoch [159/250], Loss: 0.1527, Accuracy: 97.87%, Test Loss: 2.1820, Test Accuracy: 51.24%
Epoch [160/250], Loss: 0.1494, Accuracy: 97.91%, Test Loss: 2.1966, Test Accuracy: 50.98%
Epoch [161/250], Loss: 0.1464, Accuracy: 98.02%, Test Loss: 2.1983, Test Accuracy: 51.14%
Epoch [162/250], Loss: 0.1458, Accuracy: 98.04%, Test Loss: 2.2314, Test Accuracy: 50.57%
Epoch [163/250], Loss: 0.1425, Accuracy: 98.12%, Test Loss: 2.2255, Test Accuracy: 50.99%
Epoch [164/250], Loss: 0.1403, Accuracy: 98.14%, Test Loss: 2.2305, Test Accuracy: 50.81%
Epoch [165/250], Loss: 0.1370, Accuracy: 98.23%, Test Loss: 2.2377, Test Accuracy: 51.06%
Epoch [166/250], Loss: 0.1359, Accuracy: 98.25%, Test Loss: 2.2457, Test Accuracy: 50.93%
Epoch [167/250], Loss: 0.1321, Accuracy: 98.41%, Test Loss: 2.2594, Test Accuracy: 51.03%
Epoch [168/250], Loss: 0.1292, Accuracy: 98.45%, Test Loss: 2.2763, Test Accuracy: 50.71%
Epoch [169/250], Loss: 0.1283, Accuracy: 98.43%, Test Loss: 2.2754, Test Accuracy: 50.96%
Epoch [170/250], Loss: 0.1244, Accuracy: 98.59%, Test Loss: 2.2943, Test Accuracy: 50.80%
Epoch [171/250], Loss: 0.1247, Accuracy: 98.50%, Test Loss: 2.3026, Test Accuracy: 50.68%
Epoch [172/250], Loss: 0.1195, Accuracy: 98.69%, Test Loss: 2.3068, Test Accuracy: 50.71%
Epoch [173/250], Loss: 0.1182, Accuracy: 98.69%, Test Loss: 2.3296, Test Accuracy: 50.87%
Epoch [174/250], Loss: 0.1162, Accuracy: 98.74%, Test Loss: 2.3525, Test Accuracy: 50.81%
Epoch [175/250], Loss: 0.1139, Accuracy: 98.78%, Test Loss: 2.3493, Test Accuracy: 51.05%
Epoch [176/250], Loss: 0.1119, Accuracy: 98.86%, Test Loss: 2.3571, Test Accuracy: 50.96%
Epoch [177/250], Loss: 0.1104, Accuracy: 98.86%, Test Loss: 2.3641, Test Accuracy: 50.97%
Epoch [178/250], Loss: 0.1087, Accuracy: 98.85%, Test Loss: 2.3757, Test Accuracy: 50.97%
Epoch [179/250], Loss: 0.1079, Accuracy: 98.89%, Test Loss: 2.3920, Test Accuracy: 50.80%
Epoch [180/250], Loss: 0.1042, Accuracy: 99.02%, Test Loss: 2.3943, Test Accuracy: 50.76%
Epoch [181/250], Loss: 0.1032, Accuracy: 98.93%, Test Loss: 2.4006, Test Accuracy: 50.60%
Epoch [182/250], Loss: 0.1022, Accuracy: 99.00%, Test Loss: 2.4171, Test Accuracy: 50.54%
Epoch [183/250], Loss: 0.0965, Accuracy: 99.21%, Test Loss: 2.4380, Test Accuracy: 50.60%
Epoch [184/250], Loss: 0.0968, Accuracy: 99.14%, Test Loss: 2.4400, Test Accuracy: 50.73%
Epoch [185/250], Loss: 0.0937, Accuracy: 99.19%, Test Loss: 2.4558, Test Accuracy: 50.55%
Epoch [186/250], Loss: 0.0944, Accuracy: 99.15%, Test Loss: 2.4646, Test Accuracy: 50.65%
Epoch [187/250], Loss: 0.0924, Accuracy: 99.19%, Test Loss: 2.4750, Test Accuracy: 50.90%
Epoch [188/250], Loss: 0.0879, Accuracy: 99.34%, Test Loss: 2.4826, Test Accuracy: 50.66%
Epoch [189/250], Loss: 0.0881, Accuracy: 99.24%, Test Loss: 2.5007, Test Accuracy: 50.45%
Epoch [190/250], Loss: 0.0868, Accuracy: 99.29%, Test Loss: 2.5076, Test Accuracy: 50.21%
Epoch [191/250], Loss: 0.0875, Accuracy: 99.28%, Test Loss: 2.5276, Test Accuracy: 50.26%
Epoch [192/250], Loss: 0.0838, Accuracy: 99.38%, Test Loss: 2.5265, Test Accuracy: 50.55%
Epoch [193/250], Loss: 0.0843, Accuracy: 99.26%, Test Loss: 2.5394, Test Accuracy: 50.60%
Epoch [194/250], Loss: 0.0815, Accuracy: 99.34%, Test Loss: 2.5547, Test Accuracy: 50.43%
Epoch [195/250], Loss: 0.0791, Accuracy: 99.45%, Test Loss: 2.5622, Test Accuracy: 50.44%
Epoch [196/250], Loss: 0.0776, Accuracy: 99.46%, Test Loss: 2.5794, Test Accuracy: 50.51%
Epoch [197/250], Loss: 0.0763, Accuracy: 99.44%, Test Loss: 2.5826, Test Accuracy: 50.44%
Epoch [198/250], Loss: 0.0740, Accuracy: 99.48%, Test Loss: 2.5922, Test Accuracy: 50.71%
Epoch [199/250], Loss: 0.0718, Accuracy: 99.55%, Test Loss: 2.6090, Test Accuracy: 50.88%
Epoch [200/250], Loss: 0.0712, Accuracy: 99.59%, Test Loss: 2.6219, Test Accuracy: 50.42%
Epoch [201/250], Loss: 0.0704, Accuracy: 99.53%, Test Loss: 2.6283, Test Accuracy: 50.69%
Epoch [202/250], Loss: 0.0675, Accuracy: 99.61%, Test Loss: 2.6387, Test Accuracy: 50.28%
Epoch [203/250], Loss: 0.0678, Accuracy: 99.59%, Test Loss: 2.6469, Test Accuracy: 50.43%
Epoch [204/250], Loss: 0.0659, Accuracy: 99.61%, Test Loss: 2.6603, Test Accuracy: 50.84%
Epoch [205/250], Loss: 0.0652, Accuracy: 99.61%, Test Loss: 2.6863, Test Accuracy: 50.31%
Epoch [206/250], Loss: 0.0629, Accuracy: 99.68%, Test Loss: 2.6856, Test Accuracy: 49.94%
Epoch [207/250], Loss: 0.0615, Accuracy: 99.70%, Test Loss: 2.6935, Test Accuracy: 50.43%
Epoch [208/250], Loss: 0.0609, Accuracy: 99.67%, Test Loss: 2.7190, Test Accuracy: 50.50%
Epoch [209/250], Loss: 0.0608, Accuracy: 99.66%, Test Loss: 2.7367, Test Accuracy: 50.17%
Epoch [210/250], Loss: 0.0596, Accuracy: 99.68%, Test Loss: 2.7311, Test Accuracy: 50.73%
Epoch [211/250], Loss: 0.0578, Accuracy: 99.74%, Test Loss: 2.7441, Test Accuracy: 50.29%
Epoch [212/250], Loss: 0.0560, Accuracy: 99.75%, Test Loss: 2.7495, Test Accuracy: 50.34%
Epoch [213/250], Loss: 0.0547, Accuracy: 99.75%, Test Loss: 2.7696, Test Accuracy: 50.27%
Epoch [214/250], Loss: 0.0543, Accuracy: 99.77%, Test Loss: 2.7791, Test Accuracy: 50.51%
Epoch [215/250], Loss: 0.0532, Accuracy: 99.76%, Test Loss: 2.7860, Test Accuracy: 50.45%
Epoch [216/250], Loss: 0.0511, Accuracy: 99.78%, Test Loss: 2.8107, Test Accuracy: 50.18%
Epoch [217/250], Loss: 0.0511, Accuracy: 99.79%, Test Loss: 2.8106, Test Accuracy: 50.37%
Epoch [218/250], Loss: 0.0493, Accuracy: 99.80%, Test Loss: 2.8200, Test Accuracy: 50.27%
Epoch [219/250], Loss: 0.0478, Accuracy: 99.85%, Test Loss: 2.8347, Test Accuracy: 50.40%
Epoch [220/250], Loss: 0.0480, Accuracy: 99.79%, Test Loss: 2.8526, Test Accuracy: 50.53%
Epoch [221/250], Loss: 0.0468, Accuracy: 99.83%, Test Loss: 2.8584, Test Accuracy: 50.35%
Epoch [222/250], Loss: 0.0467, Accuracy: 99.82%, Test Loss: 2.8650, Test Accuracy: 50.58%
Epoch [223/250], Loss: 0.0456, Accuracy: 99.83%, Test Loss: 2.8814, Test Accuracy: 50.34%
Epoch [224/250], Loss: 0.0440, Accuracy: 99.86%, Test Loss: 2.8883, Test Accuracy: 50.69%
Epoch [225/250], Loss: 0.0445, Accuracy: 99.83%, Test Loss: 2.8958, Test Accuracy: 50.30%
Epoch [226/250], Loss: 0.0420, Accuracy: 99.86%, Test Loss: 2.9166, Test Accuracy: 50.41%
Epoch [227/250], Loss: 0.0415, Accuracy: 99.88%, Test Loss: 2.9259, Test Accuracy: 50.16%
Epoch [228/250], Loss: 0.0399, Accuracy: 99.87%, Test Loss: 2.9387, Test Accuracy: 50.33%
Epoch [229/250], Loss: 0.0391, Accuracy: 99.89%, Test Loss: 2.9500, Test Accuracy: 50.52%
Epoch [230/250], Loss: 0.0402, Accuracy: 99.85%, Test Loss: 2.9676, Test Accuracy: 50.53%
Epoch [231/250], Loss: 0.0387, Accuracy: 99.88%, Test Loss: 2.9769, Test Accuracy: 50.38%
Epoch [232/250], Loss: 0.0372, Accuracy: 99.89%, Test Loss: 2.9765, Test Accuracy: 50.63%
Epoch [233/250], Loss: 0.0364, Accuracy: 99.91%, Test Loss: 2.9913, Test Accuracy: 50.36%
Epoch [234/250], Loss: 0.0349, Accuracy: 99.92%, Test Loss: 3.0146, Test Accuracy: 50.26%
Epoch [235/250], Loss: 0.0357, Accuracy: 99.90%, Test Loss: 3.0179, Test Accuracy: 50.34%
Epoch [236/250], Loss: 0.0341, Accuracy: 99.92%, Test Loss: 3.0323, Test Accuracy: 50.22%
Epoch [237/250], Loss: 0.0331, Accuracy: 99.92%, Test Loss: 3.0472, Test Accuracy: 50.46%
Epoch [238/250], Loss: 0.0321, Accuracy: 99.95%, Test Loss: 3.0588, Test Accuracy: 50.29%
Epoch [239/250], Loss: 0.0340, Accuracy: 99.90%, Test Loss: 3.0873, Test Accuracy: 50.10%
Epoch [240/250], Loss: 0.0315, Accuracy: 99.94%, Test Loss: 3.0860, Test Accuracy: 50.30%
Epoch [241/250], Loss: 0.0306, Accuracy: 99.94%, Test Loss: 3.0979, Test Accuracy: 50.40%
Epoch [242/250], Loss: 0.0305, Accuracy: 99.92%, Test Loss: 3.1161, Test Accuracy: 50.35%
Epoch [243/250], Loss: 0.0314, Accuracy: 99.91%, Test Loss: 3.1172, Test Accuracy: 50.38%
Epoch [244/250], Loss: 0.0302, Accuracy: 99.93%, Test Loss: 3.1366, Test Accuracy: 50.41%
Epoch [245/250], Loss: 0.0282, Accuracy: 99.97%, Test Loss: 3.1467, Test Accuracy: 50.55%
Epoch [246/250], Loss: 0.0281, Accuracy: 99.93%, Test Loss: 3.1554, Test Accuracy: 50.03%
Epoch [247/250], Loss: 0.0290, Accuracy: 99.93%, Test Loss: 3.1704, Test Accuracy: 50.17%
Epoch [248/250], Loss: 0.0283, Accuracy: 99.93%, Test Loss: 3.1817, Test Accuracy: 50.15%
Epoch [249/250], Loss: 0.0275, Accuracy: 99.94%, Test Loss: 3.1916, Test Accuracy: 50.36%
Epoch [250/250], Loss: 0.0266, Accuracy: 99.95%, Test Loss: 3.2076, Test Accuracy: 50.21%
import matplotlib.pyplot as plt

plt.plot(train_losses, label='train_losses')
plt.plot(test_losses, label='test_losses')
plt.legend()
<matplotlib.legend.Legend at 0x7fab8407f310>
image
import matplotlib.pyplot as plt

plt.plot(train_accuracies, label='train_accuracy')
plt.plot(test_accuracies, label='test_accuracy')
plt.legend()
<matplotlib.legend.Legend at 0x7fab726668f0>
image
MLP_Cifar10_ReLU_He_Adam_3H
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.init as init

# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Data

# Load CFashionMNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5, 0.5), (0.5,0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1024, num_workers=10, shuffle=True, drop_last=True)

testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1024, num_workers=10, shuffle=False)
Files already downloaded and verified
Files already downloaded and verified
import matplotlib.pyplot as plt
import numpy as np

# Function to display the images
def imshow(img):
    img = img*0.5 + 0.5
    np_img = img.numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()

for i, (images, labels) in enumerate(trainloader, 0):
    # Plot some images
    imshow(torchvision.utils.make_grid(images[:8]))  # Display 8 images from the batch
    break
image

Model

model = nn.Sequential(
    nn.Flatten(), 
    nn.Linear(32*32*3, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

# Initialize the weights using Xavier Glorot initialization
for layer in model:
    if isinstance(layer, nn.Linear):
        init.kaiming_uniform_(layer.weight, nonlinearity='relu')
        if layer.bias is not None:
            layer.bias.data.fill_(0)
            
model = model.to(device)
print(model)
Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=3072, out_features=256, bias=True)
  (2): ReLU()
  (3): Linear(in_features=256, out_features=256, bias=True)
  (4): ReLU()
  (5): Linear(in_features=256, out_features=256, bias=True)
  (6): ReLU()
  (7): Linear(in_features=256, out_features=10, bias=True)
)

Loss, Optimizer, and Evaluation Function

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Function to compute loss and accuracy for test set
def evaluate(model, testloader, criterion):
    model.eval()
    test_loss = 0.0
    running_correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            # Move inputs and labels to the device
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

    accuracy = 100 * running_correct / total
    test_loss = test_loss / len(testloader)
    return test_loss, accuracy

Train

# some parameter
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
max_epoch = 250
# train
for epoch in range(max_epoch):
    model.train()
    running_loss = 0.0
    running_correct = 0   # to track number of correct predictions
    total = 0             # to track total number of samples

    for i, (inputs, labels) in enumerate(trainloader, 0):
        # Move inputs and labels to the device
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

        # Determine class predictions and track accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        running_correct += (predicted == labels).sum().item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()        

    epoch_accuracy = 100 * running_correct / total
    epoch_loss = running_loss / (i + 1)
    
    test_loss, test_accuracy = evaluate(model, testloader, criterion)
    print(f"Epoch [{epoch + 1}/{max_epoch}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    # save for plot
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_accuracy)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
Epoch [1/250], Loss: 1.9471, Accuracy: 30.83%, Test Loss: 1.7424, Test Accuracy: 39.44%
Epoch [2/250], Loss: 1.6656, Accuracy: 41.99%, Test Loss: 1.6216, Test Accuracy: 43.67%
Epoch [3/250], Loss: 1.5597, Accuracy: 45.70%, Test Loss: 1.5577, Test Accuracy: 46.03%
Epoch [4/250], Loss: 1.4925, Accuracy: 48.36%, Test Loss: 1.5192, Test Accuracy: 47.23%
Epoch [5/250], Loss: 1.4407, Accuracy: 50.09%, Test Loss: 1.4941, Test Accuracy: 47.89%
Epoch [6/250], Loss: 1.3949, Accuracy: 51.99%, Test Loss: 1.4663, Test Accuracy: 48.64%
Epoch [7/250], Loss: 1.3509, Accuracy: 53.52%, Test Loss: 1.4555, Test Accuracy: 49.17%
Epoch [8/250], Loss: 1.3162, Accuracy: 54.87%, Test Loss: 1.4326, Test Accuracy: 49.61%
Epoch [9/250], Loss: 1.2829, Accuracy: 55.97%, Test Loss: 1.4256, Test Accuracy: 50.11%
Epoch [10/250], Loss: 1.2496, Accuracy: 57.29%, Test Loss: 1.4056, Test Accuracy: 50.85%
Epoch [11/250], Loss: 1.2210, Accuracy: 58.31%, Test Loss: 1.4033, Test Accuracy: 50.75%
Epoch [12/250], Loss: 1.1943, Accuracy: 59.33%, Test Loss: 1.3971, Test Accuracy: 50.53%
Epoch [13/250], Loss: 1.1668, Accuracy: 60.30%, Test Loss: 1.3925, Test Accuracy: 50.81%
Epoch [14/250], Loss: 1.1433, Accuracy: 61.01%, Test Loss: 1.3830, Test Accuracy: 51.42%
Epoch [15/250], Loss: 1.1175, Accuracy: 62.21%, Test Loss: 1.3829, Test Accuracy: 51.58%
Epoch [16/250], Loss: 1.0943, Accuracy: 63.16%, Test Loss: 1.3767, Test Accuracy: 51.78%
Epoch [17/250], Loss: 1.0748, Accuracy: 63.72%, Test Loss: 1.3729, Test Accuracy: 51.97%
Epoch [18/250], Loss: 1.0533, Accuracy: 64.58%, Test Loss: 1.3726, Test Accuracy: 51.99%
Epoch [19/250], Loss: 1.0324, Accuracy: 65.26%, Test Loss: 1.3748, Test Accuracy: 51.81%
Epoch [20/250], Loss: 1.0139, Accuracy: 66.06%, Test Loss: 1.3723, Test Accuracy: 52.17%
Epoch [21/250], Loss: 0.9913, Accuracy: 66.79%, Test Loss: 1.3782, Test Accuracy: 52.06%
Epoch [22/250], Loss: 0.9755, Accuracy: 67.41%, Test Loss: 1.3746, Test Accuracy: 52.28%
Epoch [23/250], Loss: 0.9544, Accuracy: 68.18%, Test Loss: 1.3749, Test Accuracy: 52.17%
Epoch [24/250], Loss: 0.9335, Accuracy: 69.03%, Test Loss: 1.3850, Test Accuracy: 51.64%
Epoch [25/250], Loss: 0.9171, Accuracy: 69.57%, Test Loss: 1.3849, Test Accuracy: 52.06%
Epoch [26/250], Loss: 0.8987, Accuracy: 70.19%, Test Loss: 1.3926, Test Accuracy: 52.16%
Epoch [27/250], Loss: 0.8820, Accuracy: 70.95%, Test Loss: 1.3862, Test Accuracy: 52.25%
Epoch [28/250], Loss: 0.8669, Accuracy: 71.43%, Test Loss: 1.3982, Test Accuracy: 52.42%
Epoch [29/250], Loss: 0.8490, Accuracy: 72.10%, Test Loss: 1.3983, Test Accuracy: 52.55%
Epoch [30/250], Loss: 0.8319, Accuracy: 72.77%, Test Loss: 1.3984, Test Accuracy: 52.21%
Epoch [31/250], Loss: 0.8153, Accuracy: 73.42%, Test Loss: 1.4092, Test Accuracy: 52.45%
Epoch [32/250], Loss: 0.8022, Accuracy: 73.86%, Test Loss: 1.4105, Test Accuracy: 52.19%
Epoch [33/250], Loss: 0.7855, Accuracy: 74.51%, Test Loss: 1.4180, Test Accuracy: 51.93%
Epoch [34/250], Loss: 0.7726, Accuracy: 75.15%, Test Loss: 1.4306, Test Accuracy: 52.02%
Epoch [35/250], Loss: 0.7570, Accuracy: 75.58%, Test Loss: 1.4301, Test Accuracy: 52.09%
Epoch [36/250], Loss: 0.7422, Accuracy: 76.24%, Test Loss: 1.4372, Test Accuracy: 51.67%
Epoch [37/250], Loss: 0.7265, Accuracy: 76.86%, Test Loss: 1.4380, Test Accuracy: 52.24%
Epoch [38/250], Loss: 0.7145, Accuracy: 77.34%, Test Loss: 1.4429, Test Accuracy: 52.13%
Epoch [39/250], Loss: 0.7023, Accuracy: 77.64%, Test Loss: 1.4622, Test Accuracy: 52.12%
Epoch [40/250], Loss: 0.6891, Accuracy: 78.09%, Test Loss: 1.4670, Test Accuracy: 51.95%
Epoch [41/250], Loss: 0.6727, Accuracy: 78.72%, Test Loss: 1.4628, Test Accuracy: 52.22%
Epoch [42/250], Loss: 0.6571, Accuracy: 79.45%, Test Loss: 1.4757, Test Accuracy: 52.23%
Epoch [43/250], Loss: 0.6472, Accuracy: 79.78%, Test Loss: 1.4882, Test Accuracy: 51.79%
Epoch [44/250], Loss: 0.6348, Accuracy: 80.17%, Test Loss: 1.4950, Test Accuracy: 52.01%
Epoch [45/250], Loss: 0.6228, Accuracy: 80.74%, Test Loss: 1.5035, Test Accuracy: 52.10%
Epoch [46/250], Loss: 0.6085, Accuracy: 81.23%, Test Loss: 1.5116, Test Accuracy: 51.72%
Epoch [47/250], Loss: 0.5962, Accuracy: 81.61%, Test Loss: 1.5197, Test Accuracy: 52.45%
Epoch [48/250], Loss: 0.5839, Accuracy: 82.12%, Test Loss: 1.5346, Test Accuracy: 51.63%
Epoch [49/250], Loss: 0.5735, Accuracy: 82.54%, Test Loss: 1.5449, Test Accuracy: 51.77%
Epoch [50/250], Loss: 0.5642, Accuracy: 82.71%, Test Loss: 1.5552, Test Accuracy: 51.83%
Epoch [51/250], Loss: 0.5503, Accuracy: 83.21%, Test Loss: 1.5646, Test Accuracy: 51.84%
Epoch [52/250], Loss: 0.5380, Accuracy: 83.86%, Test Loss: 1.5700, Test Accuracy: 52.10%
Epoch [53/250], Loss: 0.5291, Accuracy: 84.01%, Test Loss: 1.5820, Test Accuracy: 51.95%
Epoch [54/250], Loss: 0.5171, Accuracy: 84.50%, Test Loss: 1.5897, Test Accuracy: 51.82%
Epoch [55/250], Loss: 0.5032, Accuracy: 85.11%, Test Loss: 1.6087, Test Accuracy: 51.87%
Epoch [56/250], Loss: 0.4975, Accuracy: 85.29%, Test Loss: 1.6128, Test Accuracy: 51.84%
Epoch [57/250], Loss: 0.4851, Accuracy: 85.73%, Test Loss: 1.6263, Test Accuracy: 51.68%
Epoch [58/250], Loss: 0.4727, Accuracy: 86.41%, Test Loss: 1.6375, Test Accuracy: 51.57%
Epoch [59/250], Loss: 0.4635, Accuracy: 86.62%, Test Loss: 1.6502, Test Accuracy: 51.92%
Epoch [60/250], Loss: 0.4531, Accuracy: 87.04%, Test Loss: 1.6607, Test Accuracy: 52.00%
Epoch [61/250], Loss: 0.4436, Accuracy: 87.39%, Test Loss: 1.6720, Test Accuracy: 51.65%
Epoch [62/250], Loss: 0.4333, Accuracy: 87.77%, Test Loss: 1.6765, Test Accuracy: 51.65%
Epoch [63/250], Loss: 0.4228, Accuracy: 88.21%, Test Loss: 1.6945, Test Accuracy: 51.56%
Epoch [64/250], Loss: 0.4167, Accuracy: 88.35%, Test Loss: 1.7133, Test Accuracy: 51.85%
Epoch [65/250], Loss: 0.4041, Accuracy: 88.74%, Test Loss: 1.7284, Test Accuracy: 51.80%
Epoch [66/250], Loss: 0.3985, Accuracy: 89.05%, Test Loss: 1.7425, Test Accuracy: 51.54%
Epoch [67/250], Loss: 0.3870, Accuracy: 89.40%, Test Loss: 1.7518, Test Accuracy: 51.25%
Epoch [68/250], Loss: 0.3781, Accuracy: 89.80%, Test Loss: 1.7570, Test Accuracy: 52.03%
Epoch [69/250], Loss: 0.3710, Accuracy: 90.02%, Test Loss: 1.7736, Test Accuracy: 51.88%
Epoch [70/250], Loss: 0.3607, Accuracy: 90.47%, Test Loss: 1.7965, Test Accuracy: 51.15%
Epoch [71/250], Loss: 0.3522, Accuracy: 90.86%, Test Loss: 1.8041, Test Accuracy: 51.20%
Epoch [72/250], Loss: 0.3485, Accuracy: 90.89%, Test Loss: 1.8258, Test Accuracy: 51.54%
Epoch [73/250], Loss: 0.3379, Accuracy: 91.17%, Test Loss: 1.8457, Test Accuracy: 51.04%
Epoch [74/250], Loss: 0.3316, Accuracy: 91.44%, Test Loss: 1.8473, Test Accuracy: 51.22%
Epoch [75/250], Loss: 0.3196, Accuracy: 91.97%, Test Loss: 1.8619, Test Accuracy: 50.95%
Epoch [76/250], Loss: 0.3133, Accuracy: 92.25%, Test Loss: 1.8755, Test Accuracy: 51.36%
Epoch [77/250], Loss: 0.3080, Accuracy: 92.44%, Test Loss: 1.8921, Test Accuracy: 51.16%
Epoch [78/250], Loss: 0.2992, Accuracy: 92.66%, Test Loss: 1.9151, Test Accuracy: 51.31%
Epoch [79/250], Loss: 0.2922, Accuracy: 92.97%, Test Loss: 1.9279, Test Accuracy: 51.15%
Epoch [80/250], Loss: 0.2865, Accuracy: 93.02%, Test Loss: 1.9415, Test Accuracy: 51.12%
Epoch [81/250], Loss: 0.2772, Accuracy: 93.45%, Test Loss: 1.9542, Test Accuracy: 51.14%
Epoch [82/250], Loss: 0.2693, Accuracy: 93.77%, Test Loss: 1.9685, Test Accuracy: 50.85%
Epoch [83/250], Loss: 0.2642, Accuracy: 93.97%, Test Loss: 1.9901, Test Accuracy: 50.97%
Epoch [84/250], Loss: 0.2542, Accuracy: 94.39%, Test Loss: 1.9988, Test Accuracy: 50.80%
Epoch [85/250], Loss: 0.2493, Accuracy: 94.59%, Test Loss: 2.0344, Test Accuracy: 51.26%
Epoch [86/250], Loss: 0.2437, Accuracy: 94.64%, Test Loss: 2.0378, Test Accuracy: 51.15%
Epoch [87/250], Loss: 0.2342, Accuracy: 95.11%, Test Loss: 2.0628, Test Accuracy: 50.99%
Epoch [88/250], Loss: 0.2282, Accuracy: 95.15%, Test Loss: 2.0722, Test Accuracy: 50.90%
Epoch [89/250], Loss: 0.2252, Accuracy: 95.34%, Test Loss: 2.0946, Test Accuracy: 50.88%
Epoch [90/250], Loss: 0.2215, Accuracy: 95.36%, Test Loss: 2.1131, Test Accuracy: 50.74%
Epoch [91/250], Loss: 0.2159, Accuracy: 95.59%, Test Loss: 2.1144, Test Accuracy: 50.85%
Epoch [92/250], Loss: 0.2056, Accuracy: 95.91%, Test Loss: 2.1331, Test Accuracy: 51.17%
Epoch [93/250], Loss: 0.2013, Accuracy: 96.06%, Test Loss: 2.1561, Test Accuracy: 50.57%
Epoch [94/250], Loss: 0.1958, Accuracy: 96.32%, Test Loss: 2.1751, Test Accuracy: 50.95%
Epoch [95/250], Loss: 0.1917, Accuracy: 96.33%, Test Loss: 2.1906, Test Accuracy: 50.83%
Epoch [96/250], Loss: 0.1830, Accuracy: 96.68%, Test Loss: 2.2111, Test Accuracy: 50.93%
Epoch [97/250], Loss: 0.1796, Accuracy: 96.76%, Test Loss: 2.2380, Test Accuracy: 51.09%
Epoch [98/250], Loss: 0.1729, Accuracy: 96.97%, Test Loss: 2.2413, Test Accuracy: 50.48%
Epoch [99/250], Loss: 0.1689, Accuracy: 97.13%, Test Loss: 2.2622, Test Accuracy: 50.87%
Epoch [100/250], Loss: 0.1664, Accuracy: 97.17%, Test Loss: 2.3006, Test Accuracy: 50.82%
Epoch [101/250], Loss: 0.1632, Accuracy: 97.16%, Test Loss: 2.3046, Test Accuracy: 50.68%
Epoch [102/250], Loss: 0.1561, Accuracy: 97.44%, Test Loss: 2.3283, Test Accuracy: 50.90%
Epoch [103/250], Loss: 0.1564, Accuracy: 97.33%, Test Loss: 2.3389, Test Accuracy: 50.66%
Epoch [104/250], Loss: 0.1462, Accuracy: 97.77%, Test Loss: 2.3557, Test Accuracy: 50.67%
Epoch [105/250], Loss: 0.1417, Accuracy: 97.91%, Test Loss: 2.3817, Test Accuracy: 50.54%
Epoch [106/250], Loss: 0.1396, Accuracy: 97.84%, Test Loss: 2.3938, Test Accuracy: 50.32%
Epoch [107/250], Loss: 0.1347, Accuracy: 98.03%, Test Loss: 2.4160, Test Accuracy: 50.72%
Epoch [108/250], Loss: 0.1313, Accuracy: 98.12%, Test Loss: 2.4368, Test Accuracy: 50.42%
Epoch [109/250], Loss: 0.1272, Accuracy: 98.24%, Test Loss: 2.4520, Test Accuracy: 50.48%
Epoch [110/250], Loss: 0.1225, Accuracy: 98.38%, Test Loss: 2.4752, Test Accuracy: 50.95%
Epoch [111/250], Loss: 0.1207, Accuracy: 98.40%, Test Loss: 2.4993, Test Accuracy: 50.51%
Epoch [112/250], Loss: 0.1144, Accuracy: 98.52%, Test Loss: 2.5152, Test Accuracy: 50.45%
Epoch [113/250], Loss: 0.1101, Accuracy: 98.70%, Test Loss: 2.5357, Test Accuracy: 50.30%
Epoch [114/250], Loss: 0.1092, Accuracy: 98.76%, Test Loss: 2.5662, Test Accuracy: 50.44%
Epoch [115/250], Loss: 0.1051, Accuracy: 98.80%, Test Loss: 2.5922, Test Accuracy: 50.06%
Epoch [116/250], Loss: 0.1028, Accuracy: 98.80%, Test Loss: 2.6013, Test Accuracy: 50.16%
Epoch [117/250], Loss: 0.0983, Accuracy: 98.95%, Test Loss: 2.6396, Test Accuracy: 50.15%
Epoch [118/250], Loss: 0.0946, Accuracy: 99.02%, Test Loss: 2.6440, Test Accuracy: 50.04%
Epoch [119/250], Loss: 0.0934, Accuracy: 99.04%, Test Loss: 2.6666, Test Accuracy: 50.16%
Epoch [120/250], Loss: 0.0895, Accuracy: 99.13%, Test Loss: 2.6782, Test Accuracy: 50.51%
Epoch [121/250], Loss: 0.0858, Accuracy: 99.16%, Test Loss: 2.7077, Test Accuracy: 49.88%
Epoch [122/250], Loss: 0.0829, Accuracy: 99.24%, Test Loss: 2.7364, Test Accuracy: 50.15%
Epoch [123/250], Loss: 0.0833, Accuracy: 99.16%, Test Loss: 2.7551, Test Accuracy: 50.01%
Epoch [124/250], Loss: 0.0817, Accuracy: 99.20%, Test Loss: 2.7820, Test Accuracy: 50.08%
Epoch [125/250], Loss: 0.0794, Accuracy: 99.24%, Test Loss: 2.7807, Test Accuracy: 50.25%
Epoch [126/250], Loss: 0.0737, Accuracy: 99.38%, Test Loss: 2.8045, Test Accuracy: 49.90%
Epoch [127/250], Loss: 0.0713, Accuracy: 99.44%, Test Loss: 2.8234, Test Accuracy: 50.31%
Epoch [128/250], Loss: 0.0678, Accuracy: 99.52%, Test Loss: 2.8499, Test Accuracy: 49.98%
Epoch [129/250], Loss: 0.0670, Accuracy: 99.47%, Test Loss: 2.8714, Test Accuracy: 49.84%
Epoch [130/250], Loss: 0.0644, Accuracy: 99.53%, Test Loss: 2.8887, Test Accuracy: 49.93%
Epoch [131/250], Loss: 0.0614, Accuracy: 99.58%, Test Loss: 2.9210, Test Accuracy: 50.17%
Epoch [132/250], Loss: 0.0601, Accuracy: 99.62%, Test Loss: 2.9366, Test Accuracy: 50.06%
Epoch [133/250], Loss: 0.0588, Accuracy: 99.60%, Test Loss: 2.9631, Test Accuracy: 50.30%
Epoch [134/250], Loss: 0.0573, Accuracy: 99.62%, Test Loss: 2.9795, Test Accuracy: 49.74%
Epoch [135/250], Loss: 0.0546, Accuracy: 99.66%, Test Loss: 3.0060, Test Accuracy: 49.83%
Epoch [136/250], Loss: 0.0571, Accuracy: 99.56%, Test Loss: 3.0297, Test Accuracy: 49.60%
Epoch [137/250], Loss: 0.0538, Accuracy: 99.64%, Test Loss: 3.0361, Test Accuracy: 49.95%
Epoch [138/250], Loss: 0.0491, Accuracy: 99.76%, Test Loss: 3.0562, Test Accuracy: 50.02%
Epoch [139/250], Loss: 0.0462, Accuracy: 99.80%, Test Loss: 3.0922, Test Accuracy: 49.90%
Epoch [140/250], Loss: 0.0457, Accuracy: 99.78%, Test Loss: 3.1124, Test Accuracy: 49.93%
Epoch [141/250], Loss: 0.0437, Accuracy: 99.81%, Test Loss: 3.1280, Test Accuracy: 49.81%
Epoch [142/250], Loss: 0.0422, Accuracy: 99.82%, Test Loss: 3.1616, Test Accuracy: 49.78%
Epoch [143/250], Loss: 0.0407, Accuracy: 99.87%, Test Loss: 3.1665, Test Accuracy: 50.06%
Epoch [144/250], Loss: 0.0388, Accuracy: 99.87%, Test Loss: 3.1961, Test Accuracy: 49.73%
Epoch [145/250], Loss: 0.0410, Accuracy: 99.78%, Test Loss: 3.2133, Test Accuracy: 49.80%
Epoch [146/250], Loss: 0.0387, Accuracy: 99.84%, Test Loss: 3.2273, Test Accuracy: 50.13%
Epoch [147/250], Loss: 0.0377, Accuracy: 99.80%, Test Loss: 3.2610, Test Accuracy: 49.54%
Epoch [148/250], Loss: 0.0346, Accuracy: 99.90%, Test Loss: 3.2667, Test Accuracy: 49.62%
Epoch [149/250], Loss: 0.0322, Accuracy: 99.91%, Test Loss: 3.2931, Test Accuracy: 49.96%
Epoch [150/250], Loss: 0.0320, Accuracy: 99.91%, Test Loss: 3.3236, Test Accuracy: 49.71%
Epoch [151/250], Loss: 0.0318, Accuracy: 99.90%, Test Loss: 3.3403, Test Accuracy: 49.86%
Epoch [152/250], Loss: 0.0302, Accuracy: 99.93%, Test Loss: 3.3720, Test Accuracy: 50.07%
Epoch [153/250], Loss: 0.0300, Accuracy: 99.91%, Test Loss: 3.3728, Test Accuracy: 49.77%
Epoch [154/250], Loss: 0.0281, Accuracy: 99.94%, Test Loss: 3.3984, Test Accuracy: 49.73%
Epoch [155/250], Loss: 0.0272, Accuracy: 99.94%, Test Loss: 3.4297, Test Accuracy: 49.69%
Epoch [156/250], Loss: 0.0274, Accuracy: 99.92%, Test Loss: 3.4358, Test Accuracy: 49.92%
Epoch [157/250], Loss: 0.0338, Accuracy: 99.71%, Test Loss: 3.4814, Test Accuracy: 49.80%
Epoch [158/250], Loss: 0.0339, Accuracy: 99.77%, Test Loss: 3.4621, Test Accuracy: 49.91%
Epoch [159/250], Loss: 0.0255, Accuracy: 99.95%, Test Loss: 3.5166, Test Accuracy: 49.50%
Epoch [160/250], Loss: 0.0229, Accuracy: 99.97%, Test Loss: 3.5142, Test Accuracy: 49.73%
Epoch [161/250], Loss: 0.0211, Accuracy: 99.97%, Test Loss: 3.5431, Test Accuracy: 49.91%
Epoch [162/250], Loss: 0.0207, Accuracy: 99.98%, Test Loss: 3.5707, Test Accuracy: 49.61%
Epoch [163/250], Loss: 0.0202, Accuracy: 99.98%, Test Loss: 3.5886, Test Accuracy: 49.75%
Epoch [164/250], Loss: 0.0196, Accuracy: 99.98%, Test Loss: 3.6094, Test Accuracy: 49.85%
Epoch [165/250], Loss: 0.0185, Accuracy: 99.99%, Test Loss: 3.6301, Test Accuracy: 49.64%
Epoch [166/250], Loss: 0.0188, Accuracy: 99.97%, Test Loss: 3.6414, Test Accuracy: 49.69%
Epoch [167/250], Loss: 0.0175, Accuracy: 100.00%, Test Loss: 3.6675, Test Accuracy: 49.92%
Epoch [168/250], Loss: 0.0170, Accuracy: 99.98%, Test Loss: 3.6890, Test Accuracy: 49.79%
Epoch [169/250], Loss: 0.0165, Accuracy: 99.99%, Test Loss: 3.7118, Test Accuracy: 50.01%
Epoch [170/250], Loss: 0.0158, Accuracy: 99.99%, Test Loss: 3.7318, Test Accuracy: 49.77%
Epoch [171/250], Loss: 0.0151, Accuracy: 99.99%, Test Loss: 3.7567, Test Accuracy: 49.71%
Epoch [172/250], Loss: 0.0148, Accuracy: 99.99%, Test Loss: 3.7715, Test Accuracy: 49.75%
Epoch [173/250], Loss: 0.0143, Accuracy: 100.00%, Test Loss: 3.7957, Test Accuracy: 49.68%
Epoch [174/250], Loss: 0.0151, Accuracy: 99.99%, Test Loss: 3.8100, Test Accuracy: 49.71%
Epoch [175/250], Loss: 0.0145, Accuracy: 99.98%, Test Loss: 3.8277, Test Accuracy: 49.78%
Epoch [176/250], Loss: 0.0147, Accuracy: 99.98%, Test Loss: 3.8698, Test Accuracy: 49.59%
Epoch [177/250], Loss: 0.0397, Accuracy: 99.19%, Test Loss: 3.9177, Test Accuracy: 49.00%
Epoch [178/250], Loss: 0.0500, Accuracy: 98.96%, Test Loss: 3.9075, Test Accuracy: 49.57%
Epoch [179/250], Loss: 0.0540, Accuracy: 98.77%, Test Loss: 3.8162, Test Accuracy: 49.98%
Epoch [180/250], Loss: 0.0551, Accuracy: 98.65%, Test Loss: 3.8584, Test Accuracy: 49.38%
Epoch [181/250], Loss: 0.0287, Accuracy: 99.67%, Test Loss: 3.8423, Test Accuracy: 49.89%
Epoch [182/250], Loss: 0.0142, Accuracy: 99.99%, Test Loss: 3.8553, Test Accuracy: 49.68%
Epoch [183/250], Loss: 0.0114, Accuracy: 100.00%, Test Loss: 3.8935, Test Accuracy: 49.98%
Epoch [184/250], Loss: 0.0106, Accuracy: 100.00%, Test Loss: 3.9138, Test Accuracy: 49.92%
Epoch [185/250], Loss: 0.0101, Accuracy: 100.00%, Test Loss: 3.9399, Test Accuracy: 49.78%
Epoch [186/250], Loss: 0.0097, Accuracy: 100.00%, Test Loss: 3.9545, Test Accuracy: 49.68%
Epoch [187/250], Loss: 0.0094, Accuracy: 100.00%, Test Loss: 3.9761, Test Accuracy: 49.78%
Epoch [188/250], Loss: 0.0090, Accuracy: 100.00%, Test Loss: 3.9958, Test Accuracy: 49.90%
Epoch [189/250], Loss: 0.0088, Accuracy: 100.00%, Test Loss: 4.0131, Test Accuracy: 49.76%
Epoch [190/250], Loss: 0.0086, Accuracy: 100.00%, Test Loss: 4.0299, Test Accuracy: 49.72%
Epoch [191/250], Loss: 0.0084, Accuracy: 100.00%, Test Loss: 4.0447, Test Accuracy: 49.70%
Epoch [192/250], Loss: 0.0081, Accuracy: 100.00%, Test Loss: 4.0571, Test Accuracy: 49.73%
Epoch [193/250], Loss: 0.0080, Accuracy: 100.00%, Test Loss: 4.0762, Test Accuracy: 49.87%
Epoch [194/250], Loss: 0.0078, Accuracy: 100.00%, Test Loss: 4.0875, Test Accuracy: 49.74%
Epoch [195/250], Loss: 0.0076, Accuracy: 100.00%, Test Loss: 4.1042, Test Accuracy: 49.68%
Epoch [196/250], Loss: 0.0075, Accuracy: 100.00%, Test Loss: 4.1210, Test Accuracy: 49.58%
Epoch [197/250], Loss: 0.0073, Accuracy: 100.00%, Test Loss: 4.1333, Test Accuracy: 49.69%
Epoch [198/250], Loss: 0.0072, Accuracy: 100.00%, Test Loss: 4.1467, Test Accuracy: 49.64%
Epoch [199/250], Loss: 0.0070, Accuracy: 100.00%, Test Loss: 4.1633, Test Accuracy: 49.65%
Epoch [200/250], Loss: 0.0068, Accuracy: 100.00%, Test Loss: 4.1803, Test Accuracy: 49.73%
Epoch [201/250], Loss: 0.0066, Accuracy: 100.00%, Test Loss: 4.1975, Test Accuracy: 49.79%
Epoch [202/250], Loss: 0.0064, Accuracy: 100.00%, Test Loss: 4.2062, Test Accuracy: 49.71%
Epoch [203/250], Loss: 0.0064, Accuracy: 100.00%, Test Loss: 4.2266, Test Accuracy: 49.63%
Epoch [204/250], Loss: 0.0062, Accuracy: 100.00%, Test Loss: 4.2358, Test Accuracy: 49.69%
Epoch [205/250], Loss: 0.0060, Accuracy: 100.00%, Test Loss: 4.2517, Test Accuracy: 49.70%
Epoch [206/250], Loss: 0.0059, Accuracy: 100.00%, Test Loss: 4.2736, Test Accuracy: 49.59%
Epoch [207/250], Loss: 0.0059, Accuracy: 100.00%, Test Loss: 4.2870, Test Accuracy: 49.67%
Epoch [208/250], Loss: 0.0057, Accuracy: 100.00%, Test Loss: 4.3023, Test Accuracy: 49.42%
Epoch [209/250], Loss: 0.0056, Accuracy: 100.00%, Test Loss: 4.3119, Test Accuracy: 49.57%
Epoch [210/250], Loss: 0.0055, Accuracy: 100.00%, Test Loss: 4.3292, Test Accuracy: 49.65%
Epoch [211/250], Loss: 0.0053, Accuracy: 100.00%, Test Loss: 4.3534, Test Accuracy: 49.90%
Epoch [212/250], Loss: 0.0052, Accuracy: 100.00%, Test Loss: 4.3583, Test Accuracy: 49.55%
Epoch [213/250], Loss: 0.0051, Accuracy: 100.00%, Test Loss: 4.3819, Test Accuracy: 49.73%
Epoch [214/250], Loss: 0.0049, Accuracy: 100.00%, Test Loss: 4.3946, Test Accuracy: 49.43%
Epoch [215/250], Loss: 0.0049, Accuracy: 100.00%, Test Loss: 4.4116, Test Accuracy: 49.53%
Epoch [216/250], Loss: 0.0047, Accuracy: 100.00%, Test Loss: 4.4203, Test Accuracy: 49.61%
Epoch [217/250], Loss: 0.0046, Accuracy: 100.00%, Test Loss: 4.4466, Test Accuracy: 49.71%
Epoch [218/250], Loss: 0.0046, Accuracy: 100.00%, Test Loss: 4.4584, Test Accuracy: 49.69%
Epoch [219/250], Loss: 0.0045, Accuracy: 100.00%, Test Loss: 4.4740, Test Accuracy: 49.64%
Epoch [220/250], Loss: 0.0043, Accuracy: 100.00%, Test Loss: 4.4922, Test Accuracy: 49.77%
Epoch [221/250], Loss: 0.0042, Accuracy: 100.00%, Test Loss: 4.5035, Test Accuracy: 49.58%
Epoch [222/250], Loss: 0.0042, Accuracy: 100.00%, Test Loss: 4.5332, Test Accuracy: 49.57%
Epoch [223/250], Loss: 0.0043, Accuracy: 100.00%, Test Loss: 4.5420, Test Accuracy: 49.47%
Epoch [224/250], Loss: 0.0039, Accuracy: 100.00%, Test Loss: 4.5515, Test Accuracy: 49.68%
Epoch [225/250], Loss: 0.0039, Accuracy: 100.00%, Test Loss: 4.5758, Test Accuracy: 49.50%
Epoch [226/250], Loss: 0.0038, Accuracy: 100.00%, Test Loss: 4.5870, Test Accuracy: 49.58%
Epoch [227/250], Loss: 0.0037, Accuracy: 100.00%, Test Loss: 4.6042, Test Accuracy: 49.55%
Epoch [228/250], Loss: 0.0036, Accuracy: 100.00%, Test Loss: 4.6242, Test Accuracy: 49.52%
Epoch [229/250], Loss: 0.0034, Accuracy: 100.00%, Test Loss: 4.6436, Test Accuracy: 49.51%
Epoch [230/250], Loss: 0.0033, Accuracy: 100.00%, Test Loss: 4.6650, Test Accuracy: 49.48%
Epoch [231/250], Loss: 0.0033, Accuracy: 100.00%, Test Loss: 4.6775, Test Accuracy: 49.50%
Epoch [232/250], Loss: 0.0032, Accuracy: 100.00%, Test Loss: 4.6908, Test Accuracy: 49.59%
Epoch [233/250], Loss: 0.0031, Accuracy: 100.00%, Test Loss: 4.7136, Test Accuracy: 49.50%
Epoch [234/250], Loss: 0.0031, Accuracy: 100.00%, Test Loss: 4.7250, Test Accuracy: 49.78%
Epoch [235/250], Loss: 0.0030, Accuracy: 100.00%, Test Loss: 4.7494, Test Accuracy: 49.56%
Epoch [236/250], Loss: 0.0029, Accuracy: 100.00%, Test Loss: 4.7618, Test Accuracy: 49.47%
Epoch [237/250], Loss: 0.0028, Accuracy: 100.00%, Test Loss: 4.7848, Test Accuracy: 49.54%
Epoch [238/250], Loss: 0.0027, Accuracy: 100.00%, Test Loss: 4.7981, Test Accuracy: 49.36%
Epoch [239/250], Loss: 0.0027, Accuracy: 100.00%, Test Loss: 4.8105, Test Accuracy: 49.56%
Epoch [240/250], Loss: 0.0026, Accuracy: 100.00%, Test Loss: 4.8375, Test Accuracy: 49.55%
Epoch [241/250], Loss: 0.0026, Accuracy: 100.00%, Test Loss: 4.8525, Test Accuracy: 49.62%
Epoch [242/250], Loss: 0.0024, Accuracy: 100.00%, Test Loss: 4.8665, Test Accuracy: 49.50%
Epoch [243/250], Loss: 0.0024, Accuracy: 100.00%, Test Loss: 4.8849, Test Accuracy: 49.51%
Epoch [244/250], Loss: 0.0024, Accuracy: 100.00%, Test Loss: 4.9046, Test Accuracy: 49.51%
Epoch [245/250], Loss: 0.0023, Accuracy: 100.00%, Test Loss: 4.9243, Test Accuracy: 49.60%
Epoch [246/250], Loss: 0.0022, Accuracy: 100.00%, Test Loss: 4.9458, Test Accuracy: 49.55%
Epoch [247/250], Loss: 0.0022, Accuracy: 100.00%, Test Loss: 4.9631, Test Accuracy: 49.54%
Epoch [248/250], Loss: 0.0021, Accuracy: 100.00%, Test Loss: 4.9808, Test Accuracy: 49.55%
Epoch [249/250], Loss: 0.0021, Accuracy: 100.00%, Test Loss: 4.9980, Test Accuracy: 49.48%
Epoch [250/250], Loss: 0.0020, Accuracy: 100.00%, Test Loss: 5.0123, Test Accuracy: 49.70%
import matplotlib.pyplot as plt

plt.plot(train_losses, label='train_losses')
plt.plot(test_losses, label='test_losses')
plt.legend()
<matplotlib.legend.Legend at 0x7f855f88fac0>
image
import matplotlib.pyplot as plt

plt.plot(train_accuracies, label='train_accuracy')
plt.plot(test_accuracies, label='test_accuracy')
plt.legend()
<matplotlib.legend.Legend at 0x7f855c5af0d0>
image
FashionMNIST_CNN_onlyConv2d
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.init as init

# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Data

# Load CFashionMNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1024, num_workers=10, shuffle=True, drop_last=True)

testset = torchvision.datasets.FashionMNIST(root='data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1024, num_workers=10, shuffle=False)
import matplotlib.pyplot as plt
import numpy as np

# Function to display the images
def imshow(img):
    img = img*0.5 + 0.5
    np_img = img.numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()

for i, (images, labels) in enumerate(trainloader, 0):
    # Plot some images
    imshow(torchvision.utils.make_grid(images[:8]))  # Display 8 images from the batch
    break
image

Model

import torch
import torch.nn as nn
from torchsummary import summary

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=7)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=7)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=7)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=7)
        self.flatten = nn.Flatten()
        self.dense1 = nn.Linear(4*4*256, 128)
        self.dense2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.flatten(x)
        x = self.relu(self.dense1(x))
        x = self.dense2(x)
        return x

model = CustomModel()
model = model.to(device)

# Print model
summary(model, (1, 28, 28))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 22, 22]           1,600
              ReLU-2           [-1, 32, 22, 22]               0
            Conv2d-3           [-1, 64, 16, 16]         100,416
              ReLU-4           [-1, 64, 16, 16]               0
            Conv2d-5          [-1, 128, 10, 10]         401,536
              ReLU-6          [-1, 128, 10, 10]               0
            Conv2d-7            [-1, 256, 4, 4]       1,605,888
              ReLU-8            [-1, 256, 4, 4]               0
           Flatten-9                 [-1, 4096]               0
           Linear-10                  [-1, 128]         524,416
             ReLU-11                  [-1, 128]               0
           Linear-12                   [-1, 10]           1,290
================================================================
Total params: 2,635,146
Trainable params: 2,635,146
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.78
Params size (MB): 10.05
Estimated Total Size (MB): 10.83
----------------------------------------------------------------

Loss, Optimizer, and Evaluation Function

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)
# Function to compute loss and accuracy for test set
def evaluate(model, testloader, criterion):
    model.eval()
    test_loss = 0.0
    running_correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            # Move inputs and labels to the device
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

    accuracy = 100 * running_correct / total
    test_loss = test_loss / len(testloader)
    return test_loss, accuracy

Train

# some parameter
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
max_epoch = 250
# train
for epoch in range(max_epoch):
    model.train()
    running_loss = 0.0
    running_correct = 0   # to track number of correct predictions
    total = 0             # to track total number of samples

    for i, (inputs, labels) in enumerate(trainloader, 0):
        # Move inputs and labels to the device
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

        # Determine class predictions and track accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        running_correct += (predicted == labels).sum().item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()        

    epoch_accuracy = 100 * running_correct / total
    epoch_loss = running_loss / (i + 1)
    
    test_loss, test_accuracy = evaluate(model, testloader, criterion)
    print(f"Epoch [{epoch + 1}/{max_epoch}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    # save for plot
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_accuracy)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
Epoch [1/250], Loss: 2.2336, Accuracy: 23.01%, Test Loss: 2.0205, Test Accuracy: 43.39%
Epoch [2/250], Loss: 1.5065, Accuracy: 52.92%, Test Loss: 1.0948, Test Accuracy: 61.05%
Epoch [3/250], Loss: 0.9647, Accuracy: 64.76%, Test Loss: 0.9035, Test Accuracy: 66.02%
Epoch [4/250], Loss: 0.8519, Accuracy: 68.67%, Test Loss: 0.8374, Test Accuracy: 69.15%
Epoch [5/250], Loss: 0.7982, Accuracy: 70.52%, Test Loss: 0.7939, Test Accuracy: 70.69%
Epoch [6/250], Loss: 0.7622, Accuracy: 71.79%, Test Loss: 0.7610, Test Accuracy: 71.86%
Epoch [7/250], Loss: 0.7319, Accuracy: 72.70%, Test Loss: 0.7391, Test Accuracy: 72.22%
Epoch [8/250], Loss: 0.7091, Accuracy: 73.48%, Test Loss: 0.7171, Test Accuracy: 73.11%
Epoch [9/250], Loss: 0.6900, Accuracy: 74.01%, Test Loss: 0.6997, Test Accuracy: 73.50%
Epoch [10/250], Loss: 0.6722, Accuracy: 74.45%, Test Loss: 0.6847, Test Accuracy: 74.03%
Epoch [11/250], Loss: 0.6592, Accuracy: 74.92%, Test Loss: 0.6725, Test Accuracy: 74.20%
Epoch [12/250], Loss: 0.6461, Accuracy: 75.32%, Test Loss: 0.6601, Test Accuracy: 74.79%
Epoch [13/250], Loss: 0.6351, Accuracy: 75.73%, Test Loss: 0.6524, Test Accuracy: 75.36%
Epoch [14/250], Loss: 0.6265, Accuracy: 76.03%, Test Loss: 0.6408, Test Accuracy: 75.90%
Epoch [15/250], Loss: 0.6171, Accuracy: 76.31%, Test Loss: 0.6331, Test Accuracy: 75.56%
Epoch [16/250], Loss: 0.6079, Accuracy: 76.66%, Test Loss: 0.6228, Test Accuracy: 76.33%
Epoch [17/250], Loss: 0.6019, Accuracy: 76.92%, Test Loss: 0.6199, Test Accuracy: 76.32%
Epoch [18/250], Loss: 0.5937, Accuracy: 77.10%, Test Loss: 0.6104, Test Accuracy: 76.68%
Epoch [19/250], Loss: 0.5873, Accuracy: 77.45%, Test Loss: 0.6060, Test Accuracy: 76.84%
Epoch [20/250], Loss: 0.5802, Accuracy: 77.69%, Test Loss: 0.5991, Test Accuracy: 77.15%
Epoch [21/250], Loss: 0.5747, Accuracy: 78.01%, Test Loss: 0.5958, Test Accuracy: 77.03%
Epoch [22/250], Loss: 0.5691, Accuracy: 78.17%, Test Loss: 0.5867, Test Accuracy: 77.54%
Epoch [23/250], Loss: 0.5651, Accuracy: 78.38%, Test Loss: 0.5850, Test Accuracy: 77.59%
Epoch [24/250], Loss: 0.5575, Accuracy: 78.62%, Test Loss: 0.5769, Test Accuracy: 77.90%
Epoch [25/250], Loss: 0.5526, Accuracy: 78.95%, Test Loss: 0.5720, Test Accuracy: 78.11%
Epoch [26/250], Loss: 0.5466, Accuracy: 79.14%, Test Loss: 0.5664, Test Accuracy: 78.67%
Epoch [27/250], Loss: 0.5421, Accuracy: 79.35%, Test Loss: 0.5620, Test Accuracy: 78.70%
Epoch [28/250], Loss: 0.5389, Accuracy: 79.50%, Test Loss: 0.5586, Test Accuracy: 78.78%
Epoch [29/250], Loss: 0.5327, Accuracy: 79.70%, Test Loss: 0.5541, Test Accuracy: 79.05%
Epoch [30/250], Loss: 0.5299, Accuracy: 79.94%, Test Loss: 0.5510, Test Accuracy: 79.09%
Epoch [31/250], Loss: 0.5269, Accuracy: 80.04%, Test Loss: 0.5448, Test Accuracy: 79.57%
Epoch [32/250], Loss: 0.5212, Accuracy: 80.32%, Test Loss: 0.5430, Test Accuracy: 79.43%
Epoch [33/250], Loss: 0.5163, Accuracy: 80.57%, Test Loss: 0.5398, Test Accuracy: 79.69%
Epoch [34/250], Loss: 0.5140, Accuracy: 80.73%, Test Loss: 0.5404, Test Accuracy: 79.48%
Epoch [35/250], Loss: 0.5098, Accuracy: 80.81%, Test Loss: 0.5324, Test Accuracy: 80.06%
Epoch [36/250], Loss: 0.5070, Accuracy: 80.98%, Test Loss: 0.5338, Test Accuracy: 80.14%
Epoch [37/250], Loss: 0.5036, Accuracy: 81.19%, Test Loss: 0.5278, Test Accuracy: 80.27%
Epoch [38/250], Loss: 0.5012, Accuracy: 81.31%, Test Loss: 0.5215, Test Accuracy: 80.61%
Epoch [39/250], Loss: 0.4960, Accuracy: 81.46%, Test Loss: 0.5176, Test Accuracy: 80.81%
Epoch [40/250], Loss: 0.4929, Accuracy: 81.66%, Test Loss: 0.5172, Test Accuracy: 80.43%
Epoch [41/250], Loss: 0.4917, Accuracy: 81.61%, Test Loss: 0.5195, Test Accuracy: 80.94%
Epoch [42/250], Loss: 0.4878, Accuracy: 81.91%, Test Loss: 0.5146, Test Accuracy: 81.13%
Epoch [43/250], Loss: 0.4854, Accuracy: 82.07%, Test Loss: 0.5090, Test Accuracy: 81.23%
Epoch [44/250], Loss: 0.4811, Accuracy: 82.15%, Test Loss: 0.5060, Test Accuracy: 81.47%
Epoch [45/250], Loss: 0.4798, Accuracy: 82.34%, Test Loss: 0.5062, Test Accuracy: 81.20%
Epoch [46/250], Loss: 0.4771, Accuracy: 82.51%, Test Loss: 0.4992, Test Accuracy: 81.67%
Epoch [47/250], Loss: 0.4741, Accuracy: 82.55%, Test Loss: 0.4952, Test Accuracy: 81.75%
Epoch [48/250], Loss: 0.4707, Accuracy: 82.82%, Test Loss: 0.4969, Test Accuracy: 81.68%
Epoch [49/250], Loss: 0.4671, Accuracy: 82.93%, Test Loss: 0.4950, Test Accuracy: 82.15%
Epoch [50/250], Loss: 0.4666, Accuracy: 82.83%, Test Loss: 0.4913, Test Accuracy: 82.32%
Epoch [51/250], Loss: 0.4638, Accuracy: 82.98%, Test Loss: 0.4891, Test Accuracy: 81.85%
Epoch [52/250], Loss: 0.4619, Accuracy: 83.13%, Test Loss: 0.4870, Test Accuracy: 82.00%
Epoch [53/250], Loss: 0.4595, Accuracy: 83.18%, Test Loss: 0.4844, Test Accuracy: 82.31%
Epoch [54/250], Loss: 0.4573, Accuracy: 83.32%, Test Loss: 0.4792, Test Accuracy: 82.49%
Epoch [55/250], Loss: 0.4542, Accuracy: 83.36%, Test Loss: 0.4789, Test Accuracy: 82.44%
Epoch [56/250], Loss: 0.4519, Accuracy: 83.57%, Test Loss: 0.4769, Test Accuracy: 82.58%
Epoch [57/250], Loss: 0.4506, Accuracy: 83.59%, Test Loss: 0.4802, Test Accuracy: 82.33%
Epoch [58/250], Loss: 0.4485, Accuracy: 83.68%, Test Loss: 0.4760, Test Accuracy: 82.62%
Epoch [59/250], Loss: 0.4450, Accuracy: 83.76%, Test Loss: 0.4730, Test Accuracy: 83.03%
Epoch [60/250], Loss: 0.4424, Accuracy: 83.95%, Test Loss: 0.4709, Test Accuracy: 82.97%
Epoch [61/250], Loss: 0.4422, Accuracy: 83.99%, Test Loss: 0.4702, Test Accuracy: 82.91%
Epoch [62/250], Loss: 0.4404, Accuracy: 83.90%, Test Loss: 0.4669, Test Accuracy: 82.77%
Epoch [63/250], Loss: 0.4380, Accuracy: 83.93%, Test Loss: 0.4641, Test Accuracy: 82.97%
Epoch [64/250], Loss: 0.4370, Accuracy: 84.09%, Test Loss: 0.4622, Test Accuracy: 83.28%
Epoch [65/250], Loss: 0.4337, Accuracy: 84.22%, Test Loss: 0.4652, Test Accuracy: 83.13%
Epoch [66/250], Loss: 0.4329, Accuracy: 84.30%, Test Loss: 0.4622, Test Accuracy: 83.28%
Epoch [67/250], Loss: 0.4316, Accuracy: 84.31%, Test Loss: 0.4584, Test Accuracy: 83.36%
Epoch [68/250], Loss: 0.4284, Accuracy: 84.44%, Test Loss: 0.4600, Test Accuracy: 83.12%
Epoch [69/250], Loss: 0.4262, Accuracy: 84.52%, Test Loss: 0.4526, Test Accuracy: 83.63%
Epoch [70/250], Loss: 0.4239, Accuracy: 84.62%, Test Loss: 0.4531, Test Accuracy: 83.83%
Epoch [71/250], Loss: 0.4239, Accuracy: 84.68%, Test Loss: 0.4506, Test Accuracy: 83.77%
Epoch [72/250], Loss: 0.4202, Accuracy: 84.78%, Test Loss: 0.4498, Test Accuracy: 83.87%
Epoch [73/250], Loss: 0.4202, Accuracy: 84.81%, Test Loss: 0.4513, Test Accuracy: 83.53%
Epoch [74/250], Loss: 0.4193, Accuracy: 84.72%, Test Loss: 0.4471, Test Accuracy: 83.68%
Epoch [75/250], Loss: 0.4182, Accuracy: 84.83%, Test Loss: 0.4469, Test Accuracy: 83.75%
Epoch [76/250], Loss: 0.4153, Accuracy: 84.91%, Test Loss: 0.4423, Test Accuracy: 84.09%
Epoch [77/250], Loss: 0.4140, Accuracy: 85.00%, Test Loss: 0.4428, Test Accuracy: 83.94%
Epoch [78/250], Loss: 0.4120, Accuracy: 85.09%, Test Loss: 0.4465, Test Accuracy: 83.89%
Epoch [79/250], Loss: 0.4105, Accuracy: 85.20%, Test Loss: 0.4397, Test Accuracy: 84.04%
Epoch [80/250], Loss: 0.4082, Accuracy: 85.20%, Test Loss: 0.4387, Test Accuracy: 84.05%
Epoch [81/250], Loss: 0.4078, Accuracy: 85.21%, Test Loss: 0.4375, Test Accuracy: 84.20%
Epoch [82/250], Loss: 0.4079, Accuracy: 85.27%, Test Loss: 0.4396, Test Accuracy: 84.09%
Epoch [83/250], Loss: 0.4050, Accuracy: 85.31%, Test Loss: 0.4367, Test Accuracy: 84.11%
Epoch [84/250], Loss: 0.4027, Accuracy: 85.40%, Test Loss: 0.4341, Test Accuracy: 84.15%
Epoch [85/250], Loss: 0.4021, Accuracy: 85.48%, Test Loss: 0.4347, Test Accuracy: 84.00%
Epoch [86/250], Loss: 0.4006, Accuracy: 85.50%, Test Loss: 0.4345, Test Accuracy: 84.18%
Epoch [87/250], Loss: 0.4018, Accuracy: 85.34%, Test Loss: 0.4289, Test Accuracy: 84.43%
Epoch [88/250], Loss: 0.3973, Accuracy: 85.69%, Test Loss: 0.4337, Test Accuracy: 84.21%
Epoch [89/250], Loss: 0.3971, Accuracy: 85.62%, Test Loss: 0.4275, Test Accuracy: 84.65%
Epoch [90/250], Loss: 0.3962, Accuracy: 85.65%, Test Loss: 0.4262, Test Accuracy: 84.66%
Epoch [91/250], Loss: 0.3950, Accuracy: 85.82%, Test Loss: 0.4242, Test Accuracy: 84.68%
Epoch [92/250], Loss: 0.3940, Accuracy: 85.74%, Test Loss: 0.4246, Test Accuracy: 84.77%
Epoch [93/250], Loss: 0.3914, Accuracy: 85.92%, Test Loss: 0.4238, Test Accuracy: 84.60%
Epoch [94/250], Loss: 0.3900, Accuracy: 85.97%, Test Loss: 0.4245, Test Accuracy: 84.46%
Epoch [95/250], Loss: 0.3896, Accuracy: 85.89%, Test Loss: 0.4220, Test Accuracy: 84.78%
Epoch [96/250], Loss: 0.3888, Accuracy: 85.94%, Test Loss: 0.4205, Test Accuracy: 84.74%
Epoch [97/250], Loss: 0.3892, Accuracy: 85.97%, Test Loss: 0.4179, Test Accuracy: 84.95%
Epoch [98/250], Loss: 0.3870, Accuracy: 85.96%, Test Loss: 0.4209, Test Accuracy: 84.93%
Epoch [99/250], Loss: 0.3871, Accuracy: 86.01%, Test Loss: 0.4169, Test Accuracy: 85.00%
Epoch [100/250], Loss: 0.3851, Accuracy: 86.10%, Test Loss: 0.4174, Test Accuracy: 84.90%
Epoch [101/250], Loss: 0.3845, Accuracy: 86.08%, Test Loss: 0.4158, Test Accuracy: 84.97%
Epoch [102/250], Loss: 0.3834, Accuracy: 86.16%, Test Loss: 0.4142, Test Accuracy: 85.18%
Epoch [103/250], Loss: 0.3811, Accuracy: 86.18%, Test Loss: 0.4142, Test Accuracy: 85.01%
Epoch [104/250], Loss: 0.3800, Accuracy: 86.18%, Test Loss: 0.4177, Test Accuracy: 85.03%
Epoch [105/250], Loss: 0.3803, Accuracy: 86.29%, Test Loss: 0.4097, Test Accuracy: 85.44%
Epoch [106/250], Loss: 0.3775, Accuracy: 86.37%, Test Loss: 0.4091, Test Accuracy: 85.45%
Epoch [107/250], Loss: 0.3757, Accuracy: 86.38%, Test Loss: 0.4079, Test Accuracy: 85.49%
Epoch [108/250], Loss: 0.3745, Accuracy: 86.49%, Test Loss: 0.4065, Test Accuracy: 85.42%
Epoch [109/250], Loss: 0.3761, Accuracy: 86.40%, Test Loss: 0.4165, Test Accuracy: 84.93%
Epoch [110/250], Loss: 0.3730, Accuracy: 86.59%, Test Loss: 0.4045, Test Accuracy: 85.60%
Epoch [111/250], Loss: 0.3723, Accuracy: 86.57%, Test Loss: 0.4038, Test Accuracy: 85.61%
Epoch [112/250], Loss: 0.3727, Accuracy: 86.49%, Test Loss: 0.4063, Test Accuracy: 85.28%
Epoch [113/250], Loss: 0.3707, Accuracy: 86.64%, Test Loss: 0.4021, Test Accuracy: 85.69%
Epoch [114/250], Loss: 0.3689, Accuracy: 86.68%, Test Loss: 0.4034, Test Accuracy: 85.36%
Epoch [115/250], Loss: 0.3698, Accuracy: 86.63%, Test Loss: 0.4046, Test Accuracy: 85.47%
Epoch [116/250], Loss: 0.3692, Accuracy: 86.67%, Test Loss: 0.3985, Test Accuracy: 85.75%
Epoch [117/250], Loss: 0.3672, Accuracy: 86.87%, Test Loss: 0.3984, Test Accuracy: 85.85%
Epoch [118/250], Loss: 0.3659, Accuracy: 86.79%, Test Loss: 0.4007, Test Accuracy: 85.71%
Epoch [119/250], Loss: 0.3649, Accuracy: 86.84%, Test Loss: 0.3990, Test Accuracy: 85.63%
Epoch [120/250], Loss: 0.3640, Accuracy: 86.92%, Test Loss: 0.3963, Test Accuracy: 85.98%
Epoch [121/250], Loss: 0.3645, Accuracy: 86.77%, Test Loss: 0.3951, Test Accuracy: 85.94%
Epoch [122/250], Loss: 0.3625, Accuracy: 86.94%, Test Loss: 0.3992, Test Accuracy: 85.66%
Epoch [123/250], Loss: 0.3617, Accuracy: 86.90%, Test Loss: 0.3963, Test Accuracy: 85.92%
Epoch [124/250], Loss: 0.3607, Accuracy: 86.98%, Test Loss: 0.3956, Test Accuracy: 86.05%
Epoch [125/250], Loss: 0.3610, Accuracy: 87.02%, Test Loss: 0.3959, Test Accuracy: 85.89%
Epoch [126/250], Loss: 0.3588, Accuracy: 87.11%, Test Loss: 0.3920, Test Accuracy: 86.02%
Epoch [127/250], Loss: 0.3572, Accuracy: 87.10%, Test Loss: 0.3900, Test Accuracy: 86.25%
Epoch [128/250], Loss: 0.3582, Accuracy: 87.01%, Test Loss: 0.3934, Test Accuracy: 85.99%
Epoch [129/250], Loss: 0.3561, Accuracy: 87.19%, Test Loss: 0.3913, Test Accuracy: 86.02%
Epoch [130/250], Loss: 0.3548, Accuracy: 87.14%, Test Loss: 0.3893, Test Accuracy: 86.01%
Epoch [131/250], Loss: 0.3547, Accuracy: 87.24%, Test Loss: 0.3882, Test Accuracy: 86.03%
Epoch [132/250], Loss: 0.3541, Accuracy: 87.29%, Test Loss: 0.3962, Test Accuracy: 85.97%
Epoch [133/250], Loss: 0.3539, Accuracy: 87.34%, Test Loss: 0.3876, Test Accuracy: 86.14%
Epoch [134/250], Loss: 0.3517, Accuracy: 87.35%, Test Loss: 0.3867, Test Accuracy: 86.35%
Epoch [135/250], Loss: 0.3529, Accuracy: 87.26%, Test Loss: 0.3897, Test Accuracy: 86.15%
Epoch [136/250], Loss: 0.3515, Accuracy: 87.39%, Test Loss: 0.3844, Test Accuracy: 86.31%
Epoch [137/250], Loss: 0.3513, Accuracy: 87.28%, Test Loss: 0.3850, Test Accuracy: 86.43%
Epoch [138/250], Loss: 0.3483, Accuracy: 87.42%, Test Loss: 0.3837, Test Accuracy: 86.47%
Epoch [139/250], Loss: 0.3485, Accuracy: 87.44%, Test Loss: 0.3833, Test Accuracy: 86.47%
Epoch [140/250], Loss: 0.3479, Accuracy: 87.50%, Test Loss: 0.3827, Test Accuracy: 86.39%
Epoch [141/250], Loss: 0.3462, Accuracy: 87.53%, Test Loss: 0.3813, Test Accuracy: 86.58%
Epoch [142/250], Loss: 0.3450, Accuracy: 87.57%, Test Loss: 0.3819, Test Accuracy: 86.57%
Epoch [143/250], Loss: 0.3464, Accuracy: 87.53%, Test Loss: 0.3786, Test Accuracy: 86.60%
Epoch [144/250], Loss: 0.3440, Accuracy: 87.56%, Test Loss: 0.3824, Test Accuracy: 86.29%
Epoch [145/250], Loss: 0.3453, Accuracy: 87.56%, Test Loss: 0.3807, Test Accuracy: 86.31%
Epoch [146/250], Loss: 0.3434, Accuracy: 87.59%, Test Loss: 0.3779, Test Accuracy: 86.63%
Epoch [147/250], Loss: 0.3429, Accuracy: 87.59%, Test Loss: 0.3774, Test Accuracy: 86.64%
Epoch [148/250], Loss: 0.3414, Accuracy: 87.70%, Test Loss: 0.3785, Test Accuracy: 86.67%
Epoch [149/250], Loss: 0.3408, Accuracy: 87.74%, Test Loss: 0.3757, Test Accuracy: 86.55%
Epoch [150/250], Loss: 0.3395, Accuracy: 87.75%, Test Loss: 0.3762, Test Accuracy: 86.76%
Epoch [151/250], Loss: 0.3383, Accuracy: 87.74%, Test Loss: 0.3742, Test Accuracy: 86.68%
Epoch [152/250], Loss: 0.3396, Accuracy: 87.72%, Test Loss: 0.3735, Test Accuracy: 86.75%
Epoch [153/250], Loss: 0.3364, Accuracy: 87.91%, Test Loss: 0.3742, Test Accuracy: 86.65%
Epoch [154/250], Loss: 0.3372, Accuracy: 87.80%, Test Loss: 0.3725, Test Accuracy: 86.64%
Epoch [155/250], Loss: 0.3365, Accuracy: 87.83%, Test Loss: 0.3722, Test Accuracy: 86.88%
Epoch [156/250], Loss: 0.3363, Accuracy: 87.82%, Test Loss: 0.3729, Test Accuracy: 86.69%
Epoch [157/250], Loss: 0.3338, Accuracy: 88.04%, Test Loss: 0.3714, Test Accuracy: 86.82%
Epoch [158/250], Loss: 0.3335, Accuracy: 87.97%, Test Loss: 0.3717, Test Accuracy: 86.57%
Epoch [159/250], Loss: 0.3329, Accuracy: 88.02%, Test Loss: 0.3701, Test Accuracy: 86.90%
Epoch [160/250], Loss: 0.3322, Accuracy: 87.96%, Test Loss: 0.3676, Test Accuracy: 87.14%
Epoch [161/250], Loss: 0.3323, Accuracy: 88.00%, Test Loss: 0.3708, Test Accuracy: 87.00%
Epoch [162/250], Loss: 0.3308, Accuracy: 88.09%, Test Loss: 0.3673, Test Accuracy: 86.99%
Epoch [163/250], Loss: 0.3311, Accuracy: 88.00%, Test Loss: 0.3718, Test Accuracy: 86.70%
Epoch [164/250], Loss: 0.3291, Accuracy: 88.14%, Test Loss: 0.3660, Test Accuracy: 87.10%
Epoch [165/250], Loss: 0.3289, Accuracy: 88.13%, Test Loss: 0.3653, Test Accuracy: 86.94%
Epoch [166/250], Loss: 0.3277, Accuracy: 88.21%, Test Loss: 0.3655, Test Accuracy: 87.12%
Epoch [167/250], Loss: 0.3283, Accuracy: 88.12%, Test Loss: 0.3670, Test Accuracy: 86.81%
Epoch [168/250], Loss: 0.3268, Accuracy: 88.26%, Test Loss: 0.3669, Test Accuracy: 86.92%
Epoch [169/250], Loss: 0.3276, Accuracy: 88.19%, Test Loss: 0.3628, Test Accuracy: 87.42%
Epoch [170/250], Loss: 0.3254, Accuracy: 88.22%, Test Loss: 0.3627, Test Accuracy: 87.09%
Epoch [171/250], Loss: 0.3252, Accuracy: 88.25%, Test Loss: 0.3655, Test Accuracy: 86.85%
Epoch [172/250], Loss: 0.3247, Accuracy: 88.24%, Test Loss: 0.3678, Test Accuracy: 86.87%
Epoch [173/250], Loss: 0.3242, Accuracy: 88.31%, Test Loss: 0.3620, Test Accuracy: 87.11%
Epoch [174/250], Loss: 0.3225, Accuracy: 88.41%, Test Loss: 0.3622, Test Accuracy: 87.19%
Epoch [175/250], Loss: 0.3223, Accuracy: 88.37%, Test Loss: 0.3605, Test Accuracy: 87.03%
Epoch [176/250], Loss: 0.3223, Accuracy: 88.38%, Test Loss: 0.3609, Test Accuracy: 87.33%
Epoch [177/250], Loss: 0.3226, Accuracy: 88.36%, Test Loss: 0.3584, Test Accuracy: 87.33%
Epoch [178/250], Loss: 0.3202, Accuracy: 88.45%, Test Loss: 0.3641, Test Accuracy: 87.01%
Epoch [179/250], Loss: 0.3208, Accuracy: 88.42%, Test Loss: 0.3604, Test Accuracy: 87.30%
Epoch [180/250], Loss: 0.3195, Accuracy: 88.48%, Test Loss: 0.3587, Test Accuracy: 87.22%
Epoch [181/250], Loss: 0.3180, Accuracy: 88.48%, Test Loss: 0.3568, Test Accuracy: 87.44%
Epoch [182/250], Loss: 0.3185, Accuracy: 88.47%, Test Loss: 0.3604, Test Accuracy: 87.16%
Epoch [183/250], Loss: 0.3180, Accuracy: 88.49%, Test Loss: 0.3554, Test Accuracy: 87.44%
Epoch [184/250], Loss: 0.3161, Accuracy: 88.56%, Test Loss: 0.3552, Test Accuracy: 87.59%
Epoch [185/250], Loss: 0.3167, Accuracy: 88.55%, Test Loss: 0.3565, Test Accuracy: 87.36%
Epoch [186/250], Loss: 0.3162, Accuracy: 88.60%, Test Loss: 0.3560, Test Accuracy: 87.23%
Epoch [187/250], Loss: 0.3147, Accuracy: 88.72%, Test Loss: 0.3535, Test Accuracy: 87.38%
Epoch [188/250], Loss: 0.3150, Accuracy: 88.59%, Test Loss: 0.3599, Test Accuracy: 87.26%
Epoch [189/250], Loss: 0.3137, Accuracy: 88.61%, Test Loss: 0.3596, Test Accuracy: 86.94%
Epoch [190/250], Loss: 0.3150, Accuracy: 88.55%, Test Loss: 0.3540, Test Accuracy: 87.46%
Epoch [191/250], Loss: 0.3141, Accuracy: 88.66%, Test Loss: 0.3509, Test Accuracy: 87.48%
Epoch [192/250], Loss: 0.3125, Accuracy: 88.67%, Test Loss: 0.3552, Test Accuracy: 87.43%
Epoch [193/250], Loss: 0.3110, Accuracy: 88.82%, Test Loss: 0.3519, Test Accuracy: 87.49%
Epoch [194/250], Loss: 0.3104, Accuracy: 88.86%, Test Loss: 0.3508, Test Accuracy: 87.48%
Epoch [195/250], Loss: 0.3101, Accuracy: 88.76%, Test Loss: 0.3497, Test Accuracy: 87.67%
Epoch [196/250], Loss: 0.3096, Accuracy: 88.79%, Test Loss: 0.3500, Test Accuracy: 87.68%
Epoch [197/250], Loss: 0.3105, Accuracy: 88.81%, Test Loss: 0.3523, Test Accuracy: 87.45%
Epoch [198/250], Loss: 0.3121, Accuracy: 88.68%, Test Loss: 0.3516, Test Accuracy: 87.50%
Epoch [199/250], Loss: 0.3074, Accuracy: 88.86%, Test Loss: 0.3488, Test Accuracy: 87.69%
Epoch [200/250], Loss: 0.3067, Accuracy: 89.01%, Test Loss: 0.3509, Test Accuracy: 87.58%
Epoch [201/250], Loss: 0.3069, Accuracy: 88.96%, Test Loss: 0.3518, Test Accuracy: 87.46%
Epoch [202/250], Loss: 0.3064, Accuracy: 88.94%, Test Loss: 0.3485, Test Accuracy: 87.73%
Epoch [203/250], Loss: 0.3064, Accuracy: 89.00%, Test Loss: 0.3458, Test Accuracy: 87.58%
Epoch [204/250], Loss: 0.3047, Accuracy: 88.98%, Test Loss: 0.3473, Test Accuracy: 87.67%
Epoch [205/250], Loss: 0.3048, Accuracy: 88.92%, Test Loss: 0.3491, Test Accuracy: 87.75%
Epoch [206/250], Loss: 0.3057, Accuracy: 88.89%, Test Loss: 0.3458, Test Accuracy: 87.76%
Epoch [207/250], Loss: 0.3028, Accuracy: 89.07%, Test Loss: 0.3451, Test Accuracy: 87.73%
Epoch [208/250], Loss: 0.3035, Accuracy: 89.01%, Test Loss: 0.3449, Test Accuracy: 87.61%
Epoch [209/250], Loss: 0.3042, Accuracy: 89.03%, Test Loss: 0.3466, Test Accuracy: 87.62%
Epoch [210/250], Loss: 0.3021, Accuracy: 89.05%, Test Loss: 0.3422, Test Accuracy: 87.93%
Epoch [211/250], Loss: 0.3007, Accuracy: 89.13%, Test Loss: 0.3433, Test Accuracy: 88.03%
Epoch [212/250], Loss: 0.3009, Accuracy: 89.11%, Test Loss: 0.3430, Test Accuracy: 87.75%
Epoch [213/250], Loss: 0.3003, Accuracy: 89.18%, Test Loss: 0.3419, Test Accuracy: 87.83%
Epoch [214/250], Loss: 0.2987, Accuracy: 89.18%, Test Loss: 0.3423, Test Accuracy: 88.00%
Epoch [215/250], Loss: 0.2992, Accuracy: 89.19%, Test Loss: 0.3449, Test Accuracy: 87.57%
Epoch [216/250], Loss: 0.2998, Accuracy: 89.15%, Test Loss: 0.3413, Test Accuracy: 88.06%
Epoch [217/250], Loss: 0.2987, Accuracy: 89.18%, Test Loss: 0.3455, Test Accuracy: 87.59%
Epoch [218/250], Loss: 0.2995, Accuracy: 89.14%, Test Loss: 0.3394, Test Accuracy: 87.94%
Epoch [219/250], Loss: 0.2966, Accuracy: 89.30%, Test Loss: 0.3395, Test Accuracy: 87.85%
Epoch [220/250], Loss: 0.2973, Accuracy: 89.23%, Test Loss: 0.3432, Test Accuracy: 87.69%
Epoch [221/250], Loss: 0.2960, Accuracy: 89.29%, Test Loss: 0.3380, Test Accuracy: 87.87%
Epoch [222/250], Loss: 0.2958, Accuracy: 89.30%, Test Loss: 0.3451, Test Accuracy: 87.61%
Epoch [223/250], Loss: 0.2953, Accuracy: 89.34%, Test Loss: 0.3383, Test Accuracy: 87.90%
Epoch [224/250], Loss: 0.2948, Accuracy: 89.27%, Test Loss: 0.3400, Test Accuracy: 87.90%
Epoch [225/250], Loss: 0.2954, Accuracy: 89.30%, Test Loss: 0.3407, Test Accuracy: 87.98%
Epoch [226/250], Loss: 0.2937, Accuracy: 89.38%, Test Loss: 0.3405, Test Accuracy: 87.79%
Epoch [227/250], Loss: 0.2925, Accuracy: 89.41%, Test Loss: 0.3356, Test Accuracy: 88.22%
Epoch [228/250], Loss: 0.2920, Accuracy: 89.54%, Test Loss: 0.3399, Test Accuracy: 87.98%
Epoch [229/250], Loss: 0.2917, Accuracy: 89.39%, Test Loss: 0.3387, Test Accuracy: 87.90%
Epoch [230/250], Loss: 0.2914, Accuracy: 89.45%, Test Loss: 0.3376, Test Accuracy: 87.80%
Epoch [231/250], Loss: 0.2903, Accuracy: 89.45%, Test Loss: 0.3364, Test Accuracy: 88.05%
Epoch [232/250], Loss: 0.2911, Accuracy: 89.47%, Test Loss: 0.3370, Test Accuracy: 88.03%
Epoch [233/250], Loss: 0.2907, Accuracy: 89.49%, Test Loss: 0.3382, Test Accuracy: 87.92%
Epoch [234/250], Loss: 0.2898, Accuracy: 89.55%, Test Loss: 0.3346, Test Accuracy: 88.04%
Epoch [235/250], Loss: 0.2887, Accuracy: 89.58%, Test Loss: 0.3332, Test Accuracy: 88.20%
Epoch [236/250], Loss: 0.2905, Accuracy: 89.54%, Test Loss: 0.3349, Test Accuracy: 88.02%
Epoch [237/250], Loss: 0.2880, Accuracy: 89.60%, Test Loss: 0.3338, Test Accuracy: 88.22%
Epoch [238/250], Loss: 0.2888, Accuracy: 89.51%, Test Loss: 0.3318, Test Accuracy: 88.23%
Epoch [239/250], Loss: 0.2894, Accuracy: 89.56%, Test Loss: 0.3333, Test Accuracy: 88.23%
Epoch [240/250], Loss: 0.2871, Accuracy: 89.67%, Test Loss: 0.3301, Test Accuracy: 88.20%
Epoch [241/250], Loss: 0.2875, Accuracy: 89.56%, Test Loss: 0.3316, Test Accuracy: 88.29%
Epoch [242/250], Loss: 0.2867, Accuracy: 89.64%, Test Loss: 0.3338, Test Accuracy: 87.92%
Epoch [243/250], Loss: 0.2866, Accuracy: 89.60%, Test Loss: 0.3325, Test Accuracy: 87.88%
Epoch [244/250], Loss: 0.2855, Accuracy: 89.63%, Test Loss: 0.3303, Test Accuracy: 88.30%
Epoch [245/250], Loss: 0.2853, Accuracy: 89.70%, Test Loss: 0.3286, Test Accuracy: 88.34%
Epoch [246/250], Loss: 0.2834, Accuracy: 89.78%, Test Loss: 0.3333, Test Accuracy: 88.07%
Epoch [247/250], Loss: 0.2836, Accuracy: 89.73%, Test Loss: 0.3299, Test Accuracy: 88.24%
Epoch [248/250], Loss: 0.2836, Accuracy: 89.71%, Test Loss: 0.3275, Test Accuracy: 88.51%
Epoch [249/250], Loss: 0.2823, Accuracy: 89.85%, Test Loss: 0.3296, Test Accuracy: 88.43%
Epoch [250/250], Loss: 0.2816, Accuracy: 89.87%, Test Loss: 0.3290, Test Accuracy: 88.18%
import matplotlib.pyplot as plt

plt.plot(train_losses, label='train_losses')
plt.plot(test_losses, label='test_losses')
plt.legend()
<matplotlib.legend.Legend at 0x7f4f0c0ee4d0>
image
import matplotlib.pyplot as plt

plt.plot(train_accuracies, label='train_accuracy')
plt.plot(test_accuracies, label='test_accuracy')
plt.legend()
<matplotlib.legend.Legend at 0x7f4f02d32ec0>
image
Cifar10_CNN_onlyConv2d
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.init as init

# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Data

# Load CFashionMNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1024, num_workers=10, shuffle=True, drop_last=True)

testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1024, num_workers=10, shuffle=False)
Files already downloaded and verified
Files already downloaded and verified
import matplotlib.pyplot as plt
import numpy as np

# Function to display the images
def imshow(img):
    img = img*0.5 + 0.5
    np_img = img.numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()

for i, (images, labels) in enumerate(trainloader, 0):
    # Plot some images
    imshow(torchvision.utils.make_grid(images[:8]))  # Display 8 images from the batch
    break
image

Model

import torch
import torch.nn as nn
from torchsummary import summary

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=7)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=7)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=7)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=7)
        self.flatten = nn.Flatten()
        self.dense1 = nn.Linear(8*8*256, 128)
        self.dense2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.flatten(x)
        x = self.relu(self.dense1(x))
        x = self.dense2(x)
        return x

model = CustomModel()
model = model.to(device)

# Print model
summary(model, (3, 32, 32))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 26, 26]           4,736
              ReLU-2           [-1, 32, 26, 26]               0
            Conv2d-3           [-1, 64, 20, 20]         100,416
              ReLU-4           [-1, 64, 20, 20]               0
            Conv2d-5          [-1, 128, 14, 14]         401,536
              ReLU-6          [-1, 128, 14, 14]               0
            Conv2d-7            [-1, 256, 8, 8]       1,605,888
              ReLU-8            [-1, 256, 8, 8]               0
           Flatten-9                [-1, 16384]               0
           Linear-10                  [-1, 128]       2,097,280
             ReLU-11                  [-1, 128]               0
           Linear-12                   [-1, 10]           1,290
================================================================
Total params: 4,211,146
Trainable params: 4,211,146
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 1.48
Params size (MB): 16.06
Estimated Total Size (MB): 17.56
----------------------------------------------------------------

Loss, Optimizer, and Evaluation Function

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-5)
# Function to compute loss and accuracy for test set
def evaluate(model, testloader, criterion):
    model.eval()
    test_loss = 0.0
    running_correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            # Move inputs and labels to the device
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

    accuracy = 100 * running_correct / total
    test_loss = test_loss / len(testloader)
    return test_loss, accuracy

Train

# some parameter
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
max_epoch = 250
# train
for epoch in range(max_epoch):
    model.train()
    running_loss = 0.0
    running_correct = 0   # to track number of correct predictions
    total = 0             # to track total number of samples

    for i, (inputs, labels) in enumerate(trainloader, 0):
        # Move inputs and labels to the device
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

        # Determine class predictions and track accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        running_correct += (predicted == labels).sum().item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()        

    epoch_accuracy = 100 * running_correct / total
    epoch_loss = running_loss / (i + 1)
    
    test_loss, test_accuracy = evaluate(model, testloader, criterion)
    print(f"Epoch [{epoch + 1}/{max_epoch}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    # save for plot
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_accuracy)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
Epoch [1/250], Loss: 2.0785, Accuracy: 23.25%, Test Loss: 1.8775, Test Accuracy: 31.82%
Epoch [2/250], Loss: 1.8159, Accuracy: 33.27%, Test Loss: 1.7557, Test Accuracy: 36.82%
Epoch [3/250], Loss: 1.7211, Accuracy: 36.89%, Test Loss: 1.6870, Test Accuracy: 38.95%
Epoch [4/250], Loss: 1.6649, Accuracy: 39.51%, Test Loss: 1.6543, Test Accuracy: 40.93%
Epoch [5/250], Loss: 1.6198, Accuracy: 41.44%, Test Loss: 1.5995, Test Accuracy: 41.80%
Epoch [6/250], Loss: 1.5785, Accuracy: 42.93%, Test Loss: 1.5484, Test Accuracy: 43.82%
Epoch [7/250], Loss: 1.5427, Accuracy: 44.27%, Test Loss: 1.5210, Test Accuracy: 45.17%
Epoch [8/250], Loss: 1.5115, Accuracy: 45.33%, Test Loss: 1.4980, Test Accuracy: 45.52%
Epoch [9/250], Loss: 1.4857, Accuracy: 46.44%, Test Loss: 1.4709, Test Accuracy: 47.13%
Epoch [10/250], Loss: 1.4711, Accuracy: 46.94%, Test Loss: 1.4627, Test Accuracy: 47.20%
Epoch [11/250], Loss: 1.4385, Accuracy: 48.19%, Test Loss: 1.4226, Test Accuracy: 48.26%
Epoch [12/250], Loss: 1.4182, Accuracy: 48.80%, Test Loss: 1.4165, Test Accuracy: 48.84%
Epoch [13/250], Loss: 1.4023, Accuracy: 49.47%, Test Loss: 1.3832, Test Accuracy: 49.97%
Epoch [14/250], Loss: 1.3813, Accuracy: 50.35%, Test Loss: 1.3733, Test Accuracy: 50.78%
Epoch [15/250], Loss: 1.3611, Accuracy: 51.22%, Test Loss: 1.3552, Test Accuracy: 51.20%
Epoch [16/250], Loss: 1.3475, Accuracy: 51.48%, Test Loss: 1.3374, Test Accuracy: 51.54%
Epoch [17/250], Loss: 1.3285, Accuracy: 52.31%, Test Loss: 1.3263, Test Accuracy: 52.12%
Epoch [18/250], Loss: 1.3134, Accuracy: 52.93%, Test Loss: 1.3175, Test Accuracy: 52.54%
Epoch [19/250], Loss: 1.3011, Accuracy: 53.47%, Test Loss: 1.3149, Test Accuracy: 52.01%
Epoch [20/250], Loss: 1.2903, Accuracy: 53.86%, Test Loss: 1.2944, Test Accuracy: 53.64%
Epoch [21/250], Loss: 1.2730, Accuracy: 54.47%, Test Loss: 1.2759, Test Accuracy: 54.19%
Epoch [22/250], Loss: 1.2642, Accuracy: 54.89%, Test Loss: 1.2631, Test Accuracy: 55.02%
Epoch [23/250], Loss: 1.2416, Accuracy: 55.69%, Test Loss: 1.2506, Test Accuracy: 55.33%
Epoch [24/250], Loss: 1.2275, Accuracy: 56.26%, Test Loss: 1.2502, Test Accuracy: 55.24%
Epoch [25/250], Loss: 1.2183, Accuracy: 56.58%, Test Loss: 1.2244, Test Accuracy: 56.33%
Epoch [26/250], Loss: 1.2084, Accuracy: 56.99%, Test Loss: 1.2614, Test Accuracy: 55.05%
Epoch [27/250], Loss: 1.2042, Accuracy: 57.04%, Test Loss: 1.2376, Test Accuracy: 55.76%
Epoch [28/250], Loss: 1.1893, Accuracy: 57.63%, Test Loss: 1.2052, Test Accuracy: 57.20%
Epoch [29/250], Loss: 1.1744, Accuracy: 58.33%, Test Loss: 1.1974, Test Accuracy: 57.72%
Epoch [30/250], Loss: 1.1719, Accuracy: 58.46%, Test Loss: 1.2031, Test Accuracy: 57.58%
Epoch [31/250], Loss: 1.1573, Accuracy: 59.01%, Test Loss: 1.1822, Test Accuracy: 58.28%
Epoch [32/250], Loss: 1.1474, Accuracy: 59.34%, Test Loss: 1.2187, Test Accuracy: 56.99%
Epoch [33/250], Loss: 1.1410, Accuracy: 59.54%, Test Loss: 1.1687, Test Accuracy: 58.80%
Epoch [34/250], Loss: 1.1213, Accuracy: 60.50%, Test Loss: 1.1757, Test Accuracy: 58.21%
Epoch [35/250], Loss: 1.1227, Accuracy: 60.43%, Test Loss: 1.1552, Test Accuracy: 58.76%
Epoch [36/250], Loss: 1.1106, Accuracy: 60.69%, Test Loss: 1.1439, Test Accuracy: 59.72%
Epoch [37/250], Loss: 1.1001, Accuracy: 61.39%, Test Loss: 1.1308, Test Accuracy: 60.35%
Epoch [38/250], Loss: 1.0917, Accuracy: 61.48%, Test Loss: 1.1253, Test Accuracy: 60.20%
Epoch [39/250], Loss: 1.0843, Accuracy: 61.82%, Test Loss: 1.1211, Test Accuracy: 60.84%
Epoch [40/250], Loss: 1.0765, Accuracy: 62.30%, Test Loss: 1.1381, Test Accuracy: 59.58%
Epoch [41/250], Loss: 1.0651, Accuracy: 62.57%, Test Loss: 1.1079, Test Accuracy: 61.09%
Epoch [42/250], Loss: 1.0572, Accuracy: 62.74%, Test Loss: 1.1154, Test Accuracy: 60.73%
Epoch [43/250], Loss: 1.0559, Accuracy: 62.83%, Test Loss: 1.1020, Test Accuracy: 61.15%
Epoch [44/250], Loss: 1.0434, Accuracy: 63.29%, Test Loss: 1.1002, Test Accuracy: 61.51%
Epoch [45/250], Loss: 1.0357, Accuracy: 63.62%, Test Loss: 1.0903, Test Accuracy: 61.58%
Epoch [46/250], Loss: 1.0288, Accuracy: 63.95%, Test Loss: 1.0887, Test Accuracy: 61.65%
Epoch [47/250], Loss: 1.0248, Accuracy: 64.17%, Test Loss: 1.1032, Test Accuracy: 60.87%
Epoch [48/250], Loss: 1.0182, Accuracy: 64.39%, Test Loss: 1.0788, Test Accuracy: 62.13%
Epoch [49/250], Loss: 1.0099, Accuracy: 64.45%, Test Loss: 1.0818, Test Accuracy: 62.10%
Epoch [50/250], Loss: 1.0132, Accuracy: 64.52%, Test Loss: 1.0946, Test Accuracy: 61.34%
Epoch [51/250], Loss: 0.9929, Accuracy: 65.23%, Test Loss: 1.0666, Test Accuracy: 62.28%
Epoch [52/250], Loss: 0.9898, Accuracy: 65.14%, Test Loss: 1.0643, Test Accuracy: 62.66%
Epoch [53/250], Loss: 0.9812, Accuracy: 65.62%, Test Loss: 1.0482, Test Accuracy: 63.16%
Epoch [54/250], Loss: 0.9713, Accuracy: 66.00%, Test Loss: 1.0615, Test Accuracy: 62.74%
Epoch [55/250], Loss: 0.9691, Accuracy: 66.04%, Test Loss: 1.0489, Test Accuracy: 63.42%
Epoch [56/250], Loss: 0.9659, Accuracy: 66.23%, Test Loss: 1.0404, Test Accuracy: 63.70%
Epoch [57/250], Loss: 0.9535, Accuracy: 66.70%, Test Loss: 1.0437, Test Accuracy: 63.25%
Epoch [58/250], Loss: 0.9547, Accuracy: 66.71%, Test Loss: 1.0459, Test Accuracy: 63.63%
Epoch [59/250], Loss: 0.9427, Accuracy: 67.18%, Test Loss: 1.0295, Test Accuracy: 63.81%
Epoch [60/250], Loss: 0.9363, Accuracy: 67.23%, Test Loss: 1.0302, Test Accuracy: 63.47%
Epoch [61/250], Loss: 0.9328, Accuracy: 67.24%, Test Loss: 1.0350, Test Accuracy: 63.85%
Epoch [62/250], Loss: 0.9271, Accuracy: 67.56%, Test Loss: 1.0167, Test Accuracy: 63.95%
Epoch [63/250], Loss: 0.9241, Accuracy: 67.56%, Test Loss: 1.0154, Test Accuracy: 64.51%
Epoch [64/250], Loss: 0.9142, Accuracy: 68.03%, Test Loss: 1.0131, Test Accuracy: 64.02%
Epoch [65/250], Loss: 0.9085, Accuracy: 68.20%, Test Loss: 1.0061, Test Accuracy: 64.59%
Epoch [66/250], Loss: 0.9000, Accuracy: 68.57%, Test Loss: 1.0136, Test Accuracy: 64.23%
Epoch [67/250], Loss: 0.9054, Accuracy: 68.11%, Test Loss: 1.0088, Test Accuracy: 64.73%
Epoch [68/250], Loss: 0.8894, Accuracy: 69.11%, Test Loss: 1.0098, Test Accuracy: 64.50%
Epoch [69/250], Loss: 0.8858, Accuracy: 68.92%, Test Loss: 1.0093, Test Accuracy: 64.81%
Epoch [70/250], Loss: 0.8818, Accuracy: 69.45%, Test Loss: 0.9902, Test Accuracy: 65.01%
Epoch [71/250], Loss: 0.8716, Accuracy: 69.48%, Test Loss: 0.9998, Test Accuracy: 65.08%
Epoch [72/250], Loss: 0.8677, Accuracy: 69.73%, Test Loss: 0.9893, Test Accuracy: 65.10%
Epoch [73/250], Loss: 0.8641, Accuracy: 69.86%, Test Loss: 0.9920, Test Accuracy: 65.11%
Epoch [74/250], Loss: 0.8590, Accuracy: 70.08%, Test Loss: 0.9870, Test Accuracy: 65.18%
Epoch [75/250], Loss: 0.8522, Accuracy: 70.23%, Test Loss: 0.9853, Test Accuracy: 65.33%
Epoch [76/250], Loss: 0.8441, Accuracy: 70.48%, Test Loss: 0.9937, Test Accuracy: 65.30%
Epoch [77/250], Loss: 0.8431, Accuracy: 70.63%, Test Loss: 0.9974, Test Accuracy: 64.94%
Epoch [78/250], Loss: 0.8385, Accuracy: 70.84%, Test Loss: 0.9731, Test Accuracy: 65.93%
Epoch [79/250], Loss: 0.8281, Accuracy: 71.18%, Test Loss: 0.9834, Test Accuracy: 65.34%
Epoch [80/250], Loss: 0.8228, Accuracy: 71.32%, Test Loss: 0.9667, Test Accuracy: 66.06%
Epoch [81/250], Loss: 0.8181, Accuracy: 71.63%, Test Loss: 0.9605, Test Accuracy: 66.09%
Epoch [82/250], Loss: 0.8149, Accuracy: 71.63%, Test Loss: 0.9650, Test Accuracy: 66.23%
Epoch [83/250], Loss: 0.8048, Accuracy: 72.16%, Test Loss: 0.9839, Test Accuracy: 65.41%
Epoch [84/250], Loss: 0.8068, Accuracy: 71.83%, Test Loss: 0.9674, Test Accuracy: 65.90%
Epoch [85/250], Loss: 0.7973, Accuracy: 72.16%, Test Loss: 0.9706, Test Accuracy: 66.32%
Epoch [86/250], Loss: 0.7948, Accuracy: 72.31%, Test Loss: 0.9693, Test Accuracy: 65.78%
Epoch [87/250], Loss: 0.7897, Accuracy: 72.52%, Test Loss: 0.9702, Test Accuracy: 65.73%
Epoch [88/250], Loss: 0.7884, Accuracy: 72.58%, Test Loss: 0.9476, Test Accuracy: 66.66%
Epoch [89/250], Loss: 0.7848, Accuracy: 72.65%, Test Loss: 0.9815, Test Accuracy: 65.84%
Epoch [90/250], Loss: 0.7749, Accuracy: 73.01%, Test Loss: 0.9616, Test Accuracy: 66.64%
Epoch [91/250], Loss: 0.7703, Accuracy: 73.18%, Test Loss: 0.9654, Test Accuracy: 66.52%
Epoch [92/250], Loss: 0.7632, Accuracy: 73.60%, Test Loss: 0.9629, Test Accuracy: 66.28%
Epoch [93/250], Loss: 0.7567, Accuracy: 73.73%, Test Loss: 0.9424, Test Accuracy: 67.38%
Epoch [94/250], Loss: 0.7547, Accuracy: 73.84%, Test Loss: 0.9432, Test Accuracy: 66.91%
Epoch [95/250], Loss: 0.7455, Accuracy: 74.19%, Test Loss: 0.9451, Test Accuracy: 66.65%
Epoch [96/250], Loss: 0.7407, Accuracy: 74.17%, Test Loss: 0.9588, Test Accuracy: 66.86%
Epoch [97/250], Loss: 0.7431, Accuracy: 74.24%, Test Loss: 0.9399, Test Accuracy: 67.09%
Epoch [98/250], Loss: 0.7310, Accuracy: 74.74%, Test Loss: 0.9448, Test Accuracy: 67.24%
Epoch [99/250], Loss: 0.7302, Accuracy: 74.68%, Test Loss: 0.9504, Test Accuracy: 66.68%
Epoch [100/250], Loss: 0.7289, Accuracy: 74.81%, Test Loss: 0.9493, Test Accuracy: 67.01%
Epoch [101/250], Loss: 0.7208, Accuracy: 75.11%, Test Loss: 0.9369, Test Accuracy: 67.43%
Epoch [102/250], Loss: 0.7145, Accuracy: 75.18%, Test Loss: 0.9408, Test Accuracy: 67.14%
Epoch [103/250], Loss: 0.7078, Accuracy: 75.62%, Test Loss: 0.9423, Test Accuracy: 67.43%
Epoch [104/250], Loss: 0.7005, Accuracy: 75.78%, Test Loss: 0.9388, Test Accuracy: 67.66%
Epoch [105/250], Loss: 0.7071, Accuracy: 75.32%, Test Loss: 0.9500, Test Accuracy: 66.88%
Epoch [106/250], Loss: 0.7011, Accuracy: 75.66%, Test Loss: 0.9433, Test Accuracy: 67.65%
Epoch [107/250], Loss: 0.6895, Accuracy: 76.04%, Test Loss: 0.9356, Test Accuracy: 67.36%
Epoch [108/250], Loss: 0.6857, Accuracy: 76.25%, Test Loss: 0.9348, Test Accuracy: 67.59%
Epoch [109/250], Loss: 0.6783, Accuracy: 76.61%, Test Loss: 0.9435, Test Accuracy: 67.22%
Epoch [110/250], Loss: 0.6796, Accuracy: 76.45%, Test Loss: 0.9210, Test Accuracy: 68.45%
Epoch [111/250], Loss: 0.6740, Accuracy: 76.72%, Test Loss: 0.9263, Test Accuracy: 67.90%
Epoch [112/250], Loss: 0.6616, Accuracy: 77.17%, Test Loss: 0.9355, Test Accuracy: 68.02%
Epoch [113/250], Loss: 0.6615, Accuracy: 77.12%, Test Loss: 0.9435, Test Accuracy: 67.87%
Epoch [114/250], Loss: 0.6682, Accuracy: 76.90%, Test Loss: 0.9243, Test Accuracy: 68.52%
Epoch [115/250], Loss: 0.6490, Accuracy: 77.44%, Test Loss: 0.9355, Test Accuracy: 67.97%
Epoch [116/250], Loss: 0.6554, Accuracy: 77.22%, Test Loss: 0.9288, Test Accuracy: 68.29%
Epoch [117/250], Loss: 0.6432, Accuracy: 77.72%, Test Loss: 0.9332, Test Accuracy: 68.44%
Epoch [118/250], Loss: 0.6425, Accuracy: 77.89%, Test Loss: 0.9361, Test Accuracy: 67.90%
Epoch [119/250], Loss: 0.6368, Accuracy: 78.04%, Test Loss: 0.9311, Test Accuracy: 67.80%
Epoch [120/250], Loss: 0.6363, Accuracy: 77.91%, Test Loss: 0.9212, Test Accuracy: 68.65%
Epoch [121/250], Loss: 0.6261, Accuracy: 78.27%, Test Loss: 0.9467, Test Accuracy: 67.72%
Epoch [122/250], Loss: 0.6240, Accuracy: 78.44%, Test Loss: 0.9459, Test Accuracy: 68.24%
Epoch [123/250], Loss: 0.6306, Accuracy: 78.28%, Test Loss: 0.9360, Test Accuracy: 68.33%
Epoch [124/250], Loss: 0.6171, Accuracy: 78.74%, Test Loss: 0.9229, Test Accuracy: 68.53%
Epoch [125/250], Loss: 0.6089, Accuracy: 79.00%, Test Loss: 0.9232, Test Accuracy: 69.06%
Epoch [126/250], Loss: 0.6032, Accuracy: 79.16%, Test Loss: 0.9295, Test Accuracy: 68.75%
Epoch [127/250], Loss: 0.5971, Accuracy: 79.59%, Test Loss: 0.9276, Test Accuracy: 68.98%
Epoch [128/250], Loss: 0.5968, Accuracy: 79.40%, Test Loss: 0.9279, Test Accuracy: 68.47%
Epoch [129/250], Loss: 0.5912, Accuracy: 79.76%, Test Loss: 0.9387, Test Accuracy: 68.78%
Epoch [130/250], Loss: 0.5850, Accuracy: 79.90%, Test Loss: 0.9486, Test Accuracy: 68.10%
Epoch [131/250], Loss: 0.5848, Accuracy: 79.92%, Test Loss: 0.9447, Test Accuracy: 68.45%
Epoch [132/250], Loss: 0.5790, Accuracy: 80.07%, Test Loss: 0.9329, Test Accuracy: 68.83%
Epoch [133/250], Loss: 0.5710, Accuracy: 80.36%, Test Loss: 0.9321, Test Accuracy: 68.76%
Epoch [134/250], Loss: 0.5673, Accuracy: 80.40%, Test Loss: 0.9342, Test Accuracy: 68.79%
Epoch [135/250], Loss: 0.5614, Accuracy: 80.78%, Test Loss: 0.9283, Test Accuracy: 68.83%
Epoch [136/250], Loss: 0.5571, Accuracy: 80.92%, Test Loss: 0.9406, Test Accuracy: 68.46%
Epoch [137/250], Loss: 0.5522, Accuracy: 80.95%, Test Loss: 0.9506, Test Accuracy: 68.55%
Epoch [138/250], Loss: 0.5561, Accuracy: 80.94%, Test Loss: 0.9296, Test Accuracy: 69.22%
Epoch [139/250], Loss: 0.5484, Accuracy: 81.12%, Test Loss: 0.9439, Test Accuracy: 68.45%
Epoch [140/250], Loss: 0.5393, Accuracy: 81.57%, Test Loss: 0.9410, Test Accuracy: 68.63%
Epoch [141/250], Loss: 0.5501, Accuracy: 80.85%, Test Loss: 0.9389, Test Accuracy: 69.03%
Epoch [142/250], Loss: 0.5445, Accuracy: 81.19%, Test Loss: 0.9556, Test Accuracy: 68.39%
Epoch [143/250], Loss: 0.5314, Accuracy: 81.75%, Test Loss: 0.9461, Test Accuracy: 68.77%
Epoch [144/250], Loss: 0.5295, Accuracy: 81.92%, Test Loss: 0.9573, Test Accuracy: 68.37%
Epoch [145/250], Loss: 0.5288, Accuracy: 81.79%, Test Loss: 0.9557, Test Accuracy: 68.57%
Epoch [146/250], Loss: 0.5208, Accuracy: 82.21%, Test Loss: 0.9535, Test Accuracy: 68.89%
Epoch [147/250], Loss: 0.5159, Accuracy: 82.41%, Test Loss: 0.9499, Test Accuracy: 68.96%
Epoch [148/250], Loss: 0.5127, Accuracy: 82.31%, Test Loss: 0.9971, Test Accuracy: 67.81%
Epoch [149/250], Loss: 0.5042, Accuracy: 82.87%, Test Loss: 0.9516, Test Accuracy: 69.00%
Epoch [150/250], Loss: 0.4978, Accuracy: 83.08%, Test Loss: 0.9635, Test Accuracy: 69.16%
Epoch [151/250], Loss: 0.5083, Accuracy: 82.53%, Test Loss: 0.9798, Test Accuracy: 68.23%
Epoch [152/250], Loss: 0.5076, Accuracy: 82.52%, Test Loss: 0.9460, Test Accuracy: 69.12%
Epoch [153/250], Loss: 0.4878, Accuracy: 83.50%, Test Loss: 0.9503, Test Accuracy: 69.01%
Epoch [154/250], Loss: 0.4849, Accuracy: 83.59%, Test Loss: 0.9684, Test Accuracy: 68.53%
Epoch [155/250], Loss: 0.4803, Accuracy: 83.69%, Test Loss: 0.9796, Test Accuracy: 68.70%
Epoch [156/250], Loss: 0.4799, Accuracy: 83.74%, Test Loss: 0.9596, Test Accuracy: 69.05%
Epoch [157/250], Loss: 0.4688, Accuracy: 84.19%, Test Loss: 0.9651, Test Accuracy: 69.21%
Epoch [158/250], Loss: 0.4704, Accuracy: 83.96%, Test Loss: 0.9671, Test Accuracy: 69.02%
Epoch [159/250], Loss: 0.4636, Accuracy: 84.30%, Test Loss: 0.9691, Test Accuracy: 69.24%
Epoch [160/250], Loss: 0.4624, Accuracy: 84.19%, Test Loss: 0.9743, Test Accuracy: 68.99%
Epoch [161/250], Loss: 0.4582, Accuracy: 84.45%, Test Loss: 0.9817, Test Accuracy: 68.65%
Epoch [162/250], Loss: 0.4583, Accuracy: 84.57%, Test Loss: 0.9534, Test Accuracy: 69.59%
Epoch [163/250], Loss: 0.4535, Accuracy: 84.54%, Test Loss: 0.9778, Test Accuracy: 69.00%
Epoch [164/250], Loss: 0.4411, Accuracy: 85.18%, Test Loss: 0.9767, Test Accuracy: 69.15%
Epoch [165/250], Loss: 0.4435, Accuracy: 85.08%, Test Loss: 0.9776, Test Accuracy: 69.19%
Epoch [166/250], Loss: 0.4379, Accuracy: 85.25%, Test Loss: 0.9956, Test Accuracy: 69.15%
Epoch [167/250], Loss: 0.4376, Accuracy: 85.36%, Test Loss: 0.9846, Test Accuracy: 69.13%
Epoch [168/250], Loss: 0.4246, Accuracy: 85.74%, Test Loss: 0.9888, Test Accuracy: 68.72%
Epoch [169/250], Loss: 0.4258, Accuracy: 85.63%, Test Loss: 0.9860, Test Accuracy: 69.17%
Epoch [170/250], Loss: 0.4179, Accuracy: 86.01%, Test Loss: 0.9948, Test Accuracy: 68.99%
Epoch [171/250], Loss: 0.4202, Accuracy: 85.85%, Test Loss: 0.9877, Test Accuracy: 69.34%
Epoch [172/250], Loss: 0.4117, Accuracy: 86.22%, Test Loss: 0.9963, Test Accuracy: 69.15%
Epoch [173/250], Loss: 0.4041, Accuracy: 86.53%, Test Loss: 1.0219, Test Accuracy: 68.69%
Epoch [174/250], Loss: 0.4096, Accuracy: 86.13%, Test Loss: 1.0019, Test Accuracy: 69.67%
Epoch [175/250], Loss: 0.4001, Accuracy: 86.78%, Test Loss: 0.9928, Test Accuracy: 69.82%
Epoch [176/250], Loss: 0.4013, Accuracy: 86.54%, Test Loss: 0.9941, Test Accuracy: 69.23%
Epoch [177/250], Loss: 0.3909, Accuracy: 87.05%, Test Loss: 1.0037, Test Accuracy: 69.41%
Epoch [178/250], Loss: 0.3889, Accuracy: 86.99%, Test Loss: 1.0351, Test Accuracy: 68.95%
Epoch [179/250], Loss: 0.3902, Accuracy: 86.98%, Test Loss: 1.0144, Test Accuracy: 69.43%
Epoch [180/250], Loss: 0.3790, Accuracy: 87.70%, Test Loss: 1.0262, Test Accuracy: 68.87%
Epoch [181/250], Loss: 0.3770, Accuracy: 87.56%, Test Loss: 1.0226, Test Accuracy: 68.87%
Epoch [182/250], Loss: 0.3673, Accuracy: 87.88%, Test Loss: 1.0164, Test Accuracy: 69.64%
Epoch [183/250], Loss: 0.3709, Accuracy: 87.78%, Test Loss: 1.0281, Test Accuracy: 69.05%
Epoch [184/250], Loss: 0.3731, Accuracy: 87.48%, Test Loss: 1.0554, Test Accuracy: 68.26%
Epoch [185/250], Loss: 0.3625, Accuracy: 88.08%, Test Loss: 1.0638, Test Accuracy: 68.36%
Epoch [186/250], Loss: 0.3625, Accuracy: 87.95%, Test Loss: 1.0313, Test Accuracy: 69.18%
Epoch [187/250], Loss: 0.3452, Accuracy: 88.87%, Test Loss: 1.0643, Test Accuracy: 68.61%
Epoch [188/250], Loss: 0.3483, Accuracy: 88.61%, Test Loss: 1.0387, Test Accuracy: 69.41%
Epoch [189/250], Loss: 0.3444, Accuracy: 88.74%, Test Loss: 1.0530, Test Accuracy: 69.24%
Epoch [190/250], Loss: 0.3503, Accuracy: 88.37%, Test Loss: 1.0466, Test Accuracy: 69.26%
Epoch [191/250], Loss: 0.3416, Accuracy: 88.75%, Test Loss: 1.0587, Test Accuracy: 69.20%
Epoch [192/250], Loss: 0.3363, Accuracy: 89.09%, Test Loss: 1.0575, Test Accuracy: 69.17%
Epoch [193/250], Loss: 0.3411, Accuracy: 88.82%, Test Loss: 1.0506, Test Accuracy: 69.24%
Epoch [194/250], Loss: 0.3278, Accuracy: 89.48%, Test Loss: 1.0684, Test Accuracy: 69.30%
Epoch [195/250], Loss: 0.3296, Accuracy: 89.13%, Test Loss: 1.0729, Test Accuracy: 69.24%
Epoch [196/250], Loss: 0.3202, Accuracy: 89.71%, Test Loss: 1.0750, Test Accuracy: 69.37%
Epoch [197/250], Loss: 0.3213, Accuracy: 89.58%, Test Loss: 1.0729, Test Accuracy: 69.48%
Epoch [198/250], Loss: 0.3148, Accuracy: 89.85%, Test Loss: 1.0970, Test Accuracy: 68.85%
Epoch [199/250], Loss: 0.3074, Accuracy: 90.07%, Test Loss: 1.0827, Test Accuracy: 69.32%
Epoch [200/250], Loss: 0.3143, Accuracy: 89.74%, Test Loss: 1.0812, Test Accuracy: 69.40%
Epoch [201/250], Loss: 0.2998, Accuracy: 90.53%, Test Loss: 1.0874, Test Accuracy: 69.26%
Epoch [202/250], Loss: 0.3083, Accuracy: 90.01%, Test Loss: 1.1506, Test Accuracy: 68.33%
Epoch [203/250], Loss: 0.3016, Accuracy: 90.27%, Test Loss: 1.0991, Test Accuracy: 69.15%
Epoch [204/250], Loss: 0.2979, Accuracy: 90.43%, Test Loss: 1.1200, Test Accuracy: 68.96%
Epoch [205/250], Loss: 0.2948, Accuracy: 90.56%, Test Loss: 1.1108, Test Accuracy: 69.15%
Epoch [206/250], Loss: 0.2856, Accuracy: 91.19%, Test Loss: 1.1084, Test Accuracy: 69.33%
Epoch [207/250], Loss: 0.2866, Accuracy: 90.97%, Test Loss: 1.1173, Test Accuracy: 69.38%
Epoch [208/250], Loss: 0.2733, Accuracy: 91.46%, Test Loss: 1.1146, Test Accuracy: 69.33%
Epoch [209/250], Loss: 0.2698, Accuracy: 91.79%, Test Loss: 1.1394, Test Accuracy: 69.13%
Epoch [210/250], Loss: 0.2770, Accuracy: 91.33%, Test Loss: 1.1503, Test Accuracy: 68.47%
Epoch [211/250], Loss: 0.2702, Accuracy: 91.59%, Test Loss: 1.1482, Test Accuracy: 68.69%
Epoch [212/250], Loss: 0.2605, Accuracy: 91.91%, Test Loss: 1.1554, Test Accuracy: 69.15%
Epoch [213/250], Loss: 0.2582, Accuracy: 92.00%, Test Loss: 1.1654, Test Accuracy: 68.84%
Epoch [214/250], Loss: 0.2603, Accuracy: 91.95%, Test Loss: 1.1548, Test Accuracy: 68.85%
Epoch [215/250], Loss: 0.2541, Accuracy: 92.27%, Test Loss: 1.1620, Test Accuracy: 68.94%
Epoch [216/250], Loss: 0.2559, Accuracy: 92.09%, Test Loss: 1.1900, Test Accuracy: 68.22%
Epoch [217/250], Loss: 0.2485, Accuracy: 92.27%, Test Loss: 1.1680, Test Accuracy: 68.66%
Epoch [218/250], Loss: 0.2380, Accuracy: 92.91%, Test Loss: 1.1909, Test Accuracy: 68.93%
Epoch [219/250], Loss: 0.2381, Accuracy: 92.76%, Test Loss: 1.2208, Test Accuracy: 68.63%
Epoch [220/250], Loss: 0.2400, Accuracy: 92.63%, Test Loss: 1.1867, Test Accuracy: 68.96%
Epoch [221/250], Loss: 0.2334, Accuracy: 92.92%, Test Loss: 1.1868, Test Accuracy: 68.90%
Epoch [222/250], Loss: 0.2298, Accuracy: 93.18%, Test Loss: 1.2074, Test Accuracy: 68.81%
Epoch [223/250], Loss: 0.2291, Accuracy: 93.17%, Test Loss: 1.2113, Test Accuracy: 69.03%
Epoch [224/250], Loss: 0.2268, Accuracy: 93.20%, Test Loss: 1.2380, Test Accuracy: 68.45%
Epoch [225/250], Loss: 0.2289, Accuracy: 92.94%, Test Loss: 1.2351, Test Accuracy: 68.77%
Epoch [226/250], Loss: 0.2188, Accuracy: 93.41%, Test Loss: 1.2300, Test Accuracy: 68.67%
Epoch [227/250], Loss: 0.2143, Accuracy: 93.64%, Test Loss: 1.2508, Test Accuracy: 68.19%
Epoch [228/250], Loss: 0.2211, Accuracy: 93.27%, Test Loss: 1.2439, Test Accuracy: 68.68%
Epoch [229/250], Loss: 0.2156, Accuracy: 93.54%, Test Loss: 1.2435, Test Accuracy: 68.89%
Epoch [230/250], Loss: 0.2038, Accuracy: 94.13%, Test Loss: 1.2374, Test Accuracy: 68.50%
Epoch [231/250], Loss: 0.2018, Accuracy: 94.16%, Test Loss: 1.2549, Test Accuracy: 68.58%
Epoch [232/250], Loss: 0.2021, Accuracy: 94.06%, Test Loss: 1.2807, Test Accuracy: 68.40%
Epoch [233/250], Loss: 0.1925, Accuracy: 94.52%, Test Loss: 1.2825, Test Accuracy: 68.53%
Epoch [234/250], Loss: 0.1937, Accuracy: 94.39%, Test Loss: 1.2776, Test Accuracy: 68.70%
Epoch [235/250], Loss: 0.1942, Accuracy: 94.38%, Test Loss: 1.3057, Test Accuracy: 68.94%
Epoch [236/250], Loss: 0.2014, Accuracy: 93.84%, Test Loss: 1.2807, Test Accuracy: 68.54%
Epoch [237/250], Loss: 0.1841, Accuracy: 94.83%, Test Loss: 1.2907, Test Accuracy: 68.66%
Epoch [238/250], Loss: 0.1779, Accuracy: 95.18%, Test Loss: 1.3097, Test Accuracy: 68.60%
Epoch [239/250], Loss: 0.1773, Accuracy: 95.09%, Test Loss: 1.3055, Test Accuracy: 68.45%
Epoch [240/250], Loss: 0.1776, Accuracy: 94.96%, Test Loss: 1.3228, Test Accuracy: 68.34%
Epoch [241/250], Loss: 0.1709, Accuracy: 95.41%, Test Loss: 1.3253, Test Accuracy: 69.02%
Epoch [242/250], Loss: 0.1710, Accuracy: 95.22%, Test Loss: 1.3299, Test Accuracy: 68.76%
Epoch [243/250], Loss: 0.1656, Accuracy: 95.54%, Test Loss: 1.3633, Test Accuracy: 67.68%
Epoch [244/250], Loss: 0.1733, Accuracy: 95.01%, Test Loss: 1.3528, Test Accuracy: 68.37%
Epoch [245/250], Loss: 0.1734, Accuracy: 95.00%, Test Loss: 1.3622, Test Accuracy: 68.29%
Epoch [246/250], Loss: 0.1655, Accuracy: 95.34%, Test Loss: 1.3638, Test Accuracy: 67.99%
Epoch [247/250], Loss: 0.1506, Accuracy: 96.12%, Test Loss: 1.3589, Test Accuracy: 68.75%
Epoch [248/250], Loss: 0.1495, Accuracy: 96.12%, Test Loss: 1.3729, Test Accuracy: 69.16%
Epoch [249/250], Loss: 0.1538, Accuracy: 95.84%, Test Loss: 1.3766, Test Accuracy: 68.59%
Epoch [250/250], Loss: 0.1495, Accuracy: 96.08%, Test Loss: 1.3979, Test Accuracy: 68.44%
import matplotlib.pyplot as plt

plt.plot(train_losses, label='train_losses')
plt.plot(test_losses, label='test_losses')
plt.legend()
<matplotlib.legend.Legend at 0x7efc150ef520>
image
import matplotlib.pyplot as plt

plt.plot(train_accuracies, label='train_accuracy')
plt.plot(test_accuracies, label='test_accuracy')
plt.legend()
<matplotlib.legend.Legend at 0x7efc10603880>
image

Downsampling

Trong đoạn code dươi sẽ downsampling bằng 3 cách:

  • MaxPooling
  • AveragePooling
  • Interpolate
FashionMNIST_mp

Data

# Load CFashionMNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1024, num_workers=10, shuffle=True, drop_last=True)

testset = torchvision.datasets.FashionMNIST(root='data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1024, num_workers=10, shuffle=False)
import matplotlib.pyplot as plt
import numpy as np

# Function to display the images
def imshow(img):
    img = img*0.5 + 0.5
    np_img = img.numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()

for i, (images, labels) in enumerate(trainloader, 0):
    # Plot some images
    imshow(torchvision.utils.make_grid(images[:8]))  # Display 8 images from the batch
    break
image

Model

import torch
import torch.nn as nn
from torchsummary import summary

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.conv = nn.Conv2d(1, 32, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2) 
        self.flatten = nn.Flatten()
        self.dense1 = nn.Linear(12*12*32, 128)
        self.dense2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.relu(self.dense1(x))
        x = self.dense2(x)
        return x

model = CustomModel()
model = model.to(device)

# Print model
summary(model, (1, 28, 28))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 24, 24]             832
              ReLU-2           [-1, 32, 24, 24]               0
         MaxPool2d-3           [-1, 32, 12, 12]               0
           Flatten-4                 [-1, 4608]               0
            Linear-5                  [-1, 128]         589,952
              ReLU-6                  [-1, 128]               0
            Linear-7                   [-1, 10]           1,290
================================================================
Total params: 592,074
Trainable params: 592,074
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.35
Params size (MB): 2.26
Estimated Total Size (MB): 2.62
----------------------------------------------------------------

Loss, Optimizer, and Evaluation Function

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# Function to compute loss and accuracy for test set
def evaluate(model, testloader, criterion):
    model.eval()
    test_loss = 0.0
    running_correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            # Move inputs and labels to the device
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

    accuracy = 100 * running_correct / total
    test_loss = test_loss / len(testloader)
    return test_loss, accuracy

Train

# some parameter
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
max_epoch = 150
# train
for epoch in range(max_epoch):
    model.train()
    running_loss = 0.0
    running_correct = 0   # to track number of correct predictions
    total = 0             # to track total number of samples

    for i, (inputs, labels) in enumerate(trainloader, 0):
        # Move inputs and labels to the device
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

        # Determine class predictions and track accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        running_correct += (predicted == labels).sum().item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()        

    epoch_accuracy = 100 * running_correct / total
    epoch_loss = running_loss / (i + 1)
    
    test_loss, test_accuracy = evaluate(model, testloader, criterion)
    print(f"Epoch [{epoch + 1}/{max_epoch}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    # save for plot
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_accuracy)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
Epoch [1/150], Loss: 1.2998, Accuracy: 65.43%, Test Loss: 0.7856, Test Accuracy: 74.42%
Epoch [2/150], Loss: 0.6574, Accuracy: 78.01%, Test Loss: 0.5980, Test Accuracy: 79.14%
Epoch [3/150], Loss: 0.5381, Accuracy: 81.68%, Test Loss: 0.5224, Test Accuracy: 81.82%
Epoch [4/150], Loss: 0.4800, Accuracy: 83.72%, Test Loss: 0.4807, Test Accuracy: 83.11%
Epoch [5/150], Loss: 0.4446, Accuracy: 84.78%, Test Loss: 0.4520, Test Accuracy: 84.05%
Epoch [6/150], Loss: 0.4192, Accuracy: 85.57%, Test Loss: 0.4324, Test Accuracy: 84.72%
Epoch [7/150], Loss: 0.3995, Accuracy: 86.24%, Test Loss: 0.4155, Test Accuracy: 85.53%
Epoch [8/150], Loss: 0.3846, Accuracy: 86.78%, Test Loss: 0.4023, Test Accuracy: 85.86%
Epoch [9/150], Loss: 0.3719, Accuracy: 87.23%, Test Loss: 0.3948, Test Accuracy: 86.21%
Epoch [10/150], Loss: 0.3623, Accuracy: 87.57%, Test Loss: 0.3815, Test Accuracy: 86.84%
Epoch [11/150], Loss: 0.3519, Accuracy: 87.84%, Test Loss: 0.3740, Test Accuracy: 86.73%
Epoch [12/150], Loss: 0.3427, Accuracy: 88.22%, Test Loss: 0.3665, Test Accuracy: 87.28%
Epoch [13/150], Loss: 0.3356, Accuracy: 88.43%, Test Loss: 0.3597, Test Accuracy: 87.29%
Epoch [14/150], Loss: 0.3288, Accuracy: 88.68%, Test Loss: 0.3537, Test Accuracy: 87.33%
Epoch [15/150], Loss: 0.3225, Accuracy: 88.82%, Test Loss: 0.3510, Test Accuracy: 87.81%
Epoch [16/150], Loss: 0.3177, Accuracy: 89.03%, Test Loss: 0.3448, Test Accuracy: 87.78%
Epoch [17/150], Loss: 0.3112, Accuracy: 89.33%, Test Loss: 0.3434, Test Accuracy: 87.69%
Epoch [18/150], Loss: 0.3077, Accuracy: 89.36%, Test Loss: 0.3383, Test Accuracy: 88.13%
Epoch [19/150], Loss: 0.3027, Accuracy: 89.60%, Test Loss: 0.3327, Test Accuracy: 88.15%
Epoch [20/150], Loss: 0.2970, Accuracy: 89.78%, Test Loss: 0.3304, Test Accuracy: 88.34%
Epoch [21/150], Loss: 0.2931, Accuracy: 89.96%, Test Loss: 0.3263, Test Accuracy: 88.42%
Epoch [22/150], Loss: 0.2885, Accuracy: 90.01%, Test Loss: 0.3218, Test Accuracy: 88.74%
Epoch [23/150], Loss: 0.2851, Accuracy: 90.23%, Test Loss: 0.3223, Test Accuracy: 88.50%
Epoch [24/150], Loss: 0.2811, Accuracy: 90.30%, Test Loss: 0.3220, Test Accuracy: 88.53%
Epoch [25/150], Loss: 0.2789, Accuracy: 90.34%, Test Loss: 0.3199, Test Accuracy: 88.61%
Epoch [26/150], Loss: 0.2751, Accuracy: 90.58%, Test Loss: 0.3116, Test Accuracy: 88.82%
Epoch [27/150], Loss: 0.2718, Accuracy: 90.59%, Test Loss: 0.3112, Test Accuracy: 89.09%
Epoch [28/150], Loss: 0.2677, Accuracy: 90.73%, Test Loss: 0.3088, Test Accuracy: 89.08%
Epoch [29/150], Loss: 0.2665, Accuracy: 90.79%, Test Loss: 0.3063, Test Accuracy: 88.79%
Epoch [30/150], Loss: 0.2622, Accuracy: 90.93%, Test Loss: 0.3035, Test Accuracy: 89.01%
Epoch [31/150], Loss: 0.2604, Accuracy: 90.97%, Test Loss: 0.3089, Test Accuracy: 89.00%
Epoch [32/150], Loss: 0.2575, Accuracy: 91.05%, Test Loss: 0.3033, Test Accuracy: 89.18%
Epoch [33/150], Loss: 0.2543, Accuracy: 91.24%, Test Loss: 0.2971, Test Accuracy: 89.39%
Epoch [34/150], Loss: 0.2526, Accuracy: 91.19%, Test Loss: 0.3003, Test Accuracy: 89.31%
Epoch [35/150], Loss: 0.2495, Accuracy: 91.37%, Test Loss: 0.2946, Test Accuracy: 89.56%
Epoch [36/150], Loss: 0.2467, Accuracy: 91.45%, Test Loss: 0.2930, Test Accuracy: 89.35%
Epoch [37/150], Loss: 0.2453, Accuracy: 91.46%, Test Loss: 0.2910, Test Accuracy: 89.56%
Epoch [38/150], Loss: 0.2430, Accuracy: 91.57%, Test Loss: 0.2913, Test Accuracy: 89.56%
Epoch [39/150], Loss: 0.2409, Accuracy: 91.62%, Test Loss: 0.2881, Test Accuracy: 89.65%
Epoch [40/150], Loss: 0.2375, Accuracy: 91.75%, Test Loss: 0.2908, Test Accuracy: 89.21%
Epoch [41/150], Loss: 0.2359, Accuracy: 91.81%, Test Loss: 0.2901, Test Accuracy: 89.70%
Epoch [42/150], Loss: 0.2333, Accuracy: 91.88%, Test Loss: 0.2869, Test Accuracy: 89.64%
Epoch [43/150], Loss: 0.2309, Accuracy: 91.93%, Test Loss: 0.2833, Test Accuracy: 89.69%
Epoch [44/150], Loss: 0.2287, Accuracy: 92.02%, Test Loss: 0.2822, Test Accuracy: 89.80%
Epoch [45/150], Loss: 0.2285, Accuracy: 92.04%, Test Loss: 0.2834, Test Accuracy: 89.87%
Epoch [46/150], Loss: 0.2251, Accuracy: 92.15%, Test Loss: 0.2817, Test Accuracy: 89.78%
Epoch [47/150], Loss: 0.2235, Accuracy: 92.23%, Test Loss: 0.2819, Test Accuracy: 89.91%
Epoch [48/150], Loss: 0.2216, Accuracy: 92.26%, Test Loss: 0.2824, Test Accuracy: 89.51%
Epoch [49/150], Loss: 0.2220, Accuracy: 92.27%, Test Loss: 0.2793, Test Accuracy: 89.78%
Epoch [50/150], Loss: 0.2189, Accuracy: 92.38%, Test Loss: 0.2797, Test Accuracy: 89.99%
Epoch [51/150], Loss: 0.2180, Accuracy: 92.40%, Test Loss: 0.2779, Test Accuracy: 90.13%
Epoch [52/150], Loss: 0.2151, Accuracy: 92.43%, Test Loss: 0.2759, Test Accuracy: 89.86%
Epoch [53/150], Loss: 0.2128, Accuracy: 92.56%, Test Loss: 0.2734, Test Accuracy: 90.07%
Epoch [54/150], Loss: 0.2106, Accuracy: 92.64%, Test Loss: 0.2736, Test Accuracy: 89.98%
Epoch [55/150], Loss: 0.2100, Accuracy: 92.64%, Test Loss: 0.2746, Test Accuracy: 90.31%
Epoch [56/150], Loss: 0.2092, Accuracy: 92.71%, Test Loss: 0.2721, Test Accuracy: 90.20%
Epoch [57/150], Loss: 0.2065, Accuracy: 92.79%, Test Loss: 0.2711, Test Accuracy: 90.16%
Epoch [58/150], Loss: 0.2046, Accuracy: 92.89%, Test Loss: 0.2726, Test Accuracy: 90.15%
Epoch [59/150], Loss: 0.2029, Accuracy: 92.95%, Test Loss: 0.2689, Test Accuracy: 90.14%
Epoch [60/150], Loss: 0.2016, Accuracy: 93.03%, Test Loss: 0.2708, Test Accuracy: 90.29%
Epoch [61/150], Loss: 0.2007, Accuracy: 92.97%, Test Loss: 0.2698, Test Accuracy: 90.22%
Epoch [62/150], Loss: 0.1989, Accuracy: 93.12%, Test Loss: 0.2673, Test Accuracy: 90.23%
Epoch [63/150], Loss: 0.1961, Accuracy: 93.16%, Test Loss: 0.2684, Test Accuracy: 90.45%
Epoch [64/150], Loss: 0.1955, Accuracy: 93.21%, Test Loss: 0.2707, Test Accuracy: 90.30%
Epoch [65/150], Loss: 0.1934, Accuracy: 93.23%, Test Loss: 0.2683, Test Accuracy: 90.32%
Epoch [66/150], Loss: 0.1929, Accuracy: 93.28%, Test Loss: 0.2654, Test Accuracy: 90.45%
Epoch [67/150], Loss: 0.1896, Accuracy: 93.42%, Test Loss: 0.2645, Test Accuracy: 90.48%
Epoch [68/150], Loss: 0.1895, Accuracy: 93.35%, Test Loss: 0.2637, Test Accuracy: 90.60%
Epoch [69/150], Loss: 0.1905, Accuracy: 93.24%, Test Loss: 0.2641, Test Accuracy: 90.36%
Epoch [70/150], Loss: 0.1864, Accuracy: 93.43%, Test Loss: 0.2663, Test Accuracy: 90.41%
Epoch [71/150], Loss: 0.1865, Accuracy: 93.46%, Test Loss: 0.2642, Test Accuracy: 90.40%
Epoch [72/150], Loss: 0.1844, Accuracy: 93.53%, Test Loss: 0.2642, Test Accuracy: 90.44%
Epoch [73/150], Loss: 0.1831, Accuracy: 93.66%, Test Loss: 0.2642, Test Accuracy: 90.55%
Epoch [74/150], Loss: 0.1812, Accuracy: 93.74%, Test Loss: 0.2626, Test Accuracy: 90.51%
Epoch [75/150], Loss: 0.1798, Accuracy: 93.77%, Test Loss: 0.2623, Test Accuracy: 90.61%
Epoch [76/150], Loss: 0.1792, Accuracy: 93.72%, Test Loss: 0.2628, Test Accuracy: 90.37%
Epoch [77/150], Loss: 0.1770, Accuracy: 93.86%, Test Loss: 0.2604, Test Accuracy: 90.51%
Epoch [78/150], Loss: 0.1763, Accuracy: 93.83%, Test Loss: 0.2622, Test Accuracy: 90.64%
Epoch [79/150], Loss: 0.1746, Accuracy: 93.88%, Test Loss: 0.2605, Test Accuracy: 90.68%
Epoch [80/150], Loss: 0.1729, Accuracy: 94.04%, Test Loss: 0.2623, Test Accuracy: 90.69%
Epoch [81/150], Loss: 0.1722, Accuracy: 93.96%, Test Loss: 0.2618, Test Accuracy: 90.58%
Epoch [82/150], Loss: 0.1716, Accuracy: 94.03%, Test Loss: 0.2609, Test Accuracy: 90.55%
Epoch [83/150], Loss: 0.1697, Accuracy: 94.14%, Test Loss: 0.2612, Test Accuracy: 90.57%
Epoch [84/150], Loss: 0.1679, Accuracy: 94.23%, Test Loss: 0.2585, Test Accuracy: 90.73%
Epoch [85/150], Loss: 0.1677, Accuracy: 94.10%, Test Loss: 0.2600, Test Accuracy: 90.59%
Epoch [86/150], Loss: 0.1660, Accuracy: 94.27%, Test Loss: 0.2608, Test Accuracy: 90.62%
Epoch [87/150], Loss: 0.1660, Accuracy: 94.21%, Test Loss: 0.2628, Test Accuracy: 90.51%
Epoch [88/150], Loss: 0.1641, Accuracy: 94.28%, Test Loss: 0.2583, Test Accuracy: 90.73%
Epoch [89/150], Loss: 0.1626, Accuracy: 94.36%, Test Loss: 0.2568, Test Accuracy: 90.88%
Epoch [90/150], Loss: 0.1612, Accuracy: 94.41%, Test Loss: 0.2579, Test Accuracy: 90.75%
Epoch [91/150], Loss: 0.1605, Accuracy: 94.44%, Test Loss: 0.2573, Test Accuracy: 90.95%
Epoch [92/150], Loss: 0.1598, Accuracy: 94.56%, Test Loss: 0.2582, Test Accuracy: 90.90%
Epoch [93/150], Loss: 0.1581, Accuracy: 94.44%, Test Loss: 0.2556, Test Accuracy: 90.80%
Epoch [94/150], Loss: 0.1578, Accuracy: 94.54%, Test Loss: 0.2554, Test Accuracy: 90.99%
Epoch [95/150], Loss: 0.1555, Accuracy: 94.63%, Test Loss: 0.2593, Test Accuracy: 90.73%
Epoch [96/150], Loss: 0.1562, Accuracy: 94.53%, Test Loss: 0.2537, Test Accuracy: 90.99%
Epoch [97/150], Loss: 0.1528, Accuracy: 94.77%, Test Loss: 0.2557, Test Accuracy: 90.88%
Epoch [98/150], Loss: 0.1530, Accuracy: 94.72%, Test Loss: 0.2571, Test Accuracy: 90.72%
Epoch [99/150], Loss: 0.1521, Accuracy: 94.74%, Test Loss: 0.2559, Test Accuracy: 90.93%
Epoch [100/150], Loss: 0.1517, Accuracy: 94.76%, Test Loss: 0.2587, Test Accuracy: 90.84%
Epoch [101/150], Loss: 0.1502, Accuracy: 94.86%, Test Loss: 0.2593, Test Accuracy: 90.73%
Epoch [102/150], Loss: 0.1484, Accuracy: 94.97%, Test Loss: 0.2592, Test Accuracy: 90.65%
Epoch [103/150], Loss: 0.1483, Accuracy: 94.96%, Test Loss: 0.2550, Test Accuracy: 91.02%
Epoch [104/150], Loss: 0.1463, Accuracy: 95.04%, Test Loss: 0.2566, Test Accuracy: 91.03%
Epoch [105/150], Loss: 0.1445, Accuracy: 95.10%, Test Loss: 0.2556, Test Accuracy: 90.95%
Epoch [106/150], Loss: 0.1435, Accuracy: 95.13%, Test Loss: 0.2563, Test Accuracy: 90.91%
Epoch [107/150], Loss: 0.1426, Accuracy: 95.15%, Test Loss: 0.2585, Test Accuracy: 90.93%
Epoch [108/150], Loss: 0.1420, Accuracy: 95.12%, Test Loss: 0.2569, Test Accuracy: 90.78%
Epoch [109/150], Loss: 0.1412, Accuracy: 95.25%, Test Loss: 0.2552, Test Accuracy: 90.90%
Epoch [110/150], Loss: 0.1406, Accuracy: 95.25%, Test Loss: 0.2553, Test Accuracy: 90.95%
Epoch [111/150], Loss: 0.1398, Accuracy: 95.24%, Test Loss: 0.2546, Test Accuracy: 91.09%
Epoch [112/150], Loss: 0.1382, Accuracy: 95.30%, Test Loss: 0.2559, Test Accuracy: 90.95%
Epoch [113/150], Loss: 0.1373, Accuracy: 95.35%, Test Loss: 0.2548, Test Accuracy: 90.96%
Epoch [114/150], Loss: 0.1371, Accuracy: 95.31%, Test Loss: 0.2579, Test Accuracy: 91.19%
Epoch [115/150], Loss: 0.1353, Accuracy: 95.43%, Test Loss: 0.2555, Test Accuracy: 90.91%
Epoch [116/150], Loss: 0.1342, Accuracy: 95.54%, Test Loss: 0.2556, Test Accuracy: 91.03%
Epoch [117/150], Loss: 0.1330, Accuracy: 95.50%, Test Loss: 0.2572, Test Accuracy: 91.07%
Epoch [118/150], Loss: 0.1338, Accuracy: 95.51%, Test Loss: 0.2578, Test Accuracy: 90.98%
Epoch [119/150], Loss: 0.1327, Accuracy: 95.57%, Test Loss: 0.2564, Test Accuracy: 91.15%
Epoch [120/150], Loss: 0.1310, Accuracy: 95.65%, Test Loss: 0.2584, Test Accuracy: 90.94%
Epoch [121/150], Loss: 0.1307, Accuracy: 95.65%, Test Loss: 0.2552, Test Accuracy: 91.03%
Epoch [122/150], Loss: 0.1300, Accuracy: 95.66%, Test Loss: 0.2597, Test Accuracy: 90.91%
Epoch [123/150], Loss: 0.1281, Accuracy: 95.75%, Test Loss: 0.2561, Test Accuracy: 91.06%
Epoch [124/150], Loss: 0.1276, Accuracy: 95.77%, Test Loss: 0.2572, Test Accuracy: 91.16%
Epoch [125/150], Loss: 0.1268, Accuracy: 95.79%, Test Loss: 0.2580, Test Accuracy: 90.78%
Epoch [126/150], Loss: 0.1260, Accuracy: 95.78%, Test Loss: 0.2562, Test Accuracy: 90.93%
Epoch [127/150], Loss: 0.1237, Accuracy: 95.96%, Test Loss: 0.2551, Test Accuracy: 91.18%
Epoch [128/150], Loss: 0.1238, Accuracy: 95.95%, Test Loss: 0.2551, Test Accuracy: 91.21%
Epoch [129/150], Loss: 0.1237, Accuracy: 95.87%, Test Loss: 0.2563, Test Accuracy: 91.20%
Epoch [130/150], Loss: 0.1224, Accuracy: 95.98%, Test Loss: 0.2559, Test Accuracy: 91.17%
Epoch [131/150], Loss: 0.1222, Accuracy: 95.99%, Test Loss: 0.2581, Test Accuracy: 90.97%
Epoch [132/150], Loss: 0.1207, Accuracy: 96.04%, Test Loss: 0.2559, Test Accuracy: 91.16%
Epoch [133/150], Loss: 0.1201, Accuracy: 96.07%, Test Loss: 0.2592, Test Accuracy: 91.02%
Epoch [134/150], Loss: 0.1187, Accuracy: 96.17%, Test Loss: 0.2571, Test Accuracy: 91.07%
Epoch [135/150], Loss: 0.1190, Accuracy: 96.11%, Test Loss: 0.2615, Test Accuracy: 91.15%
Epoch [136/150], Loss: 0.1182, Accuracy: 96.17%, Test Loss: 0.2597, Test Accuracy: 91.27%
Epoch [137/150], Loss: 0.1174, Accuracy: 96.21%, Test Loss: 0.2584, Test Accuracy: 91.24%
Epoch [138/150], Loss: 0.1172, Accuracy: 96.09%, Test Loss: 0.2592, Test Accuracy: 91.17%
Epoch [139/150], Loss: 0.1139, Accuracy: 96.38%, Test Loss: 0.2584, Test Accuracy: 91.11%
Epoch [140/150], Loss: 0.1143, Accuracy: 96.33%, Test Loss: 0.2560, Test Accuracy: 91.24%
Epoch [141/150], Loss: 0.1130, Accuracy: 96.41%, Test Loss: 0.2568, Test Accuracy: 91.18%
Epoch [142/150], Loss: 0.1119, Accuracy: 96.41%, Test Loss: 0.2594, Test Accuracy: 91.16%
Epoch [143/150], Loss: 0.1117, Accuracy: 96.45%, Test Loss: 0.2631, Test Accuracy: 91.03%
Epoch [144/150], Loss: 0.1110, Accuracy: 96.42%, Test Loss: 0.2735, Test Accuracy: 90.43%
Epoch [145/150], Loss: 0.1117, Accuracy: 96.35%, Test Loss: 0.2606, Test Accuracy: 91.15%
Epoch [146/150], Loss: 0.1106, Accuracy: 96.40%, Test Loss: 0.2596, Test Accuracy: 91.27%
Epoch [147/150], Loss: 0.1083, Accuracy: 96.61%, Test Loss: 0.2589, Test Accuracy: 91.18%
Epoch [148/150], Loss: 0.1092, Accuracy: 96.47%, Test Loss: 0.2617, Test Accuracy: 90.98%
Epoch [149/150], Loss: 0.1075, Accuracy: 96.61%, Test Loss: 0.2608, Test Accuracy: 91.16%
Epoch [150/150], Loss: 0.1057, Accuracy: 96.67%, Test Loss: 0.2619, Test Accuracy: 91.16%
import matplotlib.pyplot as plt

plt.plot(train_losses, label='train_losses')
plt.plot(test_losses, label='test_losses')
plt.legend()
<matplotlib.legend.Legend at 0x7f5fcc86dde0>
image
import matplotlib.pyplot as plt

plt.plot(train_accuracies, label='train_accuracy')
plt.plot(test_accuracies, label='test_accuracy')
plt.legend()
<matplotlib.legend.Legend at 0x7f5fcc2c3430>
image
FashionMNIST_ap

Model

import torch
import torch.nn as nn
from torchsummary import summary

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.conv = nn.Conv2d(1, 32, kernel_size=5)
        self.pool = nn.AvgPool2d(2, 2) 
        self.flatten = nn.Flatten()
        self.dense1 = nn.Linear(12*12*32, 128)
        self.dense2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.relu(self.dense1(x))
        x = self.dense2(x)
        return x

model = CustomModel()
model = model.to(device)

# Print model
summary(model, (1, 28, 28))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 24, 24]             832
              ReLU-2           [-1, 32, 24, 24]               0
         AvgPool2d-3           [-1, 32, 12, 12]               0
           Flatten-4                 [-1, 4608]               0
            Linear-5                  [-1, 128]         589,952
              ReLU-6                  [-1, 128]               0
            Linear-7                   [-1, 10]           1,290
================================================================
Total params: 592,074
Trainable params: 592,074
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.35
Params size (MB): 2.26
Estimated Total Size (MB): 2.62
----------------------------------------------------------------

Loss, Optimizer, and Evaluation Function

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# Function to compute loss and accuracy for test set
def evaluate(model, testloader, criterion):
    model.eval()
    test_loss = 0.0
    running_correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            # Move inputs and labels to the device
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

    accuracy = 100 * running_correct / total
    test_loss = test_loss / len(testloader)
    return test_loss, accuracy

Train

# some parameter
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
max_epoch = 150
# train
for epoch in range(max_epoch):
    model.train()
    running_loss = 0.0
    running_correct = 0   # to track number of correct predictions
    total = 0             # to track total number of samples

    for i, (inputs, labels) in enumerate(trainloader, 0):
        # Move inputs and labels to the device
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

        # Determine class predictions and track accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        running_correct += (predicted == labels).sum().item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()        

    epoch_accuracy = 100 * running_correct / total
    epoch_loss = running_loss / (i + 1)
    
    test_loss, test_accuracy = evaluate(model, testloader, criterion)
    print(f"Epoch [{epoch + 1}/{max_epoch}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    # save for plot
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_accuracy)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
Epoch [1/150], Loss: 1.4442, Accuracy: 61.88%, Test Loss: 0.9142, Test Accuracy: 70.93%
Epoch [2/150], Loss: 0.7605, Accuracy: 74.57%, Test Loss: 0.6812, Test Accuracy: 75.83%
Epoch [3/150], Loss: 0.6175, Accuracy: 78.48%, Test Loss: 0.5976, Test Accuracy: 78.50%
Epoch [4/150], Loss: 0.5522, Accuracy: 80.54%, Test Loss: 0.5507, Test Accuracy: 80.52%
Epoch [5/150], Loss: 0.5110, Accuracy: 82.10%, Test Loss: 0.5188, Test Accuracy: 81.68%
Epoch [6/150], Loss: 0.4804, Accuracy: 83.10%, Test Loss: 0.4933, Test Accuracy: 82.61%
Epoch [7/150], Loss: 0.4573, Accuracy: 84.07%, Test Loss: 0.4759, Test Accuracy: 82.98%
Epoch [8/150], Loss: 0.4389, Accuracy: 84.72%, Test Loss: 0.4573, Test Accuracy: 83.89%
Epoch [9/150], Loss: 0.4234, Accuracy: 85.29%, Test Loss: 0.4453, Test Accuracy: 84.14%
Epoch [10/150], Loss: 0.4107, Accuracy: 85.65%, Test Loss: 0.4399, Test Accuracy: 84.49%
Epoch [11/150], Loss: 0.3993, Accuracy: 86.16%, Test Loss: 0.4249, Test Accuracy: 85.05%
Epoch [12/150], Loss: 0.3878, Accuracy: 86.50%, Test Loss: 0.4154, Test Accuracy: 85.31%
Epoch [13/150], Loss: 0.3792, Accuracy: 86.78%, Test Loss: 0.4072, Test Accuracy: 85.65%
Epoch [14/150], Loss: 0.3722, Accuracy: 87.03%, Test Loss: 0.3992, Test Accuracy: 85.88%
Epoch [15/150], Loss: 0.3644, Accuracy: 87.23%, Test Loss: 0.3916, Test Accuracy: 86.24%
Epoch [16/150], Loss: 0.3566, Accuracy: 87.54%, Test Loss: 0.3863, Test Accuracy: 86.43%
Epoch [17/150], Loss: 0.3522, Accuracy: 87.68%, Test Loss: 0.3808, Test Accuracy: 86.60%
Epoch [18/150], Loss: 0.3456, Accuracy: 87.93%, Test Loss: 0.3777, Test Accuracy: 86.55%
Epoch [19/150], Loss: 0.3396, Accuracy: 88.17%, Test Loss: 0.3721, Test Accuracy: 86.88%
Epoch [20/150], Loss: 0.3358, Accuracy: 88.22%, Test Loss: 0.3687, Test Accuracy: 86.89%
Epoch [21/150], Loss: 0.3312, Accuracy: 88.35%, Test Loss: 0.3618, Test Accuracy: 87.23%
Epoch [22/150], Loss: 0.3265, Accuracy: 88.61%, Test Loss: 0.3612, Test Accuracy: 87.22%
Epoch [23/150], Loss: 0.3219, Accuracy: 88.70%, Test Loss: 0.3572, Test Accuracy: 87.26%
Epoch [24/150], Loss: 0.3185, Accuracy: 88.77%, Test Loss: 0.3530, Test Accuracy: 87.50%
Epoch [25/150], Loss: 0.3147, Accuracy: 88.94%, Test Loss: 0.3516, Test Accuracy: 87.70%
Epoch [26/150], Loss: 0.3106, Accuracy: 89.07%, Test Loss: 0.3473, Test Accuracy: 87.60%
Epoch [27/150], Loss: 0.3082, Accuracy: 89.19%, Test Loss: 0.3432, Test Accuracy: 87.95%
Epoch [28/150], Loss: 0.3060, Accuracy: 89.25%, Test Loss: 0.3412, Test Accuracy: 87.76%
Epoch [29/150], Loss: 0.3015, Accuracy: 89.43%, Test Loss: 0.3393, Test Accuracy: 87.95%
Epoch [30/150], Loss: 0.2992, Accuracy: 89.49%, Test Loss: 0.3390, Test Accuracy: 87.92%
Epoch [31/150], Loss: 0.2963, Accuracy: 89.59%, Test Loss: 0.3334, Test Accuracy: 88.09%
Epoch [32/150], Loss: 0.2931, Accuracy: 89.73%, Test Loss: 0.3334, Test Accuracy: 87.95%
Epoch [33/150], Loss: 0.2901, Accuracy: 89.71%, Test Loss: 0.3289, Test Accuracy: 88.24%
Epoch [34/150], Loss: 0.2868, Accuracy: 89.92%, Test Loss: 0.3298, Test Accuracy: 87.87%
Epoch [35/150], Loss: 0.2845, Accuracy: 90.01%, Test Loss: 0.3263, Test Accuracy: 88.52%
Epoch [36/150], Loss: 0.2830, Accuracy: 90.00%, Test Loss: 0.3235, Test Accuracy: 88.48%
Epoch [37/150], Loss: 0.2799, Accuracy: 90.15%, Test Loss: 0.3213, Test Accuracy: 88.57%
Epoch [38/150], Loss: 0.2784, Accuracy: 90.11%, Test Loss: 0.3195, Test Accuracy: 88.44%
Epoch [39/150], Loss: 0.2760, Accuracy: 90.24%, Test Loss: 0.3234, Test Accuracy: 88.45%
Epoch [40/150], Loss: 0.2736, Accuracy: 90.37%, Test Loss: 0.3167, Test Accuracy: 88.61%
Epoch [41/150], Loss: 0.2705, Accuracy: 90.40%, Test Loss: 0.3161, Test Accuracy: 88.51%
Epoch [42/150], Loss: 0.2694, Accuracy: 90.40%, Test Loss: 0.3143, Test Accuracy: 88.72%
Epoch [43/150], Loss: 0.2673, Accuracy: 90.49%, Test Loss: 0.3115, Test Accuracy: 88.65%
Epoch [44/150], Loss: 0.2649, Accuracy: 90.51%, Test Loss: 0.3101, Test Accuracy: 88.83%
Epoch [45/150], Loss: 0.2633, Accuracy: 90.65%, Test Loss: 0.3086, Test Accuracy: 88.83%
Epoch [46/150], Loss: 0.2611, Accuracy: 90.72%, Test Loss: 0.3089, Test Accuracy: 88.86%
Epoch [47/150], Loss: 0.2588, Accuracy: 90.77%, Test Loss: 0.3064, Test Accuracy: 88.95%
Epoch [48/150], Loss: 0.2568, Accuracy: 90.89%, Test Loss: 0.3049, Test Accuracy: 89.01%
Epoch [49/150], Loss: 0.2544, Accuracy: 90.92%, Test Loss: 0.3036, Test Accuracy: 89.01%
Epoch [50/150], Loss: 0.2535, Accuracy: 90.95%, Test Loss: 0.3033, Test Accuracy: 89.11%
Epoch [51/150], Loss: 0.2518, Accuracy: 91.00%, Test Loss: 0.3010, Test Accuracy: 89.23%
Epoch [52/150], Loss: 0.2515, Accuracy: 91.00%, Test Loss: 0.3029, Test Accuracy: 89.05%
Epoch [53/150], Loss: 0.2490, Accuracy: 91.13%, Test Loss: 0.3000, Test Accuracy: 89.16%
Epoch [54/150], Loss: 0.2475, Accuracy: 91.19%, Test Loss: 0.2979, Test Accuracy: 89.16%
Epoch [55/150], Loss: 0.2451, Accuracy: 91.23%, Test Loss: 0.2968, Test Accuracy: 89.22%
Epoch [56/150], Loss: 0.2430, Accuracy: 91.32%, Test Loss: 0.2966, Test Accuracy: 89.29%
Epoch [57/150], Loss: 0.2428, Accuracy: 91.39%, Test Loss: 0.2972, Test Accuracy: 89.12%
Epoch [58/150], Loss: 0.2415, Accuracy: 91.38%, Test Loss: 0.2944, Test Accuracy: 89.41%
Epoch [59/150], Loss: 0.2392, Accuracy: 91.39%, Test Loss: 0.2955, Test Accuracy: 89.23%
Epoch [60/150], Loss: 0.2380, Accuracy: 91.43%, Test Loss: 0.2966, Test Accuracy: 89.23%
Epoch [61/150], Loss: 0.2371, Accuracy: 91.55%, Test Loss: 0.2926, Test Accuracy: 89.46%
Epoch [62/150], Loss: 0.2347, Accuracy: 91.58%, Test Loss: 0.2918, Test Accuracy: 89.46%
Epoch [63/150], Loss: 0.2341, Accuracy: 91.70%, Test Loss: 0.2916, Test Accuracy: 89.56%
Epoch [64/150], Loss: 0.2316, Accuracy: 91.80%, Test Loss: 0.2926, Test Accuracy: 89.35%
Epoch [65/150], Loss: 0.2317, Accuracy: 91.69%, Test Loss: 0.2889, Test Accuracy: 89.84%
Epoch [66/150], Loss: 0.2292, Accuracy: 91.86%, Test Loss: 0.2924, Test Accuracy: 89.29%
Epoch [67/150], Loss: 0.2281, Accuracy: 91.84%, Test Loss: 0.2899, Test Accuracy: 89.52%
Epoch [68/150], Loss: 0.2269, Accuracy: 91.94%, Test Loss: 0.2871, Test Accuracy: 89.53%
Epoch [69/150], Loss: 0.2260, Accuracy: 91.96%, Test Loss: 0.2868, Test Accuracy: 89.77%
Epoch [70/150], Loss: 0.2248, Accuracy: 92.04%, Test Loss: 0.2858, Test Accuracy: 89.70%
Epoch [71/150], Loss: 0.2233, Accuracy: 92.08%, Test Loss: 0.2863, Test Accuracy: 89.69%
Epoch [72/150], Loss: 0.2223, Accuracy: 92.09%, Test Loss: 0.2846, Test Accuracy: 89.88%
Epoch [73/150], Loss: 0.2206, Accuracy: 92.12%, Test Loss: 0.2832, Test Accuracy: 89.92%
Epoch [74/150], Loss: 0.2203, Accuracy: 92.11%, Test Loss: 0.2833, Test Accuracy: 89.93%
Epoch [75/150], Loss: 0.2179, Accuracy: 92.27%, Test Loss: 0.2842, Test Accuracy: 89.74%
Epoch [76/150], Loss: 0.2162, Accuracy: 92.34%, Test Loss: 0.2828, Test Accuracy: 89.97%
Epoch [77/150], Loss: 0.2160, Accuracy: 92.30%, Test Loss: 0.2871, Test Accuracy: 89.75%
Epoch [78/150], Loss: 0.2144, Accuracy: 92.36%, Test Loss: 0.2804, Test Accuracy: 89.94%
Epoch [79/150], Loss: 0.2137, Accuracy: 92.46%, Test Loss: 0.2822, Test Accuracy: 89.86%
Epoch [80/150], Loss: 0.2141, Accuracy: 92.41%, Test Loss: 0.2819, Test Accuracy: 90.00%
Epoch [81/150], Loss: 0.2115, Accuracy: 92.48%, Test Loss: 0.2820, Test Accuracy: 89.92%
Epoch [82/150], Loss: 0.2099, Accuracy: 92.59%, Test Loss: 0.2805, Test Accuracy: 89.93%
Epoch [83/150], Loss: 0.2100, Accuracy: 92.52%, Test Loss: 0.2800, Test Accuracy: 90.03%
Epoch [84/150], Loss: 0.2088, Accuracy: 92.56%, Test Loss: 0.2787, Test Accuracy: 90.06%
Epoch [85/150], Loss: 0.2069, Accuracy: 92.72%, Test Loss: 0.2774, Test Accuracy: 90.14%
Epoch [86/150], Loss: 0.2058, Accuracy: 92.68%, Test Loss: 0.2772, Test Accuracy: 90.09%
Epoch [87/150], Loss: 0.2042, Accuracy: 92.74%, Test Loss: 0.2749, Test Accuracy: 90.28%
Epoch [88/150], Loss: 0.2035, Accuracy: 92.83%, Test Loss: 0.2780, Test Accuracy: 90.17%
Epoch [89/150], Loss: 0.2022, Accuracy: 92.85%, Test Loss: 0.2768, Test Accuracy: 90.16%
Epoch [90/150], Loss: 0.2026, Accuracy: 92.93%, Test Loss: 0.2766, Test Accuracy: 90.38%
Epoch [91/150], Loss: 0.2010, Accuracy: 92.89%, Test Loss: 0.2811, Test Accuracy: 90.10%
Epoch [92/150], Loss: 0.2004, Accuracy: 92.88%, Test Loss: 0.2781, Test Accuracy: 90.22%
Epoch [93/150], Loss: 0.1992, Accuracy: 92.94%, Test Loss: 0.2770, Test Accuracy: 90.28%
Epoch [94/150], Loss: 0.1988, Accuracy: 92.97%, Test Loss: 0.2778, Test Accuracy: 89.98%
Epoch [95/150], Loss: 0.1971, Accuracy: 93.13%, Test Loss: 0.2741, Test Accuracy: 90.40%
Epoch [96/150], Loss: 0.1967, Accuracy: 93.13%, Test Loss: 0.2735, Test Accuracy: 90.32%
Epoch [97/150], Loss: 0.1952, Accuracy: 93.12%, Test Loss: 0.2760, Test Accuracy: 90.12%
Epoch [98/150], Loss: 0.1950, Accuracy: 93.08%, Test Loss: 0.2736, Test Accuracy: 90.38%
Epoch [99/150], Loss: 0.1936, Accuracy: 93.21%, Test Loss: 0.2724, Test Accuracy: 90.37%
Epoch [100/150], Loss: 0.1922, Accuracy: 93.18%, Test Loss: 0.2723, Test Accuracy: 90.46%
Epoch [101/150], Loss: 0.1923, Accuracy: 93.16%, Test Loss: 0.2726, Test Accuracy: 90.42%
Epoch [102/150], Loss: 0.1910, Accuracy: 93.30%, Test Loss: 0.2734, Test Accuracy: 90.50%
Epoch [103/150], Loss: 0.1906, Accuracy: 93.29%, Test Loss: 0.2748, Test Accuracy: 90.43%
Epoch [104/150], Loss: 0.1893, Accuracy: 93.35%, Test Loss: 0.2716, Test Accuracy: 90.58%
Epoch [105/150], Loss: 0.1888, Accuracy: 93.36%, Test Loss: 0.2777, Test Accuracy: 90.11%
Epoch [106/150], Loss: 0.1880, Accuracy: 93.39%, Test Loss: 0.2712, Test Accuracy: 90.48%
Epoch [107/150], Loss: 0.1868, Accuracy: 93.40%, Test Loss: 0.2752, Test Accuracy: 90.42%
Epoch [108/150], Loss: 0.1854, Accuracy: 93.44%, Test Loss: 0.2696, Test Accuracy: 90.71%
Epoch [109/150], Loss: 0.1847, Accuracy: 93.54%, Test Loss: 0.2706, Test Accuracy: 90.63%
Epoch [110/150], Loss: 0.1837, Accuracy: 93.55%, Test Loss: 0.2705, Test Accuracy: 90.49%
Epoch [111/150], Loss: 0.1830, Accuracy: 93.59%, Test Loss: 0.2698, Test Accuracy: 90.54%
Epoch [112/150], Loss: 0.1826, Accuracy: 93.55%, Test Loss: 0.2698, Test Accuracy: 90.49%
Epoch [113/150], Loss: 0.1824, Accuracy: 93.63%, Test Loss: 0.2695, Test Accuracy: 90.62%
Epoch [114/150], Loss: 0.1811, Accuracy: 93.62%, Test Loss: 0.2697, Test Accuracy: 90.57%
Epoch [115/150], Loss: 0.1806, Accuracy: 93.63%, Test Loss: 0.2684, Test Accuracy: 90.67%
Epoch [116/150], Loss: 0.1787, Accuracy: 93.80%, Test Loss: 0.2685, Test Accuracy: 90.69%
Epoch [117/150], Loss: 0.1793, Accuracy: 93.74%, Test Loss: 0.2680, Test Accuracy: 90.59%
Epoch [118/150], Loss: 0.1772, Accuracy: 93.82%, Test Loss: 0.2676, Test Accuracy: 90.61%
Epoch [119/150], Loss: 0.1763, Accuracy: 93.86%, Test Loss: 0.2727, Test Accuracy: 90.58%
Epoch [120/150], Loss: 0.1752, Accuracy: 93.88%, Test Loss: 0.2696, Test Accuracy: 90.61%
Epoch [121/150], Loss: 0.1749, Accuracy: 93.96%, Test Loss: 0.2666, Test Accuracy: 90.76%
Epoch [122/150], Loss: 0.1747, Accuracy: 93.91%, Test Loss: 0.2682, Test Accuracy: 90.74%
Epoch [123/150], Loss: 0.1744, Accuracy: 93.88%, Test Loss: 0.2675, Test Accuracy: 90.72%
Epoch [124/150], Loss: 0.1726, Accuracy: 93.98%, Test Loss: 0.2693, Test Accuracy: 90.42%
Epoch [125/150], Loss: 0.1715, Accuracy: 94.10%, Test Loss: 0.2695, Test Accuracy: 90.70%
Epoch [126/150], Loss: 0.1714, Accuracy: 94.06%, Test Loss: 0.2680, Test Accuracy: 90.78%
Epoch [127/150], Loss: 0.1714, Accuracy: 94.03%, Test Loss: 0.2658, Test Accuracy: 90.79%
Epoch [128/150], Loss: 0.1704, Accuracy: 94.05%, Test Loss: 0.2676, Test Accuracy: 90.66%
Epoch [129/150], Loss: 0.1693, Accuracy: 94.11%, Test Loss: 0.2667, Test Accuracy: 90.67%
Epoch [130/150], Loss: 0.1691, Accuracy: 94.09%, Test Loss: 0.2675, Test Accuracy: 90.68%
Epoch [131/150], Loss: 0.1674, Accuracy: 94.18%, Test Loss: 0.2664, Test Accuracy: 90.87%
Epoch [132/150], Loss: 0.1672, Accuracy: 94.19%, Test Loss: 0.2661, Test Accuracy: 90.79%
Epoch [133/150], Loss: 0.1664, Accuracy: 94.21%, Test Loss: 0.2681, Test Accuracy: 90.62%
Epoch [134/150], Loss: 0.1654, Accuracy: 94.31%, Test Loss: 0.2656, Test Accuracy: 90.85%
Epoch [135/150], Loss: 0.1647, Accuracy: 94.32%, Test Loss: 0.2680, Test Accuracy: 90.65%
Epoch [136/150], Loss: 0.1639, Accuracy: 94.38%, Test Loss: 0.2671, Test Accuracy: 90.76%
Epoch [137/150], Loss: 0.1640, Accuracy: 94.34%, Test Loss: 0.2696, Test Accuracy: 90.57%
Epoch [138/150], Loss: 0.1633, Accuracy: 94.34%, Test Loss: 0.2652, Test Accuracy: 90.79%
Epoch [139/150], Loss: 0.1623, Accuracy: 94.39%, Test Loss: 0.2652, Test Accuracy: 90.86%
Epoch [140/150], Loss: 0.1617, Accuracy: 94.43%, Test Loss: 0.2640, Test Accuracy: 90.90%
Epoch [141/150], Loss: 0.1597, Accuracy: 94.53%, Test Loss: 0.2654, Test Accuracy: 90.90%
Epoch [142/150], Loss: 0.1600, Accuracy: 94.50%, Test Loss: 0.2667, Test Accuracy: 90.72%
Epoch [143/150], Loss: 0.1585, Accuracy: 94.62%, Test Loss: 0.2656, Test Accuracy: 90.90%
Epoch [144/150], Loss: 0.1597, Accuracy: 94.52%, Test Loss: 0.2668, Test Accuracy: 90.73%
Epoch [145/150], Loss: 0.1584, Accuracy: 94.54%, Test Loss: 0.2647, Test Accuracy: 90.78%
Epoch [146/150], Loss: 0.1575, Accuracy: 94.55%, Test Loss: 0.2649, Test Accuracy: 90.97%
Epoch [147/150], Loss: 0.1569, Accuracy: 94.63%, Test Loss: 0.2635, Test Accuracy: 90.89%
Epoch [148/150], Loss: 0.1557, Accuracy: 94.63%, Test Loss: 0.2645, Test Accuracy: 90.88%
Epoch [149/150], Loss: 0.1547, Accuracy: 94.68%, Test Loss: 0.2663, Test Accuracy: 90.82%
Epoch [150/150], Loss: 0.1554, Accuracy: 94.65%, Test Loss: 0.2660, Test Accuracy: 90.81%
import matplotlib.pyplot as plt

plt.plot(train_losses, label='train_losses')
plt.plot(test_losses, label='test_losses')
plt.legend()
<matplotlib.legend.Legend at 0x7f0843e97040>
image
import matplotlib.pyplot as plt

plt.plot(train_accuracies, label='train_accuracy')
plt.plot(test_accuracies, label='test_accuracy')
plt.legend()
<matplotlib.legend.Legend at 0x7f0843943460>
image
FashionMNIST_resize

Model

import torch
import torch.nn as nn
from torchsummary import summary

class ResizeLayer(nn.Module):
    def __init__(self, scale_factor, mode='bilinear', align_corners=False):
        super(ResizeLayer, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.align_corners = align_corners

    def forward(self, x):
        return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)


class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.conv = nn.Conv2d(1, 32, kernel_size=5)
        self.resize = ResizeLayer(0.5)
        self.flatten = nn.Flatten()
        self.dense1 = nn.Linear(12*12*32, 128)
        self.dense2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.resize(x)
        x = self.flatten(x)
        x = self.relu(self.dense1(x))
        x = self.dense2(x)
        return x

model = CustomModel()
model = model.to(device)

# Print model
summary(model, (1, 28, 28))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 24, 24]             832
              ReLU-2           [-1, 32, 24, 24]               0
       ResizeLayer-3           [-1, 32, 12, 12]               0
           Flatten-4                 [-1, 4608]               0
            Linear-5                  [-1, 128]         589,952
              ReLU-6                  [-1, 128]               0
            Linear-7                   [-1, 10]           1,290
================================================================
Total params: 592,074
Trainable params: 592,074
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.35
Params size (MB): 2.26
Estimated Total Size (MB): 2.62
----------------------------------------------------------------

Loss, Optimizer, and Evaluation Function

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# Function to compute loss and accuracy for test set
def evaluate(model, testloader, criterion):
    model.eval()
    test_loss = 0.0
    running_correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            # Move inputs and labels to the device
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

    accuracy = 100 * running_correct / total
    test_loss = test_loss / len(testloader)
    return test_loss, accuracy

Train

# some parameter
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
max_epoch = 150
# train
for epoch in range(max_epoch):
    model.train()
    running_loss = 0.0
    running_correct = 0   # to track number of correct predictions
    total = 0             # to track total number of samples

    for i, (inputs, labels) in enumerate(trainloader, 0):
        # Move inputs and labels to the device
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

        # Determine class predictions and track accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        running_correct += (predicted == labels).sum().item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()        

    epoch_accuracy = 100 * running_correct / total
    epoch_loss = running_loss / (i + 1)
    
    test_loss, test_accuracy = evaluate(model, testloader, criterion)
    print(f"Epoch [{epoch + 1}/{max_epoch}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    # save for plot
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_accuracy)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
Epoch [1/150], Loss: 1.4016, Accuracy: 64.21%, Test Loss: 0.8842, Test Accuracy: 71.81%
Epoch [2/150], Loss: 0.7416, Accuracy: 75.15%, Test Loss: 0.6695, Test Accuracy: 76.61%
Epoch [3/150], Loss: 0.6084, Accuracy: 78.84%, Test Loss: 0.5897, Test Accuracy: 78.90%
Epoch [4/150], Loss: 0.5459, Accuracy: 81.01%, Test Loss: 0.5456, Test Accuracy: 80.75%
Epoch [5/150], Loss: 0.5054, Accuracy: 82.34%, Test Loss: 0.5152, Test Accuracy: 81.81%
Epoch [6/150], Loss: 0.4765, Accuracy: 83.45%, Test Loss: 0.4929, Test Accuracy: 82.34%
Epoch [7/150], Loss: 0.4535, Accuracy: 84.21%, Test Loss: 0.4704, Test Accuracy: 83.47%
Epoch [8/150], Loss: 0.4357, Accuracy: 84.99%, Test Loss: 0.4561, Test Accuracy: 83.68%
Epoch [9/150], Loss: 0.4216, Accuracy: 85.41%, Test Loss: 0.4458, Test Accuracy: 84.26%
Epoch [10/150], Loss: 0.4095, Accuracy: 85.80%, Test Loss: 0.4329, Test Accuracy: 84.60%
Epoch [11/150], Loss: 0.3980, Accuracy: 86.19%, Test Loss: 0.4221, Test Accuracy: 84.86%
Epoch [12/150], Loss: 0.3882, Accuracy: 86.49%, Test Loss: 0.4159, Test Accuracy: 85.18%
Epoch [13/150], Loss: 0.3806, Accuracy: 86.75%, Test Loss: 0.4084, Test Accuracy: 85.48%
Epoch [14/150], Loss: 0.3714, Accuracy: 87.04%, Test Loss: 0.4014, Test Accuracy: 85.71%
Epoch [15/150], Loss: 0.3660, Accuracy: 87.28%, Test Loss: 0.3950, Test Accuracy: 85.77%
Epoch [16/150], Loss: 0.3607, Accuracy: 87.34%, Test Loss: 0.3946, Test Accuracy: 85.95%
Epoch [17/150], Loss: 0.3537, Accuracy: 87.61%, Test Loss: 0.3855, Test Accuracy: 86.46%
Epoch [18/150], Loss: 0.3495, Accuracy: 87.74%, Test Loss: 0.3835, Test Accuracy: 86.45%
Epoch [19/150], Loss: 0.3434, Accuracy: 87.88%, Test Loss: 0.3755, Test Accuracy: 86.59%
Epoch [20/150], Loss: 0.3380, Accuracy: 88.17%, Test Loss: 0.3713, Test Accuracy: 86.79%
Epoch [21/150], Loss: 0.3342, Accuracy: 88.30%, Test Loss: 0.3681, Test Accuracy: 87.03%
Epoch [22/150], Loss: 0.3286, Accuracy: 88.46%, Test Loss: 0.3640, Test Accuracy: 87.21%
Epoch [23/150], Loss: 0.3257, Accuracy: 88.58%, Test Loss: 0.3611, Test Accuracy: 87.34%
Epoch [24/150], Loss: 0.3220, Accuracy: 88.63%, Test Loss: 0.3582, Test Accuracy: 87.37%
Epoch [25/150], Loss: 0.3182, Accuracy: 88.86%, Test Loss: 0.3539, Test Accuracy: 87.68%
Epoch [26/150], Loss: 0.3141, Accuracy: 88.89%, Test Loss: 0.3521, Test Accuracy: 87.57%
Epoch [27/150], Loss: 0.3106, Accuracy: 89.10%, Test Loss: 0.3494, Test Accuracy: 87.42%
Epoch [28/150], Loss: 0.3066, Accuracy: 89.22%, Test Loss: 0.3454, Test Accuracy: 87.70%
Epoch [29/150], Loss: 0.3035, Accuracy: 89.30%, Test Loss: 0.3426, Test Accuracy: 87.80%
Epoch [30/150], Loss: 0.3000, Accuracy: 89.42%, Test Loss: 0.3396, Test Accuracy: 87.99%
Epoch [31/150], Loss: 0.2989, Accuracy: 89.37%, Test Loss: 0.3366, Test Accuracy: 88.04%
Epoch [32/150], Loss: 0.2958, Accuracy: 89.63%, Test Loss: 0.3345, Test Accuracy: 88.16%
Epoch [33/150], Loss: 0.2918, Accuracy: 89.79%, Test Loss: 0.3313, Test Accuracy: 88.13%
Epoch [34/150], Loss: 0.2896, Accuracy: 89.82%, Test Loss: 0.3291, Test Accuracy: 88.33%
Epoch [35/150], Loss: 0.2867, Accuracy: 89.88%, Test Loss: 0.3288, Test Accuracy: 88.32%
Epoch [36/150], Loss: 0.2844, Accuracy: 90.00%, Test Loss: 0.3251, Test Accuracy: 88.42%
Epoch [37/150], Loss: 0.2824, Accuracy: 90.04%, Test Loss: 0.3224, Test Accuracy: 88.47%
Epoch [38/150], Loss: 0.2789, Accuracy: 90.13%, Test Loss: 0.3229, Test Accuracy: 88.65%
Epoch [39/150], Loss: 0.2772, Accuracy: 90.18%, Test Loss: 0.3195, Test Accuracy: 88.70%
Epoch [40/150], Loss: 0.2749, Accuracy: 90.28%, Test Loss: 0.3170, Test Accuracy: 88.70%
Epoch [41/150], Loss: 0.2734, Accuracy: 90.34%, Test Loss: 0.3182, Test Accuracy: 88.79%
Epoch [42/150], Loss: 0.2710, Accuracy: 90.40%, Test Loss: 0.3153, Test Accuracy: 88.87%
Epoch [43/150], Loss: 0.2690, Accuracy: 90.53%, Test Loss: 0.3121, Test Accuracy: 88.91%
Epoch [44/150], Loss: 0.2667, Accuracy: 90.58%, Test Loss: 0.3121, Test Accuracy: 88.84%
Epoch [45/150], Loss: 0.2642, Accuracy: 90.60%, Test Loss: 0.3101, Test Accuracy: 89.04%
Epoch [46/150], Loss: 0.2623, Accuracy: 90.68%, Test Loss: 0.3085, Test Accuracy: 89.02%
Epoch [47/150], Loss: 0.2599, Accuracy: 90.80%, Test Loss: 0.3088, Test Accuracy: 89.11%
Epoch [48/150], Loss: 0.2581, Accuracy: 90.81%, Test Loss: 0.3086, Test Accuracy: 88.93%
Epoch [49/150], Loss: 0.2577, Accuracy: 90.91%, Test Loss: 0.3046, Test Accuracy: 89.27%
Epoch [50/150], Loss: 0.2556, Accuracy: 90.92%, Test Loss: 0.3052, Test Accuracy: 89.22%
Epoch [51/150], Loss: 0.2537, Accuracy: 90.93%, Test Loss: 0.3058, Test Accuracy: 89.10%
Epoch [52/150], Loss: 0.2512, Accuracy: 91.07%, Test Loss: 0.3014, Test Accuracy: 89.37%
Epoch [53/150], Loss: 0.2496, Accuracy: 91.17%, Test Loss: 0.3011, Test Accuracy: 89.39%
Epoch [54/150], Loss: 0.2485, Accuracy: 91.20%, Test Loss: 0.2986, Test Accuracy: 89.48%
Epoch [55/150], Loss: 0.2466, Accuracy: 91.24%, Test Loss: 0.2985, Test Accuracy: 89.30%
Epoch [56/150], Loss: 0.2447, Accuracy: 91.38%, Test Loss: 0.2962, Test Accuracy: 89.34%
Epoch [57/150], Loss: 0.2428, Accuracy: 91.38%, Test Loss: 0.2955, Test Accuracy: 89.63%
Epoch [58/150], Loss: 0.2416, Accuracy: 91.38%, Test Loss: 0.2953, Test Accuracy: 89.52%
Epoch [59/150], Loss: 0.2401, Accuracy: 91.47%, Test Loss: 0.2958, Test Accuracy: 89.46%
Epoch [60/150], Loss: 0.2398, Accuracy: 91.48%, Test Loss: 0.2939, Test Accuracy: 89.58%
Epoch [61/150], Loss: 0.2376, Accuracy: 91.60%, Test Loss: 0.2939, Test Accuracy: 89.53%
Epoch [62/150], Loss: 0.2364, Accuracy: 91.49%, Test Loss: 0.2924, Test Accuracy: 89.73%
Epoch [63/150], Loss: 0.2357, Accuracy: 91.59%, Test Loss: 0.2948, Test Accuracy: 89.42%
Epoch [64/150], Loss: 0.2336, Accuracy: 91.67%, Test Loss: 0.2900, Test Accuracy: 89.77%
Epoch [65/150], Loss: 0.2324, Accuracy: 91.74%, Test Loss: 0.2919, Test Accuracy: 89.63%
Epoch [66/150], Loss: 0.2303, Accuracy: 91.80%, Test Loss: 0.2876, Test Accuracy: 89.82%
Epoch [67/150], Loss: 0.2287, Accuracy: 91.88%, Test Loss: 0.2894, Test Accuracy: 89.92%
Epoch [68/150], Loss: 0.2279, Accuracy: 91.90%, Test Loss: 0.2881, Test Accuracy: 89.80%
Epoch [69/150], Loss: 0.2271, Accuracy: 91.91%, Test Loss: 0.2859, Test Accuracy: 89.93%
Epoch [70/150], Loss: 0.2258, Accuracy: 92.04%, Test Loss: 0.2881, Test Accuracy: 89.80%
Epoch [71/150], Loss: 0.2247, Accuracy: 92.01%, Test Loss: 0.2872, Test Accuracy: 89.87%
Epoch [72/150], Loss: 0.2239, Accuracy: 92.02%, Test Loss: 0.2854, Test Accuracy: 90.00%
Epoch [73/150], Loss: 0.2215, Accuracy: 92.09%, Test Loss: 0.2879, Test Accuracy: 89.96%
Epoch [74/150], Loss: 0.2214, Accuracy: 92.14%, Test Loss: 0.2852, Test Accuracy: 90.00%
Epoch [75/150], Loss: 0.2195, Accuracy: 92.20%, Test Loss: 0.2843, Test Accuracy: 90.02%
Epoch [76/150], Loss: 0.2184, Accuracy: 92.29%, Test Loss: 0.2840, Test Accuracy: 89.96%
Epoch [77/150], Loss: 0.2173, Accuracy: 92.26%, Test Loss: 0.2817, Test Accuracy: 90.17%
Epoch [78/150], Loss: 0.2158, Accuracy: 92.33%, Test Loss: 0.2839, Test Accuracy: 90.09%
Epoch [79/150], Loss: 0.2145, Accuracy: 92.36%, Test Loss: 0.2859, Test Accuracy: 89.92%
Epoch [80/150], Loss: 0.2132, Accuracy: 92.46%, Test Loss: 0.2812, Test Accuracy: 90.16%
Epoch [81/150], Loss: 0.2136, Accuracy: 92.46%, Test Loss: 0.2804, Test Accuracy: 90.26%
Epoch [82/150], Loss: 0.2117, Accuracy: 92.50%, Test Loss: 0.2802, Test Accuracy: 90.23%
Epoch [83/150], Loss: 0.2108, Accuracy: 92.49%, Test Loss: 0.2779, Test Accuracy: 90.23%
Epoch [84/150], Loss: 0.2091, Accuracy: 92.59%, Test Loss: 0.2788, Test Accuracy: 90.28%
Epoch [85/150], Loss: 0.2074, Accuracy: 92.67%, Test Loss: 0.2802, Test Accuracy: 90.24%
Epoch [86/150], Loss: 0.2087, Accuracy: 92.63%, Test Loss: 0.2786, Test Accuracy: 90.29%
Epoch [87/150], Loss: 0.2064, Accuracy: 92.69%, Test Loss: 0.2767, Test Accuracy: 90.43%
Epoch [88/150], Loss: 0.2056, Accuracy: 92.71%, Test Loss: 0.2772, Test Accuracy: 90.26%
Epoch [89/150], Loss: 0.2033, Accuracy: 92.86%, Test Loss: 0.2772, Test Accuracy: 90.36%
Epoch [90/150], Loss: 0.2033, Accuracy: 92.78%, Test Loss: 0.2790, Test Accuracy: 90.28%
Epoch [91/150], Loss: 0.2024, Accuracy: 92.81%, Test Loss: 0.2773, Test Accuracy: 90.35%
Epoch [92/150], Loss: 0.2014, Accuracy: 92.85%, Test Loss: 0.2767, Test Accuracy: 90.41%
Epoch [93/150], Loss: 0.1995, Accuracy: 92.97%, Test Loss: 0.2790, Test Accuracy: 90.26%
Epoch [94/150], Loss: 0.1981, Accuracy: 93.00%, Test Loss: 0.2795, Test Accuracy: 90.22%
Epoch [95/150], Loss: 0.1991, Accuracy: 92.92%, Test Loss: 0.2752, Test Accuracy: 90.55%
Epoch [96/150], Loss: 0.1974, Accuracy: 93.02%, Test Loss: 0.2750, Test Accuracy: 90.43%
Epoch [97/150], Loss: 0.1959, Accuracy: 93.17%, Test Loss: 0.2743, Test Accuracy: 90.49%
Epoch [98/150], Loss: 0.1951, Accuracy: 93.14%, Test Loss: 0.2749, Test Accuracy: 90.39%
Epoch [99/150], Loss: 0.1949, Accuracy: 93.15%, Test Loss: 0.2740, Test Accuracy: 90.56%
Epoch [100/150], Loss: 0.1932, Accuracy: 93.15%, Test Loss: 0.2725, Test Accuracy: 90.53%
Epoch [101/150], Loss: 0.1921, Accuracy: 93.20%, Test Loss: 0.2770, Test Accuracy: 90.28%
Epoch [102/150], Loss: 0.1916, Accuracy: 93.23%, Test Loss: 0.2730, Test Accuracy: 90.52%
Epoch [103/150], Loss: 0.1900, Accuracy: 93.35%, Test Loss: 0.2722, Test Accuracy: 90.75%
Epoch [104/150], Loss: 0.1899, Accuracy: 93.28%, Test Loss: 0.2746, Test Accuracy: 90.33%
Epoch [105/150], Loss: 0.1891, Accuracy: 93.36%, Test Loss: 0.2714, Test Accuracy: 90.68%
Epoch [106/150], Loss: 0.1870, Accuracy: 93.49%, Test Loss: 0.2716, Test Accuracy: 90.66%
Epoch [107/150], Loss: 0.1872, Accuracy: 93.43%, Test Loss: 0.2708, Test Accuracy: 90.70%
Epoch [108/150], Loss: 0.1865, Accuracy: 93.48%, Test Loss: 0.2722, Test Accuracy: 90.59%
Epoch [109/150], Loss: 0.1847, Accuracy: 93.50%, Test Loss: 0.2740, Test Accuracy: 90.32%
Epoch [110/150], Loss: 0.1843, Accuracy: 93.53%, Test Loss: 0.2721, Test Accuracy: 90.62%
Epoch [111/150], Loss: 0.1843, Accuracy: 93.53%, Test Loss: 0.2706, Test Accuracy: 90.77%
Epoch [112/150], Loss: 0.1825, Accuracy: 93.62%, Test Loss: 0.2703, Test Accuracy: 90.72%
Epoch [113/150], Loss: 0.1816, Accuracy: 93.61%, Test Loss: 0.2703, Test Accuracy: 90.66%
Epoch [114/150], Loss: 0.1805, Accuracy: 93.70%, Test Loss: 0.2693, Test Accuracy: 90.65%
Epoch [115/150], Loss: 0.1810, Accuracy: 93.65%, Test Loss: 0.2746, Test Accuracy: 90.62%
Epoch [116/150], Loss: 0.1810, Accuracy: 93.67%, Test Loss: 0.2704, Test Accuracy: 90.63%
Epoch [117/150], Loss: 0.1792, Accuracy: 93.71%, Test Loss: 0.2677, Test Accuracy: 90.76%
Epoch [118/150], Loss: 0.1787, Accuracy: 93.73%, Test Loss: 0.2701, Test Accuracy: 90.79%
Epoch [119/150], Loss: 0.1765, Accuracy: 93.84%, Test Loss: 0.2702, Test Accuracy: 90.70%
Epoch [120/150], Loss: 0.1762, Accuracy: 93.87%, Test Loss: 0.2725, Test Accuracy: 90.53%
Epoch [121/150], Loss: 0.1755, Accuracy: 93.88%, Test Loss: 0.2716, Test Accuracy: 90.63%
Epoch [122/150], Loss: 0.1745, Accuracy: 93.91%, Test Loss: 0.2672, Test Accuracy: 90.74%
Epoch [123/150], Loss: 0.1729, Accuracy: 94.03%, Test Loss: 0.2689, Test Accuracy: 90.68%
Epoch [124/150], Loss: 0.1730, Accuracy: 93.97%, Test Loss: 0.2720, Test Accuracy: 90.66%
Epoch [125/150], Loss: 0.1732, Accuracy: 93.95%, Test Loss: 0.2669, Test Accuracy: 90.96%
Epoch [126/150], Loss: 0.1712, Accuracy: 94.06%, Test Loss: 0.2703, Test Accuracy: 90.61%
Epoch [127/150], Loss: 0.1700, Accuracy: 94.07%, Test Loss: 0.2691, Test Accuracy: 90.71%
Epoch [128/150], Loss: 0.1691, Accuracy: 94.18%, Test Loss: 0.2666, Test Accuracy: 90.92%
Epoch [129/150], Loss: 0.1683, Accuracy: 94.17%, Test Loss: 0.2708, Test Accuracy: 90.76%
Epoch [130/150], Loss: 0.1682, Accuracy: 94.21%, Test Loss: 0.2692, Test Accuracy: 90.80%
Epoch [131/150], Loss: 0.1675, Accuracy: 94.17%, Test Loss: 0.2674, Test Accuracy: 90.78%
Epoch [132/150], Loss: 0.1667, Accuracy: 94.21%, Test Loss: 0.2716, Test Accuracy: 90.62%
Epoch [133/150], Loss: 0.1656, Accuracy: 94.27%, Test Loss: 0.2663, Test Accuracy: 90.78%
Epoch [134/150], Loss: 0.1654, Accuracy: 94.22%, Test Loss: 0.2664, Test Accuracy: 90.84%
Epoch [135/150], Loss: 0.1643, Accuracy: 94.30%, Test Loss: 0.2677, Test Accuracy: 91.02%
Epoch [136/150], Loss: 0.1636, Accuracy: 94.37%, Test Loss: 0.2677, Test Accuracy: 90.82%
Epoch [137/150], Loss: 0.1622, Accuracy: 94.41%, Test Loss: 0.2696, Test Accuracy: 90.74%
Epoch [138/150], Loss: 0.1626, Accuracy: 94.41%, Test Loss: 0.2655, Test Accuracy: 90.95%
Epoch [139/150], Loss: 0.1613, Accuracy: 94.39%, Test Loss: 0.2680, Test Accuracy: 90.94%
Epoch [140/150], Loss: 0.1611, Accuracy: 94.44%, Test Loss: 0.2654, Test Accuracy: 90.96%
Epoch [141/150], Loss: 0.1605, Accuracy: 94.45%, Test Loss: 0.2683, Test Accuracy: 90.81%
Epoch [142/150], Loss: 0.1602, Accuracy: 94.47%, Test Loss: 0.2691, Test Accuracy: 90.83%
Epoch [143/150], Loss: 0.1597, Accuracy: 94.44%, Test Loss: 0.2673, Test Accuracy: 90.92%
Epoch [144/150], Loss: 0.1581, Accuracy: 94.55%, Test Loss: 0.2652, Test Accuracy: 91.00%
Epoch [145/150], Loss: 0.1569, Accuracy: 94.64%, Test Loss: 0.2723, Test Accuracy: 90.77%
Epoch [146/150], Loss: 0.1574, Accuracy: 94.56%, Test Loss: 0.2691, Test Accuracy: 90.89%
Epoch [147/150], Loss: 0.1557, Accuracy: 94.62%, Test Loss: 0.2672, Test Accuracy: 90.95%
Epoch [148/150], Loss: 0.1546, Accuracy: 94.71%, Test Loss: 0.2660, Test Accuracy: 91.00%
Epoch [149/150], Loss: 0.1537, Accuracy: 94.76%, Test Loss: 0.2661, Test Accuracy: 90.96%
Epoch [150/150], Loss: 0.1540, Accuracy: 94.77%, Test Loss: 0.2654, Test Accuracy: 90.98%
import matplotlib.pyplot as plt

plt.plot(train_losses, label='train_losses')
plt.plot(test_losses, label='test_losses')
plt.legend()
<matplotlib.legend.Legend at 0x7f07e3a845b0>
image
import matplotlib.pyplot as plt

plt.plot(train_accuracies, label='train_accuracy')
plt.plot(test_accuracies, label='test_accuracy')
plt.legend()
<matplotlib.legend.Legend at 0x7f07000dbd30>
image

Transition_CNN

Tradition
!pip install opencv-python scikit-image
Requirement already satisfied: opencv-python in ./.conda/lib/python3.10/site-packages (4.8.1.78)
Requirement already satisfied: scikit-image in ./.conda/lib/python3.10/site-packages (0.22.0)
Requirement already satisfied: numpy>=1.21.2 in /home/aivn12s1/.local/lib/python3.10/site-packages (from opencv-python) (1.26.1)
Requirement already satisfied: scipy>=1.8 in ./.conda/lib/python3.10/site-packages (from scikit-image) (1.11.3)
Requirement already satisfied: networkx>=2.8 in /home/aivn12s1/.local/lib/python3.10/site-packages (from scikit-image) (3.2.1)
Requirement already satisfied: pillow>=9.0.1 in ./.conda/lib/python3.10/site-packages (from scikit-image) (10.0.1)
Requirement already satisfied: imageio>=2.27 in ./.conda/lib/python3.10/site-packages (from scikit-image) (2.31.6)
Requirement already satisfied: tifffile>=2022.8.12 in ./.conda/lib/python3.10/site-packages (from scikit-image) (2023.9.26)
Requirement already satisfied: packaging>=21 in ./.conda/lib/python3.10/site-packages (from scikit-image) (23.2)
Requirement already satisfied: lazy_loader>=0.3 in ./.conda/lib/python3.10/site-packages (from scikit-image) (0.3)
from typing import Union, List
import time
import PIL
from PIL import Image
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import cv2
import matplotlib.pyplot as plt

1. Dataset

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=None)

len(train_dataset), len(test_dataset)
Files already downloaded and verified
Files already downloaded and verified
(50000, 10000)
def collate_fn(batch):
    images, labels = zip(*batch)
    return (list(images), torch.tensor(labels))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, num_workers=4)

len(train_loader), len(test_loader)
(1563, 313)

2. Model

# Computer Vision Feature Extraction
class CVF_Extraction(nn.Module):
    def __init__(self, num_filters=12, device='cpu'):
        super().__init__()
        self.num_filters = num_filters
        self.device = device

    def forward(self, images: List[PIL.Image.Image]):
        x = []
        for image in images:
            image = np.array(image)
            feature_images = self._extract(image)
            feature_normalized = CVF_Extraction._normalize(feature_images)
            x.append(feature_normalized)
        x = torch.stack(x) # (N, C, H, W)
        return x.to(device)

    def _extract(self, image):
        width, height, chanel = image.shape
        features = []
        for c in range(chanel):
            bi_image = image[:,:,c]
            sobel_image = CVF_Extraction._sobel(bi_image) 
            scharr_image = CVF_Extraction._scharr(bi_image) 
            laplacian_image = CVF_Extraction._laplacian(bi_image) 
            chanel_feature = torch.stack((
                                        torch.from_numpy(bi_image),
                                        sobel_image, 
                                        scharr_image, 
                                        laplacian_image
                                        ), dim=0)
            features.append(chanel_feature)
        features = torch.stack(features)
        return features.view(-1, width, height)

    @staticmethod
    def _normalize(input_tensor):
        _min = input_tensor.min()
        _max = input_tensor.max()
        normalized = (input_tensor - _min) / (_max - _min)
        return normalized

    @staticmethod
    def _sobel(image, ksize=3):
        sobel_x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=ksize)
        sobel_y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=ksize)
        sobel_x = torch.from_numpy(sobel_x)
        sobel_y = torch.from_numpy(sobel_y)
        sobel_magnitude = torch.hypot(sobel_x, sobel_y)
        return sobel_magnitude 

    @staticmethod
    def _scharr(image):
        scharr_x = cv2.Scharr(image, cv2.CV_32F, 1, 0)
        scharr_y = cv2.Scharr(image, cv2.CV_32F, 0, 1)
        scharr_x = torch.from_numpy(scharr_x)
        scharr_y = torch.from_numpy(scharr_y)
        scharr_magnitude = torch.hypot(scharr_x, scharr_y)
        return scharr_magnitude 

    @staticmethod
    def _laplacian(image):
        laplacian_img = cv2.Laplacian(image, cv2.CV_32F)
        laplacian_img = torch.from_numpy(laplacian_img)
        return laplacian_img


class SimpleImageCLS(nn.Module):
    def __init__(self,
                 features: CVF_Extraction,
                 img_size=32):
        super().__init__()
        self.features = features
        self.avgpool = nn.AvgPool2d((2, 2))
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten(1)
        in_features = int(self.features.num_filters*(img_size/2)**2)
        self.fc1 = nn.Linear(in_features=in_features, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=10)


    def forward(self, x):
        x = self.features(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)

        return x

3. Evaluate

def evaluate(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs, 1)

            total += labels.size(0)
            test_loss += loss.item()
            correct +=(predicted == labels).sum().item()

    test_loss = test_loss / len(test_loader)
    accuracy = 100* correct / total

    return test_loss, accuracy

4. Train

def train_model(model, train_loader, test_loader, optimizer, criterion, num_epochs, device):
    since = time.perf_counter()
    history = {
        "train_losses": [],
        "train_accuracies": [],
        "test_losses": [],
        "test_accuracies": []
    }

    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_correct = 0
        total = 0
        for images, labels in train_loader:
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

            loss.backward()
            optimizer.step()

        epoch_loss =  running_loss / len(train_loader)
        epoch_accuracy = 100* running_correct / total
        test_loss, test_accuracy = evaluate(model, test_loader, criterion, device)

        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

        history['train_losses'].append(epoch_loss)
        history['train_accuracies'].append(epoch_accuracy)
        history['test_losses'].append(test_loss)
        history['test_accuracies'].append(test_accuracy)
    time_elapsed = time.perf_counter() - since
    print(f"Training complete in {time_elapsed//3600}h {time_elapsed%3600//60}m {int(time_elapsed%60)}s with {num_epochs} epochs")
    return history
device = 'cuda'
num_epochs = 100
LR = 0.001
criterion = nn.CrossEntropyLoss()
# num_filters = (raw_image + sobel + scharr + laplacian) *  raw_chanel 
#             = 4 * 3 = 12

cv_features = CVF_Extraction(num_filters=12, device=device)
cv_moddel = model = SimpleImageCLS(features=cv_features)
optimizer = torch.optim.Adam(cv_moddel.parameters(), lr=LR)
cv_history = train_model(model, train_loader, test_loader, optimizer, criterion, num_epochs, device)
Epoch [1/100], Loss: 1.9179, Accuracy: 30.49%, Test Loss: 1.7676, Test Accuracy: 35.72%
Epoch [2/100], Loss: 1.7251, Accuracy: 38.23%, Test Loss: 1.6736, Test Accuracy: 40.23%
Epoch [3/100], Loss: 1.6576, Accuracy: 40.65%, Test Loss: 1.6788, Test Accuracy: 40.30%
Epoch [4/100], Loss: 1.6173, Accuracy: 42.22%, Test Loss: 1.6271, Test Accuracy: 42.33%
Epoch [5/100], Loss: 1.5890, Accuracy: 43.24%, Test Loss: 1.5699, Test Accuracy: 44.77%
Epoch [6/100], Loss: 1.5624, Accuracy: 44.27%, Test Loss: 1.5560, Test Accuracy: 44.62%
Epoch [7/100], Loss: 1.5427, Accuracy: 45.00%, Test Loss: 1.5582, Test Accuracy: 44.55%
Epoch [8/100], Loss: 1.5231, Accuracy: 45.79%, Test Loss: 1.5360, Test Accuracy: 45.58%
Epoch [9/100], Loss: 1.5081, Accuracy: 46.28%, Test Loss: 1.5450, Test Accuracy: 44.39%
Epoch [10/100], Loss: 1.4941, Accuracy: 46.63%, Test Loss: 1.5362, Test Accuracy: 45.64%
Epoch [11/100], Loss: 1.4833, Accuracy: 47.01%, Test Loss: 1.5231, Test Accuracy: 45.61%
Epoch [12/100], Loss: 1.4722, Accuracy: 47.64%, Test Loss: 1.5059, Test Accuracy: 46.92%
Epoch [13/100], Loss: 1.4630, Accuracy: 47.93%, Test Loss: 1.5330, Test Accuracy: 46.20%
Epoch [14/100], Loss: 1.4543, Accuracy: 48.22%, Test Loss: 1.4848, Test Accuracy: 47.07%
Epoch [15/100], Loss: 1.4433, Accuracy: 48.58%, Test Loss: 1.4962, Test Accuracy: 46.34%
Epoch [16/100], Loss: 1.4393, Accuracy: 48.83%, Test Loss: 1.5082, Test Accuracy: 45.93%
Epoch [17/100], Loss: 1.4346, Accuracy: 48.83%, Test Loss: 1.4842, Test Accuracy: 47.17%
Epoch [18/100], Loss: 1.4271, Accuracy: 49.33%, Test Loss: 1.5005, Test Accuracy: 46.64%
Epoch [19/100], Loss: 1.4143, Accuracy: 49.77%, Test Loss: 1.4642, Test Accuracy: 47.91%
Epoch [20/100], Loss: 1.4097, Accuracy: 50.01%, Test Loss: 1.4672, Test Accuracy: 48.26%
Epoch [21/100], Loss: 1.4067, Accuracy: 50.17%, Test Loss: 1.5120, Test Accuracy: 45.85%
Epoch [22/100], Loss: 1.4044, Accuracy: 50.20%, Test Loss: 1.4667, Test Accuracy: 48.10%
Epoch [23/100], Loss: 1.4029, Accuracy: 50.15%, Test Loss: 1.4630, Test Accuracy: 47.99%
Epoch [24/100], Loss: 1.3896, Accuracy: 50.71%, Test Loss: 1.4871, Test Accuracy: 47.42%
Epoch [25/100], Loss: 1.3868, Accuracy: 50.91%, Test Loss: 1.4699, Test Accuracy: 48.28%
Epoch [26/100], Loss: 1.3844, Accuracy: 51.07%, Test Loss: 1.4566, Test Accuracy: 48.66%
Epoch [27/100], Loss: 1.3802, Accuracy: 51.08%, Test Loss: 1.4779, Test Accuracy: 47.37%
Epoch [28/100], Loss: 1.3780, Accuracy: 51.26%, Test Loss: 1.4566, Test Accuracy: 48.46%
Epoch [29/100], Loss: 1.3717, Accuracy: 51.39%, Test Loss: 1.5191, Test Accuracy: 45.84%
Epoch [30/100], Loss: 1.3696, Accuracy: 51.56%, Test Loss: 1.4544, Test Accuracy: 48.74%
Epoch [31/100], Loss: 1.3613, Accuracy: 51.86%, Test Loss: 1.4518, Test Accuracy: 48.66%
Epoch [32/100], Loss: 1.3577, Accuracy: 51.94%, Test Loss: 1.4359, Test Accuracy: 49.33%
Epoch [33/100], Loss: 1.3570, Accuracy: 52.07%, Test Loss: 1.4763, Test Accuracy: 47.67%
Epoch [34/100], Loss: 1.3527, Accuracy: 52.15%, Test Loss: 1.4455, Test Accuracy: 48.94%
Epoch [35/100], Loss: 1.3489, Accuracy: 52.10%, Test Loss: 1.4532, Test Accuracy: 48.20%
Epoch [36/100], Loss: 1.3481, Accuracy: 52.11%, Test Loss: 1.4543, Test Accuracy: 48.75%
Epoch [37/100], Loss: 1.3458, Accuracy: 52.20%, Test Loss: 1.4340, Test Accuracy: 49.09%
Epoch [38/100], Loss: 1.3426, Accuracy: 52.55%, Test Loss: 1.4442, Test Accuracy: 49.12%
Epoch [39/100], Loss: 1.3409, Accuracy: 52.63%, Test Loss: 1.4677, Test Accuracy: 48.60%
Epoch [40/100], Loss: 1.3365, Accuracy: 52.59%, Test Loss: 1.4651, Test Accuracy: 48.58%
Epoch [41/100], Loss: 1.3386, Accuracy: 52.69%, Test Loss: 1.4455, Test Accuracy: 48.90%
Epoch [42/100], Loss: 1.3317, Accuracy: 53.16%, Test Loss: 1.4666, Test Accuracy: 48.51%
Epoch [43/100], Loss: 1.3292, Accuracy: 52.79%, Test Loss: 1.4837, Test Accuracy: 47.26%
Epoch [44/100], Loss: 1.3240, Accuracy: 53.22%, Test Loss: 1.4712, Test Accuracy: 48.73%
Epoch [45/100], Loss: 1.3240, Accuracy: 53.01%, Test Loss: 1.4715, Test Accuracy: 48.05%
Epoch [46/100], Loss: 1.3235, Accuracy: 53.17%, Test Loss: 1.4535, Test Accuracy: 48.22%
Epoch [47/100], Loss: 1.3222, Accuracy: 53.13%, Test Loss: 1.5046, Test Accuracy: 46.66%
Epoch [48/100], Loss: 1.3202, Accuracy: 53.23%, Test Loss: 1.4588, Test Accuracy: 48.40%
Epoch [49/100], Loss: 1.3144, Accuracy: 53.42%, Test Loss: 1.4701, Test Accuracy: 48.30%
Epoch [50/100], Loss: 1.3125, Accuracy: 53.43%, Test Loss: 1.4546, Test Accuracy: 48.65%
Epoch [51/100], Loss: 1.3115, Accuracy: 53.47%, Test Loss: 1.4432, Test Accuracy: 48.82%
Epoch [52/100], Loss: 1.3131, Accuracy: 53.34%, Test Loss: 1.4857, Test Accuracy: 47.58%
Epoch [53/100], Loss: 1.3077, Accuracy: 53.84%, Test Loss: 1.4387, Test Accuracy: 49.39%
Epoch [54/100], Loss: 1.3086, Accuracy: 53.67%, Test Loss: 1.4703, Test Accuracy: 48.30%
Epoch [55/100], Loss: 1.3034, Accuracy: 53.86%, Test Loss: 1.4368, Test Accuracy: 49.75%
Epoch [56/100], Loss: 1.3007, Accuracy: 54.06%, Test Loss: 1.4398, Test Accuracy: 48.98%
Epoch [57/100], Loss: 1.3006, Accuracy: 53.93%, Test Loss: 1.5147, Test Accuracy: 47.27%
Epoch [58/100], Loss: 1.3007, Accuracy: 53.92%, Test Loss: 1.4732, Test Accuracy: 48.02%
Epoch [59/100], Loss: 1.2978, Accuracy: 53.83%, Test Loss: 1.4391, Test Accuracy: 49.35%
Epoch [60/100], Loss: 1.2959, Accuracy: 54.13%, Test Loss: 1.4628, Test Accuracy: 48.41%
Epoch [61/100], Loss: 1.2962, Accuracy: 53.92%, Test Loss: 1.4422, Test Accuracy: 49.12%
Epoch [62/100], Loss: 1.2936, Accuracy: 54.32%, Test Loss: 1.5264, Test Accuracy: 47.15%
Epoch [63/100], Loss: 1.2893, Accuracy: 54.43%, Test Loss: 1.5218, Test Accuracy: 46.91%
Epoch [64/100], Loss: 1.2861, Accuracy: 54.47%, Test Loss: 1.4720, Test Accuracy: 48.87%
Epoch [65/100], Loss: 1.2846, Accuracy: 54.39%, Test Loss: 1.4648, Test Accuracy: 48.11%
Epoch [66/100], Loss: 1.2853, Accuracy: 54.36%, Test Loss: 1.4296, Test Accuracy: 49.44%
Epoch [67/100], Loss: 1.2805, Accuracy: 54.88%, Test Loss: 1.4371, Test Accuracy: 49.27%
Epoch [68/100], Loss: 1.2827, Accuracy: 54.61%, Test Loss: 1.5122, Test Accuracy: 47.00%
Epoch [69/100], Loss: 1.2772, Accuracy: 54.55%, Test Loss: 1.4539, Test Accuracy: 48.15%
Epoch [70/100], Loss: 1.2786, Accuracy: 54.75%, Test Loss: 1.4617, Test Accuracy: 48.44%
Epoch [71/100], Loss: 1.2793, Accuracy: 54.64%, Test Loss: 1.4497, Test Accuracy: 49.82%
Epoch [72/100], Loss: 1.2735, Accuracy: 54.84%, Test Loss: 1.4614, Test Accuracy: 49.16%
Epoch [73/100], Loss: 1.2749, Accuracy: 54.80%, Test Loss: 1.4271, Test Accuracy: 50.12%
Epoch [74/100], Loss: 1.2723, Accuracy: 54.90%, Test Loss: 1.4414, Test Accuracy: 49.82%
Epoch [75/100], Loss: 1.2739, Accuracy: 54.96%, Test Loss: 1.4897, Test Accuracy: 47.99%
Epoch [76/100], Loss: 1.2705, Accuracy: 55.01%, Test Loss: 1.4361, Test Accuracy: 49.46%
Epoch [77/100], Loss: 1.2673, Accuracy: 55.19%, Test Loss: 1.4660, Test Accuracy: 48.62%
Epoch [78/100], Loss: 1.2669, Accuracy: 55.26%, Test Loss: 1.4452, Test Accuracy: 48.84%
Epoch [79/100], Loss: 1.2655, Accuracy: 55.36%, Test Loss: 1.4540, Test Accuracy: 48.75%
Epoch [80/100], Loss: 1.2672, Accuracy: 55.19%, Test Loss: 1.4474, Test Accuracy: 49.28%
Epoch [81/100], Loss: 1.2641, Accuracy: 55.17%, Test Loss: 1.4611, Test Accuracy: 48.98%
Epoch [82/100], Loss: 1.2610, Accuracy: 55.38%, Test Loss: 1.4365, Test Accuracy: 49.37%
Epoch [83/100], Loss: 1.2624, Accuracy: 55.42%, Test Loss: 1.4443, Test Accuracy: 49.45%
Epoch [84/100], Loss: 1.2591, Accuracy: 55.45%, Test Loss: 1.4720, Test Accuracy: 48.33%
Epoch [85/100], Loss: 1.2580, Accuracy: 55.41%, Test Loss: 1.5616, Test Accuracy: 46.27%
Epoch [86/100], Loss: 1.2612, Accuracy: 55.31%, Test Loss: 1.4466, Test Accuracy: 49.29%
Epoch [87/100], Loss: 1.2578, Accuracy: 55.52%, Test Loss: 1.4564, Test Accuracy: 49.57%
Epoch [88/100], Loss: 1.2558, Accuracy: 55.70%, Test Loss: 1.4480, Test Accuracy: 49.23%
Epoch [89/100], Loss: 1.2541, Accuracy: 55.66%, Test Loss: 1.4572, Test Accuracy: 49.06%
Epoch [90/100], Loss: 1.2533, Accuracy: 55.70%, Test Loss: 1.4647, Test Accuracy: 48.70%
Epoch [91/100], Loss: 1.2499, Accuracy: 55.74%, Test Loss: 1.4919, Test Accuracy: 48.23%
Epoch [92/100], Loss: 1.2470, Accuracy: 55.74%, Test Loss: 1.4887, Test Accuracy: 48.52%
Epoch [93/100], Loss: 1.2483, Accuracy: 55.97%, Test Loss: 1.4664, Test Accuracy: 48.58%
Epoch [94/100], Loss: 1.2477, Accuracy: 55.68%, Test Loss: 1.4491, Test Accuracy: 49.52%
Epoch [95/100], Loss: 1.2514, Accuracy: 55.76%, Test Loss: 1.4650, Test Accuracy: 48.83%
Epoch [96/100], Loss: 1.2475, Accuracy: 55.90%, Test Loss: 1.4397, Test Accuracy: 49.48%
Epoch [97/100], Loss: 1.2458, Accuracy: 55.90%, Test Loss: 1.4800, Test Accuracy: 48.37%
Epoch [98/100], Loss: 1.2437, Accuracy: 56.12%, Test Loss: 1.4516, Test Accuracy: 49.39%
Epoch [99/100], Loss: 1.2460, Accuracy: 55.67%, Test Loss: 1.4566, Test Accuracy: 49.04%
Epoch [100/100], Loss: 1.2447, Accuracy: 55.84%, Test Loss: 1.4530, Test Accuracy: 49.39%
Training complete in 0.0h 15.0m 17s with 100 epochs

5. Plot the result


def plot_result(history):
    train_accuracies = history['train_accuracies']
    test_accuracies = history['test_accuracies']
    train_losses = history['train_losses']
    test_losses = history['test_losses']

    plt.figure(figsize=(8, 8))
    plt.subplot(2, 1, 1)
    plt.plot(train_accuracies, label='Training Accuracy')
    plt.plot(test_accuracies, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.ylabel('Accuracy')
    plt.ylim([min(plt.ylim()),100])
    plt.title('Training and Validation Accuracy')

    plt.subplot(2, 1, 2)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(test_losses, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.ylabel('Binary Cross Entropy')
    plt.title('Training and Validation Loss')
    plt.xlabel('epoch')
    plt.show()
plot_result(cv_history) 
image
CNN
from typing import Union, List
import time
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import cv2
import matplotlib.pyplot as plt

1. Dataset

transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

len(train_dataset), len(test_dataset)
Files already downloaded and verified
Files already downloaded and verified
(50000, 10000)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

len(train_loader), len(test_loader)
(1563, 313)

2. Model

# Deep Feature Extraction
class DF_Extracion(nn.Module):
    def __init__(self, num_filters=64, kernel_size=3, device='cpu'):
        super().__init__()
        self.num_filters = num_filters
        self.device = device
        self.conv1 = nn.Conv2d(3, self.num_filters, kernel_size=kernel_size, padding='same')
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        return x

class SimpleImageCLS(nn.Module):
    def __init__(self,
                 features: DF_Extracion,
                 img_size=32):
        super().__init__()
        self.features = features
        self.avgpool = nn.AvgPool2d((2, 2))
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten(1)
        in_features = int(self.features.num_filters*(img_size/2)**2)
        self.fc1 = nn.Linear(in_features=in_features, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=10)


    def forward(self, x):
        x = self.features(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)

        return x

3. Evaluate

def evaluate(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs, 1)

            total += labels.size(0)
            test_loss += loss.item()
            correct +=(predicted == labels).sum().item()

    test_loss = test_loss / len(test_loader)
    accuracy = 100* correct / total

    return test_loss, accuracy

4. Train

def train_model(model, train_loader, test_loader, optimizer, criterion, num_epochs, device):
    since = time.perf_counter()
    history = {
        "train_losses": [],
        "train_accuracies": [],
        "test_losses": [],
        "test_accuracies": []
    }

    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_correct = 0
        total = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

            loss.backward()
            optimizer.step()

        epoch_loss =  running_loss / len(train_loader)
        epoch_accuracy = 100* running_correct / total
        test_loss, test_accuracy = evaluate(model, test_loader, criterion, device)

        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

        history['train_losses'].append(epoch_loss)
        history['train_accuracies'].append(epoch_accuracy)
        history['test_losses'].append(test_loss)
        history['test_accuracies'].append(test_accuracy)
    time_elapsed = time.perf_counter() - since
    print(f"Training complete in {time_elapsed//3600}h {time_elapsed%3600//60}m {int(time_elapsed%60)}s with {num_epochs} epochs")
    return history
device = 'cuda'
num_epochs = 100
LR = 0.001
criterion = nn.CrossEntropyLoss()
deep_features = DF_Extracion(num_filters=12, device=device)
deep_moddel = model = SimpleImageCLS(features=deep_features)
optimizer = torch.optim.Adam(deep_moddel.parameters(), lr=LR)
deep_history = train_model(model, train_loader, test_loader, optimizer, criterion, num_epochs, device)
Epoch [1/100], Loss: 1.6513, Accuracy: 40.51%, Test Loss: 1.4823, Test Accuracy: 46.49%
Epoch [2/100], Loss: 1.3937, Accuracy: 49.83%, Test Loss: 1.3402, Test Accuracy: 51.29%
Epoch [3/100], Loss: 1.2861, Accuracy: 54.17%, Test Loss: 1.3442, Test Accuracy: 52.82%
Epoch [4/100], Loss: 1.2137, Accuracy: 56.93%, Test Loss: 1.2502, Test Accuracy: 54.86%
Epoch [5/100], Loss: 1.1597, Accuracy: 58.98%, Test Loss: 1.2032, Test Accuracy: 57.33%
Epoch [6/100], Loss: 1.1189, Accuracy: 60.50%, Test Loss: 1.2131, Test Accuracy: 56.69%
Epoch [7/100], Loss: 1.0828, Accuracy: 61.83%, Test Loss: 1.2039, Test Accuracy: 57.14%
Epoch [8/100], Loss: 1.0505, Accuracy: 63.13%, Test Loss: 1.2243, Test Accuracy: 57.41%
Epoch [9/100], Loss: 1.0200, Accuracy: 64.19%, Test Loss: 1.2061, Test Accuracy: 57.69%
Epoch [10/100], Loss: 0.9935, Accuracy: 65.12%, Test Loss: 1.1978, Test Accuracy: 58.21%
Epoch [11/100], Loss: 0.9711, Accuracy: 65.93%, Test Loss: 1.2024, Test Accuracy: 58.22%
Epoch [12/100], Loss: 0.9442, Accuracy: 66.78%, Test Loss: 1.1997, Test Accuracy: 58.75%
Epoch [13/100], Loss: 0.9200, Accuracy: 67.73%, Test Loss: 1.2243, Test Accuracy: 57.70%
Epoch [14/100], Loss: 0.8966, Accuracy: 68.61%, Test Loss: 1.2480, Test Accuracy: 57.80%
Epoch [15/100], Loss: 0.8754, Accuracy: 69.25%, Test Loss: 1.2369, Test Accuracy: 58.26%
Epoch [16/100], Loss: 0.8544, Accuracy: 69.75%, Test Loss: 1.2769, Test Accuracy: 57.14%
Epoch [17/100], Loss: 0.8368, Accuracy: 70.64%, Test Loss: 1.2559, Test Accuracy: 58.03%
Epoch [18/100], Loss: 0.8156, Accuracy: 71.22%, Test Loss: 1.2741, Test Accuracy: 58.05%
Epoch [19/100], Loss: 0.7967, Accuracy: 71.87%, Test Loss: 1.3169, Test Accuracy: 57.49%
Epoch [20/100], Loss: 0.7805, Accuracy: 72.63%, Test Loss: 1.2905, Test Accuracy: 58.44%
Epoch [21/100], Loss: 0.7632, Accuracy: 73.13%, Test Loss: 1.3169, Test Accuracy: 57.22%
Epoch [22/100], Loss: 0.7470, Accuracy: 73.59%, Test Loss: 1.3191, Test Accuracy: 57.69%
Epoch [23/100], Loss: 0.7305, Accuracy: 74.15%, Test Loss: 1.3595, Test Accuracy: 57.62%
Epoch [24/100], Loss: 0.7119, Accuracy: 75.02%, Test Loss: 1.3617, Test Accuracy: 57.58%
Epoch [25/100], Loss: 0.6958, Accuracy: 75.40%, Test Loss: 1.3599, Test Accuracy: 57.69%
Epoch [26/100], Loss: 0.6793, Accuracy: 75.99%, Test Loss: 1.3979, Test Accuracy: 57.22%
Epoch [27/100], Loss: 0.6667, Accuracy: 76.56%, Test Loss: 1.4218, Test Accuracy: 57.51%
Epoch [28/100], Loss: 0.6508, Accuracy: 76.99%, Test Loss: 1.4717, Test Accuracy: 56.61%
Epoch [29/100], Loss: 0.6394, Accuracy: 77.46%, Test Loss: 1.4517, Test Accuracy: 57.35%
Epoch [30/100], Loss: 0.6213, Accuracy: 78.12%, Test Loss: 1.4830, Test Accuracy: 56.89%
Epoch [31/100], Loss: 0.6099, Accuracy: 78.49%, Test Loss: 1.5003, Test Accuracy: 57.00%
Epoch [32/100], Loss: 0.5986, Accuracy: 78.94%, Test Loss: 1.5475, Test Accuracy: 56.84%
Epoch [33/100], Loss: 0.5826, Accuracy: 79.43%, Test Loss: 1.5818, Test Accuracy: 56.99%
Epoch [34/100], Loss: 0.5712, Accuracy: 79.62%, Test Loss: 1.5756, Test Accuracy: 56.34%
Epoch [35/100], Loss: 0.5585, Accuracy: 80.02%, Test Loss: 1.6473, Test Accuracy: 56.25%
Epoch [36/100], Loss: 0.5461, Accuracy: 80.65%, Test Loss: 1.6195, Test Accuracy: 56.53%
Epoch [37/100], Loss: 0.5339, Accuracy: 81.04%, Test Loss: 1.6729, Test Accuracy: 56.30%
Epoch [38/100], Loss: 0.5201, Accuracy: 81.77%, Test Loss: 1.7035, Test Accuracy: 56.63%
Epoch [39/100], Loss: 0.5100, Accuracy: 81.87%, Test Loss: 1.7296, Test Accuracy: 56.23%
Epoch [40/100], Loss: 0.5001, Accuracy: 82.13%, Test Loss: 1.7844, Test Accuracy: 56.76%
Epoch [41/100], Loss: 0.4841, Accuracy: 82.64%, Test Loss: 1.8008, Test Accuracy: 55.78%
Epoch [42/100], Loss: 0.4790, Accuracy: 82.97%, Test Loss: 1.8521, Test Accuracy: 55.82%
Epoch [43/100], Loss: 0.4649, Accuracy: 83.43%, Test Loss: 1.8631, Test Accuracy: 55.49%
Epoch [44/100], Loss: 0.4527, Accuracy: 84.00%, Test Loss: 1.9058, Test Accuracy: 56.17%
Epoch [45/100], Loss: 0.4456, Accuracy: 84.23%, Test Loss: 1.9263, Test Accuracy: 55.98%
Epoch [46/100], Loss: 0.4314, Accuracy: 84.92%, Test Loss: 1.9378, Test Accuracy: 55.98%
Epoch [47/100], Loss: 0.4235, Accuracy: 84.90%, Test Loss: 2.0077, Test Accuracy: 55.59%
Epoch [48/100], Loss: 0.4099, Accuracy: 85.31%, Test Loss: 2.0487, Test Accuracy: 55.33%
Epoch [49/100], Loss: 0.4005, Accuracy: 85.96%, Test Loss: 2.0824, Test Accuracy: 55.72%
Epoch [50/100], Loss: 0.3948, Accuracy: 85.91%, Test Loss: 2.1315, Test Accuracy: 55.16%
Epoch [51/100], Loss: 0.3826, Accuracy: 86.40%, Test Loss: 2.1616, Test Accuracy: 55.66%
Epoch [52/100], Loss: 0.3747, Accuracy: 86.76%, Test Loss: 2.2287, Test Accuracy: 55.78%
Epoch [53/100], Loss: 0.3669, Accuracy: 87.00%, Test Loss: 2.2471, Test Accuracy: 56.06%
Epoch [54/100], Loss: 0.3585, Accuracy: 87.19%, Test Loss: 2.3133, Test Accuracy: 55.64%
Epoch [55/100], Loss: 0.3477, Accuracy: 87.67%, Test Loss: 2.3617, Test Accuracy: 55.55%
Epoch [56/100], Loss: 0.3386, Accuracy: 87.92%, Test Loss: 2.4342, Test Accuracy: 54.97%
Epoch [57/100], Loss: 0.3291, Accuracy: 88.28%, Test Loss: 2.4055, Test Accuracy: 55.28%
Epoch [58/100], Loss: 0.3231, Accuracy: 88.58%, Test Loss: 2.4508, Test Accuracy: 55.18%
Epoch [59/100], Loss: 0.3178, Accuracy: 88.83%, Test Loss: 2.4872, Test Accuracy: 55.16%
Epoch [60/100], Loss: 0.3061, Accuracy: 89.12%, Test Loss: 2.5646, Test Accuracy: 54.63%
Epoch [61/100], Loss: 0.3080, Accuracy: 88.92%, Test Loss: 2.6294, Test Accuracy: 54.77%
Epoch [62/100], Loss: 0.2932, Accuracy: 89.65%, Test Loss: 2.6746, Test Accuracy: 54.46%
Epoch [63/100], Loss: 0.2885, Accuracy: 89.75%, Test Loss: 2.7049, Test Accuracy: 54.92%
Epoch [64/100], Loss: 0.2833, Accuracy: 89.97%, Test Loss: 2.7055, Test Accuracy: 54.45%
Epoch [65/100], Loss: 0.2727, Accuracy: 90.27%, Test Loss: 2.8185, Test Accuracy: 54.40%
Epoch [66/100], Loss: 0.2699, Accuracy: 90.40%, Test Loss: 2.8319, Test Accuracy: 54.25%
Epoch [67/100], Loss: 0.2590, Accuracy: 90.82%, Test Loss: 2.9443, Test Accuracy: 54.64%
Epoch [68/100], Loss: 0.2535, Accuracy: 90.98%, Test Loss: 2.9530, Test Accuracy: 54.42%
Epoch [69/100], Loss: 0.2517, Accuracy: 91.10%, Test Loss: 3.1637, Test Accuracy: 53.44%
Epoch [70/100], Loss: 0.2467, Accuracy: 91.31%, Test Loss: 3.0132, Test Accuracy: 54.14%
Epoch [71/100], Loss: 0.2390, Accuracy: 91.68%, Test Loss: 3.1162, Test Accuracy: 54.19%
Epoch [72/100], Loss: 0.2348, Accuracy: 91.66%, Test Loss: 3.1825, Test Accuracy: 54.40%
Epoch [73/100], Loss: 0.2261, Accuracy: 92.12%, Test Loss: 3.1728, Test Accuracy: 54.23%
Epoch [74/100], Loss: 0.2187, Accuracy: 92.39%, Test Loss: 3.2768, Test Accuracy: 53.04%
Epoch [75/100], Loss: 0.2200, Accuracy: 92.29%, Test Loss: 3.3226, Test Accuracy: 54.34%
Epoch [76/100], Loss: 0.2113, Accuracy: 92.46%, Test Loss: 3.4061, Test Accuracy: 54.14%
Epoch [77/100], Loss: 0.2100, Accuracy: 92.63%, Test Loss: 3.4267, Test Accuracy: 54.12%
Epoch [78/100], Loss: 0.2073, Accuracy: 92.57%, Test Loss: 3.4055, Test Accuracy: 54.16%
Epoch [79/100], Loss: 0.1946, Accuracy: 93.10%, Test Loss: 3.5124, Test Accuracy: 54.20%
Epoch [80/100], Loss: 0.1955, Accuracy: 92.90%, Test Loss: 3.5968, Test Accuracy: 53.55%
Epoch [81/100], Loss: 0.1981, Accuracy: 92.95%, Test Loss: 3.6091, Test Accuracy: 53.93%
Epoch [82/100], Loss: 0.1914, Accuracy: 93.23%, Test Loss: 3.6857, Test Accuracy: 53.74%
Epoch [83/100], Loss: 0.1848, Accuracy: 93.43%, Test Loss: 3.7409, Test Accuracy: 53.96%
Epoch [84/100], Loss: 0.1923, Accuracy: 93.13%, Test Loss: 3.7898, Test Accuracy: 54.11%
Epoch [85/100], Loss: 0.1751, Accuracy: 93.77%, Test Loss: 3.9184, Test Accuracy: 54.00%
Epoch [86/100], Loss: 0.1793, Accuracy: 93.59%, Test Loss: 3.8326, Test Accuracy: 53.72%
Epoch [87/100], Loss: 0.1742, Accuracy: 93.87%, Test Loss: 3.8659, Test Accuracy: 54.06%
Epoch [88/100], Loss: 0.1693, Accuracy: 94.00%, Test Loss: 3.9660, Test Accuracy: 53.51%
Epoch [89/100], Loss: 0.1684, Accuracy: 94.07%, Test Loss: 3.9639, Test Accuracy: 54.25%
Epoch [90/100], Loss: 0.1543, Accuracy: 94.60%, Test Loss: 4.2332, Test Accuracy: 53.33%
Epoch [91/100], Loss: 0.1659, Accuracy: 94.15%, Test Loss: 4.1239, Test Accuracy: 53.69%
Epoch [92/100], Loss: 0.1608, Accuracy: 94.34%, Test Loss: 4.1580, Test Accuracy: 53.71%
Epoch [93/100], Loss: 0.1524, Accuracy: 94.68%, Test Loss: 4.2674, Test Accuracy: 52.97%
Epoch [94/100], Loss: 0.1543, Accuracy: 94.56%, Test Loss: 4.2781, Test Accuracy: 53.40%
Epoch [95/100], Loss: 0.1475, Accuracy: 94.88%, Test Loss: 4.3320, Test Accuracy: 53.77%
Epoch [96/100], Loss: 0.1488, Accuracy: 94.78%, Test Loss: 4.4525, Test Accuracy: 53.73%
Epoch [97/100], Loss: 0.1477, Accuracy: 94.84%, Test Loss: 4.3641, Test Accuracy: 53.52%
Epoch [98/100], Loss: 0.1411, Accuracy: 95.02%, Test Loss: 4.4396, Test Accuracy: 53.48%
Epoch [99/100], Loss: 0.1386, Accuracy: 95.13%, Test Loss: 4.5190, Test Accuracy: 53.49%
Epoch [100/100], Loss: 0.1384, Accuracy: 95.15%, Test Loss: 4.5941, Test Accuracy: 53.53%
Training complete in 0.0h 5.0m 31s with 100 epochs

5. Plot the result

def plot_result(history):
    train_accuracies = history['train_accuracies']
    test_accuracies = history['test_accuracies']
    train_losses = history['train_losses']
    test_losses = history['test_losses']

    plt.figure(figsize=(8, 8))
    plt.subplot(2, 1, 1)
    plt.plot(train_accuracies, label='Training Accuracy')
    plt.plot(test_accuracies, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.ylabel('Accuracy')
    plt.ylim([min(plt.ylim()),100])
    plt.title('Training and Validation Accuracy')

    plt.subplot(2, 1, 2)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(test_losses, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.ylabel('Binary Cross Entropy')
    plt.title('Training and Validation Loss')
    plt.xlabel('epoch')
    plt.show()
plot_result(deep_history) 
image