Mitigating Memorization In Language Models

1University of Chicago, 2Argonne National Laboratory, 3Lawrence Berkeley National Laboratory, 4Dartmouth College, 5International Computer Science Institute, 6University of California, Berkeley,
LOSS LANDSCAPES OF EDITED MODELS

Loss landscapes for the Pythia 2.8B model. (a) Original model's landscape; model has memorized content. (b) Well edited model's landscape using BalancedSubnet with well configured hyper parameters (HPs); reduced memorization & preserved model performance. (c) Badly edited model's landscape using Subnet with poorly configured HPs; reduced memorization but did not preserve model performance. While the good edit does not appear to change the landscape much, the bad edit drastically changes the loss landscape.

Abstract

Language models (LMs) can “memorize” information, i.e., encode training data in their weights in such a way that inference-time queries can lead to verbatim regurgitation of that data. This ability to extract training data can be problematic, for example, when data are private or sensitive. In this work, we investigate methods to mitigate memorization: three regularizer-based, three finetuning-based, and eleven machine unlearning-based methods, with five of the latter being new methods that we introduce. We also introduce TinyMem, a suite of small, computationally-efficient LMs for the rapid development and evaluation of memorization-mitigation methods. We demonstrate that the mitigation methods that we develop using TinyMem can successfully be applied to production-grade LMs, and we determine via experiment that: regularizer-based mitigation methods are slow and ineffective at curbing memorization; fine-tuning-based methods are effective at curbing memorization, but overly expensive, especially for retaining higher accuracies; and unlearning-based methods are faster and more effective, allowing for the precise localization and removal of memorized information from LM weights prior to inference. We show, in particular, that our proposed unlearning method BalancedSubnet outperforms other mitigation methods at removing memorized information while preserving performance on target tasks.

What is memorization?

We define "memorization" as an LM's abilty to regurgitate text from its trianing data verbatim. Memorization is undesirable as it may result in unwated information being revealed to an end user, such as personally identifiable information (PII) or copy righted material.

Below we show perplexity and memorization of Pythia 2.8B and 6.9B over trainging. Notice that as models are trained for longer, memorization increases.

Pythia Memorization Over Training Graphs.

TinyMem: A Tiny Model Suite to Study Memorization

We find that a critical challenge to developing and evaluating memorization mitigation strategies is the lack of available open-source LMs with known memorized sequences. Without such known (model, memorized data) pairs, it is difficult to test mitigation strategies comprehensively under various training scenarios. Further, the few existing models with known memorized data are large, making evaluation of new mitigation strategies slow and expensive. Thus we propose a computationally efficient suite of GPT2-style models, TinyMem, to enable the rapid development and evaluation of memorization mitigation strategies. This suite allows a user to quickly train models with varying sizes, dataset configurations, and artifacts in training data. We empirically confirm that the models in TinyMem are representative of larger models with respect to several aspects of memorization (e.g., data duplication, model size).

To study memorization, we introduce two types of artifacts into TinyMem training data: perturbed versions of training data sequences (noise); and backdoored versions of training data sequences (backdoors). Each artifact type has different training (and, potentially, unlearning) characteristics: random noise is harder for a model to learn (i.e., it takes more training epochs before a model memorizes noise); while backdoors are easier to learn (i.e., a model takes fewer training epochs to learn backdoors). Additionally, within TinyMem, we consider (i) math sequence models trained on synthetic sequential math data and (ii) toy language models trained on a Wikipedia corpus. Further details about our TinyMem models can be found in our paper.

TinyMem Memorization Over Training Graphs.

How do we mitigate memorization?

In this work, we develop methods to mitigate memorization in LMs. We consider three broad classes of memorization mitigation methods: regularizer-based, fine-tuning-based, and machine-unlearning based.

Below we taxonomize the 17 methods we study. The methods highlighted in green are the ones we propose in this work. Details for each method can be found in our paper.

Method Taxonomy.

Above, we make the distinction between neuron-based and weight-based unlearning methods since they operate at different levels of granularity within LM architectures. We provide a pictoral representation below, as well as a high-level overview of performance charecteristics for both classes of unlearning methods based on our experimental results.

Method Taxonomy.

What makes an ideal mitigation method?

An ideal memorization mitigation method should:

  1. Prevent the LM from regurgitating data from training corpus
  2. Preserve LM performance on unrelated tasks
  3. Be fast and require minimal computation resources
Below we compare the properties of different clasess of mitigation methods based on our findings.

Method Mitigate Memorization Preserve LM Performance Fast
Regularizers sometimes
Fine-tuning
Unlearning

Based on the above table (more detailed results avalible in paper), we conclude that unlearning-based mitigation methods work best.

Which machine unlearning-based method is best?

We found that of the three classess of mitigation methods we studied, unlearning-based methods worked the best. Here, we further study which of the eleven unlearning based methods worked the best.

Below, we present results from applying unlearning methods to TinyMem LMs. From left to right, we present unlearning results for Math+Noise, Math+Backdoor, Language+Noise, Language+Backdoor LMs. Comparing unlearning strategies for varying model sizes, unlearning times, and data size. Effective unlearning techniques will result in 0% different in accuracy for math models or a 0% difference in perplexity for langauge models and -100\% different in % memorized. BalancedSubnet (Subnetbal) achieves the best trade off between the two criteria.

