Skip to content

Conversation

@tomicapretto
Copy link
Contributor

This PR implements the dot product for sparse matrices in the numba backend.

@tomicapretto
Copy link
Contributor Author

tomicapretto commented Jan 20, 2026

I'm looking for advice on the following issue:

Tests are failing, at least partially, due to the fact that pytensor.sparse.dot and SciPy's sparse dot return a different type when both inputs are sparse. While SciPy's return a sparse matrix, PyTensor returns a dense array. When one of the outputs is dense, both output types agree, being dense.

PyTensor does have an operator that agrees with SciPy in the (sparse, sparse) input case, and that is structured_dot. However, its docs mention something different

https://github.com/tomicapretto/pytensor/blob/8c28ffaeefea2ee6cb18e52c10ad67fe6cb048b2/pytensor/sparse/math.py#L1411-L1434

Details
def structured_dot(x, y):
    """
    Structured Dot is like dot, except that only the gradient wrt non-zero elements of the sparse matrix
    `a` are calculated and propagated.

    The output is presumed to be a dense matrix, and is represented by a TensorType instance.

    Parameters
    ----------
    a
        A sparse matrix.
    b
        A sparse or dense matrix.

    Returns
    -------
    A sparse matrix
        The dot product of `a` and `b`.

    Notes
    -----
    The grad implemented is structured.

    """

In there, it says the output is "presumed to be a dense matrix", and sparsity is exploited for the gradient.


What do you think is the best to do? Should we...

  • Keep dot return type as it is, updating how tests are evaluated, potentially updating documentation.
  • Update dot to return sparse when both inputs are sparse, potentially breaking code.
  • ... something else?

Just in case, here I share a snippet to check output types:

import scipy
import numpy as np

import pytensor
import pytensor.sparse as ps

sp_format = "csr"
x_shape = (10, 4)
y_shape = (4, 3)

x_format = y_format = sp_format

x = ps.matrix(x_format, name="x", shape=x_shape)
y = ps.matrix(y_format, name="y", shape=y_shape)

z_dot = ps.dot(x, y)
z_structured_dot = ps.structured_dot(x, y)

f_dot = pytensor.function([x, y], z_dot)
f_structured_dot = pytensor.function([x, y], z_structured_dot)

rng = np.random.default_rng(sum(map(ord, x_format)) + sum(map(ord, y_format)))
x_test = scipy.sparse.random(*x_shape, density=0.5, format=x_format, random_state=rng)
y_test = scipy.sparse.random(*y_shape, density=0.5, format=y_format, random_state=rng)

print("scipy dot", type(x_test @ y_test))
print("pytensor dot", type(f_dot(x_test, y_test)))
print("pytensor structured dot", type(f_structured_dot(x_test, y_test)))

Comment on lines 1907 to 1911
# Multiplication of objects of `*_matrix` type means dot product
rval = x * y

if x_is_sparse and y_is_sparse:
rval = rval.toarray()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 1910-1911 explicitly convert the sparse matrix to a dense array when using pytensor.sparse.dot

@tomicapretto
Copy link
Contributor Author

tomicapretto commented Jan 20, 2026

@ricardoV94, now I have quite aggressively casted quantities to uint32, but it's worth discussing whether we want to have uint32, uint64, or both.

From the discussion here scipy/scipy#16774, I realized uint64 is needed for matrices that have >= 2^31 non-zero elements, which is 2,147,483,648 cells.

@tomicapretto tomicapretto marked this pull request as ready for review January 20, 2026 18:48
@tomicapretto
Copy link
Contributor Author

Removing unnecessary casts to uint32 since...

import numba as nb

@nb.njit
def f1(n):
    acc = 0
    for i in range(n):
        acc += i
    return acc

@nb.njit
def f2(n):
    acc = 0
    for i in range(n):
        acc += i
    return acc

@nb.njit
def f3(n):
    acc = 0
    for i in range(n):
        acc += i
    return acc


f1(nb.uint32(100_000))
f2(nb.uint64(100_000))
f3(100_000)

%timeit f1(nb.uint32(100_000))
# 1.12 μs ± 49.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
%timeit f2(nb.uint64(100_000))
# 1.11 μs ± 24.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
%timeit f3(100_000)
# 102 ns ± 1.19 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

@tomicapretto
Copy link
Contributor Author

@ricardoV94 could you trigger CI? I cancelled it as it failed for some tooling issue.

@ricardoV94
Copy link
Member

@ricardoV94, now I have quite aggressively casted quantities to uint32, but it's worth discussing whether we want to have uint32, uint64, or both.

From the discussion here scipy/scipy#16774, I realized uint64 is needed for matrices that have >= 2^31 non-zero elements, which is 2,147,483,648 cells.

Scipy and us use int32 for the indices and indptr so we can never represent matrices larger than that. But that means uint32 is always safe

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is incredible! Just a few minor comments on my part


@numba_basic.numba_njit
def _spmdv(x_ptr, x_ind, x_data, y):
n_row = np.max(x_ind) + 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we assume they are ordered and last one is the max? Just surprised we need a max but the two formats do behave differently all the time...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also known information, so we could actually pass it as another parameter

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes more sense yeah

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to pass x_shape to both _spmdv functions so we have a common signature and avoid more unnecessary if-elses.

def transpose(matrix):
n_row, n_col = matrix.shape
return builder(
matrix.data.copy(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR but eventually we have to enable non copy versions of these methods when we know these are for temporary variables

@jessegrabowski
Copy link
Member

PyTensor does have an operator that agrees with SciPy in the (sparse, sparse) input case, and that is structured_dot. However, its docs mention something different

I think we should agree with scipy. My impression is that we cast back to dense too much. structured_dot might not even be the right abstraction in the end.

@ricardoV94
Copy link
Member

PyTensor does have an operator that agrees with SciPy in the (sparse, sparse) input case, and that is structured_dot. However, its docs mention something different

I think we should agree with scipy. My impression is that we cast back to dense too much. structured_dot might not even be the right abstraction in the end.

Yeah but I wouldn't do it in this PR

@tomicapretto
Copy link
Contributor Author

I addressed all the comments but now I have observed these errors in the tests

FAILED tests/link/numba/sparse/test_basic.py::test_sparse_creation_refcount - assert 2 == 3
FAILED tests/link/numba/sparse/test_basic.py::test_sparse_passthrough_refcount - assert 2 == 3

In both cases the refcount is smaller than what is has to be.

@ricardoV94
Copy link
Member

I addressed all the comments but now I have observed these errors in the tests

FAILED tests/link/numba/sparse/test_basic.py::test_sparse_creation_refcount - assert 2 == 3
FAILED tests/link/numba/sparse/test_basic.py::test_sparse_passthrough_refcount - assert 2 == 3

In both cases the refcount is smaller than what is has to be.

That means we may be stealing references somewhere now

@tomicapretto
Copy link
Contributor Author

tomicapretto commented Jan 21, 2026

@ricardoV94 @jessegrabowski, with the exception of the comment regarding the cache key, all the rest should be good.
Let me know if you want me to move commits around (or do more work on this). I remember in the past I was asked to group some commits so history was more readable. I initially attempted to do so, but then... things happen :)

@ricardoV94
Copy link
Member

@ricardoV94 @jessegrabowski, with the exception of the comment regarding the cache key, all the rest should be good. Let me know if you want me to move commits around (or do more work on this). I remember in the past I was asked to group some commits so history was more readable. I initially attempted to do so, but then... things happen :)

It's fine we can squash them

@ricardoV94
Copy link
Member

I suspect the test failure is exactly due to wrong cache thing I mentioned (if it passed locally for you)

@tomicapretto
Copy link
Contributor Author

tomicapretto commented Jan 21, 2026

I suspect the test failure is exactly due to wrong cache thing I mentioned (if it passed locally for you)

Locally, this is what I get.

From my branch:

