Swap Training and Test Data During Cross-Validation in scikit-learn

Scikit-learn is a well known Python machine learning library. It provides various utilities for machine learning, including those for cross-validation. In a standard \(K\)-fold cross-validation, the data are split into \(K\) subsets (with equal size). There are \(K\) rounds of training and testing. In each round, one subset is used as test data and all other subsets are used as training data. Under this setup, as long as \(K > 2\), there are always more training data than test data in each round of the cross-validation. Whilst this is desirable in most cases, in some machine learning applications, it is more desirable to have training data less than test data. For example, in graph embedding, each node in the network has a vector representation and labels. When running cross-validation, it is more desirable to use a smaller number of nodes as training data than the number of nodes as test data, since this better mimics the real-world scenario in terms of the amount of available training data (e.g., here). In scikit-learn, we can achieve this by swapping training and test data.

The simplest and perhaps also the cleanest way to implement cross-validation in scikit-learn is to use cross_val_score (The example is modified from here):

# Load data and the classifier
from sklearn import datasets, svm
iris = datasets.load_iris()  # features: iris.data, labels: iris.target
clf = svm.SVC(kernel='linear', C=1)

# Cross-validation
from sklearn.model_selection import cross_val_score, StratifiedKFold
cross_val_score(clf, iris.data, iris.target, cv=StratifiedKFold(5))

To swap training and test data, one obvious way is to abandon corss_val_score and explicitly iterate over all rounds of cross-validation. However, this will make the code much more redundant and uglier and thus we should avoid it if possible.

Here is my solution. We define a swapper class (which you can use across all your Python code). It wraps a splitter class and swaps the returned training and test data when split() is called:

class SplitterSwapper:
    '''Swaps training and test data in a splitter class.'''
    def __init__(self, fold):
        self.fold = fold
    def split(self, *args):
        for training, testing in self.fold.split(*args):
            yield testing, training

Then, we simply need to replace the cross_val_score line by wrapping StratifiedKFold with SplitterSwapper:

cross_val_score(clf, iris.data, iris.target, cv=SplitterSwapper(StratifiedKFold(5)))

Leave a Reply

Your email address will not be published.