aboutsummaryrefslogtreecommitdiff
path: root/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor_test.cc')
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor_test.cc100
1 files changed, 100 insertions, 0 deletions
diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor_test.cc b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor_test.cc
new file mode 100644
index 0000000..40ba9f4
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor_test.cc
@@ -0,0 +1,100 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h"
+
+#include <vector>
+
+#include "gtest/gtest.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+constexpr int kStableIterNum = 4;
+
+constexpr double kProcessNoise = 0.01;
+constexpr double kMeasurementNoise = 1.0;
+
+} // namespace
+
+struct DataSet {
+ double initial_observation;
+ std::vector<double> observation;
+ std::vector<double> position;
+ std::vector<double> velocity;
+ std::vector<double> acceleration;
+ std::vector<double> jerk;
+};
+
+void ValidateAxisPredictor(AxisPredictor* predictor, const DataSet& data) {
+ predictor->Reset();
+ predictor->Update(data.initial_observation);
+ for (decltype(data.observation.size()) i = 0; i < data.observation.size();
+ i++) {
+ predictor->Update(data.observation[i]);
+ EXPECT_NEAR(data.position[i], predictor->GetPosition(), 0.0001);
+ EXPECT_NEAR(data.velocity[i], predictor->GetVelocity(), 0.0001);
+ EXPECT_NEAR(data.acceleration[i], predictor->GetAcceleration(), 0.0001);
+ EXPECT_NEAR(data.jerk[i], predictor->GetJerk(), 0.0001);
+ }
+}
+
+// Test that the predictor will stable.
+TEST(AxisPredictorTest, ShouldStable) {
+ AxisPredictor predictor(kProcessNoise, kMeasurementNoise, kStableIterNum);
+ for (int i = 0; i < kStableIterNum; i++) {
+ EXPECT_FALSE(predictor.Stable());
+ predictor.Update(1);
+ }
+ EXPECT_TRUE(predictor.Stable());
+}
+
+// Test the kalman filter behavior. The data set is generated by a "known to
+// work" kalman filter.
+TEST(AxisPredictorTest, PredictedValue) {
+ AxisPredictor predictor(kProcessNoise, kMeasurementNoise, kStableIterNum);
+ DataSet data;
+ data.initial_observation = 0;
+ data.observation = {1, 2, 3, 4, 5, 6};
+ data.position = {0.6949411066858742, 1.8880162111305765, 3.0596776689233476,
+ 4.080666568886563, 5.039574058758894, 5.990101744132957};
+ data.velocity = {0.48326413015846115, 1.349212968908908, 1.5150757723942188,
+ 1.2449353797925855, 0.9823147273054352, 0.831418084705206};
+ data.acceleration = {0.20388102703160751, 0.6602537865634062,
+ 0.46392675203046707, 0.0691864035645362,
+ -0.1571001901104591, -0.2303438651979314};
+ data.jerk = {0.051351580374544535, 0.17805019978769315,
+ 0.06592110190532013, -0.06063794909774803,
+ -0.10198612906906362, -0.09541445938944032};
+
+ ValidateAxisPredictor(&predictor, data);
+
+ data.initial_observation = 0;
+ data.observation = {1, 2, 4, 8, 16, 32};
+ data.position = {0.6949411066858742, 1.8880162111305765, 3.9597202826804603,
+ 7.9052737853848285, 15.720340533540115, 31.24662046486774};
+ data.velocity = {0.48326413015846115, 1.349212968908908, 2.492271225870179,
+ 4.610844489557212, 8.828231877380588, 16.987494416071463};
+ data.acceleration = {0.20388102703160751, 0.6602537865634062,
+ 1.090991623810185, 1.885675547541351,
+ 3.4586206593783526, 6.34082285106952};
+ data.jerk = {0.051351580374544535, 0.17805019978769315, 0.25373225050247916,
+ 0.4023497012294069, 0.6945464157568688, 1.1947316519015612};
+
+ ValidateAxisPredictor(&predictor, data);
+}
+
+} // namespace stroke_model
+} // namespace ink