Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

python convert prediction result into one hot

Tags:

python

I run python neural network prediction expecting one hot result and got back numbers as follow:

[[0.33058667182922363, 0.3436272442340851, 0.3257860243320465], 
[0.32983461022377014, 0.3487854599952698, 0.4213798701763153], 
[0.3311253488063812, 0.3473075330257416, 0.3215670585632324], 
[0.38368630170822144, 0.35151687264442444, 0.3247968554496765], 
[0.3332786560058594, 0.343686580657959, 0.32303473353385925]]

how can I convert the array into one hot result, i.e.

[[0,1,0],
[0,0,1],
[0,1,0],
[1,0,0]
[0,1,0]]
like image 336
user1768619 Avatar asked May 12 '26 22:05

user1768619


2 Answers

By one hot result I assume you want max value of each sub-list to be 1 and rest to be 0 (based on the pattern in current result). You may do it using list comprehension as:

>>> [[int(item == max(sublist)) else 0 for item in sublist] for sublist in my_list]
#      ^  converts bool value returned by `==` into `int`. True -> 1, False -> 0
[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 1, 0]]

where my_list is your initial list.

But in the above approach, you will be calculating max() each time while iteration over sub-list. Better way will be to do it like:

def get_hot_value(my_list):
    max_val = max(my_list)
    return [int(item == max_val) for item in my_list]

hot_list = [get_hot_value(sublist) for sublist in my_list]

Edit: If you are supposed to have just one 1 in the list (in case of more than 1 element of maximum value), you may modify the get_hot_value function as:

def get_hot_value(my_list):
    max_val, hot_list, is_max_found = max(my_list), [], False
    for item in my_list:
        if item == max_val and not is_max_found:
            hot_list.append(1)
        else:
            hot_list.append(0)
            is_max_found = True
    return hot_list
like image 80
Moinuddin Quadri Avatar answered May 15 '26 11:05

Moinuddin Quadri


The other solutions are good, and solve the problem. Alternatively, if you have numpy,

import numpy as np
n = [[0.33058667182922363, 0.3436272442340851, 0.3257860243320465],
     [0.32983461022377014, 0.3487854599952698, 0.4213798701763153],
     [0.3311253488063812, 0.3473075330257416, 0.3215670585632324],
     [0.38368630170822144, 0.35151687264442444, 0.3247968554496765],
     [0.3332786560058594, 0.343686580657959, 0.32303473353385925]]

max_indices = np.argmax(n,axis=1)

final_values = [n[i] for i in max_indices]

argmax is able to find the index of the maximum value in that row, then you just need to do one list comprehension over that. Should be pretty fast I guess?

like image 20
Wboy Avatar answered May 15 '26 12:05

Wboy



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!