1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
//! The Shuffler
//!
//! This module contains the `Shuffler` transformer. `Shuffler` implements the
//! `Transformer` trait and is used to shuffle the rows of an input matrix.
//! You can control the random number generator used by the `Shuffler`.
//!
//! # Examples
//!
//! ```
//! use rusty_machine::linalg::Matrix;
//! use rusty_machine::data::transforms::Transformer;
//! use rusty_machine::data::transforms::shuffle::Shuffler;
//!
//! // Create an input matrix that we want to shuffle
//! let mat = Matrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
//!
//! // Create a new shuffler
//! let mut shuffler = Shuffler::default();
//! let shuffled_mat = shuffler.transform(mat).unwrap();
//!
//! println!("{}", shuffled_mat);
//! ```

use learning::LearningResult;
use learning::error::Error;
use linalg::{Matrix, BaseMatrix, BaseMatrixMut};
use super::Transformer;

use rand::{Rng, thread_rng, ThreadRng};

/// The `Shuffler`
///
/// Provides an implementation of `Transformer` which shuffles
/// the input rows in place.
#[derive(Debug)]
pub struct Shuffler<R: Rng> {
    rng: R,
}

impl<R: Rng> Shuffler<R> {
    /// Construct a new `Shuffler` with given random number generator.
    ///
    /// # Examples
    ///
    /// ```
    /// # extern crate rand;
    /// # extern crate rusty_machine;
    ///
    /// use rusty_machine::data::transforms::Transformer;
    /// use rusty_machine::data::transforms::shuffle::Shuffler;
    /// use rand::{StdRng, SeedableRng};
    ///
    /// # fn main() {
    /// // We can create a seeded rng
    /// let rng = StdRng::from_seed(&[1, 2, 3]);
    ///
    /// let shuffler = Shuffler::new(rng);
    /// # }
    /// ```
    pub fn new(rng: R) -> Self {
        Shuffler { rng: rng }
    }
}

/// Create a new shuffler using the `rand::thread_rng` function
/// to provide a randomly seeded random number generator.
impl Default for Shuffler<ThreadRng> {
    fn default() -> Self {
        Shuffler { rng: thread_rng() }
    }
}

/// The `Shuffler` will transform the input `Matrix` by shuffling
/// its rows in place.
///
/// Under the hood this uses a Fisher-Yates shuffle.
impl<R: Rng, T> Transformer<Matrix<T>> for Shuffler<R> {

    #[allow(unused_variables)]
    fn fit(&mut self, inputs: &Matrix<T>) -> Result<(), Error> {
        Ok(())
    }

    fn transform(&mut self, mut inputs: Matrix<T>) -> LearningResult<Matrix<T>> {
        let n = inputs.rows();

        for i in 0..n {
            // Swap i with a random point after it
            let j = self.rng.gen_range(0, n - i);
            inputs.swap_rows(i, i + j);
        }

        Ok(inputs)
    }
}

#[cfg(test)]
mod tests {
    use linalg::Matrix;
    use super::super::Transformer;
    use super::Shuffler;

    use rand::{StdRng, SeedableRng};

    #[test]
    fn seeded_shuffle() {
        let rng = StdRng::from_seed(&[1, 2, 3]);
        let mut shuffler = Shuffler::new(rng);

        let mat = Matrix::new(4, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
        let shuffled = shuffler.transform(mat).unwrap();

        assert_eq!(shuffled.into_vec(),
                   vec![3.0, 4.0, 1.0, 2.0, 7.0, 8.0, 5.0, 6.0]);
    }

    #[test]
    fn shuffle_single_row() {
        let mut shuffler = Shuffler::default();

        let mat = Matrix::new(1, 8, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
        let shuffled = shuffler.transform(mat).unwrap();

        assert_eq!(shuffled.into_vec(),
                   vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
    }

    #[test]
    fn shuffle_fit() {
        let rng = StdRng::from_seed(&[1, 2, 3]);
        let mut shuffler = Shuffler::new(rng);

        // no op
        let mat = Matrix::new(4, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
        let res = shuffler.fit(&mat).unwrap();

        assert_eq!(res, ());
    }
}