Projekt

Obecné

Profil

« Předchozí | Další » 

Revize cdeee9f8

Přidáno uživatelem Roman Kalivoda před téměř 4 roky(ů)

Re #8955 implemented multiple predictors support

Zobrazit rozdíly:

Server/ServerApp/Predictor/FeatureExtractor.cs
13 13
    /// <summary>
14 14
    /// A class responsible for preparation of features for classifiers.
15 15
    /// </summary>
16
    public class FeatureExtractor
16
    class FeatureExtractor
17 17
    {
18 18
        /// <summary>
19 19
        /// A DataParser instance used to access info objects.
20 20
        /// </summary>
21
        private readonly IDataParser dataParser;
21
        private readonly IDataParser DataParser;
22 22

  
23
        private Dictionary<string, int> buildingsToAreas;
23
        /// <summary>
24
        /// A configuration object of the <c>Predictor</c> package
25
        /// </summary>
26
        private PredictorConfiguration Configuration;
24 27

  
25 28
        /// <summary>
26 29
        /// Instantiates new FeatureExtractor class.
27 30
        /// </summary>
28 31
        /// <param name="dataParser">Data parser used to access training data.</param>
29
        public FeatureExtractor(IDataParser dataParser, Dictionary<string, int> buildingsToAreas)
32
        public FeatureExtractor(IDataParser dataParser, PredictorConfiguration configuration)
30 33
        {
31
            this.dataParser = dataParser;
32
            this.buildingsToAreas = buildingsToAreas;
34
            this.DataParser = dataParser;
35
            this.Configuration = configuration;
33 36
        }
34 37

  
35 38
        /// <summary>
......
41 44
        /// <param name="interval"></param>
42 45
        /// <param name="wholeDay"></param>
43 46
        /// <returns></returns>
44
        public List<ModelInput> PrepareTrainingInput(int area, DateTime startDate, DateTime endDate, int interval = 1, bool wholeDay = true)
47
        public List<ModelInput> PrepareTrainingInput(int area, DateTime startDate, DateTime endDate, int interval = 3, bool wholeDay = true)
45 48
        {
46
            dataParser.Parse(startDate, endDate, interval, wholeDay);
49
            DataParser.Parse(startDate, endDate, interval, wholeDay);
47 50
            List<string> buildings = new List<string>();
48 51

  
49 52
            // find all buildings in area
50
            foreach (KeyValuePair<string, int> kvp in buildingsToAreas)
53
            foreach (KeyValuePair<string, int> kvp in Configuration.BuildingsToAreas)
51 54
            {
52 55
                if (kvp.Value == area)
53 56
                {
......
56 59
            }
57 60

  
58 61
            var res = new List<ModelInput>();
59
            foreach (WeatherInfo val in dataParser.WeatherList)
62
            foreach (WeatherInfo val in DataParser.WeatherList)
60 63
            {
61 64
                res.Add(new ModelInput
62 65
                {
......
68 71
                });
69 72
            }
70 73

  
71
            List<ActivityInfo> attendance = dataParser.AttendanceList;
74
            List<ActivityInfo> attendance = DataParser.AttendanceList;
72 75
            foreach (ModelInput input in res)
73 76
            {
74 77
                List<int> amounts = new List<int>();
Server/ServerApp/Predictor/NaiveBayesClassifier.cs
40 40
        {
41 41
            this._trainingDataView = _mlContext.Data.LoadFromEnumerable(trainInput);
42 42
            var pipeline = _mlContext.Transforms.Conversion.MapValueToKey(nameof(ModelInput.Label))
43
                .Append(_mlContext.Transforms.Concatenate("Features", new[] { nameof(ModelInput.Temp), nameof(ModelInput.Rain), nameof(ModelInput.Wind) }))
43
                .Append(_mlContext.Transforms.Conversion.ConvertType(nameof(ModelInput.Hour)))
44
                .Append(_mlContext.Transforms.Concatenate("Features", 
45
                new[] { nameof(ModelInput.Temp), nameof(ModelInput.Rain), nameof(ModelInput.Wind), nameof(ModelInput.Hour) }))
44 46
                .Append(_mlContext.Transforms.NormalizeMinMax("Features", "Features"))
45 47
                .AppendCacheCheckpoint(_mlContext)
46 48
                .Append(_mlContext.MulticlassClassification.Trainers.NaiveBayes())
47 49
                .Append(_mlContext.Transforms.Conversion.MapKeyToValue(nameof(ModelOutput.PredictedLabel)));
48 50

  
49
            this._trainedModel = pipeline.Fit(this._trainingDataView);
51
            var cvResults = _mlContext.MulticlassClassification.CrossValidate(this._trainingDataView, pipeline);
52
            foreach (var result in cvResults)
53
            {
54
                var testMetrics = result.Metrics;
55
                Console.WriteLine($"*************************************************************************************************************");
56
                Console.WriteLine($"*       Metrics for Multi-class Classification model - Model #{result.Fold}    ");
57
                Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
58
                Console.WriteLine($"*       MicroAccuracy:    {testMetrics.MicroAccuracy:0.###}");
59
                Console.WriteLine($"*       MacroAccuracy:    {testMetrics.MacroAccuracy:0.###}");
60
                Console.WriteLine($"*       LogLoss:          {testMetrics.LogLoss:#.###}");
61
                Console.WriteLine($"*       LogLossReduction: {testMetrics.LogLossReduction:#.###}");
62
                Console.WriteLine($"*       Confusion Matrix: {testMetrics.ConfusionMatrix.GetFormattedConfusionTable()}");
63
                Console.WriteLine($"*************************************************************************************************************");
64
            }
65
            this._trainedModel = cvResults.OrderByDescending(fold => fold.Metrics.MicroAccuracy).Select(fold => fold.Model).First();
66
            Console.WriteLine($"Selected the model #{cvResults.OrderByDescending(fold => fold.Metrics.MicroAccuracy).Select(fold => fold.Fold).First()} as the best.");
50 67
            this._predictionEngine = _mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(this._trainedModel);
51 68

  
52 69
        }
......
69 86
            Console.WriteLine($"*       MacroAccuracy:    {testMetrics.MacroAccuracy:0.###}");
70 87
            Console.WriteLine($"*       LogLoss:          {testMetrics.LogLoss:#.###}");
71 88
            Console.WriteLine($"*       LogLossReduction: {testMetrics.LogLossReduction:#.###}");
89
            Console.WriteLine($"*       Confusion Matrix: {testMetrics.ConfusionMatrix.GetFormattedConfusionTable()}");
72 90
            Console.WriteLine($"*************************************************************************************************************");
73 91
        }
74 92
    }
Server/ServerApp/Predictor/PredictionController.cs
6 6
using System.Collections.Generic;
7 7
using ServerApp.Connection.XMLProtocolHandler;
8 8
using ServerApp.Parser.Parsers;
9
using Newtonsoft.Json;
9 10

  
10 11
namespace ServerApp.Predictor
11 12
{
......
15 16
    class PredictionController : IPredictionController
16 17
    {
17 18
        /// <summary>
18
        /// A dictionary for storing trained predictors.
19
        /// Configuration of the <c>Predictor</c>
19 20
        /// </summary>
20
        private Dictionary<string, int> buildingsToAreas;
21
        private PredictorConfiguration Configuration;
21 22

  
22
        private List<IPredictor> predictors;
23
        private List<IPredictor> Predictors;
23 24

  
24 25
        /// <summary>
25 26
        /// A reference to a data parser.
26 27
        /// </summary>
27
        private IDataParser dataParser;
28
        private IDataParser DataParser;
28 29

  
29 30
        /// <summary>
30 31
        /// A feature extractor instance.
31 32
        /// </summary>
32
        private FeatureExtractor featureExtractor;
33
        private FeatureExtractor FeatureExtractor;
33 34

  
34 35
        /// <summary>
35 36
        /// Instantiates new prediction controller.
36 37
        /// </summary>
37 38
        /// <param name="dataParser">A data parser used to get training data.</param>
38
        public PredictionController(IDataParser dataParser)
39
        public PredictionController(IDataParser dataParser, string pathToConfig = null)
39 40
        {
40
            this.dataParser = dataParser;
41
            this.predictors = new List<IPredictor>();
42
            this.buildingsToAreas = new Dictionary<string, int>();
43
            this.featureExtractor = new FeatureExtractor(this.dataParser, buildingsToAreas);
41
            // load config or get the default one
42
            if (pathToConfig is null)
43
            {
44
                pathToConfig = PredictorConfiguration.DEFAULT_CONFIG_PATH;
45
            }
46
            try
47
            {
48
                string json = System.IO.File.ReadAllText(pathToConfig);
49
                this.Configuration = JsonConvert.DeserializeObject<PredictorConfiguration>(json);
50
            } catch (System.IO.IOException e)
51
            {
52
                Console.WriteLine(e.ToString());
53
                this.Configuration = PredictorConfiguration.GetDefaultConfig();
54
            }
55

  
56
            this.DataParser = dataParser;
57
            this.Predictors = new List<IPredictor>();
58
            this.FeatureExtractor = new FeatureExtractor(this.DataParser, this.Configuration);
44 59

  
45
            // fill predictors with all available locationKeys
46
            // TODO Currently all locations use the same predictor. Try dividing locations into subareas with separate predictors.
47
            var locationKeys = TagInfo.buildings;
48
            foreach (string key in locationKeys)
60
            for (int i = 0; i < this.Configuration.PredictorCount; i++)
49 61
            {
50
                buildingsToAreas.Add(key, 0);
62
                Predictors.Add(new NaiveBayesClassifier());
51 63
            }
52
            IPredictor predictor = new NaiveBayesClassifier();
53
            predictors.Add(predictor);
64
            PredictorConfiguration.SaveConfig(PredictorConfiguration.DEFAULT_CONFIG_PATH, Configuration);
54 65
        }
55 66
        public List<string> GetPredictors()
56 67
        {
57
            return new List<string>(buildingsToAreas.Keys);
68
            return new List<string>(this.Configuration.BuildingsToAreas.Keys);
58 69
        }
59 70

  
60 71
        public void Load(string locationKey = null, string path = null)
......
81 92
            // train all predictors
82 93
            {
83 94
                // TODO A single predictor is used for all areas, so training is done only once now.
84
                for (int i = 0; i < this.predictors.Count; i++)
95
                for (int i = 0; i < this.Predictors.Count; i++)
85 96
                {
86 97
                    // train on all available data
87 98
                    // TODO the train/test split is used just temporarily for demonstration.
88
                    List<ModelInput> data = featureExtractor.PrepareTrainingInput(i, DateTime.MinValue, DateTime.MaxValue);
89
                    List<ModelInput> trainingData = data.GetRange(index: 0, count: 500);
90
                    List<ModelInput> testData = data.GetRange(index: 500, count: 94);
91
                    Console.WriteLine("Training predictor with {0} samples.", trainingData.Count);
92
                    this.predictors[i].Fit(trainingData);
93

  
94
                    Console.WriteLine("Evaluating predictor with {0} samples.", testData.Count);
95
                    this.predictors[i].Evaluate(testData);
99
                    List<ModelInput> data = FeatureExtractor.PrepareTrainingInput(i, DateTime.MinValue, DateTime.MaxValue);
100
                    Console.WriteLine("Training predictor with {0} samples.", data.Count);
101
                    this.Predictors[i].Fit(data);
96 102
                }
97 103
            }
98 104
            else
Server/ServerApp/Predictor/PredictorConfiguration.cs
1
//
2
// Author: Roman Kalivoda
3
//
4

  
5
using System.Collections.Generic;
6
using System;
7
using System.IO;
8
using Newtonsoft.Json;
9

  
10
namespace ServerApp.Predictor
11
{
12
    class PredictorConfiguration
13
    {
14
        public static readonly string DEFAULT_CONFIG_PATH = Path.GetFullPath(Path.GetDirectoryName(AppDomain.CurrentDomain.SetupInformation.ConfigurationFile) + @"\Predictor.config");
15

  
16
        public int TimeResolution { get; set; }
17

  
18
        public Dictionary<string, int> BuildingsToAreas { get; set; }
19

  
20
        public int PredictorCount { get; set; }
21

  
22
        public static PredictorConfiguration LoadConfig(string filename)
23
        {
24
            string json = System.IO.File.ReadAllText(filename);
25
            PredictorConfiguration configuration = JsonConvert.DeserializeObject<PredictorConfiguration>(json);
26
            return configuration;
27
        }
28

  
29
        public static void SaveConfig(string filename, PredictorConfiguration configuration)
30
        {
31
            string json = JsonConvert.SerializeObject(configuration);
32
            System.IO.File.WriteAllText(filename, json);
33
        }
34

  
35
        public static PredictorConfiguration GetDefaultConfig()
36
        {
37
            Dictionary<string, int> dict = new Dictionary<string, int>();
38
            var locationKeys = Parser.Parsers.TagInfo.buildings;
39
            foreach (string key in locationKeys)
40
            {
41
                dict.Add(key, 0);
42
            }
43

  
44
            return new PredictorConfiguration
45
            {
46
                TimeResolution = 3,
47
                PredictorCount = 3,
48
                BuildingsToAreas = new Dictionary<string, int>
49
                {
50
                    { "FST+FEK", 0 },
51
                    { "FDU", 0 },
52
                    { "FAV", 0 },
53
                    { "FEL", 0 },
54
                    { "REK", 0 },
55
                    { "MENZA", 0 },
56
                    { "LIB", 0 },
57
                    { "CIV", 0 },
58
                    { "UNI14", 0 },
59
                    { "DOM", 1 },
60
                    { "HUS", 1 },
61
                    { "CHOD", 1 },
62
                    { "JUNG", 1 },
63
                    { "KLAT", 1 },
64
                    { "KOLL", 1 },
65
                    { "RIEG", 1 },
66
                    { "SADY", 1 },
67
                    { "SED+VEL", 1 },
68
                    { "TES", 1 },
69
                    { "TYL", 1 },
70
                    { "KARMA", 2 },
71
                    { "KBORY", 2 },
72
                    { "KLOCH", 2 },
73
                    { "KKLAT", 2 }
74
                }
75
            };
76
        }
77
    }
78
}
Server/ServerApp/ServerApp.csproj
179 179
    <Compile Include="Parser\Parsers\DataParser.cs" />
180 180
    <Compile Include="Parser\Parsers\IDataParser.cs" />
181 181
    <Compile Include="Parser\Parsers\JisParser.cs" />
182
    <Compile Include="Predictor\PredictorConfiguration.cs" />
182 183
    <Compile Include="Predictor\FeatureExtractor.cs" />
183 184
    <Compile Include="Predictor\IPredictionController.cs" />
184 185
    <Compile Include="Predictor\PredictionController.cs" />

Také k dispozici: Unified diff