Details
(pytensor-dev) tomas@tomas:~/oss/pymc-devs/pytensor$ pytensor-cache clear && pytest tests/link/numba/sparse/
=================================================================================================================== test session starts ====================================================================================================================
platform linux -- Python 3.14.2, pytest-9.0.2, pluggy-1.6.0
benchmark: 5.2.3 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /home/tomas/oss/pymc-devs/pytensor
configfile: pyproject.toml
plugins: benchmark-5.2.3, sphinx-0.6.3, cov-7.0.0, mock-3.15.1, xdist-3.8.0
collected 62 items                                                                                                                                                                                                                                         

tests/link/numba/sparse/test_basic.py .FF...................                                                                                                                                                                                         [ 35%]
tests/link/numba/sparse/test_math.py ........................................                                                                                                                                                                        [100%]

========================================================================================================================= FAILURES =========================================================================================================================
______________________________________________________________________________________________________________ test_sparse_creation_refcount _______________________________________________________________________________________________________________

    def test_sparse_creation_refcount():
        @numba.njit
        def create_csr_matrix(data, indices, ind_ptr):
            return scipy.sparse.csr_matrix((data, indices, ind_ptr), shape=(5, 5))
    
        x = scipy.sparse.random(5, 5, density=0.5, format="csr")
    
        x_data = x.data
        x_indptr = x.indptr
>       assert getrefcount(x_data) == 3
E       assert 2 == 3
E        +  where 2 = getrefcount(array([0.95993326, 0.21163082, 0.23359326, 0.59663578, 0.69842303,\n       0.70289209, 0.63291838, 0.33894682, 0.85465412, 0.9203718 ,\n       0.35885166, 0.68967354]))

tests/link/numba/sparse/test_basic.py:84: AssertionError
_____________________________________________________________________________________________________________ test_sparse_passthrough_refcount _____________________________________________________________________________________________________________

    def test_sparse_passthrough_refcount():
        @numba.njit
        def identity(a):
            return a
    
        x = scipy.sparse.random(5, 5, density=0.5, format="csr")
    
        x_data = x.data
>       assert getrefcount(x_data) == 3
E       assert 2 == 3
E        +  where 2 = getrefcount(array([0.41290891, 0.95558254, 0.80964871, 0.73671294, 0.71546092,\n       0.89633048, 0.23127851, 0.47071729, 0.65602838, 0.18017573,\n       0.37196986, 0.32135243]))

