Skip to content
Advertisement

How does this iterative loop fit the model? ( Machine Learning)

I’m confused as to what for m in range(1, len(X_train)): is doing in the line model.fit(X_train[:m], y_train[:m]) y_train_predict = model.predict(X_train[:m]) . So I think that m is going to loop over the size of the training data.and that for each loop m+=1 but I don’t understand the rest

from sklearn.model_selection import train_test_split

def plot_learning_curves(model, X, y):
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=10)
    train_errors, val_errors = [], []
    for m in range(1, len(X_train)):
        model.fit(X_train[:m], y_train[:m])
        y_train_predict = model.predict(X_train[:m])
        y_val_predict = model.predict(X_val)
        train_errors.append(mean_squared_error(y_train_predict, y_train[:m]))
        val_errors.append(mean_squared_error(y_val_predict, y_val))

    plt.figure(figsize=(8,4))
    plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="Training set")
    plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="Validation set")
    plt.legend(loc="upper right", fontsize=14)   
    plt.xlabel("Training set size", fontsize=14) 
    plt.ylabel("RMSE", fontsize=14)      ```

Advertisement

Answer

The purpose of this function is to show the performance of a model when trained with different sized datasets. Indexing an array with X[:m] selects the first m elements of X (along the first dimension if X is multi-dimensional). For each value of m in the for loop, it’s saying “let’s pretend we only had m data points, what would our training and validation accuracies be?”. You should see that, for small m, the model would overfit so the training accuracy would be close to perfect and the validation accuracy would be very low. As m increases, the training accuracy will decrease, but the validation accuracy increase. The exact shapes of the curves are useful for diagnosing underfitting/overfitting.

User contributions licensed under: CC BY-SA
3 People found this is helpful
Advertisement