Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE_NOTES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Changes from 2.14.1 to 2.14.2
-----------------------------

* **Under development.**
* Avoid keeping arrays passed as ``out=`` alive in the ``re_evaluate`` cache.

Changes from 2.14.0 to 2.14.1
-----------------------------
Expand Down
26 changes: 23 additions & 3 deletions numexpr/necompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import re
import sys
import threading
import weakref
from typing import Dict, Optional

import numpy
Expand Down Expand Up @@ -794,6 +795,26 @@ def getArguments(names, local_dict=None, global_dict=None, _frame_depth: int=2):
evaluate_lock = threading.Lock()


def _cache_last_kwargs(out: numpy.ndarray,
order: str,
casting: str,
ex_uses_vml: bool) -> Dict:
return {
'out': None if out is None else weakref.ref(out),
'order': order,
'casting': casting,
'ex_uses_vml': ex_uses_vml,
}
Comment thread
FrancescAlted marked this conversation as resolved.
Outdated


def _resolve_last_kwargs(kwargs: Dict) -> Dict:
kwargs = kwargs.copy()
out = kwargs.get('out')
if isinstance(out, weakref.ReferenceType):
kwargs['out'] = out()
return kwargs


def validate(ex: str,
local_dict: Optional[Dict] = None,
global_dict: Optional[Dict] = None,
Expand Down Expand Up @@ -905,8 +926,7 @@ def validate(ex: str,
compiled_ex = _numexpr_cache.c[numexpr_key]
except KeyError:
compiled_ex = _numexpr_cache.c[numexpr_key] = NumExpr(ex, signature, sanitize=sanitize, **context)
kwargs = {'out': out, 'order': order, 'casting': casting,
'ex_uses_vml': ex_uses_vml}
kwargs = _cache_last_kwargs(out, order, casting, ex_uses_vml)
_numexpr_last.l.set(ex=compiled_ex, argnames=names, kwargs=kwargs)
except Exception as e:
return e
Expand Down Expand Up @@ -1049,6 +1069,6 @@ def re_evaluate(local_dict: Optional[Dict] = None,
raise RuntimeError("A previous evaluate() execution was not found, please call `validate` or `evaluate` once before `re_evaluate`")
argnames = _numexpr_last.l['argnames']
args = getArguments(argnames, local_dict, global_dict, _frame_depth=_frame_depth)
kwargs = _numexpr_last.l['kwargs']
kwargs = _resolve_last_kwargs(_numexpr_last.l['kwargs'])
# with evaluate_lock:
return compiled_ex(*args, **kwargs)
27 changes: 26 additions & 1 deletion numexpr/tests/test_numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
# rights to use.
####################################################################


import gc
import os
import platform
import subprocess
import sys
import unittest
import warnings
import weakref
from contextlib import contextmanager
from unittest.mock import MagicMock

Expand Down Expand Up @@ -412,6 +413,30 @@ def test_re_evaluate_dict(self):
x = re_evaluate(local_dict=local_dict)
assert_array_equal(x, array([86., 124., 168.]))

def test_evaluate_out_is_not_kept_alive(self):
a = arange(1000.0)
out = zeros(a.shape)
out_ref = weakref.ref(out)

evaluate("a + 1", local_dict={"a": a}, out=out)
del out
gc.collect()

assert out_ref() is None

def test_re_evaluate_reuses_live_out(self):
a = array([1., 2., 3.])
out = zeros(a.shape)

x = evaluate("a + 1", local_dict={"a": a}, out=out)
assert x is out
assert_array_equal(out, array([2., 3., 4.]))

a = array([4., 5., 6.])
x = re_evaluate(local_dict={"a": a})
assert x is out
assert_array_equal(out, array([5., 6., 7.]))

def test_validate(self):
a = array([1., 2., 3.])
b = array([4., 5., 6.])
Expand Down
Loading