Projekt

Obecné

Profil

Stáhnout (1.74 KB) Statistiky
| Větev: | Tag: | Revize:
1 4977ce53 Roman Kalivoda
//
2
// Author: Roman Kalivoda
3
//
4
5 abfd9c7c Roman Kalivoda
using System.Collections.Generic;
6
using System.Linq;
7
using Microsoft.ML;
8 9fc5fa93 Roman Kalivoda
using ServerApp.Parser.OutputInfo;
9 abfd9c7c Roman Kalivoda
10
namespace ServerApp.Predictor
11
{
12 4977ce53 Roman Kalivoda
    /// <summary>
13
    /// Implementation of the naive Bayes classifier in ML.NET.
14
    /// </summary>
15 abfd9c7c Roman Kalivoda
    class NaiveBayesClassifier : IPredictor
16
    {
17 4977ce53 Roman Kalivoda
        /// <summary>
18
        /// Context of the ML.NET framework.
19
        /// </summary>
20 abfd9c7c Roman Kalivoda
        private MLContext mlContext;
21
22 4977ce53 Roman Kalivoda
        /// <summary>
23
        /// Model instance
24
        /// </summary>
25 9fc5fa93 Roman Kalivoda
        private ITransformer model;
26
27 4977ce53 Roman Kalivoda
        /// <summary>
28
        /// Instantiates new <c>MLContext</c>.
29
        /// </summary>
30 abfd9c7c Roman Kalivoda
        public NaiveBayesClassifier()
31
        {
32
            mlContext = new MLContext();
33 9fc5fa93 Roman Kalivoda
        }
34
35 4977ce53 Roman Kalivoda
        public void Fit(IEnumerable<ModelInput> trainInput)
36 abfd9c7c Roman Kalivoda
        {
37 4977ce53 Roman Kalivoda
            IDataView trainingDataView = mlContext.Data.LoadFromEnumerable(trainInput);
38 66c3e0df Roman Kalivoda
            var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey(nameof(ModelInput.Label))
39
                .Append(mlContext.Transforms.Concatenate("Features", new[] { "temp" })
40
                .Append(mlContext.Transforms.NormalizeMinMax("Features", "Features")));
41 9fc5fa93 Roman Kalivoda
            var trainer = mlContext.MulticlassClassification.Trainers.NaiveBayes();
42 66c3e0df Roman Kalivoda
            var traininingPipeline = dataProcessPipeline.Append(trainer)
43
                .Append(mlContext.Transforms.Conversion.MapKeyToValue("prediction", "PredictedLabel"));
44 9fc5fa93 Roman Kalivoda
45
            this.model = traininingPipeline.Fit(trainingDataView);
46
47 abfd9c7c Roman Kalivoda
        }
48
49 4977ce53 Roman Kalivoda
        public IDataView Predict(IEnumerable<ModelInput> input)
50 abfd9c7c Roman Kalivoda
        {
51 9fc5fa93 Roman Kalivoda
            var data = mlContext.Data.LoadFromEnumerable(input);
52 66c3e0df Roman Kalivoda
            IDataView result = model.Transform(data);
53 4977ce53 Roman Kalivoda
            return result;
54 abfd9c7c Roman Kalivoda
        }
55
    }
56
}