Skip to content

Commit

Permalink
Don't use groups for transposed convs for Keras
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Apr 21, 2024
1 parent a862c99 commit e696c09
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion neuralprocesses/tensorflow/nn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Optional, Union

Expand Down Expand Up @@ -119,14 +120,23 @@ def ConvNd(
else:
suffix = ""

if groups > 1:
if transposed:
warnings.warn(
"Keras does not depthwise separable transposed convolutions! "
"Using non-separable convolutions for the transposed convolutions. "
"This could be a LOT more expensive."
)
else:
additional_args["groups"] = groups

conv_layer = getattr(tf.keras.layers, f"Conv{dim}D{suffix}")(
input_shape=(in_channels,) + (None,) * dim,
filters=out_channels,
kernel_size=kernel,
strides=stride,
padding="same",
dilation_rate=dilation,
groups=groups,
use_bias=bias,
data_format=data_format,
dtype=dtype,
Expand Down

0 comments on commit e696c09

Please sign in to comment.