[C#] 矩阵的一种实现 Ver. 2

using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;

namespace Traph
{
    public class Matrix : ICloneable, IEnumerable<Matrix.Row>, IEquatable<Matrix>
    {
        internal readonly double[,] values;
        public int Rows { get; }
        public int Cols { get; }
        public double this[int row, int col]
        {
            set => values[row, col] = value;
            get => values[row, col];
        }
        public Matrix(double[,] values, bool copyValues = false)
        {
            this.values = copyValues ? (double[,])values.Clone() : values;
            Rows = values.GetLength(0);
            Cols = values.GetLength(1);
        }
        public object Clone() => new Matrix(values, true);

        public IEnumerator<Row> GetEnumerator()
        {
            for (var i = 0; i < Rows; i++)
                yield return new Row(this, i);
        }

        IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

        public class Row : IEnumerable<double>
        {
            private readonly Matrix matrix;
            private readonly int row;
            internal Row(Matrix matrix, int row)
            {
                this.matrix = matrix;
                this.row = row;
            }
            public IEnumerator<double> GetEnumerator()
            {
                for (var i = 0; i < matrix.Cols; i++)
                    yield return matrix.values[row, i];
            }

            IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

            public override string ToString() => string.Join("\t", this);
        }
        public bool Equals([AllowNull] Matrix other)
        {
            if (base.Equals(other)) return true;
            if (other == null) return false;
            if (Rows != other.Rows || Cols != other.Cols) return false;
            bool ret = true;
            ByRowCol((i, j) =>
            {
                if (values[i, j] != other.values[i, j])
                {
                    ret = false;
                    return;
                }
            });
            return ret;
        }
        public static bool operator ==(Matrix left, Matrix right) => left?.Equals(right) ?? (right == null);
        public static bool operator !=(Matrix left, Matrix right) => !(left == right);
        public override bool Equals(object obj) => Equals(obj as Matrix);
        public override int GetHashCode()
        {
            var hash = HashCode.Combine(Rows, Cols);
            foreach (var value in values) hash = HashCode.Combine(hash, value);
            return hash;
        }
        public override string ToString() => string.Join(Environment.NewLine, this);
        private void ByRowCol(Action<int, int> operation)
        {
            for (var i = 0; i < Rows; i++)
                for (var j = 0; j < Cols; j++)
                    operation(i, j);
        }
        public Matrix Add(Matrix right)
        {
            if (Rows != right.Rows || Cols != right.Cols) throw new ArithmeticException();
            var ret = (Matrix)Clone();
            ret.ByRowCol((row, col) => ret.values[row, col] += right.values[row, col]);
            return ret;
        }
        public static Matrix operator +(Matrix left, Matrix right) => left.Add(right);
        public Matrix Minus
        {
            get
            {
                var ret = (Matrix)Clone();
                ret.ByRowCol((row, col) => ret.values[row, col] *= -1);
                return ret;
            }
        }
        public static Matrix operator -(Matrix matrix) => matrix.Minus;
        public static Matrix operator -(Matrix left, Matrix right) => left.Add(right.Minus);
        public Matrix Multiply(double number)
        {
            var ret = (Matrix)Clone();
            ret.ByRowCol((row, col) => ret.values[row, col] *= number);
            return ret;
        }
        public static Matrix operator *(double left, Matrix right) => right.Multiply(left);
        public static Matrix operator *(Matrix left, double right) => left.Multiply(right);
        public Matrix Multiply(Matrix right)
        {
            if (Cols != right.Rows) throw new ArithmeticException();
            var ret = new Matrix(new double[Rows, right.Cols]);
            ret.ByRowCol((row, col) =>
            {
                for (var i = 0; i < Cols; i++)
                    ret.values[row, col] += values[row, i] * right.values[i, col];
            });
            return ret;
        }
        public static Matrix operator *(Matrix left, Matrix right) => left.Multiply(right);
        public Matrix Transposition
        {
            get
            {
                var ret = new Matrix(new double[Cols, Rows]);
                ret.ByRowCol((row, col) => ret.values[row, col] = values[col, row]);
                return ret;
            }
        }
        public Matrix T => Transposition;
    }
    public static class MatrixExtention
    {
        public static Matrix AsMatrix(this double[,] doubles) => new Matrix(doubles, true);
    }
}

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据