Software
stochtree
stochtree is a general-purpose R and Python library for stochastic tree ensemble modeling.
Its primary interface is a “batteries-included” implementation of both the BART model for supervised learning and the BCF model for causal inference.
It also offers a flexible “low-level” interface for specifying custom models that involve stochastic tree ensemble terms. Think PyTorch, but for trees. That’s the vision.
The underlying implementation is largely in C++ with tree classes and I/O routines borrowed from xgboost and LightGBM. See here for stochtree’s source repo.
faststochtree
faststochtree is a performance-optimized implementation of stochtree’s core models, fine-tuned for fast sampling on Apple Silicon machines. It is mostly optimized for efficient CPU execution, but it was inspired by Giacomo Petrillo’s bartz project which implements a GPU-accelerated version of BART using jax.
On a benchmark dataset with 50k observations and 50 features,
stochtreetakes 3 minutes to run 1200 iterations of the BART sampler, whilefaststochtreedoes the same in 36 seconds,stochtreetakes 80 seconds to run 40 iterations of the XBART sampler, whilefaststochtreedoes the same in 23 seconds.
sgemm_silicon
sgemm_silicon is an experimental / didactic low-level implementation of matrix multiplication on Apple Silicon hardware, inspired by Amanzhol Salykov’s work on x86 CPUs and Nvidia GPUs.
I haven’t yet worked through the GPU implementation on Metal, but my CPU implementation reaches about 1/4 of the FLOPS of Accelerate’s cblas_sgemm routine, which owes in part to Accelerate’s access to undocumented “AMX” instructions (see here for some third-party research into this instruction set). My implementation outperforms both Eigen and MLX in terms of FLOPS.
scikit-tree
scikit-tree was a spiritual precursor to stochtree that began while I was a PhD student. It extended the Cython implementation of a decision tree learner from scikit-learn. The general rationale was:
scikit-learn’s codebase is fast and heavily tested. Relying on their implementation of a tree data structures makes it easier to extend to new methods without worrying about edge cases in the tree code.Cython offers a nice tradeoff between high-level code (readability, maintainability, syntactic sugar, etc…) and low-level code (direct memory access / management, GIL release, etc…)
The plan was for this project to provide an easier way to experiment with and extend decision tree models without spending 6 months writing C++. Ultimately, it proved not to be the right tool for the job, for several reasons:
Coupling to cython / python means that an R package would be difficult to offer (and certainly swimming upsteam)
Tooling: cython is much less developed than C / C++ in terms of debugging and profiling tools
Performance: most existing BART packages were built with a C++ core and while cython can be much faster than Python, it is still nontrivial to write cython code in a way that competes with C++. The selling point of Cython being as readable as Python could end sacrificed as successive optimizations are made to match C++ performance.
The process of setting up this testbed project is detailed in this blog post.
implementations
Implementations is a Github repository with simple implementations of common statistical and mathematical algorithms that is now largely dormant, but was a fun source of learning during my PhD. The goal of the repo was to enable easy exploration of the challenges and tradeoffs inherent in statistics and machine learning methods, not to provide robust implementations for use in research or applications.
Some examples: