【二分类】用C#调用python训练好的XGBoost模型

作者:luozhipeng   发表日期:2016-9-30  浏览:924次

由于XGBoost没有C#版本,且C#调用Python也不太好用,所以写了一个C#加载模型并做预测的类,大家可以参考写成其他语言。(由于实验室没有缺失特征,所以此版本没有对缺失值处理)

 

训练完XGBoost后保存成dump:

bst.dump_model('xgb.'+str(k)+'.model.dump', fmap='', with_stats=True)

 

模型文件如下:

 

 

 

C#加载模型并预测代码如下:

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

using System.IO;

namespace Official
{
    public class Node
    {
        public string fnum = "";
        public float thre = -1;
        public float value = -1;
        public Node left;
        public Node right;
    }

    class XGBoost
    {
        public List<Node> BTree = new List<Node>();

        private static void CreateBitTree(Node t, int i, Dictionary<int, string> dic)
        {
            string line = dic[i];

            if(line.IndexOf("leaf") >= 0)
            {
                string[] p1s = line.Split(new string[] { "," }, StringSplitOptions.RemoveEmptyEntries);
                string[] p2s = p1s[0].Split(new string[] { "=" }, StringSplitOptions.RemoveEmptyEntries);
                float value = float.Parse(p2s[1]);
                t.value = value;
                t.left = null;
                t.right = null;
            }
            else
            {
                string[] p1s = line.Split(new string[] { " " }, StringSplitOptions.RemoveEmptyEntries);

                string fs = p1s[0].Substring(1, p1s[0].Length - 2);
                string[] p2s = fs.Split(new string[] { "<" }, StringSplitOptions.RemoveEmptyEntries);
                string fnum = p2s[0];
                float thre = float.Parse(p2s[1]);

                string[] yn = p1s[1].Split(new string[] { "," }, StringSplitOptions.RemoveEmptyEntries);
                string[] yn0 = yn[0].Split(new string[] { "=" }, StringSplitOptions.RemoveEmptyEntries);
                int yes = int.Parse(yn0[1]);

                string[] yn1 = yn[1].Split(new string[] { "=" }, StringSplitOptions.RemoveEmptyEntries);
                int no = int.Parse(yn1[1]);

                t.fnum = fnum;
                t.thre = thre;

                Node left = new Node();
                CreateBitTree(left, yes, dic);
                t.left = left;

                Node right = new Node();
                CreateBitTree(right, no, dic);
                t.right = right;
            }
            
        }

        private void AddTree(Dictionary<int, string> dic)
        {
            Node hNode = new Node();
            CreateBitTree(hNode, 0, dic);
            this.BTree.Add(hNode);
        }

        public void LoadModel(string modelPath)
        {
            string line = "";
            StreamReader sr = new StreamReader(modelPath);

            Dictionary<int, string> dic = new Dictionary<int, string>();
            line = sr.ReadLine();
            while( (line = sr.ReadLine()) != null)
            {
                if (line.IndexOf("booster") >= 0)
                {
                    AddTree(dic);
                    dic.Clear();
                }
                else
                {
                    string[] parts = line.Split(new string[] { ":" }, StringSplitOptions.RemoveEmptyEntries);
                    int id = int.Parse(parts[0]);
                    dic.Add(id, parts[1]);
                }
            }
            AddTree(dic);
            sr.Close();
        }

        public float Predict(Dictionary<string, float> feats)
        {
            float sco = 0;
            foreach(Node nd in BTree)
            {
                Node t = nd;
                while(t.left != null)
                {
                    if (feats[t.fnum] < t.thre)
                        t = t.left;
                    else
                        t = t.right;
                }
                sco += t.value;
            }

            sco = Sigmoid(sco);
            sco = (float) Math.Round(sco, 12);
            return sco;
        }

        private static float Sigmoid(float x)
        {
            return (float) ( 1.0 / (1.0 + Math.Exp(-x)) );
        }

        private static Dictionary<string, float> TestQueryTitleFeat(string query, string title)
        {
            int fid = 0;
            Dictionary<string, float> feats = new Dictionary<string, float>();

            List<float> qlis = CDSSMParsing.String2List(CDSSMParsing.Query2Vec(query));
            foreach (float x in qlis) //query vector
                feats.Add("f" + fid++, x);

            List<float> tlis = CDSSMParsing.String2List(CDSSMParsing.Doc2Vec(title));
            foreach (float x in tlis) //title vector
                feats.Add("f" + fid++, x);

            List<float> wf = Level3Ranking.WordFeature(query, title);
            foreach (float x in wf) //word feature
                feats.Add("f" + fid++, x);

            feats.Add("f" + fid++, Level3Ranking.CosSmi(qlis, tlis)); //cos

            List<float> dotFeat = Level3Ranking.DotFeature(qlis, tlis);
            foreach (float x in dotFeat) //dot feature
                feats.Add("f" + fid++, x);

            return feats;
        }

        public static void Proccess()
        {
            XGBoost xgb = new XGBoost();
            xgb.LoadModel(@"C:\Users\v-zhiplu\Desktop\official\data\newRank\model\xgb.5.model.dump");

            CDSSMParsing.Init();

            while(true)
            {
                Console.Write("query:");
                string query = Console.ReadLine();
                Console.Write("title:");
                string title = Console.ReadLine();

                Dictionary<string, float> feats = TestQueryTitleFeat(query, title);  //自己的feature,id:value, id从f0开始
                Console.WriteLine(xgb.Predict(feats));
            }
        }
    }
}

标签:

本文固定链接: http://www.luozhipeng.com/?p=1182
转载请注明: luozhipeng 2016-9-30 于 罗志鹏的BLOG 发表

上一篇: :下一篇
返回顶部