实现 Estimator#

本页介绍如何实现与 sktime 兼容的 estimator,以及如何确保和测试兼容性。对于直接贡献给 sktime 的 estimator,还需要额外的步骤。

实现与 sktime 兼容的 estimator#

实现与 sktime 兼容的 estimator 的高级步骤如下

  1. 确定 estimator 的类型:forecaster(预测器)、classifier(分类器)等

  2. 将该类型 estimator 的扩展模板复制到其目标位置

  3. 填写完整的扩展模板

  4. 运行 sktime 测试套件和/或 check_estimator 工具(参见此处

  5. 如果测试套件发现错误或问题,修复它们并返回步骤 4

有关如何实现自己的 estimator 的更多指导,请参阅 pydata 上的此教程,内容关于接口一致性测试。

我的学习任务是什么?#

sktime 按照包含特定学习任务的模块进行组织,例如 forecasting(预测)或 time series classification(时间序列分类)。为简洁起见,我们根据 estimator 解决的正式学习任务定义其科学类型或“scitype”。例如,解决预测任务的 estimator 的 scitype 是“forecaster”。解决时间序列分类任务的 estimator 的 scitype 是“time series classifier”。

给定 scitype 的 estimator 应位于相应的模块中。estimator 的 scitype 也对应于 sktimeextension_templates 目录中找到的不同扩展模板。

通常,给定 estimator 的 scitype 直接由其功能决定。这通常也在与 estimator 相关的出版物中明确指出。例如,大多数教科书在预测的背景下提到 ARIMA,因此在这种假设情况下,考虑“forecaster”模板是有意义的。然后,检查模板并查看类的各种方法是否清晰地映射到 estimator 的例程。如果不匹配,则可能需要使用另一个模板。

这里最常见的困惑点在于 transformers 和其他 estimator 类型之间的区分,因为 transformers 经常被用作其他类型算法的一部分。

如果不确定,请随时在 sktime 的社交渠道之一上提问。请勿惊慌——学术出版物对 estimator 类型描述不明确的情况并不少见,即使对专家来说,正确分类也可能很困难。

sktime 扩展模板是什么?#

扩展模板是为新 estimator 的实现者提供的便捷“填充”模板。它们以以下方式融入 sktime 的统一接口

  • 对于每种 scitype,都有一个公共用户接口,由相应的基类定义。例如,BaseForecaster 为 forecaster 定义了 fitpredict 接口。所有 forecaster 都将通过继承 BaseForecaster 以相同的方式实现 fitpredict。公共接口遵循“策略”面向对象模式。

  • 对于每种 scitype,都有一个私有扩展接口,由扩展模板中的扩展契约定义。例如,用于 forecaster 的 forecaster.py 扩展模板解释了为继承自 BaseForecaster 的具体 forecaster 需要填写哪些内容。在大多数扩展模板中,用户应实现私有方法(“内部”方法),例如 forecaster 的 _fit_predict。样板代码位于接口的公共部分,即 fitpredict 中。扩展接口遵循“模板”面向对象模式。

熟悉 scikit-learn 扩展的扩展者应注意与 scikit-learn 的以下区别

sktime (具体) estimator 中,公共接口,例如 fitpredict,从不被覆盖。实现发生在私有的、扩展者侧的接口中,例如 _fit_predict

这可以避免样板代码的重复,例如 scikit-learn 中的 check_X 等。这也允许更丰富的样板代码,例如自动向量化功能或输入转换。

如何使用 sktime 扩展模板#

要使用 sktime 扩展模板,请将其复制到 estimator 的目标位置。在扩展模板内部,必要的操作用 todo 标记。典型的工作流程是搜索 todo,并执行 todo 旁描述的操作。

扩展模板通常有以下 todo

  • 选择 estimator 的名称和参数

  • 填写 __init__ 方法:将参数写入 self,调用 super__init__ 方法

  • 填写模块和 estimator 的 docstring。建议在参数确定后尽早完成此步骤,这通常对后续实现非常有益,可作为规范来遵循。

  • 填写 estimator 的标签。一些标签是“capabilities”(能力),即 estimator 能做什么,例如处理 nans。其他标签确定在“内部”方法 _fit 等中输入的格式,这些标签通常称为 X_inner_mtype 或类似名称。这在内部功能假设输入为 numpy.ndarraypandas.DataFrame 的情况下非常有用,有助于避免转换样板代码。类型字符串可以在 datatypes.MTYPE_REGISTER 中找到。有关数据类型约定的教程,请参阅 examples/AA_datatypes_and_datasets

  • 填写“内部”方法,例如 _fit_predict。应遵循扩展模板中的 docstring 和注释。docstring 还描述了对“内部”方法输入的保证,这些保证通常比对公共方法输入的保证更强,并由已设置的标签值确定。例如,将 forecaster 的标签 y_inner_mtype 设置为 pd.DataFrame 可保证 _fit 看到的 y 将是 pandas.DataFrame,并且符合 sktime 中的其他数据容器规范(例如,索引类型)。

  • get_test_params 中填写测试参数。参数的选择应涵盖主要的 estimator 内部情况区分,以实现良好的覆盖率。

一些常见注意事项,也在扩展模板文本中描述

  • __init__ 参数应写入 self 且绝不应更改

  • 特殊情况:estimator 组件,即作为 estimator 的参数,通常应被克隆(通过 sklearn.clone),且方法应仅在克隆对象上调用

  • 方法通常应避免对参数产生副作用

  • 非状态更改方法通常不应写入 self

  • 通常,不需要实现 get_paramsset_params,因为 sktimeBaseEstimator 继承自 sklearnBaseEstimator。自定义的 get_paramsset_params 通常仅在异构复合(例如,包含 estimator 的嵌套结构的参数的管道)等复杂情况下才需要。

如何测试接口一致性#

有关以下内容的视频教程和更多示例,请访问我们在 pydata 上的教程

使用 check_estimator 工具#

通常,测试与 sktime 接口一致性的最简单方法是使用 utils.estimator_checks 模块中的 check_estimator 方法。

调用时,这将收集 sktime 中与该 estimator 类型相关的测试,并在该 estimator 上运行它们。

这可以在 notebook 环境中用于手动调试。运行 NaiveForecaster 完整测试套件的示例

from sktime.utils.estimator_checks import check_estimator
from sktime.forecasting.naive import NaiveForecaster
check_estimator(NaiveForecaster)

默认情况下,check_estimator 工具将返回一个 dict,其索引是测试/fixture 组合字符串,即测试名称和方括号中的 fixture 组合字符串。示例:'test_repr[NaiveForecaster-2]',其中 test_repr 是测试名称,NaiveForecaster-2 是 fixture 组合字符串。

返回的 dict 的值,如果测试成功,是字符串 "PASSED",否则是测试失败时会引发的异常。默认情况下,check_estimator 不会引发异常,而是将它们作为字典值返回。为了引发异常(例如,用于调试),请使用参数 raise_exceptions=True,这将引发异常而不是将其作为字典值返回。在这种情况下,最多只会引发一个异常,即在测试执行顺序中遇到的第一个异常。

要运行或排除某些测试,请使用 tests_to_runtests_to_exclude 参数。提供的值应为测试名称(字符串),或测试名称列表。请注意,测试名称不包含方括号中的部分。

例如,使用所有 fixture 运行测试 test_constructor

check_estimator(NaiveForecaster, tests_to_run="test_constructor")

{'test_constructor[NaiveForecaster]': 'PASSED'}

要运行或排除某些测试-fixture 组合,请使用 fixtures_to_runfixtures_to_exclude 参数。提供的值应为测试-fixture 组合字符串(字符串),或此类字符串列表。有效字符串正是使用默认参数调用 check_estimator 时返回的字典键。

例如,运行测试-fixture 组合 "test_repr[NaiveForecaster-2]"

check_estimator(NaiveForecaster, fixtures_to_run="test_repr[NaiveForecaster-2]")

{'test_repr[NaiveForecaster-2]': 'PASSED'}

使用 check_estimator 调试 estimator 的有用工作流程如下

  1. 运行 check_estimator(MyEstimator) 查找失败的测试

  2. 使用 fixtures_to_runtests_to_run 子集到失败的测试或 fixture

  3. 如果失败原因不明显,设置 raise_exceptions=True 以引发异常并检查回溯信息。

  4. 如果失败原因仍然不清楚,请在调用 check_estimator 的代码行上使用高级调试器。

在仓库克隆中运行测试套件#

如果 estimator 的目标位置在 sktime 内部,则可以运行 sktime 测试套件。 sktime 测试套件(和 CI/CD)基于 pytestpytest 将自动收集特定类型的所有 estimator 以及适用于给定 estimator 的测试。

有关测试框架的概览,请参阅“测试框架”文档。通用接口一致性测试包含在 TestAllEstimatorsTestAllForecasters 等类中。对于 estimator EstimatorNamepytest 测试-fixture 字符串将始终包含 EstimatorName 作为子字符串,并且与 check_estimator 返回的测试-fixture 字符串相同。

要从控制台仅运行给定 estimator 的测试,可以使用命令 pytest -k "EstimatorName"。这通常与使用 check_estimator(EstimatorName) 具有相同的效果,只是通过直接调用 pytest。使用 Visual Studio Code 或 pycharm 时,也可以使用 GUI 过滤功能对测试进行子集化——有关此内容,请参阅相应的 IDE 测试集成文档。

要识别适用于特定 estimator 的测试代码库位置,一种快速方法是在代码库中搜索由 check_estimator 生成的测试字符串,并在其前面加上 def(表示函数/方法定义)。

在第三方扩展包中进行测试#

对于 sktime 的第三方扩展包(开源或闭源),或旨在与 sktime 接口兼容的第三方模块,可以通过以下方式导入和扩展 sktime 测试套件

  • 导入 check_estimator,这将一次性执行 sktime 中定义的测试。check_estimator 可以在任何测试框架中运行,包括 unittestpytest

  • sktime.utils.estimator_checks 导入 parametrize_with_checks。当在 pytest 测试套件中使用时,这将使用 sktime 中定义的所有测试来参数化一个测试函数,针对 estimator 类或实例列表,将每个 estimator-测试组合作为单独的测试用例运行。这种模式需要在测试套件中添加以下测试函数

    from sktime.utils.estimator_checks import parametrize_with_checks
    
    @parametrize_with_checks(OBJS_TO_TEST)
    def test_sktime_api_compliance(obj, test_name):
        check_estimator(obj, tests_to_run=test_name, raise_exceptions=True)
    
  • 导入测试类,例如 test_all_estimators.TestAllEstimatorstest_all_forecasters.TestAllForecasters。这些导入将直接由 pytest 发现。测试套件也可以通过继承这些测试类来扩展。

将与 sktime 兼容的 estimator 添加到 sktime#

将与 sktime 兼容的 estimator 添加到 sktime 本身时,还需要做一些额外的事情

  • 确保代码也符合 sktime文档标准。

  • 将 estimator 添加到 sktime API 参考中。这通过在 docs/source/api_reference 目录中正确的 rst 文件中添加对 estimator 的引用来完成。

  • estimator 的作者应将自己添加到 estimator 的 "authors""maintainers" 标签中,作为贡献的 estimator 的所有者。

  • 如果 estimator 依赖于软依赖项,或添加新的软依赖项,应遵循“依赖项”开发者指南中的步骤

  • 确保 estimator 在其目标位置通过 sktime 的整个本地测试套件。要仅运行该 estimator 的测试,可以使用命令 pytest -k "EstimatorName"(或 VS Code GUI 过滤功能)

  • 确保在 get_test_params 中选择的测试参数使得 estimator 特定的测试在 sktime 远程 CI/CD 上的运行时长保持在秒级别

请勿惊慌——在向 sktime 贡献时,核心开发者会在 PR 审阅中就上述内容提供有用的指导。

建议开启一个草稿 PR,以便尽早获得反馈。

依赖于 cython 的 Estimator#

要将依赖于 cython 的 estimator 添加到 sktime,还需要以下额外步骤

  • 所有 cython 代码应存在于 pypi 和/或 conda-forge 上的单独包中。任何依赖于 cython 的代码都不应直接添加到 sktime 中。下面,为简化引用,我们将此单独的包称为 home-package

  • home-package 中,建议通过 check_estimator 对 estimator 进行测试,测试矩阵与 sktime 相同:所有支持的 python 版本;MacOS、Linux、Windows。

  • sktime 中,应添加算法的接口。如果 home-package 中的算法已经通过了 check_estimator,这可以是简单的从 home-package 导入。

  • 或者,可以通过委托者(delegator)作为委托对象(delegate)来连接算法接口,可以在委托者中添加标签和方法覆盖。例如,可以参考 MrSQM

  • 对于 sktime 接口,应将 requires_cython 标签设置为 True,并将 python_dependencies 标签设置为字符串 "home-package"

如果一切设置正确,该 estimator 将由 CI 元素 test-cython-estimatorssktime 中进行测试。请注意,此 CI 元素不涵盖 python 版本和操作系统的完整测试矩阵,这应在上游包中完成。