实现 Estimator#
本页介绍如何实现与 sktime 兼容的 estimator,以及如何确保和测试兼容性。对于直接贡献给 sktime 的 estimator,还需要额外的步骤。
实现与 sktime 兼容的 estimator#
实现与 sktime 兼容的 estimator 的高级步骤如下
确定 estimator 的类型:forecaster(预测器)、classifier(分类器)等
将该类型 estimator 的扩展模板复制到其目标位置
填写完整的扩展模板
运行
sktime测试套件和/或check_estimator工具(参见此处)如果测试套件发现错误或问题,修复它们并返回步骤 4
有关如何实现自己的 estimator 的更多指导,请参阅 pydata 上的此教程,内容关于接口一致性测试。
我的学习任务是什么?#
sktime 按照包含特定学习任务的模块进行组织,例如 forecasting(预测)或 time series classification(时间序列分类)。为简洁起见,我们根据 estimator 解决的正式学习任务定义其科学类型或“scitype”。例如,解决预测任务的 estimator 的 scitype 是“forecaster”。解决时间序列分类任务的 estimator 的 scitype 是“time series classifier”。
给定 scitype 的 estimator 应位于相应的模块中。estimator 的 scitype 也对应于 sktime 的 extension_templates 目录中找到的不同扩展模板。
通常,给定 estimator 的 scitype 直接由其功能决定。这通常也在与 estimator 相关的出版物中明确指出。例如,大多数教科书在预测的背景下提到 ARIMA,因此在这种假设情况下,考虑“forecaster”模板是有意义的。然后,检查模板并查看类的各种方法是否清晰地映射到 estimator 的例程。如果不匹配,则可能需要使用另一个模板。
这里最常见的困惑点在于 transformers 和其他 estimator 类型之间的区分,因为 transformers 经常被用作其他类型算法的一部分。
如果不确定,请随时在 sktime 的社交渠道之一上提问。请勿惊慌——学术出版物对 estimator 类型描述不明确的情况并不少见,即使对专家来说,正确分类也可能很困难。
sktime 扩展模板是什么?#
扩展模板是为新 estimator 的实现者提供的便捷“填充”模板。它们以以下方式融入 sktime 的统一接口
对于每种 scitype,都有一个公共用户接口,由相应的基类定义。例如,
BaseForecaster为 forecaster 定义了fit和predict接口。所有 forecaster 都将通过继承BaseForecaster以相同的方式实现fit和predict。公共接口遵循“策略”面向对象模式。对于每种 scitype,都有一个私有扩展接口,由扩展模板中的扩展契约定义。例如,用于 forecaster 的
forecaster.py扩展模板解释了为继承自BaseForecaster的具体 forecaster 需要填写哪些内容。在大多数扩展模板中,用户应实现私有方法(“内部”方法),例如 forecaster 的_fit和_predict。样板代码位于接口的公共部分,即fit和predict中。扩展接口遵循“模板”面向对象模式。
熟悉 scikit-learn 扩展的扩展者应注意与 scikit-learn 的以下区别
在 sktime (具体) estimator 中,公共接口,例如 fit 和 predict,从不被覆盖。实现发生在私有的、扩展者侧的接口中,例如 _fit 和 _predict。
这可以避免样板代码的重复,例如 scikit-learn 中的 check_X 等。这也允许更丰富的样板代码,例如自动向量化功能或输入转换。
如何使用 sktime 扩展模板#
要使用 sktime 扩展模板,请将其复制到 estimator 的目标位置。在扩展模板内部,必要的操作用 todo 标记。典型的工作流程是搜索 todo,并执行 todo 旁描述的操作。
扩展模板通常有以下 todo
选择 estimator 的名称和参数
填写
__init__方法:将参数写入self,调用super的__init__方法填写模块和 estimator 的 docstring。建议在参数确定后尽早完成此步骤,这通常对后续实现非常有益,可作为规范来遵循。
填写 estimator 的标签。一些标签是“capabilities”(能力),即 estimator 能做什么,例如处理 nans。其他标签确定在“内部”方法
_fit等中输入的格式,这些标签通常称为X_inner_mtype或类似名称。这在内部功能假设输入为numpy.ndarray或pandas.DataFrame的情况下非常有用,有助于避免转换样板代码。类型字符串可以在datatypes.MTYPE_REGISTER中找到。有关数据类型约定的教程,请参阅examples/AA_datatypes_and_datasets。填写“内部”方法,例如
_fit和_predict。应遵循扩展模板中的 docstring 和注释。docstring 还描述了对“内部”方法输入的保证,这些保证通常比对公共方法输入的保证更强,并由已设置的标签值确定。例如,将 forecaster 的标签y_inner_mtype设置为pd.DataFrame可保证_fit看到的y将是pandas.DataFrame,并且符合sktime中的其他数据容器规范(例如,索引类型)。在
get_test_params中填写测试参数。参数的选择应涵盖主要的 estimator 内部情况区分,以实现良好的覆盖率。
一些常见注意事项,也在扩展模板文本中描述
__init__参数应写入self且绝不应更改特殊情况:estimator 组件,即作为 estimator 的参数,通常应被克隆(通过
sklearn.clone),且方法应仅在克隆对象上调用方法通常应避免对参数产生副作用
非状态更改方法通常不应写入
self通常,不需要实现
get_params和set_params,因为sktime的BaseEstimator继承自sklearn的BaseEstimator。自定义的get_params和set_params通常仅在异构复合(例如,包含 estimator 的嵌套结构的参数的管道)等复杂情况下才需要。
如何测试接口一致性#
有关以下内容的视频教程和更多示例,请访问我们在 pydata 上的教程。
使用 check_estimator 工具#
通常,测试与 sktime 接口一致性的最简单方法是使用 utils.estimator_checks 模块中的 check_estimator 方法。
调用时,这将收集 sktime 中与该 estimator 类型相关的测试,并在该 estimator 上运行它们。
这可以在 notebook 环境中用于手动调试。运行 NaiveForecaster 完整测试套件的示例
from sktime.utils.estimator_checks import check_estimator
from sktime.forecasting.naive import NaiveForecaster
check_estimator(NaiveForecaster)
默认情况下,check_estimator 工具将返回一个 dict,其索引是测试/fixture 组合字符串,即测试名称和方括号中的 fixture 组合字符串。示例:'test_repr[NaiveForecaster-2]',其中 test_repr 是测试名称,NaiveForecaster-2 是 fixture 组合字符串。
返回的 dict 的值,如果测试成功,是字符串 "PASSED",否则是测试失败时会引发的异常。默认情况下,check_estimator 不会引发异常,而是将它们作为字典值返回。为了引发异常(例如,用于调试),请使用参数 raise_exceptions=True,这将引发异常而不是将其作为字典值返回。在这种情况下,最多只会引发一个异常,即在测试执行顺序中遇到的第一个异常。
要运行或排除某些测试,请使用 tests_to_run 或 tests_to_exclude 参数。提供的值应为测试名称(字符串),或测试名称列表。请注意,测试名称不包含方括号中的部分。
例如,使用所有 fixture 运行测试 test_constructor
check_estimator(NaiveForecaster, tests_to_run="test_constructor")
{'test_constructor[NaiveForecaster]': 'PASSED'}
要运行或排除某些测试-fixture 组合,请使用 fixtures_to_run 或 fixtures_to_exclude 参数。提供的值应为测试-fixture 组合字符串(字符串),或此类字符串列表。有效字符串正是使用默认参数调用 check_estimator 时返回的字典键。
例如,运行测试-fixture 组合 "test_repr[NaiveForecaster-2]"
check_estimator(NaiveForecaster, fixtures_to_run="test_repr[NaiveForecaster-2]")
{'test_repr[NaiveForecaster-2]': 'PASSED'}
使用 check_estimator 调试 estimator 的有用工作流程如下
运行
check_estimator(MyEstimator)查找失败的测试使用
fixtures_to_run或tests_to_run子集到失败的测试或 fixture如果失败原因不明显,设置
raise_exceptions=True以引发异常并检查回溯信息。如果失败原因仍然不清楚,请在调用
check_estimator的代码行上使用高级调试器。
在仓库克隆中运行测试套件#
如果 estimator 的目标位置在 sktime 内部,则可以运行 sktime 测试套件。 sktime 测试套件(和 CI/CD)基于 pytest,pytest 将自动收集特定类型的所有 estimator 以及适用于给定 estimator 的测试。
有关测试框架的概览,请参阅“测试框架”文档。通用接口一致性测试包含在 TestAllEstimators、TestAllForecasters 等类中。对于 estimator EstimatorName 的 pytest 测试-fixture 字符串将始终包含 EstimatorName 作为子字符串,并且与 check_estimator 返回的测试-fixture 字符串相同。
要从控制台仅运行给定 estimator 的测试,可以使用命令 pytest -k "EstimatorName"。这通常与使用 check_estimator(EstimatorName) 具有相同的效果,只是通过直接调用 pytest。使用 Visual Studio Code 或 pycharm 时,也可以使用 GUI 过滤功能对测试进行子集化——有关此内容,请参阅相应的 IDE 测试集成文档。
要识别适用于特定 estimator 的测试代码库位置,一种快速方法是在代码库中搜索由 check_estimator 生成的测试字符串,并在其前面加上 def(表示函数/方法定义)。
在第三方扩展包中进行测试#
对于 sktime 的第三方扩展包(开源或闭源),或旨在与 sktime 接口兼容的第三方模块,可以通过以下方式导入和扩展 sktime 测试套件
导入
check_estimator,这将一次性执行sktime中定义的测试。check_estimator可以在任何测试框架中运行,包括unittest和pytest。从
sktime.utils.estimator_checks导入parametrize_with_checks。当在pytest测试套件中使用时,这将使用sktime中定义的所有测试来参数化一个测试函数,针对 estimator 类或实例列表,将每个 estimator-测试组合作为单独的测试用例运行。这种模式需要在测试套件中添加以下测试函数from sktime.utils.estimator_checks import parametrize_with_checks @parametrize_with_checks(OBJS_TO_TEST) def test_sktime_api_compliance(obj, test_name): check_estimator(obj, tests_to_run=test_name, raise_exceptions=True)
导入测试类,例如
test_all_estimators.TestAllEstimators或test_all_forecasters.TestAllForecasters。这些导入将直接由pytest发现。测试套件也可以通过继承这些测试类来扩展。
将与 sktime 兼容的 estimator 添加到 sktime#
将与 sktime 兼容的 estimator 添加到 sktime 本身时,还需要做一些额外的事情
确保代码也符合
sktime的文档标准。将 estimator 添加到
sktimeAPI 参考中。这通过在docs/source/api_reference目录中正确的rst文件中添加对 estimator 的引用来完成。estimator 的作者应将自己添加到 estimator 的
"authors"和"maintainers"标签中,作为贡献的 estimator 的所有者。如果 estimator 依赖于软依赖项,或添加新的软依赖项,应遵循“依赖项”开发者指南中的步骤
确保 estimator 在其目标位置通过
sktime的整个本地测试套件。要仅运行该 estimator 的测试,可以使用命令pytest -k "EstimatorName"(或 VS Code GUI 过滤功能)确保在
get_test_params中选择的测试参数使得 estimator 特定的测试在sktime远程 CI/CD 上的运行时长保持在秒级别
请勿惊慌——在向 sktime 贡献时,核心开发者会在 PR 审阅中就上述内容提供有用的指导。
建议开启一个草稿 PR,以便尽早获得反馈。
依赖于 cython 的 Estimator#
要将依赖于 cython 的 estimator 添加到 sktime,还需要以下额外步骤
所有 cython 代码应存在于
pypi和/或conda-forge上的单独包中。任何依赖于 cython 的代码都不应直接添加到sktime中。下面,为简化引用,我们将此单独的包称为home-package。在
home-package中,建议通过check_estimator对 estimator 进行测试,测试矩阵与sktime相同:所有支持的 python 版本;MacOS、Linux、Windows。在
sktime中,应添加算法的接口。如果home-package中的算法已经通过了check_estimator,这可以是简单的从home-package导入。或者,可以通过委托者(delegator)作为委托对象(delegate)来连接算法接口,可以在委托者中添加标签和方法覆盖。例如,可以参考
MrSQM。对于
sktime接口,应将requires_cython标签设置为True,并将python_dependencies标签设置为字符串"home-package"。
如果一切设置正确,该 estimator 将由 CI 元素 test-cython-estimators 在 sktime 中进行测试。请注意,此 CI 元素不涵盖 python 版本和操作系统的完整测试矩阵,这应在上游包中完成。