From f8d1608036855e8b7a5f6d7ebe4f1fcf64334d33 Mon Sep 17 00:00:00 2001 From: william Date: Sun, 28 May 2023 11:02:28 -0400 Subject: [PATCH] Added matrices --- src/matrix.rs | 146 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 src/matrix.rs diff --git a/src/matrix.rs b/src/matrix.rs new file mode 100644 index 0000000..c40dae8 --- /dev/null +++ b/src/matrix.rs @@ -0,0 +1,146 @@ +use std::ops; + +#[derive(Debug)] +pub struct Matrix { + row_count: usize, + col_count: usize, + default_val: T, + rows: Vec>, +} + +#[derive(Debug)] +pub enum MatrixError { + RowIndexOutOfBound(usize), + ColIndexOutOfBound(usize), + IncompatibleSize, +} + +impl Matrix where + T: Copy { + pub fn new(row_count: usize, col_count: usize, default_val: T) -> Self { + let rows = vec![vec![default_val; col_count]; row_count]; + Matrix { row_count, col_count, default_val, rows } + } + + pub fn new_with_size(matrix: &Self) -> Self { + Matrix::new(matrix.row_count, matrix.col_count, matrix.default_val) + } + + pub fn get(&self, row: usize, col: usize) -> Option<&T> { + if row >= self.row_count || col >= self.col_count { + return None; + } + + Some(&self.rows[row][col]) + } + + pub fn set(&mut self, val: T, row: usize, col: usize) -> Result<(), MatrixError> { + if row >= self.row_count { + return Err(MatrixError::RowIndexOutOfBound(row)); + } + + if col >= self.col_count { + return Err(MatrixError::ColIndexOutOfBound(col)); + } + + self.rows[row][col] = val; + Ok(()) + } + + fn size_equal(&self, other: &Self) -> bool { + self.row_count == other.row_count && self.col_count == other.col_count + } +} + +impl ops::Add> for Matrix where + T: ops::Add, T: Copy { + type Output = Result, MatrixError>; + + fn add(self, rhs: Matrix) -> Self::Output { + if !&self.size_equal(&rhs) { + return Err(MatrixError::IncompatibleSize); + } + + let mut result_matrix = Matrix::new_with_size(&self); + + for row in 0..self.row_count { + for col in 0..self.col_count { + // Since the sizes of the matrices are known, we can ignore errors + let val = self.rows[row][col] + rhs.rows[row][col]; + result_matrix.set(val, row, col).unwrap(); + } + } + + Ok(result_matrix) + } +} + +impl ops::Sub> for Matrix where + T: ops::Sub, T: Copy { + type Output = Result, MatrixError>; + + fn sub(self, rhs: Matrix) -> Self::Output { + if !&self.size_equal(&rhs) { + return Err(MatrixError::IncompatibleSize); + } + + let mut result_matrix = Matrix::new_with_size(&self); + + for row in 0..self.row_count { + for col in 0..self.col_count { + // Since the sizes of the matrices are known, we can ignore errors + let val = self.rows[row][col] - rhs.rows[row][col]; + result_matrix.set(val, row, col).unwrap(); + } + } + + Ok(result_matrix) + } +} + +impl ops::Mul for Matrix where + T: ops::Mul, T: Copy { + type Output = Matrix; + + fn mul(self, rhs: T) -> Self::Output { + let mut result_matrix = Matrix::new_with_size(&self); + + for row in 0..self.row_count { + for col in 0..self.col_count { + let val = self.rows[row][col] * rhs; + result_matrix.set(val, row, col).unwrap(); + } + } + + result_matrix + } +} + +impl ops::Mul> for Matrix where + T: ops::Add, T: ops::Mul, T: Copy { + type Output = Result, MatrixError>; + + fn mul(self, rhs: Matrix) -> Self::Output { + if self.col_count != rhs.row_count { + return Err(MatrixError::IncompatibleSize); + } + + let mut result_matrix = Matrix::new(self.row_count, rhs.col_count, self.default_val); + + for row in 0..result_matrix.row_count { + for col in 0..result_matrix.col_count { + let mut val = self.default_val; + + for i in 0..self.col_count { + let val_i = self.rows[row][i]; + let val_j = rhs.rows[i][col]; + val = val + val_i * val_j; + } + + result_matrix.set(val, row, col).unwrap(); + } + } + + Ok(result_matrix) + } +}