binder

[1]:
import warnings

warnings.filterwarnings("ignore")

多元时间序列分类中的通道选择#

概述#

有时并非所有通道都用于执行分类;只有少数通道有用。[1] 提出了一种用于多元时间分类的快速通道选择技术。

[1] : 可扩展多元时间序列分类的快速通道选择 链接

[2]:
from sklearn.linear_model import RidgeClassifierCV
from sklearn.pipeline import make_pipeline

from sktime.datasets import load_UCR_UEA_dataset
from sktime.transformations.panel import channel_selection
from sktime.transformations.panel.rocket import Rocket

1 初始化流水线#

[3]:
# cs = channel_selection.ElbowClassSum()  # ECS
cs = channel_selection.ElbowClassPairwise()  # ECP
[4]:
rocket_pipeline = make_pipeline(cs, Rocket(), RidgeClassifierCV())

2 加载并拟合训练数据#

[5]:
data = "BasicMotions"
X_train, y_train = load_UCR_UEA_dataset(data, split="train", return_X_y=True)
X_test, y_test = load_UCR_UEA_dataset(data, split="test", return_X_y=True)
[6]:
rocket_pipeline.fit(X_train, y_train)
[6]:
Pipeline(steps=[('elbowclasspairwise', ElbowClassPairwise()),
                ('rocket', Rocket()),
                ('ridgeclassifiercv',
                 RidgeClassifierCV(alphas=array([ 0.1,  1. , 10. ])))])

3 对测试数据进行分类#

[7]:
rocket_pipeline.score(X_test, y_test)
[7]:
1.0

4 识别通道#

[8]:
rocket_pipeline.steps[0][1].channels_selected_
[8]:
[0, 1]
[9]:
rocket_pipeline.steps[0][1].distance_frame_
[9]:
Centroid_badminton_running Centroid_badminton_standing Centroid_badminton_walking Centroid_running_standing Centroid_running_walking Centroid_standing_walking
0 39.594679 55.752785 48.440779 63.610220 57.247383 10.717044
1 57.681767 24.390543 27.770269 60.458125 62.339120 16.370347
2 20.175911 24.126969 22.331621 25.671979 22.991555 4.897452
3 12.546212 12.439152 12.741854 6.317654 6.695743 3.585273
4 10.101196 8.865871 9.221908 6.520172 6.715702 1.299989
5 23.464251 14.568685 13.953445 18.878429 19.768549 7.228389

5 独立运行#

[10]:
cs.fit(X_train, y_train)
[10]:
ElbowClassPairwise()

6 距离矩阵#

[11]:
cs.distance_frame_
[11]:
Centroid_badminton_running Centroid_badminton_standing Centroid_badminton_walking Centroid_running_standing Centroid_running_walking Centroid_standing_walking
0 39.594679 55.752785 48.440779 63.610220 57.247383 10.717044
1 57.681767 24.390543 27.770269 60.458125 62.339120 16.370347
2 20.175911 24.126969 22.331621 25.671979 22.991555 4.897452
3 12.546212 12.439152 12.741854 6.317654 6.695743 3.585273
4 10.101196 8.865871 9.221908 6.520172 6.715702 1.299989
5 23.464251 14.568685 13.953445 18.878429 19.768549 7.228389
[12]:
cs.train_time_
[12]:
13

使用 nbsphinx 生成。Jupyter 笔记本可在此处找到:here