FP-Tree算法的实现

向右看齐 2022-01-16 10:45 226阅读 0赞

在关联规则挖掘领域最经典的算法法是Apriori,其致命的缺点是需要多次扫描事务数据库。于是人们提出了各种裁剪(prune)数据集的方法以减少I/O开支,韩嘉炜老师的FP-Tree算法就是其中非常高效的一种。

支持度和置信度

严格地说Apriori和FP-Tree都是寻找频繁项集的算法,频繁项集就是所谓的“支持度”比较高的项集,下面解释一下支持度和置信度的概念。

设事务数据库为:

复制代码

  1. A  E  F  G
  2. A  F  G
  3. A  B  E  F  G
  4. E  F  G

复制代码

则{A,F,G}的支持度数为3,支持度为3/4。

{F,G}的支持度数为4,支持度为4/4。

{A}的支持度数为3,支持度为3/4。

{F,G}=>{A}的置信度为:{A,F,G}的支持度数 除以 {F,G}的支持度数,即3/4

{A}=>{F,G}的置信度为:{A,F,G}的支持度数 除以 {A}的支持度数,即3/3

强关联规则挖掘是在满足一定支持度的情况下寻找置信度达到阈值的所有模式。

FP-Tree算法

我们举个例子来详细讲解FP-Tree算法的完整实现。

事务数据库如下,一行表示一条购物记录:

复制代码

  1. 牛奶,鸡蛋,面包,薯片
  2. 鸡蛋,爆米花,薯片,啤酒
  3. 鸡蛋,面包,薯片
  4. 牛奶,鸡蛋,面包,爆米花,薯片,啤酒
  5. 牛奶,面包,啤酒
  6. 鸡蛋,面包,啤酒
  7. 牛奶,面包,薯片
  8. 牛奶,鸡蛋,面包,黄油,薯片
  9. 牛奶,鸡蛋,黄油,薯片

复制代码

我们的目的是要找出哪些商品总是相伴出现的,比如人们买薯片的时候通常也会买鸡蛋,则[薯片,鸡蛋]就是一条频繁模式(frequent pattern)。

FP-Tree算法第一步:扫描事务数据库,每项商品按频数递减排序,并删除频数小于最小支持度MinSup的商品。(第一次扫描数据库)

薯片:7鸡蛋:7面包:7牛奶:6啤酒:4 (这里我们令MinSup=3)

以上结果就是频繁1项集,记为F1。

第二步:对于每一条购买记录,按照F1中的顺序重新排序。(第二次也是最后一次扫描数据库)

复制代码

  1. 薯片,鸡蛋,面包,牛奶
  2. 薯片,鸡蛋,啤酒
  3. 薯片,鸡蛋,面包
  4. 薯片,鸡蛋,面包,牛奶,啤酒
  5. 面包,牛奶,啤酒
  6. 鸡蛋,面包,啤酒
  7. 薯片,面包,牛奶
  8. 薯片,鸡蛋,面包,牛奶
  9. 薯片,鸡蛋,牛奶

复制代码

第三步:把第二步得到的各条记录插入到FP-Tree中。刚开始时后缀模式为空。

插入每一条(薯片,鸡蛋,面包,牛奶)之后

2011100415080029.png

插入第二条记录(薯片,鸡蛋,啤酒)

2011100415073229.png

插入第三条记录(面包,牛奶,啤酒)

2011100415081681.png

估计你也知道怎么插了,最终生成的FP-Tree是:

2011100415083056.png

上图中左边的那一叫做表头项,树中相同名称的节点要链接起来,链表的第一个元素就是表头项里的元素。

如果FP-Tree为空(只含一个虚的root节点),则FP-Growth函数返回。

此时输出表头项的每一项+postModel,支持度为表头项中对应项的计数。

第四步:从FP-Tree中找出频繁项。

遍历表头项中的每一项(我们拿“牛奶:6”为例),对于各项都执行以下(1)到(5)的操作:

(1)从FP-Tree中找到所有的“牛奶”节点,向上遍历它的祖先节点,得到4条路径:

复制代码

  1. 薯片:7,鸡蛋:6,牛奶:1
  2. 薯片:7,鸡蛋:6,面包:4,牛奶:3
  3. 薯片:7,面包:1,牛奶:1
  4. 面包:1,牛奶:1

复制代码

对于每一条路径上的节点,其count都设置为牛奶的count

复制代码

  1. 薯片:1,鸡蛋:1,牛奶:1
  2. 薯片:3,鸡蛋:3,面包:3,牛奶:3
  3. 薯片:1,面包:1,牛奶:1
  4. 面包:1,牛奶:1

复制代码

因为每一项末尾都是牛奶,可以把牛奶去掉,得到条件模式基(Conditional Pattern Base,CPB),此时的后缀模式是:(牛奶)。

复制代码

  1. 薯片:1,鸡蛋:1
  2. 薯片:3,鸡蛋:3,面包:3
  3. 薯片:1,面包:1
  4. 面包:1

复制代码

(2)我们把上面的结果当作原始的事务数据库,返回到第3步,递归迭代运行。

没讲清楚,你可以参考这篇博客,直接看核心代码吧:

复制代码

  1. public void FPGrowth(List<List<String>> transRecords,
  2. List<String> postPattern,Context context) throws IOException, InterruptedException {
  3. // 构建项头表,同时也是频繁1项集
  4. ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);
  5. // 构建FP-Tree
  6. TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);
  7. // 如果FP-Tree为空则返回
  8. if (treeRoot.getChildren()==null || treeRoot.getChildren().size() == 0)
  9. return;
  10. //输出项头表的每一项+postPattern
  11. if(postPattern!=null){
  12. for (TreeNode header : HeaderTable) {
  13. String outStr=header.getName();
  14. int count=header.getCount();
  15. for (String ele : postPattern)
  16. outStr+="\t" + ele;
  17. context.write(new IntWritable(count), new Text(outStr));
  18. }
  19. }
  20. // 找到项头表的每一项的条件模式基,进入递归迭代
  21. for (TreeNode header : HeaderTable) {
  22. // 后缀模式增加一项
  23. List<String> newPostPattern = new LinkedList<String>();
  24. newPostPattern.add(header.getName());
  25. if (postPattern != null)
  26. newPostPattern.addAll(postPattern);
  27. // 寻找header的条件模式基CPB,放入newTransRecords中
  28. List<List<String>> newTransRecords = new LinkedList<List<String>>();
  29. TreeNode backnode = header.getNextHomonym();
  30. while (backnode != null) {
  31. int counter = backnode.getCount();
  32. List<String> prenodes = new ArrayList<String>();
  33. TreeNode parent = backnode;
  34. // 遍历backnode的祖先节点,放到prenodes中
  35. while ((parent = parent.getParent()).getName() != null) {
  36. prenodes.add(parent.getName());
  37. }
  38. while (counter-- > 0) {
  39. newTransRecords.add(prenodes);
  40. }
  41. backnode = backnode.getNextHomonym();
  42. }
  43. // 递归迭代
  44. FPGrowth(newTransRecords, newPostPattern,context);
  45. }
  46. }

复制代码

对于FP-Tree已经是单枝的情况,就没有必要再递归调用FPGrowth了,直接输出整条路径上所有节点的各种组合+postModel就可了。例如当FP-Tree为:

2012081411593897.png

我们直接输出:

3  A+postModel

3  B+postModel

3  A+B+postModel

就可以了。

如何按照上面代码里的做法,是先输出:

3  A+postModel

3  B+postModel

然后把B插入到postModel的头部,重新建立一个FP-Tree,这时Tree中只含A,于是输出

3  A+(B+postModel)

两种方法结果是一样的,但毕竟重新建立FP-Tree计算量大些。

Java实现

FP树节点定义

?










package 
fptree;


  


import 
java.util.ArrayList;


import 
java.util.List;


  


public 
class 
TreeNode
implements 
Comparable<TreeNode> {


  


    
private 
String name;
// 节点名称


    
private 
int 
count;
// 计数


    
private 
TreeNode parent;
// 父节点


    
private 
List<TreeNode> children;
// 子节点


    
private 
TreeNode nextHomonym;
// 下一个同名节点


  


    
public 
TreeNode() {


  


    
}


  


    
public 
TreeNode(String name) {


        
this
.name = name;


    
}


  


    
public 
String getName() {


        
return 
name;


    
}


  


    
public 
void 
setName(String name) {


        
this
.name = name;


    
}


  


    
public 
int 
getCount() {


        
return 
count;


    
}


  


    
public 
void 
setCount(
int 
count) {


        
this
.count = count;


    
}


  


    
public 
TreeNode getParent() {


        
return 
parent;


    
}


  


    
public 
void 
setParent(TreeNode parent) {


        
this
.parent = parent;


    
}


  


    
public 
List<TreeNode> getChildren() {


        
return 
children;


    
}


  


    
public 
void 
addChild(TreeNode child) {


        
if 
(
this
.getChildren() ==
null
) {


            
List<TreeNode> list =
new 
ArrayList<TreeNode>();


            
list.add(child);


            
this
.setChildren(list);


        
}
else 
{


            
this
.getChildren().add(child);


        
}


    
}


  


    
public 
TreeNode findChild(String name) {


        
List<TreeNode> children =
this
.getChildren();


        
if 
(children !=
null
) {


            
for 
(TreeNode child : children) {


                
if 
(child.getName().equals(name)) {


                    
return 
child;


                
}


            
}


        
}


        
return 
null
;


    
}


  


    
public 
void 
setChildren(List<TreeNode> children) {


        
this
.children = children;


    
}


  


    
public 
void 
printChildrenName() {


        
List<TreeNode> children =
this
.getChildren();


        
if 
(children !=
null
) {


            
for 
(TreeNode child : children) {


                
System.out.print(child.getName() +
“ “
);


            
}


        
}
else 
{


            
System.out.print(
“null”
);


        
}


    
}


  


    
public 
TreeNode getNextHomonym() {


        
return 
nextHomonym;


    
}


  


    
public 
void 
setNextHomonym(TreeNode nextHomonym) {


        
this
.nextHomonym = nextHomonym;


    
}


  


    
public 
void 
countIncrement(
int 
n) {


        
this
.count += n;


    
}


  


    
@Override


    
public 
int 
compareTo(TreeNode arg0) {


        
// TODO Auto-generated method stub


        
int 
count0 = arg0.getCount();


        
// 跟默认的比较大小相反,导致调用Arrays.sort()时是按降序排列


        
return 
count0 -
this
.count;


    
}


}

挖掘频繁模式

?










package 
fptree;


 


import 
java.io.BufferedReader;


import 
java.io.FileReader;


import 
java.io.IOException;


import 
java.util.ArrayList;


import 
java.util.Collections;


import 
java.util.Comparator;


import 
java.util.HashMap;


import 
java.util.LinkedList;


import 
java.util.List;


import 
java.util.Map;


import 
java.util.Map.Entry;


import 
java.util.Set;


 


public 
class 
FPTree {


 


    
private 
int 
minSuport;


 


    
public 
int 
getMinSuport() {


        
return 
minSuport;


    
}


 


    
public 
void 
setMinSuport(
int 
minSuport) {


        
this
.minSuport = minSuport;


    
}


 


    
// 从若干个文件中读入Transaction Record


    
public 
List<List<String>> readTransRocords(String… filenames) {


        
List<List<String>> transaction =
null
;


        
if 
(filenames.length >
0
) {


            
transaction =
new 
LinkedList<List<String>>();


            
for 
(String filename : filenames) {


                
try 
{


                    
FileReader fr =
new 
FileReader(filename);


                    
BufferedReader br =
new 
BufferedReader(fr);


                    
try 
{


                        
String line;


                        
List<String> record;


                        
while 
((line = br.readLine()) !=
null
) {


                            
if
(line.trim().length()>
0
){


                                
String str[] = line.split(
“,”
);


                                
record =
new 
LinkedList<String>();


                                
for 
(String w : str)


                                    
record.add(w);


                                
transaction.add(record);


                            
}


                        
}


                    
}
finally 
{


                        
br.close();


                    
}


                
}
catch 
(IOException ex) {


                    
System.out.println(
“Read transaction records failed.”


                            
+ ex.getMessage());


                    
System.exit(
1
);


                
}


            
}


        
}


        
return 
transaction;


    
}


 


    
// FP-Growth算法


    
public 
void 
FPGrowth(List<List<String>> transRecords,


            
List<String> postPattern) {


        
// 构建项头表,同时也是频繁1项集


        
ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);


        
// 构建FP-Tree


        
TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);


        
// 如果FP-Tree为空则返回


        
if 
(treeRoot.getChildren()==
null 
|| treeRoot.getChildren().size() ==
0
)


            
return
;


        
//输出项头表的每一项+postPattern


        
if
(postPattern!=
null
){


            
for 
(TreeNode header : HeaderTable) {


                
System.out.print(header.getCount() +
“\t” 
+ header.getName());


                
for 
(String ele : postPattern)


                    
System.out.print(
“\t” 
+ ele);


                
System.out.println();


            
}


        
}


        
// 找到项头表的每一项的条件模式基,进入递归迭代


        
for 
(TreeNode header : HeaderTable) {


            
// 后缀模式增加一项


            
List<String> newPostPattern =
new 
LinkedList<String>();


            
newPostPattern.add(header.getName());


            
if 
(postPattern !=
null
)


                
newPostPattern.addAll(postPattern);


            
// 寻找header的条件模式基CPB,放入newTransRecords中


            
List<List<String>> newTransRecords =
new 
LinkedList<List<String>>();


            
TreeNode backnode = header.getNextHomonym();


            
while 
(backnode !=
null
) {


                
int 
counter = backnode.getCount();


                
List<String> prenodes =
new 
ArrayList<String>();


                
TreeNode parent = backnode;


                
// 遍历backnode的祖先节点,放到prenodes中


                
while 
((parent = parent.getParent()).getName() !=
null
) {


                    
prenodes.add(parent.getName());


                
}


                
while 
(counter— >
0
) {


                    
newTransRecords.add(prenodes);


                
}


                
backnode = backnode.getNextHomonym();


            
}


            
// 递归迭代


            
FPGrowth(newTransRecords, newPostPattern);


        
}


    
}


 


    
// 构建项头表,同时也是频繁1项集


    
public 
ArrayList<TreeNode> buildHeaderTable(List<List<String>> transRecords) {


        
ArrayList<TreeNode> F1 =
null
;


        
if 
(transRecords.size() >
0
) {


            
F1 =
new 
ArrayList<TreeNode>();


            
Map<String, TreeNode> map =
new 
HashMap<String, TreeNode>();


            
// 计算事务数据库中各项的支持度


            
for 
(List<String> record : transRecords) {


                
for 
(String item : record) {


                    
if 
(!map.keySet().contains(item)) {


                        
TreeNode node =
new 
TreeNode(item);


                        
node.setCount(
1
);


                        
map.put(item, node);


                    
}
else 
{


                        
map.get(item).countIncrement(
1
);


                    
}


                
}


            
}


            
// 把支持度大于(或等于)minSup的项加入到F1中


            
Set<String> names = map.keySet();


            
for 
(String name : names) {


                
TreeNode tnode = map.get(name);


                
if 
(tnode.getCount() >= minSuport) {


                    
F1.add(tnode);


                
}


            
}


            
Collections.sort(F1);


            
return 
F1;


        
}
else 
{


            
return 
null
;


        
}


    
}


 


    
// 构建FP-Tree


    
public 
TreeNode buildFPTree(List<List<String>> transRecords,


            
ArrayList<TreeNode> F1) {


        
TreeNode root =
new 
TreeNode();
// 创建树的根节点


        
for 
(List<String> transRecord : transRecords) {


            
LinkedList<String> record = sortByF1(transRecord, F1);


            
TreeNode subTreeRoot = root;


            
TreeNode tmpRoot =
null
;


            
if 
(root.getChildren() !=
null
) {


                
while 
(!record.isEmpty()


                        
&& (tmpRoot = subTreeRoot.findChild(record.peek())) !=
null
) {


                    
tmpRoot.countIncrement(
1
);


                    
subTreeRoot = tmpRoot;


                    
record.poll();


                
}


            
}


            
addNodes(subTreeRoot, record, F1);


        
}


        
return 
root;


    
}


 


    
// 把交易记录按项的频繁程序降序排列


    
public 
LinkedList<String> sortByF1(List<String> transRecord,


            
ArrayList<TreeNode> F1) {


        
Map<String, Integer> map =
new 
HashMap<String, Integer>();


        
for 
(String item : transRecord) {


            
// 由于F1已经是按降序排列的,


            
for 
(
int 
i =
0
; i < F1.size(); i++) {


                
TreeNode tnode = F1.get(i);


                
if 
(tnode.getName().equals(item)) {


                    
map.put(item, i);


                
}


            
}


        
}


        
ArrayList<Entry<String, Integer>> al =
new 
ArrayList<Entry<String, Integer>>(


                
map.entrySet());


        
Collections.sort(al,
new 
Comparator<Map.Entry<String, Integer>>() {


            
@Override


            
public 
int 
compare(Entry<String, Integer> arg0,


                    
Entry<String, Integer> arg1) {


                
// 降序排列


                
return 
arg0.getValue() - arg1.getValue();


            
}


        
});


        
LinkedList<String> rest =
new 
LinkedList<String>();


        
for 
(Entry<String, Integer> entry : al) {


            
rest.add(entry.getKey());


        
}


        
return 
rest;


    
}


 


    
// 把record作为ancestor的后代插入树中


    
public 
void 
addNodes(TreeNode ancestor, LinkedList<String> record,


            
ArrayList<TreeNode> F1) {


        
if 
(record.size() >
0
) {


            
while 
(record.size() >
0
) {


                
String item = record.poll();


                
TreeNode leafnode =
new 
TreeNode(item);


                
leafnode.setCount(
1
);


                
leafnode.setParent(ancestor);


                
ancestor.addChild(leafnode);


 


                
for 
(TreeNode f1 : F1) {


                    
if 
(f1.getName().equals(item)) {


                        
while 
(f1.getNextHomonym() !=
null
) {


                            
f1 = f1.getNextHomonym();


                        
}


                        
f1.setNextHomonym(leafnode);


                        
break
;


                    
}


                
}


 


                
addNodes(leafnode, record, F1);


            
}


        
}


    
}


 


    
public 
static 
void 
main(String[] args) {


        
FPTree fptree =
new 
FPTree();


        
fptree.setMinSuport(
3
);


        
List<List<String>> transRecords = fptree


                
.readTransRocords(
“/home/orisun/test/market”
);


        
fptree.FPGrowth(transRecords,
null
);


    
}


}

输入文件

复制代码

  1. 牛奶,鸡蛋,面包,薯片
  2. 鸡蛋,爆米花,薯片,啤酒
  3. 鸡蛋,面包,薯片
  4. 牛奶,鸡蛋,面包,爆米花,薯片,啤酒
  5. 牛奶,面包,啤酒
  6. 鸡蛋,面包,啤酒
  7. 牛奶,面包,薯片
  8. 牛奶,鸡蛋,面包,黄油,薯片
  9. 牛奶,鸡蛋,黄油,薯片

复制代码

输出

复制代码

  1. 6 薯片 鸡蛋
  2. 5 薯片 面包
  3. 5 鸡蛋 面包
  4. 4 薯片 鸡蛋 面包
  5. 5 薯片 牛奶
  6. 5 面包 牛奶
  7. 4 鸡蛋 牛奶
  8. 4 薯片 面包 牛奶
  9. 4 薯片 鸡蛋 牛奶
  10. 3 面包 鸡蛋 牛奶
  11. 3 薯片 面包 鸡蛋 牛奶
  12. 3 鸡蛋 啤酒
  13. 3 面包 啤酒

复制代码

用Hadoop来实现

在上面的代码我们把整个事务数据库放在一个List>里面传给FPGrowth,在实际中这是不可取的,因为内存不可能容下整个事务数据库,我们可能需要从关系关系数据库中一条一条地读入来建立FP-Tree。但无论如何 FP-Tree是肯定需要放在内存中的,但内存如果容不下怎么办?另外FPGrowth仍然是非常耗时的,你想提高速度怎么办?解决办法:分而治之,并行计算。

我们把原始事务数据库分成N部分,在N个节点上并行地进行FPGrowth挖掘,最后把关联规则汇总到一起就可以了。关键问题是怎么“划分”才会不遗露任何一条关联规则呢?参见这篇博客。这里为了达到并行计算的目的,采用了一种“冗余”的划分方法,即各部分的并集大于原来的集合。这种方法最终求出来的关联规则也是有冗余的,比如在节点1上得到一条规则(6:啤酒,尿布),在节点2上得到一条规则(3:尿布,啤酒),显然节点2上的这条规则是冗余的,需要采用后续步骤把冗余的规则去掉。

代码:

Record.java

?










package 
fptree;


 


import 
java.io.DataInput;


import 
java.io.DataOutput;


import 
java.io.IOException;


import 
java.util.Collections;


import 
java.util.LinkedList;


 


import 
org.apache.hadoop.io.WritableComparable;


 


public 
class 
Record
implements 
WritableComparable<Record>{


     


    
LinkedList<String> list;


     


    
public 
Record(){


        
list=
new 
LinkedList<String>();


    
}


     


    
public 
Record(String[] arr){


        
list=
new 
LinkedList<String>();


        
for
(
int 
i=
0
;i<arr.length;i++)


            
list.add(arr[i]);


    
}


     


    
@Override


    
public 
String toString(){


        
String str=list.get(
0
);


        
for
(
int 
i=
1
;i<list.size();i++)


            
str+=
“\t”
+list.get(i);


        
return 
str;


    
}


 


    
@Override


    
public 
void 
readFields(DataInput in)
throws 
IOException {


        
list.clear();


        
String line=in.readUTF();


        
String []arr=line.split(
“\s+”
);


        
for
(
int 
i=
0
;i<arr.length;i++)


            
list.add(arr[i]);


    
}


 


    
@Override


    
public 
void 
write(DataOutput out)
throws 
IOException {


        
out.writeUTF(
this
.toString());


    
}


 


    
@Override


    
public 
int 
compareTo(Record obj) {


        
Collections.sort(list);


        
Collections.sort(obj.list);


        
return 
this
.toString().compareTo(obj.toString());


    
}


 


}

DC_FPTree.java

?










package 
fptree;


 


import 
java.io.BufferedReader;


import 
java.io.IOException;


import 
java.io.InputStreamReader;


import 
java.util.ArrayList;


import 
java.util.BitSet;


import 
java.util.Collections;


import 
java.util.Comparator;


import 
java.util.HashMap;


import 
java.util.LinkedList;


import 
java.util.List;


import 
java.util.Map;


import 
java.util.Map.Entry;


import 
java.util.Set;


 


import 
org.apache.hadoop.conf.Configuration;


import 
org.apache.hadoop.conf.Configured;


import 
org.apache.hadoop.fs.FSDataInputStream;


import 
org.apache.hadoop.fs.FileSystem;


import 
org.apache.hadoop.fs.Path;


import 
org.apache.hadoop.io.IntWritable;


import 
org.apache.hadoop.io.LongWritable;


import 
org.apache.hadoop.io.Text;


import 
org.apache.hadoop.mapreduce.Job;


import 
org.apache.hadoop.mapreduce.Mapper;


import 
org.apache.hadoop.mapreduce.Reducer;


import 
org.apache.hadoop.mapreduce.lib.input.FileInputFormat;


import 
org.apache.hadoop.mapreduce.lib.input.TextInputFormat;


import 
org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;


import 
org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;


import 
org.apache.hadoop.util.Tool;


import 
org.apache.hadoop.util.ToolRunner;


 


