- 1、原创力文档(book118)网站文档一经付费(服务费),不意味着购买了该文档的版权,仅供个人/单位学习、研究之用,不得用于商业用途,未经授权,严禁复制、发行、汇编、翻译或者网络传播等,侵权必究。。
- 2、本站所有内容均由合作方或网友上传,本站不对文档的完整性、权威性及其观点立场正确性做任何保证或承诺!文档内容仅供研究参考,付费前请自行鉴别。如您付费,意味着您自己接受本站规则且自行承担风险,本站不退款、不进行额外附加服务;查看《如何避免下载的几个坑》。如果您已付费下载过本站文档,您可以点击 这里二次下载。
- 3、如文档侵犯商业秘密、侵犯著作权、侵犯人身权等,请点击“版权申诉”(推荐),也可以打举报电话:400-050-0827(电话支持时间:9:00-18:30)。
查看更多
spark.mllib源码阅读-分类算法4-DecisionTree
第一部分:
决策树模型
分类决策树模型是一种描述对实例进行分类的树形结构。决策树由结点和有向边组成。结点有两种类型:内部节点和叶节点,内部节点表示一个特征或属性,叶节点表示一个类。分类的时候,从根节点开始,当前节点设为根节点,当前节点必定是一种特征,根据实例的该特征的取值,向下移动,直到到达叶节点,将实例分到叶节点对应的类中。
图 1 是一棵结构简单的决策树,用于预测贷款用户是否具有偿还贷款的能力。贷款用户主要具备三个属性:是否拥有房产,是否结婚,平均月收入。每一个内部节点都表示一个属性条件判断,叶子节点表示贷款用户是否具有偿还能力。例如:用户甲没有房产,没有结婚,月收入 5K。通过决策树的根节点判断,用户甲符合右边分支 (拥有房产为“否”);再判断是否结婚,用户甲符合左边分支 (是否结婚为否);然后判断月收入是否大于 4k,用户甲符合左边分支 (月收入大于 4K),该用户落在“可以偿还”的叶子节点上。所以预测用户甲具备偿还贷款能力。
决策树的存储与表示:
决策树是一类特殊的树,每个结点存储了结点的分裂信息(非叶子结点)或者分类信息(叶子结点),既然是树结构,那么就可以用我们熟悉的树数据结构来表示和存储了。
Spark在Node.Scala文件中实现了决策树结点的存储与通过遍历结点来进行预测,其基本的形态是一颗二叉树,并实现了三类不同的结点:
LeafNode:叶子结点使用LeafNode存储,关键参数有prediction,impurity。
InternalNode:内部结点(包含叶子结点)InternalNode,关键参数有prediction,impurity,gain,leftChild,rightChild,split。
LearningNode:决策树训练时结点的表示类LearningNode,在训练完成后通过LearningNode.toNode方法,将其转变为InternalNode或者LeafNode。
说一下几个参数的意思:
prediction:预测类别或者回归值
impurity:不纯度,Spark实现了三种不纯度度量方式:熵、信息增益、残差(适用于回归)。
leftChild、rightChild:左右子节点
split:Node在进行预测时,需要用到split存储的结点信息,由split来决定选择左结点还是右结点。
结点分裂信息类Split:
Spark实现了2个结点选取类CategoricalSplit和ContinuousSplit,分别完成分类特征和连续特征下的子结点选取问题。
CategoricalSplit:将分类特征的属性值集分成2个集合(左集合)和右集合,判断属性值属于哪个集合来决定选取哪个子节点。
ContinuousSplit:针对连续型特征的子节点选取类,输入的特征值与设定的阀值threshold比较大小,来决定是选取左子节点还是右子结点。
决策树特征选择与分裂:
选择一个合适的特征作为判断节点,可以快速的分类,减少决策树的深度。决策树的目标就是把数据集按对应的类标签进行分类。最理想的情况是,通过特征的选择能把不同类别的数据集贴上对应类标签。特征选择的目标使得分类后的数据集比较纯。
Spark实现了3类数据不纯度度量算法:Giniimpurity、Entropy、Variance,都继承自Impurity类并覆写了不纯度计算方法calculate。
Gini impurity:
采用基尼指数来度量数据的不纯度,计算公式如下:
计算代码如下:
[java] view plain copy print?在CODE上查看代码片派生到我的代码片
pre code_snippet_id=2325202 snippet_file_name=blog1_1999783 name=code class=javaoverride def calculate(counts: Array[Double], totalCount: Double): Double = {
if (totalCount == 0) {
return 0
}
val numClasses = counts.length
//∑Ci=1fi(1?fi) = ∑Ci=1fi + ∑Ci=1fi*fi, 其中前半部分为1 实际只需要计算后半部分。
var impurity = 1.0
var classIndex = 0
while (classIndex numClasses) {
val freq = counts(cl
文档评论(0)