diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e63edb73..3a2dfc46f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,9 @@ ### Added ### Fixed ### Changed +- Speed up `Term.__eq__` via the C-level API ### Removed +- Removed `Term.ptrtuple` to optimize `Term` memory usage ## 6.1.0 - 2026.01.31 ### Added diff --git a/src/pyscipopt/expr.pxi b/src/pyscipopt/expr.pxi index f782a46da..fe9fa6e49 100644 --- a/src/pyscipopt/expr.pxi +++ b/src/pyscipopt/expr.pxi @@ -104,13 +104,11 @@ cdef class Term: '''This is a monomial term''' cdef readonly tuple vartuple - cdef readonly tuple ptrtuple cdef Py_ssize_t hashval def __init__(self, *vartuple: Variable): - self.vartuple = tuple(sorted(vartuple, key=lambda v: v.ptr())) - self.ptrtuple = tuple(v.ptr() for v in self.vartuple) - self.hashval = hash(self.ptrtuple) + self.vartuple = tuple(sorted(vartuple, key=hash)) + self.hashval = hash(self.vartuple) def __getitem__(self, idx): return self.vartuple[idx] @@ -118,8 +116,25 @@ cdef class Term: def __hash__(self) -> Py_ssize_t: return self.hashval - def __eq__(self, other: Term): - return self.ptrtuple == other.ptrtuple + def __eq__(self, other) -> bool: + if other is self: + return True + if Py_TYPE(other) is not Term: + return False + + cdef int n = len(self) + cdef Term _other = other + if n != len(_other) or self.hashval != _other.hashval: + return False + + cdef int i + cdef Variable var1, var2 + for i in range(n): + var1 = PyTuple_GET_ITEM(self.vartuple, i) + var2 = PyTuple_GET_ITEM(_other.vartuple, i) + if hash(var1) != hash(var2): + return False + return True def __len__(self): return len(self.vartuple) @@ -138,7 +153,7 @@ cdef class Term: while i < n1 and j < n2: var1 = PyTuple_GET_ITEM(self.vartuple, i) var2 = PyTuple_GET_ITEM(other.vartuple, j) - if var1.ptr() <= var2.ptr(): + if hash(var1) <= hash(var2): vartuple[k] = var1 i += 1 else: @@ -156,8 +171,7 @@ cdef class Term: cdef Term res = Term.__new__(Term) res.vartuple = tuple(vartuple) - res.ptrtuple = tuple(v.ptr() for v in res.vartuple) - res.hashval = hash(res.ptrtuple) + res.hashval = hash(res.vartuple) return res def __repr__(self): diff --git a/src/pyscipopt/scip.pxi b/src/pyscipopt/scip.pxi index 9c21942bd..d1fd20021 100644 --- a/src/pyscipopt/scip.pxi +++ b/src/pyscipopt/scip.pxi @@ -1558,10 +1558,15 @@ cdef class Variable(Expr): cname = bytes( SCIPvarGetName(self.scip_var) ) return cname.decode('utf-8') - def ptr(self): - """ """ + def __hash__(self): return (self.scip_var) + def ptr(self): + return self.__hash__() + + def __richcmp__(self, other, int op): + return _expr_richcmp(self, other, op) + def __repr__(self): return self.name diff --git a/src/pyscipopt/scip.pyi b/src/pyscipopt/scip.pyi index e8bbc46d5..b8ad35a52 100644 --- a/src/pyscipopt/scip.pyi +++ b/src/pyscipopt/scip.pyi @@ -2190,7 +2190,6 @@ class SumExpr(GenExpr): @disjoint_base class Term: - ptrtuple: Incomplete vartuple: Incomplete def __init__(self, *vartuple: Incomplete) -> None: ... def __mul__(self, other: Term) -> Term: ... diff --git a/tests/test_expr.py b/tests/test_expr.py index a4e739b76..aa13d8b13 100644 --- a/tests/test_expr.py +++ b/tests/test_expr.py @@ -2,7 +2,7 @@ import pytest -from pyscipopt import Model, sqrt, log, exp, sin, cos +from pyscipopt import Model, sqrt, log, exp, sin, cos, quickprod from pyscipopt.scip import Expr, GenExpr, ExprCons, CONST @@ -244,3 +244,24 @@ def test_abs_abs_expr(): # should print abs(x) not abs(abs(x)) assert str(abs(abs(x))) == str(abs(x)) + + +def test_term_eq(): + m = Model() + + x = m.addMatrixVar(1000) + y = m.addVar() + z = m.addVar() + + e1 = quickprod(x.flat) + e2 = quickprod(x.flat) + t1 = next(iter(e1)) + t2 = next(iter(e2)) + t3 = next(iter(e1 * y)) + t4 = next(iter(e2 * z)) + + assert t1 == t1 # same term + assert t1 == t2 # same term + assert t3 != t4 # same length, but different term + assert t1 != t3 # different length + assert t1 != "not a term" # different type