Pandas create a mask based on multiple thresholds
Problem:
Lets say there is a Pandas Dataframe:
d = {'A': [0.1, 0.4, 0.2, 0.2],
'B': [0.7, 0.3, 0.2, 0.9],
'Z': [0.5, 0.3, 0.4, 0.6],
'sth': ['abc', 'something', 'unimportant', 'x']}
df = pd.DataFrame(data = d)
df
A | B | Z | sth | |
---|---|---|---|---|
0 | 0.1 | 0.7 | 0.5 | "abc" |
1 | 0.4 | 0.3 | 0.3 | "something" |
2 | 0.2 | 0.2 | 0.4 | "unimportant" |
3 | 0.2 | 0.9 | 0.6 | "x" |
thresholds = {'A': 0.5, 'B':0.8, 'Z': 0.3}
I want to find a mask that will have True
for each row, where highest value of this row is lower than threshold defined for this class.
For the given example, correct mask would be:
[ True, True, False, False]
Explanation:
- Row
0
. First find the highest value in this rowmax([0.1,0.7,0.5]) = 0.7
. Note that0.7
was in columnB
. Compare this value with the threshold (0.8) for columnB
.0.8 > 0.7
, so result is True. - Row
1
has highest value at columnA
causemax([0.4,0.3,0.3]) = 0.4
, and threshold for classA
is 0.5, henceTrue
- Row
2
has highest value at columnZ
causemax([0.2,0.2,0.4]) = 0.4
, and threshold for classZ
is 0.3, henceFalse
- Row
3
has highest value at columnB
causemax([0.2,0.9,0.6]) = 0.9
, and threshold for classB
is 0.8. Cause0.8 < 0.9
this row isFalse
Solution 1:
You could use apply
with a lambda
function to calculate the ones that breach the threshold.
Try this:
def within_threshold(x, thresh):
key = pd.to_numeric(x[thresh.keys()]).idxmax(axis=0)
return x[key] > thresh[key]
df["within_threshold"] = df.apply(lambda x: within_threshold(x, thresholds), axis=1)
df
The full code snippet:
import pandas as pd
thresholds = {'A': 0.5, 'B':0.8, 'Z': 0.3}
d = {'A': [0.1,0.4,0.2],'B':[0.7,0.3,0.2],'Z':[0.5,0.3,0.4],'sth':["a","b","c"]}
df = pd.DataFrame(data = d)
def within_threshold(x, thresh):
key = pd.to_numeric(x[thresh.keys()]).idxmax(axis=0)
return x[key] > thresh[key]
df["within_threshold"] = df.apply(lambda x: within_threshold(x, thresholds), axis=1)
df
Should get you this:
A B Z sth within_threshold
0 0.1 0.7 0.5 a True
1 0.4 0.3 0.3 b True
2 0.2 0.2 0.4 c False
3 0.2 0.9 0.2 d False
Also, from your example data, row 0
has a Z value of 0.5
, which is above the Z threshold.
Edit by OP
This answer lead me to find the solution, so I edited it, and it now solves the problem.
Solution 2:
A list comprehension can do the work straightforwards :
[df[col].max < tresholds[col] for col in tresholds.keys()]
However, I wouldn't use a list to get the result but rather a dictionary with key being the column name and the value the desired boolean. Indexing with integers could be a bit ambiguous depending on the dataframe you are using.