diff --git a/python/prophet/forecaster.py b/python/prophet/forecaster.py index c3027e4..7a95333 100644 --- a/python/prophet/forecaster.py +++ b/python/prophet/forecaster.py @@ -9,11 +9,12 @@ from __future__ import absolute_import, division, print_function import logging from collections import OrderedDict, defaultdict from copy import deepcopy -from datetime import timedelta, datetime -from typing import Dict, List +from datetime import timedelta +from typing import Dict, List, Union import numpy as np import pandas as pd +from numpy.typing import NDArray from prophet.make_holidays import get_holiday_names, make_holidays_df from prophet.models import StanBackendEnum @@ -21,7 +22,7 @@ from prophet.plot import (plot, plot_components) logger = logging.getLogger('prophet') logger.setLevel(logging.INFO) - +NANOSECONDS_TO_SECONDS = 1000 * 1000 * 1000 class Prophet(object): """Prophet forecaster. @@ -421,7 +422,11 @@ class Prophet(object): self.changepoints_t = np.array([0]) # dummy changepoint @staticmethod - def fourier_series(dates, period, series_order): + def fourier_series( + dates: pd.Series, + period: Union[int, float], + series_order: int, + ) -> NDArray[np.float_]: """Provides Fourier series components with the specified frequency and order. @@ -435,17 +440,19 @@ class Prophet(object): ------- Matrix with seasonality features. """ + if not (series_order >= 1): + raise ValueError("series_order must be >= 1") + # convert to days since epoch - t = np.array( - (dates - datetime(1970, 1, 1)) - .dt.total_seconds() - .astype(float) - ) / (3600 * 24.) - return np.column_stack([ - fun((2.0 * (i + 1) * np.pi * t / period)) - for i in range(series_order) - for fun in (np.sin, np.cos) - ]) + t = dates.to_numpy(dtype=int) // NANOSECONDS_TO_SECONDS / (3600 * 24.) + + x_T = t * np.pi * 2 + fourier_components = np.empty((dates.shape[0], 2 * series_order)) + for i in range(series_order): + c = x_T * (i + 1) / period + fourier_components[:, 2 * i] = np.sin(c) + fourier_components[:, (2 * i) + 1] = np.cos(c) + return fourier_components @classmethod def make_seasonality_features(cls, dates, period, series_order, prefix):