实现 Estimator#
本页介绍如何实现与 sktime
兼容的 estimator,以及如何确保和测试兼容性。对于直接贡献给 sktime
的 estimator,还需要额外的步骤。
实现与 sktime
兼容的 estimator#
实现与 sktime
兼容的 estimator 的高级步骤如下
确定 estimator 的类型:forecaster(预测器)、classifier(分类器)等
将该类型 estimator 的扩展模板复制到其目标位置
填写完整的扩展模板
运行
sktime
测试套件和/或check_estimator
工具(参见此处)如果测试套件发现错误或问题,修复它们并返回步骤 4
有关如何实现自己的 estimator 的更多指导,请参阅 pydata 上的此教程,内容关于接口一致性测试。
我的学习任务是什么?#
sktime
按照包含特定学习任务的模块进行组织,例如 forecasting(预测)或 time series classification(时间序列分类)。为简洁起见,我们根据 estimator 解决的正式学习任务定义其科学类型或“scitype”。例如,解决预测任务的 estimator 的 scitype 是“forecaster”。解决时间序列分类任务的 estimator 的 scitype 是“time series classifier”。
给定 scitype 的 estimator 应位于相应的模块中。estimator 的 scitype 也对应于 sktime
的 extension_templates 目录中找到的不同扩展模板。
通常,给定 estimator 的 scitype 直接由其功能决定。这通常也在与 estimator 相关的出版物中明确指出。例如,大多数教科书在预测的背景下提到 ARIMA,因此在这种假设情况下,考虑“forecaster”模板是有意义的。然后,检查模板并查看类的各种方法是否清晰地映射到 estimator 的例程。如果不匹配,则可能需要使用另一个模板。
这里最常见的困惑点在于 transformers 和其他 estimator 类型之间的区分,因为 transformers 经常被用作其他类型算法的一部分。
如果不确定,请随时在 sktime
的社交渠道之一上提问。请勿惊慌——学术出版物对 estimator 类型描述不明确的情况并不少见,即使对专家来说,正确分类也可能很困难。
sktime
扩展模板是什么?#
扩展模板是为新 estimator 的实现者提供的便捷“填充”模板。它们以以下方式融入 sktime
的统一接口
对于每种 scitype,都有一个公共用户接口,由相应的基类定义。例如,
BaseForecaster
为 forecaster 定义了fit
和predict
接口。所有 forecaster 都将通过继承BaseForecaster
以相同的方式实现fit
和predict
。公共接口遵循“策略”面向对象模式。对于每种 scitype,都有一个私有扩展接口,由扩展模板中的扩展契约定义。例如,用于 forecaster 的
forecaster.py
扩展模板解释了为继承自BaseForecaster
的具体 forecaster 需要填写哪些内容。在大多数扩展模板中,用户应实现私有方法(“内部”方法),例如 forecaster 的_fit
和_predict
。样板代码位于接口的公共部分,即fit
和predict
中。扩展接口遵循“模板”面向对象模式。
熟悉 scikit-learn
扩展的扩展者应注意与 scikit-learn
的以下区别
在 sktime
(具体) estimator 中,公共接口,例如 fit
和 predict
,从不被覆盖。实现发生在私有的、扩展者侧的接口中,例如 _fit
和 _predict
。
这可以避免样板代码的重复,例如 scikit-learn
中的 check_X
等。这也允许更丰富的样板代码,例如自动向量化功能或输入转换。
如何使用 sktime
扩展模板#
要使用 sktime
扩展模板,请将其复制到 estimator 的目标位置。在扩展模板内部,必要的操作用 todo
标记。典型的工作流程是搜索 todo
,并执行 todo
旁描述的操作。
扩展模板通常有以下 todo
选择 estimator 的名称和参数
填写
__init__
方法:将参数写入self
,调用super
的__init__
方法填写模块和 estimator 的 docstring。建议在参数确定后尽早完成此步骤,这通常对后续实现非常有益,可作为规范来遵循。
填写 estimator 的标签。一些标签是“capabilities”(能力),即 estimator 能做什么,例如处理 nans。其他标签确定在“内部”方法
_fit
等中输入的格式,这些标签通常称为X_inner_mtype
或类似名称。这在内部功能假设输入为numpy.ndarray
或pandas.DataFrame
的情况下非常有用,有助于避免转换样板代码。类型字符串可以在datatypes.MTYPE_REGISTER
中找到。有关数据类型约定的教程,请参阅examples/AA_datatypes_and_datasets
。填写“内部”方法,例如
_fit
和_predict
。应遵循扩展模板中的 docstring 和注释。docstring 还描述了对“内部”方法输入的保证,这些保证通常比对公共方法输入的保证更强,并由已设置的标签值确定。例如,将 forecaster 的标签y_inner_mtype
设置为pd.DataFrame
可保证_fit
看到的y
将是pandas.DataFrame
,并且符合sktime
中的其他数据容器规范(例如,索引类型)。在
get_test_params
中填写测试参数。参数的选择应涵盖主要的 estimator 内部情况区分,以实现良好的覆盖率。
一些常见注意事项,也在扩展模板文本中描述
__init__
参数应写入self
且绝不应更改特殊情况:estimator 组件,即作为 estimator 的参数,通常应被克隆(通过
sklearn.clone
),且方法应仅在克隆对象上调用方法通常应避免对参数产生副作用
非状态更改方法通常不应写入
self
通常,不需要实现
get_params
和set_params
,因为sktime
的BaseEstimator
继承自sklearn
的BaseEstimator
。自定义的get_params
和set_params
通常仅在异构复合(例如,包含 estimator 的嵌套结构的参数的管道)等复杂情况下才需要。
如何测试接口一致性#
有关以下内容的视频教程和更多示例,请访问我们在 pydata 上的教程。
使用 check_estimator
工具#
通常,测试与 sktime
接口一致性的最简单方法是使用 utils.estimator_checks
模块中的 check_estimator
方法。
调用时,这将收集 sktime
中与该 estimator 类型相关的测试,并在该 estimator 上运行它们。
这可以在 notebook 环境中用于手动调试。运行 NaiveForecaster
完整测试套件的示例
from sktime.utils.estimator_checks import check_estimator
from sktime.forecasting.naive import NaiveForecaster
check_estimator(NaiveForecaster)
默认情况下,check_estimator
工具将返回一个 dict
,其索引是测试/fixture 组合字符串,即测试名称和方括号中的 fixture 组合字符串。示例:'test_repr[NaiveForecaster-2]'
,其中 test_repr
是测试名称,NaiveForecaster-2
是 fixture 组合字符串。
返回的 dict
的值,如果测试成功,是字符串 "PASSED"
,否则是测试失败时会引发的异常。默认情况下,check_estimator
不会引发异常,而是将它们作为字典值返回。为了引发异常(例如,用于调试),请使用参数 raise_exceptions=True
,这将引发异常而不是将其作为字典值返回。在这种情况下,最多只会引发一个异常,即在测试执行顺序中遇到的第一个异常。
要运行或排除某些测试,请使用 tests_to_run
或 tests_to_exclude
参数。提供的值应为测试名称(字符串),或测试名称列表。请注意,测试名称不包含方括号中的部分。
例如,使用所有 fixture 运行测试 test_constructor
check_estimator(NaiveForecaster, tests_to_run="test_constructor")
{'test_constructor[NaiveForecaster]': 'PASSED'}
要运行或排除某些测试-fixture 组合,请使用 fixtures_to_run
或 fixtures_to_exclude
参数。提供的值应为测试-fixture 组合字符串(字符串),或此类字符串列表。有效字符串正是使用默认参数调用 check_estimator
时返回的字典键。
例如,运行测试-fixture 组合 "test_repr[NaiveForecaster-2]"
check_estimator(NaiveForecaster, fixtures_to_run="test_repr[NaiveForecaster-2]")
{'test_repr[NaiveForecaster-2]': 'PASSED'}
使用 check_estimator
调试 estimator 的有用工作流程如下
运行
check_estimator(MyEstimator)
查找失败的测试使用
fixtures_to_run
或tests_to_run
子集到失败的测试或 fixture如果失败原因不明显,设置
raise_exceptions=True
以引发异常并检查回溯信息。如果失败原因仍然不清楚,请在调用
check_estimator
的代码行上使用高级调试器。
在仓库克隆中运行测试套件#
如果 estimator 的目标位置在 sktime
内部,则可以运行 sktime
测试套件。 sktime
测试套件(和 CI/CD)基于 pytest
,pytest
将自动收集特定类型的所有 estimator 以及适用于给定 estimator 的测试。
有关测试框架的概览,请参阅“测试框架”文档。通用接口一致性测试包含在 TestAllEstimators
、TestAllForecasters
等类中。对于 estimator EstimatorName
的 pytest
测试-fixture 字符串将始终包含 EstimatorName
作为子字符串,并且与 check_estimator
返回的测试-fixture 字符串相同。
要从控制台仅运行给定 estimator 的测试,可以使用命令 pytest -k "EstimatorName"
。这通常与使用 check_estimator(EstimatorName)
具有相同的效果,只是通过直接调用 pytest
。使用 Visual Studio Code 或 pycharm 时,也可以使用 GUI 过滤功能对测试进行子集化——有关此内容,请参阅相应的 IDE 测试集成文档。
要识别适用于特定 estimator 的测试代码库位置,一种快速方法是在代码库中搜索由 check_estimator
生成的测试字符串,并在其前面加上 def
(表示函数/方法定义)。
在第三方扩展包中进行测试#
对于 sktime
的第三方扩展包(开源或闭源),或旨在与 sktime
接口兼容的第三方模块,可以通过以下方式导入和扩展 sktime
测试套件
导入
check_estimator
,这将一次性执行sktime
中定义的测试。check_estimator
可以在任何测试框架中运行,包括unittest
和pytest
。从
sktime.utils.estimator_checks
导入parametrize_with_checks
。当在pytest
测试套件中使用时,这将使用sktime
中定义的所有测试来参数化一个测试函数,针对 estimator 类或实例列表,将每个 estimator-测试组合作为单独的测试用例运行。这种模式需要在测试套件中添加以下测试函数from sktime.utils.estimator_checks import parametrize_with_checks @parametrize_with_checks(OBJS_TO_TEST) def test_sktime_api_compliance(obj, test_name): check_estimator(obj, tests_to_run=test_name, raise_exceptions=True)
导入测试类,例如
test_all_estimators.TestAllEstimators
或test_all_forecasters.TestAllForecasters
。这些导入将直接由pytest
发现。测试套件也可以通过继承这些测试类来扩展。
将与 sktime
兼容的 estimator 添加到 sktime
#
将与 sktime
兼容的 estimator 添加到 sktime
本身时,还需要做一些额外的事情
确保代码也符合
sktime
的文档标准。将 estimator 添加到
sktime
API 参考中。这通过在docs/source/api_reference
目录中正确的rst
文件中添加对 estimator 的引用来完成。estimator 的作者应将自己添加到 estimator 的
"authors"
和"maintainers"
标签中,作为贡献的 estimator 的所有者。如果 estimator 依赖于软依赖项,或添加新的软依赖项,应遵循“依赖项”开发者指南中的步骤
确保 estimator 在其目标位置通过
sktime
的整个本地测试套件。要仅运行该 estimator 的测试,可以使用命令pytest -k "EstimatorName"
(或 VS Code GUI 过滤功能)确保在
get_test_params
中选择的测试参数使得 estimator 特定的测试在sktime
远程 CI/CD 上的运行时长保持在秒级别
请勿惊慌——在向 sktime
贡献时,核心开发者会在 PR 审阅中就上述内容提供有用的指导。
建议开启一个草稿 PR,以便尽早获得反馈。
依赖于 cython 的 Estimator#
要将依赖于 cython 的 estimator 添加到 sktime
,还需要以下额外步骤
所有 cython 代码应存在于
pypi
和/或conda-forge
上的单独包中。任何依赖于 cython 的代码都不应直接添加到sktime
中。下面,为简化引用,我们将此单独的包称为home-package
。在
home-package
中,建议通过check_estimator
对 estimator 进行测试,测试矩阵与sktime
相同:所有支持的 python 版本;MacOS、Linux、Windows。在
sktime
中,应添加算法的接口。如果home-package
中的算法已经通过了check_estimator
,这可以是简单的从home-package
导入。或者,可以通过委托者(delegator)作为委托对象(delegate)来连接算法接口,可以在委托者中添加标签和方法覆盖。例如,可以参考
MrSQM
。对于
sktime
接口,应将requires_cython
标签设置为True
,并将python_dependencies
标签设置为字符串"home-package"
。
如果一切设置正确,该 estimator 将由 CI 元素 test-cython-estimators
在 sktime
中进行测试。请注意,此 CI 元素不涵盖 python 版本和操作系统的完整测试矩阵,这应在上游包中完成。