Projekt

Obecné

Profil

Stáhnout (2.68 KB) Statistiky
| Větev: | Tag: | Revize:
1
//
2
// Author: Roman Kalivoda
3
//
4

    
5
using System;
6
using System.Collections.Generic;
7
using System.Linq;
8
using log4net;
9
using Microsoft.ML;
10

    
11
namespace ServerApp.Predictor
12
{
13
    /// <summary>
14
    /// Implementation of the LightGBM classifier in ML.NET.
15
    /// </summary>
16
    class SdcaMaximumEntropyClassifier : AbstractClassificationPredictor
17
    {
18
        private static readonly ILog _log = LogManager.GetLogger(typeof(SdcaMaximumEntropyClassifier));
19

    
20
        /// <summary>
21
        /// Instantiates new <c>MLContext</c>.
22
        /// </summary>
23
        public SdcaMaximumEntropyClassifier()
24
        {
25
            this._mlContext = new MLContext();
26
        }
27

    
28
        /// <summary>
29
        /// Loads existing model from file
30
        /// </summary>
31
        /// <param name="filename">path to the model file.</param>
32
        public SdcaMaximumEntropyClassifier(string filename) : this()
33
        {
34
            DataViewSchema modelSchema;
35
            this._trainedModel = _mlContext.Model.Load(filename, out modelSchema);
36
            // TODO check if the loaded model has valid input and output schema
37
            this._predictionEngine = _mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(this._trainedModel);
38
        }
39

    
40
        public override void Fit(IEnumerable<ModelInput> trainInput)
41
        {
42
            this._trainingDataView = _mlContext.Data.LoadFromEnumerable(trainInput);
43
            var pipeline = _mlContext.Transforms.Conversion.MapValueToKey(nameof(ModelInput.Label))
44
                .Append(_mlContext.Transforms.Conversion.ConvertType(nameof(ModelInput.Hour)))
45
                .Append(_mlContext.Transforms.Concatenate("Features", 
46
                new[] { nameof(ModelInput.Temp), nameof(ModelInput.Rain), nameof(ModelInput.Wind), nameof(ModelInput.Hour) }))
47
                .Append(_mlContext.Transforms.NormalizeMeanVariance("Features", useCdf:false))
48
                .AppendCacheCheckpoint(_mlContext)
49
                .Append(_mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy())
50
                .Append(_mlContext.Transforms.Conversion.MapKeyToValue(nameof(ModelOutput.PredictedLabel)));
51

    
52
            var cvResults = _mlContext.MulticlassClassification.CrossValidate(this._trainingDataView, pipeline);
53
            _log.Debug("Cross-validated the trained model");
54
            this._trainedModel = cvResults.OrderByDescending(fold => fold.Metrics.MicroAccuracy).Select(fold => fold.Model).First();
55
            _log.Info($"Selected the model #{cvResults.OrderByDescending(fold => fold.Metrics.MicroAccuracy).Select(fold => fold.Fold).First()} as the best.");
56
            this._predictionEngine = _mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(this._trainedModel);
57

    
58
        }    
59
    }
60
}
(11-11/11)