Skip to content

Commit

Permalink
WIP: Lazy import numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhollas committed Jul 8, 2023
1 parent 1433071 commit 872856f
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 16 deletions.
10 changes: 8 additions & 2 deletions aiida/orm/nodes/data/array/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import json
from string import Template

import numpy

from aiida.common.exceptions import ValidationError
from aiida.common.utils import join_labels, prettify_labels

Expand Down Expand Up @@ -74,6 +72,7 @@ def find_bandgap(bandsdata, number_electrons=None, fermi_energy=None):
"""

# pylint: disable=too-many-return-statements,too-many-branches,too-many-statements,no-else-return
import numpy

def nint(num):
"""
Expand Down Expand Up @@ -259,6 +258,7 @@ def _validate_bands_occupations(self, bands, occupations=None, labels=None):
correspond to the number of kpoints.
"""
# pylint: disable=too-many-branches
import numpy
try:
kpoints = self.get_kpoints()
except AttributeError:
Expand Down Expand Up @@ -392,6 +392,7 @@ def get_bands(self, also_occupations=False, also_labels=False):
:param also_occupations: if True, returns also the occupations array.
Default = False
"""
import numpy
try:
bands = numpy.array(self.get_array('bands'))
except KeyError:
Expand Down Expand Up @@ -440,6 +441,7 @@ def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None,
"""
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
# load the x and y's of the graph
import numpy
stored_bands = self.get_bands()
if len(stored_bands.shape) == 2:
bands = stored_bands
Expand Down Expand Up @@ -656,6 +658,7 @@ def _prepare_dat_blocks(self, main_file_name='', comments=True): # pylint: disa
:param comments: if True, print comments (if it makes sense for the given
format)
"""
import numpy
plot_info = self._get_bandplot_data(cartesian=True, prettify_format=None, join_symbol='|')

