¿Cómo establecer la leyenda de los subplots compartidos de matplotlib para ser horizontal y en posición central inferior?

Estoy usando Matplotlib y Seaborn para trazar cuatro gráficos de barras con una leyenda compartida. Sin embargo, no puedo hacer que la leyenda sea horizontal y en el centro inferior. Traté de establecer los números en esta línea:

 ax.legend(bbox_to_anchor=(0.99, -0.15),
                    loc=1,
                    fontsize=13,
                    # ncol=2
                    )

pero si la leyenda va al centro, entonces la distancia entre las dos columnas del subplot aumentaría también lo que no es bueno. enter image description here

Aquí está mi código:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pdb
import pyautogui
import multiprocessing
from time import sleep
from matplotlib import patches as mpatches

def convert_to_grouped_bar_chart_format(data, 
    col_1_name, col_2_name, col_3_name):
    """
        Parameters
        ----------
        data: Pandas dataframe. Format:
        
           Method               Class1  Class2     Class3
        0  Method_1           0.1       0.2        0.3
        1  Method_2           0.6       0.5        0.4
        
        Returns
        -------
        data_grouped: Pandas dataframe.
    """
    cls_list = data.columns[1:].tolist()
    col_1 = []
    col_2 = []
    col_3 = []
    (num_of_rows, num_of_cols) = data.shape
    for row_idx in range(num_of_rows):
        for cls_idx, cls in enumerate(cls_list):
            col_1.append(data.iloc[row_idx, 0])
            col_2.append(cls)
            col_3.append(data.iloc[row_idx, cls_idx+1])
            pass
        pass
    data_grouped_dict = {
        col_1_name: col_1,
        col_2_name: col_2,
        col_3_name: col_3
    }
    data_grouped = pd.DataFrame(data_grouped_dict, columns = [col_1_name, col_2_name, col_3_name])
    return data_grouped
    
def draw_four_bar_graph_seaborn():
    file_list = [
        ['Measure1_ED.csv', 'Measure1_ES.csv'],
        ['Measure2_ED.csv', 'Measure2_ES.csv']
    ]
    n_rows = len(file_list)
    n_cols = len(file_list[0])
    fig, axes = plt.subplots(n_rows, n_cols)
   
    for idx_row in range(n_rows):
        # if idx_row > 0:
            # continue
        for idx_col in range(n_cols):
            file_name = file_list[idx_row][idx_col]
            data = pd.read_csv(file_name)
            col_1_name = 'Method'
            col_2_name = 'Class'
            col_3_name = file_name.split('_')[0]
            data_type = file_name.split('_')[1][:-4]
            
            ax = axes[idx_row, idx_col]
            # ax =axes[idx_col]
            
            data_grouped = convert_to_grouped_bar_chart_format(data,
                col_1_name, col_2_name, col_3_name)    
            splot = sns.barplot(
                # ax=axes[idx_row, idx_col],
                ax=ax,
                x=col_2_name, 
                y=col_3_name, 
                hue=col_1_name,         
                palette="magma", 
                # palette=my_pal,
                # sharey=False,
                data=data_grouped)
            splot.set_xlabel("",fontsize=1)
            splot.set_ylabel(col_3_name,fontsize=13)
            splot.tick_params(labelsize=13)            
            title_subplot = 'Title 1'            
            ax.set_title(title_subplot, fontsize=13)
            if col_3_name == 'Measure1': 
                ax.set_ylim(0, 1.10)
            else:
                ax.set_ylim(0, 2.25)
            for p1 in splot.patches:        
                splot.annotate('%.3f' % p1.get_height(), 
                           (p1.get_x() + p1.get_width() / 2., p1.get_height()), 
                           ha = 'center', va = 'center', 
                           size=13,
                           xytext = (0, 8), 
                           textcoords = 'offset points')

            if (idx_row == 1) and (idx_col == 0):
                ax.legend(
                    bbox_to_anchor=(1.2, -0.15),
                    loc=1,
                    fontsize=13,
                    # ncol=2
                    )
            else:
                splot.get_legend().remove()
            
            # Change width size
            # ax = axes[idx_row, idx_col]
            new_value = 0.35
            for patch in ax.patches :
                current_width = patch.get_width()
                diff = current_width - new_value
                # we change the bar width
                patch.set_width(new_value)
                # we recenter the bar
                patch.set_x(patch.get_x() + diff * .5)

    plt.tight_layout(pad=0)    
    mng = plt.get_current_fig_manager()
    mng.window.state('zoomed') #works fine on Windows! 
    plt.show()
    fig.savefig('out.pdf')
    plt.close()

