diff --git a/tempest/matrix/C44Matrix.cpp b/tempest/matrix/C44Matrix.cpp index 548ea58..823cfcc 100644 --- a/tempest/matrix/C44Matrix.cpp +++ b/tempest/matrix/C44Matrix.cpp @@ -83,3 +83,27 @@ C44Matrix operator*(const C44Matrix& l, float a) { return { a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3, d0, d1, d2, d3 }; } + +C44Matrix operator*(const C44Matrix& l, const C44Matrix& r) { + float a0 = l.a0 * r.a0 + l.a1 * r.b0 + l.a2 * r.c0 + l.a3 * r.d0; + float a1 = l.a0 * r.a1 + l.a1 * r.b1 + l.a2 * r.c1 + l.a3 * r.d1; + float a2 = l.a0 * r.a2 + l.a1 * r.b2 + l.a2 * r.c2 + l.a3 * r.d2; + float a3 = l.a0 * r.a3 + l.a1 * r.b3 + l.a2 * r.c3 + l.a3 * r.d3; + + float b0 = l.b0 * r.a0 + l.b1 * r.b0 + l.b2 * r.c0 + l.b3 * r.d0; + float b1 = l.b0 * r.a1 + l.b1 * r.b1 + l.b2 * r.c1 + l.b3 * r.d1; + float b2 = l.b0 * r.a2 + l.b1 * r.b2 + l.b2 * r.c2 + l.b3 * r.d2; + float b3 = l.b0 * r.a3 + l.b1 * r.b3 + l.b2 * r.c3 + l.b3 * r.d3; + + float c0 = l.c0 * r.a0 + l.c1 * r.b0 + l.c2 * r.c0 + l.c3 * r.d0; + float c1 = l.c0 * r.a1 + l.c1 * r.b1 + l.c2 * r.c1 + l.c3 * r.d1; + float c2 = l.c0 * r.a2 + l.c1 * r.b2 + l.c2 * r.c2 + l.c3 * r.d2; + float c3 = l.c0 * r.a3 + l.c1 * r.b3 + l.c2 * r.c3 + l.c3 * r.d3; + + float d0 = l.d0 * r.a0 + l.d1 * r.b0 + l.d2 * r.c0 + l.d3 * r.d0; + float d1 = l.d0 * r.a1 + l.d1 * r.b1 + l.d2 * r.c1 + l.d3 * r.d1; + float d2 = l.d0 * r.a2 + l.d1 * r.b2 + l.d2 * r.c2 + l.d3 * r.d2; + float d3 = l.d0 * r.a3 + l.d1 * r.b3 + l.d2 * r.c3 + l.d3 * r.d3; + + return { a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3, d0, d1, d2, d3 }; +} diff --git a/tempest/matrix/C44Matrix.hpp b/tempest/matrix/C44Matrix.hpp index 2766115..7a65633 100644 --- a/tempest/matrix/C44Matrix.hpp +++ b/tempest/matrix/C44Matrix.hpp @@ -50,4 +50,6 @@ class C44Matrix { C44Matrix operator*(const C44Matrix& l, float a); +C44Matrix operator*(const C44Matrix& l, const C44Matrix& r); + #endif diff --git a/test/Matrix.cpp b/test/Matrix.cpp index c2de9f0..69c6c0e 100644 --- a/test/Matrix.cpp +++ b/test/Matrix.cpp @@ -256,4 +256,26 @@ TEST_CASE("C44Matrix global operators", "[matrix]") { CHECK(matrix2.d2 == 30.0f); CHECK(matrix2.d3 == 32.0f); } + + SECTION("C44Matrix * C44Matrix") { + auto matrix1 = C44Matrix(1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f); + auto matrix2 = C44Matrix(1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f); + auto matrix3 = matrix1 * matrix2; + CHECK(matrix3.a0 == 90.0f); + CHECK(matrix3.a1 == 100.0f); + CHECK(matrix3.a2 == 110.0f); + CHECK(matrix3.a3 == 120.0f); + CHECK(matrix3.b0 == 202.0f); + CHECK(matrix3.b1 == 228.0f); + CHECK(matrix3.b2 == 254.0f); + CHECK(matrix3.b3 == 280.0f); + CHECK(matrix3.c0 == 314.0f); + CHECK(matrix3.c1 == 356.0f); + CHECK(matrix3.c2 == 398.0f); + CHECK(matrix3.c3 == 440.0f); + CHECK(matrix3.d0 == 426.0f); + CHECK(matrix3.d1 == 484.0f); + CHECK(matrix3.d2 == 542.0f); + CHECK(matrix3.d3 == 600.0f); + } }