Revize 76072df0
Přidáno uživatelem Roman Kalivoda před téměř 4 roky(ů)
Server/ServerApp/Predictor/NaiveBayesClassifier.cs | ||
---|---|---|
13 | 13 |
/// <summary> |
14 | 14 |
/// Implementation of the naive Bayes classifier in ML.NET. |
15 | 15 |
/// </summary> |
16 |
class NaiveBayesClassifier : IPredictor
|
|
16 |
class NaiveBayesClassifier : AbstractClassificationPredictor
|
|
17 | 17 |
{ |
18 | 18 |
private static readonly ILog _log = LogManager.GetLogger(typeof(NaiveBayesClassifier)); |
19 | 19 |
|
20 |
/// <summary> |
|
21 |
/// Context of the ML.NET framework. |
|
22 |
/// </summary> |
|
23 |
private MLContext _mlContext; |
|
24 |
|
|
25 |
/// <summary> |
|
26 |
/// Model instance |
|
27 |
/// </summary> |
|
28 |
private ITransformer _trainedModel; |
|
29 |
|
|
30 |
private PredictionEngine<ModelInput, ModelOutput> _predictionEngine; |
|
31 |
|
|
32 |
IDataView _trainingDataView; |
|
33 |
|
|
34 | 20 |
/// <summary> |
35 | 21 |
/// Instantiates new <c>MLContext</c>. |
36 | 22 |
/// </summary> |
37 | 23 |
public NaiveBayesClassifier() |
38 | 24 |
{ |
39 |
_mlContext = new MLContext(); |
|
25 |
this._mlContext = new MLContext();
|
|
40 | 26 |
} |
41 | 27 |
|
42 | 28 |
public NaiveBayesClassifier(string filename) : this() |
... | ... | |
47 | 33 |
this._predictionEngine = _mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(this._trainedModel); |
48 | 34 |
} |
49 | 35 |
|
50 |
public void Save(string filename) |
|
51 |
{ |
|
52 |
if (this._trainingDataView is null) |
|
53 |
{ |
|
54 |
throw new NullReferenceException("DataView is not set."); |
|
55 |
} |
|
56 |
if( this._trainedModel is null) |
|
57 |
{ |
|
58 |
throw new NullReferenceException("Trained model instance does not exist. This predictor has not been trained yet."); |
|
59 |
} |
|
60 |
if(filename is null) |
|
61 |
{ |
|
62 |
throw new ArgumentNullException(nameof(filename)); |
|
63 |
} |
|
64 |
this._mlContext.Model.Save(this._trainedModel, this._trainingDataView.Schema, filename); |
|
65 |
} |
|
66 |
|
|
67 |
public void Fit(IEnumerable<ModelInput> trainInput) |
|
36 |
public override void Fit(IEnumerable<ModelInput> trainInput) |
|
68 | 37 |
{ |
69 | 38 |
this._trainingDataView = _mlContext.Data.LoadFromEnumerable(trainInput); |
70 | 39 |
var pipeline = _mlContext.Transforms.Conversion.MapValueToKey(nameof(ModelInput.Label)) |
... | ... | |
82 | 51 |
_log.Info($"Selected the model #{cvResults.OrderByDescending(fold => fold.Metrics.MicroAccuracy).Select(fold => fold.Fold).First()} as the best."); |
83 | 52 |
this._predictionEngine = _mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(this._trainedModel); |
84 | 53 |
|
85 |
} |
|
86 |
|
|
87 |
public string Predict(ModelInput input) |
|
88 |
{ |
|
89 |
_log.Debug($"Predicting for input: {input}"); |
|
90 |
return this._predictionEngine.Predict(input).PredictedLabel; |
|
91 |
} |
|
92 |
|
|
93 |
public void Evaluate(IEnumerable<ModelInput> modelInputs) |
|
94 |
{ |
|
95 |
var testDataView = this._mlContext.Data.LoadFromEnumerable(modelInputs); |
|
96 |
var data = _trainedModel.Transform(testDataView); |
|
97 |
var testMetrics = _mlContext.MulticlassClassification.Evaluate(data); |
|
98 |
|
|
99 |
Console.WriteLine($"*************************************************************************************************************"); |
|
100 |
Console.WriteLine($"* Metrics for Multi-class Classification model - Test Data "); |
|
101 |
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); |
|
102 |
Console.WriteLine($"* MicroAccuracy: {testMetrics.MicroAccuracy:0.###}"); |
|
103 |
Console.WriteLine($"* MacroAccuracy: {testMetrics.MacroAccuracy:0.###}"); |
|
104 |
Console.WriteLine($"* LogLoss: {testMetrics.LogLoss:#.###}"); |
|
105 |
Console.WriteLine($"* LogLossReduction: {testMetrics.LogLossReduction:#.###}"); |
|
106 |
Console.WriteLine($"* Confusion Matrix: {testMetrics.ConfusionMatrix.GetFormattedConfusionTable()}"); |
|
107 |
Console.WriteLine($"*************************************************************************************************************"); |
|
108 |
} |
|
54 |
} |
|
109 | 55 |
} |
110 | 56 |
} |
Také k dispozici: Unified diff
Re #8597 Implementation of AbstractClassificationPredictor, SdcaMEClassifier