I am trying to use conv1d functions to make a transposed convlotion repectively at jax and tensorflow. I read the documentation of both of jax and tensorflow for the con1d_transposed operation but they are resulting with different outputs for the same input.
I can not find out what the problem is. And I don’t know which one produces the correct results. Help me please.
My Jax Implementation (Jax Code)
x = np.asarray([[[1, 2, 3, 4, -5], [1, 2, 3, 4, 5]]], dtype=np.float32).transpose((0, 2, 1)) filters = np.array([[[1, 0, -1], [-1, 0, 1]], [[1, 1, 1], [-1, -1, -1]]], dtype=np.float32).transpose((2, 1, 0)) kernel_rot = np.rot90(np.rot90(filters)) print(f"x strides: {x.strides}nfilters strides: {kernel_rot.strides}nx shape: {x.shape}nfilters shape: {filters.shape}nx: n{x}nfilters: n{filters}n") dn1 = lax.conv_dimension_numbers(x.shape, filters.shape,('NWC', 'WIO', 'NWC')) print(dn1) res = lax.conv_general_dilated(x,kernel_rot,(1,),'SAME',(1,),(1,),dn1) res = np.asarray(res) print(f"result strides: {res.strides}nresult shape: {res.shape}nresult: n{res}n")
My TensorFlow Implementation (TensorFlow Code)
x = np.asarray([[[1, 2, 3, 4, -5], [1, 2, 3, 4, 5]]], dtype=np.float32).transpose((0, 2, 1)) filters = np.array([[[1, 0, -1], [-1, 0, 1]], [[1, 1, 1], [-1, -1, -1]]], dtype=np.float32).transpose((2, 1, 0)) print(f"x strides: {x.strides}nfilters strides: {filters.strides}nx shape: {x.shape}nfilters shape: {filters.shape}nx: n{x}nfilters: n{filters}n") res = tf.nn.conv1d_transpose(x, filters, output_shape = x.shape, strides = (1, 1, 1), padding = 'SAME', data_format='NWC', dilations=1) res = np.asarray(res) print(f"result strides: {res.strides}nresult shape: {res.shape}nresult: n{res}n")
Output from the Jax
result strides: (40, 8, 4) result shape: (1, 5, 2) result: [[[ 0. 0.] [ 0. 0.] [ 0. 0.] [10. 10.] [ 0. 10.]]]
Output from the TensorFlow
result strides: (40, 8, 4) result shape: (1, 5, 2) result: [[[ 5. -5.] [ 8. -8.] [ 11. -11.] [ 4. -4.] [ 5. -5.]]]
Advertisement
Answer
Function conv1d_transpose
expects filters in shape [filter_width, output_channels, in_channels]
. If filters
in snippet above were transposed to satisfy this shape, then for jax to return correct results, while computing dn1
parameter should be WOI
(Width – Output_channels – Input_channels) and not WIO
(Width – Input_channels – Output_channels). After that:
result.strides = (40, 8, 4) result.shape = (1, 5, 2) result: [[[ -5., 5.], [ -8., 8.], [-11., 11.], [ -4., 4.], [ -5., 5.]]]
Results not same as with tensorflow, but kernels for jax were flipped, so actually that was expected.