mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-14 20:48:08 +00:00
Build CmdStan and Stan Model in temporary directory (#2428)
This commit is contained in:
parent
a90c4f6e72
commit
9dd5c8cc4f
2 changed files with 30 additions and 13 deletions
|
|
@ -54,7 +54,7 @@ class IStanBackend(ABC):
|
|||
|
||||
|
||||
class CmdStanPyBackend(IStanBackend):
|
||||
CMDSTAN_VERSION = "2.26.1"
|
||||
CMDSTAN_VERSION = "2.31.0"
|
||||
def __init__(self):
|
||||
import cmdstanpy
|
||||
# this must be set before super.__init__() for load_model to work on Windows
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import platform
|
|||
from pathlib import Path
|
||||
from shutil import copy, copytree, rmtree
|
||||
from typing import List
|
||||
import tempfile
|
||||
|
||||
from setuptools import find_packages, setup, Extension
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
|
@ -19,13 +20,15 @@ from wheel.bdist_wheel import bdist_wheel
|
|||
MODEL_DIR = "stan"
|
||||
MODEL_TARGET_DIR = os.path.join("prophet", "stan_model")
|
||||
|
||||
CMDSTAN_VERSION = "2.26.1"
|
||||
CMDSTAN_VERSION = "2.31.0"
|
||||
BINARIES_DIR = "bin"
|
||||
BINARIES = ["diagnose", "print", "stanc", "stansummary"]
|
||||
TBB_PARENT = "stan/lib/stan_math/lib"
|
||||
TBB_DIRS = ["tbb", "tbb_2019_U8"]
|
||||
TBB_DIRS = ["tbb", "tbb_2020.3"]
|
||||
|
||||
|
||||
IS_WINDOWS = platform.platform().startswith("Win")
|
||||
|
||||
def prune_cmdstan(cmdstan_dir: str) -> None:
|
||||
"""
|
||||
Keep only the cmdstan executables and tbb files (minimum required to run a cmdstanpy commands on a pre-compiled model).
|
||||
|
|
@ -55,12 +58,13 @@ def repackage_cmdstan():
|
|||
return os.environ.get("PROPHET_REPACKAGE_CMDSTAN", "").lower() not in ["false", "0"]
|
||||
|
||||
|
||||
def maybe_install_cmdstan_toolchain():
|
||||
def maybe_install_cmdstan_toolchain() -> bool:
|
||||
"""Install C++ compilers required to build stan models on Windows machines."""
|
||||
import cmdstanpy
|
||||
|
||||
try:
|
||||
cmdstanpy.utils.cxx_toolchain_path()
|
||||
return False
|
||||
except Exception:
|
||||
try:
|
||||
from cmdstanpy.install_cxx_toolchain import run_rtools_install
|
||||
|
|
@ -70,14 +74,14 @@ def maybe_install_cmdstan_toolchain():
|
|||
|
||||
run_rtools_install({"version": None, "dir": None, "verbose": True})
|
||||
cmdstanpy.utils.cxx_toolchain_path()
|
||||
|
||||
return True
|
||||
|
||||
def install_cmdstan_deps(cmdstan_dir: Path):
|
||||
import cmdstanpy
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
if repackage_cmdstan():
|
||||
if platform.platform().startswith("Win"):
|
||||
if IS_WINDOWS:
|
||||
maybe_install_cmdstan_toolchain()
|
||||
print("Installing cmdstan to", cmdstan_dir)
|
||||
if os.path.isdir(cmdstan_dir):
|
||||
|
|
@ -91,6 +95,7 @@ def install_cmdstan_deps(cmdstan_dir: Path):
|
|||
cores=cpu_count(),
|
||||
progress=True,
|
||||
):
|
||||
|
||||
raise RuntimeError("CmdStan failed to install in repackaged directory")
|
||||
|
||||
|
||||
|
|
@ -106,12 +111,24 @@ def build_cmdstan_model(target_dir):
|
|||
"""
|
||||
import cmdstanpy
|
||||
|
||||
cmdstan_dir = (Path(target_dir) / f"cmdstan-{CMDSTAN_VERSION}").resolve()
|
||||
install_cmdstan_deps(cmdstan_dir)
|
||||
model_name = "prophet.stan"
|
||||
target_name = "prophet_model.bin"
|
||||
sm = cmdstanpy.CmdStanModel(stan_file=os.path.join(MODEL_DIR, model_name))
|
||||
copy(sm.exe_file, os.path.join(target_dir, target_name))
|
||||
target_cmdstan_dir = (Path(target_dir) / f"cmdstan-{CMDSTAN_VERSION}").resolve()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# long paths on windows can cause problems during build
|
||||
if IS_WINDOWS:
|
||||
cmdstan_dir = (Path(tmp_dir) / f"cmdstan-{CMDSTAN_VERSION}").resolve()
|
||||
else:
|
||||
cmdstan_dir = target_cmdstan_dir
|
||||
|
||||
install_cmdstan_deps(cmdstan_dir)
|
||||
model_name = "prophet.stan"
|
||||
|
||||
temp_stan_file = copy(os.path.join(MODEL_DIR, model_name), cmdstan_dir)
|
||||
sm = cmdstanpy.CmdStanModel(stan_file=temp_stan_file)
|
||||
target_name = "prophet_model.bin"
|
||||
copy(sm.exe_file, os.path.join(target_dir, target_name))
|
||||
|
||||
if IS_WINDOWS:
|
||||
copytree(cmdstan_dir, target_cmdstan_dir)
|
||||
|
||||
# Clean up
|
||||
for f in Path(MODEL_DIR).iterdir():
|
||||
|
|
@ -119,7 +136,7 @@ def build_cmdstan_model(target_dir):
|
|||
os.remove(f)
|
||||
|
||||
if repackage_cmdstan():
|
||||
prune_cmdstan(cmdstan_dir)
|
||||
prune_cmdstan(target_cmdstan_dir)
|
||||
|
||||
|
||||
def get_backends_from_env() -> List[str]:
|
||||
|
|
|
|||
Loading…
Reference in a new issue