In a pandas dataframe, for every row, I want to keep only the top N values and set everything else to 0. I can iterate through the rows and do it but I am sure python/pandas can do it elegantly in a single line.
For e.g.: for N = 2
Input:
A B C D
4 10 10 6
5 20 50 90
6 30 6 4
7 40 12 9
Output:
A B C D
0 10 10 0
0 0 50 90
6 30 6 0
0 40 12 0
Using rank with parameters axis=1 and method='min' and ascending=False as:
N = 2
df = df.mask(df.rank(axis=1, method='min', ascending=False) > N, 0)
Or using np.where with pd.DataFrame which is faster than mask method:
df = pd.DataFrame(np.where(df.rank(axis=1,method='min',ascending=False)>N, 0, df),
columns=df.columns)
print(df)
A B C D
0 0 10 10 0
1 0 0 50 90
2 6 30 6 0
3 0 40 12 0
Step 1:
First we need to find what are the 2 smallest numbers in the row and also if there is a duplicate that need to be taken account. So, using axis=1 ranks across rows and duplicate values will be taken care by method='min' and ascending = False:
print(df.rank(axis=1, method='min', ascending=False))
A B C D
0 4.0 1.0 1.0 3.0
1 4.0 3.0 2.0 1.0
2 2.0 1.0 2.0 4.0
3 4.0 1.0 2.0 3.0
Step 2: Second we need to filter where the values is greater than (N) as per condition and then change those values using mask:
print(df.rank(axis=1, method='min', ascending=False) > N)
A B C D
0 True False False True
1 True True False False
2 False False False True
3 True False False True
print(df.mask(df.rank(axis=1, method='min', ascending=False) > N, 0))
A B C D
0 0 10 10 0
1 0 0 50 90
2 6 30 6 0
3 0 40 12 0
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With