【问题标题】:Create plotting function with one input that uses two data frames in one plot使用一个输入创建绘图函数,该函数在一个绘图中使用两个数据框
【发布时间】:2022-01-24 16:07:41
【问题描述】:

我正在尝试创建一个函数,该函数采用一个参数,输出是一个带有两条线的图,将标普 500 指数的价格与过去几年的二氧化碳排放量进行比较,从选定国家/地区的 2000 年开始。

from pandas_datareader import wb
import pandas as pd
import matplotlib.pyplot as plt

def plot_comparison(selected_country)

我有两个数据框。其中有许多国家和从 2000 年开始的每一年的二氧化碳排放数据,如图所示或以下代码:df_co2

        country                         year    co2 emissions
0       Africa Eastern and Southern     2018    600351.133333
1       Africa Eastern and Southern     2017    601323.394691
2       Africa Eastern and Southern     2016    592299.593959
3       Africa Eastern and Southern     2015    586385.004029
4       Africa Eastern and Southern     2014    601860.163983
... ... ... ...
5049    Zimbabwe                        2004    9770.000000
5050    Zimbabwe                        2003    10180.000000
5051    Zimbabwe                        2002    12490.000000
5052    Zimbabwe                        2001    13900.000000
5053    Zimbabwe                        2000    13700.000000

第二个数据框包含标准普尔 500 指数的年度收盘价,如图所示或以下代码:df_spx

    year    Close
0   2000    1320.280029
1   2001    1148.079956
2   2002    879.820007
3   2003    1111.920044
4   2004    1211.920044
5   2005    1248.290039
6   2006    1418.300049
7   2007    1468.359985
8   2008    903.250000
9   2009    1115.099976
...

这是我到目前为止得到的代码,但它(显然)不起作用,我无法继续前进。收到此错误:

NameError:名称“国家”未定义

def plot_comparison(selected_country):
    fig, ax = plt.subplots(ncols=1,
                           nrows=1, 
                           figsize=(15,6), 
                           dpi=100)
        
    for selected_country in df_co2:
        country_df = df_co2[df_co2['country'] == selected_country].copy()
        ax.plot(country_df['year'], country_df['co2 emissions'], label=country)

    ax.set_xlabel('Year', fontsize=12)        
    ax.set_ylabel('CO2 value', fontsize=12) 

    fig.legend(fontsize=12)
    plt.show()    

【问题讨论】:

  • @Mr.T 嗨,我认为您误解了所需的输出。我的目标是创建一个函数,它接受一个参数(数据框 df_co2 中的国家名称)并将其与标准普尔 500 指数一起绘制。所以基本上有两条线,一条是给定国家一年中的二氧化碳排放值,第二条是标准普尔500指数在同一年的收盘价。现在是否清楚,或者我应该提高我的解释技巧。编程新手。
  • 啊,好吧。按国家比较。那么,为什么现在不能添加ax.plot(df_spx['year'], df_spx['close'], label="S&P 500 index")
  • @Mr.T 我只需要绘制两条线:一条用于标准普尔 500 指数,第二条用于输入函数的国家用户的二氧化碳排放量,例如plot_comparison('Zimbabwe').
  • @Mr.T 我有 NameError: name 'country' is not defined。不知道为什么它没有在数据框中看到它...
  • 您在函数中将变量定义为selected_country

标签: python pandas matplotlib


【解决方案1】:

不清楚错误消息的来源。但是由于数据范围大不相同,您需要采用 twin axis 方法,这里有一个完整的示例:

import matplotlib.pyplot as plt
import pandas as pd


#sample data generation
import numpy as np
n = 30
np.random.seed(123)
country_list = list("ABCDE")
df_co2 = pd.DataFrame({"country": np.repeat(country_list, n), 
                       "year": np.tile(range(2010, 2010+n), len(country_list)), 
                       "co2 emissions": np.random.randint(10000, 200000, n * len(country_list))})    
df_spx = pd.DataFrame({"year": np.arange(2010, 2010+n), 
                       "close": np.random.randint(1000, 2000, n)})


#here starts your modified function  
def plot_comparison(selected_country):
    fig, ax = plt.subplots(ncols=1,
                           nrows=1, 
                           figsize=(15,6), 
                           dpi=100)
    
    #filter for selected country 
    #and sort for year, just in case the year column is not sorted yet  
    df_plot = df_co2[df_co2["country"] == selected_country].sort_values('year')
    #plot into the axis object
    df_plot.plot(x='year', y='co2 emissions', color="blue", label=f"country is {selected_country}", ax=ax, legend=False)
    #set the axis limits to the minimum and maximum values 
    #so that the data for different countries are represented at the same scale
    ax.set_ylim(df_co2['co2 emissions'].min(), df_co2['co2 emissions'].max())
    ax.set_xlabel('Year', fontsize=12)        
    ax.set_ylabel('CO2 value', fontsize=12) 
    
    #create a twin axis because the values of the two series will differ substantially 
    ax1 = ax.twinx()
    #plot the second line onto the twin axis
    df_spx.sort_values('year').plot(x='year', y='close', color="red", label="S&P 500", ax=ax1, legend=False)  
    ax1.set_ylabel('Close', fontsize=12) 

    fig.legend(fontsize=12, loc="upper center", ncol=2)
    plt.show()


#call the function
for item in ["A", "C", "D"]:
    plot_comparison(item)

示例输出(1 个,共 3 个):

【讨论】:

    猜你喜欢
    • 2021-12-13
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-09-14
    相关资源
    最近更新 更多