Method Taxonomy.

What is BalancedSubnet?

BalancedSubnet is a method that we introduce in this paper and it is the best performing method for mitigating memorization in LMs. BalancedSubnet is fast, prevents regurgitation of memorized content, and preserved overall LM performance. We detail how we developed BalancedSubnet below.

Our proposed method Subnet is inspired by methods that Ramanujan et al. (2020) developed to find functional subnetworks within randomly initialized NNs by training binary masks using a straight-through estimator to prune random NNs. Subnet, instead, trains a binary mask to localize sparse and performant subnetworks responsible for memorization in a pretrained LM, which we prune directly.

BalancedSubnet extends Subnet to overcome a key drawback, namely that Subnet is able to find subnetworks that are important for memorized sequence generation, but struggles to differentiate whether those subnetworks are also exercised for non-memorization related tasks. Our innovation is to add an additional term to our sparse binary mask optimization objective that penalizes the mask from identifying weights that are important for non-memorized sequence generation. This additional term effectively disentangles the subnetwork responsible for memorization from network components that are crucial for other tasks.

We provide a pseudo-code overview of the algorithm below. A full open-source implementation can be found in the accompanying GitHub repository.


argument descriptions:
	LM: original language model
	K: number of weights to drop
	num_epochs: number of iterations to perform the procedure
	memorized_sequences: set of sequences memorized by the LM
	random_sequences: set of random sequences
	loss_weight: weighting coefficient optimization objective

BalancedSubnet(LM, K, memorized_sequences, random_sequences, num_epochs, loss_weight)
	LMedited ← Copy of LM
	scores ← kaiming_uniform([...]) # Score array w/ kaiming init., 1 score per parameter
	Initialize optimizer state w/ scores
	shuffled_data ← Shuffle[memorized_sequences ∪ random_sequences]
	for e ∈ num_epochs do:
		for batch ∈ shuffled_data do:
			for i ∈ len(LM.parameters) do: # Parameters for layer i
				poriginal ← LMedited.parameters[i] # Parameters for layer i (original LM)
				p ← poriginal # Restore original p beofre we reapply scores
				s ← |scores[i]| # Absolute value of scores at layer i
				p ← drop_top_k_weights_per_layer(p,s)
				end for
				N ← Number of sequences in batch
				Initialize array batch_mask[1...N]
				for seq ∈ batch do
					if seq ∈ memorized_sequences then
						batch_mask[indexOf(seq)] ← -(1 - loss_weight)
					else
						batch_mask[indexOf(seq)] ← 1* loss_weight
					end if
				end for
				loss ← LMedited(batch).loss * batch_mask
				Backpropogate loss
				optimizer step # This updates scores w/ gradients (not LM parameters)
			end for
		end for
		return LMedited
				

Do our methods extend to production-grade models?

We have found that all unlearning-based methods are capable of mitigating memorization in TinyMem models. We especially notice that the BalancedSubnet method outperformed all of the other methods with respect to both mitigating memorization and preserving accuracy/perplexity. We now wondering if methods developed using TinyMem models can extend to production-grade LMs.

To test this question, we apply unlearning methods to Pythia 2.8B & 6.9B and compare these results with results obtrained from TinyMem models. Results are below; comparison of memorization percent difference (closer to –100 better) versus perplexity/accuracy percent different (closer to 0 better), before and after unlearning. We notice that, like before, BalancedSubnet (Subnetbal) achieves the best trade off between the two criteria.

Comparison of Unlearning Results in both Pythia and Toy models.

Below, we further investigate if unlearning methods are robust accross various training timesteps by unlearning memorization at four different timepoints in training. We notice that BalancedSubnet is able to mitigate memorization while preseving perplexities accross all time steps.

Comparison of Unlearning Results Accross Various Pythia Training Time Steps.

Conclusion

As memorization of training data becomes increasingly pervasive in modern LMs, it is important to study the causes of, and/or remedies for, this behavior. To this end, we have developed and released the TinyMem memorization test suite of small, fast-to-train models that mimic the known properties of larger LMs that memorize training data. We have also provided the first comprehensive analysis of the three main classes of memorization mitigation strategies (regularizers, fine-tuning, and unlearning-based methods), with five of the latter strategies being new.

We stress tested each of 17 strategies across a range of model training recipes (e.g., varying model size, training dataset, training lengths) from three perspectives: (i) memorization mitigation effectiveness; (ii) model accuracy preservation; and (iii) method efficiency (speed). We found that machine unlearning strategies vastly outperform regularization and fine-tuning, and that, of the unlearning strategies, our new BalancedSubnet strategy performs the best. We also demonstrated, by applying unlearning methods to Pythia 2.8 and 6.9B models, that methods developed on TinyMem can be effectively applied out-of-the-box to mitigate memorization in production-grade LMs.

Further details about all experiments and figures discussed in this blog can be found in the main paper. If there are any questions feel free to email the first author for clarification.

BibTeX


@article{sakarvadia2023mitigating,
  title={Mitigating Memorization In Language Models},
  author={Sakarvadia, Mansi and Ajith, Aswathy and Khan, Arham and Hudson, Nathaniel and Geniesse, Caleb and Chard, Kyle and Yang, Yaoqing and Foster, Ian and Mahoney, Michael},
  journal={arXiv preprint arXiv:2410.02159},
  year={2024}
}