In my previous article, I noted some interesting behavior regarding Weight Decay; here, I examine it in detail.

It is generally accepted in the ML industry that if we take a pre-trained model and fine-tune it on a new task, the old weights are gradually overwritten. Furthermore, if we add Weight Decay (L2 regularization), the process of "forgetting" superfluous information should theoretically happen even faster.

I tested this claim experimentally. The results were counter-intuitive: under specific settings, Weight Decay works in the exact opposite way—it protects the old structure from destruction.

Below is a description of the experiment and conclusions for those involved in model training and AI safety.

Spoiler: When retraining a neural network, a Weight Decay of ~10^-3 creates a paradoxical effect:

  • The model without pre-training forgets old patterns (accuracy drops to 57%).

  • The model with pre-training preserves them (accuracy remains at 83%).

  • The 25% difference represents structural memory that is resistant to regularization.

Practical Conclusion: Standard fine-tuning (WD ~10^-4) does not delete, but rather masks unwanted information within models.

The Essence of the Experiment

To study the effect of Weight Decay on model training, the experiment required a rigorous setup. We needed to determine how retraining affects the preservation of old data.

To do this, in the experimental model, we first train it on Task A: determining which cell the coordinates belong to (parity check: even/odd). Then, we retrain the model on a fundamentally different Task B: z = (x^2 - y^2) / 5.

However, during this second phase, we cut out a circle with radius R (1.5) from the training data, meaning the model never sees data from this region. This range is covered only by the model's generalization capabilities. During retraining, we use different levels of Weight Decay.

We cannot directly check if the model remembers Parity, because it is now trained to output the Saddle function. To understand what is happening inside the model after retraining, we use a probe—a second, separate model that looks at the internal representation (embeddings).

The probe is trained to solve the classification Task A. It uses the vectors from the first model (trained on Task B) as input data. Crucially, this training involves only coordinates inside the circle with radius R (1.5)—that is, data the model did not see while learning Task B. Since the first model passes precise information about the position of points in space, the probe can easily use this data to calculate the correct answer.

The Control Model is not trained on Task A; all other stages remain the same.

Excising the central area creates a "blind spot" where the neural network's weights receive no direct commands to change for the new task. This is necessary for the purity of the experiment: if the model were trained on the entire space, the algorithm would simply force an overwrite of the old knowledge with the new. The isolated center allows us to see the real picture: does the structure of old memory persist deep within the neural network on its own, or does it crumble under the pressure of the new task from the periphery and the mechanism of forgetting.

Let me try to explain this more simply. Both tasks depend on the same inputs—coordinates x and y.

The Control Model, while learning only Task B, forms meaningful representations of x and y only outside the restricted radius. Inside the blind spot, it receives no gradients and, under the pressure of regularization, fills it with a generalized result or a degenerate representation lacking fine structure.

The Experimental Model first forms a clear, discrete structure of x, y representations across the entire space during Task A. After retraining on Task B with Weight Decay, we use the probe to check if this old structure has survived inside the blind spot. Comparing this with the Control Model allows us to distinguish whether the observed structure is a result of generalizing Task B or residual memory from Task A.

Effectively, the blind spot simulates the staged training of LLMs:

  • Pre-training: The model is fed the entire internet, creating a broad knowledge base.

  • Fine-tuning: The model is taught a narrow task—for example, "being a helpful assistant" or "writing code."

Metaphorically speaking, the goal of the experiment is to find out: if, after a month of coding training, we ask the model to write a poem (peeking into the blind zone), can it retrieve this skill from the depths of the pre-training weights, or was it erased by specialization?

More details about the experiment in the spoiler:

Скрытый текст

Two Tasks with Different Topologies

1. The Topology Conflict

Two tasks requiring fundamentally different weight organizations:

  • Task A: "Chessboard" (Parity)

    Label = (floor(x) + floor(y)) % 2

    A high-frequency, discontinuous function. It requires the neural network to build many clear boundaries ("walls") between classes 0 and 1.

    Point (0.7, 1.3) -> floor(0) + floor(1) = 1 -> class 1

    This is the memory we are trying to erase.

  • Task B: "Saddle" (Saddle)

    z = (x^2 - y^2) / 5

    A low-frequency, smooth function. This is a classic regression requiring continuous coordinate transformation.

    This is the new task we are imposing on the model.

The tasks are geometrically orthogonal. Solving Parity requires quantizing the space, while Saddle requires interpolating it.

