基准测试 - 比较评估器性能#
benchmarking
模块使您能够轻松地组织基准测试实验,在这些实验中,您希望比较一个或多个算法在一个或多个数据集和基准测试配置上的性能。
基准测试作为一项工作,通常很容易出错,导致关于评估器性能的错误结论——请参阅普林斯顿大学 2022 年的这项研究,其中提供了同行评审学术论文中此类错误的众多示例作为证据。
sktime
的 benchmarking
模块旨在提供基准测试功能,同时强制执行最佳实践和结构,以帮助用户避免犯下会使结果无效的错误(例如数据泄露等)。benchmarking
模块设计时考虑了易用性,因此它直接与 sktime
对象和类进行接口。先前开发的评估器应该可以原样使用,无需修改。
本 notebook 演示了 benchmarking
模块的使用。
[1]:
from sktime.benchmarking.forecasting import ForecastingBenchmark
from sktime.datasets import load_airline
from sktime.forecasting.naive import NaiveForecaster
from sktime.performance_metrics.forecasting import MeanSquaredPercentageError
from sktime.split import ExpandingWindowSplitter
实例化一个基准测试类实例#
在此示例中,我们正在比较预测评估器。
[2]:
benchmark = ForecastingBenchmark()
添加竞争评估器#
我们将不同的竞争评估器添加到基准测试实例中。所有添加的评估器将自动通过每个添加的基准测试任务运行,并汇总其结果。
[3]:
benchmark.add_estimator(
estimator=NaiveForecaster(strategy="mean", sp=12),
estimator_id="NaiveForecaster-mean-v1",
)
benchmark.add_estimator(
estimator=NaiveForecaster(strategy="last", sp=12),
estimator_id="NaiveForecaster-last-v1",
)
添加基准测试任务#
这些是每个评估器将进行测试并汇总结果的预测/验证任务。
基准测试任务的确切参数取决于目标是预测、分类等,但总体而言它们是相似的。以下是定义预测基准测试任务的必需参数。
指定交叉验证分割方案#
使用标准的 sktime
对象定义交叉验证分割方案。
[4]:
cv_splitter = ExpandingWindowSplitter(
initial_window=24,
step_length=12,
fh=12,
)
指定性能指标#
使用标准的 sktime
对象定义用于比较评估器的性能指标。
[5]:
scorers = [MeanSquaredPercentageError()]
指定数据集加载器#
定义数据集加载器,它们是应该返回数据集的可调用对象(函数)。通常这是一个返回包含整个数据集的 dataframe 的可调用对象。可以使用 sktime
定义的数据集,或自定义数据集。一个像下面这样简单的例子就足够了
def my_dataset_loader():
return pd.read_csv("path/to/data.csv")
在运行基准测试任务时将加载数据集,通过交叉验证方案运行,随后评估器将在数据集分割上进行测试。
[6]:
dataset_loaders = [load_airline]
将任务添加到基准测试实例#
使用先前定义的对象将任务添加到基准测试实例。可以选择使用循环等方法轻松设置多个重复使用参数的基准测试任务。
[7]:
for dataset_loader in dataset_loaders:
benchmark.add_task(
dataset_loader,
cv_splitter,
scorers,
)
运行所有任务-评估器组合并存储结果#
请注意,run
不会重新运行已有结果的任务,因此添加一个新的评估器并再次运行 run
将仅为该新评估器运行任务。
[8]:
results_df = benchmark.run("./forecasting_results.csv")
results_df.T
[8]:
0 | 1 | |
---|---|---|
validation_id | [dataset=load_airline]_[cv_splitter=ExpandingW... | [dataset=load_airline]_[cv_splitter=ExpandingW... |
model_id | NaiveForecaster-last-v1 | NaiveForecaster-mean-v1 |
runtime_secs | 0.061472 | 0.081733 |
MeanSquaredPercentageError_fold_0_test | 0.024532 | 0.049681 |
MeanSquaredPercentageError_fold_1_test | 0.020831 | 0.0737 |
MeanSquaredPercentageError_fold_2_test | 0.001213 | 0.05352 |
MeanSquaredPercentageError_fold_3_test | 0.01495 | 0.081063 |
MeanSquaredPercentageError_fold_4_test | 0.031067 | 0.138163 |
MeanSquaredPercentageError_fold_5_test | 0.008373 | 0.145125 |
MeanSquaredPercentageError_fold_6_test | 0.007972 | 0.154337 |
MeanSquaredPercentageError_fold_7_test | 0.000009 | 0.123298 |
MeanSquaredPercentageError_fold_8_test | 0.028191 | 0.185644 |
MeanSquaredPercentageError_fold_9_test | 0.003906 | 0.184654 |
MeanSquaredPercentageError_mean | 0.014104 | 0.118918 |
MeanSquaredPercentageError_std | 0.011451 | 0.051265 |