Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In pandas, how to convert a numeric type to category type to use with seaborn hue

I am stuck on what seems like an easy problem trying to color the different groups on a scatterplot I am creating. I have the following example dataframe and graph:

test_df = pd.DataFrame({ 'A' : 1.,
                    'B' : np.array([1, 5, 9, 7, 3], dtype='int32'),
                    'C' : np.array([6, 7, 8, 9, 3], dtype='int32'),
                    'D' : np.array([2, 2, 3, 4, 4], dtype='int32'),
                    'E' : pd.Categorical(["test","train","test","train","train"]),
                    'F' : 'foo' })

# fix to category
# test_df['D'] = test_df["D"].astype('category')

# and test plot
f, ax = plt.subplots(figsize=(6,6))
ax = sns.scatterplot(x="B", y="C", hue="D", s=100, 
                     data=test_df)

which creates this graph:

enter image description here However, instead of a continuous scale, I'd like a categorical scale for each of the 3 categories [2, 3, 4]. After I uncomment the line of code test_df['D'] = ..., to change this column to a category column-type for category-coloring in the seaborn plot, I receive the following error from the seaborn plot: TypeError: data type not understood

Does anybody know the correct way to convert this numeric column to a factor / categorical column to use for coloring?

Thanks!

like image 244
Canovice Avatar asked Nov 22 '25 20:11

Canovice


1 Answers

I copy/pasted your code, added libraries for import and removed the comment as I thought it looked good. I get a plot with 'categorical' colouring for value [2,3,4] without changing any of your code.

Try updating your seaborn module using: pip install --upgrade seaborn

Here is a list of working libraries used with your code.

matplotlib==3.1.2
numpy==1.18.1
seaborn==0.10.0
pandas==0.25.3

... which executed below code.

import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

test_df = pd.DataFrame({ 'A' : 1.,
                    'B' : np.array([1, 5, 9, 7, 3], dtype='int32'),
                    'C' : np.array([6, 7, 8, 9, 3], dtype='int32'),
                    'D' : np.array([2, 2, 3, 4, 4], dtype='int32'),
                    'E' : pd.Categorical(["test","train","test","train","train"]),
                    'F' : 'foo' })

# fix to category
test_df['D'] = test_df["D"].astype('category')

# and test plot
f, ax = plt.subplots(figsize=(6,6))
ax = sns.scatterplot(x="B", y="C", hue="D", s=100, 
                     data=test_df)
plt.show()
like image 162
AabyWan Avatar answered Nov 24 '25 11:11

AabyWan



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!