Implementing Transfer Learning in PyTorch

Introduction to Transfer Learning

Transfer Learning is a technique where a model trained for a certain task is used for another similar task. In deep learning, there are two major transfer learning approaches:

1. Fine-tuning: Here, a pre-trained model is loaded and used for training. This will remove the burden of random initialization on the network.

2. Feature Extraction: Like Fine-tuning, a pre-trained model is loaded and then we will freeze the weights of all layers say except the last layer then use it for training.

In both approaches, the output layer is modified according to our needs. And we may add or delete layers depending on different factors.

Let’s dive into the code

Let’s build a Dog vs Cat classifier using a pre-trained resnet34. You can download the dataset from here.

We will start with importing the necessary packages.

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms

Loading Data

We will use torchvision and packages for loading the data.

transforms = transforms.Compose([
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
train_set = datasets.ImageFolder("data/train",transforms)
val_set = datasets.ImageFolder("data/train",transforms)

train_loader =, batch_size=4,
shuffle=True, num_workers=4)
val_loader =, batch_size=4,
shuffle=True, num_workers=4)
classes = train_set.classes
device = torch.device("cuda:0" if torch.cuda.is_available()
else "cpu")

The above code is the same for both approaches.

Model Building

First, let’s import pre-trained resnet34. In Fine-tuning, that is the only thing we need to do whereas in Feature Extraction we need to freeze the weight.


model = models.resnet34(pretrained=True)

Feature Extraction

model = models.resnet34(pretrained=True)
for param in model.parameters():
param.requires_grad = False

From now code for both approaches will be the same.

In ResNet34, the last layer is a fully-connected layer with 1000 neurons. Since we are doing binary classification we will alter the final layer to have two neurons.

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model = = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

Training and Validation will be the same as we do normally in PyTorch.

Model Training

for epoch in range(25):
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
inputs =
labels =


outputs = model(inputs)
loss = criterion(outputs, labels)

running_loss += loss.item()
print('Finished Training')

Now our Transfer learned model is ready, let’s validate our model over the validation set.

Model Validation

class_correct = list(0. for i in range(2))
class_total = list(0. for i in range(2))
with torch.no_grad():
for i, data in enumerate(val_loader, 0):
inputs, labels = data
inputs =
labels =
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(2):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))

Test with your own image

from PIL import Imagemodel.eval()img_name = "1.jpeg" # change this to the name of your image file.def predict_image(image_path, model):
image =
image_tensor = transforms(image)
image_tensor = image_tensor.unsqueeze(0)
image_tensor =
output = model(image_tensor)
index = output.argmax().item()
if index == 0:
return "Cat"
elif index == 1:
return "Dog"

So, that’s how you do transfer learning in PyTorch, I hope you enjoyed it. If you’ve made it this far and found any errors in any of the above or can think of any ways to make it clearer for future readers, don’t hesitate to drop a comment. Thanks!



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store