Program Listing for File linear_algebra.hpp
↰ Return to documentation for file (src/navtk/linear_algebra.hpp)
#pragma once
#include <xtensor-blas/xlinalg.hpp>
#include <xtensor/reducers/xnorm.hpp>
#include <navtk/factory.hpp>
#include <navtk/inspect.hpp>
#include <navtk/navutils/math.hpp>
#include <navtk/tensors.hpp>
#include <navtk/utils/ValidationContext.hpp>
#include <xtensor/misc/xpad.hpp>
namespace navtk {
Matrix expm(const Matrix& matrix);
Matrix matrix_power(const Matrix& matrix, long n);
Matrix chol(const Matrix& matrix);
Matrix sqrt_of_main_diagonal(const Matrix& matrix);
Matrix calc_cov(const Matrix& matrix);
Matrix calc_cov_weighted(const Matrix& matrix, const Vector& weights);
Matrix _dot(const Matrix& a, const Matrix& b);
template <typename A, typename B, IfFirstTensorOfDim<A, B, 1>* = nullptr>
Vector dot(A&& a, B&& b) {
return to_vec(_dot(to_matrix(std::forward<A>(a), 0), to_matrix(std::forward<B>(b))));
}
#ifndef NEED_DOXYGEN_EXHALE_WORKAROUND
template <typename A, typename B, IfSecondTensorOfDim<A, B, 1>* = nullptr>
Vector dot(A&& a, B&& b) {
return to_vec(_dot(to_matrix(std::forward<A>(a)), to_matrix(std::forward<B>(b), 1)));
}
template <typename A, typename B, IfBothTensorsOfDim<A, B, 1>* = nullptr>
Vector dot(A&& a, B&& b) {
return to_vec(
_dot(xt::transpose(to_matrix(std::forward<A>(a), 1)), to_matrix(std::forward<B>(b), 1)));
}
template <typename A, typename B, IfBothTensorsOfDim<A, B, 2>* = nullptr>
Matrix dot(A&& a, B&& b) {
Matrix a_mat = to_matrix(std::forward<A>(a));
Matrix b_mat = to_matrix(std::forward<B>(b));
# ifdef __aarch64__
// Manually multiply on ARM because xt::linalg::dot seems to trigger UB (#906)
if (has_zero_size(a_mat) || has_zero_size(b_mat)) return Matrix{};
if (utils::ValidationContext validation{}) {
validation.add_matrix(a_mat, "a_mat")
.dim('X', 'N')
.add_matrix(b_mat, "b_mat")
.dim('N', 'Y')
.validate();
}
auto left_rows = num_rows(a_mat);
auto left_cols = num_cols(a_mat);
// auto right_rows = right.shape().at(0);
auto right_cols = num_cols(b_mat);
Matrix out = zeros(left_rows, right_cols);
for (size_t i = 0; i < left_rows; i++) {
for (size_t idx = 0; idx < left_cols; idx++) {
for (size_t j = 0; j < right_cols; j++) {
out(i, j) += a(i, idx) * b(idx, j);
}
}
}
return out;
# else
return _dot(a_mat, b_mat);
# endif
}
template <typename A, typename B, IfBothTensorsOfDim<A, B, 2>* = nullptr>
Matrix transpose_a_dot_b(A&& a, B&& b) {
Matrix a_mat = to_matrix(std::forward<A>(a));
Matrix b_mat = to_matrix(std::forward<B>(b));
# ifdef __aarch64__
// Manually multiply on ARM because xt::linalg::dot seems to trigger UB (#906)
if (has_zero_size(a_mat) || has_zero_size(b_mat)) return Matrix{};
if (utils::ValidationContext validation{}) {
validation.add_matrix(a_mat, "a_mat")
.dim('X', 'N')
.add_matrix(b_mat, "b_mat")
.dim('Y', 'N')
.validate();
}
Size rows = num_rows(a_mat);
Size columns = num_rows(b_mat);
Size join_size = num_cols(a_mat);
Matrix out = zeros(rows, columns);
for (size_t i = 0; i < rows; i++)
for (size_t j = 0; j < columns; j++)
for (size_t idx = 0; idx < join_size; idx++) {
out(i, j) += a(i, idx) * b(j, idx);
}
return out;
# else
return _dot(a_mat, xt::transpose(b_mat));
# endif
}
template <typename A, typename B, IfBothTensorsOfDim<A, B, 2>* = nullptr>
Matrix a_dot_transpose_b(A&& a, B&& b) {
Matrix a_mat = to_matrix(std::forward<A>(a));
Matrix b_mat = to_matrix(std::forward<B>(b));
# ifdef __aarch64__
// Manually multiply on ARM because xt::linalg::dot seems to trigger UB (#906)
if (has_zero_size(a_mat) || has_zero_size(b_mat)) return Matrix{};
if (utils::ValidationContext validation{}) {
validation.add_matrix(a_mat, "a_mat")
.dim('N', 'X')
.add_matrix(b_mat, "b_mat")
.dim('N', 'Y')
.validate();
}
Size rows = num_cols(a_mat);
Size columns = num_cols(b_mat);
Size join_size = num_rows(a_mat);
Matrix out = zeros(rows, columns);
for (size_t idx = 0; idx < join_size; idx++)
for (size_t i = 0; i < rows; i++)
for (size_t j = 0; j < columns; j++) {
out(i, j) += a(idx, i) * b(idx, j);
}
return out;
# else
return _dot(xt::transpose(a_mat), b_mat);
# endif
}
#endif
Matrix inverse(const Matrix& m);
double norm(const Matrix& m);
double norm(const Vector& m);
Vector3 cross(const Vector3& m, const Vector3& n);
Vector solve_tridiagonal(const Vector& low, const Vector& mid, const Vector& up, const Vector& b);
Vector solve_tridiagonal_overwrite(Vector& low, Vector& mid, Vector& up, Vector& b);
Matrix3 solve_wahba_svd(const Matrix3& outer);
Matrix3 solve_wahba_svd(const std::vector<Vector3>& p, const std::vector<Vector3>& r);
std::vector<Matrix3> solve_wahba_davenport(const std::vector<Vector3>& p,
const std::vector<Vector3>& r);
std::vector<Matrix3> solve_wahba_davenport(const Matrix3& outer, const Vector3& cr);
} // namespace navtk