binder

基准测试 - 比较评估器性能#

benchmarking 模块使您能够轻松地组织基准测试实验,在这些实验中,您希望比较一个或多个算法在一个或多个数据集和基准测试配置上的性能。

基准测试作为一项工作,通常很容易出错,导致关于评估器性能的错误结论——请参阅普林斯顿大学 2022 年的这项研究,其中提供了同行评审学术论文中此类错误的众多示例作为证据。

sktimebenchmarking 模块旨在提供基准测试功能,同时强制执行最佳实践和结构,以帮助用户避免犯下会使结果无效的错误(例如数据泄露等)。benchmarking 模块设计时考虑了易用性,因此它直接与 sktime 对象和类进行接口。先前开发的评估器应该可以原样使用,无需修改。

本 notebook 演示了 benchmarking 模块的使用。

[1]:
from sktime.benchmarking.forecasting import ForecastingBenchmark
from sktime.datasets import load_airline
from sktime.forecasting.naive import NaiveForecaster
from sktime.performance_metrics.forecasting import MeanSquaredPercentageError
from sktime.split import ExpandingWindowSplitter

实例化一个基准测试类实例#

在此示例中,我们正在比较预测评估器。

[2]:
benchmark = ForecastingBenchmark()

添加竞争评估器#

我们将不同的竞争评估器添加到基准测试实例中。所有添加的评估器将自动通过每个添加的基准测试任务运行,并汇总其结果。

[3]:
benchmark.add_estimator(
    estimator=NaiveForecaster(strategy="mean", sp=12),
    estimator_id="NaiveForecaster-mean-v1",
)
benchmark.add_estimator(
    estimator=NaiveForecaster(strategy="last", sp=12),
    estimator_id="NaiveForecaster-last-v1",
)

添加基准测试任务#

这些是每个评估器将进行测试并汇总结果的预测/验证任务。

基准测试任务的确切参数取决于目标是预测、分类等,但总体而言它们是相似的。以下是定义预测基准测试任务的必需参数。

指定交叉验证分割方案#

使用标准的 sktime 对象定义交叉验证分割方案。

[4]:
cv_splitter = ExpandingWindowSplitter(
    initial_window=24,
    step_length=12,
    fh=12,
)

指定性能指标#

使用标准的 sktime 对象定义用于比较评估器的性能指标。

[5]:
scorers = [MeanSquaredPercentageError()]

指定数据集加载器#

定义数据集加载器,它们是应该返回数据集的可调用对象(函数)。通常这是一个返回包含整个数据集的 dataframe 的可调用对象。可以使用 sktime 定义的数据集,或自定义数据集。一个像下面这样简单的例子就足够了

def my_dataset_loader():
    return pd.read_csv("path/to/data.csv")

在运行基准测试任务时将加载数据集,通过交叉验证方案运行,随后评估器将在数据集分割上进行测试。

[6]:
dataset_loaders = [load_airline]

将任务添加到基准测试实例#

使用先前定义的对象将任务添加到基准测试实例。可以选择使用循环等方法轻松设置多个重复使用参数的基准测试任务。

[7]:
for dataset_loader in dataset_loaders:
    benchmark.add_task(
        dataset_loader,
        cv_splitter,
        scorers,
    )

运行所有任务-评估器组合并存储结果#

请注意,run 不会重新运行已有结果的任务,因此添加一个新的评估器并再次运行 run 将仅为该新评估器运行任务。

[8]:
results_df = benchmark.run("./forecasting_results.csv")
results_df.T
[8]:
0 1
validation_id [dataset=load_airline]_[cv_splitter=ExpandingW... [dataset=load_airline]_[cv_splitter=ExpandingW...
model_id NaiveForecaster-last-v1 NaiveForecaster-mean-v1
runtime_secs 0.061472 0.081733
MeanSquaredPercentageError_fold_0_test 0.024532 0.049681
MeanSquaredPercentageError_fold_1_test 0.020831 0.0737
MeanSquaredPercentageError_fold_2_test 0.001213 0.05352
MeanSquaredPercentageError_fold_3_test 0.01495 0.081063
MeanSquaredPercentageError_fold_4_test 0.031067 0.138163
MeanSquaredPercentageError_fold_5_test 0.008373 0.145125
MeanSquaredPercentageError_fold_6_test 0.007972 0.154337
MeanSquaredPercentageError_fold_7_test 0.000009 0.123298
MeanSquaredPercentageError_fold_8_test 0.028191 0.185644
MeanSquaredPercentageError_fold_9_test 0.003906 0.184654
MeanSquaredPercentageError_mean 0.014104 0.118918
MeanSquaredPercentageError_std 0.011451 0.051265

使用 nbsphinx 生成。Jupyter notebook 可在此处找到:这里