tests/link/numba/sparse/test_basic.py:108: AssertionError
=================================================================================================================== slowest 50 durations ===================================================================================================================
2.38s call     tests/link/numba/sparse/test_math.py::test_dot_sparse_dense[x_shape0-y_shape0-csr-dot]
1.95s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape0-y_shape0-csr-csc-structured_dot]
1.88s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape0-y_shape0-csr-csr-dot]
1.47s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape0-y_shape0-csr-csc-dot]
1.31s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape0-y_shape0-csr-csr-structured_dot]
1.19s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape1-y_shape1-csc-csc-structured_dot]
1.17s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape1-y_shape1-csc-csr-structured_dot]
1.16s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape0-y_shape0-csc-csr-dot]
1.13s call     tests/link/numba/sparse/test_math.py::test_dot_dense_sparse[x_shape1-y_shape1-csc-dot]
1.13s call     tests/link/numba/sparse/test_math.py::test_dot_dense_sparse[x_shape1-y_shape1-csr-dot]
1.12s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape1-y_shape1-csr-csr-dot]
1.05s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape1-y_shape1-csr-csc-structured_dot]
1.05s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape1-y_shape1-csr-csc-dot]
1.00s call     tests/link/numba/sparse/test_math.py::test_dot_dense_sparse[x_shape0-y_shape0-csr-dot]
0.99s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape1-y_shape1-csc-csc-dot]
0.99s call     tests/link/numba/sparse/test_math.py::test_dot_dense_sparse[x_shape0-y_shape0-csc-dot]
0.98s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape0-y_shape0-csc-csc-structured_dot]
0.96s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape1-y_shape1-csc-csr-dot]
0.96s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape0-y_shape0-csc-csr-structured_dot]
0.94s call     tests/link/numba/sparse/test_math.py::test_dot_dense_sparse[x_shape0-y_shape0-csc-structured_dot]
0.93s call     tests/link/numba/sparse/test_math.py::test_dot_dense_sparse[x_shape0-y_shape0-csr-structured_dot]
0.87s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape0-y_shape0-csc-csc-dot]
0.84s call     tests/link/numba/sparse/test_math.py::test_sparse_dot_sparse_sparse[x_shape1-y_shape1-csr-csr-structured_dot]
0.67s call     tests/link/numba/sparse/test_math.py::test_dot_sparse_dense[x_shape0-y_shape0-csr-structured_dot]
0.63s call     tests/link/numba/sparse/test_basic.py::test_simple_graph[csr]
0.61s call     tests/link/numba/sparse/test_basic.py::test_sparse_dense_from_sparse[csr]
0.59s call     tests/link/numba/sparse/test_math.py::test_dot_sparse_dense[x_shape0-y_shape0-csc-structured_dot]
0.58s call     tests/link/numba/sparse/test_math.py::test_dot_sparse_dense[x_shape1-y_shape1-csc-dot]
0.56s call     tests/link/numba/sparse/test_math.py::test_dot_sparse_dense[x_shape0-y_shape0-csc-dot]
0.39s call     tests/link/numba/sparse/test_math.py::test_dot_sparse_dense[x_shape1-y_shape1-csr-dot]
0.38s call     tests/link/numba/sparse/test_basic.py::test_simple_graph[csc]
0.38s call     tests/link/numba/sparse/test_basic.py::test_sparse_constant[csr-True]
0.38s call     tests/link/numba/sparse/test_math.py::test_dot_sparse_dense[x_shape1-y_shape1-csc-structured_dot]
0.36s call     tests/link/numba/sparse/test_basic.py::test_sparse_constant[csc-True]
0.34s call     tests/link/numba/sparse/test_basic.py::test_sparse_dense_from_sparse[csc]
0.34s call     tests/link/numba/sparse/test_math.py::test_dot_sparse_dense[x_shape1-y_shape1-csr-structured_dot]
0.33s call     tests/link/numba/sparse/test_math.py::test_sparse_spmv[csc]
0.30s call     tests/link/numba/sparse/test_math.py::test_sparse_spmv[csr]
0.29s call     tests/link/numba/sparse/test_math.py::test_sparse_dense_multiply[2-csc]
0.29s call     tests/link/numba/sparse/test_math.py::test_sparse_dense_multiply[2-csr]
0.28s call     tests/link/numba/sparse/test_basic.py::test_sparse_deepcopy[csr]
0.28s call     tests/link/numba/sparse/test_basic.py::test_sparse_deepcopy[csc]
0.27s call     tests/link/numba/sparse/test_basic.py::test_sparse_boxing
0.23s call     tests/link/numba/sparse/test_math.py::test_sparse_dense_multiply[0-csr]
0.22s call     tests/link/numba/sparse/test_math.py::test_sparse_dense_multiply[0-csc]
0.21s call     tests/link/numba/sparse/test_basic.py::test_sparse_constant[csr-False]
0.19s call     tests/link/numba/sparse/test_basic.py::test_sparse_constant[csc-False]
0.16s call     tests/link/numba/sparse/test_basic.py::test_sparse_copy
0.16s call     tests/link/numba/sparse/test_basic.py::test_sparse_objmode[False-csr]
0.16s call     tests/link/numba/sparse/test_basic.py::test_sparse_objmode[False-csc]
================================================================================================================= short test summary info ==================================================================================================================
FAILED tests/link/numba/sparse/test_basic.py::test_sparse_creation_refcount - assert 2 == 3
FAILED tests/link/numba/sparse/test_basic.py::test_sparse_passthrough_refcount - assert 2 == 3
============================================================================================================== 2 failed, 60 passed in 40.13s ===============================================================================================================

From main:

