Projekt

Obecné

Profil

Stáhnout (1.85 KB) Statistiky
| Větev: | Tag: | Revize:
1
using System;
2
using System.Collections.Generic;
3
using System.Linq;
4
using System.Text;
5
using System.Threading.Tasks;
6
using Microsoft.ML;
7
using Microsoft.ML.Data;
8
using ServerApp.Parser.OutputInfo;
9

    
10
namespace ServerApp.Predictor
11
{
12
    class NaiveBayesClassifier : IPredictor
13
    {
14
        private MLContext mlContext;
15

    
16
        private ITransformer model;
17

    
18
        public NaiveBayesClassifier()
19
        {
20
            mlContext = new MLContext();
21

    
22
        }
23

    
24
        public IEnumerable<ModelInput> ExtractModelInput(List<WeatherInfo> weatherInfos, List<ActivityInfo> activityInfos)
25
        {
26
            return weatherInfos.Select(e => new ModelInput(){
27
                Temp = (float)e.temp,
28
                Label = e.temp > 15.0 ? "Full" : "Empty",
29
            }).ToList();
30
        }
31

    
32
        public void Fit(IEnumerable<ModelInput> trainingData)
33
        {
34
            IDataView trainingDataView = mlContext.Data.LoadFromEnumerable(trainingData);
35
            var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey(nameof(ModelInput.Label))
36
                .Append(mlContext.Transforms.Concatenate("Features", new[] { "temp" })
37
                .Append(mlContext.Transforms.NormalizeMinMax("Features", "Features")));
38
            var trainer = mlContext.MulticlassClassification.Trainers.NaiveBayes();
39
            var traininingPipeline = dataProcessPipeline.Append(trainer)
40
                .Append(mlContext.Transforms.Conversion.MapKeyToValue("prediction", "PredictedLabel"));
41

    
42
            this.model = traininingPipeline.Fit(trainingDataView);
43

    
44
        }
45

    
46
        public String[] Predict(IEnumerable<ModelInput> input)
47
        {
48
            var data = mlContext.Data.LoadFromEnumerable(input);
49
            IDataView result = model.Transform(data);
50
            String[] prediction = result.GetColumn<String>("prediction").ToArray();
51

    
52
            return prediction;
53
        }
54
    }
55
}
(4-4/4)