def draw_graph_then_save_and_close_automatically(func=None, args=[]):
    coords_close_graph = (1365, 12) # Auto click to close graph
    multiprocessing.Process(target=func, args=args).start()
    sleep(10)
    pyautogui.moveTo(coords_close_graph)
    pyautogui.click()

def main():
    draw_graph_then_save_and_close_automatically(
        func=draw_four_bar_graph_seaborn,
        args=[])

if __name__ == '__main__':
   main()

Por favor, ayúdame, muchas gracias.

Pregunta hecha hace 3 años, 4 meses, 28 días - Por compilerhero1251


3 Respuestas:

  • Usar una figura-legend en lugar de colocar en uno de tus ejes y establecer el número de columnas que la leyenda debe tener al número de entradas de leyenda. Aquí hay un ejemplo (he encontrado que el tuyo es lo suficientemente mínimo^^^)

    import numpy as np
    from matplotlib import pyplot as plt
    # create random data
    y = np.random.randint(0,100,size=(10, 3))
    # open a figure with two axes
    fig,axs = plt.subplots(1,2)
    # plot something in the axes
    axs[0].plot(y[:,0])
    axs[1].plot(y[:,1:])
    # define the name of the 
    legendEntries = ("a","bcdefg","h")
    # set figure legend entries, number of columns, location
    fig.legend(legendEntries,ncol=len(legendEntries),loc="upper center")
    

    legends

    Aquí. es un doc-example, haciendo hincapié en utilizar el argumento ncol forzar a matplotlib a expandir la leyenda horizontalmente. Y Aquí. es un tutorial/ejemplo cómo puedes colocar la leyenda de un axis fuera de la región del eje.

    Respondida el Dec 17, 2020 a las 10:59 - por algorithmalchemist

    Votos positivos: 0 | Votos negativos: 0

  • Para establecer la leyenda de los subplots compartidos de matplotlib como horizontal y en posición central inferior, puedes probar lo siguiente en tu código:

    • Elimina la línea `bbox_to_anchor=(1.2, -0.15)` de la función `ax.legend()`.
    • Agrega la siguiente línea después de la función `ax.legend()` para hacer que la leyenda sea horizontal y esté en la posición central inferior:
    python
    ax.legend(bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=2, fontsize=13)
    

    Esta línea moverá la leyenda a la posición central inferior y la hará horizontal con 2 columnas de elementos en la leyenda.

    Con estos cambios, tu función draw_four_bar_graph_seaborn() debería verse así:

    python
    if (idx_row == 1) and (idx_col == 0):
        ax.legend(loc='upper center', ncol=2, fontsize=13)
    else:
        splot.get_legend().remove()
    

    Prueba estos cambios y vuelve a ejecutar tu código para verificar si la leyenda se muestra como deseas. Espero que esta solución te sea útil. Si tienes alguna otra pregunta o necesitas más ayuda, ¡no dudes en decírmelo!

    Respondida el Dec 18, 2020 a las 11:01 - por Gpt

    Votos positivos: 0 | Votos negativos: 0

  • Para establecer la leyenda de los subplots compartidos de matplotlib para que sea horizontal y en la posición central inferior, puedes utilizar el siguiente código:

    python
    import matplotlib.pyplot as plt
    
    

    Crea los subplots

    fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)

    Agrega datos a los subplots

    ...

    Establece la leyenda

    axes[1, 1].legend(bbox_to_anchor=(0.5, -0.2), loc='lower center', ncol=2)

    Muestra los subplots

    plt.show()

    En este código, el argumento bbox_to_anchor especifica la posición de la leyenda en relación con el subplot actual. (0.5, -0.2) coloca la leyenda en el centro horizontal del subplot actual y 0.2 unidades por debajo del borde inferior del subplot.

    El argumento loc especifica la ubicación de la leyenda en relación con el subplot actual. 'lower center' coloca la leyenda en el centro inferior del subplot actual.

    El argumento ncol especifica el número de columnas en la leyenda. ncol=2 coloca las etiquetas de la leyenda en dos columnas.

    Ten en cuenta que la posición de la leyenda puede variar según el tamaño y la distribución de los subplots. Es posible que debas ajustar los valores de bbox_to_anchor y loc para obtener la posición deseada.

    Respondida el Dec 18, 2020 a las 11:05 - por Gemini

    Votos positivos: 0 | Votos negativos: 0