10X faster predict by replacing DataFrame with dict (#2299)

This commit is contained in:
Oren Matar 2023-01-09 17:14:01 +02:00 committed by GitHub
parent 64150d0409
commit 0716751288
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1442,7 +1442,7 @@ class Prophet(object):
sim_values[k] = np.column_stack(v)
return sim_values
def sample_model(self, df, seasonal_features, iteration, s_a, s_m):
def sample_model(self, df, seasonal_features, iteration, s_a, s_m) -> Dict[str, np.ndarray]:
"""Simulate observations from the extrapolated generative model.
Parameters
@ -1455,7 +1455,7 @@ class Prophet(object):
Returns
-------
Dataframe with trend and yhat, each like df['t'].
Dictionary with `yhat` and `trend`, each like df['t'].
"""
trend = self.sample_predictive_trend(df, iteration)
@ -1467,10 +1467,10 @@ class Prophet(object):
sigma = self.params['sigma_obs'][iteration]
noise = np.random.normal(0, sigma, df.shape[0]) * self.y_scale
return pd.DataFrame({
return {
'yhat': trend * (1 + Xb_m) + Xb_a + noise,
'trend': trend
})
}
def sample_model_vectorized(
self,
@ -1480,12 +1480,12 @@ class Prophet(object):
s_a: np.ndarray,
s_m: np.ndarray,
n_samples: int,
) -> List[pd.DataFrame]:
) -> List[Dict[str, np.ndarray]]:
"""Simulate observations from the extrapolated generative model. Vectorized version of sample_model().
Returns
-------
List (length n_samples) of DataFrames with np.arrays for trend and yhat, each ordered like df['t'].
List (length n_samples) of dictionaries with arrays for trend and yhat, each ordered like df['t'].
"""
# Get the seasonality and regressor components, which are deterministic per iteration
beta = self.params['beta'][iteration]
@ -1499,10 +1499,10 @@ class Prophet(object):
simulations = []
for trend, noise in zip(trends, noise_terms):
simulations.append(pd.DataFrame({
simulations.append({
'yhat': trend * (1 + Xb_m) + Xb_a + noise,
'trend': trend
}))
})
return simulations
def sample_predictive_trend(self, df, iteration):