Skip to content

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Dec 1, 2025

Spinoff from #811

Major changes:

  1. Systematic fallback to obj mode for complex inputs
  2. Systematic casting of discrete inputs to floats (with a warning on compile_verbose)
  3. Systematic upcasting of inputs in operations with more than one input
  4. Change view(inp).ctypes -> inp.ctypes
  5. Explicitly handle empty inputs, as lapack tends to raise or emit warning

Re: Change view(inp).ctypes -> inp.ctypes

This may have be needed when working with complex inputs? But we are not supporting them in most implementations, so it makes code more complex and is a potential source of bugs when we fail to systematically upcast input (point 3. from above)

Say we have a float32 and a float64 inputs, and forget to upcast the first one. Calling view will raise for non f-contiguous inputs (which we always need for these routines):

import numpy as np

x = np.asfortranarray(np.eye(3, dtype="float64"))
x.view(dtype="float32")   # ValueError: To change to a dtype of a different size, the last axis must be contiguous

Even if it didn't raise the meaning of the array would be nonsensical:

np.eye(3, dtype="float64").view(dtype="float32")
# array([[0.   , 1.875, 0.   , 0.   , 0.   , 0.   ],
#       [0.   , 0.   , 0.   , 1.875, 0.   , 0.   ],
#       [0.   , 0.   , 0.   , 0.   , 0.   , 1.875]], dtype=float32)

@ricardoV94
Copy link
Member Author

We should redo _check_scipy_linalg_matrix to take a dtype as well, so if an input dtype does not match the operation dtype it raises during numba typing

@ricardoV94
Copy link
Member Author

ricardoV94 commented Dec 5, 2025

@jessegrabowski I didn't touch the QR stuff, could you give me a hand?

I don't recall if you were supporting complex inputs, and if in that case the view thing is needed. Also I was too lazy to think about the empty input case for it.

@jessegrabowski
Copy link
Member

I can attack it over the weekend yeah

@ricardoV94
Copy link
Member Author

I'm not testing the empty case for factor/solve tridiagonal, because I don't know how to define a valid empty case for those Ops.

@ricardoV94 ricardoV94 requested a review from Copilot December 5, 2025 17:06
@ricardoV94 ricardoV94 marked this pull request as ready for review December 5, 2025 17:06
@ricardoV94
Copy link
Member Author

ricardoV94 commented Dec 5, 2025

Orthogonal to your help, the PR is ready for review @jessegrabowski. I understand if you want to take your time here as it is kind of your baby

Copilot finished reviewing on behalf of ricardoV94 December 5, 2025 17:09
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR systematically handles mixed input dtypes and empty arrays in numba LAPACK functions. The changes improve robustness by casting discrete inputs to floats, upcasting mixed inputs, handling empty arrays explicitly, and simplifying LAPACK wrapper code by removing unnecessary .view() calls.

Key Changes

  • Added systematic dtype handling for discrete/complex inputs with fallback to obj mode for complex types
  • Added explicit empty array handling to prevent LAPACK warnings/errors
  • Simplified LAPACK wrapper code by removing .view(dtype).ctypes pattern in favor of direct .ctypes access
  • Updated Op make_node methods to properly infer output dtypes based on input dtypes
  • Reorganized test classes for better structure

Reviewed changes

Copilot reviewed 21 out of 21 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
tests/tensor/test_nlinalg.py Fixed duplicate function definition and improved lstsq test assertions
tests/tensor/test_blockwise.py Added test for eig blockwise operation with dtype verification
tests/link/numba/test_slinalg.py Reorganized tests into classes and added empty array tests for solve/decomposition ops
tests/link/numba/test_nlinalg.py Enhanced Eig test to handle multiple dtypes and verify correctness
pytensor/tensor/slinalg.py Updated LU Op to infer correct output dtypes for discrete inputs
pytensor/tensor/nlinalg.py Updated multiple Ops (MatrixPinv, MatrixInverse, Det, SLogDet, Lstsq) to infer output dtypes and remove unnecessary casts
pytensor/link/numba/dispatch/slinalg.py Added dtype handling, empty array checks, and casting logic for Cholesky, LU, LUFactor, Solve, SolveTriangular, CholeskySolve
pytensor/link/numba/dispatch/nlinalg.py Simplified dtype handling by removing int_to_float_fn and adding direct casting logic
pytensor/link/numba/dispatch/basic.py Removed now-unused int_to_float_fn helper function
pytensor/link/numba/dispatch/linalg/utils.py Refactored _check_scipy_linalg_matrix into more flexible _check_linalg_matrix with dtype matching
pytensor/link/numba/dispatch/linalg/solve/*.py Removed .view() calls and updated to use _check_linalg_matrix
pytensor/link/numba/dispatch/linalg/decomposition/*.py Removed .view() calls, updated checks, and enhanced cholesky to handle C-contiguous inputs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants