为什么scipy.optimize.least_squares存在,而scipy.optimize.minimize可能用于相同的事情?

apeeds0o  于 5个月前  发布在  其他
关注(0)|答案(2)|浏览(40)

我想了解一下为什么scipy.optimize.least_squares会出现在scipy中,这个函数可以用来进行模型拟合,但是也可以用scipy.optimize.minimize来做同样的事情,唯一的区别是scipy.optimize.least_squares是在内部进行卡方的计算,而如果要用scipy.optimize.minimize,他/她将不得不在用户想要最小化的函数中手动计算卡方。此外,scipy.optimize.least_squares不能被视为scipy.optimize.minimize的 Package 器,因为它支持的三种方法(trfdogboxlm),scipy.optimize.minimize完全不支援。
所以我的问题是:

  • scipy.optimize.minimize可以达到相同的结果时,为什么scipy.optimize.least_squares存在?
  • 为什么scipy.optimize.minimize不支持trfdogboxlm方法?

谢谢

balp4ylt

balp4ylt1#

scipy.optimize.least_squares中的算法利用最小化问题的最小二乘结构来获得更好的收敛性(或所使用的导数的低阶)。
它类似于高斯-牛顿算法和牛顿方法之间的区别,参见Wikipedia或本题。
特别是,高斯-牛顿法只使用雅可比(一阶导数),而牛顿法还使用海森(二阶导数),这是昂贵的计算。

mepcadol

mepcadol2#

OP:
如果想要使用scipy.optimize.minimize,他/她将不得不在用户想要最小化的函数内手动计算卡方。
我不太明白你是如何在minimize中应用chi-square的?chi-square
我看到这样的最小化最小二乘--局部:

import numpy as np
from scipy.optimize import minimize
import matplotlib.pyplot as plt

X = np.array([1,2,3,4,5, 6,7,8,9,10])
y = np.array([10,15,7,5,8, 10,18,14,12,11])

plt.plot(X,y,'k')
plt.show()

loss_res = lambda z: 0.5 * z ** 2   # MSE
f_to_optMin = lambda w: np.sum(loss_res(X @ w.ravel() - y.T)) 

res= minimize(f_to_optMin, (0,0,0,0,0,0,0,0,0,0))
y_pred= np.array([ i + X.max() - X.min() for i in np.multiply(X, res.x) ] )

plt.plot(X,y,'k', X, y_pred.T,'r')
plt.show()

字符串
或者这样(优化参数,又名系数):

import numpy as np
from scipy.optimize import minimize
import matplotlib.pyplot as plt

X_est = np.array([1,2,3,4,5, 6,7,8,9,10])
y = np.array([10,15,7,5,8, 10,18,14,12,11])

plt.plot(X_est,y,'k')
plt.show()

def loss(x, X_est, y):
    # Evaluate the fit function with the current parameter estimates
    ynew = myModel( X_est, *x)
    yerr = np.sum( ( ynew - y ) ** 2 )
    return yerr

def myModel( x,  b, c ):
    y =   b * x + c
    return y

p0 = [ 1., 1.]   # params to optimize
res= minimize(loss, p0, args=( X_est, y ), method='Nelder-Mead')
print(res.x)
b,c = res.x
y_pred= np.array([ myModel(i, b,c)  for i in  X_est ] )

plt.plot(X_est,y,'k', X_est, y_pred.T,'r')
plt.show()

相关问题