Skip to content
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
### Added
### Fixed
### Changed
- Speed up `Term.__eq__` via the C-level API
### Removed
- Removed `Term.ptrtuple` to optimize `Term` memory usage

## 6.1.0 - 2026.01.31
### Added
Expand Down
32 changes: 23 additions & 9 deletions src/pyscipopt/expr.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,37 @@ cdef class Term:
'''This is a monomial term'''

cdef readonly tuple vartuple
cdef readonly tuple ptrtuple
cdef Py_ssize_t hashval

def __init__(self, *vartuple: Variable):
self.vartuple = tuple(sorted(vartuple, key=lambda v: v.ptr()))
self.ptrtuple = tuple(v.ptr() for v in self.vartuple)
self.hashval = <Py_ssize_t>hash(self.ptrtuple)
self.vartuple = tuple(sorted(vartuple, key=hash))
self.hashval = <Py_ssize_t>hash(self.vartuple)

def __getitem__(self, idx):
return self.vartuple[idx]

def __hash__(self) -> Py_ssize_t:
return self.hashval

def __eq__(self, other: Term):
return self.ptrtuple == other.ptrtuple
def __eq__(self, other) -> bool:
if other is self:
return True
if <type>Py_TYPE(other) is not Term:
return False

cdef int n = len(self)
cdef Term _other = <Term>other
if n != len(_other) or self.hashval != _other.hashval:
return False

cdef int i
cdef Variable var1, var2
for i in range(n):
var1 = <Variable>PyTuple_GET_ITEM(self.vartuple, i)
var2 = <Variable>PyTuple_GET_ITEM(_other.vartuple, i)
if hash(var1) != hash(var2):
return False
return True

def __len__(self):
return len(self.vartuple)
Expand All @@ -138,7 +153,7 @@ cdef class Term:
while i < n1 and j < n2:
var1 = <Variable>PyTuple_GET_ITEM(self.vartuple, i)
var2 = <Variable>PyTuple_GET_ITEM(other.vartuple, j)
if var1.ptr() <= var2.ptr():
if hash(var1) <= hash(var2):
vartuple[k] = var1
i += 1
else:
Expand All @@ -156,8 +171,7 @@ cdef class Term:

cdef Term res = Term.__new__(Term)
res.vartuple = tuple(vartuple)
res.ptrtuple = tuple(v.ptr() for v in res.vartuple)
res.hashval = <Py_ssize_t>hash(res.ptrtuple)
res.hashval = <Py_ssize_t>hash(res.vartuple)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hash(res.vartuple) reqires Variable.__hash__.

return res

def __repr__(self):
Expand Down
9 changes: 7 additions & 2 deletions src/pyscipopt/scip.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@
if rc == SCIP_OKAY:
pass
elif rc == SCIP_ERROR:
raise Exception('SCIP: unspecified error!')

Check failure on line 319 in src/pyscipopt/scip.pxi

View workflow job for this annotation

GitHub Actions / test-coverage (3.11)

SCIP: unspecified error!
elif rc == SCIP_NOMEMORY:
raise MemoryError('SCIP: insufficient memory error!')
elif rc == SCIP_READERROR:
Expand All @@ -335,7 +335,7 @@
raise Exception('SCIP: method cannot be called at this time'
+ ' in solution process!')
elif rc == SCIP_INVALIDDATA:
raise Exception('SCIP: error in input data!')

Check failure on line 338 in src/pyscipopt/scip.pxi

View workflow job for this annotation

GitHub Actions / test-coverage (3.11)

SCIP: error in input data!
elif rc == SCIP_INVALIDRESULT:
raise Exception('SCIP: method returned an invalid result code!')
elif rc == SCIP_PLUGINNOTFOUND:
Expand Down Expand Up @@ -1558,10 +1558,15 @@
cname = bytes( SCIPvarGetName(self.scip_var) )
return cname.decode('utf-8')

def ptr(self):
""" """
def __hash__(self):
return <size_t>(self.scip_var)

def ptr(self):
return self.__hash__()

def __richcmp__(self, other, int op):
return _expr_richcmp(self, other, op)

def __repr__(self):
return self.name

Expand Down
1 change: 0 additions & 1 deletion src/pyscipopt/scip.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2190,7 +2190,6 @@ class SumExpr(GenExpr):

@disjoint_base
class Term:
ptrtuple: Incomplete
vartuple: Incomplete
def __init__(self, *vartuple: Incomplete) -> None: ...
def __mul__(self, other: Term) -> Term: ...
Expand Down
23 changes: 22 additions & 1 deletion tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from pyscipopt import Model, sqrt, log, exp, sin, cos
from pyscipopt import Model, sqrt, log, exp, sin, cos, quickprod
from pyscipopt.scip import Expr, GenExpr, ExprCons, CONST


Expand Down Expand Up @@ -244,3 +244,24 @@ def test_abs_abs_expr():

# should print abs(x) not abs(abs(x))
assert str(abs(abs(x))) == str(abs(x))


def test_term_eq():
m = Model()

x = m.addMatrixVar(1000)
y = m.addVar()
z = m.addVar()

e1 = quickprod(x.flat)
e2 = quickprod(x.flat)
t1 = next(iter(e1))
t2 = next(iter(e2))
t3 = next(iter(e1 * y))
t4 = next(iter(e2 * z))

assert t1 == t1 # same term
assert t1 == t2 # same term
assert t3 != t4 # same length, but different term
assert t1 != t3 # different length
assert t1 != "not a term" # different type