-
Notifications
You must be signed in to change notification settings - Fork 159
Optimize JoinDims and SplitDims by canonicalizing to simpler operations (Partial fixes #1843) #1847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Optimize JoinDims and SplitDims by canonicalizing to simpler operations (Partial fixes #1843) #1847
Conversation
Optimize JoinDims and SplitDims by canonicalizing to simpler operations (identity, expand_dims, squeeze). Partial fixes pymc-devs#1843
|
My guess is that ricardo meant reshape, not literally specify_shape (which you're right, just adds metadata but doesn't do any computation) |
|
I meant split dims, when the shape argument has just one entry That's what the syntax |
pytensor/tensor/rewriting/reshape.py
Outdated
| x, shape = node.inputs | ||
| axis = node.op.axis | ||
|
|
||
| if isinstance(shape, Constant) and shape.data.size == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't need to be constant just static shape of zero shape.type.shape == (0,)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I would merge this with the split-to-reshape rewrite so we don't accidentally run that before this
|
Thank you @jessegrabowski and @ricardoV94 for clarifying - so it sounds like we don't need split_dims(x, axis=axis, shape=(dim,)) → specify_shape(...) this function since it will fall into reshape anyways? I have made the changes according to the comment above. |
|
reshape should be our last resort, everything we can avoid as reshape we should |
|
To clarify, none of the changes in this PR were strictly needed, they are an improvement over simple reshape |
understood. is there anything else to do with the last function: |
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making progress, needs a few more tweaks
| # Special case: empty shape -> squeeze | ||
| if shape.type.shape == (0,): | ||
| squeezed_x = squeeze(x, axis=axis) | ||
| copy_stack_trace(x, squeezed_x) | ||
| return [squeezed_x] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is duplicated, you meant the case with shape.type.shape == (1,) I presume?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed the redundant block; im not sure how to treat the shape == 1 case without calling reshape, since specify_shape won't help?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean won't help. neither split_dims nor reshape do anything in that case, that's why it's functionally equivalent to a specify shape.
Try to run some cases of such split_dims to get acquainted with the behavior.
pytensor/tensor/rewriting/reshape.py
Outdated
|
|
||
| @register_canonicalize | ||
| @node_rewriter([JoinDims]) | ||
| def local_join_dims_noop(fgraph, node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge these join dims rewrites in a single one, like we did with SplitDims
pytensor/tensor/rewriting/reshape.py
Outdated
|
|
||
| @register_canonicalize | ||
| @node_rewriter([SplitDims]) | ||
| def local_split_dims_to_reshape(fgraph, node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we don't do only reshape, we should have a more generic name. Same for the join_dims when we merge the special cases
| def local_split_dims_to_reshape(fgraph, node): | |
| def local_lower_split_dims(fgraph, node): |
| # After rewrite: should have 0 JoinDims nodes | ||
| assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 | ||
| # Output should be equivalent to input (identity rewrite) | ||
| # The rewrite returns the input variable, so output should match input shape/type | ||
| assert fg.outputs[0].type.shape == x.type.shape | ||
| assert fg.outputs[0].type.dtype == x.type.dtype | ||
| assert fg.outputs[0].type.ndim == x.type.ndim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use utt.assert_equal_computations to check we have the specific graph that we expect, not just anything without JoinDims
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This recommendation applies to all new tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This recommendation applies to all new tests
I can't seem to get it to pass for the first 2 tests. when i looked it up, i got "assert_equal_computations is better suited for cases where the canonical form is a specific operation (like expand_dims, squeeze, or identity) where graph structures match. For basic reshape cases, the rewrite produces a different but equivalent graph structure, so structural checks are sufficient"
Please let me know how to proceed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest you look at the generated graph, the utility prints it when the assert fails. It shouldn't have anything too strange in it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest you look at the generated graph, the utility prints it when the assert fails. It shouldn't have anything too strange in it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no it passes as is; but you suggested to use utt.assert_equal_computations for all new tests; when i added that in it didn't pass so i changed it back for the first 2 tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It means there's something about the final graph you are not expecting. That's exactly why it's better to use the helper as it its more strict. Try to put it back and pay attention to the error message when it fails.
It will print the expected graph, and the one you got in a textual form. If you compare them you'll see what you were missing in the "expected result"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To help figure things out you can also set pytensor.config.optimizer_verbose = True to see if some other "surprising" rewrites are getting involved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For instance if you have known shapes for the input, then the specify_shape will itself be considered useless (by another rewrite). So if you want to see it you need to use an unknown shape like x = pt.tensor(shape=(2, None, 3) or in the call to rewrite_graph use excluding kwarg to shut down that second rewrite that judges specify_shape to be useless
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For instance if you have known shapes for the input, then the specify_shape will itself be considered useless (by another rewrite). So if you want to see it you need to use an unknown shape like x = pt.tensor(shape=(2, None, 3) or in the call to rewrite_graph use excluding kwarg to shut down that second rewrite that judges specify_shape to be useless
Description
This PR implements the 3 out of 4 canonicalization rewrites suggested in #1843:
join_dims(x, axis=axis, n_axes=1)→ identity (no-op)join_dims(x, axis=axis, n_axes=0)→expand_dims(x, axis)split_dims(x, axis=axis, shape=())→squeeze(x, axis)split_dims(x, axis=axis, shape=(dim,))→specify_shape(...)(see Block section)Questions
I tried to work on the last requested change:
The issue: specify_shape preserves the input's known shape when it's already concrete, so it doesn't match SplitDims's output type. If the input already has a known shape at a dimension, it uses that shape; and it only uses the specified shape when the input shape is None. This has caused the function to fail.
For this rewrite to work even when the input shape is known, I'd need to use reshape instead of specify_shape, but that defeats the purpose of using specify_shape for shape assertion.
Related Issue
Checklist
Type of change