]> gitweb.fperrin.net Git - GpsPrune.git/blob - src/tim/prune/function/estimate/LearnParameters.java
Version 20.4, May 2021
[GpsPrune.git] / src / tim / prune / function / estimate / LearnParameters.java
1 package tim.prune.function.estimate;
2
3 import java.awt.BorderLayout;
4 import java.awt.Component;
5 import java.awt.FlowLayout;
6 import java.awt.event.ActionEvent;
7 import java.awt.event.ActionListener;
8 import java.awt.event.AdjustmentEvent;
9 import java.awt.event.AdjustmentListener;
10 import java.awt.event.KeyAdapter;
11 import java.awt.event.KeyEvent;
12 import java.util.ArrayList;
13
14 import javax.swing.BorderFactory;
15 import javax.swing.Box;
16 import javax.swing.BoxLayout;
17 import javax.swing.JButton;
18 import javax.swing.JDialog;
19 import javax.swing.JLabel;
20 import javax.swing.JPanel;
21 import javax.swing.JScrollBar;
22
23 import tim.prune.App;
24 import tim.prune.GenericFunction;
25 import tim.prune.I18nManager;
26 import tim.prune.config.Config;
27 import tim.prune.data.DataPoint;
28 import tim.prune.data.Distance;
29 import tim.prune.data.RangeStatsWithGradients;
30 import tim.prune.data.Track;
31 import tim.prune.data.Unit;
32 import tim.prune.data.UnitSetLibrary;
33 import tim.prune.function.estimate.jama.Matrix;
34 import tim.prune.gui.ProgressDialog;
35
36 /**
37  * Function to learn the estimation parameters from the current track
38  */
39 public class LearnParameters extends GenericFunction implements Runnable
40 {
41         /** Progress dialog */
42         ProgressDialog _progress = null;
43         /** Results dialog */
44         JDialog _dialog = null;
45         /** Calculated parameters */
46         private ParametersPanel _calculatedParamPanel = null;
47         private EstimationParameters _calculatedParams = null;
48         /** Slider for weighted average */
49         private JScrollBar _weightSlider = null;
50         /** Label to describe position of slider */
51         private JLabel _sliderDescLabel = null;
52         /** Combined parameters */
53         private ParametersPanel _combinedParamPanel = null;
54         /** Combine button */
55         private JButton _combineButton = null;
56
57
58         /**
59          * Inner class used to hold the results of the matrix solving
60          */
61         static class MatrixResults
62         {
63                 public EstimationParameters _parameters = null;
64                 public double _averageErrorPc = 0.0; // percentage
65         }
66
67
68         /**
69          * Constructor
70          * @param inApp App object
71          */
72         public LearnParameters(App inApp)
73         {
74                 super(inApp);
75         }
76
77         /** @return key for function name */
78         public String getNameKey() {
79                 return "function.learnestimationparams";
80         }
81
82         /**
83          * Begin the function
84          */
85         public void begin()
86         {
87                 // Show progress bar
88                 if (_progress == null) {
89                         _progress = new ProgressDialog(_parentFrame, getNameKey());
90                 }
91                 _progress.show();
92                 // Start new thread for the calculations
93                 new Thread(this).start();
94         }
95
96         /**
97          * Run method in separate thread
98          */
99         public void run()
100         {
101                 _progress.setMaximum(100);
102                 // Go through the track and collect the range stats for each sample
103                 ArrayList<RangeStatsWithGradients> statsList = new ArrayList<RangeStatsWithGradients>(20);
104                 Track track = _app.getTrackInfo().getTrack();
105                 final int numPoints = track.getNumPoints();
106                 final int sampleSize = numPoints / 30;
107                 int prevStartIndex = -1;
108                 for (int i=0; i<30; i++)
109                 {
110                         int startIndex = i * sampleSize;
111                         RangeStatsWithGradients stats = getRangeStats(track, startIndex, startIndex + sampleSize, prevStartIndex);
112                         if (stats != null && stats.getMovingDistanceKilometres() > 1.0
113                                 && !stats.getTimestampsIncomplete() && !stats.getTimestampsOutOfSequence()
114                                 && stats.getTotalDurationInSeconds() > 100
115                                 && startIndex > prevStartIndex)
116                         {
117                                 // System.out.println("Got stats for " + stats.getStartIndex() + " to " + stats.getEndIndex());
118                                 statsList.add(stats);
119                                 prevStartIndex = startIndex;
120                         }
121                         _progress.setValue(i);
122                 }
123
124                 // Check if we've got enough samples
125                 // System.out.println("Got a total of " + statsList.size() + " samples");
126                 if (statsList.size() < 10)
127                 {
128                         _progress.dispose();
129                         // Show error message, not enough samples
130                         _app.showErrorMessage(getNameKey(), "error.learnestimationparams.failed");
131                         return;
132                 }
133                 // Loop around, solving the matrices and removing the highest-error sample
134                 MatrixResults results = reduceSamples(statsList);
135                 if (results == null)
136                 {
137                         _progress.dispose();
138                         _app.showErrorMessage(getNameKey(), "error.learnestimationparams.failed");
139                         return;
140                 }
141
142                 _progress.dispose();
143
144                 // Create the dialog if necessary
145                 if (_dialog == null)
146                 {
147                         _dialog = new JDialog(_parentFrame, I18nManager.getText(getNameKey()), true);
148                         _dialog.setLocationRelativeTo(_parentFrame);
149                         // Create Gui and show it
150                         _dialog.getContentPane().add(makeDialogComponents());
151                         _dialog.pack();
152                 }
153
154                 // Populate the values in the dialog
155                 populateCalculatedValues(results);
156                 updateCombinedLabels(calculateCombinedParameters());
157                 _dialog.setVisible(true);
158         }
159
160
161         /**
162          * Make the dialog components
163          * @return the GUI components for the dialog
164          */
165         private Component makeDialogComponents()
166         {
167                 JPanel dialogPanel = new JPanel();
168                 dialogPanel.setLayout(new BorderLayout());
169
170                 // main panel with a box layout
171                 JPanel mainPanel = new JPanel();
172                 mainPanel.setLayout(new BoxLayout(mainPanel, BoxLayout.Y_AXIS));
173                 // Label at top
174                 JLabel introLabel = new JLabel(I18nManager.getText("dialog.learnestimationparams.intro") + ":");
175                 introLabel.setBorder(BorderFactory.createEmptyBorder(5, 5, 5, 5));
176                 introLabel.setAlignmentX(Component.LEFT_ALIGNMENT);
177                 mainPanel.add(introLabel);
178
179                 // Panel for the calculated results
180                 _calculatedParamPanel = new ParametersPanel("dialog.estimatetime.results", true);
181                 _calculatedParamPanel.setAlignmentX(Component.LEFT_ALIGNMENT);
182                 mainPanel.add(_calculatedParamPanel);
183                 mainPanel.add(Box.createVerticalStrut(14));
184
185                 mainPanel.add(new JLabel(I18nManager.getText("dialog.learnestimationparams.combine") + ":"));
186                 mainPanel.add(Box.createVerticalStrut(4));
187                 _weightSlider = new JScrollBar(JScrollBar.HORIZONTAL, 5, 1, 0, 11);
188                 _weightSlider.addAdjustmentListener(new AdjustmentListener() {
189                         public void adjustmentValueChanged(AdjustmentEvent inEvent)
190                         {
191                                 if (!inEvent.getValueIsAdjusting()) {
192                                         updateCombinedLabels(calculateCombinedParameters());
193                                 }
194                         }
195                 });
196                 mainPanel.add(_weightSlider);
197                 _sliderDescLabel = new JLabel(" ");
198                 _sliderDescLabel.setAlignmentX(Component.LEFT_ALIGNMENT);
199                 mainPanel.add(_sliderDescLabel);
200                 mainPanel.add(Box.createVerticalStrut(12));
201
202                 // Results panel
203                 _combinedParamPanel = new ParametersPanel("dialog.learnestimationparams.combinedresults");
204                 _combinedParamPanel.setAlignmentX(Component.LEFT_ALIGNMENT);
205                 mainPanel.add(_combinedParamPanel);
206
207                 dialogPanel.add(mainPanel, BorderLayout.NORTH);
208
209                 // button panel at bottom
210                 JPanel buttonPanel = new JPanel();
211                 buttonPanel.setLayout(new FlowLayout(FlowLayout.RIGHT));
212
213                 // Combine
214                 _combineButton = new JButton(I18nManager.getText("button.combine"));
215                 _combineButton.addActionListener(new ActionListener() {
216                         public void actionPerformed(ActionEvent arg0) {
217                                 combineAndFinish();
218                         }
219                 });
220                 buttonPanel.add(_combineButton);
221
222                 // Cancel
223                 JButton cancelButton = new JButton(I18nManager.getText("button.cancel"));
224                 cancelButton.addActionListener(new ActionListener() {
225                         public void actionPerformed(ActionEvent e) {
226                                 _dialog.dispose();
227                         }
228                 });
229                 KeyAdapter escapeListener = new KeyAdapter() {
230                         public void keyPressed(KeyEvent inE) {
231                                 if (inE.getKeyCode() == KeyEvent.VK_ESCAPE) {_dialog.dispose();}
232                         }
233                 };
234                 _combineButton.addKeyListener(escapeListener);
235                 cancelButton.addKeyListener(escapeListener);
236                 buttonPanel.add(cancelButton);
237                 dialogPanel.add(buttonPanel, BorderLayout.SOUTH);
238                 return dialogPanel;
239         }
240
241         /**
242          * Construct a rangestats object for the selected range
243          * @param inTrack track object
244          * @param inStartIndex start index
245          * @param inEndIndex end index
246          * @param inPreviousStartIndex the previously used start index, or -1
247          * @return range stats object or null if required information missing from this bit of the track
248          */
249         private RangeStatsWithGradients getRangeStats(Track inTrack, int inStartIndex,
250                 int inEndIndex, int inPreviousStartIndex)
251         {
252                 // Check parameters
253                 if (inTrack == null || inStartIndex < 0 || inEndIndex <= inStartIndex || inStartIndex > inTrack.getNumPoints()) {
254                         return null;
255                 }
256                 final int numPoints = inTrack.getNumPoints();
257                 int start = inStartIndex;
258
259                 // Search forward until a decent track point found for the start
260                 DataPoint p = inTrack.getPoint(start);
261                 while (start < numPoints && (p == null || p.isWaypoint() || !p.hasTimestamp() || !p.hasAltitude()))
262                 {
263                         start++;
264                         p = inTrack.getPoint(start);
265                 }
266                 if (inPreviousStartIndex >= 0 && start <= (inPreviousStartIndex + 10) // overlapping too much with previous range
267                         || (start >= (numPoints - 10))) // starting too late in the track
268                 {
269                         return null;
270                 }
271
272                 // Search forward (counting the radians) until a decent end point found
273                 double movingRads = 0.0;
274                 final double minimumRads = Distance.convertDistanceToRadians(1.0, UnitSetLibrary.UNITS_KILOMETRES);
275                 DataPoint prevPoint = inTrack.getPoint(start);
276                 int endIndex = start;
277                 boolean shouldStop = false;
278                 do
279                 {
280                         endIndex++;
281                         p = inTrack.getPoint(endIndex);
282                         if (p != null && !p.isWaypoint())
283                         {
284                                 if (!p.hasAltitude() || !p.hasTimestamp()) {return null;} // abort if no time/altitude
285                                 if (prevPoint != null && !p.getSegmentStart()) {
286                                         movingRads += DataPoint.calculateRadiansBetween(prevPoint, p);
287                                 }
288                         }
289                         prevPoint = p;
290                         if (endIndex >= numPoints) {
291                                 shouldStop = true; // reached the end of the track
292                         }
293                         else if (movingRads >= minimumRads && endIndex >= inEndIndex) {
294                                 shouldStop = true; // got at least a kilometre
295                         }
296                 }
297                 while (!shouldStop);
298
299                 // Check moving distance
300                 if (movingRads >= minimumRads) {
301                         return new RangeStatsWithGradients(inTrack, start, endIndex);
302                 }
303                 return null;
304         }
305
306         /**
307          * Build an A matrix for the given list of RangeStats objects
308          * @param inStatsList list of (non-null) RangeStats objects
309          * @return A matrix with n rows and 5 columns
310          */
311         private static Matrix buildAMatrix(ArrayList<RangeStatsWithGradients> inStatsList)
312         {
313                 final Unit METRES = UnitSetLibrary.UNITS_METRES;
314                 Matrix result = new Matrix(inStatsList.size(), 5);
315                 int row = 0;
316                 for (RangeStatsWithGradients stats : inStatsList)
317                 {
318                         result.setValue(row, 0, stats.getMovingDistanceKilometres());
319                         result.setValue(row, 1, stats.getGentleAltitudeRange().getClimb(METRES));
320                         result.setValue(row, 2, stats.getSteepAltitudeRange().getClimb(METRES));
321                         result.setValue(row, 3, stats.getGentleAltitudeRange().getDescent(METRES));
322                         result.setValue(row, 4, stats.getSteepAltitudeRange().getDescent(METRES));
323                         row++;
324                 }
325                 return result;
326         }
327
328         /**
329          * Build a B matrix containing the observations (moving times)
330          * @param inStatsList list of (non-null) RangeStats objects
331          * @return B matrix with single column of n rows
332          */
333         private static Matrix buildBMatrix(ArrayList<RangeStatsWithGradients> inStatsList)
334         {
335                 Matrix result = new Matrix(inStatsList.size(), 1);
336                 int row = 0;
337                 for (RangeStatsWithGradients stats : inStatsList)
338                 {
339                         result.setValue(row, 0, stats.getMovingDurationInSeconds() / 60.0); // convert seconds to minutes
340                         row++;
341                 }
342                 return result;
343         }
344
345         /**
346          * Look for the maximum absolute value in the given column matrix
347          * @param inMatrix matrix with only one column
348          * @return row index of cell with greatest absolute value, or -1 if not valid
349          */
350         private static int getIndexOfMaxValue(Matrix inMatrix)
351         {
352                 if (inMatrix == null || inMatrix.getNumColumns() > 1) {
353                         return -1;
354                 }
355                 int index = 0;
356                 double currValue = 0.0, maxValue = 0.0;
357                 // Loop over the first column looking for the maximum absolute value
358                 for (int i=0; i<inMatrix.getNumRows(); i++)
359                 {
360                         currValue = Math.abs(inMatrix.get(i, 0));
361                         if (currValue > maxValue)
362                         {
363                                 maxValue = currValue;
364                                 index = i;
365                         }
366                 }
367                 return index;
368         }
369
370         /**
371          * See if the given set of samples is sufficient for getting a descent solution (at least 3 nonzero values)
372          * @param inRangeSet list of RangeStats objects
373          * @param inRowToIgnore row index to ignore, or -1 to use them all
374          * @return true if the samples look ok
375          */
376         private static boolean isRangeSetSufficient(ArrayList<RangeStatsWithGradients> inRangeSet, int inRowToIgnore)
377         {
378                 // number of samples with gentle/steep climb/descent values > 0
379                 int numGC = 0, numSC = 0, numGD = 0, numSD = 0;
380                 final Unit METRES = UnitSetLibrary.UNITS_METRES;
381                 int i = 0;
382                 for (RangeStatsWithGradients stats : inRangeSet)
383                 {
384                         if (i != inRowToIgnore)
385                         {
386                                 if (stats.getGentleAltitudeRange().getClimb(METRES) > 0) {numGC++;}
387                                 if (stats.getSteepAltitudeRange().getClimb(METRES) > 0)  {numSC++;}
388                                 if (stats.getGentleAltitudeRange().getDescent(METRES) > 0) {numGD++;}
389                                 if (stats.getSteepAltitudeRange().getDescent(METRES) > 0)  {numSD++;}
390                         }
391                         i++;
392                 }
393                 return numGC > 3 && numSC > 3 && numGD > 3 && numSD > 3;
394         }
395
396         /**
397          * Reduce the number of samples in the given list by eliminating the ones with highest errors
398          * @param inStatsList list of stats
399          * @return results in an object
400          */
401         private MatrixResults reduceSamples(ArrayList<RangeStatsWithGradients> inStatsList)
402         {
403                 int statsIndexToRemove = -1;
404                 Matrix answer = null;
405                 boolean finished = false;
406                 double averageErrorPc = 0.0;
407                 while (!finished)
408                 {
409                         // Remove the marked stats object, if any
410                         if (statsIndexToRemove >= 0) {
411                                 inStatsList.remove(statsIndexToRemove);
412                         }
413
414                         // Build up the matrices
415                         Matrix A = buildAMatrix(inStatsList);
416                         Matrix B = buildBMatrix(inStatsList);
417                         // System.out.println("Times in minutes are:\n" + B.toString());
418
419                         // Solve (if possible)
420                         try
421                         {
422                                 answer = A.solve(B);
423                                 // System.out.println("Solved matrix with " + A.getNumRows() + " rows:\n" + answer.toString());
424                                 // Work out the percentage error for each estimate
425                                 Matrix estimates = A.times(answer);
426                                 Matrix errors = estimates.minus(B).divideEach(B);
427                                 // System.out.println("Errors: " + errors.toString());
428                                 averageErrorPc = errors.getAverageAbsValue();
429                                 // find biggest percentage error, remove it from list
430                                 statsIndexToRemove = getIndexOfMaxValue(errors);
431                                 if (statsIndexToRemove < 0)
432                                 {
433                                         System.err.println("Something wrong - index is " + statsIndexToRemove);
434                                         throw new Exception();
435                                 }
436                                 // Check whether removing this element would make the range set insufficient
437                                 finished = inStatsList.size() <= 25 || !isRangeSetSufficient(inStatsList, statsIndexToRemove);
438                         }
439                         catch (Exception e)
440                         {
441                                 // Couldn't solve at all
442                                 System.out.println("Failed to reduce: " + e.getClass().getName() + " - " + e.getMessage());
443                                 return null;
444                         }
445                         _progress.setValue(20 + 80 * (30 - inStatsList.size())/5); // Counting from 30 to 25
446                 }
447                 // Copy results to an EstimationParameters object
448                 MatrixResults result = new MatrixResults();
449                 result._parameters = new EstimationParameters();
450                 result._parameters.populateWithMetrics(answer.get(0, 0) * 5, // convert from 1km to 5km
451                         answer.get(1, 0) * 100.0, answer.get(2, 0) * 100.0,      // convert from m to 100m
452                         answer.get(3, 0) * 100.0, answer.get(4, 0) * 100.0);
453                 result._averageErrorPc = averageErrorPc;
454                 return result;
455         }
456
457
458         /**
459          * Populate the dialog's labels with the calculated values
460          * @param inResults results of the calculations
461          */
462         private void populateCalculatedValues(MatrixResults inResults)
463         {
464                 if (inResults == null || inResults._parameters == null)
465                 {
466                         _calculatedParams = null;
467                         _calculatedParamPanel.updateParameters(null, 0.0);
468                 }
469                 else
470                 {
471                         _calculatedParams = inResults._parameters;
472                         _calculatedParamPanel.updateParameters(_calculatedParams, inResults._averageErrorPc);
473                 }
474         }
475
476         /**
477          * Combine the calculated parameters with the existing ones
478          * according to the value of the slider
479          * @return combined parameters
480          */
481         private EstimationParameters calculateCombinedParameters()
482         {
483                 final double fraction1 = 1 - 0.1 * _weightSlider.getValue(); // slider left = value 0 = fraction 1 = keep current
484                 EstimationParameters oldParams = new EstimationParameters(Config.getConfigString(Config.KEY_ESTIMATION_PARAMS));
485                 return oldParams.combine(_calculatedParams, fraction1);
486         }
487
488         /**
489          * Update the labels to show the combined parameters
490          * @param inCombinedParams combined estimation parameters
491          */
492         private void updateCombinedLabels(EstimationParameters inCombinedParams)
493         {
494                 // Update the slider description label
495                 String sliderDesc = null;
496                 final int sliderVal = _weightSlider.getValue();
497                 switch (sliderVal)
498                 {
499                         case 0:  sliderDesc = I18nManager.getText("dialog.learnestimationparams.weight.100pccurrent"); break;
500                         case 5:  sliderDesc = I18nManager.getText("dialog.learnestimationparams.weight.50pc"); break;
501                         case 10: sliderDesc = I18nManager.getText("dialog.learnestimationparams.weight.100pccalculated"); break;
502                         default:
503                                 final int currTenths = 10 - sliderVal, calcTenths = sliderVal;
504                                 sliderDesc = "" + currTenths + "0% " + I18nManager.getText("dialog.learnestimationparams.weight.current")
505                                         + " + " + calcTenths + "0% " + I18nManager.getText("dialog.learnestimationparams.weight.calculated");
506                 }
507                 _sliderDescLabel.setText(sliderDesc);
508                 // And update all the combined params labels
509                 _combinedParamPanel.updateParameters(inCombinedParams);
510                 _combineButton.setEnabled(sliderVal > 0);
511         }
512
513         /**
514          * React to the combine button, by saving the combined parameters in the config
515          */
516         private void combineAndFinish()
517         {
518                 EstimationParameters params = calculateCombinedParameters();
519                 Config.setConfigString(Config.KEY_ESTIMATION_PARAMS, params.toConfigString());
520                 _dialog.dispose();
521         }
522 }