2. Experiment Protocol (Timeline)

The x, y space is divided into two zones:

  • Donut (R > 1.5): Here, the model receives gradients during retraining.

  • Blind Spot (R < 1.5): Here, Task B is not learned; we only check what remains inside the model's "head".

The experiment proceeds in three stages for two groups of models:

Group 1: CONTROL (Tabula Rasa)

A clean experiment simulating training from scratch.

  • Imprinting: Skipped. Weights are initialized randomly (Kaiming init).

  • Adaptation (Training): We train the Saddle task only on the "Donut".

    • Expectation: The model will learn the saddle shape at the edges and interpolate it into the center (since the function is smooth).

  • Test: We freeze the model. We launch the "Probe" into the center.

Group 2: IMPRINTING (Memory)

A structure survival experiment.

  • Imprinting: We train the Parity task on the entire space (including the center).

    • The model weights form a complex "grid" to separate classes. Accuracy ~99%.

  • Adaptation (Retraining): We switch the task to Saddle. We train only on the "Donut". We enable Weight Decay.

    • The optimizer must reduce the Saddle Loss on the periphery without having data about the center.

  • Test: Freeze the model, probe the center.

3. The Probe Mechanism

What is being measured

The Student (probe) is a separate, small neural network (usually a linear classifier or a 2-layer MLP) placed on top of the frozen Teacher (main model).

Verification algorithm:

  1. Take test data X_center from the blind spot.

  2. Pass them through the teacher: V = Teacher(X_center).

    • The teacher is not learning at this moment (weights frozen). We simply extract its internal representation (embeddings) or output vector.

  3. Feed vectors V into the student.

  4. The student attempts to predict Task A (Parity).

Logic: The student never sees the x, y coordinates. It only sees how the teacher processed these coordinates.

Interpreting Student Accuracy:

  • If Accuracy ~50-60%: The teacher output chaos. Vectors V contain no information about the chessboard order. The structure is destroyed.

  • If Accuracy ~90-95% (at WD=0): The teacher perfectly encoded the coordinates (Generalization). The student used this to re-learn Parity from scratch.

  • If Accuracy ~80-85% (at WD=1e-3, where Control failed): This is structural memory. The teacher output vectors that are already grouped into clusters based on old memory.

Results

I ran the training with different Weight Decay (WD) values for 20 runs and measured how well the old task structure is preserved in the blind spot.

Here is what happened:

Weight Decay (Pressure)

Control Accuracy (Training from scratch)

Imprint Accuracy (Memory)

Hysteresis

Mode Interpretation

0.0

93.4%

89.6%

-3.8%

Overfitting: Both learned coordinates perfectly.

1e-4

91.0%

90.2%

-0.8%

Plateau: Pressure is too weak.

3e-4

80.1%

88.2%

+8.1%

Divergence: Control starts losing quality.

6e-4

65.5%

88.1%

+22.6%

Phase Transition: Sharp drop in Control.

1e-3

57.3% (Noise)

82.8% (High)

+25.4%

PEAK (Sweet Spot): Maximum hysteresis.

2e-3

54.8%

75.1%

+20.3%

Fading: Memory is still strong.

1e-2

54.5%

58.6%

+4.2%

Erasure: Regularization destroys everything.

Results Analysis:

1. The "Free Lunch" Zone (WD < 1e-4):

Both models show accuracy >90%. This is not memory; it is the neural network's ability to generalize coordinates. Without a penalty on weights (WD=0), the Control model perfectly interpolates the space even where it saw no data.

2. The "Structural Hysteresis" Zone (WD = 6e-4 ... 2e-3):

This is the discovery. At a pressure of WD ~ 10^-3:

  • Control (57.3%): The optimizer gives up. It is "cheaper" to zero out weights in the center than to build a complex structure without data.

  • Imprinting (82.8%): The optimizer gets lazy. It is "cheaper" to keep the old structure than to demolish it.

  • Summary: The difference of 25.4% is the pure contribution of past experience, which cannot be explained by chance.

3. The "Erasure" Zone (WD > 1e-2):

The pressure becomes so strong that any structure collapses. Only here does real forgetting occur.

Visualization

To ensure this wasn't a metric error, I visualized the hidden space (Latent Space) of the models in the WD = 1e-3 mode using t-SNE.

Control Group: A smooth line. The model learned a continuous function; there are no class boundaries in the center.

