Projekt

Obecné

Profil

Stáhnout (3.3 KB) Statistiky
| Větev: | Tag: | Revize:
1
//
2
// Author: Roman Kalivoda
3
//
4

    
5
using System;
6
using System.Collections.Generic;
7
using log4net;
8
using Microsoft.ML;
9

    
10
namespace ServerApp.Predictor
11
{
12
    abstract class AbstractClassificationPredictor : IPredictor
13
    {
14
        private static readonly ILog _log = LogManager.GetLogger(typeof(AbstractClassificationPredictor));
15

    
16
        /// <summary>
17
        /// Context of the ML.NET framework.
18
        /// </summary>
19
        protected MLContext _mlContext;
20

    
21
        /// <summary>
22
        /// Model instance
23
        /// </summary>
24
        protected ITransformer _trainedModel;
25

    
26
        protected PredictionEngine<ModelInput, ModelOutput> _predictionEngine;
27

    
28
        protected IDataView _trainingDataView;
29

    
30
        public void Evaluate(IEnumerable<ModelInput> modelInputs)
31
        {
32
            var testDataView = this._mlContext.Data.LoadFromEnumerable(modelInputs);
33
            var data = _trainedModel.Transform(testDataView);
34
            var testMetrics = _mlContext.MulticlassClassification.Evaluate(data);
35

    
36
            Console.WriteLine($"*************************************************************************************************************");
37
            Console.WriteLine($"*       Metrics for Multi-class Classification model - Test Data     ");
38
            Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
39
            Console.WriteLine($"*       MicroAccuracy:    {testMetrics.MicroAccuracy:0.###}");
40
            Console.WriteLine($"*       MacroAccuracy:    {testMetrics.MacroAccuracy:0.###}");
41
            Console.WriteLine($"*       LogLoss:          {testMetrics.LogLoss:#.###}");
42
            Console.WriteLine($"*       LogLossReduction: {testMetrics.LogLossReduction:#.###}");
43
            Console.WriteLine($"*       Confusion Matrix: {testMetrics.ConfusionMatrix.GetFormattedConfusionTable()}");
44
            Console.WriteLine($"*************************************************************************************************************");
45
        }
46

    
47
        public abstract void Fit(IEnumerable<ModelInput> trainInput);
48

    
49
        public void Load(string filename)
50
        {
51
            DataViewSchema modelSchema;
52
            this._trainedModel = _mlContext.Model.Load(filename, out modelSchema);
53
            // TODO check if the loaded model has valid input and output schema
54
            this._predictionEngine = _mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(this._trainedModel);
55
        }
56

    
57
        public string Predict(ModelInput input)
58
        {
59
            _log.Debug($"Predicting for input: {input}");
60
            return this._predictionEngine.Predict(input).PredictedLabel;
61
        }
62

    
63
        public void Save(string filename)
64
        {
65
            if (this._trainingDataView is null)
66
            {
67
                throw new NullReferenceException("DataView is not set.");
68
            }
69
            if (this._trainedModel is null)
70
            {
71
                throw new NullReferenceException("Trained model instance does not exist. This predictor has not been trained yet.");
72
            }
73
            if (filename is null)
74
            {
75
                throw new ArgumentNullException(nameof(filename));
76
            }
77
            this._mlContext.Model.Save(this._trainedModel, this._trainingDataView.Schema, filename);
78
        }
79
    }
80
}
(1-1/11)