Function rusty_machine::analysis::cross_validation::k_fold_validate [] [src]

pub fn k_fold_validate<M, S>(model: &mut M,
                             inputs: &Matrix<f64>,
                             targets: &Matrix<f64>,
                             k: usize,
                             score: S)
                             -> LearningResult<Vec<f64>> where S: Fn(&Matrix<f64>, &Matrix<f64>) -> f64, M: SupModel<Matrix<f64>, Matrix<f64>>

Randomly splits the inputs into k 'folds'. For each fold a model is trained using all inputs except for that fold, and tested on the data in the fold. Returns the scores for each fold.

Arguments

Examples

use rusty_machine::analysis::cross_validation::k_fold_validate;
use rusty_machine::analysis::score::row_accuracy;
use rusty_machine::learning::naive_bayes::{NaiveBayes, Bernoulli};
use rusty_machine::linalg::{BaseMatrix, Matrix};

let inputs = Matrix::new(3, 2, vec![1.0, 1.1,
                                    5.2, 4.3,
                                    6.2, 7.3]);

let targets = Matrix::new(3, 3, vec![1.0, 0.0, 0.0,
                                     0.0, 0.0, 1.0,
                                     0.0, 0.0, 1.0]);

let mut model = NaiveBayes::<Bernoulli>::new();

let accuracy_per_fold: Vec<f64> = k_fold_validate(
    &mut model,
    &inputs,
    &targets,
    3,
    // Score each fold by the fraction of test samples where
    // the model's prediction equals the target.
    row_accuracy
).unwrap();