Imprinting Group: Clear clusters. The model preserved fragments of the old task. It thinks in categories of a chessboard, even while trying to depict a saddle.

Experiment Code:

Скрытый текст
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# ==========================================
# 1. CONFIGURATION
# ==========================================
N_RUNS_PER_WD = 20   # Increased for reliability (will take time!)
N_SAMPLES = 5000
BLIND_RADIUS = 1.5   
OUTPUT_B_DIM = 20    

# Detailed Weight Decay grid (log scale + points of interest)
WD_VALUES = [
    0.0,
    1e-5,
    5e-5,
    1e-4, 
    3e-4, 
    6e-4,
    1e-3,  # Expected "Sweet Spot"
    2e-3,
    4e-3,
    7e-3,
    1e-2,
    2e-2   # Total death zone
]

# ==========================================
# 2. FUNCTIONS (Data Generation)
# ==========================================
def task_A_Parity(x):
    grid_x = torch.floor(x[:, 0]); grid_y = torch.floor(x[:, 1])
    return ((grid_x + grid_y) % 2 == 0).float().unsqueeze(1)

def task_B_Saddle(x):
    return ((x[:, 0]**2 - x[:, 1]**2) / 5.0).unsqueeze(1)

def get_full_data(n):
    X = (torch.rand(n, 2) * 6 - 3); return X, task_A_Parity(X)

def get_donut_data(n):
    X_list = []
    while len(X_list) < n:
        batch = (torch.rand(n, 2) * 6 - 3)
        R = torch.norm(batch, dim=1)
        mask = (R > BLIND_RADIUS) & (R < 3.5)
        X_list.append(batch[mask])
    X = torch.cat(X_list)[:n]; return X, task_B_Saddle(X)

def get_center_data(n):
    X_list = []
    while len(X_list) < n:
        batch = (torch.rand(n, 2) * 6 - 3)
        R = torch.norm(batch, dim=1)
        mask = R < BLIND_RADIUS
        X_list.append(batch[mask])
    X = torch.cat(X_list)[:n]; return X, task_A_Parity(X)

