Generative AI

How Knowledge Distillation Compresses Ensemble Intelligence into a Single-Use AI Model

Complex prediction problems often lead to ensembles because combining multiple models improves accuracy by reducing variability and capturing different patterns. However, these ensembles are not practical in production due to latency issues and operational complexity.

Instead of discarding them, Knowledge Distillation offers a smart approach: keep a collection as a teacher and train a small student model using its soft probability results. This allows the reader to inherit many of the cluster's functionality while being lightweight and fast enough to use.

In this article, we build this pipeline from scratch – train a teacher ensemble with 12 examples, generate soft targets with temperature measurements, and export it to a learner that achieves a 53.8% ensemble accuracy edge at 160× compression.

What is Knowledge Distillation?

Information filtering is a model compression technique in which a large, pre-trained “teacher” model transfers its learned behavior to a small “learner” model. Instead of being trained only on ground-truth labels, the student is trained to simulate the teacher's predictions—capturing not just the final results but the rich patterns embedded in their probability distributions. This method enables the student to estimate the performance of complex models while remaining very small and very fast. From the early work of compressing large ensemble models into discrete networks, information processing is now widely used in domains such as NLP, speech, and computer vision, and has become increasingly important in reducing large generative AI models into functional, usable systems.

Knowledge Distillation: From Ensemble Teacher to Dependent Student

Setting dependencies

pip install torch scikit-learn numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np
torch.manual_seed(42)
np.random.seed(42)

Creates a dataset

This block creates and prepares a dataset for performing a binary classification function (such as predicting whether a user clicks on an ad). First, make_classification generates 5,000 samples with 20 features, some of which are informative and some of which do not need to simulate the complexity of real-world data. The dataset is then split into training and test sets to test the performance of the model on unobserved data.

Next, StandardScaler adjusts the features to a consistent scale, which helps neural networks train more efficiently. The data is then converted to a PyTorch tensor for use in model training. Finally, DataLoader is designed to feed data in small batches (size 64) during training, improving efficiency and enabling stochastic gradient descent.

X, y = make_classification(
    n_samples=5000, n_features=20, n_informative=10,
    n_redundant=5, random_state=42
)
 
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
 
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test  = scaler.transform(X_test)
 
# Convert to tensors
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t  = torch.tensor(y_train, dtype=torch.long)
X_test_t   = torch.tensor(X_test,  dtype=torch.float32)
y_test_t   = torch.tensor(y_test,  dtype=torch.long)
 
train_loader = DataLoader(
    TensorDataset(X_train_t, y_train_t), batch_size=64, shuffle=True
)

Model Architecture

This section describes two properties of a neural network: a Teacher model and a StudentModel. The teacher represents one of the largest models in the collection—it has many layers, wide scales, and general dropout, making it more transparent but more expensive at the time of prediction.

The learner model, on the other hand, is a small and efficient network with few layers and parameters. Its goal is not to match the complexity of the teacher, but to study its behavior through distillation. Importantly, the student still retains sufficient ability to measure the parameters of the teacher's decision—too little, and will not be able to capture the rich patterns learned by the group.

class TeacherModel(nn.Module):
    """Represents one heavy model inside the ensemble."""
    def __init__(self, input_dim=20, num_classes=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, 128),       nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, 64),        nn.ReLU(),
            nn.Linear(64, num_classes)
        )
    def forward(self, x):
        return self.net(x)
 
 
class StudentModel(nn.Module):
    """
    The lean production model that learns from the ensemble.
    Two hidden layers -- enough capacity to absorb distilled
    knowledge, still ~30x smaller than the full ensemble.
    """
    def __init__(self, input_dim=20, num_classes=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64), nn.ReLU(),
            nn.Linear(64, 32),        nn.ReLU(),
            nn.Linear(32, num_classes)
        )
    def forward(self, x):
        return self.net(x)

They don't help

This section describes the two functions of the training and evaluation tool.

train_period_one handles one full pass over the training data. It puts the model into training mode, iterates in small batches, calculates losses, performs backpropagation, and updates the model weights using the optimizer. It also tracks and returns the average loss across all batches to monitor training progress.

check it out is used to measure the performance of the models. It switches the model to test mode (disables dropout and integration), makes predictions on the input data, and calculates the accuracy by comparing the predicted labels with the actual labels.

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for xb, yb in loader:
        optimizer.zero_grad()
        loss = criterion(model(xb), yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)
 
 
def evaluate(model, X, y):
    model.eval()
    with torch.no_grad():
        preds = model(X).argmax(dim=1)
    return (preds == y).float().mean().item()

Ensemble Training

This section trains a set of teachers, which serve as a source of information for filtering. Instead of a single model, 12 teacher models are trained independently with different random initializations, allowing each to learn slightly different patterns in the data. This diversity is what makes ensembles so powerful.

Each teacher is trained for several seasons until the meeting, and the accuracy of each of his tests is printed. Once all the models are trained, their predictions are combined using soft voting—by averaging their output instead of taking a simple majority vote. This produces a strong, stable final prediction, giving you a high-performance cluster that will act as a “teacher” in the next step.

print("=" * 55)
print("STEP 1: Training the 12-model Teacher Ensemble")
print("        (this happens offline, not in production)")
print("=" * 55)
 
NUM_TEACHERS = 12
teachers = []
 
for i in range(NUM_TEACHERS):
    torch.manual_seed(i)                           # different init per teacher
    model = TeacherModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
 
    for epoch in range(30):                        # train until convergence
        train_one_epoch(model, train_loader, optimizer, criterion)
 
    acc = evaluate(model, X_test_t, y_test_t)
    print(f"  Teacher {i+1:02d} -> test accuracy: {acc:.4f}")
    model.eval()
    teachers.append(model)
 
# Soft voting: average logits across all teachers (stronger than majority vote)
with torch.no_grad():
    avg_logits     = torch.stack([t(X_test_t) for t in teachers], dim=0).mean(dim=0)
    ensemble_preds = avg_logits.argmax(dim=1)
ensemble_acc = (ensemble_preds == y_test_t).float().mean().item()
print(f"n  Ensemble (soft vote) accuracy: {ensemble_acc:.4f}")

Lean Directed Production from Ensemble

This step generates soft targets from a pool of qualified teachers, which is the key ingredient in the data mining process. Instead of using hard labels (0 or 1), ensemble mean predictions are transformed into a probability distribution, which captures how confident the model is across categories.

The function first averages the logs from all teachers (soft polling), then uses temperature scaling to smooth the probabilities. A high temperature (such as 3.0) softens the distribution, revealing subtle relationships between classes that strong labels cannot capture. These soft targets provide rich learning signals, allowing the learner model to better estimate the behavior of the cluster.

TEMPERATURE = 3.0   # controls how "soft" the teacher's output is
 
def get_ensemble_soft_targets(teachers, X, T):
    """
    Average logits from all teachers, then apply temperature scaling.
    Soft targets carry richer signal than hard 0/1 labels.
    """
    with torch.no_grad():
        logits = torch.stack([t(X) for t in teachers], dim=0).mean(dim=0)
    return F.softmax(logits / T, dim=1)   # soft probability distribution
 
soft_targets = get_ensemble_soft_targets(teachers, X_train_t, TEMPERATURE)
 
print(f"n  Sample hard label : {y_train_t[0].item()}")
print(f"  Sample soft target: [{soft_targets[0,0]:.4f}, {soft_targets[0,1]:.4f}]")
print("  -> Soft target carries confidence info, not just class identity.")

Distillation: Training the Student

This section trains the learner model using filtering information, when it learns from both the teacher's collection and the truth labels. A new data loader is implemented that provides inputs and hard and soft target labels together.

During training, two losses are calculated:

  • Distillation loss (KL-divergence) encourages the student to match the soft probability distribution of the teacher, conveying the “knowledge” of the collection.
  • Losing a strong label (cross-entropy) ensures that the reader is still consistent with the ground truth.

This is combined using a weighting factor (ALPHA), where a higher value gives more importance to the teacher's guidance. Temperature scaling is also used to maintain consistency and soft targeting, and a rescaling feature ensures stable gradients. Over the course of many seasons, the student gradually learns to balance the behavior of the collection while remaining small and efficient in use.

print("n" + "=" * 55)
print("STEP 2: Training the Student via Knowledge Distillation")
print("        (this produces the single production model)")
print("=" * 55)
 
ALPHA  = 0.7    # weight on distillation loss (0.7 = mostly soft targets)
EPOCHS = 50
 
student    = StudentModel()
optimizer  = torch.optim.Adam(student.parameters(), lr=1e-3, weight_decay=1e-4)
ce_loss_fn = nn.CrossEntropyLoss()
 
# Dataloader that yields (inputs, hard labels, soft targets) together
distill_loader = DataLoader(
    TensorDataset(X_train_t, y_train_t, soft_targets),
    batch_size=64, shuffle=True
)
 
