Projekt

Obecné

Profil

Stáhnout (1.74 KB) Statistiky
| Větev: | Tag: | Revize:
1
//
2
// Author: Roman Kalivoda
3
//
4

    
5
using System.Collections.Generic;
6
using System.Linq;
7
using Microsoft.ML;
8
using ServerApp.Parser.OutputInfo;
9

    
10
namespace ServerApp.Predictor
11
{
12
    /// <summary>
13
    /// Implementation of the naive Bayes classifier in ML.NET.
14
    /// </summary>
15
    class NaiveBayesClassifier : IPredictor
16
    {
17
        /// <summary>
18
        /// Context of the ML.NET framework.
19
        /// </summary>
20
        private MLContext mlContext;
21

    
22
        /// <summary>
23
        /// Model instance
24
        /// </summary>
25
        private ITransformer model;
26

    
27
        /// <summary>
28
        /// Instantiates new <c>MLContext</c>.
29
        /// </summary>
30
        public NaiveBayesClassifier()
31
        {
32
            mlContext = new MLContext();
33
        }
34

    
35
        public void Fit(IEnumerable<ModelInput> trainInput)
36
        {
37
            IDataView trainingDataView = mlContext.Data.LoadFromEnumerable(trainInput);
38
            var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey(nameof(ModelInput.Label))
39
                .Append(mlContext.Transforms.Concatenate("Features", new[] { "temp" })
40
                .Append(mlContext.Transforms.NormalizeMinMax("Features", "Features")));
41
            var trainer = mlContext.MulticlassClassification.Trainers.NaiveBayes();
42
            var traininingPipeline = dataProcessPipeline.Append(trainer)
43
                .Append(mlContext.Transforms.Conversion.MapKeyToValue("prediction", "PredictedLabel"));
44

    
45
            this.model = traininingPipeline.Fit(trainingDataView);
46

    
47
        }
48

    
49
        public IDataView Predict(IEnumerable<ModelInput> input)
50
        {
51
            var data = mlContext.Data.LoadFromEnumerable(input);
52
            IDataView result = model.Transform(data);
53
            return result;
54
        }
55
    }
56
}
(6-6/7)