KMeans++的C#实现

作者:luozhipeng   发表日期:2016-8-3  浏览:512次


using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

using System.IO;
using System.Diagnostics;

namespace Community
{

    class KMeans
    {

        public static void Init(string filepath)
        {
            string line = "";
            StreamReader sr = new StreamReader(filepath);

            int t = 0;

            while( (line = sr.ReadLine()) != null )
            {
                /*
                if (++t > 1000)
                    break;
               */

                string[] parts = line.Split(new string[] { "\t" }, StringSplitOptions.RemoveEmptyEntries);
                string[] vecs = parts[1].Split(new string[] { " " }, StringSplitOptions.RemoveEmptyEntries);

                List<float> lis = new List<float>();
                foreach(string vec in vecs)
                {
                   // Console.WriteLine(vec);

                    lis.Add(float.Parse(vec));
 
                }

                if (id_vec.ContainsKey(parts[0])) //duplicate
                    continue;

                id_vec.Add(parts[0], lis);
                sampleId.Add(parts[0]);
            }

            sr.Close();
        }

        private static float CosDis(List<float> lis1, List<float> lis2)
        {
            double dis = 0.0f;
            float a = 0.0f, b = 0.0f, c = 0.0f;

            for (int i = 0; i < lis1.Count; ++i)
            {
                a += lis1[i] * lis2[i];
                b += lis1[i] * lis1[i];
                c += lis2[i] * lis2[i];
            }
            dis = a / (Math.Sqrt(b) * Math.Sqrt(c));
            return 1.0f - (float)dis;
        }

        private static List<string> GenerateMeansId()
        {
            List<string> lis = new List<string>();
            int n = 1;
            Random ran = new Random();
            int seed = ran.Next(0, sampleId.Count);
            lis.Add(sampleId[seed]);

            Dictionary<string, float> pair_dis = new Dictionary<string,float>();

            while (n < K)
            {
                Console.WriteLine("seed:" + n);

                float sum_dis = 0.0f;

                Dictionary<string, float> dic = new Dictionary<string, float>();
                foreach (string id1 in sampleId)
                {
                    if (ran.NextDouble() > means_samp_ratio) //采样中心
                        continue;

                    if (!lis.Contains(id1))
                    {
                        float dis = float.MaxValue;
                        foreach (string id2 in lis)
                        {
                            float tmp_dis = 0;
                            tmp_dis = CosDis(id_vec[id1], id_vec[id2]);

                            /*
                            string pair = id1+"-"+id2;
                            if( !pair_dis.ContainsKey(pair) )
                            {
                                tmp_dis = CosDis(id_vec[id1], id_vec[id2]);
                                pair_dis.Add(pair, tmp_dis);
                            }
                            else
                            {
                                tmp_dis = pair_dis[pair];
                            }
                            */

                            dis = Math.Min(dis, tmp_dis);
                        }
                        dic.Add(id1, dis);
                        sum_dis += dis;

                    }
                }

                if(sum_dis == 0)
                {
                    continue;
                }

                float ran_dis = sum_dis * (float)ran.NextDouble();

                foreach (string id in dic.Keys)
                {
                    ran_dis -= dic[id];
                    if (ran_dis <= 0)
                    {
                        lis.Add(id);
                        break;
                    }
                }

                ++n;
            }

            return lis;
        }

        private static void UpdateCenters(List<List<float>> means, List<List<string>> clu_samp)
        {
            for (int i = 0; i < means.Count; ++i)
            {
                means[i].Clear();

                foreach (string id in clu_samp[i])
                {
                    for (int j = 0; j < id_vec[id].Count; ++j)
                    {
                        if (j < means[i].Count)
                            means[i][j] += id_vec[id][j];
                        else
                            means[i].Add(id_vec[id][j]);
                    }
                }

                for (int j = 0; j < means[i].Count; ++j)
                    means[i][j] /= clu_samp[i].Count;
            }
        }

