add read me
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from .core import dispatch
|
||||
from .dispatcher import (Dispatcher, halt_ordering, restart_ordering,
|
||||
MDNotImplementedError)
|
||||
|
||||
__version__ = '0.4.9'
|
||||
|
||||
__all__ = [
|
||||
'dispatch',
|
||||
|
||||
'Dispatcher', 'halt_ordering', 'restart_ordering', 'MDNotImplementedError',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,68 @@
|
||||
from .utils import _toposort, groupby
|
||||
|
||||
class AmbiguityWarning(Warning):
|
||||
pass
|
||||
|
||||
|
||||
def supercedes(a, b):
|
||||
""" A is consistent and strictly more specific than B """
|
||||
return len(a) == len(b) and all(map(issubclass, a, b))
|
||||
|
||||
|
||||
def consistent(a, b):
|
||||
""" It is possible for an argument list to satisfy both A and B """
|
||||
return (len(a) == len(b) and
|
||||
all(issubclass(aa, bb) or issubclass(bb, aa)
|
||||
for aa, bb in zip(a, b)))
|
||||
|
||||
|
||||
def ambiguous(a, b):
|
||||
""" A is consistent with B but neither is strictly more specific """
|
||||
return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
|
||||
|
||||
|
||||
def ambiguities(signatures):
|
||||
""" All signature pairs such that A is ambiguous with B """
|
||||
signatures = list(map(tuple, signatures))
|
||||
return {(a, b) for a in signatures for b in signatures
|
||||
if hash(a) < hash(b)
|
||||
and ambiguous(a, b)
|
||||
and not any(supercedes(c, a) and supercedes(c, b)
|
||||
for c in signatures)}
|
||||
|
||||
|
||||
def super_signature(signatures):
|
||||
""" A signature that would break ambiguities """
|
||||
n = len(signatures[0])
|
||||
assert all(len(s) == n for s in signatures)
|
||||
|
||||
return [max([type.mro(sig[i]) for sig in signatures], key=len)[0]
|
||||
for i in range(n)]
|
||||
|
||||
|
||||
def edge(a, b, tie_breaker=hash):
|
||||
""" A should be checked before B
|
||||
|
||||
Tie broken by tie_breaker, defaults to ``hash``
|
||||
"""
|
||||
if supercedes(a, b):
|
||||
if supercedes(b, a):
|
||||
return tie_breaker(a) > tie_breaker(b)
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def ordering(signatures):
|
||||
""" A sane ordering of signatures to check, first to last
|
||||
|
||||
Topoological sort of edges as given by ``edge`` and ``supercedes``
|
||||
"""
|
||||
signatures = list(map(tuple, signatures))
|
||||
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
|
||||
edges = groupby(lambda x: x[0], edges)
|
||||
for s in signatures:
|
||||
if s not in edges:
|
||||
edges[s] = []
|
||||
edges = {k: [b for a, b in v] for k, v in edges.items()}
|
||||
return _toposort(edges)
|
||||
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
import inspect
|
||||
|
||||
from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn
|
||||
|
||||
# XXX: This parameter to dispatch isn't documented and isn't used anywhere in
|
||||
# sympy. Maybe it should just be removed.
|
||||
global_namespace: dict[str, Any] = {}
|
||||
|
||||
|
||||
def dispatch(*types, namespace=global_namespace, on_ambiguity=ambiguity_warn):
|
||||
""" Dispatch function on the types of the inputs
|
||||
|
||||
Supports dispatch on all non-keyword arguments.
|
||||
|
||||
Collects implementations based on the function name. Ignores namespaces.
|
||||
|
||||
If ambiguous type signatures occur a warning is raised when the function is
|
||||
defined suggesting the additional method to break the ambiguity.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> from sympy.multipledispatch import dispatch
|
||||
>>> @dispatch(int)
|
||||
... def f(x):
|
||||
... return x + 1
|
||||
|
||||
>>> @dispatch(float)
|
||||
... def f(x): # noqa: F811
|
||||
... return x - 1
|
||||
|
||||
>>> f(3)
|
||||
4
|
||||
>>> f(3.0)
|
||||
2.0
|
||||
|
||||
Specify an isolated namespace with the namespace keyword argument
|
||||
|
||||
>>> my_namespace = dict()
|
||||
>>> @dispatch(int, namespace=my_namespace)
|
||||
... def foo(x):
|
||||
... return x + 1
|
||||
|
||||
Dispatch on instance methods within classes
|
||||
|
||||
>>> class MyClass(object):
|
||||
... @dispatch(list)
|
||||
... def __init__(self, data):
|
||||
... self.data = data
|
||||
... @dispatch(int)
|
||||
... def __init__(self, datum): # noqa: F811
|
||||
... self.data = [datum]
|
||||
"""
|
||||
types = tuple(types)
|
||||
|
||||
def _(func):
|
||||
name = func.__name__
|
||||
|
||||
if ismethod(func):
|
||||
dispatcher = inspect.currentframe().f_back.f_locals.get(
|
||||
name,
|
||||
MethodDispatcher(name))
|
||||
else:
|
||||
if name not in namespace:
|
||||
namespace[name] = Dispatcher(name)
|
||||
dispatcher = namespace[name]
|
||||
|
||||
dispatcher.add(types, func, on_ambiguity=on_ambiguity)
|
||||
return dispatcher
|
||||
return _
|
||||
|
||||
|
||||
def ismethod(func):
|
||||
""" Is func a method?
|
||||
|
||||
Note that this has to work as the method is defined but before the class is
|
||||
defined. At this stage methods look like functions.
|
||||
"""
|
||||
signature = inspect.signature(func)
|
||||
return signature.parameters.get('self', None) is not None
|
||||
@@ -0,0 +1,413 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from warnings import warn
|
||||
import inspect
|
||||
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
|
||||
from .utils import expand_tuples
|
||||
import itertools as itl
|
||||
|
||||
|
||||
class MDNotImplementedError(NotImplementedError):
|
||||
""" A NotImplementedError for multiple dispatch """
|
||||
|
||||
|
||||
### Functions for on_ambiguity
|
||||
|
||||
def ambiguity_warn(dispatcher, ambiguities):
|
||||
""" Raise warning when ambiguity is detected
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dispatcher : Dispatcher
|
||||
The dispatcher on which the ambiguity was detected
|
||||
ambiguities : set
|
||||
Set of type signature pairs that are ambiguous within this dispatcher
|
||||
|
||||
See Also:
|
||||
Dispatcher.add
|
||||
warning_text
|
||||
"""
|
||||
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
|
||||
|
||||
|
||||
class RaiseNotImplementedError:
|
||||
"""Raise ``NotImplementedError`` when called."""
|
||||
|
||||
def __init__(self, dispatcher):
|
||||
self.dispatcher = dispatcher
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
types = tuple(type(a) for a in args)
|
||||
raise NotImplementedError(
|
||||
"Ambiguous signature for %s: <%s>" % (
|
||||
self.dispatcher.name, str_signature(types)
|
||||
))
|
||||
|
||||
def ambiguity_register_error_ignore_dup(dispatcher, ambiguities):
|
||||
"""
|
||||
If super signature for ambiguous types is duplicate types, ignore it.
|
||||
Else, register instance of ``RaiseNotImplementedError`` for ambiguous types.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dispatcher : Dispatcher
|
||||
The dispatcher on which the ambiguity was detected
|
||||
ambiguities : set
|
||||
Set of type signature pairs that are ambiguous within this dispatcher
|
||||
|
||||
See Also:
|
||||
Dispatcher.add
|
||||
ambiguity_warn
|
||||
"""
|
||||
for amb in ambiguities:
|
||||
signature = tuple(super_signature(amb))
|
||||
if len(set(signature)) == 1:
|
||||
continue
|
||||
dispatcher.add(
|
||||
signature, RaiseNotImplementedError(dispatcher),
|
||||
on_ambiguity=ambiguity_register_error_ignore_dup
|
||||
)
|
||||
|
||||
###
|
||||
|
||||
|
||||
_unresolved_dispatchers: set[Dispatcher] = set()
|
||||
_resolve = [True]
|
||||
|
||||
|
||||
def halt_ordering():
|
||||
_resolve[0] = False
|
||||
|
||||
|
||||
def restart_ordering(on_ambiguity=ambiguity_warn):
|
||||
_resolve[0] = True
|
||||
while _unresolved_dispatchers:
|
||||
dispatcher = _unresolved_dispatchers.pop()
|
||||
dispatcher.reorder(on_ambiguity=on_ambiguity)
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
""" Dispatch methods based on type signature
|
||||
|
||||
Use ``dispatch`` to add implementations
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> from sympy.multipledispatch import dispatch
|
||||
>>> @dispatch(int)
|
||||
... def f(x):
|
||||
... return x + 1
|
||||
|
||||
>>> @dispatch(float)
|
||||
... def f(x): # noqa: F811
|
||||
... return x - 1
|
||||
|
||||
>>> f(3)
|
||||
4
|
||||
>>> f(3.0)
|
||||
2.0
|
||||
"""
|
||||
__slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc'
|
||||
|
||||
def __init__(self, name, doc=None):
|
||||
self.name = self.__name__ = name
|
||||
self.funcs = {}
|
||||
self._cache = {}
|
||||
self.ordering = []
|
||||
self.doc = doc
|
||||
|
||||
def register(self, *types, **kwargs):
|
||||
""" Register dispatcher with new implementation
|
||||
|
||||
>>> from sympy.multipledispatch.dispatcher import Dispatcher
|
||||
>>> f = Dispatcher('f')
|
||||
>>> @f.register(int)
|
||||
... def inc(x):
|
||||
... return x + 1
|
||||
|
||||
>>> @f.register(float)
|
||||
... def dec(x):
|
||||
... return x - 1
|
||||
|
||||
>>> @f.register(list)
|
||||
... @f.register(tuple)
|
||||
... def reverse(x):
|
||||
... return x[::-1]
|
||||
|
||||
>>> f(1)
|
||||
2
|
||||
|
||||
>>> f(1.0)
|
||||
0.0
|
||||
|
||||
>>> f([1, 2, 3])
|
||||
[3, 2, 1]
|
||||
"""
|
||||
def _(func):
|
||||
self.add(types, func, **kwargs)
|
||||
return func
|
||||
return _
|
||||
|
||||
@classmethod
|
||||
def get_func_params(cls, func):
|
||||
if hasattr(inspect, "signature"):
|
||||
sig = inspect.signature(func)
|
||||
return sig.parameters.values()
|
||||
|
||||
@classmethod
|
||||
def get_func_annotations(cls, func):
|
||||
""" Get annotations of function positional parameters
|
||||
"""
|
||||
params = cls.get_func_params(func)
|
||||
if params:
|
||||
Parameter = inspect.Parameter
|
||||
|
||||
params = (param for param in params
|
||||
if param.kind in
|
||||
(Parameter.POSITIONAL_ONLY,
|
||||
Parameter.POSITIONAL_OR_KEYWORD))
|
||||
|
||||
annotations = tuple(
|
||||
param.annotation
|
||||
for param in params)
|
||||
|
||||
if not any(ann is Parameter.empty for ann in annotations):
|
||||
return annotations
|
||||
|
||||
def add(self, signature, func, on_ambiguity=ambiguity_warn):
|
||||
""" Add new types/method pair to dispatcher
|
||||
|
||||
>>> from sympy.multipledispatch import Dispatcher
|
||||
>>> D = Dispatcher('add')
|
||||
>>> D.add((int, int), lambda x, y: x + y)
|
||||
>>> D.add((float, float), lambda x, y: x + y)
|
||||
|
||||
>>> D(1, 2)
|
||||
3
|
||||
>>> D(1, 2.0)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
NotImplementedError: Could not find signature for add: <int, float>
|
||||
|
||||
When ``add`` detects a warning it calls the ``on_ambiguity`` callback
|
||||
with a dispatcher/itself, and a set of ambiguous type signature pairs
|
||||
as inputs. See ``ambiguity_warn`` for an example.
|
||||
"""
|
||||
# Handle annotations
|
||||
if not signature:
|
||||
annotations = self.get_func_annotations(func)
|
||||
if annotations:
|
||||
signature = annotations
|
||||
|
||||
# Handle union types
|
||||
if any(isinstance(typ, tuple) for typ in signature):
|
||||
for typs in expand_tuples(signature):
|
||||
self.add(typs, func, on_ambiguity)
|
||||
return
|
||||
|
||||
for typ in signature:
|
||||
if not isinstance(typ, type):
|
||||
str_sig = ', '.join(c.__name__ if isinstance(c, type)
|
||||
else str(c) for c in signature)
|
||||
raise TypeError("Tried to dispatch on non-type: %s\n"
|
||||
"In signature: <%s>\n"
|
||||
"In function: %s" %
|
||||
(typ, str_sig, self.name))
|
||||
|
||||
self.funcs[signature] = func
|
||||
self.reorder(on_ambiguity=on_ambiguity)
|
||||
self._cache.clear()
|
||||
|
||||
def reorder(self, on_ambiguity=ambiguity_warn):
|
||||
if _resolve[0]:
|
||||
self.ordering = ordering(self.funcs)
|
||||
amb = ambiguities(self.funcs)
|
||||
if amb:
|
||||
on_ambiguity(self, amb)
|
||||
else:
|
||||
_unresolved_dispatchers.add(self)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
types = tuple([type(arg) for arg in args])
|
||||
try:
|
||||
func = self._cache[types]
|
||||
except KeyError:
|
||||
func = self.dispatch(*types)
|
||||
if not func:
|
||||
raise NotImplementedError(
|
||||
'Could not find signature for %s: <%s>' %
|
||||
(self.name, str_signature(types)))
|
||||
self._cache[types] = func
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except MDNotImplementedError:
|
||||
funcs = self.dispatch_iter(*types)
|
||||
next(funcs) # burn first
|
||||
for func in funcs:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except MDNotImplementedError:
|
||||
pass
|
||||
raise NotImplementedError("Matching functions for "
|
||||
"%s: <%s> found, but none completed successfully"
|
||||
% (self.name, str_signature(types)))
|
||||
|
||||
def __str__(self):
|
||||
return "<dispatched %s>" % self.name
|
||||
__repr__ = __str__
|
||||
|
||||
def dispatch(self, *types):
|
||||
""" Deterimine appropriate implementation for this type signature
|
||||
|
||||
This method is internal. Users should call this object as a function.
|
||||
Implementation resolution occurs within the ``__call__`` method.
|
||||
|
||||
>>> from sympy.multipledispatch import dispatch
|
||||
>>> @dispatch(int)
|
||||
... def inc(x):
|
||||
... return x + 1
|
||||
|
||||
>>> implementation = inc.dispatch(int)
|
||||
>>> implementation(3)
|
||||
4
|
||||
|
||||
>>> print(inc.dispatch(float))
|
||||
None
|
||||
|
||||
See Also:
|
||||
``sympy.multipledispatch.conflict`` - module to determine resolution order
|
||||
"""
|
||||
|
||||
if types in self.funcs:
|
||||
return self.funcs[types]
|
||||
|
||||
try:
|
||||
return next(self.dispatch_iter(*types))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def dispatch_iter(self, *types):
|
||||
n = len(types)
|
||||
for signature in self.ordering:
|
||||
if len(signature) == n and all(map(issubclass, types, signature)):
|
||||
result = self.funcs[signature]
|
||||
yield result
|
||||
|
||||
def resolve(self, types):
|
||||
""" Deterimine appropriate implementation for this type signature
|
||||
|
||||
.. deprecated:: 0.4.4
|
||||
Use ``dispatch(*types)`` instead
|
||||
"""
|
||||
warn("resolve() is deprecated, use dispatch(*types)",
|
||||
DeprecationWarning)
|
||||
|
||||
return self.dispatch(*types)
|
||||
|
||||
def __getstate__(self):
|
||||
return {'name': self.name,
|
||||
'funcs': self.funcs}
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.name = d['name']
|
||||
self.funcs = d['funcs']
|
||||
self.ordering = ordering(self.funcs)
|
||||
self._cache = {}
|
||||
|
||||
@property
|
||||
def __doc__(self):
|
||||
docs = ["Multiply dispatched method: %s" % self.name]
|
||||
|
||||
if self.doc:
|
||||
docs.append(self.doc)
|
||||
|
||||
other = []
|
||||
for sig in self.ordering[::-1]:
|
||||
func = self.funcs[sig]
|
||||
if func.__doc__:
|
||||
s = 'Inputs: <%s>\n' % str_signature(sig)
|
||||
s += '-' * len(s) + '\n'
|
||||
s += func.__doc__.strip()
|
||||
docs.append(s)
|
||||
else:
|
||||
other.append(str_signature(sig))
|
||||
|
||||
if other:
|
||||
docs.append('Other signatures:\n ' + '\n '.join(other))
|
||||
|
||||
return '\n\n'.join(docs)
|
||||
|
||||
def _help(self, *args):
|
||||
return self.dispatch(*map(type, args)).__doc__
|
||||
|
||||
def help(self, *args, **kwargs):
|
||||
""" Print docstring for the function corresponding to inputs """
|
||||
print(self._help(*args))
|
||||
|
||||
def _source(self, *args):
|
||||
func = self.dispatch(*map(type, args))
|
||||
if not func:
|
||||
raise TypeError("No function found")
|
||||
return source(func)
|
||||
|
||||
def source(self, *args, **kwargs):
|
||||
""" Print source code for the function corresponding to inputs """
|
||||
print(self._source(*args))
|
||||
|
||||
|
||||
def source(func):
|
||||
s = 'File: %s\n\n' % inspect.getsourcefile(func)
|
||||
s = s + inspect.getsource(func)
|
||||
return s
|
||||
|
||||
|
||||
class MethodDispatcher(Dispatcher):
|
||||
""" Dispatch methods based on type signature
|
||||
|
||||
See Also:
|
||||
Dispatcher
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_func_params(cls, func):
|
||||
if hasattr(inspect, "signature"):
|
||||
sig = inspect.signature(func)
|
||||
return itl.islice(sig.parameters.values(), 1, None)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
self.obj = instance
|
||||
self.cls = owner
|
||||
return self
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
types = tuple([type(arg) for arg in args])
|
||||
func = self.dispatch(*types)
|
||||
if not func:
|
||||
raise NotImplementedError('Could not find signature for %s: <%s>' %
|
||||
(self.name, str_signature(types)))
|
||||
return func(self.obj, *args, **kwargs)
|
||||
|
||||
|
||||
def str_signature(sig):
|
||||
""" String representation of type signature
|
||||
|
||||
>>> from sympy.multipledispatch.dispatcher import str_signature
|
||||
>>> str_signature((int, float))
|
||||
'int, float'
|
||||
"""
|
||||
return ', '.join(cls.__name__ for cls in sig)
|
||||
|
||||
|
||||
def warning_text(name, amb):
|
||||
""" The text for ambiguity warnings """
|
||||
text = "\nAmbiguities exist in dispatched function %s\n\n" % (name)
|
||||
text += "The following signatures may result in ambiguous behavior:\n"
|
||||
for pair in amb:
|
||||
text += "\t" + \
|
||||
', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
|
||||
text += "\n\nConsider making the following additions:\n\n"
|
||||
text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
|
||||
+ ')\ndef %s(...)' % name for s in amb])
|
||||
return text
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,62 @@
|
||||
from sympy.multipledispatch.conflict import (supercedes, ordering, ambiguities,
|
||||
ambiguous, super_signature, consistent)
|
||||
|
||||
|
||||
class A: pass
|
||||
class B(A): pass
|
||||
class C: pass
|
||||
|
||||
|
||||
def test_supercedes():
|
||||
assert supercedes([B], [A])
|
||||
assert supercedes([B, A], [A, A])
|
||||
assert not supercedes([B, A], [A, B])
|
||||
assert not supercedes([A], [B])
|
||||
|
||||
|
||||
def test_consistent():
|
||||
assert consistent([A], [A])
|
||||
assert consistent([B], [B])
|
||||
assert not consistent([A], [C])
|
||||
assert consistent([A, B], [A, B])
|
||||
assert consistent([B, A], [A, B])
|
||||
assert not consistent([B, A], [B])
|
||||
assert not consistent([B, A], [B, C])
|
||||
|
||||
|
||||
def test_super_signature():
|
||||
assert super_signature([[A]]) == [A]
|
||||
assert super_signature([[A], [B]]) == [B]
|
||||
assert super_signature([[A, B], [B, A]]) == [B, B]
|
||||
assert super_signature([[A, A, B], [A, B, A], [B, A, A]]) == [B, B, B]
|
||||
|
||||
|
||||
def test_ambiguous():
|
||||
assert not ambiguous([A], [A])
|
||||
assert not ambiguous([A], [B])
|
||||
assert not ambiguous([B], [B])
|
||||
assert not ambiguous([A, B], [B, B])
|
||||
assert ambiguous([A, B], [B, A])
|
||||
|
||||
|
||||
def test_ambiguities():
|
||||
signatures = [[A], [B], [A, B], [B, A], [A, C]]
|
||||
expected = {((A, B), (B, A))}
|
||||
result = ambiguities(signatures)
|
||||
assert set(map(frozenset, expected)) == set(map(frozenset, result))
|
||||
|
||||
signatures = [[A], [B], [A, B], [B, A], [A, C], [B, B]]
|
||||
expected = set()
|
||||
result = ambiguities(signatures)
|
||||
assert set(map(frozenset, expected)) == set(map(frozenset, result))
|
||||
|
||||
|
||||
def test_ordering():
|
||||
signatures = [[A, A], [A, B], [B, A], [B, B], [A, C]]
|
||||
ord = ordering(signatures)
|
||||
assert ord[0] == (B, B) or ord[0] == (A, C)
|
||||
assert ord[-1] == (A, A) or ord[-1] == (A, C)
|
||||
|
||||
|
||||
def test_type_mro():
|
||||
assert super_signature([[object], [type]]) == [type]
|
||||
@@ -0,0 +1,213 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
from sympy.multipledispatch import dispatch
|
||||
from sympy.multipledispatch.conflict import AmbiguityWarning
|
||||
from sympy.testing.pytest import raises, warns
|
||||
from functools import partial
|
||||
|
||||
test_namespace: dict[str, Any] = {}
|
||||
|
||||
orig_dispatch = dispatch
|
||||
dispatch = partial(dispatch, namespace=test_namespace)
|
||||
|
||||
|
||||
def test_singledispatch():
|
||||
@dispatch(int)
|
||||
def f(x): # noqa:F811
|
||||
return x + 1
|
||||
|
||||
@dispatch(int)
|
||||
def g(x): # noqa:F811
|
||||
return x + 2
|
||||
|
||||
@dispatch(float) # noqa:F811
|
||||
def f(x): # noqa:F811
|
||||
return x - 1
|
||||
|
||||
assert f(1) == 2
|
||||
assert g(1) == 3
|
||||
assert f(1.0) == 0
|
||||
|
||||
assert raises(NotImplementedError, lambda: f('hello'))
|
||||
|
||||
|
||||
def test_multipledispatch():
|
||||
@dispatch(int, int)
|
||||
def f(x, y): # noqa:F811
|
||||
return x + y
|
||||
|
||||
@dispatch(float, float) # noqa:F811
|
||||
def f(x, y): # noqa:F811
|
||||
return x - y
|
||||
|
||||
assert f(1, 2) == 3
|
||||
assert f(1.0, 2.0) == -1.0
|
||||
|
||||
|
||||
class A: pass
|
||||
class B: pass
|
||||
class C(A): pass
|
||||
class D(C): pass
|
||||
class E(C): pass
|
||||
|
||||
|
||||
def test_inheritance():
|
||||
@dispatch(A)
|
||||
def f(x): # noqa:F811
|
||||
return 'a'
|
||||
|
||||
@dispatch(B) # noqa:F811
|
||||
def f(x): # noqa:F811
|
||||
return 'b'
|
||||
|
||||
assert f(A()) == 'a'
|
||||
assert f(B()) == 'b'
|
||||
assert f(C()) == 'a'
|
||||
|
||||
|
||||
def test_inheritance_and_multiple_dispatch():
|
||||
@dispatch(A, A)
|
||||
def f(x, y): # noqa:F811
|
||||
return type(x), type(y)
|
||||
|
||||
@dispatch(A, B) # noqa:F811
|
||||
def f(x, y): # noqa:F811
|
||||
return 0
|
||||
|
||||
assert f(A(), A()) == (A, A)
|
||||
assert f(A(), C()) == (A, C)
|
||||
assert f(A(), B()) == 0
|
||||
assert f(C(), B()) == 0
|
||||
assert raises(NotImplementedError, lambda: f(B(), B()))
|
||||
|
||||
|
||||
def test_competing_solutions():
|
||||
@dispatch(A)
|
||||
def h(x): # noqa:F811
|
||||
return 1
|
||||
|
||||
@dispatch(C) # noqa:F811
|
||||
def h(x): # noqa:F811
|
||||
return 2
|
||||
|
||||
assert h(D()) == 2
|
||||
|
||||
|
||||
def test_competing_multiple():
|
||||
@dispatch(A, B)
|
||||
def h(x, y): # noqa:F811
|
||||
return 1
|
||||
|
||||
@dispatch(C, B) # noqa:F811
|
||||
def h(x, y): # noqa:F811
|
||||
return 2
|
||||
|
||||
assert h(D(), B()) == 2
|
||||
|
||||
|
||||
def test_competing_ambiguous():
|
||||
test_namespace = {}
|
||||
dispatch = partial(orig_dispatch, namespace=test_namespace)
|
||||
|
||||
@dispatch(A, C)
|
||||
def f(x, y): # noqa:F811
|
||||
return 2
|
||||
|
||||
with warns(AmbiguityWarning, test_stacklevel=False):
|
||||
@dispatch(C, A) # noqa:F811
|
||||
def f(x, y): # noqa:F811
|
||||
return 2
|
||||
|
||||
assert f(A(), C()) == f(C(), A()) == 2
|
||||
# assert raises(Warning, lambda : f(C(), C()))
|
||||
|
||||
|
||||
def test_caching_correct_behavior():
|
||||
@dispatch(A)
|
||||
def f(x): # noqa:F811
|
||||
return 1
|
||||
|
||||
assert f(C()) == 1
|
||||
|
||||
@dispatch(C)
|
||||
def f(x): # noqa:F811
|
||||
return 2
|
||||
|
||||
assert f(C()) == 2
|
||||
|
||||
|
||||
def test_union_types():
|
||||
@dispatch((A, C))
|
||||
def f(x): # noqa:F811
|
||||
return 1
|
||||
|
||||
assert f(A()) == 1
|
||||
assert f(C()) == 1
|
||||
|
||||
|
||||
def test_namespaces():
|
||||
ns1 = {}
|
||||
ns2 = {}
|
||||
|
||||
def foo(x):
|
||||
return 1
|
||||
foo1 = orig_dispatch(int, namespace=ns1)(foo)
|
||||
|
||||
def foo(x):
|
||||
return 2
|
||||
foo2 = orig_dispatch(int, namespace=ns2)(foo)
|
||||
|
||||
assert foo1(0) == 1
|
||||
assert foo2(0) == 2
|
||||
|
||||
|
||||
"""
|
||||
Fails
|
||||
def test_dispatch_on_dispatch():
|
||||
@dispatch(A)
|
||||
@dispatch(C)
|
||||
def q(x): # noqa:F811
|
||||
return 1
|
||||
|
||||
assert q(A()) == 1
|
||||
assert q(C()) == 1
|
||||
"""
|
||||
|
||||
|
||||
def test_methods():
|
||||
class Foo:
|
||||
@dispatch(float)
|
||||
def f(self, x): # noqa:F811
|
||||
return x - 1
|
||||
|
||||
@dispatch(int) # noqa:F811
|
||||
def f(self, x): # noqa:F811
|
||||
return x + 1
|
||||
|
||||
@dispatch(int)
|
||||
def g(self, x): # noqa:F811
|
||||
return x + 3
|
||||
|
||||
|
||||
foo = Foo()
|
||||
assert foo.f(1) == 2
|
||||
assert foo.f(1.0) == 0.0
|
||||
assert foo.g(1) == 4
|
||||
|
||||
|
||||
def test_methods_multiple_dispatch():
|
||||
class Foo:
|
||||
@dispatch(A, A)
|
||||
def f(x, y): # noqa:F811
|
||||
return 1
|
||||
|
||||
@dispatch(A, C) # noqa:F811
|
||||
def f(x, y): # noqa:F811
|
||||
return 2
|
||||
|
||||
|
||||
foo = Foo()
|
||||
assert foo.f(A(), A()) == 1
|
||||
assert foo.f(A(), C()) == 2
|
||||
assert foo.f(C(), C()) == 2
|
||||
@@ -0,0 +1,284 @@
|
||||
from sympy.multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError,
|
||||
MethodDispatcher, halt_ordering,
|
||||
restart_ordering,
|
||||
ambiguity_register_error_ignore_dup)
|
||||
from sympy.testing.pytest import raises, warns
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
|
||||
def inc(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
def dec(x):
|
||||
return x - 1
|
||||
|
||||
|
||||
def test_dispatcher():
|
||||
f = Dispatcher('f')
|
||||
f.add((int,), inc)
|
||||
f.add((float,), dec)
|
||||
|
||||
with warns(DeprecationWarning, test_stacklevel=False):
|
||||
assert f.resolve((int,)) == inc
|
||||
assert f.dispatch(int) is inc
|
||||
|
||||
assert f(1) == 2
|
||||
assert f(1.0) == 0.0
|
||||
|
||||
|
||||
def test_union_types():
|
||||
f = Dispatcher('f')
|
||||
f.register((int, float))(inc)
|
||||
|
||||
assert f(1) == 2
|
||||
assert f(1.0) == 2.0
|
||||
|
||||
|
||||
def test_dispatcher_as_decorator():
|
||||
f = Dispatcher('f')
|
||||
|
||||
@f.register(int)
|
||||
def inc(x): # noqa:F811
|
||||
return x + 1
|
||||
|
||||
@f.register(float) # noqa:F811
|
||||
def inc(x): # noqa:F811
|
||||
return x - 1
|
||||
|
||||
assert f(1) == 2
|
||||
assert f(1.0) == 0.0
|
||||
|
||||
|
||||
def test_register_instance_method():
|
||||
|
||||
class Test:
|
||||
__init__ = MethodDispatcher('f')
|
||||
|
||||
@__init__.register(list)
|
||||
def _init_list(self, data):
|
||||
self.data = data
|
||||
|
||||
@__init__.register(object)
|
||||
def _init_obj(self, datum):
|
||||
self.data = [datum]
|
||||
|
||||
a = Test(3)
|
||||
b = Test([3])
|
||||
assert a.data == b.data
|
||||
|
||||
|
||||
def test_on_ambiguity():
|
||||
f = Dispatcher('f')
|
||||
|
||||
def identity(x): return x
|
||||
|
||||
ambiguities = [False]
|
||||
|
||||
def on_ambiguity(dispatcher, amb):
|
||||
ambiguities[0] = True
|
||||
|
||||
f.add((object, object), identity, on_ambiguity=on_ambiguity)
|
||||
assert not ambiguities[0]
|
||||
f.add((object, float), identity, on_ambiguity=on_ambiguity)
|
||||
assert not ambiguities[0]
|
||||
f.add((float, object), identity, on_ambiguity=on_ambiguity)
|
||||
assert ambiguities[0]
|
||||
|
||||
|
||||
def test_raise_error_on_non_class():
|
||||
f = Dispatcher('f')
|
||||
assert raises(TypeError, lambda: f.add((1,), inc))
|
||||
|
||||
|
||||
def test_docstring():
|
||||
|
||||
def one(x, y):
|
||||
""" Docstring number one """
|
||||
return x + y
|
||||
|
||||
def two(x, y):
|
||||
""" Docstring number two """
|
||||
return x + y
|
||||
|
||||
def three(x, y):
|
||||
return x + y
|
||||
|
||||
master_doc = 'Doc of the multimethod itself'
|
||||
|
||||
f = Dispatcher('f', doc=master_doc)
|
||||
f.add((object, object), one)
|
||||
f.add((int, int), two)
|
||||
f.add((float, float), three)
|
||||
|
||||
assert one.__doc__.strip() in f.__doc__
|
||||
assert two.__doc__.strip() in f.__doc__
|
||||
assert f.__doc__.find(one.__doc__.strip()) < \
|
||||
f.__doc__.find(two.__doc__.strip())
|
||||
assert 'object, object' in f.__doc__
|
||||
assert master_doc in f.__doc__
|
||||
|
||||
|
||||
def test_help():
|
||||
def one(x, y):
|
||||
""" Docstring number one """
|
||||
return x + y
|
||||
|
||||
def two(x, y):
|
||||
""" Docstring number two """
|
||||
return x + y
|
||||
|
||||
def three(x, y):
|
||||
""" Docstring number three """
|
||||
return x + y
|
||||
|
||||
master_doc = 'Doc of the multimethod itself'
|
||||
|
||||
f = Dispatcher('f', doc=master_doc)
|
||||
f.add((object, object), one)
|
||||
f.add((int, int), two)
|
||||
f.add((float, float), three)
|
||||
|
||||
assert f._help(1, 1) == two.__doc__
|
||||
assert f._help(1.0, 2.0) == three.__doc__
|
||||
|
||||
|
||||
def test_source():
|
||||
def one(x, y):
|
||||
""" Docstring number one """
|
||||
return x + y
|
||||
|
||||
def two(x, y):
|
||||
""" Docstring number two """
|
||||
return x - y
|
||||
|
||||
master_doc = 'Doc of the multimethod itself'
|
||||
|
||||
f = Dispatcher('f', doc=master_doc)
|
||||
f.add((int, int), one)
|
||||
f.add((float, float), two)
|
||||
|
||||
assert 'x + y' in f._source(1, 1)
|
||||
assert 'x - y' in f._source(1.0, 1.0)
|
||||
|
||||
|
||||
def test_source_raises_on_missing_function():
|
||||
f = Dispatcher('f')
|
||||
|
||||
assert raises(TypeError, lambda: f.source(1))
|
||||
|
||||
|
||||
def test_halt_method_resolution():
|
||||
g = [0]
|
||||
|
||||
def on_ambiguity(a, b):
|
||||
g[0] += 1
|
||||
|
||||
f = Dispatcher('f')
|
||||
|
||||
halt_ordering()
|
||||
|
||||
def func(*args):
|
||||
pass
|
||||
|
||||
f.add((int, object), func)
|
||||
f.add((object, int), func)
|
||||
|
||||
assert g == [0]
|
||||
|
||||
restart_ordering(on_ambiguity=on_ambiguity)
|
||||
|
||||
assert g == [1]
|
||||
|
||||
assert set(f.ordering) == {(int, object), (object, int)}
|
||||
|
||||
|
||||
def test_no_implementations():
|
||||
f = Dispatcher('f')
|
||||
assert raises(NotImplementedError, lambda: f('hello'))
|
||||
|
||||
|
||||
def test_register_stacking():
|
||||
f = Dispatcher('f')
|
||||
|
||||
@f.register(list)
|
||||
@f.register(tuple)
|
||||
def rev(x):
|
||||
return x[::-1]
|
||||
|
||||
assert f((1, 2, 3)) == (3, 2, 1)
|
||||
assert f([1, 2, 3]) == [3, 2, 1]
|
||||
|
||||
assert raises(NotImplementedError, lambda: f('hello'))
|
||||
assert rev('hello') == 'olleh'
|
||||
|
||||
|
||||
def test_dispatch_method():
|
||||
f = Dispatcher('f')
|
||||
|
||||
@f.register(list)
|
||||
def rev(x):
|
||||
return x[::-1]
|
||||
|
||||
@f.register(int, int)
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
class MyList(list):
|
||||
pass
|
||||
|
||||
assert f.dispatch(list) is rev
|
||||
assert f.dispatch(MyList) is rev
|
||||
assert f.dispatch(int, int) is add
|
||||
|
||||
|
||||
def test_not_implemented():
|
||||
f = Dispatcher('f')
|
||||
|
||||
@f.register(object)
|
||||
def _(x):
|
||||
return 'default'
|
||||
|
||||
@f.register(int)
|
||||
def _(x):
|
||||
if x % 2 == 0:
|
||||
return 'even'
|
||||
else:
|
||||
raise MDNotImplementedError()
|
||||
|
||||
assert f('hello') == 'default' # default behavior
|
||||
assert f(2) == 'even' # specialized behavior
|
||||
assert f(3) == 'default' # fall bac to default behavior
|
||||
assert raises(NotImplementedError, lambda: f(1, 2))
|
||||
|
||||
|
||||
def test_not_implemented_error():
|
||||
f = Dispatcher('f')
|
||||
|
||||
@f.register(float)
|
||||
def _(a):
|
||||
raise MDNotImplementedError()
|
||||
|
||||
assert raises(NotImplementedError, lambda: f(1.0))
|
||||
|
||||
def test_ambiguity_register_error_ignore_dup():
|
||||
f = Dispatcher('f')
|
||||
|
||||
class A:
|
||||
pass
|
||||
class B(A):
|
||||
pass
|
||||
class C(A):
|
||||
pass
|
||||
|
||||
# suppress warning for registering ambiguous signal
|
||||
f.add((A, B), lambda x,y: None, ambiguity_register_error_ignore_dup)
|
||||
f.add((B, A), lambda x,y: None, ambiguity_register_error_ignore_dup)
|
||||
f.add((A, C), lambda x,y: None, ambiguity_register_error_ignore_dup)
|
||||
f.add((C, A), lambda x,y: None, ambiguity_register_error_ignore_dup)
|
||||
|
||||
# raises error if ambiguous signal is passed
|
||||
assert raises(NotImplementedError, lambda: f(B(), C()))
|
||||
@@ -0,0 +1,105 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def expand_tuples(L):
|
||||
"""
|
||||
>>> from sympy.multipledispatch.utils import expand_tuples
|
||||
>>> expand_tuples([1, (2, 3)])
|
||||
[(1, 2), (1, 3)]
|
||||
|
||||
>>> expand_tuples([1, 2])
|
||||
[(1, 2)]
|
||||
"""
|
||||
if not L:
|
||||
return [()]
|
||||
elif not isinstance(L[0], tuple):
|
||||
rest = expand_tuples(L[1:])
|
||||
return [(L[0],) + t for t in rest]
|
||||
else:
|
||||
rest = expand_tuples(L[1:])
|
||||
return [(item,) + t for t in rest for item in L[0]]
|
||||
|
||||
|
||||
# Taken from theano/theano/gof/sched.py
|
||||
# Avoids licensing issues because this was written by Matthew Rocklin
|
||||
def _toposort(edges):
|
||||
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
|
||||
|
||||
inputs:
|
||||
edges - a dict of the form {a: {b, c}} where b and c depend on a
|
||||
outputs:
|
||||
L - an ordered list of nodes that satisfy the dependencies of edges
|
||||
|
||||
>>> from sympy.multipledispatch.utils import _toposort
|
||||
>>> _toposort({1: (2, 3), 2: (3, )})
|
||||
[1, 2, 3]
|
||||
|
||||
Closely follows the wikipedia page [2]
|
||||
|
||||
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
|
||||
Communications of the ACM
|
||||
[2] https://en.wikipedia.org/wiki/Toposort#Algorithms
|
||||
"""
|
||||
incoming_edges = reverse_dict(edges)
|
||||
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
|
||||
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
|
||||
L = []
|
||||
|
||||
while S:
|
||||
n, _ = S.popitem()
|
||||
L.append(n)
|
||||
for m in edges.get(n, ()):
|
||||
assert n in incoming_edges[m]
|
||||
incoming_edges[m].remove(n)
|
||||
if not incoming_edges[m]:
|
||||
S[m] = None
|
||||
if any(incoming_edges.get(v, None) for v in edges):
|
||||
raise ValueError("Input has cycles")
|
||||
return L
|
||||
|
||||
|
||||
def reverse_dict(d):
|
||||
"""Reverses direction of dependence dict
|
||||
|
||||
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
|
||||
>>> reverse_dict(d) # doctest: +SKIP
|
||||
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
|
||||
|
||||
:note: dict order are not deterministic. As we iterate on the
|
||||
input dict, it make the output of this function depend on the
|
||||
dict order. So this function output order should be considered
|
||||
as undeterministic.
|
||||
|
||||
"""
|
||||
result = {}
|
||||
for key in d:
|
||||
for val in d[key]:
|
||||
result[val] = result.get(val, ()) + (key, )
|
||||
return result
|
||||
|
||||
|
||||
# Taken from toolz
|
||||
# Avoids licensing issues because this version was authored by Matthew Rocklin
|
||||
def groupby(func, seq):
|
||||
""" Group a collection by a key function
|
||||
|
||||
>>> from sympy.multipledispatch.utils import groupby
|
||||
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
|
||||
>>> groupby(len, names) # doctest: +SKIP
|
||||
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
|
||||
|
||||
>>> iseven = lambda x: x % 2 == 0
|
||||
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
|
||||
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
|
||||
|
||||
See Also:
|
||||
``countby``
|
||||
"""
|
||||
|
||||
d = {}
|
||||
for item in seq:
|
||||
key = func(item)
|
||||
if key not in d:
|
||||
d[key] = []
|
||||
d[key].append(item)
|
||||
return d
|
||||
Reference in New Issue
Block a user