在SciKit学习中,将目标类强制转换为float或int类型(甚至字符串,请参见:
Is numerical encoding necessary for the target variable in classification?
)两者都是允许的。您只需注意分类目标将与输入保持相同的类型,因此如果您的输入是浮点类型,您将得到预测的浮点向量(请参见:
https://scikit-learn.org/stable/tutorial/basic/tutorial.html#type-casting
)
在本例中,您将直接验证
KNeighborsClassifier
将生成相同的类预测(但数据类型不同,具体取决于目标类输入类型):
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
data = load_iris()
(X_train, X_test,
y_train, y_test) = train_test_split(data.data,
data.target,
test_size=0.33,
random_state=42)
neigh = KNeighborsClassifier(n_neighbors=3)
neigh.fit(X_train, y_train.astype(int))
int_preds = neigh.predict(X_test)
neigh.fit(X_train, y_train.astype(float))
float_preds = neigh.predict(X_test)
print(int_preds.dtype, float_preds.dtype)
print("Same classes:", (int_preds == float_preds).all())