diff --git a/neuralprocesses/tensorflow/nn.py b/neuralprocesses/tensorflow/nn.py index 1956a17f..da8c42d9 100644 --- a/neuralprocesses/tensorflow/nn.py +++ b/neuralprocesses/tensorflow/nn.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Optional, Union @@ -119,6 +120,16 @@ def ConvNd( else: suffix = "" + if groups > 1: + if transposed: + warnings.warn( + "Keras does not depthwise separable transposed convolutions! " + "Using non-separable convolutions for the transposed convolutions. " + "This could be a LOT more expensive." + ) + else: + additional_args["groups"] = groups + conv_layer = getattr(tf.keras.layers, f"Conv{dim}D{suffix}")( input_shape=(in_channels,) + (None,) * dim, filters=out_channels, @@ -126,7 +137,6 @@ def ConvNd( strides=stride, padding="same", dilation_rate=dilation, - groups=groups, use_bias=bias, data_format=data_format, dtype=dtype,