Revize 0d31f7e0
Přidáno uživatelem Roman Kalivoda před téměř 4 roky(ů)
Server/ServerApp/Predictor/NaiveBayesClassifier.cs | ||
---|---|---|
5 | 5 |
using System; |
6 | 6 |
using System.Collections.Generic; |
7 | 7 |
using System.Linq; |
8 |
using System.Reflection; |
|
9 |
using log4net; |
|
8 | 10 |
using Microsoft.ML; |
9 | 11 |
|
10 | 12 |
namespace ServerApp.Predictor |
... | ... | |
14 | 16 |
/// </summary> |
15 | 17 |
class NaiveBayesClassifier : IPredictor |
16 | 18 |
{ |
19 |
private static readonly ILog _log = LogManager.GetLogger(typeof(NaiveBayesClassifier)); |
|
20 |
|
|
17 | 21 |
/// <summary> |
18 | 22 |
/// Context of the ML.NET framework. |
19 | 23 |
/// </summary> |
... | ... | |
49 | 53 |
.Append(_mlContext.Transforms.Conversion.MapKeyToValue(nameof(ModelOutput.PredictedLabel))); |
50 | 54 |
|
51 | 55 |
var cvResults = _mlContext.MulticlassClassification.CrossValidate(this._trainingDataView, pipeline); |
52 |
foreach (var result in cvResults) |
|
53 |
{ |
|
54 |
var testMetrics = result.Metrics; |
|
55 |
Console.WriteLine($"*************************************************************************************************************"); |
|
56 |
Console.WriteLine($"* Metrics for Multi-class Classification model - Model #{result.Fold} "); |
|
57 |
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); |
|
58 |
Console.WriteLine($"* MicroAccuracy: {testMetrics.MicroAccuracy:0.###}"); |
|
59 |
Console.WriteLine($"* MacroAccuracy: {testMetrics.MacroAccuracy:0.###}"); |
|
60 |
Console.WriteLine($"* LogLoss: {testMetrics.LogLoss:#.###}"); |
|
61 |
Console.WriteLine($"* LogLossReduction: {testMetrics.LogLossReduction:#.###}"); |
|
62 |
Console.WriteLine($"* Confusion Matrix: {testMetrics.ConfusionMatrix.GetFormattedConfusionTable()}"); |
|
63 |
Console.WriteLine($"*************************************************************************************************************"); |
|
64 |
} |
|
56 |
_log.Debug("Cross-validated the trained model"); |
|
65 | 57 |
this._trainedModel = cvResults.OrderByDescending(fold => fold.Metrics.MicroAccuracy).Select(fold => fold.Model).First(); |
66 |
Console.WriteLine($"Selected the model #{cvResults.OrderByDescending(fold => fold.Metrics.MicroAccuracy).Select(fold => fold.Fold).First()} as the best.");
|
|
58 |
_log.Info($"Selected the model #{cvResults.OrderByDescending(fold => fold.Metrics.MicroAccuracy).Select(fold => fold.Fold).First()} as the best.");
|
|
67 | 59 |
this._predictionEngine = _mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(this._trainedModel); |
68 | 60 |
|
69 | 61 |
} |
70 | 62 |
|
71 | 63 |
public string Predict(ModelInput input) |
72 | 64 |
{ |
65 |
_log.Debug($"Predicting for input: {input}"); |
|
73 | 66 |
return this._predictionEngine.Predict(input).PredictedLabel; |
74 | 67 |
} |
75 | 68 |
|
Také k dispozici: Unified diff
Re #8953 tests