Why my neural network does not converge using Jax?

Solution 1:

It seems that the asker has already solve this problem him/herself. However, I still want to give an explanation what really happened because I faced the entirely same problem this asker faced.

Indeed, the awkward behavior of the neural network before the asker deleted that Y = jnp.squeeze(Y), the shape of Y and predictions in the function definition of loss(params, X, Y) actually had different shapes: predictions is a column vector (of size (N, 1)), and after the "squeezing" operation, which was corrected by the asker him/herself, Y becomes a row vector (of size (1, N)).

In NumPy and JAX's NumPy (actually also in MATLAB), there is a feature called broadcasting for array operations. Due to this feature, the interpreter will do calculations like

\begin{pmatrix}
a_{1}\\
a_{2}\\
\vdots \\
a_{m}
\end{pmatrix} -\begin{pmatrix}
b_{1} & b_{2} & \cdots  & b_{n}
\end{pmatrix} =\begin{pmatrix}
a_{1} -b_{1} & a_{1} -b_{2} & \cdots  & a_{1} -b_{n}\\
a_{2} -b_{1} & a_{2} -b_{2} & \cdots  & a_{2} -b_{n}\\
\vdots  & \vdots  & \ddots  & \vdots \\
a_{m} -b_{1} & a_{m} -b_{2} & \cdots  & a_{m} -b_{n}
\end{pmatrix}

(this formula should be interpreted by LaTeX)

Therefore, before the asker's own correction, Y - predictions is actually a matrix with shape (N, N), and np.means() averaged all the entries in this N*N matrix, which is of course not the desired MSELoss one wants to calculate, and caused the strange convergence behavior the asker showed.