for epoch in range(EPOCHS):
    student.train()
    epoch_loss = 0
 
    for xb, yb, soft_yb in distill_loader:
        optimizer.zero_grad()
 
        student_logits = student(xb)
 
        # (1) Distillation loss: match the teacher's soft distribution
        #     KL-divergence between student and teacher outputs at temperature T
        student_soft = F.log_softmax(student_logits / TEMPERATURE, dim=1)
        distill_loss = F.kl_div(student_soft, soft_yb, reduction='batchmean')
        distill_loss *= TEMPERATURE ** 2   # rescale: keeps gradient magnitude
                                           # stable across different T values
 
        # (2) Hard label loss: also learn from ground truth
        hard_loss = ce_loss_fn(student_logits, yb)
 
        # Combined loss
        loss = ALPHA * distill_loss + (1 - ALPHA) * hard_loss
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
 
    if (epoch + 1) % 10 == 0:
        acc = evaluate(student, X_test_t, y_test_t)
        print(f"  Epoch {epoch+1:02d}/{EPOCHS}  loss: {epoch_loss/len(distill_loader):.4f}  "
              f"student accuracy: {acc:.4f}")

The student is trained on Hard Labels only

This phase trains a base learner model without propagating information, using only ground truth labels. The architecture is similar to that of a distilled reader, ensuring a fair comparison.

The model is trained in a conventional way with cross-entropy loss, learning directly from the hard labels without any guidance from the set of teachers. After training, its accuracy is tested on the test set.

This baseline serves as a reference point—allowing you to clearly measure how much performance benefit comes directly from digestion, rather than just the volume of the learner model or training process.

print("n" + "=" * 55)
print("BASELINE: Student trained on hard labels only (no distillation)")
print("=" * 55)
 
baseline_student = StudentModel()
b_optimizer = torch.optim.Adam(
    baseline_student.parameters(), lr=1e-3, weight_decay=1e-4
)
 
for epoch in range(EPOCHS):
    train_one_epoch(baseline_student, train_loader, b_optimizer, ce_loss_fn)
 
baseline_acc = evaluate(baseline_student, X_test_t, y_test_t)
print(f"  Baseline student accuracy: {baseline_acc:.4f}")

Comparison

To measure how well the ensemble's knowledge conveys, we run three models against the same captured test set. The ensemble – all 12 teachers vote together on average – set the ceiling accuracy at 97.80%. This is a number we are trying to measure, not beat. The basic learner is the same structural model trained in the usual way, only on strong labels: it sees each sample as a binary match 0 or 1, nothing else. It is up to 96.50%. The fused learner is still of the same structure, but trained on the output of the cluster soft probability at temperature T=3, with the combined loss weighted 70% by matching the teacher distribution and 30% by the ground truth labels. It reaches 97.20%.

The 0.70 point gap between the baseline and the stripped student is not random seeding or training noise — it's a measurable amount of soft targets. The reader did not get more data, better architecture, or more computation. It received a rich training signal, and that alone returned 53.8% of the gap between what the small model alone could learn and what the full ensemble knew. The remaining gap of 0.60 percent between the stripped student and the cluster is the reliable cost of suppression – the part of the cluster's information that the 3,490-parameter model cannot capture, no matter how well trained.

distilled_acc = evaluate(student, X_test_t, y_test_t)
 
print("n" + "=" * 55)
print("RESULTS SUMMARY")
print("=" * 55)
print(f"  Ensemble  (12 models, production-undeployable) : {ensemble_acc:.4f}")
print(f"  Student   (distilled, production-ready)        : {distilled_acc:.4f}")
print(f"  Baseline  (student, hard labels only)          : {baseline_acc:.4f}")
 
gap      = ensemble_acc - distilled_acc
recovery = (distilled_acc - baseline_acc) / max(ensemble_acc - baseline_acc, 1e-9)
print(f"n  Accuracy gap vs ensemble       : {gap:.4f}")
print(f"  Knowledge recovered vs baseline: {recovery*100:.1f}%")
def count_params(m):
    return sum(p.numel() for p in m.parameters())
 
single_teacher_params = count_params(teachers[0])
student_params        = count_params(student)
 
print(f"n  Single teacher parameters : {single_teacher_params:,}")
print(f"  Full ensemble parameters  : {single_teacher_params * NUM_TEACHERS:,}")
print(f"  Student parameters        : {student_params:,}")
print(f"  Size reduction            : {single_teacher_params * NUM_TEACHERS / student_params:.0f}x")

Check it out Full Codes here. Also, feel free to follow us Twitter and don't forget to join our 120k+ ML SubReddit and Subscribe to Our newspaper. Wait! are you on telegram? now you can join us on telegram too.

Need to work with us on developing your GitHub Repo OR Hug Face Page OR Product Release OR Webinar etc.? contact us


I am a Civil Engineering Graduate (2022) from Jamia Millia Islamia, New Delhi, and I am very interested in Data Science, especially Neural Networks and its application in various fields.

Source link

Related Articles

Leave a Reply

Your email address will not be published. Required fields are marked *

Back to top button