Arbre de décision
Package rpart
library(rpart), library(rpart.plot).
Construire l'arbre
cart = rpart( data = data,
Cible~.,
parms = list(split = "gini"),
cp = 0)
rpart.plot(cart) # afficher l'arbre
Option :
method = class/anovavariable à expliquer de type qualitative/quantitative.parms = list(split = "gini")critère à utiliser.-
control = rpart.control()pour controler les paramètres de l'arbre. Paramètres :minsplit = 5nombre minimum d'observations dans chaque noeud pour qu'une nouvelle feuille puisse être créée.minbucket = 1l'effectif minimal dans chaque noeud terminal.maxdepth = 30hauteur (profondeur) maximale de l'arbre.cp = 0paramètre de pénalisation pour la complexité.mincriterion = 0.31-p-valeur à partir de laquelle on souhaite arrêter.
summary(cart) information sur l'arbre.
Retourne :
minsplit = nbrenombre de branches minimum.minbucket = 1/3*minsplitpar défaut.control = rpart.control(minsplit = 5,cp = 0)sans contrainte sur la qualité et avec au moins 5 obsevations par feuille.
for stumps : rpart.control(maxdepth = 1,cp = -1, minsplit = 0)
Information sur l'arbre
summary(cart)
l’erreur xerror calculée par validation croisée (R constitut 10 échantillons).
xerrorerreur de la validation croisée.
Sortie :
$variable.importancerenvoie les variables les plus importantes pour le modèle.
Qualité de l'arbre
printcp(cart)
plotcp(cart)
| Indicateur | Définition |
|---|---|
| Root node error | erreur à la racine. |
| CP | coefficient de complexité. |
| nsplit | nombre de branches. |
| rel error | taux d'erreur de jeu d'apprentissage. |
| xerror | taux d'erreur de la validation croisée (R constitut 10 échantillons). |
| xstd | écart type des erreurs. |
prune(cart, cp = 0.0155441) élaguer l'arbre.
Graphiques
library(rpart.plot)
prunedcart9f = prune(cart, cp = 0.0155441)
-
rpart.plot(abre_complet)afficher l'arbre (ou sinonprp(abre_complet)). Paramètres:type = 0retirer les informations sur les noeuds.
-
Affiche :
- La classe prédite.
- La probabilité d'appartenir à la classe.
- Le pourcentage d'observations dans la noeud.
-
visTree(abre_complet)afficher l'arbre avec le nombre de données dans chaque noeud (library(visNetwork)).
Qualité
printcp(cart)afficher la complexité de l'arbre.plotcp(cart)afficher l'erreur relative en foonction de la complexité.
Package party
library(party)
ctree(y_reel~., dtf_train, control = ctree_control())
-
ctree_control()paramètres de l'arbre.minsplit = 2effectif minimal pour séparer un noeud.minbucket = 1effectif minimal dans chaque noeud terminal.maxdepth = 30hauteur (profondeur) maximale de l'arbremincriterion = 0.3c1-p-valeur à partir de laquelle on souhaite cesser la croissance.
plot(arbre_ctree) afficher l'arbre.
prp(cart, type = 2, extra = 1, split.box.col = "lightgray")
rpart.plot(cart)