Name of variables in sklearn pipeline
Two key components can help make this work. The first gets the encoding names from the OneHotEncoder
: OneHotEncoder.get_feature_names_out
. Specifically, you use that on your encoder
as encoder.get_feature_names_out()
. The second component is that sklearn.tree.export_text
takes a feature_names
argument. So, you can pass those extracted names right into the display system. Other sklearn
tree displayers also take that parameter (plot_tree
, export_graphviz
).
See here for related SO:
- Feature names from OneHotEncoder
- How to display feature names in sklearn decision tree?
sklearn
docs here:
- https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html
- https://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree (follow those links for the tree export/plot functions).
The following should work for you (Edit: I forget the pipeline part in my example. You can use my_pipe.named_steps[step_name]
to extract out the OneHotEncoder
. You may have to nest that since you have nested pipelines. Added that example below.):
import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from sklearn.tree import DecisionTreeClassifier, export_text
import sklearn
print(sklearn.__version__) # ---> 1.0.2 for me
ftrs = pd.DataFrame({'Sex' : ['male', 'female']*3,
'AgeGroup': ['0-20', '0-20',
'20-60', '20-60',
'80+', '80+']})
tgt = np.array([1, 1, 1, 1, 0, 1])
encoder = OneHotEncoder()
enc_ftrs = encoder.fit_transform(ftrs)
dtc = DecisionTreeClassifier().fit(enc_ftrs, tgt)
encoder_names = encoder.get_feature_names_out()
print(export_text(dtc, feature_names = list(encoder_names)))
Which for me gives the following output:
|--- AgeGroup_80+ <= 0.50
| |--- class: 1
|--- AgeGroup_80+ > 0.50
| |--- Sex_female <= 0.50
| | |--- class: 0
| |--- Sex_female > 0.50
| | |--- class: 1
Including the pipeline, it looks like this:
from sklearn.pipeline import Pipeline
pipe = Pipeline([('enc', OneHotEncoder()),
('dtc', DecisionTreeClassifier())])
pipe.fit(ftrs, tgt)
feature_names = list(pipe.named_steps['enc'].get_feature_names_out())
print(export_text(pipe.named_steps['dtc'],
feature_names = feature_names))
with the same output.