bands = plot_info['y']
Expand Down Expand Up @@ -714,6 +717,7 @@ def _matplotlib_get_dict(
accepted, see internal variable 'valid_additional_keywords
"""
# pylint: disable=too-many-arguments,too-many-locals
import numpy

# Only these keywords are accepted in kwargs, and then set into the json
valid_additional_keywords = [
Expand Down Expand Up @@ -1128,6 +1132,8 @@ def _prepare_agr(

import math

import numpy

# load the x and y of every set
if color_number > MAX_NUM_AGR_COLORS:
raise ValueError(f'Color number is too high (should be less than {MAX_NUM_AGR_COLORS})')
Expand Down
11 changes: 9 additions & 2 deletions aiida/orm/nodes/data/array/kpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
lists and meshes of k-points (i.e., points in the reciprocal space of a
periodic crystal structure).
"""
import numpy

from .array import ArrayData

__all__ = ('KpointsData',)
Expand Down Expand Up @@ -59,6 +57,7 @@ def cell(self):
The crystal unit cell. Rows are the crystal vectors in Angstroms.
:return: a 3x3 numpy.array
"""
import numpy
return numpy.array(self.base.attributes.get('cell'))

@cell.setter
Expand Down Expand Up @@ -164,6 +163,7 @@ def _change_reference(self, kpoints, to_cartesian=True):
:param kpoints: a list of (3) point coordinates
:return kpoints: a list of (3) point coordinates in the new reference
"""
import numpy
if not isinstance(kpoints, numpy.ndarray):
raise ValueError('kpoints must be a numpy.array for method change_reference()')

Expand Down Expand Up @@ -224,6 +224,7 @@ def reciprocal_cell(self):
:returns: reciprocal cell in units of 1/Angstrom with cell vectors stored as rows.
Use e.g. reciprocal_cell[0] to access the first reciprocal cell vector.
"""
import numpy
the_cell = numpy.array(self.cell)
reciprocal_cell = 2. * numpy.pi * numpy.linalg.inv(the_cell).transpose()
return reciprocal_cell
Expand Down Expand Up @@ -282,6 +283,7 @@ def get_kpoints_mesh(self, print_list=False):
:return kpoints: (if print_list = True) an explicit list of kpoints coordinates,
similar to what returned by get_kpoints()
"""
import numpy
mesh = self.base.attributes.get('mesh')
offset = self.base.attributes.get('offset')

Expand Down Expand Up @@ -315,6 +317,7 @@ def set_kpoints_mesh_from_density(self, distance, offset=None, force_parity=Fals
:note: a cell should be defined first.
:note: the number of kpoints along non-periodic axes is always 1.
"""
import numpy
if offset is None:
offset = [0., 0., 0.]

Expand Down Expand Up @@ -352,6 +355,7 @@ def _validate_kpoints_weights(self, kpoints, weights):
Kpoints and weights must be convertible respectively to an array of
N x dimension and N floats
"""
import numpy
kpoints = numpy.array(kpoints)

# I cannot just use `if not kpoints` because it's a numpy array and
Expand Down Expand Up @@ -425,6 +429,8 @@ def set_kpoints(self, kpoints, cartesian=False, labels=None, weights=None, fill_
non-periodic dimensions (indicated by False in self.pbc), or list of
values for each of the non-periodic dimensions.
"""
import numpy

from aiida.common.exceptions import ModificationNotAllowed

# check that it is a 'dim'x #kpoints dimensional array
Expand Down Expand Up @@ -486,6 +492,7 @@ def get_kpoints(self, also_weights=False, cartesian=False):
:param cartesian: if True, returns points in cartesian coordinates,
otherwise, returns in crystal coordinates. Default = False.
"""
import numpy
try:
kpoints = numpy.array(self.get_array('kpoints'))
except KeyError:
Expand Down
12 changes: 6 additions & 6 deletions aiida/orm/nodes/data/array/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
"""Data plugin to represet arrays of projected wavefunction components."""
import copy

import numpy as np

from aiida.common import exceptions
from aiida.plugins import OrbitalFactory

Expand Down Expand Up @@ -43,13 +41,14 @@ def _check_projections_bands(self, projection_array):
:raise: AttributeError if input_array is not of same shape as
dos_energy
"""
import numpy
try:
shape_bands = np.shape(self.get_reference_bandsdata())
shape_bands = numpy.shape(self.get_reference_bandsdata())
except AttributeError:
raise exceptions.ValidationError('Bands must be set first, then projwfc')
# The [0:2] is so that each array, and not collection of arrays
# is used to make the comparison
if np.shape(projection_array) != shape_bands:
if numpy.shape(projection_array) != shape_bands:
raise AttributeError('These arrays are not the same shape as the bands')

def set_reference_bandsdata(self, value):
Expand Down Expand Up @@ -197,6 +196,7 @@ def set_projectiondata(
"""

# pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
import numpy

def single_to_list(item):
"""
Expand All @@ -215,11 +215,11 @@ def single_to_list(item):
def array_list_checker(array_list, array_name, orb_length):
"""
Does basic checks over everything in the array_list. Makes sure that
all the arrays are np.ndarray floats, that the length is same as
all the arrays are numpy.ndarray floats, that the length is same as
required_length, raises exception using array_name if there is
a failure
"""
if not all(isinstance(_, np.ndarray) for _ in array_list):
if not all(isinstance(_, numpy.ndarray) for _ in array_list):
raise exceptions.ValidationError(f'{array_name} was not composed entirely of ndarrays')
if len(array_list) != orb_length:
raise exceptions.ValidationError(f'{array_name} did not have the same length as the list of orbitals')
Expand Down
9 changes: 5 additions & 4 deletions aiida/orm/nodes/data/array/xy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
collections of y-arrays bound to a single x-array, and the methods to operate
on them.
"""
import numpy as np

from aiida.common.exceptions import NotExistent

from .array import ArrayData
Expand Down Expand Up @@ -49,10 +47,11 @@ def _arrayandname_validator(array, name, units):
Validates that the array is an numpy.ndarray and that the name is
of type str. Raises TypeError or ValueError if this not the case.
"""
import numpy
if not isinstance(name, str):
raise TypeError('The name must always be a str.')

if not isinstance(array, np.ndarray):
if not isinstance(array, numpy.ndarray):
raise TypeError('The input array must always be a numpy array')
try:
array.astype(float)
Expand Down Expand Up @@ -83,6 +82,8 @@ def set_y(self, y_arrays, y_names, y_units):
:param y_names: A list of strings giving the names of the y_arrays
:param y_units: A list of strings giving the units of the y_arrays
"""
import numpy as np

# for the case of single name, array, tag input converts to a list
y_arrays = check_convert_single_to_tuple(y_arrays)
y_names = check_convert_single_to_tuple(y_names)
Expand All @@ -102,7 +103,7 @@ def set_y(self, y_arrays, y_names, y_units):
# validate each of the y_arrays
for num, (y_array, y_name, y_unit) in enumerate(zip(y_arrays, y_names, y_units)):
self._arrayandname_validator(y_array, y_name, y_unit)
if np.shape(y_array) != np.shape(x_array):
if numpy.shape(y_array) != numpy.shape(x_array):
raise ValueError(f'y_array {y_name} did not have the same shape has the x_array!')
self.set_array(f'y_array_{num}', y_array)

Expand Down
4 changes: 2 additions & 2 deletions aiida/orm/nodes/data/bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
###########################################################################
"""`Data` sub class to represent a boolean value."""

import numpy
#import numpy

from .base import BaseType, to_aiida_type

Expand All @@ -29,7 +29,7 @@ def __bool__(self):


@to_aiida_type.register(bool)
@to_aiida_type.register(numpy.bool_)
#@to_aiida_type.register(numpy.bool_)
def _(value):
return Bool(value)

Expand Down

0 comments on commit 872856f

Please sign in to comment.