-
Notifications
You must be signed in to change notification settings - Fork 160
Implement sparse dot product in numba backend #1854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…oughly in numba sparse dot
|
I'm looking for advice on the following issue: Tests are failing, at least partially, due to the fact that PyTensor does have an operator that agrees with SciPy in the (sparse, sparse) input case, and that is Detailsdef structured_dot(x, y):
"""
Structured Dot is like dot, except that only the gradient wrt non-zero elements of the sparse matrix
`a` are calculated and propagated.
The output is presumed to be a dense matrix, and is represented by a TensorType instance.
Parameters
----------
a
A sparse matrix.
b
A sparse or dense matrix.
Returns
-------
A sparse matrix
The dot product of `a` and `b`.
Notes
-----
The grad implemented is structured.
"""In there, it says the output is "presumed to be a dense matrix", and sparsity is exploited for the gradient. What do you think is the best to do? Should we...
Just in case, here I share a snippet to check output types: import scipy
import numpy as np
import pytensor
import pytensor.sparse as ps
sp_format = "csr"
x_shape = (10, 4)
y_shape = (4, 3)
x_format = y_format = sp_format
x = ps.matrix(x_format, name="x", shape=x_shape)
y = ps.matrix(y_format, name="y", shape=y_shape)
z_dot = ps.dot(x, y)
z_structured_dot = ps.structured_dot(x, y)
f_dot = pytensor.function([x, y], z_dot)
f_structured_dot = pytensor.function([x, y], z_structured_dot)
rng = np.random.default_rng(sum(map(ord, x_format)) + sum(map(ord, y_format)))
x_test = scipy.sparse.random(*x_shape, density=0.5, format=x_format, random_state=rng)
y_test = scipy.sparse.random(*y_shape, density=0.5, format=y_format, random_state=rng)
print("scipy dot", type(x_test @ y_test))
print("pytensor dot", type(f_dot(x_test, y_test)))
print("pytensor structured dot", type(f_structured_dot(x_test, y_test))) |
| # Multiplication of objects of `*_matrix` type means dot product | ||
| rval = x * y | ||
|
|
||
| if x_is_sparse and y_is_sparse: | ||
| rval = rval.toarray() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 1910-1911 explicitly convert the sparse matrix to a dense array when using pytensor.sparse.dot
|
@ricardoV94, now I have quite aggressively casted quantities to From the discussion here scipy/scipy#16774, I realized uint64 is needed for matrices that have >= 2^31 non-zero elements, which is 2,147,483,648 cells. |
|
Removing unnecessary casts to uint32 since... import numba as nb
@nb.njit
def f1(n):
acc = 0
for i in range(n):
acc += i
return acc
@nb.njit
def f2(n):
acc = 0
for i in range(n):
acc += i
return acc
@nb.njit
def f3(n):
acc = 0
for i in range(n):
acc += i
return acc
f1(nb.uint32(100_000))
f2(nb.uint64(100_000))
f3(100_000)
%timeit f1(nb.uint32(100_000))
# 1.12 μs ± 49.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
%timeit f2(nb.uint64(100_000))
# 1.11 μs ± 24.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
%timeit f3(100_000)
# 102 ns ± 1.19 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each) |
|
@ricardoV94 could you trigger CI? I cancelled it as it failed for some tooling issue. |
Scipy and us use int32 for the indices and indptr so we can never represent matrices larger than that. But that means uint32 is always safe |
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is incredible! Just a few minor comments on my part
|
|
||
| @numba_basic.numba_njit | ||
| def _spmdv(x_ptr, x_ind, x_data, y): | ||
| n_row = np.max(x_ind) + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we assume they are ordered and last one is the max? Just surprised we need a max but the two formats do behave differently all the time...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also known information, so we could actually pass it as another parameter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes more sense yeah
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I decided to pass x_shape to both _spmdv functions so we have a common signature and avoid more unnecessary if-elses.
| def transpose(matrix): | ||
| n_row, n_col = matrix.shape | ||
| return builder( | ||
| matrix.data.copy(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this PR but eventually we have to enable non copy versions of these methods when we know these are for temporary variables
I think we should agree with scipy. My impression is that we cast back to dense too much. |
Yeah but I wouldn't do it in this PR |
|
I addressed all the comments but now I have observed these errors in the tests FAILED tests/link/numba/sparse/test_basic.py::test_sparse_creation_refcount - assert 2 == 3
FAILED tests/link/numba/sparse/test_basic.py::test_sparse_passthrough_refcount - assert 2 == 3In both cases the refcount is smaller than what is has to be. |
That means we may be stealing references somewhere now |
|
@ricardoV94 @jessegrabowski, with the exception of the comment regarding the cache key, all the rest should be good. |
It's fine we can squash them |
|
I suspect the test failure is exactly due to wrong cache thing I mentioned (if it passed locally for you) |
Locally, this is what I get. From my branch: DetailsFrom main: DetailsSo I don't think the failure on refcount is related to my changes. Could it be version of Python I'm using, numba, scipy, etc? Update: I can't reproduce the refcount error in Python 3.13. Could it be this is the root cause? https://rushter.com/blog/python-refcount/ |
Yes, or the python interpreter being used. |
| # Only one of 'x' or 'y' is sparse, not both. | ||
| # Before using a general dot(sparse-matrix, dense-matrix) algorithm, | ||
| # we check if we can rely on the less intensive (sparse-matrix, dense-vector) algorithm (spmv). | ||
| y_is_1d_like = y.type.ndim == 1 or (y.type.ndim == 2 and y.type.shape[1] == 1) | ||
| x_is_1d = x.type.ndim == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The section that starts here became a bit more complicated because I also have to handle dot(dense-vector, sparse-matrix).
To my surprise, this is SciPy's behavior
import numpy as np
import scipy.sparse as sp
x = np.random.normal(size=(7, ))
y = sp.random(7, 3, density=0.3, format="csr")
print(x * y)
print(x @ y)[ 1.69110199 0.61299262 -0.86025165]
[ 1.69110199 0.61299262 -0.86025165]
which means it works as if one was doing (1, p) @ (p, k) -> (1, k).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not what dense matmul does as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not what dense matmul does as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, but I was expecting something else, given the call can be x * y which would fail in numpy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a bug in scipy sparse imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not checked the docs with enough attention, but I can imagine they will get this resolved with the new sparse array interface. I guess that array interface will be something to deal with once they make the breaking change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any link to ongoing discussion on that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this? scipy/scipy#18938
No, we're not going through np.dot, nevermind.
|
@ricardoV94, @jessegrabowski: this is the test that is failing Should I fix it here? |
no need |
|
Thanks! Next thing on my plate will be the gradients and also more support in JAX |
This PR implements the dot product for sparse matrices in the numba backend.