using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
namespace Traph
{
public class Matrix : ICloneable, IEnumerable<Matrix.Row>, IEquatable<Matrix>
{
internal readonly double[,] values;
public int Rows { get; }
public int Cols { get; }
public double this[int row, int col]
{
set => values[row, col] = value;
get => values[row, col];
}
public Matrix(double[,] values, bool copyValues = false)
{
this.values = copyValues ? (double[,])values.Clone() : values;
Rows = values.GetLength(0);
Cols = values.GetLength(1);
}
public object Clone() => new Matrix(values, true);
public IEnumerator<Row> GetEnumerator()
{
for (var i = 0; i < Rows; i++)
yield return new Row(this, i);
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public class Row : IEnumerable<double>
{
private readonly Matrix matrix;
private readonly int row;
internal Row(Matrix matrix, int row)
{
this.matrix = matrix;
this.row = row;
}
public IEnumerator<double> GetEnumerator()
{
for (var i = 0; i < matrix.Cols; i++)
yield return matrix.values[row, i];
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public override string ToString() => string.Join("\t", this);
}
public bool Equals([AllowNull] Matrix other)
{
if (base.Equals(other)) return true;
if (other == null) return false;
if (Rows != other.Rows || Cols != other.Cols) return false;
bool ret = true;
ByRowCol((i, j) =>
{
if (values[i, j] != other.values[i, j])
{
ret = false;
return;
}
});
return ret;
}
public static bool operator ==(Matrix left, Matrix right) => left?.Equals(right) ?? (right == null);
public static bool operator !=(Matrix left, Matrix right) => !(left == right);
public override bool Equals(object obj) => Equals(obj as Matrix);
public override int GetHashCode()
{
var hash = HashCode.Combine(Rows, Cols);
foreach (var value in values) hash = HashCode.Combine(hash, value);
return hash;
}
public override string ToString() => string.Join(Environment.NewLine, this);
private void ByRowCol(Action<int, int> operation)
{
for (var i = 0; i < Rows; i++)
for (var j = 0; j < Cols; j++)
operation(i, j);
}
public Matrix Add(Matrix right)
{
if (Rows != right.Rows || Cols != right.Cols) throw new ArithmeticException();
var ret = (Matrix)Clone();
ret.ByRowCol((row, col) => ret.values[row, col] += right.values[row, col]);
return ret;
}
public static Matrix operator +(Matrix left, Matrix right) => left.Add(right);
public Matrix Minus
{
get
{
var ret = (Matrix)Clone();
ret.ByRowCol((row, col) => ret.values[row, col] *= -1);
return ret;
}
}
public static Matrix operator -(Matrix matrix) => matrix.Minus;
public static Matrix operator -(Matrix left, Matrix right) => left.Add(right.Minus);
public Matrix Multiply(double number)
{
var ret = (Matrix)Clone();
ret.ByRowCol((row, col) => ret.values[row, col] *= number);
return ret;
}
public static Matrix operator *(double left, Matrix right) => right.Multiply(left);
public static Matrix operator *(Matrix left, double right) => left.Multiply(right);
public Matrix Multiply(Matrix right)
{
if (Cols != right.Rows) throw new ArithmeticException();
var ret = new Matrix(new double[Rows, right.Cols]);
ret.ByRowCol((row, col) =>
{
for (var i = 0; i < Cols; i++)
ret.values[row, col] += values[row, i] * right.values[i, col];
});
return ret;
}
public static Matrix operator *(Matrix left, Matrix right) => left.Multiply(right);
public Matrix Transposition
{
get
{
var ret = new Matrix(new double[Cols, Rows]);
ret.ByRowCol((row, col) => ret.values[row, col] = values[col, row]);
return ret;
}
}
public Matrix T => Transposition;
}
public static class MatrixExtention
{
public static Matrix AsMatrix(this double[,] doubles) => new Matrix(doubles, true);
}
}