微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

根据python中的条件绘制多色线条

我有一个包含三列和一个日期时间索引的pandas数据帧

date        px_last  200dma     50dma           
2014-12-24  2081.88 1953.16760  2019.2726
2014-12-26  2088.77 1954.37975  2023.7982
2014-12-29  2090.57 1955.62695  2028.3544
2014-12-30  2080.35 1956.73455  2032.2262
2014-12-31  2058.90 1957.66780  2035.3240

我想制作一个’px_last’列的时间序列图,如果在给定日50dma高于200dma值,则为绿色,如果50dma值低于200dma值,则为红色.我见过这个例子,但似乎无法使它适用于我的情况
http://matplotlib.org/examples/pylab_examples/multicolored_line.html

解决方法:

下面是一个例子做才不至于matplotlib.collections.LineCollection.我们的想法是先识别交叉点,然后通过groupby使用绘图功能.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# simulate data
# =============================
np.random.seed(1234)
df = pd.DataFrame({'px_last': 100 + np.random.randn(1000).cumsum()}, index=pd.date_range('2010-01-01', periods=1000, freq='B'))
df['50dma'] = pd.rolling_mean(df['px_last'], window=50)
df['200dma'] = pd.rolling_mean(df['px_last'], window=200)
df['label'] = np.where(df['50dma'] > df['200dma'], 1, -1)


# plot
# =============================
df = df.dropna(axis=0, how='any')

fig, ax = plt.subplots()

def plot_func(group):
    global ax
    color = 'r' if (group['label'] < 0).all() else 'g'
    lw = 2.0
    ax.plot(group.index, group.px_last, c=color, linewidth=lw)

df.groupby((df['label'].shift() * df['label'] < 0).cumsum()).apply(plot_func)

# add ma lines
ax.plot(df.index, df['50dma'], 'k--', label='MA-50')
ax.plot(df.index, df['200dma'], 'b--', label='MA-200')
ax.legend(loc='best')

enter image description here

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。

相关推荐