mean_squared_error#

mean_squared_error(y_true, y_pred, horizon_weight=None, multioutput='uniform_average', square_root=False, **kwargs)[来源]#

均方误差 (MSE) 或均方根误差 (RMSE)。

如果 square_root 为 False,则计算 MSE;如果 square_root 为 True,则计算 RMSE。MSE 和 RMSE 都是非负浮点数。最佳值为 0.0。

MSE 的单位是输入数据的平方单位,而 RMSE 的单位与数据相同。由于 MSE 和 RMSE 是对预测误差进行平方而不是取绝对值,因此它们对大误差的惩罚比 MAE 更重。

参数:
y_truepd.Series, pd.DataFrame 或 np.array,形状为 (fh,) 或 (fh, n_outputs),其中 fh 是预测范围

真实(正确)目标值。

y_predpd.Series, pd.DataFrame 或 np.array,形状为 (fh,) 或 (fh, n_outputs),其中 fh 是预测范围

预测值。

horizon_weightarray-like,形状为 (fh,),默认为 None

预测范围权重。

multioutput{‘raw_values’, ‘uniform_average’} 或 array-like,形状为 (n_outputs,),默认为 ‘uniform_average’

定义如何聚合多元(多输出)数据的指标。如果是 array-like,则使用值作为权重来平均误差。如果是 ‘raw_values’,则在多输出输入情况下返回完整的误差集。如果是 ‘uniform_average’,则所有输出的误差以统一权重进行平均。

square_rootbool,默认为 False

是否取均方误差的平方根。如果为 True,则返回均方根误差 (RMSE);如果为 False,则返回均方误差 (MSE)。

返回:
lossfloat 或 float 的 ndarray

MSE 损失。如果 multioutput 是 ‘raw_values’,则为每个输出分别返回 MSE。如果 multioutput 是 ‘uniform_average’ 或权重的 ndarray,则返回所有输出误差的加权平均 MSE。

参考

Hyndman, R. J and Koehler, A. B. (2006). “Another look at measures of forecast accuracy”, International Journal of Forecasting, Volume 22, Issue 4.

示例

>>> import numpy as np
>>> from sktime.performance_metrics.forecasting import mean_squared_error
>>> y_true = np.array([3, -0.5, 2, 7, 2])
>>> y_pred = np.array([2.5, 0.0, 2, 8, 1.25])
>>> mean_squared_error(y_true, y_pred)
0.4125
>>> y_true = np.array([[0.5, 1], [-1, 1], [7, -6]])
>>> y_pred = np.array([[0, 2], [-1, 2], [8, -5]])
>>> mean_squared_error(y_true, y_pred)
0.7083333333333334
>>> mean_squared_error(y_true, y_pred, square_root=True)
0.8227486121839513
>>> mean_squared_error(y_true, y_pred, multioutput='raw_values')
array([0.41666667, 1.        ])
>>> mean_squared_error(y_true, y_pred, multioutput='raw_values', square_root=True)
array([0.64549722, 1.        ])
>>> mean_squared_error(y_true, y_pred, multioutput=[0.3, 0.7])
0.825
>>> mean_squared_error(y_true, y_pred, multioutput=[0.3, 0.7], square_root=True)
0.8936491673103708