close
close

first Drop

Com TW NOw News 2024

Learning to Unlearn: Why Data Scientists and AI Practitioners Should Understand Machine Unlearning
news

Learning to Unlearn: Why Data Scientists and AI Practitioners Should Understand Machine Unlearning

Learning to Unlearn: Why Data Scientists and AI Practitioners Should Understand Machine UnlearningPhoto by Sue Winston on Unsplash

Explore the intersections between privacy and AI with a guide to removing the impact of individual data points in AI training using the SISA technique applied to Convolutional Neural Networks (CNNs) using Python.

To the date that this article is being written and based on World Bank data, over 32% of the world’s population (approximately 8 billion) is under twenty years old. This means that approximately 2.6 billion people were born in the social media era, and it’s highly probable that almost all their lives have been registered online, by their parents, their inner circle, or in the end by themselves (depending on their attachment to social media as well as their network). If we add the people who are between their twenties and fifties, we have an extra 3.3 billion people who, to some extent, have a part of their lives registered online in different sources and formats (images, comments, videos, etc.). Of course, we can adjust the numbers considering the people over fifty, or that not everyone in the world has access to or uses the internet (at least more than 35% don’t have access or use it, based on World Bank estimations in 2021), but I’m sure you understand my point. There is a significant amount of our lives registered in today’s digital world.

Another high probability or maybe certainty (we could ask again OpenAI’s CTO🙄) is that much of this data is being used or has been used to train all the “state-of-the-art” models being deployed today, from LLMs to multimodal AI models that can process information such as images, videos, or text. In this context, when it comes to data, technology, and privacy, we often find two sides struggling to find a middle ground. On one side is the social contract that each individual has with technology, where we are willing to trade some rights in our data for the benefits that technology offers us. On the other side, is the question of where the line has to be drawn, as most defenders of this position say, “Just because data is accessible doesn’t mean that it is free to collect and use”.

In this article, we’ll explore some challenges that emerge when discussing privacy in terms of AI, including a brief overview of Machine Unlearning and the SISA training approach (Sharded, Isolated, Sliced, and Aggregated training), a machine unlearning framework recently developed to help manage or reduce the impact of individual data points in AI training and address the compliance challenge related to “The Right to Be Forgotten”.

Photo by Tingey Injury Law Firm on Unsplash

What is whispered in the closet shall be proclaimed from the house-tops

One of the first publications in history to advocate for a right to privacy is an essay published in the 1890s by two American lawyers, Samuel D. Warren and Louis Brandeis. The essay, titled The Right to Privacy, was written to raise awareness about the effects of unauthorized photographs and early newspaper enterprises, which in their own words, have turned gossip into a commodity and harmed the individual’s right to enjoy life, the right to be left alone.

That the individual shall have full protection in person and in property is a principle as old as the common law; but it has been found necessary from time to time to define anew the exact nature and extent of such protection. ….Recent inventions and business methods call attention to the next step which must be taken for the protection of the person, and for securing to the individual what Judge Cooley calls the right “to be let alone” (Samuel D. Warren, Louis Brandeis. 1890)

Times have changed since the publication of The Right to Privacy, but Warren and Louis Brandeis were not mistaken about one thing; technological, political, social, and economic changes constantly challenge existing or new rights. In response, the common law should always remain open-minded to meet the new demands of society, recognizing that the protection of society primarily comes through acknowledging the rights of the individual.

Since then, privacy has often been associated with a traditional approach of securing and protecting what we care about and want behind closed curtains, keeping it out of the public eye, and controlling its access and use. But it’s also true that its boundaries have been tested over time by disruptive technologies; photography and video set new boundaries, and recently, the exponential growth of data. But data-based technologies not only impacted the data compliance landscape; they also had some impacts on our beliefs and customs. This has been the case with social media platforms or super apps , where we are willing to trade some rights in our data for the benefits that technology offers us. This means that context matters, and in some cases, sharing our sensitive information relies more on values like trust than necessarily considering a breach of privacy.

“Data is not simply ‘private’ or ‘not private’ or ‘sensitive’ or ‘non-sensitive’. Context matters, as do normative social values…” (The Ethics of Advanced AI Assistants. Google DeepMind 2024)

The relation between context and privacy is an interesting line of thought known as the model of informational privacy in terms of
“Contextual Integrity ” (Nissenbaum, 2004). It states that in every exchange or flow of information between a sender and a receiver, there are social rules governing it. Understanding these rules is essential for ensuring that the exchange of information is properly regulated.

