eigen.h 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. /*
  2. pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
  3. Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
  4. All rights reserved. Use of this source code is governed by a
  5. BSD-style license that can be found in the LICENSE file.
  6. */
  7. #pragma once
  8. #include "numpy.h"
  9. #if defined(__INTEL_COMPILER)
  10. # pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
  11. #elif defined(__GNUG__) || defined(__clang__)
  12. # pragma GCC diagnostic push
  13. # pragma GCC diagnostic ignored "-Wconversion"
  14. # pragma GCC diagnostic ignored "-Wdeprecated-declarations"
  15. # ifdef __clang__
  16. // Eigen generates a bunch of implicit-copy-constructor-is-deprecated warnings with -Wdeprecated
  17. // under Clang, so disable that warning here:
  18. # pragma GCC diagnostic ignored "-Wdeprecated"
  19. # endif
  20. # if __GNUC__ >= 7
  21. # pragma GCC diagnostic ignored "-Wint-in-bool-context"
  22. # endif
  23. #endif
  24. #if defined(_MSC_VER)
  25. # pragma warning(push)
  26. # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
  27. # pragma warning(disable: 4996) // warning C4996: std::unary_negate is deprecated in C++17
  28. #endif
  29. #include <Eigen/Core>
  30. #include <Eigen/SparseCore>
  31. // Eigen prior to 3.2.7 doesn't have proper move constructors--but worse, some classes get implicit
  32. // move constructors that break things. We could detect this an explicitly copy, but an extra copy
  33. // of matrices seems highly undesirable.
  34. static_assert(EIGEN_VERSION_AT_LEAST(3,2,7), "Eigen support in pybind11 requires Eigen >= 3.2.7");
  35. NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
  36. // Provide a convenience alias for easier pass-by-ref usage with fully dynamic strides:
  37. using EigenDStride = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
  38. template <typename MatrixType> using EigenDRef = Eigen::Ref<MatrixType, 0, EigenDStride>;
  39. template <typename MatrixType> using EigenDMap = Eigen::Map<MatrixType, 0, EigenDStride>;
  40. NAMESPACE_BEGIN(detail)
  41. #if EIGEN_VERSION_AT_LEAST(3,3,0)
  42. using EigenIndex = Eigen::Index;
  43. #else
  44. using EigenIndex = EIGEN_DEFAULT_DENSE_INDEX_TYPE;
  45. #endif
  46. // Matches Eigen::Map, Eigen::Ref, blocks, etc:
  47. template <typename T> using is_eigen_dense_map = all_of<is_template_base_of<Eigen::DenseBase, T>, std::is_base_of<Eigen::MapBase<T, Eigen::ReadOnlyAccessors>, T>>;
  48. template <typename T> using is_eigen_mutable_map = std::is_base_of<Eigen::MapBase<T, Eigen::WriteAccessors>, T>;
  49. template <typename T> using is_eigen_dense_plain = all_of<negation<is_eigen_dense_map<T>>, is_template_base_of<Eigen::PlainObjectBase, T>>;
  50. template <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>;
  51. // Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This
  52. // basically covers anything that can be assigned to a dense matrix but that don't have a typical
  53. // matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and
  54. // SelfAdjointView fall into this category.
  55. template <typename T> using is_eigen_other = all_of<
  56. is_template_base_of<Eigen::EigenBase, T>,
  57. negation<any_of<is_eigen_dense_map<T>, is_eigen_dense_plain<T>, is_eigen_sparse<T>>>
  58. >;
  59. // Captures numpy/eigen conformability status (returned by EigenProps::conformable()):
  60. template <bool EigenRowMajor> struct EigenConformable {
  61. bool conformable = false;
  62. EigenIndex rows = 0, cols = 0;
  63. EigenDStride stride{0, 0}; // Only valid if negativestrides is false!
  64. bool negativestrides = false; // If true, do not use stride!
  65. EigenConformable(bool fits = false) : conformable{fits} {}
  66. // Matrix type:
  67. EigenConformable(EigenIndex r, EigenIndex c,
  68. EigenIndex rstride, EigenIndex cstride) :
  69. conformable{true}, rows{r}, cols{c} {
  70. // TODO: when Eigen bug #747 is fixed, remove the tests for non-negativity. http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747
  71. if (rstride < 0 || cstride < 0) {
  72. negativestrides = true;
  73. } else {
  74. stride = {EigenRowMajor ? rstride : cstride /* outer stride */,
  75. EigenRowMajor ? cstride : rstride /* inner stride */ };
  76. }
  77. }
  78. // Vector type:
  79. EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride)
  80. : EigenConformable(r, c, r == 1 ? c*stride : stride, c == 1 ? r : r*stride) {}
  81. template <typename props> bool stride_compatible() const {
  82. // To have compatible strides, we need (on both dimensions) one of fully dynamic strides,
  83. // matching strides, or a dimension size of 1 (in which case the stride value is irrelevant)
  84. return
  85. !negativestrides &&
  86. (props::inner_stride == Eigen::Dynamic || props::inner_stride == stride.inner() ||
  87. (EigenRowMajor ? cols : rows) == 1) &&
  88. (props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer() ||
  89. (EigenRowMajor ? rows : cols) == 1);
  90. }
  91. operator bool() const { return conformable; }
  92. };
  93. template <typename Type> struct eigen_extract_stride { using type = Type; };
  94. template <typename PlainObjectType, int MapOptions, typename StrideType>
  95. struct eigen_extract_stride<Eigen::Map<PlainObjectType, MapOptions, StrideType>> { using type = StrideType; };
  96. template <typename PlainObjectType, int Options, typename StrideType>
  97. struct eigen_extract_stride<Eigen::Ref<PlainObjectType, Options, StrideType>> { using type = StrideType; };
  98. // Helper struct for extracting information from an Eigen type
  99. template <typename Type_> struct EigenProps {
  100. using Type = Type_;
  101. using Scalar = typename Type::Scalar;
  102. using StrideType = typename eigen_extract_stride<Type>::type;
  103. static constexpr EigenIndex
  104. rows = Type::RowsAtCompileTime,
  105. cols = Type::ColsAtCompileTime,
  106. size = Type::SizeAtCompileTime;
  107. static constexpr bool
  108. row_major = Type::IsRowMajor,
  109. vector = Type::IsVectorAtCompileTime, // At least one dimension has fixed size 1
  110. fixed_rows = rows != Eigen::Dynamic,
  111. fixed_cols = cols != Eigen::Dynamic,
  112. fixed = size != Eigen::Dynamic, // Fully-fixed size
  113. dynamic = !fixed_rows && !fixed_cols; // Fully-dynamic size
  114. template <EigenIndex i, EigenIndex ifzero> using if_zero = std::integral_constant<EigenIndex, i == 0 ? ifzero : i>;
  115. static constexpr EigenIndex inner_stride = if_zero<StrideType::InnerStrideAtCompileTime, 1>::value,
  116. outer_stride = if_zero<StrideType::OuterStrideAtCompileTime,
  117. vector ? size : row_major ? cols : rows>::value;
  118. static constexpr bool dynamic_stride = inner_stride == Eigen::Dynamic && outer_stride == Eigen::Dynamic;
  119. static constexpr bool requires_row_major = !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1;
  120. static constexpr bool requires_col_major = !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1;
  121. // Takes an input array and determines whether we can make it fit into the Eigen type. If
  122. // the array is a vector, we attempt to fit it into either an Eigen 1xN or Nx1 vector
  123. // (preferring the latter if it will fit in either, i.e. for a fully dynamic matrix type).
  124. static EigenConformable<row_major> conformable(const array &a) {
  125. const auto dims = a.ndim();
  126. if (dims < 1 || dims > 2)
  127. return false;
  128. if (dims == 2) { // Matrix type: require exact match (or dynamic)
  129. EigenIndex
  130. np_rows = a.shape(0),
  131. np_cols = a.shape(1),
  132. np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
  133. np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
  134. if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
  135. return false;
  136. return {np_rows, np_cols, np_rstride, np_cstride};
  137. }
  138. // Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
  139. // is used, we want the (single) numpy stride value.
  140. const EigenIndex n = a.shape(0),
  141. stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
  142. if (vector) { // Eigen type is a compile-time vector
  143. if (fixed && size != n)
  144. return false; // Vector size mismatch
  145. return {rows == 1 ? 1 : n, cols == 1 ? 1 : n, stride};
  146. }
  147. else if (fixed) {
  148. // The type has a fixed size, but is not a vector: abort
  149. return false;
  150. }
  151. else if (fixed_cols) {
  152. // Since this isn't a vector, cols must be != 1. We allow this only if it exactly
  153. // equals the number of elements (rows is Dynamic, and so 1 row is allowed).
  154. if (cols != n) return false;
  155. return {1, n, stride};
  156. }
  157. else {
  158. // Otherwise it's either fully dynamic, or column dynamic; both become a column vector
  159. if (fixed_rows && rows != n) return false;
  160. return {n, 1, stride};
  161. }
  162. }
  163. static constexpr bool show_writeable = is_eigen_dense_map<Type>::value && is_eigen_mutable_map<Type>::value;
  164. static constexpr bool show_order = is_eigen_dense_map<Type>::value;
  165. static constexpr bool show_c_contiguous = show_order && requires_row_major;
  166. static constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major;
  167. static constexpr auto descriptor =
  168. _("numpy.ndarray[") + npy_format_descriptor<Scalar>::name +
  169. _("[") + _<fixed_rows>(_<(size_t) rows>(), _("m")) +
  170. _(", ") + _<fixed_cols>(_<(size_t) cols>(), _("n")) +
  171. _("]") +
  172. // For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to be
  173. // satisfied: writeable=True (for a mutable reference), and, depending on the map's stride
  174. // options, possibly f_contiguous or c_contiguous. We include them in the descriptor output
  175. // to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to
  176. // see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you
  177. // *gave* a numpy.ndarray of the right type and dimensions.
  178. _<show_writeable>(", flags.writeable", "") +
  179. _<show_c_contiguous>(", flags.c_contiguous", "") +
  180. _<show_f_contiguous>(", flags.f_contiguous", "") +
  181. _("]");
  182. };
  183. // Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,
  184. // otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
  185. template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
  186. constexpr ssize_t elem_size = sizeof(typename props::Scalar);
  187. array a;
  188. if (props::vector)
  189. a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
  190. else
  191. a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
  192. src.data(), base);
  193. if (!writeable)
  194. array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
  195. return a.release();
  196. }
  197. // Takes an lvalue ref to some Eigen type and a (python) base object, creating a numpy array that
  198. // reference the Eigen object's data with `base` as the python-registered base class (if omitted,
  199. // the base will be set to None, and lifetime management is up to the caller). The numpy array is
  200. // non-writeable if the given type is const.
  201. template <typename props, typename Type>
  202. handle eigen_ref_array(Type &src, handle parent = none()) {
  203. // none here is to get past array's should-we-copy detection, which currently always
  204. // copies when there is no base. Setting the base to None should be harmless.
  205. return eigen_array_cast<props>(src, parent, !std::is_const<Type>::value);
  206. }
  207. // Takes a pointer to some dense, plain Eigen type, builds a capsule around it, then returns a numpy
  208. // array that references the encapsulated data with a python-side reference to the capsule to tie
  209. // its destruction to that of any dependent python objects. Const-ness is determined by whether or
  210. // not the Type of the pointer given is const.
  211. template <typename props, typename Type, typename = enable_if_t<is_eigen_dense_plain<Type>::value>>
  212. handle eigen_encapsulate(Type *src) {
  213. capsule base(src, [](void *o) { delete static_cast<Type *>(o); });
  214. return eigen_ref_array<props>(*src, base);
  215. }
  216. // Type caster for regular, dense matrix types (e.g. MatrixXd), but not maps/refs/etc. of dense
  217. // types.
  218. template<typename Type>
  219. struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
  220. using Scalar = typename Type::Scalar;
  221. using props = EigenProps<Type>;
  222. bool load(handle src, bool convert) {
  223. // If we're in no-convert mode, only load if given an array of the correct type
  224. if (!convert && !isinstance<array_t<Scalar>>(src))
  225. return false;
  226. // Coerce into an array, but don't do type conversion yet; the copy below handles it.
  227. auto buf = array::ensure(src);
  228. if (!buf)
  229. return false;
  230. auto dims = buf.ndim();
  231. if (dims < 1 || dims > 2)
  232. return false;
  233. auto fits = props::conformable(buf);
  234. if (!fits)
  235. return false;
  236. // Allocate the new type, then build a numpy reference into it
  237. value = Type(fits.rows, fits.cols);
  238. auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
  239. if (dims == 1) ref = ref.squeeze();
  240. else if (ref.ndim() == 1) buf = buf.squeeze();
  241. int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
  242. if (result < 0) { // Copy failed!
  243. PyErr_Clear();
  244. return false;
  245. }
  246. return true;
  247. }
  248. private:
  249. // Cast implementation
  250. template <typename CType>
  251. static handle cast_impl(CType *src, return_value_policy policy, handle parent) {
  252. switch (policy) {
  253. case return_value_policy::take_ownership:
  254. case return_value_policy::automatic:
  255. return eigen_encapsulate<props>(src);
  256. case return_value_policy::move:
  257. return eigen_encapsulate<props>(new CType(std::move(*src)));
  258. case return_value_policy::copy:
  259. return eigen_array_cast<props>(*src);
  260. case return_value_policy::reference:
  261. case return_value_policy::automatic_reference:
  262. return eigen_ref_array<props>(*src);
  263. case return_value_policy::reference_internal:
  264. return eigen_ref_array<props>(*src, parent);
  265. default:
  266. throw cast_error("unhandled return_value_policy: should not happen!");
  267. };
  268. }
  269. public:
  270. // Normal returned non-reference, non-const value:
  271. static handle cast(Type &&src, return_value_policy /* policy */, handle parent) {
  272. return cast_impl(&src, return_value_policy::move, parent);
  273. }
  274. // If you return a non-reference const, we mark the numpy array readonly:
  275. static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) {
  276. return cast_impl(&src, return_value_policy::move, parent);
  277. }
  278. // lvalue reference return; default (automatic) becomes copy
  279. static handle cast(Type &src, return_value_policy policy, handle parent) {
  280. if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
  281. policy = return_value_policy::copy;
  282. return cast_impl(&src, policy, parent);
  283. }
  284. // const lvalue reference return; default (automatic) becomes copy
  285. static handle cast(const Type &src, return_value_policy policy, handle parent) {
  286. if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
  287. policy = return_value_policy::copy;
  288. return cast(&src, policy, parent);
  289. }
  290. // non-const pointer return
  291. static handle cast(Type *src, return_value_policy policy, handle parent) {
  292. return cast_impl(src, policy, parent);
  293. }
  294. // const pointer return
  295. static handle cast(const Type *src, return_value_policy policy, handle parent) {
  296. return cast_impl(src, policy, parent);
  297. }
  298. static constexpr auto name = props::descriptor;
  299. operator Type*() { return &value; }
  300. operator Type&() { return value; }
  301. operator Type&&() && { return std::move(value); }
  302. template <typename T> using cast_op_type = movable_cast_op_type<T>;
  303. private:
  304. Type value;
  305. };
  306. // Base class for casting reference/map/block/etc. objects back to python.
  307. template <typename MapType> struct eigen_map_caster {
  308. private:
  309. using props = EigenProps<MapType>;
  310. public:
  311. // Directly referencing a ref/map's data is a bit dangerous (whatever the map/ref points to has
  312. // to stay around), but we'll allow it under the assumption that you know what you're doing (and
  313. // have an appropriate keep_alive in place). We return a numpy array pointing directly at the
  314. // ref's data (The numpy array ends up read-only if the ref was to a const matrix type.) Note
  315. // that this means you need to ensure you don't destroy the object in some other way (e.g. with
  316. // an appropriate keep_alive, or with a reference to a statically allocated matrix).
  317. static handle cast(const MapType &src, return_value_policy policy, handle parent) {
  318. switch (policy) {
  319. case return_value_policy::copy:
  320. return eigen_array_cast<props>(src);
  321. case return_value_policy::reference_internal:
  322. return eigen_array_cast<props>(src, parent, is_eigen_mutable_map<MapType>::value);
  323. case return_value_policy::reference:
  324. case return_value_policy::automatic:
  325. case return_value_policy::automatic_reference:
  326. return eigen_array_cast<props>(src, none(), is_eigen_mutable_map<MapType>::value);
  327. default:
  328. // move, take_ownership don't make any sense for a ref/map:
  329. pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type");
  330. }
  331. }
  332. static constexpr auto name = props::descriptor;
  333. // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
  334. // types but not bound arguments). We still provide them (with an explicitly delete) so that
  335. // you end up here if you try anyway.
  336. bool load(handle, bool) = delete;
  337. operator MapType() = delete;
  338. template <typename> using cast_op_type = MapType;
  339. };
  340. // We can return any map-like object (but can only load Refs, specialized next):
  341. template <typename Type> struct type_caster<Type, enable_if_t<is_eigen_dense_map<Type>::value>>
  342. : eigen_map_caster<Type> {};
  343. // Loader for Ref<...> arguments. See the documentation for info on how to make this work without
  344. // copying (it requires some extra effort in many cases).
  345. template <typename PlainObjectType, typename StrideType>
  346. struct type_caster<
  347. Eigen::Ref<PlainObjectType, 0, StrideType>,
  348. enable_if_t<is_eigen_dense_map<Eigen::Ref<PlainObjectType, 0, StrideType>>::value>
  349. > : public eigen_map_caster<Eigen::Ref<PlainObjectType, 0, StrideType>> {
  350. private:
  351. using Type = Eigen::Ref<PlainObjectType, 0, StrideType>;
  352. using props = EigenProps<Type>;
  353. using Scalar = typename props::Scalar;
  354. using MapType = Eigen::Map<PlainObjectType, 0, StrideType>;
  355. using Array = array_t<Scalar, array::forcecast |
  356. ((props::row_major ? props::inner_stride : props::outer_stride) == 1 ? array::c_style :
  357. (props::row_major ? props::outer_stride : props::inner_stride) == 1 ? array::f_style : 0)>;
  358. static constexpr bool need_writeable = is_eigen_mutable_map<Type>::value;
  359. // Delay construction (these have no default constructor)
  360. std::unique_ptr<MapType> map;
  361. std::unique_ptr<Type> ref;
  362. // Our array. When possible, this is just a numpy array pointing to the source data, but
  363. // sometimes we can't avoid copying (e.g. input is not a numpy array at all, has an incompatible
  364. // layout, or is an array of a type that needs to be converted). Using a numpy temporary
  365. // (rather than an Eigen temporary) saves an extra copy when we need both type conversion and
  366. // storage order conversion. (Note that we refuse to use this temporary copy when loading an
  367. // argument for a Ref<M> with M non-const, i.e. a read-write reference).
  368. Array copy_or_ref;
  369. public:
  370. bool load(handle src, bool convert) {
  371. // First check whether what we have is already an array of the right type. If not, we can't
  372. // avoid a copy (because the copy is also going to do type conversion).
  373. bool need_copy = !isinstance<Array>(src);
  374. EigenConformable<props::row_major> fits;
  375. if (!need_copy) {
  376. // We don't need a converting copy, but we also need to check whether the strides are
  377. // compatible with the Ref's stride requirements
  378. Array aref = reinterpret_borrow<Array>(src);
  379. if (aref && (!need_writeable || aref.writeable())) {
  380. fits = props::conformable(aref);
  381. if (!fits) return false; // Incompatible dimensions
  382. if (!fits.template stride_compatible<props>())
  383. need_copy = true;
  384. else
  385. copy_or_ref = std::move(aref);
  386. }
  387. else {
  388. need_copy = true;
  389. }
  390. }
  391. if (need_copy) {
  392. // We need to copy: If we need a mutable reference, or we're not supposed to convert
  393. // (either because we're in the no-convert overload pass, or because we're explicitly
  394. // instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
  395. if (!convert || need_writeable) return false;
  396. Array copy = Array::ensure(src);
  397. if (!copy) return false;
  398. fits = props::conformable(copy);
  399. if (!fits || !fits.template stride_compatible<props>())
  400. return false;
  401. copy_or_ref = std::move(copy);
  402. loader_life_support::add_patient(copy_or_ref);
  403. }
  404. ref.reset();
  405. map.reset(new MapType(data(copy_or_ref), fits.rows, fits.cols, make_stride(fits.stride.outer(), fits.stride.inner())));
  406. ref.reset(new Type(*map));
  407. return true;
  408. }
  409. operator Type*() { return ref.get(); }
  410. operator Type&() { return *ref; }
  411. template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
  412. private:
  413. template <typename T = Type, enable_if_t<is_eigen_mutable_map<T>::value, int> = 0>
  414. Scalar *data(Array &a) { return a.mutable_data(); }
  415. template <typename T = Type, enable_if_t<!is_eigen_mutable_map<T>::value, int> = 0>
  416. const Scalar *data(Array &a) { return a.data(); }
  417. // Attempt to figure out a constructor of `Stride` that will work.
  418. // If both strides are fixed, use a default constructor:
  419. template <typename S> using stride_ctor_default = bool_constant<
  420. S::InnerStrideAtCompileTime != Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
  421. std::is_default_constructible<S>::value>;
  422. // Otherwise, if there is a two-index constructor, assume it is (outer,inner) like
  423. // Eigen::Stride, and use it:
  424. template <typename S> using stride_ctor_dual = bool_constant<
  425. !stride_ctor_default<S>::value && std::is_constructible<S, EigenIndex, EigenIndex>::value>;
  426. // Otherwise, if there is a one-index constructor, and just one of the strides is dynamic, use
  427. // it (passing whichever stride is dynamic).
  428. template <typename S> using stride_ctor_outer = bool_constant<
  429. !any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
  430. S::OuterStrideAtCompileTime == Eigen::Dynamic && S::InnerStrideAtCompileTime != Eigen::Dynamic &&
  431. std::is_constructible<S, EigenIndex>::value>;
  432. template <typename S> using stride_ctor_inner = bool_constant<
  433. !any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
  434. S::InnerStrideAtCompileTime == Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
  435. std::is_constructible<S, EigenIndex>::value>;
  436. template <typename S = StrideType, enable_if_t<stride_ctor_default<S>::value, int> = 0>
  437. static S make_stride(EigenIndex, EigenIndex) { return S(); }
  438. template <typename S = StrideType, enable_if_t<stride_ctor_dual<S>::value, int> = 0>
  439. static S make_stride(EigenIndex outer, EigenIndex inner) { return S(outer, inner); }
  440. template <typename S = StrideType, enable_if_t<stride_ctor_outer<S>::value, int> = 0>
  441. static S make_stride(EigenIndex outer, EigenIndex) { return S(outer); }
  442. template <typename S = StrideType, enable_if_t<stride_ctor_inner<S>::value, int> = 0>
  443. static S make_stride(EigenIndex, EigenIndex inner) { return S(inner); }
  444. };
  445. // type_caster for special matrix types (e.g. DiagonalMatrix), which are EigenBase, but not
  446. // EigenDense (i.e. they don't have a data(), at least not with the usual matrix layout).
  447. // load() is not supported, but we can cast them into the python domain by first copying to a
  448. // regular Eigen::Matrix, then casting that.
  449. template <typename Type>
  450. struct type_caster<Type, enable_if_t<is_eigen_other<Type>::value>> {
  451. protected:
  452. using Matrix = Eigen::Matrix<typename Type::Scalar, Type::RowsAtCompileTime, Type::ColsAtCompileTime>;
  453. using props = EigenProps<Matrix>;
  454. public:
  455. static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
  456. handle h = eigen_encapsulate<props>(new Matrix(src));
  457. return h;
  458. }
  459. static handle cast(const Type *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); }
  460. static constexpr auto name = props::descriptor;
  461. // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
  462. // types but not bound arguments). We still provide them (with an explicitly delete) so that
  463. // you end up here if you try anyway.
  464. bool load(handle, bool) = delete;
  465. operator Type() = delete;
  466. template <typename> using cast_op_type = Type;
  467. };
  468. template<typename Type>
  469. struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
  470. typedef typename Type::Scalar Scalar;
  471. typedef remove_reference_t<decltype(*std::declval<Type>().outerIndexPtr())> StorageIndex;
  472. typedef typename Type::Index Index;
  473. static constexpr bool rowMajor = Type::IsRowMajor;
  474. bool load(handle src, bool) {
  475. if (!src)
  476. return false;
  477. auto obj = reinterpret_borrow<object>(src);
  478. object sparse_module = module::import("scipy.sparse");
  479. object matrix_type = sparse_module.attr(
  480. rowMajor ? "csr_matrix" : "csc_matrix");
  481. if (!obj.get_type().is(matrix_type)) {
  482. try {
  483. obj = matrix_type(obj);
  484. } catch (const error_already_set &) {
  485. return false;
  486. }
  487. }
  488. auto values = array_t<Scalar>((object) obj.attr("data"));
  489. auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
  490. auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
  491. auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
  492. auto nnz = obj.attr("nnz").cast<Index>();
  493. if (!values || !innerIndices || !outerIndices)
  494. return false;
  495. value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
  496. shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
  497. outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
  498. return true;
  499. }
  500. static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
  501. const_cast<Type&>(src).makeCompressed();
  502. object matrix_type = module::import("scipy.sparse").attr(
  503. rowMajor ? "csr_matrix" : "csc_matrix");
  504. array data(src.nonZeros(), src.valuePtr());
  505. array outerIndices((rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
  506. array innerIndices(src.nonZeros(), src.innerIndexPtr());
  507. return matrix_type(
  508. std::make_tuple(data, innerIndices, outerIndices),
  509. std::make_pair(src.rows(), src.cols())
  510. ).release();
  511. }
  512. PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
  513. + npy_format_descriptor<Scalar>::name + _("]"));
  514. };
  515. NAMESPACE_END(detail)
  516. NAMESPACE_END(PYBIND11_NAMESPACE)
  517. #if defined(__GNUG__) || defined(__clang__)
  518. # pragma GCC diagnostic pop
  519. #elif defined(_MSC_VER)
  520. # pragma warning(pop)
  521. #endif