Logic simplification (#2192)

This commit is contained in:
Vincent Koc 2022-09-04 22:31:33 +00:00 committed by GitHub
parent 201887403b
commit 05438721e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 26 deletions

View file

@ -35,8 +35,8 @@ def get_holiday_names(country):
try:
holiday_names = getattr(hdays_part1, country)(years=years).values()
except AttributeError as e:
raise AttributeError(
"Holidays in {} are not currently supported!".format(country)) from e
raise AttributeError(f"Holidays in {country} are not currently supported!") from e
return set(holiday_names)
@ -59,8 +59,8 @@ def make_holidays_df(year_list, country, province=None, state=None):
try:
holidays = getattr(hdays_part1, country)(prov=province, state=state, years=year_list, expand=False)
except AttributeError as e:
raise AttributeError(
"Holidays in {} are not currently supported!".format(country)) from e
raise AttributeError(f"Holidays in {country} are not currently supported!") from e
holidays_df = pd.DataFrame([(date, holidays.get_list(date)) for date in holidays], columns=['ds', 'holiday'])
holidays_df = holidays_df.explode('holiday')
holidays_df.reset_index(inplace=True, drop=True)

View file

@ -16,9 +16,7 @@ import platform
import logging
logger = logging.getLogger('prophet.models')
PLATFORM = "unix"
if platform.platform().startswith("Win"):
PLATFORM = "win"
PLATFORM = "win" if platform.platform().startswith("Win") else "unix"
class IStanBackend(ABC):
def __init__(self):
@ -98,15 +96,11 @@ class CmdStanPyBackend(IStanBackend):
self.stan_fit = self.model.optimize(**args)
except RuntimeError as e:
# Fall back on Newton
if self.newton_fallback and args['algorithm'] != 'Newton':
logger.warning(
'Optimization terminated abnormally. Falling back to Newton.'
)
args['algorithm'] = 'Newton'
self.stan_fit = self.model.optimize(**args)
else:
if not self.newton_fallback or args['algorithm'] == 'Newton':
raise e
logger.warning('Optimization terminated abnormally. Falling back to Newton.')
args['algorithm'] = 'Newton'
self.stan_fit = self.model.optimize(**args)
params = self.stan_to_dict_numpy(
self.stan_fit.column_names, self.stan_fit.optimized_params_np)
for par in params:
@ -188,11 +182,7 @@ class CmdStanPyBackend(IStanBackend):
end = 0
two_dims = len(data.shape) > 1
for cname in column_names:
if "." in cname:
parsed = cname.split(".")
else:
parsed = cname.split("[")
parsed = cname.split(".") if "." in cname else cname.split("[")
curr = parsed[0]
if prev is None:
prev = curr
@ -208,10 +198,7 @@ class CmdStanPyBackend(IStanBackend):
output[prev] = np.array(data[start:end])
prev = curr
start = end
end += 1
else:
end += 1
end += 1
if prev in output:
raise RuntimeError(
"Found repeated column name"
@ -233,4 +220,4 @@ class StanBackendEnum(Enum):
try:
return StanBackendEnum[name].value
except KeyError as e:
raise ValueError("Unknown stan backend: {}".format(name)) from e
raise ValueError(f"Unknown stan backend: {name}") from e

View file

@ -128,7 +128,7 @@ def get_backends_from_env() -> List[str]:
def build_models(target_dir):
print(f"Compiling cmdstanpy model")
print("Compiling cmdstanpy model")
build_cmdstan_model(target_dir)
if 'PYSTAN' in get_backends_from_env():