Figure 01 Source: Author’s own creation

A clear example could be, for instance, information regarding my child’s performance in school. If a teacher shared records of my child’s performance with other parents or strangers outside the school, I might consider that a privacy breach. However, if the same teacher shared that same information with other teachers who teach my child to share experiences and improve my child’s performance in school, I might not be as concerned and would rely on trust, values, and the good judgment of the teachers. So, under the Contextual Integrity approach, privacy is not judged as the rigid state of “the right to be left alone”. Rather, what matters is that the flow of information is appropriately regulated, taking into account the context and the governing norms within it to establish the limits. Privacy as a fundamental right shouldn’t be changed, but it could be rethinked.

Should the rigid concept of privacy remain unchanged? Or should we begin by first understanding the social rules governing information flows?

As Artificial Intelligence continues to shape the future, this rethinking challenges us to consider adapting existing rights or possibly introducing new digital rights.

Machine Unlearning

Whether you think of privacy as a rigid concept or consider the contextual integrity approach, I think most of us would agree that we all deserve our data to be processed fairly, with our consent, and with the ability to rectify or erase it if necessary.

While GDPR has facilitated the coexistence of data and privacy, balancing privacy and AI within regulatory frameworks presents a different challenge. Though we can erase or modify sensitive data from datasets, doing so in AI models is more complex. They aren’t retrained daily, and in most cases, it takes months to ensure their reliability. To address the task of selectively removing specific training data points (and their influence) in AI models without significantly sacrificing the model’s performance, techniques like Machine Unlearning have appeared and are being researched to find solutions to privacy concerns, comply with any possible enforced regulations, and protect users’ legal rights to erasure or rectification.

In contrast with the study of privacy policy, which can be traced back more than one hundred years, machine unlearning is a relatively new field, with initial studies appearing only about 10 years ago (Y. Cao and J. Yang, 2015).

So why should we be interested in machine unlearning? Whether you are an AI researcher pushing boundaries, working in AI solutions to make AI friendly for end users, here are some good reasons to adopt machine unlearning techniques in your ML processes:

· The Right to be Forgotten (RTBF): LLMs and state-of-the-art foundation models process data in complex, rapidly evolving ways. As seen with GDPR, it’s only a matter of time before the Right to Erasure is requested by users and adapted into regulations applied to AI. This will require any company using AI to adjust processes to meet regulations and follow user requests to remove personal data from pre-trained models.

· The Non-Zero Influence: Frameworks like differential privacy exist today to ensure some privacy for sensitive datasets by introducing noise to hide the contribution of any single datapoint. However, while differential privacy helps to mitigate the influence of a single datapoint, that effort is still “non-zero”. This means there is still a possibility that the targeted datapoint has some kind of influence on the model. In a scenario where a datapoint needs to be completely removed, different approaches to differential privacy may be required.

· Performance Optimization: It’s well known that foundation models are trained with significant amounts of data, requiring intensive time and compute resources. Retraining a complete model from scratch to remove a single datapoint may be the most effective way to erase any influence of that datapoint within the model, but it’s not the most efficient approach (models would need to be retrained frequently😨). The machine unlearning landscape addresses this problem by considering time and compute resources as constraints in the process of reversing or negating the effect of specific datapoints on a model’s parameters.

· Cybersecurity: Models are not exempt from attacks by adversaries who inject data to manipulate the model’s behavior to provide sensitive information about users. Machine unlearning can help remove harmful datapoints and protect the sensitive information used to train the model.

In the machine unlearning landscape, we find two lines of thought: Exact Machine Unlearning and Approximate Machine Unlearning. While Exact Machine Unlearning focuses on eliminating the influence of specific data points by removing them completely (as if that data had never been introduced to the model), Approximate Machine Unlearning aims to efficiently reduce the influence of specific data points in a trained model (making the model’s behavior approximate how it would be if the data points had never been introduced). Both approaches provide diverse techniques to address users’ right to erasure, considering constraints like deterioration in model performance, compute resources, time consumption, storage resources, specific learning models, or data structures.

For a better understanding of ongoing work in this field, I suggest two interesting readings: Machine Unlearning: Solutions and Challenges (2024) and Learn to Unlearn: Insights into Machine Unlearning (2023). Both papers provide a good recap of the extraordinary work of scientists and researchers in the Machine Unlearning field over the past few years.

SISA (Sharded, Isolated, Sliced, and Aggregated)

The SISA framework is part of the Exact Machine Unlearning line of thought and aims to remove data without requiring a full retraining of the model. The framework begins with the premise that, although retraining from scratch, excluding the data points that need to be unlearned, is the most straightforward way to align with the “Right to be Forgotten” principle (providing proof and assurance that the unwanted data has been removed), it also recognizes that this could be perceived as a naïve strategy when it comes to complex foundation models trained with large datasets, which demand high resources to be trained. So, in order to tackle the endeavor of resolving the process of unlearning, any technique should meet the following requirements:

  1. Easy to Understand (Intelligibility): The technique should be easy to understand and implement.
  2. Accuracy: Although it is reasonable that some accuracy may be lost, the gap should be small.
  3. Time/Compute Efficient: It should require less time compared to exclude data points from scratch and use compute resources similar to those already existing for training procedures.
  4. Easy to Verify (Provable Guarantee): The technique should clearly demonstrate that the solicited data points have been unlearned without affecting the model parameters, and the proof can be easily explained (even to non-experts).
  5. Model Agnostic: It should be applicable to models of varying nature and complexity.

How can we guarantee the complete removal of specific training data points? How do we verify the success of such unlearning processes?

The SISA framework (Sharded, Isolated, Sliced, and Aggregated) was first introduced in 2019 in the paper Machine Unlearning (Bourtoule et al.) to present an alternative solution to the problem of unlearning data from ML models, ensuring that the removal guarantee is easy to comprehend. The paper is easy to read in its introductory pages but could become complex if you are unfamiliar with the machine learning landscape. So, I’ll try to summarize some of the interesting characteristics I find in the technique, but if you have the time, I strongly recommend giving the paper a try, it’s worth reading! (An interesting presentation of the paper’s findings can also be watched in this video made by the authors at the IEEE Symposium on Security and Privacy)

The SISA training approach involves replicating the model several times, with each replica trained on a different subset of the dataset (known as a shard). Each model is referred to as a “constituent model”. Within each shard, the data is further divided into “slices”, and incremental learning is applied with parameters archived accordingly. Each constituent model works primarily with its assigned shard during the training phase, while the slices are used within each shard to manage the data and support incremental learning. After training, the sub-models from each shard are aggregated to form the final model. During inference, predictions from the various constituent models are combined to produce an overall prediction. Figure 02 ilustrates how the SISA training approach works.

Figure 02 Source: Author’s own creation based on Bourtoule et al. paper (2019)

When data needs to be unlearned, only the constituent models whose shards contains the point to be unlearned is retrained (a data point is unlearned from a particular slice in a particular shard).

Applying SISA: Unlearning and Retraining a CNN Model for Image Recognition

To understand how SISA can be applied, I will work on a use case example using Python. Recently, using PyTorch, computer vision techniques, and a Convolutional Neural Network (CNN), I built a basic setup to track hockey players and teams and gather some basic performance statistics (you can access the full article here).

Player Tracking with Computer Vision

Although consent to use the 40-second video for the project was provided by the Peruvian Inline Hockey Association (APHL), let’s imagine a scenario for our SISA use case: a player has complained about his images being used and, exercising his erasure rights, has requested the removal of his images from the CNN pre-trained model that classifies players into each team. This would require us to remove the images from the training dataset and retrain the entire model. However, by applying the SISA technique, we would only need to work on the shards and slices containing those images, thus avoiding the need to retrain the model from scratch and optimizing time.

The original CNN model was structured as follows:

# ************CONVOLUTIONAL NEURAL NETWORK-THREE CLASSES DETECTION**************************
# REFEREE
# WHITE TEAM (white_away)
# YELLOW TEAM (yellow_home)

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt


#******************************Data transformation********************************************
# Training and Validation Datasets
data_dir="D:/PYTHON/teams_sample_dataset"

transform = transforms.Compose((
transforms.Resize((150, 150)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
))

# Load datasets
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'), transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

#********************************CNN Model Architecture**************************************
class CNNModel(nn.Module):
def __init__(self):
super(CNNModel, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.fc1 = nn.Linear(128 * 18 * 18, 512)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(512, 3) #Three Classes

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 128 * 18 * 18)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x


#********************************CNN TRAINING**********************************************

# Model-loss function-optimizer
model = CNNModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

#*********************************Training*************************************************
num_epochs = 10
train_losses, val_losses = (), ()

for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
labels = labels.type(torch.LongTensor)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()

train_losses.append(running_loss / len(train_loader))

model.eval()
val_loss = 0.0
all_labels = ()
all_preds = ()
with torch.no_grad():
for inputs, labels in val_loader:
outputs = model(inputs)
labels = labels.type(torch.LongTensor)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, preds = torch.max(outputs, 1)
all_labels.extend(labels.tolist())
all_preds.extend(preds.tolist())

#********************************METRICS & PERFORMANCE************************************

val_losses.append(val_loss / len(val_loader))
val_accuracy = accuracy_score(all_labels, all_preds)
val_precision = precision_score(all_labels, all_preds, average="macro", zero_division=1)
val_recall = recall_score(all_labels, all_preds, average="macro", zero_division=1)
val_f1 = f1_score(all_labels, all_preds, average="macro", zero_division=1)

print(f"Epoch ({epoch + 1}/{num_epochs}), "
f"Loss: {train_losses(-1):.4f}, "
f"Val Loss: {val_losses(-1):.4f}, "
f"Val Acc: {val_accuracy:.2%}, "
f"Val Precision: {val_precision:.4f}, "
f"Val Recall: {val_recall:.4f}, "
f"Val F1 Score: {val_f1:.4f}")

#*******************************SHOW METRICS & PERFORMANCE**********************************
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.legend()
plt.show()

# SAVE THE MODEL FOR THE GH_CV_track_teams CODE
torch.save(model.state_dict(), 'D:/PYTHON/hockey_team_classifier.pth')

As you can see, it is a three-layer (conv1, conv2, conv3) neural network structure using ReLU as the activation function, trained with a dataset of approximately 90 images classified into three classes: Referee, Team_Away (White jersey players), and Team_Home (Yellow jersey players), over a full cycle of 10 epochs.

Considering this initial approach, a request to remove images from the training process would involve erasing the images from both the training and validation datasets and retraining the model. While this might be easy with a small dataset like ours, for larger datasets, such as those used in current large language models (LLMs), this would represent a significant use of resources. Additionally, performing this process repeatedly could also be a limitation.

Now, let’s imagine that while building the model, we are aware of users’ rights to erasure or rectification and consider applying the SISA technique. This approach would prepare the model for any future scenarios where images might need to be permanently removed from the training dataset, as well as any features that the CNN may have captured during its learning process. The first step would be adapting the initial model presented above to include the four steps of the SISA technique: Sharding, Isolating, Slicing, and Aggregation.

Step 01: Shards and Slices

After the transformation step specified at the beginning of the previous code, we’ll begin applying SISA by dividing the dataset into shards. In the code, you will see that the shards are diverse and then split into equal-sized parts to ensure that each shard contains a representative number of samples and is balanced across the different classes we want to predict (in our case, we are predicting three classes).


#******************************Sharding the dataset**************************

def shard_dataset(dataset, num_shards):
indices = list(range(len(dataset)))
np.random.shuffle(indices)
shards = ()
shard_size = len(dataset) // num_shards
for i in range(num_shards):
shard_indices = indices(i * shard_size : (i + 1) * shard_size)
shards.append(Subset(dataset, shard_indices))
return shards

#******************************Overlapping Slices***************************
def create_overlapping_slices(shard, slice_size, overlap):
indices = list(shard.indices)
slices = ()
step = slice_size - overlap
for start in range(0, len(indices) - slice_size + 1, step):
slice_indices = indices(start:start + slice_size)
slices.append(Subset(shard.dataset, slice_indices))
return slices

You’ll notice that for the slicing process, I didn’t assign exclusive slices per shard as the SISA technique suggests. Instead, we are using overlapping slices. This means that each slice is not exclusively composed of data points from just one shard; some data points from one slice will also appear in the next slice.

So why did I overlap the slices? As you might have guessed already, our dataset is small (approximately 90 images), so working with exclusive slices per shard would not guarantee that each slice has a sufficiently balanced dataset to maintain the predictive capability of the model. Overlapping slices allow the model to make better use of the available data and improve generalization. For larger datasets, non-overlapping slices might be more efficient, as they require fewer computational resources. In the end, creating shards and slices involves considering the size of your dataset, your compute resources, and the need to maintain the predictive capabilities of your model.

Finally, after the functions are defined, we proceed to set the hyperparameters for the sharding and slicing process:


#**************************Applying Sharding and Slicing*******************

num_shards = 4
slice_size = len(full_train_dataset) // num_shards // 2
overlap = slice_size // 2
shards = shard_dataset(full_train_dataset, num_shards)

#************************Overlapping slices for each shard*****************
all_slices = ()
for shard in shards:
slices = create_overlapping_slices(shard, slice_size, overlap)
all_slices.extend(slices)

The dataset is split into 4 shards, but I should mention that initially, I used 10 shards. This resulted in each shard containing only a few sample images, which didn’t represent corectly the full dataset’s class distribution, leading to a significant drop in the model’s performance metrics (accuracy, precision, and F1 score). Since we are dealing with a small dataset, reducing the number of shards to four was a wise decision. Finally, the slicing process divides each shard into two slices with a 50% overlap, meaning that half of the images in each slice overlap with the next slice.

Step 02: Isolating specific data points

In this step, we proceed to isolate the specific data points that end users may want to rectify or remove from the model’s learning process. First, we define a function that removes the specified data points from each slice. Next, we identify the indices of the images based on their filenames. These indices are then used to update each slice by removing the data points where they are present.


#**************************+*Isolate datapoints******************************
def isolate_data_for_unlearning(slice, data_points_to_remove):
new_indices = (i for i in slice.indices if i not in data_points_to_remove)
return Subset(slice.dataset, new_indices)

#*****Identify the indices of the images we want to rectify/erasure**********
def get_indices_to_remove(dataset, image_names_to_remove):
indices_to_remove = () #list is empty
image_to_index = {img_path: idx for idx, (img_path, _) in enumerate(dataset.imgs)}
for image_name in image_names_to_remove:
if image_name in image_to_index:
indices_to_remove.append(image_to_index(image_name))
return indices_to_remove

#*************************Specify and remove images***************************
images_to_remove = ()
indices_to_remove = get_indices_to_remove(full_train_dataset, images_to_remove)
updated_slices = (isolate_data_for_unlearning(slice, indices_to_remove) for slice in all_slices)

Currently, the list is empty (images_to_remove = () ), so no images are removed at this stage, but the setup is ready for use when a request arrives (we’ll see an example later in this article).

The complete version of the model implementing the SISA technique should look something like this:


import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt


#******************************Data transformation********************************************
# Training and Validation Datasets
data_dir="D:/PYTHON/teams_sample_dataset"

transform = transforms.Compose((
transforms.Resize((150, 150)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
))

# Load data
full_train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'), transform=transform)

#******************************Sharding the dataset**************************

def shard_dataset(dataset, num_shards):
indices = list(range(len(dataset)))
np.random.shuffle(indices)
shards = ()
shard_size = len(dataset) // num_shards
for i in range(num_shards):
shard_indices = indices(i * shard_size : (i + 1) * shard_size)
shards.append(Subset(dataset, shard_indices))
return shards

#******************************Overlapping Slices***************************
def create_overlapping_slices(shard, slice_size, overlap):
indices = list(shard.indices)
slices = ()
step = slice_size - overlap
for start in range(0, len(indices) - slice_size + 1, step):
slice_indices = indices(start:start + slice_size)
slices.append(Subset(shard.dataset, slice_indices))
return slices

#**************************Applying Sharding and Slicing*******************

num_shards = 4
slice_size = len(full_train_dataset) // num_shards // 2
overlap = slice_size // 2
shards = shard_dataset(full_train_dataset, num_shards)

#************************Overlapping slices for each shard*****************
all_slices = ()
for shard in shards:
slices = create_overlapping_slices(shard, slice_size, overlap)
all_slices.extend(slices)

#**************************+*Isolate datapoints******************************
def isolate_data_for_unlearning(slice, data_points_to_remove):
new_indices = (i for i in slice.indices if i not in data_points_to_remove)
return Subset(slice.dataset, new_indices)

#*****Identify the indices of the images we want to rectify/erasure**********
def get_indices_to_remove(dataset, image_names_to_remove):
indices_to_remove = ()
image_to_index = {img_path: idx for idx, (img_path, _) in enumerate(dataset.imgs)}
for image_name in image_names_to_remove:
if image_name in image_to_index:
indices_to_remove.append(image_to_index(image_name))
return indices_to_remove

#*************************Specify and remove images***************************
images_to_remove = ()
indices_to_remove = get_indices_to_remove(full_train_dataset, images_to_remove)
updated_slices = (isolate_data_for_unlearning(slice, indices_to_remove) for slice in all_slices)


#********************************CNN Model Architecture**************************************

class CNNModel(nn.Module):
def __init__(self):
super(CNNModel, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.fc1 = nn.Linear(128 * 18 * 18, 512)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(512, 3) # Output three classes

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 128 * 18 * 18)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x

#********************************CNN TRAINING**********************************************

# Model-loss function-optimizer
model = CNNModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

#*********************************Training*************************************************
num_epochs = 10
train_losses, val_losses = (), ()

for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for slice in updated_slices:
train_loader = DataLoader(slice, batch_size=32, shuffle=True)
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
labels = labels.type(torch.LongTensor)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()

train_losses.append(running_loss / (len(updated_slices)))

model.eval()
val_loss = 0.0
all_labels = ()
all_preds = ()
with torch.no_grad():
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
for inputs, labels in val_loader:
outputs = model(inputs)
labels = labels.type(torch.LongTensor)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, preds = torch.max(outputs, 1)
all_labels.extend(labels.tolist())
all_preds.extend(preds.tolist())

#********************************METRICS & PERFORMANCE************************************

val_losses.append(val_loss / len(val_loader))
val_accuracy = accuracy_score(all_labels, all_preds)
val_precision = precision_score(all_labels, all_preds, average="macro", zero_division=1)
val_recall = recall_score(all_labels, all_preds, average="macro", zero_division=1)
val_f1 = f1_score(all_labels, all_preds, average="macro", zero_division=1)

print(f"Epoch ({epoch + 1}/{num_epochs}), "
f"Loss: {train_losses(-1):.4f}, "
f"Val Loss: {val_losses(-1):.4f}, "
f"Val Acc: {val_accuracy:.2%}, "
f"Val Precision: {val_precision:.4f}, "
f"Val Recall: {val_recall:.4f}, "
f"Val F1 Score: {val_f1:.4f}")

#*******************************SHOW METRICS & PERFORMANCE**********************************
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.legend()
plt.show()

# SAVE THE MODEL
torch.save(model.state_dict(), 'hockey_team_classifier_SISA.pth')

Now, let’s go to our erasure scenario. Imagine that months have passed since the model was deployed, and a hockey player requests the removal of their images from the CNN model’s training data. For this example, let’s assume the player is represented in three images from the training and validation dataset: Away_image03.JPG, Away_image04.JPG, and Away_image05.JPG. To remove these images from the training process, simply specify them in the “Specify and Remove Images” section of the code (as shown above). Only the slices containing these images would need to be retrained.

#*************************Specify and remove images***************************
images_to_remove = ("Away_image03.JPG", "Away_image04.JPG", "Away_image05.JPG")
indices_to_remove = get_indices_to_remove(full_train_dataset, images_to_remove)
updated_slices = (isolate_data_for_unlearning(slice, indices_to_remove) for slice in all_slices)

Finally, I would like to share some key takeaways from adapting the SISA framework to my model:

  • Weak learners and performance trade-offs: Since each constituent model is trained on small subsets (shards and slices), one might assume that their accuracy would be lower than that of a single model trained on the entire dataset and degrading the model’s generalization. Surprisingly, in our case, the model’s performance improved significantly, which could be due to working with a small, overlapping dataset, leading to some degree of overfitting. In use cases involving large datasets, it’s important to consider the potential performance trade-offs.
  • Proper sharding: My initial attempts with a high number of shards resulted in shards with very few samples, leading to a negative impact on the model’s performance. Don’t underestimate the importance of the sharding and slicing process. Proper sharding helps the model avoid overfitting and generalize better on the validation set.

I hope you found this project applying the SISA technique for machine unlearning interesting. You can access the complete code in this GitHub repository.

Final Thoughts

My older sister and I have this routine where we exchange images of what social media platform’s daily remind us of what we posted five, ten, or fifteen years ago. We often laugh about the things we shared or the comments we made at that time (clearly, as most of us didn’t fully understand social media when it first appeared). Over time, I have learned to use my social media presence more wisely, appreciating my surroundings outside the social media ecosystem and the privacy that some aspects of our lives deserve. But the truth is that neither my sister nor I are the same people we were ten or fifteen years ago, and although the past is an important part of who we are now, it doesn’t define us (not everything has to be “written in stone” in the digital world). We all have the right to choose whether that data may or may not stay in the digital world and be used or not to define our choices/preferences or the ones from others.

It’s true that AI performs better when trained with data from users similar to those who will use it (The Ethics of Advanced AI Assistants, Google DeepMind 2024). However, “Privacy requires Transparency” . Therefore, how and when companies using machine learning with pre-trained sensitive data implement the “Right to be Forgotten” is crucial for moving toward the trustworthy AI we all want.

Thank you for reading! As always, your suggestions are welcome and keep the conversation going.


Learning to Unlearn: Why Data Scientists and AI Practitioners Should Understand Machine Unlearning was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.