Why my neural network does not converge using Jax?
Solution 1:
It seems that the asker has already solve this problem him/herself. However, I still want to give an explanation what really happened because I faced the entirely same problem this asker faced.
Indeed, the awkward behavior of the neural network before the asker deleted that Y = jnp.squeeze(Y)
, the shape of Y
and predictions
in the function definition of loss(params, X, Y)
actually had different shapes: predictions
is a column vector (of size (N, 1)
), and after the "squeezing" operation, which was corrected by the asker him/herself, Y
becomes a row vector (of size (1, N)
).
In NumPy and JAX's NumPy (actually also in MATLAB), there is a feature called broadcasting for array operations. Due to this feature, the interpreter will do calculations like
\begin{pmatrix}
a_{1}\\
a_{2}\\
\vdots \\
a_{m}
\end{pmatrix} -\begin{pmatrix}
b_{1} & b_{2} & \cdots & b_{n}
\end{pmatrix} =\begin{pmatrix}
a_{1} -b_{1} & a_{1} -b_{2} & \cdots & a_{1} -b_{n}\\
a_{2} -b_{1} & a_{2} -b_{2} & \cdots & a_{2} -b_{n}\\
\vdots & \vdots & \ddots & \vdots \\
a_{m} -b_{1} & a_{m} -b_{2} & \cdots & a_{m} -b_{n}
\end{pmatrix}
(this formula should be interpreted by LaTeX)
Therefore, before the asker's own correction, Y - predictions is actually a matrix with shape (N, N)
, and np.means()
averaged all the entries in this N*N matrix, which is of course not the desired MSELoss one wants to calculate, and caused the strange convergence behavior the asker showed.