mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-14 20:48:08 +00:00
Logic simplification (#2192)
This commit is contained in:
parent
201887403b
commit
05438721e7
3 changed files with 13 additions and 26 deletions
|
|
@ -35,8 +35,8 @@ def get_holiday_names(country):
|
|||
try:
|
||||
holiday_names = getattr(hdays_part1, country)(years=years).values()
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
"Holidays in {} are not currently supported!".format(country)) from e
|
||||
raise AttributeError(f"Holidays in {country} are not currently supported!") from e
|
||||
|
||||
return set(holiday_names)
|
||||
|
||||
|
||||
|
|
@ -59,8 +59,8 @@ def make_holidays_df(year_list, country, province=None, state=None):
|
|||
try:
|
||||
holidays = getattr(hdays_part1, country)(prov=province, state=state, years=year_list, expand=False)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
"Holidays in {} are not currently supported!".format(country)) from e
|
||||
raise AttributeError(f"Holidays in {country} are not currently supported!") from e
|
||||
|
||||
holidays_df = pd.DataFrame([(date, holidays.get_list(date)) for date in holidays], columns=['ds', 'holiday'])
|
||||
holidays_df = holidays_df.explode('holiday')
|
||||
holidays_df.reset_index(inplace=True, drop=True)
|
||||
|
|
|
|||
|
|
@ -16,9 +16,7 @@ import platform
|
|||
import logging
|
||||
logger = logging.getLogger('prophet.models')
|
||||
|
||||
PLATFORM = "unix"
|
||||
if platform.platform().startswith("Win"):
|
||||
PLATFORM = "win"
|
||||
PLATFORM = "win" if platform.platform().startswith("Win") else "unix"
|
||||
|
||||
class IStanBackend(ABC):
|
||||
def __init__(self):
|
||||
|
|
@ -98,15 +96,11 @@ class CmdStanPyBackend(IStanBackend):
|
|||
self.stan_fit = self.model.optimize(**args)
|
||||
except RuntimeError as e:
|
||||
# Fall back on Newton
|
||||
if self.newton_fallback and args['algorithm'] != 'Newton':
|
||||
logger.warning(
|
||||
'Optimization terminated abnormally. Falling back to Newton.'
|
||||
)
|
||||
args['algorithm'] = 'Newton'
|
||||
self.stan_fit = self.model.optimize(**args)
|
||||
else:
|
||||
if not self.newton_fallback or args['algorithm'] == 'Newton':
|
||||
raise e
|
||||
|
||||
logger.warning('Optimization terminated abnormally. Falling back to Newton.')
|
||||
args['algorithm'] = 'Newton'
|
||||
self.stan_fit = self.model.optimize(**args)
|
||||
params = self.stan_to_dict_numpy(
|
||||
self.stan_fit.column_names, self.stan_fit.optimized_params_np)
|
||||
for par in params:
|
||||
|
|
@ -188,11 +182,7 @@ class CmdStanPyBackend(IStanBackend):
|
|||
end = 0
|
||||
two_dims = len(data.shape) > 1
|
||||
for cname in column_names:
|
||||
if "." in cname:
|
||||
parsed = cname.split(".")
|
||||
else:
|
||||
parsed = cname.split("[")
|
||||
|
||||
parsed = cname.split(".") if "." in cname else cname.split("[")
|
||||
curr = parsed[0]
|
||||
if prev is None:
|
||||
prev = curr
|
||||
|
|
@ -208,10 +198,7 @@ class CmdStanPyBackend(IStanBackend):
|
|||
output[prev] = np.array(data[start:end])
|
||||
prev = curr
|
||||
start = end
|
||||
end += 1
|
||||
else:
|
||||
end += 1
|
||||
|
||||
end += 1
|
||||
if prev in output:
|
||||
raise RuntimeError(
|
||||
"Found repeated column name"
|
||||
|
|
@ -233,4 +220,4 @@ class StanBackendEnum(Enum):
|
|||
try:
|
||||
return StanBackendEnum[name].value
|
||||
except KeyError as e:
|
||||
raise ValueError("Unknown stan backend: {}".format(name)) from e
|
||||
raise ValueError(f"Unknown stan backend: {name}") from e
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ def get_backends_from_env() -> List[str]:
|
|||
|
||||
|
||||
def build_models(target_dir):
|
||||
print(f"Compiling cmdstanpy model")
|
||||
print("Compiling cmdstanpy model")
|
||||
build_cmdstan_model(target_dir)
|
||||
|
||||
if 'PYSTAN' in get_backends_from_env():
|
||||
|
|
|
|||
Loading…
Reference in a new issue