Per-endpoint class count table for a fitted CTA tree
cta_endpoint_counts.RdReturns one row per terminal endpoint (leaf) per actual class, read directly from stored leaf node fields. No refitting, no prediction, and no recomputation from training data is performed.
Class counts are stored at fit time by oda_cta_fit on
every terminal leaf. Row order within each endpoint follows the order
of names(leaf$class_counts_raw), which is ascending by class
label. Endpoints are ordered by node_id, matching
cta_endpoint_summary.
Scope: This function exposes stored raw and weighted class
counts only. It does not include target-class proportions,
event rates, odds, or staging order. Staging-table and event-rate
summaries are available via cta_staging_table.
If any terminal leaf is missing the stored class counts (i.e., the
cta_tree was fitted by an earlier version of oda that did
not store endpoint counts), the function stops with a clear error.
Arguments
- tree
A
cta_treefromoda_cta_fit.
Value
A data.frame with one row per terminal endpoint per actual class
and columns:
endpoint_idInteger sequential endpoint index 1..n in node order, matching
cta_endpoint_summary.endpoint_node_idInteger tree node identifier for this endpoint leaf.
pathCharacter; AND-joined branch labels from root to this leaf (e.g.
"V14<=0.5 AND V15>0.5").terminal_predictionInteger class label assigned to this endpoint (stored leaf
majority_class).classCharacter; actual class label for this row (e.g.
"0","1").n_rawInteger raw count of observations of this actual class reaching this endpoint.
n_weightedNumeric weighted total for this actual class reaching this endpoint. Equals
n_rawwhen case weights are not active.
For a no-tree fit the returned data frame has zero rows but the correct column structure and types.
Examples
data(mtcars)
X <- mtcars[, c("cyl", "disp", "hp", "wt")]
y <- as.integer(mtcars$am)
tree <- oda_cta_fit(X, y, mindenom = 5L, mc_iter = 500L, mc_seed = 42L)
cta_endpoint_counts(tree)
#> endpoint_id endpoint_node_id path terminal_prediction class n_raw
#> 1 1 2 wt>3.18 1 0 2
#> 2 1 2 wt>3.18 1 1 12
#> 3 2 3 wt<=3.18 0 0 17
#> 4 2 3 wt<=3.18 0 1 1
#> n_weighted
#> 1 2
#> 2 12
#> 3 17
#> 4 1