        private static void SaveResult(List<List<float>> means, List<List<string>> clu_samp, int n)
        {
            StreamWriter sw_means = new StreamWriter(Parameter.dir + string.Format(means_path, n) );
            StreamWriter sw_clu_sample = new StreamWriter(Parameter.dir + string.Format(clu_sample_path, n) );

            string line = "";
            for( int i = 0; i < means.Count; ++i)
            {
                line = i + 1 + "\t";
                foreach (float f in means[i])
                    line += f.ToString() + " ";
                sw_means.WriteLine(line);
            }

            for(int i = 0; i < clu_samp.Count; ++i)
            {
                line = i + 1 + "\t";
                foreach (string id in clu_samp[i])
                    line += id + " ";
                sw_clu_sample.WriteLine(line);
            }

            sw_clu_sample.Close();
            sw_means.Close();
        }

        private static float GetWeightAvg(List<List<float>> means, List<List<string>> clu_smap)
        {
            float wei_avg = 0.0f;

            for (int i = 0; i < clu_smap.Count; ++i)
            {
                float tmp = 0.0f;
                foreach (string id in clu_smap[i])
                    tmp += CosDis(means[i], id_vec[id]);
                tmp /= clu_smap[i].Count;

                wei_avg += tmp;
            }

            return wei_avg / clu_smap.Count;
        }

        private static void KMeansClustering()
        {

            Console.WriteLine("...1..");

            //初始化聚类中心
            List<string> meansId = GenerateMeansId();
            List<List<float>> means = new List<List<float>>();
            foreach(string id in meansId)
            {
                Console.WriteLine(id);

                List<float> tmp_lis = new List<float>();
                id_vec[id].ForEach(item => tmp_lis.Add(item));
                means.Add(tmp_lis);
            }

            Console.WriteLine("...2...");

            List<List<string>> clu_samp = new List<List<string>>(); //每个簇包含的样例id
            for (int i = 0; i < K; ++i)
                clu_samp.Add( new List<string>() );

            Console.WriteLine("...3..");

            int n = 0;

            float best_wa = float.MaxValue;

            while( n < num_round)
            {
                Console.Write("epoch:" + n + " ........");
                Stopwatch stw = new Stopwatch();
                stw.Start();

                for (int i = 0; i < K; ++i)
                    clu_samp[i].Clear();

                foreach(string id in sampleId)
                {
                    int near_cen = 0;
                    float min_dis = float.MaxValue;
                    for(int i = 0; i < means.Count; ++i)
                    {
                        float tmp_dis = CosDis(id_vec[id], means[i]);
                        if(tmp_dis < min_dis)
                        {
                            min_dis = tmp_dis;
                            near_cen = i;
                        }
                    }

                    clu_samp[near_cen].Add(id);
                }

                UpdateCenters(means, clu_samp);

                ++n;
                stw.Stop();

                float t_wa = GetWeightAvg(means, clu_samp );
                TimeSpan ts = stw.Elapsed;
                Console.WriteLine(ts.TotalSeconds + "(s) ....... avgrage mass center distance:" + t_wa );

                SaveResult(means, clu_samp, n); //保存结果

                if (t_wa < best_wa)
                    best_wa = t_wa;
                else
                    break;
            }

        }

        private static int K = 500;
        private static int num_round = 50;

        private static double means_samp_ratio = 1.0; //如果数据量大初始化聚类中心可能会慢,这样可以采样部分来选中心

        private static Dictionary<string, List<float>> id_vec = new Dictionary<string, List<float>>();
        private static List<string> sampleId = new List<string>();

        //input
        private static string sample_vecs_path = @"data/number.titles.q.vec.tsv";

        //output
        private static string means_path = @"data/{0}.means.tsv";
        private static string clu_sample_path = @"data/{0}.cluster.samples.tsv";

        public static void Proccess()
        {
            Init(Parameter.dir + sample_vecs_path );

            KMeansClustering();
        }
    }
}

输入文件:

第一列是id,后面是对应数据的vector。
kmeans1

 

 

输出文件:

第一列是簇id,后面是对应的簇中心vector。

kmeans2

 

第一列是簇id,第二列是该簇所包含的样例id。

kmeans3

标签:

本文固定链接: http://www.luozhipeng.com/?p=945
转载请注明: luozhipeng 2016-8-3 于 罗志鹏的BLOG 发表

上一篇: :下一篇
返回顶部