Function rusty_machine::analysis::cross_validation::k_fold_validate
[−]
[src]
pub fn k_fold_validate<M, S>(model: &mut M,
inputs: &Matrix<f64>,
targets: &Matrix<f64>,
k: usize,
score: S)
-> LearningResult<Vec<f64>> where S: Fn(&Matrix<f64>, &Matrix<f64>) -> f64, M: SupModel<Matrix<f64>, Matrix<f64>>
Randomly splits the inputs into k 'folds'. For each fold a model is trained using all inputs except for that fold, and tested on the data in the fold. Returns the scores for each fold.
Arguments
model
- Used to train and predict for each fold.inputs
- All input samples.targets
- All targets.k
- Number of folds to use.score
- Used to compare the outputs for each fold to the targets. Higher scores are better. See theanalysis::score
module for examples.
Examples
use rusty_machine::analysis::cross_validation::k_fold_validate; use rusty_machine::analysis::score::row_accuracy; use rusty_machine::learning::naive_bayes::{NaiveBayes, Bernoulli}; use rusty_machine::linalg::{BaseMatrix, Matrix}; let inputs = Matrix::new(3, 2, vec![1.0, 1.1, 5.2, 4.3, 6.2, 7.3]); let targets = Matrix::new(3, 3, vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0]); let mut model = NaiveBayes::<Bernoulli>::new(); let accuracy_per_fold: Vec<f64> = k_fold_validate( &mut model, &inputs, &targets, 3, // Score each fold by the fraction of test samples where // the model's prediction equals the target. row_accuracy ).unwrap();