Model Generalization
1. Motivation
Nhắc lại về CNN training
Buổi hôm trước chúng ta thảo luận về việc huấn luyện mô hình. generalization là một khái niệm quan trọng trong deep learning. Một mô hình được coi là generalization tốt khi nó có khả năng dự đoán tốt trên dữ liệu mới mà nó chưa từng thấy. Hiện tại kiến trúc buổi trước tập huấn luyện đạt được độ chính xác cao nhưng khi áp dụng vào tập test thì kết quả không tốt. Điều này chứng tỏ mô hình không generalization tốt.
2. Problem
Trước khi huấn luyện
- Mô hình bắt đầu ở trạng thái ngẫu nhiên với các biên phân tách (decision boundary) không chính xác.
- Đây là điểm khởi đầu của mô hình, mô hình không thể dự đoán chính xác trên tập huấn luyện và tập kiểm tra.
Giai đoạn huấn luyện
- Mô hình dần dần cải thiện, học được các xu hướng chính của tập dữ liệu.
- Biên phân tách trở nên hợp lý hơn, phản ánh tốt các đặc điểm của dữ liệu.
Huấn luyện tối ưu - robust fit
- Mô hình đã học được các đặc điểm chính của tập dữ liệu, biên phân tách trở nên chính xác.
- Tại trạng thái này, mô hình có khả năng tổng quát hóa (generalization) tốt, tức là có khả năng dự đoán tốt trên dữ liệu mới.
Trạng thái cuối cùng - overfitting
- Mô hình huấn luyện quá mức dẫn đến việc học “quá mức” các đặc điểm của tập dữ liệu.
- Biên phân tách trở nên quá phức tạp, mô hình không thể tổng quát hóa tốt trên dữ liệu mới.
3. Solution
3.1. Trick 1: ’learn hard’ - randomly add noise to training data
Motivation
Kỹ thuật ’learn hard'
Code
transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2470, 0.2435, 0.2616]),
transforms.RandomErasing(
p=0.75, # Xác suất áp dụng Random Erasing (75% ảnh sẽ bị xóa ngẫu nhiên một phần)
scale=(0.01, 0.3), # Kích thước phần bị xóa: tối thiểu 1% - tối đa 30% diện tích ảnh
ratio=(1.0, 1.0), # Tỷ lệ khung hình (width/height) của vùng bị xóa luôn là 1.0 (hình vuông)
value=0, # Giá trị pixel thay thế (0 tức là tô màu đen)
inplace=True # Thay đổi ảnh gốc thay vì tạo bản sao mới
)
])
3.2. Trick 2: Batch Normalization
Target
Batch Normalization giúp giảm thiểu Internal Coveriate Shift, một vấn đề xáy ra khi phân phối của các đặc trưng thay đổi liên tục trong qúa trình huấn luyện do sự cập nhật trọng số lần trước đó. Bằng cách chuẩn hóa giá trị đầu vào của mỗi lớp, Batch Normalization giúp mô hình học nhanh hơn và hiệu quả hơn, duy trì sự ổn định của phân phối dữ liệu, giúp các lơp phía sau không cần phải liên tục điều chỉnh để thích nghi. Điều này giúp mô hình tăng tốc độ hội tụ và làm quá trình huấn luyện hiệu quả hơn.
Batch Normalization vô tình thêm nhiễu (noise) vào quá trình huấn luyện trong mean và variancr giữa các mini-batch. Nhiễu này hoạt động như một dạng regularization, làm giảm nguy cơ overfitting. Nhờ đó mô hình không chỉ hoạt động tốt trên tập huấn luyện mà còn cải thiện khả năng tổng quát hóa trên tập kiểm tra, đảm bảo hiệu suất cao hơn khi gặp dữ liệu chưa từng thấy.


Tại sao batch_normalize lại giúp tăng accuracy của test (ở những bài trước thì mình thảo luận nó giúp tăng ở tập train)


3.3. Trick 3: Dropout
Target
Dropout là một kỹ thuật regularization phổ biến trong deep learning. Ý tưởng của Dropout là loại bỏ ngẫu nhiên một số lượng node trong mạng neural network trong quá trình huấn luyện. Điều này giúp mô hình trở nên robust hơn, giảm nguy cơ overfitting, tăng khả năng tổng quát hóa của mô hình.
Dropout giúp mô hình học được các đặc trưng chính xác của dữ liệu, đồng thời giảm khả năng mô hình học “quá mức” các đặc trưng của tập huấn luyện. Điều này giúp mô hình có khả năng dự đoán tốt trên dữ liệu mới, giảm khoảng cách giữa độ chính xác trên tập huấn luyện và tập kiểm tra.
Deeply in Dropout
How dropout works mathematically