# ==========================================
# 3. MODEL
# ==========================================
class Teacher(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = nn.Sequential(
            nn.Linear(2, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, 32)
        )
        self.head_a = nn.Linear(32, 1); self.head_b = nn.Linear(32, OUTPUT_B_DIM)

    def forward_A(self, x): return torch.sigmoid(self.head_a(self.core(x)))
    def forward_B(self, x): return self.head_b(self.core(x))

# ==========================================
# 4. TRAINING AND TESTING LOGIC
# ==========================================
def run_comparison(wd_value):
    def train_and_probe(scenario):
        teacher = Teacher()
        
        # 1. IMPRINTING
        if scenario == 'IMPRINTING':
            X_full, Y_par = get_full_data(N_SAMPLES)
            opt = optim.Adam(teacher.parameters(), lr=0.01)
            crit = nn.BCELoss()
            for _ in range(300):
                loss = crit(teacher.forward_A(X_full), Y_par)
                opt.zero_grad(); loss.backward(); opt.step()
        
        # 2. ADAPTATION (Weight Decay is applied here)
        X_donut, Y_sad = get_donut_data(N_SAMPLES)
        opt = optim.Adam(teacher.parameters(), lr=0.001, weight_decay=wd_value)
        crit = nn.MSELoss()
        for _ in range(800):
            # Important: constrain the mean to allow freedom for hidden dimensions
            loss = crit(teacher.forward_B(X_donut).mean(dim=1, keepdim=True), Y_sad)
            opt.zero_grad(); loss.backward(); opt.step()

        # 3. PROBE (Test in the center)
        X_cen, Y_par_cen = get_center_data(1000)
        teacher.eval()
        with torch.no_grad(): V = teacher.forward_B(X_cen)
        
        # Check for neuron "death"
        if V.std().item() < 0.001: return 0.5
        
        student = nn.Sequential(nn.Linear(OUTPUT_B_DIM, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())
        opt_s = optim.Adam(student.parameters(), lr=0.01)
        crit_s = nn.BCELoss()
        for _ in range(300):
            loss = crit_s(student(V.detach()), Y_par_cen)
            opt_s.zero_grad(); loss.backward(); opt_s.step()
            
        with torch.no_grad():
            return ((student(V) > 0.5) == Y_par_cen).float().mean().item()

    return train_and_probe('CONTROL'), train_and_probe('IMPRINTING')

# ==========================================
# 5. EXECUTION AND STATISTICS GATHERING
# ==========================================
print(f"STARTING DETAILED SCAN: {len(WD_VALUES)} points, {N_RUNS_PER_WD} runs each.")
print("="*105)
print(f"{'Weight Decay':<12} | {'Control':<10} | {'Imprint':<10} | {'LEAKAGE':<10} | {'Std Dev':<10} | {'Status':<25}")
print("-" * 105)

stats = []

for wd in WD_VALUES:
    c_list, i_list = [], []
    for i in range(N_RUNS_PER_WD):
        c, m = run_comparison(wd)
        c_list.append(c); i_list.append(m)
        # Small progress indicator within the point
        if i % 5 == 0: print(".", end="", flush=True)
    
    print("", end="\r") # Clear dot line
    
    mc, mi = np.mean(c_list), np.mean(i_list)
    std_i = np.std(i_list) # Imprinting spread is most important
    leak = mi - mc
    stats.append((wd, leak))
    
    # Automatic mode classification
    if mc > 0.85 and mi > 0.85: status = "Overfitting (Both High)"
    elif mc < 0.60 and mi > 0.75: status = ">>> MEMORY ZONE <<<"
    elif mc < 0.60 and mi < 0.60: status = "Erasure (Both Low)"
    else: status = "Transition"
    
    # Draw leakage "stars"
    bar = "*" * int(leak * 100 / 2.5) # 1 star per 2.5%
    
    print(f"{wd:<12.1e} | {mc:.1%}    | {mi:.1%}    | +{leak:.1%} {bar:<10} | +/-{std_i:.1%}   | {status}")

print("="*105)
best = max(stats, key=lambda x: x[1])
print(f"PEAK LEAKAGE: +{best[1]:.2%} at WD={best[0]}")

Interpretation and Conclusions

Why does Weight Decay help preserve memory?

The optimizer does exactly what is required of it: it optimizes the budget (Loss + Weight Decay).

  • In the Control case: Building a complex structure in the center from scratch is expensive (the WD penalty). It is simpler to create a flat placeholder.

  • In the Imprinting case: The structure is already built. Modifying weights incurs a cost in resources. It is cheaper to leave everything as it is, provided it does not contradict the new task on the periphery.

Weight Decay in this mode acts as an inertia stabilizer.

The Neural Network Remembers What We Tried to Erase: Experiment Conclusions

1. The Illusion of "Unlearning" and Safety

  • Hypothesis: The neural network retains memory of the old task even where the new task provides no gradients (in blind zones). Memory hides not in the outputs, but in the redundant degrees of freedom (Null Space). It can be extracted using a probe, even if the model itself outputs garbage.

  • Practice: If you are using Fine-tuning for "censoring" or forgetting harmful data, the standard approach does not work. Ordinary FT (with WD ~1e-4) simply drives old patterns underground but does not erase them. To truly delete information, you need either Re-initialization (resetting weights) or extremely aggressive Weight Decay that will destroy the weight structure itself.

2. The Structural Hysteresis Trap

  • Hypothesis: Two models can have identical Loss and identical answers to standard queries but be functionally different. This is structural hysteresis: differences are visible only through the history of training.

  • Practice: Validation is invalid. You might have two checkpoints with identical accuracy. But the first, trained from scratch, is a clean slate, while the second (fine-tuned) might output an old pattern on a rare query.

3. Weight Decay is More Than a Regularizer

  • Hypothesis: WD acts as a selector. It cleans structure in a model learning from scratch (Control), but paradoxically conserves weight structure in an already trained model.

  • Practice: WD settings depend on the goal.

    • If Transfer Learning is needed: Moderate WD is useful; it prevents the old structure from dissolving.

    • If Full Adaptation (Tabula Rasa) is needed: Standard WD might hinder the process, as it is not strong enough to erase the past structure, but strong enough to distort new learning.

4. Geometry of Tasks (Compatibility)

  • Hypothesis: A new task preserves old structure better if they are geometrically compatible. Tasks that share coordinates (like in our experiment: one worked with parity, the other with shape) conserve memory more strongly than tasks that mix everything up.

  • Practice: The more orthogonal the new task is to the old one, the harder it is to erase previous knowledge with standard retraining.

High accuracy on a new task does not mean the model has forgotten the old one. It is simply hiding it.