From 825babe785bbdfafbae8443dc9bdf22270448c4f Mon Sep 17 00:00:00 2001 From: Gabriele Cesa Date: Wed, 30 Jun 2021 12:44:41 +0200 Subject: [PATCH 1/2] optional caching of weights' variances for He init --- e2cnn/nn/init.py | 59 ++++++++++++---- e2cnn/nn/modules/r2_conv/basisexpansion.py | 6 ++ .../modules/r2_conv/basisexpansion_blocks.py | 44 +++++++++++- .../r2_conv/basisexpansion_singleblock.py | 20 +++++- test/nn/test_he_init.py | 68 +++++++++++++++++++ 5 files changed, 181 insertions(+), 16 deletions(-) diff --git a/e2cnn/nn/init.py b/e2cnn/nn/init.py index 4a45a7f0..fb19f700 100644 --- a/e2cnn/nn/init.py +++ b/e2cnn/nn/init.py @@ -11,26 +11,24 @@ __all__ = ["generalized_he_init", "deltaorthonormal_init"] -def generalized_he_init(tensor: torch.Tensor, basisexpansion: BasisExpansion): +def _generalized_he_init_variances(basisexpansion: BasisExpansion): r""" - - Initialize the weights of a convolutional layer with a generalized He's weight initialization method. - + + Compute the variances of the weights of a convolutional layer with a generalized He's weight initialization method. + Args: - tensor (torch.Tensor): the tensor containing the weights basisexpansion (BasisExpansion): the basis expansion method """ - # Initialization - assert tensor.shape == (basisexpansion.dimension(), ) - - vars = torch.ones_like(tensor) + vars = torch.ones((basisexpansion.dimension(),)) inputs_count = defaultdict(lambda: set()) basis_count = defaultdict(int) - for attr in basisexpansion.get_basis_info(): + basis_info = list(basisexpansion.get_basis_info()) + + for attr in basis_info: i, o = attr["in_irreps_position"], attr["out_irreps_position"] in_irrep, out_irrep = attr["in_irrep"], attr["out_irrep"] inputs_count[o].add(in_irrep) @@ -39,13 +37,48 @@ def generalized_he_init(tensor: torch.Tensor, basisexpansion: BasisExpansion): for o in inputs_count.keys(): inputs_count[o] = len(inputs_count[o]) - for w, attr in enumerate(basisexpansion.get_basis_info()): + for w, attr in enumerate(basis_info): i, o = attr["in_irreps_position"], attr["out_irreps_position"] in_irrep, out_irrep = attr["in_irrep"], attr["out_irrep"] vars[w] = 1. / math.sqrt(inputs_count[o] * basis_count[(in_irrep, o)]) - # for i, o in basis_count.keys(): - # print(i, o, inputs_count[o], basis_count[(i, o)]) + return vars + + +cached_he_vars = {} + + +def generalized_he_init(tensor: torch.Tensor, basisexpansion: BasisExpansion, cache: bool = False): + r""" + + Initialize the weights of a convolutional layer with a generalized He's weight initialization method. + + Because the computation of the variances can be expensive, to save time on consecutive runs of the same model, + it is possible to cache the tensor containing the variance of each weight, for a specific ```basisexpansion```. + This can be useful if a network contains multiple convolution layers of the same kind (same input and output types, + same kernel size, etc.) or if one needs to train the same network from scratch multiple times (e.g. to perform + hyper-parameter search over learning rate or to repeat an experiment with different random seeds). + + .. note :: + The variance tensor is cached in memory and therefore is only available to the current process. + + Args: + tensor (torch.Tensor): the tensor containing the weights + basisexpansion (BasisExpansion): the basis expansion method + cache (bool, optional): cache the variance tensor. By default, ```cache=False``` + + """ + # Initialization + + assert tensor.shape == (basisexpansion.dimension(),) + + if cache and basisexpansion not in cached_he_vars: + cached_he_vars[basisexpansion] = _generalized_he_init_variances(basisexpansion) + + if cache: + vars = cached_he_vars[basisexpansion] + else: + vars = _generalized_he_init_variances(basisexpansion) tensor[:] = vars * torch.randn_like(tensor) diff --git a/e2cnn/nn/modules/r2_conv/basisexpansion.py b/e2cnn/nn/modules/r2_conv/basisexpansion.py index a9cefce0..be6442df 100644 --- a/e2cnn/nn/modules/r2_conv/basisexpansion.py +++ b/e2cnn/nn/modules/r2_conv/basisexpansion.py @@ -80,4 +80,10 @@ def dimension(self) -> int: """ pass + @abstractmethod + def __hash__(self): + raise NotImplementedError() + @abstractmethod + def __eq__(self, other): + raise NotImplementedError() diff --git a/e2cnn/nn/modules/r2_conv/basisexpansion_blocks.py b/e2cnn/nn/modules/r2_conv/basisexpansion_blocks.py index 07f30c45..c7e1ec59 100644 --- a/e2cnn/nn/modules/r2_conv/basisexpansion_blocks.py +++ b/e2cnn/nn/modules/r2_conv/basisexpansion_blocks.py @@ -165,7 +165,7 @@ def __init__(self, # increment the position counter last_weight_position += total_weights - + def get_basis_names(self) -> List[str]: return self._basis_to_ids @@ -364,6 +364,48 @@ def forward(self, weights: torch.Tensor) -> torch.Tensor: # return the new filter return _filter + def __hash__(self): + + _hash = 0 + for io in self._representations_pairs: + n_pairs = self._in_count[io[0]] * self._out_count[io[1]] + _hash += hash(getattr(self, f"block_expansion_{io}")) * n_pairs + + return _hash + + def __eq__(self, other): + if not isinstance(other, BlocksBasisExpansion): + return False + + if self.dimension() != other.dimension(): + return False + + if self._representations_pairs != other._representations_pairs: + return False + + for io in self._representations_pairs: + if self._contiguous[io] != other._contiguous[io]: + return False + + if self._weights_ranges[io] != other._weights_ranges[io]: + return False + + if self._contiguous[io]: + if getattr(self, f"in_indices_{io}") != getattr(other, f"in_indices_{io}"): + return False + if getattr(self, f"out_indices_{io}") != getattr(other, f"out_indices_{io}"): + return False + else: + if torch.any(getattr(self, f"in_indices_{io}") != getattr(other, f"in_indices_{io}")): + return False + if torch.any(getattr(self, f"out_indices_{io}") != getattr(other, f"out_indices_{io}")): + return False + + if getattr(self, f"block_expansion_{io}") != getattr(other, f"block_expansion_{io}"): + return False + + return True + def _retrieve_indices(type: FieldType): fiber_position = 0 diff --git a/e2cnn/nn/modules/r2_conv/basisexpansion_singleblock.py b/e2cnn/nn/modules/r2_conv/basisexpansion_singleblock.py index dfa51e69..c0bab284 100644 --- a/e2cnn/nn/modules/r2_conv/basisexpansion_singleblock.py +++ b/e2cnn/nn/modules/r2_conv/basisexpansion_singleblock.py @@ -32,6 +32,8 @@ def __init__(self, super(SingleBlockBasisExpansion, self).__init__() + self.basis = basis + # compute the mask of the sampled basis containing only the elements allowed by the filter mask = np.zeros(len(basis), dtype=bool) for b, attr in enumerate(basis): @@ -69,7 +71,8 @@ def __init__(self, if not any(norms): raise EmptyBasisException sampled_basis = sampled_basis[norms, ...] - + self._mask = mask + self.attributes = [attr for b, attr in enumerate(attributes) if norms[b]] # register the bases tensors as parameters of this module @@ -111,7 +114,20 @@ def get_basis_info(self) -> Iterable: def dimension(self) -> int: return self.sampled_basis.shape[0] - + + def __eq__(self, other): + if isinstance(other, SingleBlockBasisExpansion): + return ( + self.basis == other.basis and + torch.allclose(self.sampled_basis, other.sampled_basis) and + (self._mask == other._mask).all() + ) + else: + return False + + def __hash__(self): + return 10000 * hash(self.basis) + 100 * hash(self.sampled_basis) + hash(self._mask) + # dictionary storing references to already built basis tensors # when a new filter tensor is built, it is also stored here diff --git a/test/nn/test_he_init.py b/test/nn/test_he_init.py index efbe3d76..b65d7b00 100644 --- a/test/nn/test_he_init.py +++ b/test/nn/test_he_init.py @@ -24,18 +24,21 @@ def test_one_block(self): # t1 = FieldType(gspace, [gspace.regular_repr]*2) # t2 = FieldType(gspace, [gspace.regular_repr]*3) self.check(t1, t2) + self.check_caching(t1, t2) def test_many_block_discontinuous(self): gspace = Rot2dOnR2(8) t1 = FieldType(gspace, list(gspace.representations.values()) * 2) t2 = FieldType(gspace, list(gspace.representations.values()) * 3) self.check(t1, t2) + self.check_caching(t1, t2) def test_many_block_sorted(self): gspace = Rot2dOnR2(8) t1 = FieldType(gspace, list(gspace.representations.values()) * 2).sorted() t2 = FieldType(gspace, list(gspace.representations.values()) * 3).sorted() self.check(t1, t2) + self.check_caching(t1, t2) def test_different_irreps_ratio(self): N = 8 @@ -45,6 +48,33 @@ def test_different_irreps_ratio(self): t1 = FieldType(gspace, [irreps_in]*3) t2 = FieldType(gspace, [irreps_out]*3) self.check(t1, t2) + self.check_caching(t1, t2) + + def test_caching(self): + + N = 7 + gspace = Rot2dOnR2(N) + + # try combinations of field types which are similar but should still be cached independently + + irreps = directsum([gspace.fibergroup.irrep(k) for k in range(N // 2 + 1)], name="irrepssum") + t1 = FieldType(gspace, [irreps] * 2) + t2 = FieldType(gspace, [irreps] * 3) + self.check_caching(t1, t2) + + t1 = FieldType(gspace, [directsum([irreps]*2)]) + t2 = FieldType(gspace, [directsum([irreps]*3)]) + self.check_caching(t1, t2) + + irreps = [gspace.fibergroup.irrep(k) for k in range(N // 2 + 1)] + t1 = FieldType(gspace, irreps * 2) + t2 = FieldType(gspace, irreps * 3) + self.check_caching(t1, t2) + + irreps = [gspace.fibergroup.irrep(k) for k in range(N // 2 + 1)] + t1 = FieldType(gspace, irreps * 2).sorted() + t2 = FieldType(gspace, irreps * 3).sorted() + self.check_caching(t1, t2) def check(self, r1: FieldType, r2: FieldType): @@ -83,6 +113,44 @@ def check(self, r1: FieldType, r2: FieldType): self.assertTrue(torch.allclose(torch.zeros_like(mean), mean, rtol=2e-2, atol=5e-2)) self.assertTrue(torch.allclose(torch.ones_like(std), std, rtol=1e-1, atol=6e-2)) + def check_caching(self, r1: FieldType, r2: FieldType): + + np.set_printoptions(precision=7, threshold=60000, suppress=True) + + assert r1.gspace == r2.gspace + + with torch.no_grad(): + + s = 7 + + cl = R2Conv(r1, r2, s, + # sigma=[0.01] + [0.6]*int(s//2), + frequencies_cutoff=3.) + + from datetime import datetime + + torch.manual_seed(42) + weights1 = torch.zeros_like(cl.weights.data) + + start = datetime.now() + init.generalized_he_init(weights1, cl.basisexpansion, cache=True) + end = datetime.now() + elapsed1 = (end - start).total_seconds() + + for i in range(10): + torch.manual_seed(42) + weights2 = torch.zeros_like(cl.weights.data) + + start = datetime.now() + init.generalized_he_init(weights2, cl.basisexpansion, cache=True) + end = datetime.now() + elapsed2 = (end - start).total_seconds() + + self.assertTrue(torch.allclose(weights1, weights2)) + + # this ensures that the first run was indeed the first time this was cached + self.assertTrue(elapsed1 >= 4 * elapsed2) + if __name__ == '__main__': unittest.main() From 0847b72f399e0b02b09a3b24f8020019785e39ac Mon Sep 17 00:00:00 2001 From: Gabriele Cesa Date: Wed, 30 Jun 2021 12:59:37 +0200 Subject: [PATCH 2/2] FieldType.transform() moves tensor to CPU and gives warning. Fix #39 --- e2cnn/nn/field_type.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/e2cnn/nn/field_type.py b/e2cnn/nn/field_type.py index 9416cb8a..da617b93 100644 --- a/e2cnn/nn/field_type.py +++ b/e2cnn/nn/field_type.py @@ -210,8 +210,9 @@ def transform(self, input: torch.Tensor, element) -> torch.Tensor: and its (induced) action on the base space. .. warning :: - The input tensor is detached before the transformation, therefore no gradient is propagated back - through this operation. + This method is internally implemented using ```numpy```. + This means that the input tensor is detached (and moved to CPU) before the transformation, therefore no + gradient is propagated back through this operation. .. seealso :: @@ -229,9 +230,14 @@ def transform(self, input: torch.Tensor, element) -> torch.Tensor: transformed tensor """ - transformed = self.gspace.featurefield_action(input.detach().numpy(), self.representation, element) + if input.is_cuda: + import warnings + warnings.warn('The input tensor is on GPU. The `FieldType.transform()` operation is based on `numpy` and,' + ' therefore, must temporarily move the tensor on CPU. This can cause performance issues.') + + transformed = self.gspace.featurefield_action(input.detach().cpu().numpy(), self.representation, element) transformed = np.ascontiguousarray(transformed) - return torch.from_numpy(transformed.astype(np.float32)) + return torch.from_numpy(transformed.astype(np.float32)).to(device=input.device) def restrict(self, id) -> 'FieldType': r"""