/* SPDX-FileCopyrightText: 2022 Blender Authors * * SPDX-License-Identifier: GPL-2.0-or-later */ #pragma once /** \file * \ingroup bli * * Template for matrix types. * * The `blender::MatBase` is a Row x Col matrix (in mathematical notation) laid * out as column major in memory. * * This class overloads `+, -, *` and `+=, -=, *=` mathematical operators. * They are all using per component operation, except for a few: * `MatBase * Vector` the vector product with the matrix. * `Vector * MatBase` the vector product with the **transposed** matrix. * `MatBase * MatBase` and `MatBase *= MatBase` the matrix multiplication. * * The `blender::MatView` allows working on a subset of a matrix without having to move the data * around. It can be obtained using the `MatBase.view()`. It is const by default if * the matrix type is. Otherwise, a `blender::MutableMatView` is returned. * * A `blender::MutableMatView`. It is mostly the same as `blender::MatView`, but can to be * modified. * * This allow working with any number type `T` (float, double, mpq, ...) and to use these types in * shared shader files (code compiled in both C++ and Shader language). To this end, only low level * constructors are defined inside the class itself and every function working on matrices are * defined outside of the class in the `blender::math` namespace. */ #include #include #include #include #include "BLI_math_vector_types.hh" #include "BLI_utildefines.h" #include "BLI_utility_mixins.hh" namespace blender { template struct MatView; template struct MutableMatView; template< /* Number type. */ typename T, /* Number of column in the matrix. */ int NumCol, /* Number of row in the matrix. */ int NumRow, /* Alignment in bytes. Do not align matrices whose size is not a multiple of 4 component. * This is in order to avoid padding when using arrays of matrices. */ int Alignment = (((NumCol * NumRow) % 4 == 0) ? 4 : 1) * sizeof(T)> struct alignas(Alignment) MatBase : public vec_struct_base, NumCol> { using base_type = T; using vec3_type = VecBase; using col_type = VecBase; using row_type = VecBase; using loc_type = VecBase; static constexpr int min_dim = (NumRow < NumCol) ? NumRow : NumCol; static constexpr int col_len = NumCol; static constexpr int row_len = NumRow; MatBase() = default; /* Workaround issue with template BLI_ENABLE_IF((Size == 2)) not working. */ #define BLI_ENABLE_IF_MAT(_size, _test) int S = _size, BLI_ENABLE_IF((S _test)) template MatBase(col_type _x, col_type _y) { (*this)[0] = _x; (*this)[1] = _y; } template MatBase(col_type _x, col_type _y, col_type _z) { (*this)[0] = _x; (*this)[1] = _y; (*this)[2] = _z; } template MatBase(col_type _x, col_type _y, col_type _z, col_type _w) { (*this)[0] = _x; (*this)[1] = _y; (*this)[2] = _z; (*this)[3] = _w; } /** Masking. */ template explicit MatBase(const MatBase &other) { if constexpr ((OtherNumRow >= NumRow) && (OtherNumCol >= NumCol)) { unroll([&](auto i) { (*this)[i] = col_type(other[i]); }); } else { /* Allow enlarging following GLSL standard (i.e: mat4x4(mat3x3())). */ unroll([&](auto i) { unroll([&](auto j) { if (i < OtherNumCol && j < OtherNumRow) { (*this)[i][j] = other[i][j]; } else if (i == j) { (*this)[i][j] = T(1); } else { (*this)[i][j] = T(0); } }); }); } } #undef BLI_ENABLE_IF_MAT /** Conversion from pointers (from C-style vectors). */ explicit MatBase(const T *ptr) { unroll([&](auto i) { (*this)[i] = reinterpret_cast(ptr)[i]; }); } template))> explicit MatBase(const U *ptr) { unroll([&](auto i) { (*this)[i] = ptr[i]; }); } explicit MatBase(const T (*ptr)[NumCol]) : MatBase(static_cast(ptr[0])) {} /** Conversion from other matrix types. */ template explicit MatBase(const MatBase &vec) { unroll([&](auto i) { (*this)[i] = col_type(vec[i]); }); } /** C-style pointer dereference. */ using c_style_mat = T[NumCol][NumRow]; /** \note Prevent implicit cast to types that could fit other pointer constructor. */ const c_style_mat &ptr() const { return *reinterpret_cast(this); } /** \note Prevent implicit cast to types that could fit other pointer constructor. */ c_style_mat &ptr() { return *reinterpret_cast(this); } /** \note Prevent implicit cast to types that could fit other pointer constructor. */ const T *base_ptr() const { return reinterpret_cast(this); } /** \note Prevent implicit cast to types that could fit other pointer constructor. */ T *base_ptr() { return reinterpret_cast(this); } /** View creation. */ template const MatView view() const { return MatView( const_cast(*this)); } template MutableMatView view() { return MutableMatView(*this); } /** Array access. */ const col_type &operator[](int index) const { BLI_assert(index >= 0); BLI_assert(index < NumCol); return reinterpret_cast(this)[index]; } col_type &operator[](int index) { BLI_assert(index >= 0); BLI_assert(index < NumCol); return reinterpret_cast(this)[index]; } /** Access helpers. Using Blender coordinate system. */ vec3_type &x_axis() { BLI_STATIC_ASSERT(NumCol >= 1, "Wrong Matrix dimension"); BLI_STATIC_ASSERT(NumRow >= 3, "Wrong Matrix dimension"); return *reinterpret_cast(&(*this)[0]); } vec3_type &y_axis() { BLI_STATIC_ASSERT(NumCol >= 2, "Wrong Matrix dimension"); BLI_STATIC_ASSERT(NumRow >= 3, "Wrong Matrix dimension"); return *reinterpret_cast(&(*this)[1]); } vec3_type &z_axis() { BLI_STATIC_ASSERT(NumCol >= 3, "Wrong Matrix dimension"); BLI_STATIC_ASSERT(NumRow >= 3, "Wrong Matrix dimension"); return *reinterpret_cast(&(*this)[2]); } loc_type &location() { BLI_STATIC_ASSERT(NumCol >= 3, "Wrong Matrix dimension"); BLI_STATIC_ASSERT(NumRow >= 2, "Wrong Matrix dimension"); return *reinterpret_cast(&(*this)[NumCol - 1]); } const vec3_type &x_axis() const { BLI_STATIC_ASSERT(NumCol >= 1, "Wrong Matrix dimension"); BLI_STATIC_ASSERT(NumRow >= 3, "Wrong Matrix dimension"); return *reinterpret_cast(&(*this)[0]); } const vec3_type &y_axis() const { BLI_STATIC_ASSERT(NumCol >= 2, "Wrong Matrix dimension"); BLI_STATIC_ASSERT(NumRow >= 3, "Wrong Matrix dimension"); return *reinterpret_cast(&(*this)[1]); } const vec3_type &z_axis() const { BLI_STATIC_ASSERT(NumCol >= 3, "Wrong Matrix dimension"); BLI_STATIC_ASSERT(NumRow >= 3, "Wrong Matrix dimension"); return *reinterpret_cast(&(*this)[2]); } const loc_type &location() const { BLI_STATIC_ASSERT(NumCol >= 3, "Wrong Matrix dimension"); BLI_STATIC_ASSERT(NumRow >= 2, "Wrong Matrix dimension"); return *reinterpret_cast(&(*this)[NumCol - 1]); } /** Matrix operators. */ friend MatBase operator+(const MatBase &a, const MatBase &b) { MatBase result; unroll([&](auto i) { result[i] = a[i] + b[i]; }); return result; } friend MatBase operator+(const MatBase &a, T b) { MatBase result; unroll([&](auto i) { result[i] = a[i] + b; }); return result; } friend MatBase operator+(T a, const MatBase &b) { return b + a; } MatBase &operator+=(const MatBase &b) { unroll([&](auto i) { (*this)[i] += b[i]; }); return *this; } MatBase &operator+=(T b) { unroll([&](auto i) { (*this)[i] += b; }); return *this; } friend MatBase operator-(const MatBase &a) { MatBase result; unroll([&](auto i) { result[i] = -a[i]; }); return result; } friend MatBase operator-(const MatBase &a, const MatBase &b) { MatBase result; unroll([&](auto i) { result[i] = a[i] - b[i]; }); return result; } friend MatBase operator-(const MatBase &a, T b) { MatBase result; unroll([&](auto i) { result[i] = a[i] - b; }); return result; } friend MatBase operator-(T a, const MatBase &b) { MatBase result; unroll([&](auto i) { result[i] = a - b[i]; }); return result; } MatBase &operator-=(const MatBase &b) { unroll([&](auto i) { (*this)[i] -= b[i]; }); return *this; } MatBase &operator-=(T b) { unroll([&](auto i) { (*this)[i] -= b; }); return *this; } /** Multiply two matrices using matrix multiplication. */ MatBase operator*(const MatBase &b) const { const MatBase &a = *this; /* This is the reference implementation. * Might be overloaded with vectorized / optimized code. */ /* TODO(fclem): It should be possible to return non-square matrices when multiplying against * MatBase. */ MatBase result{}; unroll([&](auto j) { unroll([&](auto i) { /* Same as dot product, but avoid dependency on vector math. */ unroll([&](auto k) { result[j][i] += a[k][i] * b[j][k]; }); }); }); return result; } /** Multiply each component by a scalar. */ friend MatBase operator*(const MatBase &a, T b) { MatBase result; unroll([&](auto i) { result[i] = a[i] * b; }); return result; } /** Multiply each component by a scalar. */ friend MatBase operator*(T a, const MatBase &b) { return b * a; } /** Multiply two matrices using matrix multiplication. */ MatBase &operator*=(const MatBase &b) { const MatBase &a = *this; *this = a * b; return *this; } /** Multiply each component by a scalar. */ MatBase &operator*=(T b) { unroll([&](auto i) { (*this)[i] *= b; }); return *this; } /** Vector operators. */ friend col_type operator*(const MatBase &a, const row_type &b) { /* This is the reference implementation. * Might be overloaded with vectorized / optimized code. */ col_type result(0); unroll([&](auto c) { result += b[c] * a[c]; }); return result; } /** Multiply by the transposed. */ friend row_type operator*(const col_type &a, const MatBase &b) { /* This is the reference implementation. * Might be overloaded with vectorized / optimized code. */ row_type result(0); unroll([&](auto c) { unroll([&](auto r) { result[c] += b[c][r] * a[r]; }); }); return result; } /** Compare. */ friend bool operator==(const MatBase &a, const MatBase &b) { for (int i = 0; i < NumCol; i++) { if (a[i] != b[i]) { return false; } } return true; } friend bool operator!=(const MatBase &a, const MatBase &b) { return !(a == b); } /** Miscellaneous. */ static MatBase diagonal(T value) { MatBase result{}; unroll([&](auto i) { result[i][i] = value; }); return result; } static MatBase all(T value) { MatBase result; unroll([&](auto i) { result[i] = col_type(value); }); return result; } static MatBase identity() { return diagonal(1); } static MatBase zero() { return all(0); } uint64_t hash() const { uint64_t h = 435109; unroll([&](auto i) { T value = (reinterpret_cast(this))[i]; h = h * 33 + *reinterpret_cast *>(&value); }); return h; } friend std::ostream &operator<<(std::ostream &stream, const MatBase &mat) { stream << "(\n"; unroll([&](auto i) { stream << "("; unroll([&](auto j) { /** NOTE: j and i are swapped to follow mathematical convention. */ stream << mat[j][i]; if (j < NumRow - 1) { stream << ", "; } }); stream << ")"; if (i < NumCol - 1) { stream << ","; } stream << "\n"; }); stream << ")\n"; return stream; } }; template struct MatView : NonCopyable, NonMovable { using MatT = MatBase; using SrcMatT = MatBase; using col_type = VecBase; using row_type = VecBase; const SrcMatT &mat; MatView() = delete; MatView(const SrcMatT &src) : mat(src) { BLI_STATIC_ASSERT(SrcStartCol >= 0, "View does not fit source matrix dimensions"); BLI_STATIC_ASSERT(SrcStartRow >= 0, "View does not fit source matrix dimensions"); BLI_STATIC_ASSERT(SrcStartCol + NumCol <= SrcNumCol, "View does not fit source matrix dimensions"); BLI_STATIC_ASSERT(SrcStartRow + NumRow <= SrcNumRow, "View does not fit source matrix dimensions"); } /** Allow wrapping C-style matrices using view. IMPORTANT: Alignment of src needs to match. */ explicit MatView(const float (*src)[SrcNumRow]) : MatView(*reinterpret_cast(&src[0][0])){}; /** Array access. */ const col_type &operator[](int index) const { BLI_assert(index >= 0); BLI_assert(index < NumCol); return *reinterpret_cast(&mat[index + SrcStartCol][SrcStartRow]); } /** Conversion back to matrix. */ operator MatT() const { MatT mat; unroll([&](auto c) { mat[c] = (*this)[c]; }); return mat; } /** Matrix operators. */ friend MatT operator+(const MatView &a, T b) { MatT result; unroll([&](auto i) { result[i] = a[i] + b; }); return result; } friend MatT operator+(T a, const MatView &b) { return b + a; } friend MatT operator-(const MatView &a) { MatT result; unroll([&](auto i) { result[i] = -a[i]; }); return result; } template friend MatT operator-(const MatView &a, const MatView &b) { MatT result; unroll([&](auto i) { result[i] = a[i] - b[i]; }); return result; } friend MatT operator-(const MatView &a, const MatT &b) { return a - b.view(); } template friend MatT operator-(const MatView &a, const MatView &b) { MatT result; unroll([&](auto i) { result[i] = a[i] - b[i]; }); return result; } friend MatT operator-(const MatT &a, const MatView &b) { return a.view() - b; } friend MatT operator-(const MatView &a, T b) { MatT result; unroll([&](auto i) { result[i] = a[i] - b; }); return result; } friend MatView operator-(T a, const MatView &b) { MatView result; unroll([&](auto i) { result[i] = a - b[i]; }); return result; } /** Multiply two matrices using matrix multiplication. */ template MatBase operator*(const MatView &b) const { const MatView &a = *this; /* This is the reference implementation. * Might be overloaded with vectorized / optimized code. */ MatBase result{}; unroll([&](auto j) { unroll([&](auto i) { /* Same as dot product, but avoid dependency on vector math. */ unroll([&](auto k) { result[j][i] += a[k][i] * b[j][k]; }); }); }); return result; } MatT operator*(const MatT &b) const { return *this * b.view(); } /** Multiply each component by a scalar. */ friend MatT operator*(const MatView &a, T b) { MatT result; unroll([&](auto i) { result[i] = a[i] * b; }); return result; } /** Multiply each component by a scalar. */ friend MatT operator*(T a, const MatView &b) { return b * a; } /** Vector operators. */ friend col_type operator*(const MatView &a, const row_type &b) { /* This is the reference implementation. * Might be overloaded with vectorized / optimized code. */ col_type result(0); unroll([&](auto c) { result += b[c] * a[c]; }); return result; } /** Multiply by the transposed. */ friend row_type operator*(const col_type &a, const MatView &b) { /* This is the reference implementation. * Might be overloaded with vectorized / optimized code. */ row_type result(0); unroll([&](auto c) { unroll([&](auto r) { result[c] += b[c][r] * a[r]; }); }); return result; } /** Compare. */ friend bool operator==(const MatView &a, const MatView &b) { for (int i = 0; i < NumCol; i++) { if (a[i] != b[i]) { return false; } } return true; } friend bool operator!=(const MatView &a, const MatView &b) { return !(a == b); } /** Miscellaneous. */ friend std::ostream &operator<<(std::ostream &stream, const MatView &mat) { return stream << mat->mat; } }; template struct MutableMatView : MatView { using MatT = MatBase; using MatViewT = MatView; using SrcMatT = MatBase; using col_type = VecBase; using row_type = VecBase; public: MutableMatView() = delete; MutableMatView(SrcMatT &src) : MatViewT(const_cast(src)){}; /** Allow wrapping C-style matrices using view. IMPORTANT: Alignment of src needs to match. */ explicit MutableMatView(float src[SrcNumCol][SrcNumRow]) : MutableMatView(*reinterpret_cast(&src[0][0])){}; /** Array access. */ col_type &operator[](int index) { return const_cast(static_cast(*this)[index]); } /** Conversion to immutable view. */ operator MatViewT() const { return MatViewT(this->mat); } /** Copy Assignment. */ template MutableMatView &operator=(const MatView &other) { BLI_assert_msg( (reinterpret_cast(&other.mat[0][0]) != reinterpret_cast(&this->mat[0][0])) || /* Make sure assignment won't overwrite the source. OtherSrc* is the source. */ ((OtherSrcStartCol > SrcStartCol) || (OtherSrcStartCol + NumCol <= SrcStartCol) || (OtherSrcStartRow > SrcStartRow + NumRow) || (OtherSrcStartRow + NumRow <= SrcStartRow)), "Operation is undefined if views overlap."); unroll([&](auto i) { (*this)[i] = other[i]; }); return *this; } MutableMatView &operator=(const MatT &other) { *this = other.view(); return *this; } /** Matrix operators. */ template MutableMatView &operator+=(const MatView &b) { unroll([&](auto i) { (*this)[i] += b[i]; }); return *this; } MutableMatView &operator+=(const MatT &b) { return *this += b.view(); } MutableMatView &operator+=(T b) { unroll([&](auto i) { (*this)[i] += b; }); return *this; } template MutableMatView &operator-=(const MatView &b) { unroll([&](auto i) { (*this)[i] -= b[i]; }); return *this; } MutableMatView &operator-=(const MatT &b) { return *this -= b.view(); } MutableMatView &operator-=(T b) { unroll([&](auto i) { (*this)[i] -= b; }); return *this; } /** Multiply two matrices using matrix multiplication. */ template MutableMatView &operator*=(const MatView &b) { *this = *static_cast(this) * b; return *this; } MutableMatView &operator*=(const MatT &b) { return *this *= b.view(); } /** Multiply each component by a scalar. */ MutableMatView &operator*=(T b) { unroll([&](auto i) { (*this)[i] *= b; }); return *this; } /** Vector operators. Need to be redefined to avoid operator priority issue. */ friend col_type operator*(MutableMatView &a, const row_type &b) { /* This is the reference implementation. * Might be overloaded with vectorized / optimized code. */ col_type result(0); unroll([&](auto c) { result += b[c] * a[c]; }); return result; } /** Multiply by the transposed. */ friend row_type operator*(const col_type &a, MutableMatView &b) { /* This is the reference implementation. * Might be overloaded with vectorized / optimized code. */ row_type result(0); unroll([&](auto c) { unroll([&](auto r) { result[c] += b[c][r] * a[r]; }); }); return result; } }; using float2x2 = MatBase; using float2x3 = MatBase; using float2x4 = MatBase; using float3x2 = MatBase; using float3x3 = MatBase; using float3x4 = MatBase; using float4x2 = MatBase; using float4x3 = MatBase; using float4x4 = MatBase; /* These types are reserved to wrap C matrices without copy. Note the un-alignment. */ /* TODO: It would be preferable to align all C matrices inside DNA structs. */ using float4x4_view = MatView; using float4x4_mutableview = MutableMatView; using double2x2 = MatBase; using double2x3 = MatBase; using double2x4 = MatBase; using double3x2 = MatBase; using double3x3 = MatBase; using double3x4 = MatBase; using double4x2 = MatBase; using double4x3 = MatBase; using double4x4 = MatBase; /* Specialization for SSE optimization. */ template<> float4x4 float4x4::operator*(const float4x4 &b) const; template<> float3x3 float3x3::operator*(const float3x3 &b) const; extern template float2x2 float2x2::operator*(const float2x2 &b) const; extern template double2x2 double2x2::operator*(const double2x2 &b) const; extern template double3x3 double3x3::operator*(const double3x3 &b) const; extern template double4x4 double4x4::operator*(const double4x4 &b) const; } // namespace blender