Skip to content

Commit

Permalink
Merge pull request #597 from greschd/fix_numeric_type_arithmetics
Browse files Browse the repository at this point in the history
Fix numeric type arithmetics
  • Loading branch information
giovannipizzi authored Jun 21, 2017
2 parents 0e571ef + 292fb55 commit b47154f
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 68 deletions.
13 changes: 13 additions & 0 deletions aiida/backends/tests/base_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# For further information please visit http://www.aiida.net #
###########################################################################
import unittest
import operator

from aiida.backends.testbase import AiidaTestCase
from aiida.common.exceptions import ModificationNotAllowed
Expand Down Expand Up @@ -189,3 +190,15 @@ def test_power(self):

res = a ** b
self.assertEqual(res.value, 16.)

class TestFloatIntMix(AiidaTestCase):
def test_operator(self):
a = Float(2.2)
b = Int(3)

for op in [operator.add, operator.mul, operator.pow, operator.lt, operator.le, operator.gt, operator.ge, operator.iadd, operator.imul]:
for x, y in [(a, b), (b, a)]:
c = op(x, y)
c_val = op(x.value, y.value)
self.assertEqual(c._type, type(c_val))
self.assertEqual(c, op(x.value, y.value))
134 changes: 70 additions & 64 deletions aiida/orm/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@
# For further information please visit http://www.aiida.net #
###########################################################################
from abc import ABCMeta
import numbers
import collections
from aiida.orm import Data

try:
from functools import singledispatch
except ImportError:
from singledispatch import singledispatch
from past.builtins import basestring
import numpy as np

from aiida.orm import Data

class BaseType(Data):
"""
Expand Down Expand Up @@ -89,95 +96,70 @@ def _create_init_args(self, *args, **kwargs):

return kwargs

def _left_operator(func):
def inner(self, other):
l = self.value
if isinstance(other, NumericType):
r = other.value
else:
r = other
return to_aiida_type(func(l, r))
return inner

def _right_operator(func):
def inner(self, other):
assert not isinstance(other, NumericType)
return to_aiida_type(func(self.value, other))
return inner

class NumericType(BaseType):
"""
Specific subclass of :py:class:`BaseType` to store numbers,
overloading common operators (``+``, ``*``, ...)
"""
@_left_operator
def __add__(self, other):
if isinstance(other, NumericType):
return self.new(self.value + other.value)
else:
return self.new(self.value + other)

def __iadd__(self, other):
assert not self.is_stored
if isinstance(other, NumericType):
self.value += other.value
else:
self.value += other
return self
return self + other

@_right_operator
def __radd__(self, other):
assert not isinstance(other, NumericType)
return self.new(other + self.value)
return other + self

@_left_operator
def __sub__(self, other):
if isinstance(other, NumericType):
return self.new(self.value - other.value)
else:
return self.new(self.value - other)

def __isub__(self, other):
assert not self.is_stored
if isinstance(other, NumericType):
self.value -= other.value
else:
self.value -= other
return self
return self - other

@_right_operator
def __rsub__(self, other):
assert not isinstance(other, NumericType)
return self.new(other - self.value)
return other - self

@_left_operator
def __mul__(self, other):
if isinstance(other, NumericType):
return self.new(self.value * other.value)
else:
return self.new(self.value * other)

def __imul__(self, other):
assert not self.is_stored
if isinstance(other, NumericType):
self.value *= other.value
else:
self.value *= other
return self
return self * other

@_right_operator
def __rmul__(self, other):
assert not isinstance(other, NumericType)
return self.new(other * self.value)
return other * self

def __pow__(self, power, modulo=None):
if isinstance(power, NumericType):
return self.new(self.value ** power.value)
else:
return self.new(self.value ** power)
@_left_operator
def __pow__(self, power):
return self ** power

@_left_operator
def __lt__(self, other):
if isinstance(other, NumericType):
return self.value < other.value
else:
return self.value < other
return self < other

@_left_operator
def __le__(self, other):
if isinstance(other, NumericType):
return self.value <= other.value
else:
return self.value <= other
return self <= other

@_left_operator
def __gt__(self, other):
if isinstance(other, NumericType):
return self.value > other.value
else:
return self.value > other
return self > other

@_left_operator
def __ge__(self, other):
if isinstance(other, NumericType):
return self.value >= other.value
else:
return self.value >= other
return self >= other

def __float__(self):
return float(self.value)
Expand Down Expand Up @@ -350,3 +332,27 @@ def get_false_node():
"""
FALSE = Bool(typevalue=(bool, False))
return FALSE

@singledispatch
def to_aiida_type(value):
"""
Turns basic Python types (str, int, float, bool) into the corresponding AiiDA types.
"""
raise TypeError("Cannot convert value of type {} to AiiDA type.".format(type(value)))

@to_aiida_type.register(basestring)
def _(value):
return Str(value)

@to_aiida_type.register(numbers.Integral)
def _(value):
return Int(value)

@to_aiida_type.register(numbers.Real)
def _(value):
return Float(value)

@to_aiida_type.register(bool)
@to_aiida_type.register(np.bool_)
def _(value):
return Bool(value)
10 changes: 6 additions & 4 deletions setup_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
'tzlocal==1.3',
'pytz==2014.10',
'six==1.10',
'future',
'singledispatch >= 3.4.0.0',
# We need for the time being to stay with an old version
# of celery, including the versions of the AMQP libraries below,
# because the support for a SQLA broker has been dropped in later
Expand Down Expand Up @@ -118,12 +120,12 @@
]
}

# There are a number of optional dependencies that are not
# listed even as optional dependencies as they are quite
# There are a number of optional dependencies that are not
# listed even as optional dependencies as they are quite
# cumbersome to install and there is a risk that a user, wanting
# to install all dependencies (including optional ones)
# to install all dependencies (including optional ones)
# does not manage and thinks it's an AiiDA problem.
#
#
# These include:
# - mayavi==4.5.0
# plotting package, requires to have the vtk code installed first;
Expand Down

0 comments on commit b47154f

Please sign in to comment.