Use np.percentile if array does not contain NaNs (#1311)

Co-authored-by: jackd-stripe <41304233+jackd-stripe@users.noreply.github.com>
This commit is contained in:
Jack Dent 2020-02-05 13:10:01 -05:00 committed by GitHub
parent 5f97759669
commit 496facb152
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1329,10 +1329,10 @@ class Prophet(object):
comp *= self.y_scale
data[component] = np.nanmean(comp, axis=1)
if self.uncertainty_samples:
data[component + '_lower'] = np.nanpercentile(
data[component + '_lower'] = self.percentile(
comp, lower_p, axis=1,
)
data[component + '_upper'] = np.nanpercentile(
data[component + '_upper'] = self.percentile(
comp, upper_p, axis=1,
)
return pd.DataFrame(data)
@ -1410,9 +1410,9 @@ class Prophet(object):
series = {}
for key in ['yhat', 'trend']:
series['{}_lower'.format(key)] = np.nanpercentile(
series['{}_lower'.format(key)] = self.percentile(
sim_values[key], lower_p, axis=1)
series['{}_upper'.format(key)] = np.nanpercentile(
series['{}_upper'.format(key)] = self.percentile(
sim_values[key], upper_p, axis=1)
return pd.DataFrame(series)
@ -1498,6 +1498,17 @@ class Prophet(object):
return trend * self.y_scale + df['floor']
def percentile(self, a, *args, **kwargs):
"""
We rely on np.nanpercentile in the rare instances where there
are a small number of bad samples with MCMC that contain NaNs.
However, since np.nanpercentile is far slower than np.percentile,
we only fall back to it if the array contains NaNs. See
https://github.com/facebook/prophet/issues/1310 for more details.
"""
fn = np.nanpercentile if np.isnan(a).any() else np.percentile
return fn(a, *args, **kwargs)
def make_future_dataframe(self, periods, freq='D', include_history=True):
"""Simulate the trend using the extrapolated generative model.