Projekt

Obecné

Profil

Stáhnout (2.5 KB) Statistiky
| Větev: | Tag: | Revize:
1
//
2
// Author: Roman Kalivoda
3
//
4

    
5
using System;
6
using System.Collections.Generic;
7
using System.Linq;
8
using log4net;
9
using Microsoft.ML;
10

    
11
namespace ServerApp.Predictor
12
{
13
    /// <summary>
14
    /// Implementation of the naive Bayes classifier in ML.NET.
15
    /// </summary>
16
    class NaiveBayesClassifier : AbstractClassificationPredictor
17
    {
18
        private static readonly ILog _log = LogManager.GetLogger(typeof(NaiveBayesClassifier));
19

    
20
        /// <summary>
21
        /// Instantiates new <c>MLContext</c>.
22
        /// </summary>
23
        public NaiveBayesClassifier()
24
        {
25
            this._mlContext = new MLContext();
26
        }
27

    
28
        public NaiveBayesClassifier(string filename) : this()
29
        {
30
            DataViewSchema modelSchema;
31
            this._trainedModel = _mlContext.Model.Load(filename, out modelSchema);
32
            // TODO check if the loaded model has valid input and output schema
33
            this._predictionEngine = _mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(this._trainedModel);
34
        }
35

    
36
        public override void Fit(IEnumerable<ModelInput> trainInput)
37
        {
38
            this._trainingDataView = _mlContext.Data.LoadFromEnumerable(trainInput);
39
            var pipeline = _mlContext.Transforms.Conversion.MapValueToKey(nameof(ModelInput.Label))
40
                .Append(_mlContext.Transforms.Conversion.ConvertType(nameof(ModelInput.Hour)))
41
                .Append(_mlContext.Transforms.Concatenate("Features", 
42
                new[] { nameof(ModelInput.Temp), nameof(ModelInput.Rain), nameof(ModelInput.Wind), nameof(ModelInput.Hour) }))
43
                .Append(_mlContext.Transforms.NormalizeMeanVariance("Features", useCdf:false))
44
                .AppendCacheCheckpoint(_mlContext)
45
                .Append(_mlContext.MulticlassClassification.Trainers.NaiveBayes())
46
                .Append(_mlContext.Transforms.Conversion.MapKeyToValue(nameof(ModelOutput.PredictedLabel)));
47

    
48
            var cvResults = _mlContext.MulticlassClassification.CrossValidate(this._trainingDataView, pipeline);
49
            _log.Debug("Cross-validated the trained model");
50
            this._trainedModel = cvResults.OrderByDescending(fold => fold.Metrics.MicroAccuracy).Select(fold => fold.Model).First();
51
            _log.Info($"Selected the model #{cvResults.OrderByDescending(fold => fold.Metrics.MicroAccuracy).Select(fold => fold.Fold).First()} as the best.");
52
            this._predictionEngine = _mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(this._trainedModel);
53

    
54
        }    
55
    }
56
}
(8-8/11)