save_model#
- save_model(sktime_model, path, conda_env=None, code_paths=None, mlflow_model=None, signature=None, input_example=None, pip_requirements=None, extra_pip_requirements=None, serialization_format='pickle')[source]#
将 sktime 模型保存到本地文件系统上的路径。
- 参数:
- sktime_model
拟合好的 sktime 模型对象。
- pathstr
模型要保存到的本地路径。
- conda_envUnion[dict, str], 可选 (默认值=None)
Conda 环境的字典表示或 conda 环境 yaml 文件路径。
- code_pathsarray-like, 可选 (默认值=None)
Python 文件依赖项(或包含文件依赖项的目录)的本地文件系统路径列表。加载模型时,这些文件会被添加到系统路径的开头。
- mlflow_model: mlflow.models.Model, 可选 (默认值=None)
mlflow.models.Model 配置,用于添加 python_function flavor。
- signaturemlflow.models.signature.ModelSignature, 可选 (默认值=None)
模型签名 mlflow.models.ModelSignature 描述了模型输入和输出
Schema
。模型签名可以从具有有效模型输入(例如,省略目标列的训练数据集)和有效模型输出(例如,在训练数据集上生成的模型预测)的数据集中推断
出来,例如from mlflow.models.signature import infer_signature train = df.drop_column("target_label") predictions = ... # compute model predictions signature = infer_signature(train, predictions)
警告
如果使用 sktime 模型进行概率预测(
predict_interval
,predict_quantiles
),由于使用这些方法时 Pandas MultiIndex 列类型的原因,返回的预测对象上的签名将无法正确推断。infer_schema
在使用模型的pyfunc
flavor 时会正常工作。- input_exampleUnion[pandas.core.frame.DataFrame, numpy.ndarray, dict, list, csr_matrix, csc_matrix], 可选
- (默认值=None)
输入示例提供了一个或多个有效模型输入的实例。该示例可以作为向模型馈送何种数据的提示。给定的示例将被转换为
Pandas DataFrame
,然后使用Pandas
面向拆分的格式序列化为 json。字节进行 base64 编码。- pip_requirementsUnion[Iterable, str], 可选 (默认值=None)
pip 需求字符串的可迭代对象(例如 [“sktime”, “-r requirements.txt”, “-c constraints.txt”])或本地文件系统上 pip 需求文件的字符串路径(例如 “requirements.txt”)
- extra_pip_requirementsUnion[Iterable, str], 可选 (默认值=None)
pip 需求字符串的可迭代对象(例如 [“pandas”, “-r requirements.txt”, “-c constraints.txt”])或本地文件系统上 pip 需求文件的字符串路径(例如 “requirements.txt”)
- serialization_formatstr, 可选 (默认值=”pickle”)
序列化模型的格式。应为“pickle”或“cloudpickle”格式之一
参考
[1]https://www.mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.Model.save
示例
>>> from sktime.datasets import load_airline >>> from sktime.forecasting.arima import ARIMA >>> from sktime.utils import mlflow_sktime >>> y = load_airline() >>> forecaster = ARIMA( ... order=(1, 1, 0), ... seasonal_order=(0, 1, 0, 12), ... suppress_warnings=True) >>> forecaster.fit(y) ARIMA(...) >>> model_path = "model" >>> mlflow_sktime.save_model( ... sktime_model=forecaster, ... path=model_path) >>> loaded_model = mlflow_sktime.load_model(model_uri=model_path) >>> loaded_model.predict(fh=[1, 2, 3])