Pytorch sẽ sinh ra D có giá trị = {0,1}, dot hadamard với giá trị của layer đó.

Hệ số $scale$ :
Bởi vì tắt node thì đầu ra các layer sẽ bị giảm, cần scale lên cho các node không tắt để giữ cho tổng năng lượng (magnitude) của đầu ra không đổi.
3.4. Trick 4: Kernel regularizer
Target
Code
optimizer = torch.optim.Adam(model.parameters(), lr=0.
001, weight_decay=1e-5) ## weight_decay là hệ số lambda
3.5. Trick 5: Data Augmentation
Code
Đến đây, chúng ta đã đạt được mục tiêu ban đầu: tăng độ chính xác (accuracy) trên tập kiểm tra (test) và giảm khoảng cách giữa độ chính xác trên tập huấn luyện và tập kiểm tra. Tuy nhiên, nếu muốn tăng thêm độ chính xác trên tập kiểm tra, liệu có khả thi hay không?
Giải pháp:
Để cải thiện độ chính xác trên tập kiểm tra, chúng ta cần làm:
Tăng độ chính xác trên tập huấn luyện trước (train_accuracy):
- Mục tiêu là để mô hình học được các đặc trưng phức tạp hơn trong dữ liệu.
- Điều này tương đương với việc bạn muốn đạt điểm cao hơn trên bài kiểm tra thực tế, thì trước tiên bạn phải nắm vững kiến thức trong quá trình ôn luyện.
—> Tăng model capacity
3.6. Trick 6: Trick 6: Reduce learning rate (Adam + Weight decay)
Target
3.7. Trick 7: Increase model capacity (and use more data augmentation)
Target
3.8. Trick 8: Using skip-connection
Target
- Skip connection giúp truyền ngược thông tin từ các lớp thấp đến các lớp cao hơn, giúp mô hình học được các đặc trưng phức tạp hơn trong dữ liệu.
- Skip connection giúp giảm nguy cơ vanishing gradient, giúp mô hình học nhanh hơn và hiệu quả hơn.
- Skip connection giúp mô hình học được các đặc trưng chính xác của dữ liệu, giảm nguy cơ overfitting, tăng khả năng tổng quát hóa của mô hình.
- Skip connection giúp mô hình học được các đặc trưng cục bộ và toàn cục của dữ liệu, giúp mô hình hiểu được cấu trúc không gian và hình học của dữ liệu.
3.9. Trick 9: Increase model capacity once more
4. Summary
5. Câu hỏi ôn tập
Hiện tượng nào xảy ra khi khoảng cách giữa training accuracy và test accuracy quá lớn?
Mục tiêu của buổi học này là gì?
training accuracy
và test accuracy
.Trong giai đoạn đầu huấn luyện, biên phân tách của mô hình có đặc điểm gì?
Khi nào mô hình đạt trạng thái tổng quát hóa tốt (generalization)?
Hiện tượng gì xảy ra khi mô hình huấn luyện quá mức (overfitting)?
Trước khi huấn luyện, trạng thái của mô hình như thế nào?
Tại sao việc thêm nhiễu vào dữ liệu huấn luyện lại giúp cải thiện khả năng tổng quát hóa của mô hình?
Trong PyTorch, kỹ thuật nào được sử dụng để thêm nhiễu vào dữ liệu?
torchvision.transforms
được sử dụng để thêm nhiễu một cách ngẫu nhiên vào dữ liệu huấn luyện.Batch Normalization giúp giảm thiểu vấn đề gì trong quá trình huấn luyện?
Vì sao Batch Normalization có thể hoạt động như một dạng regularization?
Batch Normalization làm thay đổi điều gì trong mỗi epoch của quá trình huấn luyện?
Sau khi áp dụng Batch Normalization, độ chính xác kiểm tra (val_accuracy) đã tăng từ bao nhiêu lên bao nhiêu?
Kỹ thuật Dropout hoạt động như thế nào trong quá trình huấn luyện?
Tại sao Dropout giúp mô hình tránh overfitting?
Kernel Regularizer (L2 Regularization) có mục đích gì trong quá trình huấn luyện mô hình?
Trong PyTorch, tham số nào của Adam optimizer thực hiện chức năng L2 Regularization?
weight_decay
trong Adam optimizer thực hiện chức năng L2 Regularization.