How to make a piecewise linear fit in Python with some constant pieces?

I'm trying to make a piecewise linear fit consisting of 3 pieces whereof the first and last pieces are constant. As you can see in this figure figure

don't get the expected fit, since the fit doesn't capture the 3 linear pieces clearly visual from the original data points.

I've tried following this question and expanded it to the case of 3 pieces with the two constant pieces, but I must have done something wrong.

Here is my code:

from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
plt.rcParams['figure.figsize'] = [16, 6]

x = np.arange(0, 50, dtype=float)
y = np.array([50 for i in range(10)]
             + [50 - (50-5)/31 * i for i in range(1, 31)]
             + [5 for i in range(10)],
             dtype=float)

def piecewise_linear(x, x0, y0, x1, y1):
    return np.piecewise(x,
                        [x < x0, (x >= x0) & (x < x1), x >= x1],
                        [lambda x:y0, lambda x:(y1-y0)/(x1-x0)*(x-x0)+y0, lambda x:y1])

p , e = optimize.curve_fit(piecewise_linear, x, y)
xd = np.linspace(0, 50, 101)

plt.plot(x, y, "o", label='original data')
plt.plot(xd, piecewise_linear(xd, *p), label='piecewise linear fit')
plt.legend()

The accepted answer to the previous mentioned question suggest looking at segments_fit.ipynb for the case of N parts, but following that it doesn't seem that I can specify, that the first and last pieces should be constant.

Furthermore I do get the following warning:

OptimizeWarning: Covariance of the parameters could not be estimated

What do I do wrong?


You can get a one line solution (not counting the import) using univariate splines of degree one. Like this

from scipy.interpolate import UnivariateSpline

f = UnivariateSpline(x,y,k=1,s=0)

Here k=1 means we interpolate using polynomials of degree one aka lines. s is the smoothing parameter. It decides how much you want to compromise on the fit to avoid using too many segments. Setting it to zero means no compromises i.e. the line HAS to go threw all points. See the documentation.

Then

plt.plot(x, y, "o", label='original data')
plt.plot(x, f(x), label='linear interpolation')
plt.legend()
plt.savefig("out.png", dpi=300)

gives linear spline interpolation