Skip to content

babelyian/randomgarden

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

randomgarden

A small mimic of scikit-learn's Random Forest, which implements a decision tree (FlowerPot) and a random forest (RandomGarden) built on top of it with a scikit-learn-like API. The trees only work over numerical features and use Gini impurity for splitting. But it's enough to build up entire random forests doing multi-class classification, as the demo notebook shows.

Example usage

Below is a slightly contrived example showing a number of possible supported operations:

from randomgarden import FlowerPot, RandomGarden

rg = RandomGarden(n_estimators=100, max_depth=5, random_state=42)
rg.fit(X_train, y_train)

preds = rg.predict(X_test)
proba = rg.predict_proba(X_test)
acc   = rg.score(X_test, y_test)

print(f'{acc:.4f}')  # prints the accuracy on the test set, e.g. 1.0000

Training a random forest

The demo.ipynb notebook contains a complete, end-to-end example of fitting a RandomGarden random forest to serve as a multi-class classifier. It demonstrates how to instantiate the forest using the randomgarden module, fit it on the classic Iris dataset, and evaluate predictions against scikit-learn's RandomForestClassifier. The notebook specifically shows per-class accuracy and averaged class probabilities for the three Iris species: flowerpot tree

Tracing / visualization

To make tracing easier, the dag.ipynb notebook generates Graphviz diagrams of a FlowerPot decision tree. For example, the diagram shown here depicts a depth-3 tree trained on the Iris dataset and is created by passing the fitted tree to the tree_to_dot function in the preceding code. Each node displays the split feature and threshold; each leaf displays the predicted class and per-class probabilities.

from randomgarden import FlowerPot

fp = FlowerPot(max_depth=3, random_state=42)
fp.fit(X_train, y_train)

dot = tree_to_dot(fp, feature_names=feature_names, class_names=class_names)
dot.render('flowerpot_tree', format='png', cleanup=True)

Running tests

To run the unit tests you will have to install scikit-learn, which the tests use as a reference for verifying the correctness of the calculated predictions and accuracy. Then simply:

python -m pytest

About

A small mimic of scikit-learn's Random Forest

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors