【多分类】用Java调用python训练好的XGBoost模型

作者:luozhipeng   发表日期:2016-12-27  浏览:1,616次

由于XGBoost没有Java版本,且Java调用Python也不太好用,所以写了一个Java加载模型并做预测的类,代码中对允许特征缺失,大家可以参考写成其他语言。

 

 

训练完XGBoost后保存成dump:

bst.dump_model('xgb.'+str(k)+'.model.dump', fmap='', with_stats=True)

 

模型文件如下:

Java加载模型并预测代码如下:
package com.cm.bot;

import java.awt.List;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Dictionary;
import java.util.HashMap;
import java.util.Scanner;

import com.hankcs.hanlp.dependency.nnparser.util.math;

public class Softmax
{    
    public ArrayList<Node> Btree = new ArrayList<Node>();
    public int class_num = 0;
    public int tree_num = 0;
    
    static MySegment mySeg;
    
    static HashMap<String, Integer> word2id = new HashMap<>();
    static HashMap<String, Integer> wordcnt = new HashMap<>();
    
    static String dir = "C:\\Users\\Administrator\\Desktop\\bot\\python\\";
    
    public Softmax(int cnum, int tnum)
    {
        class_num = cnum;
        tree_num = tnum;
        mySeg = new MySegment();
        
        String line = "";
        try
        {
            BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(dir+"df.tsv"),"UTF-8"));        
            while ( (line = br.readLine())!= null)
            {
                String[] parts = line.split("\t");
                
                if(!parts[0].equals("-1"))
                    word2id.put(parts[1], Integer.parseInt(parts[0]));
                wordcnt.put(parts[1], Integer.parseInt(parts[2]));
            }
            br.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }

    }
    
    private static void CreateBitTree(Node t, int i, HashMap<Integer,String> dic)
    {
        String line = dic.get(i);

        if(line.indexOf("leaf") >= 0)
        {//到达叶子节点
            String[] p1s = line.split(",");
            String[] p2s = p1s[0].split("=");
            float value = Float.parseFloat(p2s[1]);
            t.value = value;
            t.left = null;
            t.right = null;
        }
        else
        {
            String[] p1s = line.split(" ");
            String fs = p1s[0].substring(1, p1s[0].length()-1);
            String[] p2s = fs.split("<");
            
            String fnum = p2s[0];
            float thre = Float.parseFloat(p2s[1]); //分隔值
            
            String[] yn = p1s[1].split(",");
            
            String[] yn0 = yn[0].split("=");
            int yes = Integer.parseInt(yn0[1]); //左节点
            
            String[] yn1 = yn[1].split("=");
            int no = Integer.parseInt(yn1[1]); //右节点
            
            String[] yn2 = yn[2].split("=");
            int miss = Integer.parseInt(yn2[1]); //缺失
            
            t.fnum = fnum;
            t.thre = thre;
            
            if(yes == miss)
                t.missing = 0;
            else
                t.missing = 1;
            
            Node left = new Node();
            CreateBitTree(left, yes, dic);
            t.left = left;
            
            Node right = new Node();
            CreateBitTree(right, no, dic);
            t.right = right;
        }
    }
    
    private void AddTree(HashMap<Integer, String> dic)
    {
        Node hNode = new Node();
        CreateBitTree(hNode, 0, dic);
        this.Btree.add(hNode);
    }
    
    public void LoadModel(String modelPath)
    {
        try
        {
            String line = "";
            BufferedReader br =new BufferedReader(new InputStreamReader(new FileInputStream(modelPath),"UTF-8"));
            
            HashMap<Integer, String> dic = new HashMap<>();
            line = br.readLine();
            while ( (line = br.readLine())!= null)
            {
                if(line.indexOf("booster") >= 0)
                {
                    AddTree(dic);
                    dic.clear();
                }
                else
                {
                    String[] parts = line.trim().split(":");
                    int id = Integer.parseInt(parts[0]);
                    dic.put(id, parts[1]);
                }
            }
            AddTree(dic);
            
            br.close();
        }
        catch (Exception e) {
            // TODO: handle exception
            e.printStackTrace();
        }
    }
    
    
    public int Predict(String str)
    {
        HashMap<String, Float> feats = string2Feature(str);
        
        int cate = 0;
        double[] scos = {0.0f, 0.0f, 0.0f};
        
        for(int i = 0; i < Btree.size(); ++i)
        {
            Node t = Btree.get(i);
            while (t.left != null)
            {
                if(feats.containsKey(t.fnum))
                {//包含该特征
                    if(feats.get(t.fnum) < t.thre)
                        t= t.left;
                    else
                        t= t.right;
                }
                else
                {//特征缺失
                    if(t.missing == 0)
                        t= t.left;
                    else
                        t= t.right;
                }
            }
            
            int tid = i%class_num;
            scos[tid] += t.value;
        }
        
        double sum = 0.0f;
        double pmax = 0;
        for(int i = 0; i < scos.length; ++i)
        {
            scos[i] = Math.exp(scos[i]);
            sum += scos[i];
        }
        
        for(int i = 0; i < scos.length; ++i)
        {
            scos[i] /= sum;
            if(scos[i] > pmax)
            {
                cate = i;
                pmax = scos[i];
            }
            
            System.out.println(scos[i]);
        }
        
        return cate;
    }
    
    
    public static HashMap<String, Float> string2Feature(String str)
    {
        HashMap<String, Float> feats = new HashMap<>();
        
        str = str.trim().toLowerCase().replaceAll("\t", "");
        str = mySeg.seg(str, false, true);
        String[] words = str.split(" ");
        
        HashMap<String, Integer> tf = new HashMap<>();
        HashMap<Integer, Double> tfidf = new HashMap<>();
        
        for(String word:words)
        {
            if(word2id.containsKey(word))
            {
                if(!tf.containsKey(word))
                    tf.put(word, 1);
                else
                    tf.put(word, tf.get(word) + 1);
            }
        }

        
        double sum = 0.0;
        for(String word:tf.keySet())
        {
            double d = tf.get(word) * Math.log(wordcnt.get("doc_num")/wordcnt.get(word)); //tf*idf
            
            sum += d*d;
            tfidf.put(word2id.get(word), d);
        }
        
        for(Integer id:tfidf.keySet())
        {
            feats.put("f"+id, (float)(tfidf.get(id)/Math.sqrt(sum) ) );
        }
        
        return feats;
    }
    
    
    
    public static void main(String[] args)
    {
        Softmax sf = new Softmax(3, 600); //3类, 600颗树
        sf.LoadModel(dir + "xgb.600.model.dump");
        
        HashMap<String, Float> feats = string2Feature("今天天气怎么样?"); //自己的feature,id:value, id从f0开
        for (String str : feats.keySet())
        {
            System.out.println(str + "\t" + feats.get(str));
        }
        
        Scanner scanner= new Scanner(System.in);
        while(true)
        {
            String line = scanner.nextLine();
            System.out.println(sf.Predict(line));
        }

    }
}

姊妹篇:【二分类】用C#调用python训练好的XGBoost模型

本文固定链接: http://www.luozhipeng.com/?p=1207
转载请注明: luozhipeng 2016-12-27 于 罗志鹏的BLOG 发表

上一篇: :下一篇
返回顶部