Details
(pytensor-dev) tomas@tomas:~/oss/pymc-devs/pytensor$ git checkout main
Switched to branch 'main'
Your branch is up to date with 'origin/main'.
(pytensor-dev) tomas@tomas:~/oss/pymc-devs/pytensor$ pytensor-cache clear && pytest tests/link/numba/sparse/
=================================================================================================================== test session starts ====================================================================================================================
platform linux -- Python 3.14.2, pytest-9.0.2, pluggy-1.6.0
benchmark: 5.2.3 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /home/tomas/oss/pymc-devs/pytensor
configfile: pyproject.toml
plugins: benchmark-5.2.3, sphinx-0.6.3, cov-7.0.0, mock-3.15.1, xdist-3.8.0
collected 26 items                                                                                                                                                                                                                                         

tests/link/numba/sparse/test_basic.py .FF.................                                                                                                                                                                                           [ 76%]
tests/link/numba/sparse/test_math.py ......                                                                                                                                                                                                          [100%]

========================================================================================================================= FAILURES =========================================================================================================================
______________________________________________________________________________________________________________ test_sparse_creation_refcount _______________________________________________________________________________________________________________

    def test_sparse_creation_refcount():
        @numba.njit
        def create_csr_matrix(data, indices, ind_ptr):
            return scipy.sparse.csr_matrix((data, indices, ind_ptr), shape=(5, 5))
    
        x = scipy.sparse.random(5, 5, density=0.5, format="csr")
    
        x_data = x.data
        x_indptr = x.indptr
>       assert getrefcount(x_data) == 3
E       assert 2 == 3
E        +  where 2 = getrefcount(array([0.30044972, 0.40076953, 0.6897057 , 0.45557084, 0.17461931,\n       0.49249511, 0.32707842, 0.861761  , 0.19650113, 0.03299047,\n       0.97835237, 0.58165045]))

tests/link/numba/sparse/test_basic.py:78: AssertionError
_____________________________________________________________________________________________________________ test_sparse_passthrough_refcount _____________________________________________________________________________________________________________

    def test_sparse_passthrough_refcount():
        @numba.njit
        def identity(a):
            return a
    
        x = scipy.sparse.random(5, 5, density=0.5, format="csr")
    
        x_data = x.data
>       assert getrefcount(x_data) == 3
E       assert 2 == 3
E        +  where 2 = getrefcount(array([0.07933104, 0.23696602, 0.64010685, 0.27453624, 0.50998838,\n       0.76014696, 0.26097033, 0.50361755, 0.79347254, 0.23723223,\n       0.51087027, 0.76441449]))

tests/link/numba/sparse/test_basic.py:102: AssertionError
=================================================================================================================== slowest 50 durations ===================================================================================================================
0.67s call     tests/link/numba/sparse/test_basic.py::test_simple_graph[csr]
0.39s call     tests/link/numba/sparse/test_basic.py::test_sparse_constant[csr-True]
0.39s call     tests/link/numba/sparse/test_basic.py::test_sparse_constant[csc-True]
0.36s call     tests/link/numba/sparse/test_basic.py::test_simple_graph[csc]
0.31s call     tests/link/numba/sparse/test_math.py::test_sparse_dense_multiply[0-csr]
0.30s call     tests/link/numba/sparse/test_basic.py::test_sparse_deepcopy[csr]
0.28s call     tests/link/numba/sparse/test_math.py::test_sparse_dense_multiply[2-csc]
0.28s call     tests/link/numba/sparse/test_math.py::test_sparse_dense_multiply[2-csr]
0.27s call     tests/link/numba/sparse/test_basic.py::test_sparse_deepcopy[csc]
0.26s call     tests/link/numba/sparse/test_basic.py::test_sparse_boxing
0.23s call     tests/link/numba/sparse/test_math.py::test_sparse_dense_multiply[0-csc]
0.21s call     tests/link/numba/sparse/test_basic.py::test_sparse_constant[csr-False]
0.18s call     tests/link/numba/sparse/test_basic.py::test_sparse_constant[csc-False]
0.17s call     tests/link/numba/sparse/test_math.py::test_sparse_dense_multiply[1-csr]
0.17s call     tests/link/numba/sparse/test_basic.py::test_sparse_copy
0.16s call     tests/link/numba/sparse/test_basic.py::test_sparse_objmode[False-csc]
0.16s call     tests/link/numba/sparse/test_basic.py::test_sparse_objmode[True-csr]
0.15s call     tests/link/numba/sparse/test_basic.py::test_sparse_objmode[False-csr]
0.15s call     tests/link/numba/sparse/test_basic.py::test_sparse_objmode[True-csc]
0.14s call     tests/link/numba/sparse/test_math.py::test_sparse_dense_multiply[1-csc]
0.13s call     tests/link/numba/sparse/test_basic.py::test_sparse_constructor[csr]
0.11s call     tests/link/numba/sparse/test_basic.py::test_sparse_constructor[csc]
0.08s call     tests/link/numba/sparse/test_basic.py::test_sparse_ndim
0.08s call     tests/link/numba/sparse/test_basic.py::test_sparse_shape

