Stratified Sampling
In simple terms, in Stratified Sampling you: - Split the dataset into subgroups (strata) - Apply random sampling on each subgroup
Connection to the Data Generating Process
In statistics and machine learning, we usually assume that the observed data are realizations from a Data Generating Process (DGP).
You can think of the DGP as the probabilistic mechanism that produces the data we observe.
The DGP defines a joint probability distribution over all variables in the population.
In practice, however, we rarely observe the entire population generated by the DGP. Instead, we collect a dataset, which is itself a sample from that population.
Ideally, this dataset should preserve the probabilistic structure of the DGP. If the sampling procedure distorts the distribution of key variables, the dataset may no longer represent the process that generated the data.
This is where sampling methods become important.
Before introducing stratified sampling, it is useful to start with the most basic assumption used in statistics: random sampling.
Random Sampling
Random sampling means that every observation in the population has the same probability of being selected.
Under this assumption, the dataset behaves like a miniature version of the population generated by the DGP.
Empirical Distribution
When we collect a dataset, we can compute what is called the empirical distribution of the variables.
The empirical distribution is simply the distribution observed in the data. For example, if we have a dataset of customers and:
- 56% are from Brazil
- 30% are from the United States
- 14% are from Germany
then those percentages represent the empirical distribution of the variable “country” in the dataset.
In statistics, one of the key ideas is that if the dataset is randomly drawn from the population, the empirical distribution tends to approximate the true distribution of the DGP as the sample size increases.
The Asymptotic Idea
This idea comes from results such as the Law of Large Numbers.
Intuitively:
- small samples can fluctuate a lot
- larger samples stabilize over time
The expectation is that as the sample size grows, the empirical distribution gets closer to the true distribution generated by the DGP.
This is why random sampling is such a fundamental assumption in statistics and machine learning: it allows the data we observe to approximate the probabilistic structure of the underlying process.
Why Random Sampling Is Not Always Enough
However, real-world datasets often contain important subgroups whose proportions matter for analysis or model evaluation.
If we rely purely on random sampling, these groups might end up underrepresented in the dataset due to random fluctuations.
This issue becomes more likely when:
- the dataset is small, or
- the subgroup of interest is rare.
In these situations, stratified sampling becomes useful because it stabilizes the proportions of important variables when creating smaller samples from the dataset.
Stratified Sampling
Stratified sampling is a technique where the population (or dataset) is divided into strata, and observations are sampled from each stratum.
The word strata is statistical jargon for groups. In practice, it simply means splitting the data according to a variable of interest.
This idea is very similar to:
- GROUP BY in SQL
- groupby() in pandas
Once the dataset is divided into groups, we sample observations within each group instead of sampling from the whole dataset.
The goal is to ensure that the sample preserves the proportions of those groups found in the original data.
In practice, stratified sampling is mainly used to reduce sampling variability when the number of observations in some groups is small.
Connection to Marginal Distributions
When we stratify on a variable, we are trying to preserve the marginal distribution of that variable.
Recall that the data generating process (DGP) defines a joint probability distribution over all variables in the population.
For example, suppose we observe two variables:
- \(X =\) country
- \(Y =\) purchase amount
The DGP defines the joint distribution
\[ P(X, Y) \]
which describes how both variables occur together in the population.
From the joint distribution, we can obtain the marginal distribution of a variable by ignoring the other variable.
For example, the marginal distribution of \(X\) is
\[ P(X) = \sum_{y} P(X, Y = y) \]
Intuitively, the marginal distribution tells us how often each value of a variable appears in the population, regardless of the other variables.
So if we look only at the variable country, ignoring purchase amount, the population distribution might be:
- Brazil: 56%
- United States: 30%
- Germany: 14%
However, the population is a theoretical object defined by the DGP. In practice, we almost never observe the entire population.
Instead, we collect a dataset, which is a finite sample of observations drawn from that population.
From this dataset, we compute the empirical distribution, which is simply the distribution of the values observed in the dataset.
For example, in our dataset we might observe:
- Brazil: 55%
- United States: 31%
- Germany: 14%
Let \(X\) denote the variable country, whose possible values are categories such as Brazil, United States, and Germany.
These percentages represent the empirical marginal distribution of \(X\) in the dataset, which we use as an estimate of the true population distribution \(P(X)\).
In many data science or machine learning workflows, we then create smaller samples from the dataset, such as:
- a training dataset
- a test dataset
Stratified sampling aims to ensure that these smaller samples preserve the distribution observed in the original dataset.
In other words, if the marginal distribution of \(X\) in the population is \(P(X)\), we would like the empirical distribution in the sampled dataset to satisfy
\[ \hat{P}_{sample}(X) \approx P(X) \]
so that the sampled data behaves as if it were generated by the same underlying process.
Example 1 - Class Imbalance in Machine Learning
Suppose you are building a model to detect fraudulent transactions.
Your dataset contains:
- 98% legitimate transactions
- 2% fraudulent transactions
If we randomly split the dataset into training and test datasets, the expected proportion of fraud cases in the test dataset is still 2%.
However, when the test dataset is relatively small, random sampling can produce noticeable deviations from this proportion.
For example, the test dataset might end up containing:
- 1% fraud cases
- or 4% fraud cases
even though the true fraud rate in the population is 2%.
This can create instability in model evaluation. In machine learning, model performance refers to how well a model’s predictions match the true outcomes observed in a test dataset.
For classification problems such as fraud detection, this comparison is typically summarized using a confusion matrix.
A confusion matrix is a table that compares the predicted class with the true class of each observation.
| Actual Fraud | Actual Legitimate | |
|---|---|---|
| Predicted Fraud | True Positive (TP) | False Positive (FP) |
| Predicted Legitimate | False Negative (FN) | True Negative (TN) |
Many performance metrics are computed directly from these counts.
For example:
\[ \text{Recall} = \frac{TP}{TP + FN} \]
The denominator \((TP + FN)\) represents all observations that are truly positive.
In the fraud example, this means all transactions that were actually fraudulent, regardless of whether the model detected them or not.
Recall therefore measures the proportion of actual fraud cases that the model successfully detected.
In other words:
Of all the fraudulent transactions that actually occurred, how many did the model correctly identify?
\[ \text{Precision} = \frac{TP}{TP + FP} \]
The denominator \((TP + FP)\) represents all observations that the model predicted as positive.
In the fraud example, this means all transactions that the model flagged as fraud, whether the prediction was correct or not.
Precision therefore measures how reliable the model’s fraud predictions are.
In other words:
Of all the transactions that the model labeled as fraud, how many were truly fraudulent?
If the test dataset contains very few fraud cases, then the values of \(TP\) and \(FN\) are based on a small number of observations.
For example, if the test dataset contains only 10 fraud cases, correctly detecting one additional case changes recall by 10 percentage points.
Using stratified sampling on the fraud variable (\(Y\)) ensures that both the training and test datasets preserve the same proportion of fraud and non-fraud cases observed in the original dataset.
This reduces random fluctuations in the class distribution and makes model evaluation more stable.
Example 2 - Country Distribution in an E-commerce Dataset
Imagine you are building a model using customer data from several countries.
Your dataset contains:
- 56% Brazil
- 30% United States
- 14% Germany
If we randomly sample observations from the dataset to create a smaller dataset, one country might become overrepresented due to random fluctuations.
Suppose the sampled dataset ends up looking like this:
- 40% Brazil
- 45% United States
- 15% Germany
Now suppose customers in the United States tend to spend more on average than customers from other countries.
If the sampled dataset contains too many U.S. customers, the estimated average purchase amount in the sample will be artificially inflated.
As a result, you might conclude that customers in this dataset spend more than expected.
However, this conclusion would partly reflect a sampling distortion, not the true structure of the population generated by the DGP.
Using stratified sampling on the country variable (\(X\)) ensures that the sample preserves the original proportions observed in the dataset.
This helps ensure that the sampled dataset remains representative of the original data.
When Is Stratified Sampling Important?
The main reason for using stratified sampling is sampling variability.
When the dataset is small or the event of interest is rare, random splits can produce large deviations in class proportions when creating training and test datasets.
However, when the dataset is large, class proportions tend to stabilize automatically due to the Law of Large Numbers.
In these cases, random sampling often preserves class proportions reasonably well, and stratified sampling becomes less critical.
In practice, stratified sampling is most useful when:
- the dataset is small
- the event of interest is rare
- evaluation metrics depend heavily on the number of positive cases