Revize d358b79e
Přidáno uživatelem Roman Kalivoda před téměř 4 roky(ů)
Server/ServerApp/Predictor/NaiveBayesClassifier.cs | ||
---|---|---|
2 | 2 |
// Author: Roman Kalivoda |
3 | 3 |
// |
4 | 4 |
|
5 |
using System; |
|
5 | 6 |
using System.Collections.Generic; |
6 | 7 |
using System.Linq; |
7 | 8 |
using Microsoft.ML; |
8 |
using ServerApp.Parser.OutputInfo; |
|
9 | 9 |
|
10 | 10 |
namespace ServerApp.Predictor |
11 | 11 |
{ |
... | ... | |
17 | 17 |
/// <summary> |
18 | 18 |
/// Context of the ML.NET framework. |
19 | 19 |
/// </summary> |
20 |
private MLContext mlContext; |
|
20 |
private MLContext _mlContext;
|
|
21 | 21 |
|
22 | 22 |
/// <summary> |
23 | 23 |
/// Model instance |
24 | 24 |
/// </summary> |
25 |
private ITransformer model; |
|
25 |
private ITransformer _trainedModel; |
|
26 |
|
|
27 |
private PredictionEngine<ModelInput, ModelOutput> _predictionEngine; |
|
28 |
|
|
29 |
IDataView _trainingDataView; |
|
26 | 30 |
|
27 | 31 |
/// <summary> |
28 | 32 |
/// Instantiates new <c>MLContext</c>. |
29 | 33 |
/// </summary> |
30 | 34 |
public NaiveBayesClassifier() |
31 | 35 |
{ |
32 |
mlContext = new MLContext(); |
|
36 |
_mlContext = new MLContext();
|
|
33 | 37 |
} |
34 | 38 |
|
35 | 39 |
public void Fit(IEnumerable<ModelInput> trainInput) |
36 | 40 |
{ |
37 |
IDataView trainingDataView = mlContext.Data.LoadFromEnumerable(trainInput); |
|
38 |
var pipeline = mlContext.Transforms.Conversion.MapValueToKey(nameof(ModelInput.Label)) |
|
39 |
.Append(mlContext.Transforms.Concatenate("Features", new[] { "Temp" })) |
|
40 |
.Append(mlContext.Transforms.NormalizeMinMax("Features", "Features")) |
|
41 |
.Append(mlContext.MulticlassClassification.Trainers.NaiveBayes()); |
|
41 |
this._trainingDataView = _mlContext.Data.LoadFromEnumerable(trainInput); |
|
42 |
var pipeline = _mlContext.Transforms.Conversion.MapValueToKey(nameof(ModelInput.Label)) |
|
43 |
.Append(_mlContext.Transforms.Concatenate("Features", new[] { "Temp" })) |
|
44 |
.Append(_mlContext.Transforms.NormalizeMinMax("Features", "Features")) |
|
45 |
.AppendCacheCheckpoint(_mlContext) |
|
46 |
.Append(_mlContext.MulticlassClassification.Trainers.NaiveBayes()) |
|
47 |
.Append(_mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel")); ; |
|
48 |
|
|
49 |
this._trainedModel = pipeline.Fit(this._trainingDataView); |
|
50 |
this._predictionEngine = _mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(this._trainedModel); |
|
42 | 51 |
|
43 |
this.model =pipeline.Fit(trainingDataView);
|
|
52 |
}
|
|
44 | 53 |
|
54 |
public string Predict(ModelInput input) |
|
55 |
{ |
|
56 |
return this._predictionEngine.Predict(input).Prediction; |
|
45 | 57 |
} |
46 | 58 |
|
47 |
public IDataView Predict(IEnumerable<ModelInput> input)
|
|
59 |
public void Evaluate(IEnumerable<ModelInput> modelInputs)
|
|
48 | 60 |
{ |
49 |
var data = mlContext.Data.LoadFromEnumerable(input); |
|
50 |
IDataView result = model.Transform(data); |
|
51 |
return result; |
|
61 |
var testDataView = this._mlContext.Data.LoadFromEnumerable(modelInputs); |
|
62 |
var testMetrics = _mlContext.MulticlassClassification.Evaluate(_trainedModel.Transform(testDataView)); |
|
63 |
|
|
64 |
Console.WriteLine($"*************************************************************************************************************"); |
|
65 |
Console.WriteLine($"* Metrics for Multi-class Classification model - Test Data "); |
|
66 |
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); |
|
67 |
Console.WriteLine($"* MicroAccuracy: {testMetrics.MicroAccuracy:0.###}"); |
|
68 |
Console.WriteLine($"* MacroAccuracy: {testMetrics.MacroAccuracy:0.###}"); |
|
69 |
Console.WriteLine($"* LogLoss: {testMetrics.LogLoss:#.###}"); |
|
70 |
Console.WriteLine($"* LogLossReduction: {testMetrics.LogLossReduction:#.###}"); |
|
71 |
Console.WriteLine($"*************************************************************************************************************"); |
|
52 | 72 |
} |
53 | 73 |
} |
54 | 74 |
} |
Také k dispozici: Unified diff
Re #8832 Label creation