(26 durations < 0.005s hidden.  Use -vv to show these durations.)
================================================================================================================= short test summary info ==================================================================================================================
FAILED tests/link/numba/sparse/test_basic.py::test_sparse_creation_refcount - assert 2 == 3
FAILED tests/link/numba/sparse/test_basic.py::test_sparse_passthrough_refcount - assert 2 == 3
=============================================================================================================== 2 failed, 24 passed in 7.46s ===============================================================================================================

So I don't think the failure on refcount is related to my changes. Could it be version of Python I'm using, numba, scipy, etc?


Update: I can't reproduce the refcount error in Python 3.13. Could it be this is the root cause? https://rushter.com/blog/python-refcount/

@ricardoV94
Copy link
Member

Could it be version of Python I'm using, numba, scipy, etc?

Yes, or the python interpreter being used.

Comment on lines +238 to +242
# Only one of 'x' or 'y' is sparse, not both.
# Before using a general dot(sparse-matrix, dense-matrix) algorithm,
# we check if we can rely on the less intensive (sparse-matrix, dense-vector) algorithm (spmv).
y_is_1d_like = y.type.ndim == 1 or (y.type.ndim == 2 and y.type.shape[1] == 1)
x_is_1d = x.type.ndim == 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The section that starts here became a bit more complicated because I also have to handle dot(dense-vector, sparse-matrix).

To my surprise, this is SciPy's behavior

import numpy as np
import scipy.sparse as sp

x = np.random.normal(size=(7, ))
y = sp.random(7, 3, density=0.3, format="csr")

print(x * y)
print(x @ y)
[ 1.69110199  0.61299262 -0.86025165]
[ 1.69110199  0.61299262 -0.86025165]

which means it works as if one was doing (1, p) @ (p, k) -> (1, k).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not what dense matmul does as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not what dense matmul does as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, but I was expecting something else, given the call can be x * y which would fail in numpy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a bug in scipy sparse imo.

Copy link
Contributor Author

@tomicapretto tomicapretto Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not checked the docs with enough attention, but I can imagine they will get this resolved with the new sparse array interface. I guess that array interface will be something to deal with once they make the breaking change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any link to ongoing discussion on that?

Copy link
Member

@jessegrabowski jessegrabowski Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this? scipy/scipy#18938

No, we're not going through np.dot, nevermind.

@tomicapretto
Copy link
Contributor Author

@ricardoV94, @jessegrabowski: this is the test that is failing

FAILED tests/link/jax/test_shape.py::test_jax_Reshape_shape_graph_input - [XPASS(strict)] `shape_pt` should be specified as a static argument

Should I fix it here?

@ricardoV94
Copy link
Member

@ricardoV94, @jessegrabowski: this is the test that is failing

FAILED tests/link/jax/test_shape.py::test_jax_Reshape_shape_graph_input - [XPASS(strict)] `shape_pt` should be specified as a static argument

Should I fix it here?

no need

@ricardoV94 ricardoV94 merged commit 207b0c6 into pymc-devs:main Jan 22, 2026
64 of 66 checks passed
@tomicapretto
Copy link
Contributor Author

Thanks! Next thing on my plate will be the gradients and also more support in JAX

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants