Language models (LMs) can “memorize” information, i.e., encode training data in their weights in such a way that inference-time queries can lead to verbatim regurgitation of that data. This ability to extract training data can be problematic, for example, when data are private or sensitive. In this work, we investigate methods to mitigate memorization: three regularizer-based, three finetuning-based, and eleven machine unlearning-based methods, with five of the latter being new methods that we introduce. We also introduce TinyMem, a suite of small, computationally-efficient LMs for the rapid development and evaluation of memorization-mitigation methods. We demonstrate that the mitigation methods that we develop using TinyMem can successfully be applied to production-grade LMs, and we determine via experiment that: regularizer-based mitigation methods are slow and ineffective at curbing memorization; fine-tuning-based methods are effective at curbing memorization, but overly expensive, especially for retaining higher accuracies; and unlearning-based methods are faster and more effective, allowing for the precise localization and removal of memorized information from LM weights prior to inference. We show, in particular, that our proposed unlearning method BalancedSubnet outperforms other mitigation methods at removing memorized information while preserving performance on target tasks.
We define "memorization" as an LM's abilty to regurgitate text from its trianing data verbatim.
Memorization is undesirable as it may result in unwated information being revealed to an end user, such as personally identifiable information (PII) or copy righted material.
Below we show perplexity and memorization of Pythia 2.8B and 6.9B over trainging. Notice that as models are trained for longer, memorization increases.
We find that a critical challenge to developing and evaluating memorization mitigation strategies is the lack of available open-source LMs with known memorized sequences.
Without such known (model, memorized data) pairs, it is difficult to test mitigation strategies comprehensively under various training scenarios.
Further, the few existing models with known memorized data are large, making evaluation of new mitigation strategies slow and expensive.
Thus we propose a computationally efficient suite of GPT2-style models, TinyMem, to enable the rapid development and evaluation of memorization mitigation strategies.
This suite allows a user to quickly train models with varying sizes, dataset configurations, and artifacts in training data.
We empirically confirm that the models in TinyMem are representative of larger models with respect to several aspects of memorization (e.g., data duplication, model size).
To study memorization, we introduce two types of artifacts into TinyMem training data: perturbed versions of training data sequences (noise); and backdoored versions of training data sequences (backdoors).
Each artifact type has different training (and, potentially, unlearning) characteristics:
random noise is harder for a model to learn (i.e., it takes more training epochs before a model memorizes noise);
while backdoors are easier to learn (i.e., a model takes fewer training epochs to learn backdoors).
Additionally, within TinyMem, we consider (i) math sequence models trained on synthetic sequential math data and (ii) toy language models trained on a Wikipedia corpus.
Further details about our TinyMem models can be found in our paper.
In this work, we develop methods to mitigate memorization in LMs. We consider three broad classes of memorization mitigation methods: regularizer-based, fine-tuning-based, and machine-unlearning based.
Below we taxonomize the 17 methods we study. The methods highlighted in green are the ones we propose in this work. Details for each method can be found in our paper.
Above, we make the distinction between neuron-based and weight-based unlearning methods since they operate at different levels of granularity within LM architectures. We provide a pictoral representation below, as well as a high-level overview of performance charecteristics for both classes of unlearning methods based on our experimental results.
An ideal memorization mitigation method should:
Method | Mitigate Memorization | Preserve LM Performance | Fast |
---|---|---|---|
Regularizers | ❌ | sometimes | ❌ |
Fine-tuning | ✅ | ✅ | ❌ |
Unlearning | ✅ | ✅ | ✅ |
Based on the above table (more detailed results avalible in paper), we conclude that unlearning-based mitigation methods work best.
We found that of the three classess of mitigation methods we studied, unlearning-based methods worked the best. Here, we further study which of the eleven unlearning based methods worked the best.
Below, we present results from applying unlearning methods to TinyMem LMs.
From left to right, we present unlearning results for Math+Noise, Math+Backdoor, Language+Noise, Language+Backdoor LMs.
Comparing unlearning strategies for varying model sizes, unlearning times, and data size.
Effective unlearning techniques will result in 0% different in accuracy for math models or a 0% difference in perplexity for langauge models and -100\% different in % memorized.
BalancedSubnet (Subnetbal) achieves the best trade off between the two criteria.
BalancedSubnet is a method that we introduce in this paper and it is the best performing method for mitigating memorization in LMs.
BalancedSubnet is fast, prevents regurgitation of memorized content, and preserved overall LM performance.
We detail how we developed BalancedSubnet below.
Our proposed method Subnet is inspired by methods that Ramanujan et al. (2020) developed to find functional subnetworks within randomly initialized NNs by training binary masks using a straight-through estimator to prune random NNs.
Subnet, instead, trains a binary mask to localize sparse and performant subnetworks responsible for memorization in a pretrained LM, which we prune directly.
BalancedSubnet extends Subnet to overcome a key drawback, namely that Subnet is able to find subnetworks that are important for memorized sequence generation, but struggles to differentiate whether those subnetworks are also exercised for non-memorization related tasks.
Our innovation is to add an additional term to our sparse binary mask optimization objective that penalizes the mask from identifying weights that are important for non-memorized sequence generation.
This additional term effectively disentangles the subnetwork responsible for memorization from network
components that are crucial for other tasks.
We provide a pseudo-code overview of the algorithm below. A full open-source implementation can be found in the accompanying GitHub repository.
argument descriptions:
LM: original language model
K: number of weights to drop
num_epochs: number of iterations to perform the procedure
memorized_sequences: set of sequences memorized by the LM
random_sequences: set of random sequences
loss_weight: weighting coefficient optimization objective
BalancedSubnet(LM, K, memorized_sequences, random_sequences, num_epochs, loss_weight)
LMedited ← Copy of LM
scores ← kaiming_uniform([...]) # Score array w/ kaiming init., 1 score per parameter
Initialize optimizer state w/ scores
shuffled_data ← Shuffle[memorized_sequences ∪ random_sequences]
for e ∈ num_epochs do:
for batch ∈ shuffled_data do:
for i ∈ len(LM.parameters) do: # Parameters for layer i
poriginal ← LMedited.parameters[i] # Parameters for layer i (original LM)
p ← poriginal # Restore original p beofre we reapply scores
s ← |scores[i]| # Absolute value of scores at layer i
p ← drop_top_k_weights_per_layer(p,s)
end for
N ← Number of sequences in batch
Initialize array batch_mask[1...N]
for seq ∈ batch do
if seq ∈ memorized_sequences then
batch_mask[indexOf(seq)] ← -(1 - loss_weight)
else
batch_mask[indexOf(seq)] ← 1* loss_weight
end if
end for
loss ← LMedited(batch).loss * batch_mask
Backpropogate loss
optimizer step # This updates scores w/ gradients (not LM parameters)
end for
end for
return LMedited
We have found that all unlearning-based methods are capable of mitigating memorization in TinyMem models.
We especially notice that the BalancedSubnet method outperformed all of the other methods with respect to both mitigating memorization and preserving accuracy/perplexity.
We now wondering if methods developed using TinyMem models can extend to production-grade LMs.
To test this question, we apply unlearning methods to Pythia 2.8B & 6.9B and compare these results with results obtrained from TinyMem models.
Results are below; comparison of memorization percent difference (closer to –100 better) versus perplexity/accuracy percent different (closer to 0 better), before and after unlearning.
We notice that, like before, BalancedSubnet (Subnetbal) achieves the best trade off between the two criteria.
Below, we further investigate if unlearning methods are robust accross various training timesteps by unlearning memorization at four different timepoints in training. We notice that BalancedSubnet is able to mitigate memorization while preseving perplexities accross all time steps.
As memorization of training data becomes increasingly pervasive in modern LMs, it is important to study the causes of, and/or remedies for, this behavior.
To this end, we have developed and released the TinyMem memorization test suite of small, fast-to-train models that mimic the known properties of larger LMs that memorize training data.
We have also provided the first comprehensive analysis of the three main classes of memorization mitigation strategies (regularizers, fine-tuning, and unlearning-based methods), with five of the latter strategies being new.
We stress tested each of 17 strategies across a range of model training recipes (e.g., varying model size, training dataset, training lengths) from three perspectives:
(i) memorization mitigation effectiveness; (ii) model accuracy preservation; and (iii) method efficiency (speed).
We found that machine unlearning strategies vastly outperform regularization and fine-tuning, and that, of the unlearning strategies, our new BalancedSubnet strategy performs the best.
We also demonstrated, by applying unlearning methods to Pythia 2.8 and 6.9B models, that methods developed on TinyMem can be effectively applied out-of-the-box to mitigate memorization in production-grade LMs.
Further details about all experiments and figures discussed in this blog can be found in the main paper. If there are any questions feel free to email the first author for clarification.
@article{sakarvadia2023mitigating,
title={Mitigating Memorization In Language Models},
author={Sakarvadia, Mansi and Ajith, Aswathy and Khan, Arham and Hudson, Nathaniel and Geniesse, Caleb and Chard, Kyle and Yang, Yaoqing and Foster, Ian and Mahoney, Michael},
journal={arXiv preprint arXiv:2410.02159},
year={2024}
}