Build CmdStan and Stan Model in temporary directory (#2428)

This commit is contained in:
Brian Ward 2023-05-28 09:29:42 -04:00 committed by GitHub
parent a90c4f6e72
commit 9dd5c8cc4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 13 deletions

View file

@ -54,7 +54,7 @@ class IStanBackend(ABC):
class CmdStanPyBackend(IStanBackend):
CMDSTAN_VERSION = "2.26.1"
CMDSTAN_VERSION = "2.31.0"
def __init__(self):
import cmdstanpy
# this must be set before super.__init__() for load_model to work on Windows

View file

@ -8,6 +8,7 @@ import platform
from pathlib import Path
from shutil import copy, copytree, rmtree
from typing import List
import tempfile
from setuptools import find_packages, setup, Extension
from setuptools.command.build_ext import build_ext
@ -19,13 +20,15 @@ from wheel.bdist_wheel import bdist_wheel
MODEL_DIR = "stan"
MODEL_TARGET_DIR = os.path.join("prophet", "stan_model")
CMDSTAN_VERSION = "2.26.1"
CMDSTAN_VERSION = "2.31.0"
BINARIES_DIR = "bin"
BINARIES = ["diagnose", "print", "stanc", "stansummary"]
TBB_PARENT = "stan/lib/stan_math/lib"
TBB_DIRS = ["tbb", "tbb_2019_U8"]
TBB_DIRS = ["tbb", "tbb_2020.3"]
IS_WINDOWS = platform.platform().startswith("Win")
def prune_cmdstan(cmdstan_dir: str) -> None:
"""
Keep only the cmdstan executables and tbb files (minimum required to run a cmdstanpy commands on a pre-compiled model).
@ -55,12 +58,13 @@ def repackage_cmdstan():
return os.environ.get("PROPHET_REPACKAGE_CMDSTAN", "").lower() not in ["false", "0"]
def maybe_install_cmdstan_toolchain():
def maybe_install_cmdstan_toolchain() -> bool:
"""Install C++ compilers required to build stan models on Windows machines."""
import cmdstanpy
try:
cmdstanpy.utils.cxx_toolchain_path()
return False
except Exception:
try:
from cmdstanpy.install_cxx_toolchain import run_rtools_install
@ -70,14 +74,14 @@ def maybe_install_cmdstan_toolchain():
run_rtools_install({"version": None, "dir": None, "verbose": True})
cmdstanpy.utils.cxx_toolchain_path()
return True
def install_cmdstan_deps(cmdstan_dir: Path):
import cmdstanpy
from multiprocessing import cpu_count
if repackage_cmdstan():
if platform.platform().startswith("Win"):
if IS_WINDOWS:
maybe_install_cmdstan_toolchain()
print("Installing cmdstan to", cmdstan_dir)
if os.path.isdir(cmdstan_dir):
@ -91,6 +95,7 @@ def install_cmdstan_deps(cmdstan_dir: Path):
cores=cpu_count(),
progress=True,
):
raise RuntimeError("CmdStan failed to install in repackaged directory")
@ -106,12 +111,24 @@ def build_cmdstan_model(target_dir):
"""
import cmdstanpy
cmdstan_dir = (Path(target_dir) / f"cmdstan-{CMDSTAN_VERSION}").resolve()
install_cmdstan_deps(cmdstan_dir)
model_name = "prophet.stan"
target_name = "prophet_model.bin"
sm = cmdstanpy.CmdStanModel(stan_file=os.path.join(MODEL_DIR, model_name))
copy(sm.exe_file, os.path.join(target_dir, target_name))
target_cmdstan_dir = (Path(target_dir) / f"cmdstan-{CMDSTAN_VERSION}").resolve()
with tempfile.TemporaryDirectory() as tmp_dir:
# long paths on windows can cause problems during build
if IS_WINDOWS:
cmdstan_dir = (Path(tmp_dir) / f"cmdstan-{CMDSTAN_VERSION}").resolve()
else:
cmdstan_dir = target_cmdstan_dir
install_cmdstan_deps(cmdstan_dir)
model_name = "prophet.stan"
temp_stan_file = copy(os.path.join(MODEL_DIR, model_name), cmdstan_dir)
sm = cmdstanpy.CmdStanModel(stan_file=temp_stan_file)
target_name = "prophet_model.bin"
copy(sm.exe_file, os.path.join(target_dir, target_name))
if IS_WINDOWS:
copytree(cmdstan_dir, target_cmdstan_dir)
# Clean up
for f in Path(MODEL_DIR).iterdir():
@ -119,7 +136,7 @@ def build_cmdstan_model(target_dir):
os.remove(f)
if repackage_cmdstan():
prune_cmdstan(cmdstan_dir)
prune_cmdstan(target_cmdstan_dir)
def get_backends_from_env() -> List[str]: