Speed up fourier_series (#2334)

This commit is contained in:
Yasir Ekinci 2023-01-11 15:49:03 +01:00 committed by GitHub
parent 5e2221f480
commit e2b4ef3a8e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -9,11 +9,12 @@ from __future__ import absolute_import, division, print_function
import logging
from collections import OrderedDict, defaultdict
from copy import deepcopy
from datetime import timedelta, datetime
from typing import Dict, List
from datetime import timedelta
from typing import Dict, List, Union
import numpy as np
import pandas as pd
from numpy.typing import NDArray
from prophet.make_holidays import get_holiday_names, make_holidays_df
from prophet.models import StanBackendEnum
@ -21,7 +22,7 @@ from prophet.plot import (plot, plot_components)
logger = logging.getLogger('prophet')
logger.setLevel(logging.INFO)
NANOSECONDS_TO_SECONDS = 1000 * 1000 * 1000
class Prophet(object):
"""Prophet forecaster.
@ -421,7 +422,11 @@ class Prophet(object):
self.changepoints_t = np.array([0]) # dummy changepoint
@staticmethod
def fourier_series(dates, period, series_order):
def fourier_series(
dates: pd.Series,
period: Union[int, float],
series_order: int,
) -> NDArray[np.float_]:
"""Provides Fourier series components with the specified frequency
and order.
@ -435,17 +440,19 @@ class Prophet(object):
-------
Matrix with seasonality features.
"""
if not (series_order >= 1):
raise ValueError("series_order must be >= 1")
# convert to days since epoch
t = np.array(
(dates - datetime(1970, 1, 1))
.dt.total_seconds()
.astype(float)
) / (3600 * 24.)
return np.column_stack([
fun((2.0 * (i + 1) * np.pi * t / period))
for i in range(series_order)
for fun in (np.sin, np.cos)
])
t = dates.to_numpy(dtype=int) // NANOSECONDS_TO_SECONDS / (3600 * 24.)
x_T = t * np.pi * 2
fourier_components = np.empty((dates.shape[0], 2 * series_order))
for i in range(series_order):
c = x_T * (i + 1) / period
fourier_components[:, 2 * i] = np.sin(c)
fourier_components[:, (2 * i) + 1] = np.cos(c)
return fourier_components
@classmethod
def make_seasonality_features(cls, dates, period, series_order, prefix):