这是我的解决方案。我不知道该怎么做
curve_fit
,但它与
leastsq
. 它有一个包装函数,该函数接受自由和固定参数以及自由参数位置列表。像
leastsq公司
首先使用自由参数调用函数,因此包装器必须重新排列顺序。
from matplotlib import pyplot as plt
import numpy as np
from scipy.optimize import leastsq
def func(x,a,b,c,d,e):
return a+b*x+c*x**2+d*x**3+e*x**4
#takes x, the 5 parameters and a list
# the first n parameters are free
# the list of length n gives there position, e.g. 2 parameters, 1st and 3rd order ->[1,3]
# the remaining parameters are in order, i.e. in this example it would be f(x,b,d,a,c,e)
def expand_parameters(*args):
callArgs=args[1:6]
freeList=args[-1]
fixedList=range(5)
for item in freeList:
fixedList.remove(item)
callList=[0,0,0,0,0]
for val,pos in zip(callArgs, freeList+fixedList):
callList[pos]=val
return func(args[0],*callList)
def residuals(parameters,dataPoint,fixedParameterValues=None,freeParametersPosition=None):
if fixedParameterValues is None:
a,b,c,d,e = parameters
dist = [y -func(x,a,b,c,d,e) for x,y in dataPoint]
else:
assert len(fixedParameterValues)==5-len(freeParametersPosition)
assert len(fixedParameterValues)>0
assert len(fixedParameterValues)<5 # doesn't make sense to fix all
extraIn=list(parameters)+list(fixedParameterValues)+[freeParametersPosition]
dist = [y -expand_parameters(x,*extraIn) for x,y in dataPoint]
return dist
if __name__=="__main__":
xList=np.linspace(-1,3,15)
fList=np.fromiter( (func(s,1.1,-.9,-.7,.5,.1) for s in xList), np.float)
fig=plt.figure()
ax=fig.add_subplot(1,1,1)
dataTupel=zip(xList,fList)
###some test
print residuals([1.1,-.9,-.7,.5,.1],dataTupel)
print residuals([1.1,-.9,-.7,.5],dataTupel,fixedParameterValues=[.1],freeParametersPosition=[0,1,2,3])
#exact fit
bestFitValuesAll, ier = leastsq(residuals, [1,1,1,1,1],args=(dataTupel))
print bestFitValuesAll
###Only a constant
guess=[1]
bestFitValuesConstOnly, ier = leastsq(residuals, guess,args=(dataTupel,[0,0,0,0],[0]))
print bestFitValuesConstOnly
fConstList=np.fromiter(( func(x,*np.append(bestFitValuesConstOnly,[0,0,0,0])) for x in xList),np.float)
###Only 2nd and 4th
guess=[1,1]
bestFitValues_1_3, ier = leastsq(residuals, guess,args=(dataTupel,[0,0,0],[2,4]))
print bestFitValues_1_3
f_1_3_List=np.fromiter(( expand_parameters(x, *(list(bestFitValues_1_3)+[0,0,0]+[[2,4]] ) ) for x in xList),np.float)
###Only 2nd and 4th with closer values
guess=[1,1]
bestFitValues_1_3_closer, ier = leastsq(residuals, guess,args=(dataTupel,[1.2,-.8,0],[2,4]))
print bestFitValues_1_3_closer
f_1_3_closer_List=np.fromiter(( expand_parameters(x, *(list(bestFitValues_1_3_closer)+[1.2,-.8,0]+[[2,4]] ) ) for x in xList),np.float)
ax.plot(xList,fList,linestyle='',marker='o',label='orig')
ax.plot(xList,fConstList,linestyle='',marker='o',label='0')
ax.plot(xList,f_1_3_List,linestyle='',marker='o',label='1,3')
ax.plot(xList,f_1_3_closer_List,linestyle='',marker='o',label='1,3 c')
ax.legend(loc=0)
plt.show()
提供:
>>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
>>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
>>[ 1.1 -0.9 -0.7 0.5 0.1]
>>[ 2.64880466]
>>[-0.14065838 0.18305123]
>>[-0.31708629 0.2227272 ]