Revize cdeee9f8
Přidáno uživatelem Roman Kalivoda před téměř 4 roky(ů)
Server/ServerApp/Predictor/NaiveBayesClassifier.cs | ||
---|---|---|
40 | 40 |
{ |
41 | 41 |
this._trainingDataView = _mlContext.Data.LoadFromEnumerable(trainInput); |
42 | 42 |
var pipeline = _mlContext.Transforms.Conversion.MapValueToKey(nameof(ModelInput.Label)) |
43 |
.Append(_mlContext.Transforms.Concatenate("Features", new[] { nameof(ModelInput.Temp), nameof(ModelInput.Rain), nameof(ModelInput.Wind) })) |
|
43 |
.Append(_mlContext.Transforms.Conversion.ConvertType(nameof(ModelInput.Hour))) |
|
44 |
.Append(_mlContext.Transforms.Concatenate("Features", |
|
45 |
new[] { nameof(ModelInput.Temp), nameof(ModelInput.Rain), nameof(ModelInput.Wind), nameof(ModelInput.Hour) })) |
|
44 | 46 |
.Append(_mlContext.Transforms.NormalizeMinMax("Features", "Features")) |
45 | 47 |
.AppendCacheCheckpoint(_mlContext) |
46 | 48 |
.Append(_mlContext.MulticlassClassification.Trainers.NaiveBayes()) |
47 | 49 |
.Append(_mlContext.Transforms.Conversion.MapKeyToValue(nameof(ModelOutput.PredictedLabel))); |
48 | 50 |
|
49 |
this._trainedModel = pipeline.Fit(this._trainingDataView); |
|
51 |
var cvResults = _mlContext.MulticlassClassification.CrossValidate(this._trainingDataView, pipeline); |
|
52 |
foreach (var result in cvResults) |
|
53 |
{ |
|
54 |
var testMetrics = result.Metrics; |
|
55 |
Console.WriteLine($"*************************************************************************************************************"); |
|
56 |
Console.WriteLine($"* Metrics for Multi-class Classification model - Model #{result.Fold} "); |
|
57 |
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); |
|
58 |
Console.WriteLine($"* MicroAccuracy: {testMetrics.MicroAccuracy:0.###}"); |
|
59 |
Console.WriteLine($"* MacroAccuracy: {testMetrics.MacroAccuracy:0.###}"); |
|
60 |
Console.WriteLine($"* LogLoss: {testMetrics.LogLoss:#.###}"); |
|
61 |
Console.WriteLine($"* LogLossReduction: {testMetrics.LogLossReduction:#.###}"); |
|
62 |
Console.WriteLine($"* Confusion Matrix: {testMetrics.ConfusionMatrix.GetFormattedConfusionTable()}"); |
|
63 |
Console.WriteLine($"*************************************************************************************************************"); |
|
64 |
} |
|
65 |
this._trainedModel = cvResults.OrderByDescending(fold => fold.Metrics.MicroAccuracy).Select(fold => fold.Model).First(); |
|
66 |
Console.WriteLine($"Selected the model #{cvResults.OrderByDescending(fold => fold.Metrics.MicroAccuracy).Select(fold => fold.Fold).First()} as the best."); |
|
50 | 67 |
this._predictionEngine = _mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(this._trainedModel); |
51 | 68 |
|
52 | 69 |
} |
... | ... | |
69 | 86 |
Console.WriteLine($"* MacroAccuracy: {testMetrics.MacroAccuracy:0.###}"); |
70 | 87 |
Console.WriteLine($"* LogLoss: {testMetrics.LogLoss:#.###}"); |
71 | 88 |
Console.WriteLine($"* LogLossReduction: {testMetrics.LogLossReduction:#.###}"); |
89 |
Console.WriteLine($"* Confusion Matrix: {testMetrics.ConfusionMatrix.GetFormattedConfusionTable()}"); |
|
72 | 90 |
Console.WriteLine($"*************************************************************************************************************"); |
73 | 91 |
} |
74 | 92 |
} |
Také k dispozici: Unified diff
Re #8955 implemented multiple predictors support