Understanding Mean and Variance
In machine learning & deep learning, mean (average) and variance (spread of data) are important for:
- Data normalization (scaling inputs before feeding into neural networks).
- Loss function analysis (tracking the spread of errors).
- Weight initialization (ensuring stable training).
Mean
The mean is the average of all values:
mean = (sum of all values) / (number of values)
NumPy Example:
import numpy as np
arr = np.array([3, 1, 7, 0, 5])
print("Mean:", np.mean(arr)) # (3+1+7+0+5) / 5 = 3.2
PyTorch Example:
import torch
tensor = torch.tensor([3, 1, 7, 0, 5])
print("Mean:", torch.mean(tensor.float())) # 3.2
Why
.float()
in PyTorch? PyTorch does integer division by default, so.float()
ensures correct results.
Variance
The variance measures how far the values are from the mean:
variance = (sum of squared differences from the mean) / (number of values)
NumPy Example:
print("Variance:", np.var(arr)) # 6.5600000000000005
PyTorch Example:
print("Variance:", torch.var(tensor.float())) # Measures spread of data
✅ Higher variance → More spread out data.
✅ Lower variance → More concentrated data.
Mean & Variance on Multi-Dimensional Data
NumPy Multi-Dimensional Mean/Variance
arr2D = np.array([[3, 7, 2],
[5, 1, 8]])
print("Mean (Overall):", np.mean(arr2D))
print("Mean per Column:", np.mean(arr2D, axis=0)) # Column-wise
print("Variance per Row:", np.var(arr2D, axis=1)) # Row-wise
PyTorch Multi-Dimensional Mean/Variance
tensor2D = torch.tensor([[3, 7, 2],
[5, 1, 8]])
print("Mean (Overall):", torch.mean(tensor2D.float()))
print("Mean per Column:", torch.mean(tensor2D.float(), dim=0)) # Column-wise
print("Variance per Row:", torch.var(tensor2D.float(), dim=1)) # Row-wise
👉 dim=0
→ Column-wise (downwards)
👉 dim=1
→ Row-wise (horizontally)
Why are Mean & Variance Important in Deep Learning?
1) Feature Scaling (Normalization & Standardization)
Before feeding data into a neural network, we normalize it:
normalized_X = (X - mean) / sqrt(variance)
NumPy Example:
X = np.array([3, 1, 7, 0, 5])
X_normalized = (X - np.mean(X)) / np.sqrt(np.var(X))
print("Normalized Data:", X_normalized)
PyTorch Example:
X_tensor = torch.tensor([3, 1, 7, 0, 5], dtype=torch.float32)
X_normalized = (X_tensor - torch.mean(X_tensor)) / torch.sqrt(torch.var(X_tensor))
print("Normalized Data:", X_normalized)
✅ Normalization helps neural networks converge faster by keeping values in a standard range.
2) Weight Initialization in Neural Networks
- If weights have high variance, gradients explode! 🚀
- If variance is too low, learning is too slow. 🐢
Xavier/Glorot Initialization ensures balanced variance:
W ~ Normal(0, 1 / number_of_neurons)
PyTorch Example:
import torch.nn as nn
layer = nn.Linear(10, 5)
nn.init.xavier_normal_(layer.weight)
✅ Ensures weights are properly scaled to avoid training issues.
3) Loss Function Behavior
- Mean Squared Error (MSE) loss measures variance between predicted & actual values:
MSE = (1 / N) * sum((true_value - predicted_value)²)
PyTorch Example:
y_true = torch.tensor([1.0, 2.0, 3.0])
y_pred = torch.tensor([1.1, 1.9, 2.8])
mse_loss = torch.mean((y_true - y_pred) ** 2)
print("MSE Loss:", mse_loss.item()) # Measures variance of prediction errors
✅ Lower variance → More accurate predictions.
Visualizing Mean & Variance
Let's plot variance to see how spread out the data is.
import matplotlib.pyplot as plt
# Generate two datasets: one with high variance, one with low variance
low_variance = np.random.normal(5, 1, 1000) # Mean=5, Low variance
high_variance = np.random.normal(5, 5, 1000) # Mean=5, High variance
plt.figure(figsize=(10, 5))
# Low variance plot
plt.hist(low_variance, bins=30, alpha=0.6, label="Low Variance", color="blue")
# High variance plot
plt.hist(high_variance, bins=30, alpha=0.6, label="High Variance", color="red")
plt.legend()
plt.title("Distribution of Low vs. High Variance Data")
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.show()
Conclusion 🚀
Function | What It Does | Example Use in Deep Learning |
---|---|---|
mean |
Finds the average value | Normalize inputs before training |
var |
Finds the spread of data | Track variance in loss values |
dim=0 |
Column-wise operations | Normalize features |
dim=1 |
Row-wise operations | Aggregate values across samples |