public 
class 
DC_FPTree
extends 
Configured
implements 
Tool {


 


    
private 
static 
final 
int 
GroupNum =
10
;


    
private 
static 
final 
int 
minSuport=
6
;


 


    
public 
static 
class 
GroupMapper
extends


            
Mapper<LongWritable, Text, IntWritable, Record> {


        
List<String> freq =
new 
LinkedList<String>();
// 频繁1项集


        
List<List<String>> freq_group =
new 
LinkedList<List<String>>();
// 分组后的频繁1项集


 


        
@Override


        
public 
void 
setup(Context context)
throws 
IOException {


            
// 从文件读入频繁1项集


            
FileSystem fs = FileSystem.get(context.getConfiguration());


            
Path freqFile =
new 
Path(
“/user/orisun/input/F1”
);


            
FSDataInputStream in = fs.open(freqFile);


            
InputStreamReader isr =
new 
InputStreamReader(in);


            
BufferedReader br =
new 
BufferedReader(isr);


            
try 
{


                
String line;


                
while 
((line = br.readLine()) !=
null
) {


                    
String[] str = line.split(
“\s+”
);


                    
String word = str[
0
];


                    
freq.add(word);


                
}


            
}
finally 
{


                
br.close();


            
}


            
// 对频繁1项集进行分组


            
Collections.shuffle(freq);
// 打乱顺序


            
int 
cap = freq.size() / GroupNum;
// 每段分为一组


            
for 
(
int 
i =
0
; i < GroupNum; i++) {


                
List<String> list =
new 
LinkedList<String>();


                
for 
(
int 
j =
0
; j < cap; j++) {


                    
list.add(freq.get(i cap + j));


                
}


                
freq_group.add(list);


            
}


            
int 
remainder = freq.size() % GroupNum;


            
int 
base = GroupNum
cap;


            
for 
(
int 
i =
0
; i < remainder; i++) {


                
freq_group.get(i).add(freq.get(base + i));


            
}


        
}


 


        
@Override


        
public 
void 
map(LongWritable key, Text value, Context context)


                
throws 
IOException, InterruptedException {


            
String[] arr = value.toString().split(
“\s+”
);


            
Record record =
new 
Record(arr);


            
LinkedList<String> list = record.list;


            
BitSet bs=
new 
BitSet(freq_group.size());


            
bs.clear();


            
while 
(record.list.size() >
0
) {


                
String item = list.peekLast();
// 取出record的最后一项


                
int 
i=
0
;


                
for 
(; i < freq_group.size(); i++) {


                    
if
(bs.get(i))


                        
continue
;


                    
if 
(freq_group.get(i).contains(item)) {


                        
bs.set(i);


                        
break
;


                    
}


                
}


                
if
(i<freq_group.size()){    
//找到了


                    
context.write(
new 
IntWritable(i), record); 


                
}


                
record.list.pollLast();


            
}


        
}


    
}


     


    
public 
static 
class 
FPReducer
extends 
Reducer<IntWritable,Record,IntWritable,Text>{


        
public 
void 
reduce(IntWritable key,Iterable<Record> values,Context context)
throws 
IOException,InterruptedException{


            
List<List<String>> trans=
new 
LinkedList<List<String>>();


            
while
(values.iterator().hasNext()){


                
Record record=values.iterator().next();


                
LinkedList<String> list=
new 
LinkedList<String>();


                
for
(String ele:record.list)


                    
list.add(ele);


                
trans.add(list);


            
}


            
FPGrowth(trans,
null
,context);


        
}


        
// FP-Growth算法


    
public 
void 
FPGrowth(List<List<String>> transRecords,


            
List<String> postPattern,Context context)
throws 
IOException, InterruptedException {


        
// 构建项头表,同时也是频繁1项集


        
ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);


        
// 构建FP-Tree


        
TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);


        
// 如果FP-Tree为空则返回


        
if 
(treeRoot.getChildren()==
null 
|| treeRoot.getChildren().size() ==
0
)


            
return
;


        
//输出项头表的每一项+postPattern


        
if
(postPattern!=
null
){


            
for 
(TreeNode header : HeaderTable) {


                
String outStr=header.getName();


                
int 
count=header.getCount();


                
for 
(String ele : postPattern)


                    
outStr+=
“\t” 
+ ele;


                
context.write(
new 
IntWritable(count),
new 
Text(outStr));


            
}


        
}


        
// 找到项头表的每一项的条件模式基,进入递归迭代


        
for 
(TreeNode header : HeaderTable) {


            
// 后缀模式增加一项


            
List<String> newPostPattern =
new 
LinkedList<String>();


            
newPostPattern.add(header.getName());


            
if 
(postPattern !=
null
)


                
newPostPattern.addAll(postPattern);


            
// 寻找header的条件模式基CPB,放入newTransRecords中


            
List<List<String>> newTransRecords =
new 
LinkedList<List<String>>();


            
TreeNode backnode = header.getNextHomonym();


            
while 
(backnode !=
null
) {


                
int 
counter = backnode.getCount();


                
List<String> prenodes =
new 
ArrayList<String>();


                
TreeNode parent = backnode;


                
// 遍历backnode的祖先节点,放到prenodes中


                
while 
((parent = parent.getParent()).getName() !=
null
) {


                    
prenodes.add(parent.getName());


                
}


                
while 
(counter— >
0
) {


                    
newTransRecords.add(prenodes);


                
}


                
backnode = backnode.getNextHomonym();


            
}


            
// 递归迭代


            
FPGrowth(newTransRecords, newPostPattern,context);


        
}


    
}


 


        
// 构建项头表,同时也是频繁1项集


        
public 
ArrayList<TreeNode> buildHeaderTable(List<List<String>> transRecords) {


            
ArrayList<TreeNode> F1 =
null
;


            
if 
(transRecords.size() >
0
) {


                
F1 =
new 
ArrayList<TreeNode>();


                
Map<String, TreeNode> map =
new 
HashMap<String, TreeNode>();


                
// 计算事务数据库中各项的支持度


                
for 
(List<String> record : transRecords) {


                    
for 
(String item : record) {


                        
if 
(!map.keySet().contains(item)) {


                            
TreeNode node =
new 
TreeNode(item);


                            
node.setCount(
1
);


                            
map.put(item, node);


                        
}
else 
{


                            
map.get(item).countIncrement(
1
);


                        
}


                    
}


                
}


                
// 把支持度大于(或等于)minSup的项加入到F1中


                
Set<String> names = map.keySet();


                
for 
(String name : names) {


                    
TreeNode tnode = map.get(name);


                    
if 
(tnode.getCount() >= minSuport) {


                        
F1.add(tnode);


                    
}


                
}


                
Collections.sort(F1);


                
return 
F1;


            
}
else 
{


                
return 
null
;


            
}


        
}


 


        
// 构建FP-Tree


        
public 
TreeNode buildFPTree(List<List<String>> transRecords,


                
ArrayList<TreeNode> F1) {


            
TreeNode root =
new 
TreeNode();
// 创建树的根节点


            
for 
(List<String> transRecord : transRecords) {


                
LinkedList<String> record = sortByF1(transRecord, F1);


                
TreeNode subTreeRoot = root;


                
TreeNode tmpRoot =
null
;


                
if 
(root.getChildren() !=
null
) {


                    
while 
(!record.isEmpty()


                            
&& (tmpRoot = subTreeRoot.findChild(record.peek())) !=
null
) {


                        
tmpRoot.countIncrement(
1
);


                        
subTreeRoot = tmpRoot;


                        
record.poll();


                    
}


                
}


                
addNodes(subTreeRoot, record, F1);


            
}


            
return 
root;


        
}


 


        
// 把交易记录按项的频繁程序降序排列


        
public 
LinkedList<String> sortByF1(List<String> transRecord,


                
ArrayList<TreeNode> F1) {


            
Map<String, Integer> map =
new 
HashMap<String, Integer>();


            
for 
(String item : transRecord) {


                
// 由于F1已经是按降序排列的,


                
for 
(
int 
i =
0
; i < F1.size(); i++) {


                    
TreeNode tnode = F1.get(i);


                    
if 
(tnode.getName().equals(item)) {


                        
map.put(item, i);


                    
}


                
}


            
}


            
ArrayList<Entry<String, Integer>> al =
new 
ArrayList<Entry<String, Integer>>(


                    
map.entrySet());


            
Collections.sort(al,
new 
Comparator<Map.Entry<String, Integer>>() {


                
@Override


                
public 
int 
compare(Entry<String, Integer> arg0,


                        
Entry<String, Integer> arg1) {


                    
// 降序排列


                    
return 
arg0.getValue() - arg1.getValue();


                
}


            
});


            
LinkedList<String> rest =
new 
LinkedList<String>();


            
for 
(Entry<String, Integer> entry : al) {


                
rest.add(entry.getKey());


            
}


            
return 
rest;


        
}


 


        
// 把record作为ancestor的后代插入树中


        
public 
void 
addNodes(TreeNode ancestor, LinkedList<String> record,


                
ArrayList<TreeNode> F1) {


            
if 
(record.size() >
0
) {


                
while 
(record.size() >
0
) {


                    
String item = record.poll();


                    
TreeNode leafnode =
new 
TreeNode(item);


                    
leafnode.setCount(
1
);


                    
leafnode.setParent(ancestor);


                    
ancestor.addChild(leafnode);


 


                    
for 
(TreeNode f1 : F1) {


                        
if 
(f1.getName().equals(item)) {


                            
while 
(f1.getNextHomonym() !=
null
) {


                                
f1 = f1.getNextHomonym();


                            
}


                            
f1.setNextHomonym(leafnode);


                            
break
;


                        
}


                    
}


 


                    
addNodes(leafnode, record, F1);


                
}


            
}


        
}


    
}


     


    
public 
static 
class 
InverseMapper
extends


            
Mapper<LongWritable, Text, Record, IntWritable> {


        
@Override


        
public 
void 
map(LongWritable key, Text value, Context context)


                
throws 
IOException, InterruptedException {


            
String []arr=value.toString().split(
“\s+”
);


            
int 
count=Integer.parseInt(arr[
0
]);


            
Record record=
new 
Record();


            
for
(
int 
i=
1
;i<arr.length;i++){


                
record.list.add(arr[i]);


            
}


            
context.write(record,
new 
IntWritable(count));


        
}


    
}


     


    
public 
static 
class 
MaxReducer
extends 
Reducer<Record,IntWritable,IntWritable,Record>{


        
public 
void 
reduce(Record key,Iterable<IntWritable> values,Context context)
throws 
IOException,InterruptedException{


            
int 
max=-
1
;


            
for
(IntWritable value:values){


                
int 
i=value.get();


                
if
(i>max)


                    
max=i;


            
}


            
context.write(
new 
IntWritable(max), key);


        
}


    
}


 


 


    
@Override


    
public 
int 
run(String[] arg0)
throws 
Exception {


        
Configuration conf=getConf();


        
conf.set(
“mapred.task.timeout”
,
“6000000”
);


        
Job job=
new 
Job(conf);


        
job.setJarByClass(DC_FPTree.
class
);


        
FileSystem fs=FileSystem.get(getConf());


         


        
FileInputFormat.setInputPaths(job,
“/user/orisun/input/data”
);


        
Path outDir=
new 
Path(
“/user/orisun/output”
);


        
fs.delete(outDir,
true
);


        
FileOutputFormat.setOutputPath(job, outDir);


         


        
job.setMapperClass(GroupMapper.
class
);


        
job.setReducerClass(FPReducer.
class
);


         


        
job.setInputFormatClass(TextInputFormat.
class
);


        
job.setOutputFormatClass(TextOutputFormat.
class
);


        
job.setMapOutputKeyClass(IntWritable.
class
);


        
job.setMapOutputValueClass(Record.
class
);


        
job.setOutputKeyClass(IntWritable.
class
);


        
job.setOutputKeyClass(Text.
class
);


         


        
boolean 
success=job.waitForCompletion(
true
);


         


        
job=
new 
Job(conf);


        
job.setJarByClass(DC_FPTree.
class
);


         


        
FileInputFormat.setInputPaths(job,
“/user/orisun/output/part-r-*”
);


        
Path outDir2=
new 
Path(
“/user/orisun/output2”
);


        
fs.delete(outDir2,
true
);


        
FileOutputFormat.setOutputPath(job, outDir2);


         


        
job.setMapperClass(InverseMapper.
class
);


        
job.setReducerClass(MaxReducer.
class
);


        
//job.setNumReduceTasks(0);


         


        
job.setInputFormatClass(TextInputFormat.
class
);


        
job.setOutputFormatClass(TextOutputFormat.
class
);


        
job.setMapOutputKeyClass(Record.
class
);


        
job.setMapOutputValueClass(IntWritable.
class
);


        
job.setOutputKeyClass(IntWritable.
class
);


        
job.setOutputKeyClass(Record.
class
);


         


        
success |= job.waitForCompletion(
true
);


         


        
return 
success?
0
:
1
;


    
}


 


    
public 
static 
void 
main(String[] args)
throws 
Exception{


        
int 
res=ToolRunner.run(
new 
Configuration(),
new 
DC_FPTree(), args);


        
System.exit(res);


    
}


}

结束语

在实践中,关联规则挖掘可能并不像人们期望的那么有用。一方面是因为支持度置信度框架会产生过多的规则,并不是每一个规则都是有用的。另一方面大部分的关联规则并不像“啤酒与尿布”这种经典故事这么普遍。关联规则分析是需要技巧的,有时需要用更严格的统计学知识来控制规则的增殖。 

原文来自:博客园(华夏35度)http://www.cnblogs.com/zhangchaoyang 作者:Orisun

发表评论

表情:
评论列表 (有 0 条评论,226人围观)

还没有评论,来说两句吧...

相关阅读

    相关 HotSpot算法实现

    1.根节点的枚举 我们通过可达性分析算法从GC Roots中找到全局性的引用(例如常量或者类静态属性)或者是执行上下文(例如栈帧中的本地变量)中,尽管我们的目标非常明确,

    相关 HotSpot算法实现

    HotSpot虚拟机上实现算法(对象存活判定算法和垃圾收集算法)时,必须对算法的执行效率有严格的考量,才能保证虚拟机高效运行。 枚举根节点 从可达性分析中从GC Roo

    相关 HotSpot算法实现

    枚举根节点 ​ 由于目前的主流Java虚拟机使用的都是准确式GC,当执行系统停顿下来后,并不需要一个不漏地检查完所有执行上下文和全局的引用位置,虚拟机应当是有办法直接得知