The purpose of this notebook is to illustrate how we can overcome the feature explosion problem based on an example dataset involving sensor data.
Summary:
Author: Dr. Patrick Urbanke
The feature explosion problem is one of the most important issues in automated feature engineering. In fact, it is probably the main reason why automated feature engineering is not already the norm in data science projects involving business data.
To illustrate the problem, consider how data scientists write features for a simple time series problem:
SELECT SOME_AGGREGATION(t2.some_column)
FROM some_table t1
LEFT JOIN some_table t2
ON t1.join_key = t2.join_key
WHERE t2.some_other_column >= some_value
AND t2.rowid <= t1.rowid
AND t2.rowid + some_other_value > t1.rowid
GROUP BY t1.rowid;
Think about that for a second.
Every column that we have can either be aggregated (some_column) or it can be used for our conditions (some_other_column). That means if we have n columns to aggregate, we can potentially build conditions for $n$ other columns. In other words, the computational complexity is $n^2$ in the number of columns.
Note that this problem occurs regardless of whether you automate feature engineering or you do it by hand. The size of the search space is $n^2$ in the number of columns in either case, unless you can rule something out a-priori.
This problem is known as feature explosion.
So when we have relational data or time series with many columns, what do we do? The answer is to write different features. Specifically, suppose we had features like this:
SELECT SOME_AGGREGATION(
CASE
WHEN t2.some_column > some_value THEN weight1
WHEN t2.some_column <= some_value THEN weight2
END
)
FROM some_table t1
LEFT JOIN some_table t2
ON t1.join_key = t2.join_key
WHERE t2.rowid <= t1.rowid
AND t2.rowid + some_other_value > t1.rowid
GROUP BY t1.rowid;
weight1 and weight2 are learnable weights. An algorithm that generates features like this can only use columns for conditions, it is not allowed to aggregate columns – and it doesn't need to do so.
That means the computational complexity is linear instead of quadratic. For data sets with a large number of columns this can make all the difference in the world. For instance, if you have 100 columns the size of the search space of the second approach is only 1% of the size of the search space of the first one.
getML features an algorithm called relboost, which generates features according to this principle and is therefore very suitable for data sets with many columns.
To illustrate the problem, we use a data set related to robotics. When robots interact with humans, the most important think is that they don't hurt people. In order to prevent such accidents, the force vector on the robot's arm is measured. However, measuring the force vector is expensive.
Therefore, we want consider an alternative approach. We would like to predict the force vector based on other sensor data that are less costly to measure. To do so, we use machine learning.
However, the data set contains measurements from almost 100 different sensors and we do not know which and how many sensors are relevant for predicting the force vector.
The data set has been generously provided by Erik Berger who originally collected it for his dissertation:
Berger, E. (2018). Behavior-Specific Proprioception Models for Robotic Force Estimation: A Machine Learning Approach. Freiberg, Germany: Technische Universitaet Bergakademie Freiberg.
We begin by importing the libraries and setting the project.
import datetime
import os
from urllib import request
import pandas as pd
import numpy as np
import getml
import getml.data as data
import getml.database as database
import getml.engine as engine
import getml.feature_learning.aggregations as agg
import getml.data.roles as roles
from IPython.display import Image
import matplotlib.pyplot as plt
plt.style.use('seaborn')
%matplotlib inline
getml.engine.launch()
getml.engine.set_project('robot')
Launched the getML engine. The log output will be stored in /home/patrick/.getML/logs/20220324212608.log. Connected to project 'robot'
data_all = getml.data.DataFrame.from_csv(
"https://static.getml.com/datasets/robotarm/robot-demo.csv",
"data_all"
)
Downloading robot-demo.csv... [========================================] 100%
data_all
name | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | f_x | f_y | f_z |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
role | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float |
0 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9605 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1233 | -6.5483 | -2.8045 | -0.8296 | 0.07625 | -0.1906 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1983 | 0.7699 | 0.41 | 0.08279 | -1.4094 | 0.786 | -0.3682 | 0 | 0 | 0 | 0 | 0 | 0 | -22.654 | -11.503 | -18.673 | -3.5155 | 5.8354 | -2.05 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.069 | 48.009 | 0.9668 | 47.834 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 | -11.03 | 6.9 | -7.33 |
1 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9604 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1188 | -6.5506 | -2.8404 | -0.8281 | 0.06405 | -0.1998 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1983 | 0.7699 | 0.41 | 0.0828 | -1.4094 | 0.7859 | -0.3682 | 0 | 0 | 0 | 0 | 0 | 0 | -21.627 | -11.046 | -18.66 | -3.5395 | 5.7577 | -1.9805 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.009 | 48.009 | 0.8594 | 47.834 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 | -10.848 | 6.7218 | -7.4427 |
2 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9605 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1099 | -6.5438 | -2.8 | -0.8205 | 0.07473 | -0.183 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1922 | 0.7699 | 0.41 | 0.08279 | -1.4094 | 0.7859 | -0.3682 | 0 | 0 | 0 | 0 | 0 | 0 | -23.843 | -12.127 | -18.393 | -3.6453 | 5.978 | -1.9978 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.009 | 48.069 | 0.931 | 47.879 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 | -10.666 | 6.5436 | -7.5555 |
3 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3273 | 0.9604 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1233 | -6.5483 | -2.8224 | -0.8266 | 0.07168 | -0.1998 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1967 | 0.7699 | 0.41 | 0.08275 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | -21.772 | -10.872 | -18.691 | -3.5512 | 5.6648 | -1.9976 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.069 | 48.069 | 0.931 | 47.879 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 | -10.507 | 6.4533 | -7.65 |
4 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9604 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1255 | -6.5394 | -2.8 | -0.8327 | 0.07473 | -0.1952 | 0.1211 | -6.5483 | -2.8157 | -0.8327 | 0.07015 | -0.1922 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | -22.823 | -11.645 | -18.524 | -3.5305 | 5.8712 | -2.0096 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.069 | 48.069 | 0.8952 | 47.879 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 | -10.413 | 6.6267 | -7.69 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | |
14996 | 3.0837 | -0.8836 | 1.4501 | -2.2102 | -1.559 | -5.3265 | -0.03151 | -0.05375 | 0.04732 | 0.1482 | -0.05218 | 0.06706 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3694 | -4.1879 | -1.1847 | -0.09441 | -0.1568 | 0.1898 | 1.1605 | -42.951 | -19.023 | -2.6343 | 0.1551 | -0.1338 | 3.0836 | -0.8836 | 1.4503 | -2.2101 | -1.5591 | -5.3263 | -0.03347 | -0.05585 | 0.04805 | 0.151 | -0.05513 | 0.07114 | -0.3564 | -6.0394 | -2.3001 | -0.2181 | -0.1159 | 0.09608 | -0.3632 | -6.0394 | -2.3023 | -0.212 | -0.125 | 0.1113 | 0.7116 | 0.06957 | 0.06036 | -0.8506 | 2.9515 | -0.03352 | -0.03558 | -0.03029 | 0.002444 | -0.04208 | 0.1458 | -0.1098 | -0.8784 | -0.07291 | -37.584 | 0.0001132 | -2.1031 | 0.03318 | 0.7117 | 0.0697 | 0.06044 | -0.8511 | 2.951 | -0.03356 | -0.03508 | -0.02849 | 0.001571 | -0.03951 | 0.1442 | -0.1036 | 48.069 | 48.009 | 0.8952 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 | 10.84 | -1.41 | 16.14 |
14997 | 3.0835 | -0.884 | 1.4505 | -2.2091 | -1.5594 | -5.326 | -0.02913 | -0.0497 | 0.04376 | 0.137 | -0.04825 | 0.062 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3677 | -4.1837 | -1.1874 | -0.09682 | -0.1562 | 0.189 | 1.1592 | -42.937 | -19.023 | -2.6331 | 0.1545 | -0.1338 | 3.0833 | -0.8841 | 1.4507 | -2.209 | -1.5596 | -5.3258 | -0.02909 | -0.04989 | 0.04198 | 0.1481 | -0.05465 | 0.06249 | -0.3161 | -6.1179 | -2.253 | -0.3752 | -0.03965 | 0.08693 | -0.3273 | -6.1022 | -2.2597 | -0.366 | -0.05033 | 0.0915 | 0.7114 | 0.06932 | 0.06039 | -0.8497 | 2.953 | -0.03359 | -0.0335 | -0.02723 | 0.001208 | -0.04242 | 0.1428 | -0.0967 | -2.7137 | 0.8552 | -38.514 | -0.6088 | -3.2383 | -0.9666 | 0.7114 | 0.06948 | 0.06045 | -0.8503 | 2.9525 | -0.03359 | -0.03246 | -0.02633 | 0.001469 | -0.03657 | 0.1333 | -0.09571 | 48.009 | 48.009 | 0.8594 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 | 10.857 | -1.52 | 15.943 |
14998 | 3.0833 | -0.8844 | 1.4508 | -2.208 | -1.5598 | -5.3256 | -0.02676 | -0.04565 | 0.04019 | 0.1258 | -0.04431 | 0.05695 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3659 | -4.1797 | -1.1901 | -0.09922 | -0.1555 | 0.1881 | 1.1579 | -42.924 | -19.023 | -2.6321 | 0.154 | -0.1338 | 3.0831 | -0.8844 | 1.451 | -2.2078 | -1.56 | -5.3253 | -0.02776 | -0.04382 | 0.03652 | 0.1295 | -0.05064 | 0.04818 | -0.343 | -6.2569 | -2.1566 | -0.3035 | 0.00305 | 0.1434 | -0.3385 | -6.2322 | -2.1589 | -0.302 | -0.00915 | 0.1571 | 0.7111 | 0.06912 | 0.06039 | -0.849 | 2.9544 | -0.0337 | -0.02911 | -0.02589 | 0.001292 | -0.04046 | 0.1246 | -0.08058 | 4.2749 | 1.0128 | -36.412 | -1.2811 | -0.4296 | -1.1013 | 0.7112 | 0.06928 | 0.06046 | -0.8495 | 2.9538 | -0.03362 | -0.02984 | -0.02417 | 0.001364 | -0.03362 | 0.1224 | -0.08786 | 48.009 | 48.009 | 0.931 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 | 10.89 | -1.74 | 15.55 |
14999 | 3.0831 | -0.8847 | 1.4511 | -2.2071 | -1.5602 | -5.3251 | -0.02438 | -0.0416 | 0.03662 | 0.1147 | -0.04038 | 0.0519 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3642 | -4.1758 | -1.1928 | -0.1016 | -0.1548 | 0.1873 | 1.1568 | -42.912 | -19.023 | -2.6311 | 0.1535 | -0.1338 | 3.0829 | -0.8848 | 1.4513 | -2.2068 | -1.5604 | -5.3249 | -0.02149 | -0.04059 | 0.03417 | 0.1202 | -0.0395 | 0.04178 | -0.4237 | -6.2703 | -2.0939 | -0.302 | -0.01372 | 0.1739 | -0.4125 | -6.2569 | -2.0916 | -0.2943 | -0.02898 | 0.1891 | 0.7109 | 0.06894 | 0.06039 | -0.8484 | 2.9557 | -0.03384 | -0.02738 | -0.01982 | 0.001031 | -0.03028 | 0.1157 | -0.06702 | 11.518 | 1.5002 | -39.314 | -1.8671 | -0.3734 | -0.5733 | 0.7109 | 0.06909 | 0.06047 | -0.8488 | 2.955 | -0.03364 | -0.02721 | -0.02201 | 0.001255 | -0.03067 | 0.1115 | -0.08003 | 48.009 | 48.009 | 0.931 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 | 11.29 | -1.4601 | 15.743 |
15000 | 3.0829 | -0.885 | 1.4514 | -2.2062 | -1.5605 | -5.3247 | -0.02201 | -0.03755 | 0.03305 | 0.1035 | -0.03645 | 0.04684 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3624 | -4.172 | -1.1955 | -0.1041 | -0.1542 | 0.1864 | 1.1558 | -42.901 | -19.023 | -2.6302 | 0.1531 | -0.1338 | 3.0827 | -0.8851 | 1.4516 | -2.2059 | -1.5607 | -5.3246 | -0.02096 | -0.03808 | 0.02958 | 0.1171 | -0.03289 | 0.03883 | -0.417 | -6.2434 | -2.058 | -0.4102 | -0.04728 | 0.1967 | -0.4237 | -6.2367 | -2.0714 | -0.4163 | -0.0671 | 0.2059 | 0.7107 | 0.06878 | 0.06041 | -0.8478 | 2.9567 | -0.03382 | -0.02535 | -0.01854 | 0.001614 | -0.02421 | 0.11 | -0.06304 | 15.099 | 2.936 | -39.068 | -1.9402 | 0.139 | -0.2674 | 0.7107 | 0.06893 | 0.06048 | -0.8482 | 2.9561 | -0.03367 | -0.02458 | -0.01986 | 0.001142 | -0.0277 | 0.1007 | -0.07221 | 48.009 | 48.069 | 0.8952 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.955 | 11.69 | -1.1801 | 15.937 |
15001 rows x 96 columns
memory usage: 11.52 MB
name: data_all
type: getml.DataFrame
url: http://localhost:1709/#/getdataframe/robot/data_all/
The force vector consists of three component (f_x, f_y and f_z), meaning that we have three targets.
data_all.set_role(["f_x", "f_y", "f_z"], getml.data.roles.target)
data_all.set_role(data_all.roles.unused, getml.data.roles.numerical)
This is what the data set looks like:
data_all
name | f_x | f_y | f_z | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
role | target | target | target | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical |
0 | -11.03 | 6.9 | -7.33 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9605 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1233 | -6.5483 | -2.8045 | -0.8296 | 0.07625 | -0.1906 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1983 | 0.7699 | 0.41 | 0.08279 | -1.4094 | 0.786 | -0.3682 | 0 | 0 | 0 | 0 | 0 | 0 | -22.654 | -11.503 | -18.673 | -3.5155 | 5.8354 | -2.05 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.069 | 48.009 | 0.9668 | 47.834 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 |
1 | -10.848 | 6.7218 | -7.4427 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9604 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1188 | -6.5506 | -2.8404 | -0.8281 | 0.06405 | -0.1998 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1983 | 0.7699 | 0.41 | 0.0828 | -1.4094 | 0.7859 | -0.3682 | 0 | 0 | 0 | 0 | 0 | 0 | -21.627 | -11.046 | -18.66 | -3.5395 | 5.7577 | -1.9805 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.009 | 48.009 | 0.8594 | 47.834 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 |
2 | -10.666 | 6.5436 | -7.5555 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9605 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1099 | -6.5438 | -2.8 | -0.8205 | 0.07473 | -0.183 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1922 | 0.7699 | 0.41 | 0.08279 | -1.4094 | 0.7859 | -0.3682 | 0 | 0 | 0 | 0 | 0 | 0 | -23.843 | -12.127 | -18.393 | -3.6453 | 5.978 | -1.9978 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.009 | 48.069 | 0.931 | 47.879 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 |
3 | -10.507 | 6.4533 | -7.65 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3273 | 0.9604 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1233 | -6.5483 | -2.8224 | -0.8266 | 0.07168 | -0.1998 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1967 | 0.7699 | 0.41 | 0.08275 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | -21.772 | -10.872 | -18.691 | -3.5512 | 5.6648 | -1.9976 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.069 | 48.069 | 0.931 | 47.879 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 |
4 | -10.413 | 6.6267 | -7.69 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9604 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1255 | -6.5394 | -2.8 | -0.8327 | 0.07473 | -0.1952 | 0.1211 | -6.5483 | -2.8157 | -0.8327 | 0.07015 | -0.1922 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | -22.823 | -11.645 | -18.524 | -3.5305 | 5.8712 | -2.0096 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.069 | 48.069 | 0.8952 | 47.879 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | |
14996 | 10.84 | -1.41 | 16.14 | 3.0837 | -0.8836 | 1.4501 | -2.2102 | -1.559 | -5.3265 | -0.03151 | -0.05375 | 0.04732 | 0.1482 | -0.05218 | 0.06706 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3694 | -4.1879 | -1.1847 | -0.09441 | -0.1568 | 0.1898 | 1.1605 | -42.951 | -19.023 | -2.6343 | 0.1551 | -0.1338 | 3.0836 | -0.8836 | 1.4503 | -2.2101 | -1.5591 | -5.3263 | -0.03347 | -0.05585 | 0.04805 | 0.151 | -0.05513 | 0.07114 | -0.3564 | -6.0394 | -2.3001 | -0.2181 | -0.1159 | 0.09608 | -0.3632 | -6.0394 | -2.3023 | -0.212 | -0.125 | 0.1113 | 0.7116 | 0.06957 | 0.06036 | -0.8506 | 2.9515 | -0.03352 | -0.03558 | -0.03029 | 0.002444 | -0.04208 | 0.1458 | -0.1098 | -0.8784 | -0.07291 | -37.584 | 0.0001132 | -2.1031 | 0.03318 | 0.7117 | 0.0697 | 0.06044 | -0.8511 | 2.951 | -0.03356 | -0.03508 | -0.02849 | 0.001571 | -0.03951 | 0.1442 | -0.1036 | 48.069 | 48.009 | 0.8952 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 |
14997 | 10.857 | -1.52 | 15.943 | 3.0835 | -0.884 | 1.4505 | -2.2091 | -1.5594 | -5.326 | -0.02913 | -0.0497 | 0.04376 | 0.137 | -0.04825 | 0.062 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3677 | -4.1837 | -1.1874 | -0.09682 | -0.1562 | 0.189 | 1.1592 | -42.937 | -19.023 | -2.6331 | 0.1545 | -0.1338 | 3.0833 | -0.8841 | 1.4507 | -2.209 | -1.5596 | -5.3258 | -0.02909 | -0.04989 | 0.04198 | 0.1481 | -0.05465 | 0.06249 | -0.3161 | -6.1179 | -2.253 | -0.3752 | -0.03965 | 0.08693 | -0.3273 | -6.1022 | -2.2597 | -0.366 | -0.05033 | 0.0915 | 0.7114 | 0.06932 | 0.06039 | -0.8497 | 2.953 | -0.03359 | -0.0335 | -0.02723 | 0.001208 | -0.04242 | 0.1428 | -0.0967 | -2.7137 | 0.8552 | -38.514 | -0.6088 | -3.2383 | -0.9666 | 0.7114 | 0.06948 | 0.06045 | -0.8503 | 2.9525 | -0.03359 | -0.03246 | -0.02633 | 0.001469 | -0.03657 | 0.1333 | -0.09571 | 48.009 | 48.009 | 0.8594 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 |
14998 | 10.89 | -1.74 | 15.55 | 3.0833 | -0.8844 | 1.4508 | -2.208 | -1.5598 | -5.3256 | -0.02676 | -0.04565 | 0.04019 | 0.1258 | -0.04431 | 0.05695 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3659 | -4.1797 | -1.1901 | -0.09922 | -0.1555 | 0.1881 | 1.1579 | -42.924 | -19.023 | -2.6321 | 0.154 | -0.1338 | 3.0831 | -0.8844 | 1.451 | -2.2078 | -1.56 | -5.3253 | -0.02776 | -0.04382 | 0.03652 | 0.1295 | -0.05064 | 0.04818 | -0.343 | -6.2569 | -2.1566 | -0.3035 | 0.00305 | 0.1434 | -0.3385 | -6.2322 | -2.1589 | -0.302 | -0.00915 | 0.1571 | 0.7111 | 0.06912 | 0.06039 | -0.849 | 2.9544 | -0.0337 | -0.02911 | -0.02589 | 0.001292 | -0.04046 | 0.1246 | -0.08058 | 4.2749 | 1.0128 | -36.412 | -1.2811 | -0.4296 | -1.1013 | 0.7112 | 0.06928 | 0.06046 | -0.8495 | 2.9538 | -0.03362 | -0.02984 | -0.02417 | 0.001364 | -0.03362 | 0.1224 | -0.08786 | 48.009 | 48.009 | 0.931 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 |
14999 | 11.29 | -1.4601 | 15.743 | 3.0831 | -0.8847 | 1.4511 | -2.2071 | -1.5602 | -5.3251 | -0.02438 | -0.0416 | 0.03662 | 0.1147 | -0.04038 | 0.0519 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3642 | -4.1758 | -1.1928 | -0.1016 | -0.1548 | 0.1873 | 1.1568 | -42.912 | -19.023 | -2.6311 | 0.1535 | -0.1338 | 3.0829 | -0.8848 | 1.4513 | -2.2068 | -1.5604 | -5.3249 | -0.02149 | -0.04059 | 0.03417 | 0.1202 | -0.0395 | 0.04178 | -0.4237 | -6.2703 | -2.0939 | -0.302 | -0.01372 | 0.1739 | -0.4125 | -6.2569 | -2.0916 | -0.2943 | -0.02898 | 0.1891 | 0.7109 | 0.06894 | 0.06039 | -0.8484 | 2.9557 | -0.03384 | -0.02738 | -0.01982 | 0.001031 | -0.03028 | 0.1157 | -0.06702 | 11.518 | 1.5002 | -39.314 | -1.8671 | -0.3734 | -0.5733 | 0.7109 | 0.06909 | 0.06047 | -0.8488 | 2.955 | -0.03364 | -0.02721 | -0.02201 | 0.001255 | -0.03067 | 0.1115 | -0.08003 | 48.009 | 48.009 | 0.931 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 |
15000 | 11.69 | -1.1801 | 15.937 | 3.0829 | -0.885 | 1.4514 | -2.2062 | -1.5605 | -5.3247 | -0.02201 | -0.03755 | 0.03305 | 0.1035 | -0.03645 | 0.04684 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3624 | -4.172 | -1.1955 | -0.1041 | -0.1542 | 0.1864 | 1.1558 | -42.901 | -19.023 | -2.6302 | 0.1531 | -0.1338 | 3.0827 | -0.8851 | 1.4516 | -2.2059 | -1.5607 | -5.3246 | -0.02096 | -0.03808 | 0.02958 | 0.1171 | -0.03289 | 0.03883 | -0.417 | -6.2434 | -2.058 | -0.4102 | -0.04728 | 0.1967 | -0.4237 | -6.2367 | -2.0714 | -0.4163 | -0.0671 | 0.2059 | 0.7107 | 0.06878 | 0.06041 | -0.8478 | 2.9567 | -0.03382 | -0.02535 | -0.01854 | 0.001614 | -0.02421 | 0.11 | -0.06304 | 15.099 | 2.936 | -39.068 | -1.9402 | 0.139 | -0.2674 | 0.7107 | 0.06893 | 0.06048 | -0.8482 | 2.9561 | -0.03367 | -0.02458 | -0.01986 | 0.001142 | -0.0277 | 0.1007 | -0.07221 | 48.009 | 48.069 | 0.8952 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.955 |
15001 rows x 96 columns
memory usage: 11.52 MB
name: data_all
type: getml.DataFrame
url: http://localhost:1709/#/getdataframe/robot/data_all/
We also want to separate the data set into a training and testing set. We do so by using the first 10,500 measurements for training and then using the remainder for testing.
split = getml.data.split.time(data_all, "rowid", test=10500)
split
0 | train |
---|---|
1 | train |
2 | train |
3 | train |
4 | train |
... |
15001 rows
type: StringColumnView
time_series = getml.data.TimeSeries(
population=data_all,
split=split,
time_stamps="rowid",
lagged_targets=False,
memory=30,
)
time_series
data frames | staging table | |
---|---|---|
0 | population | POPULATION__STAGING_TABLE_1 |
1 | data_all | DATA_ALL__STAGING_TABLE_2 |
subset | name | rows | type | |
---|---|---|---|---|
0 | test | data_all | 4501 | View |
1 | train | data_all | 10500 | View |
name | rows | type | |
---|---|---|---|
0 | data_all | 15001 | View |
relboost = getml.feature_learning.Relboost(
loss_function=getml.feature_learning.loss_functions.SquareLoss,
num_features=10,
)
xgboost = getml.predictors.XGBoostRegressor()
pipe1 = getml.pipeline.Pipeline(
data_model=time_series.data_model,
feature_learners=[relboost],
predictors=xgboost
)
It is always a good idea to check the pipeline for any potential issues.
pipe1.check(time_series.train)
Checking data model... Staging... [========================================] 100% Checking... [========================================] 100% OK.
pipe1.fit(time_series.train)
Checking data model... Staging... [========================================] 100% OK. Staging... [========================================] 100% Relboost: Training features... [========================================] 100% Relboost: Training features... [========================================] 100% Relboost: Training features... [========================================] 100% Relboost: Building features... [========================================] 100% Relboost: Building features... [========================================] 100% Relboost: Building features... [========================================] 100% XGBoost: Training as predictor... [========================================] 100% XGBoost: Training as predictor... [========================================] 100% XGBoost: Training as predictor... [========================================] 100% Trained pipeline. Time taken: 0h:4m:2.339313
Pipeline(data_model='population', feature_learners=['Relboost'], feature_selectors=[], include_categorical=False, loss_function=None, peripheral=['data_all'], predictors=['XGBoostRegressor'], preprocessors=[], share_selected_features=0.5, tags=['container-y0rs18'])
url: http://localhost:1709/#/getpipeline/robot/kDSQEy/0/
pipe1.score(time_series.test)
Staging... [========================================] 100% Relboost: Building features... [========================================] 100% Relboost: Building features... [========================================] 100% Relboost: Building features... [========================================] 100%
date time | set used | target | mae | rmse | rsquared | |
---|---|---|---|---|---|---|
0 | 2022-03-24 21:30:20 | train | f_x | 0.4449 | 0.5852 | 0.9962 |
1 | 2022-03-24 21:30:20 | train | f_y | 0.5159 | 0.6813 | 0.9893 |
2 | 2022-03-24 21:30:20 | train | f_z | 0.275 | 0.3576 | 0.9988 |
3 | 2022-03-24 21:30:37 | test | f_x | 0.5681 | 0.7374 | 0.9949 |
4 | 2022-03-24 21:30:37 | test | f_y | 0.5654 | 0.7516 | 0.9871 |
5 | 2022-03-24 21:30:37 | test | f_z | 0.2957 | 0.3845 | 0.9986 |
It is always a good idea to study the features the relational learning algorithm has extracted.
The feature importance is calculated by xgboost based on the improvement of the optimizing criterium at each split in the decision tree and is normalized to 100%.
Also note that we have three different target (f_x, f_y and f_z) and that different features are relevant for different targets.
plt.subplots(figsize=(20, 10))
names, importances = pipe1.features.importances(target_num=0)
plt.bar(names[0:30], importances[0:30], color='#6829c2')
plt.title("feature importances for the x-component", size=20)
plt.grid(True)
plt.xlabel("features")
plt.ylabel("importance")
plt.xticks(rotation='vertical')
plt.show()
plt.subplots(figsize=(20, 10))
names, importances = pipe1.features.importances(target_num=1)
plt.bar(names[0:30], importances[0:30], color='#6829c2')
plt.title("feature importances for the y-component", size=20)
plt.grid(True)
plt.xlabel("features")
plt.ylabel("importance")
plt.xticks(rotation='vertical')
plt.show()
plt.subplots(figsize=(20, 10))
names, importances = pipe1.features.importances(target_num=2)
plt.bar(names[0:30], importances[0:30], color='#6829c2')
plt.title("feature importances for the z-component", size=20)
plt.grid(True)
plt.xlabel("features")
plt.ylabel("importance")
plt.xticks(rotation='vertical')
plt.show()
Because getML is a tool for relational learning, we can also calculate the importances for the original columns, using similar methods we have used for the feature importances.
plt.subplots(figsize=(20, 10))
names, importances = pipe1.columns.importances(target_num=0)
plt.bar(names[0:30], importances[0:30], color='#6829c2')
plt.title("column importances for the x-component", size=20)
plt.grid(True)
plt.xlabel("column")
plt.ylabel("importance")
plt.xticks(rotation='vertical')
plt.show()
plt.subplots(figsize=(20, 10))
names, importances = pipe1.columns.importances(target_num=1)
plt.bar(names[0:30], importances[0:30], color='#6829c2')
plt.title("column importances for the y-component", size=20)
plt.grid(True)
plt.xlabel("column")
plt.ylabel("importance")
plt.xticks(rotation='vertical')
plt.show()
plt.subplots(figsize=(20, 10))
names, importances = pipe1.columns.importances(target_num=2)
plt.bar(names[0:30], importances[0:30], color='#6829c2')
plt.title("column importances for the z-component", size=20)
plt.grid(True)
plt.xlabel("column")
plt.ylabel("importance")
plt.xticks(rotation='vertical')
plt.show()
When we study the plots for the column importances we find that there are some good news. We actually don't need that many columns. About 80% of the columns contain very little predictive value.
This means that we can also apply other algorithms that are not as scalable as relboost. All we have to do is to select the most relevant columns:
The .select(...)
returns a new column, in which the unimportant columns have been dropped:
time_series2 = pipe1.columns.select(time_series, share_selected_columns=0.35)
time_series2
data frames | staging table | |
---|---|---|
0 | population | POPULATION__STAGING_TABLE_1 |
1 | data_all | DATA_ALL__STAGING_TABLE_2 |
subset | name | rows | type | |
---|---|---|---|---|
0 | train | data_all | 10500 | View |
1 | test | data_all | 4501 | View |
name | rows | type | |
---|---|---|---|
0 | data_all | 15001 | View |
The multirel algorithm does scale well do data sets with many columns. As we have discussed in the introduction, its computational complexity is $n^2$ in the number of columns. But now, we only use 35% of the original columns, meaning that it is fine to use multirel.
multirel = getml.feature_learning.Multirel(
loss_function=getml.feature_learning.loss_functions.SquareLoss,
num_features=10,
)
relboost = getml.feature_learning.Relboost(
loss_function=getml.feature_learning.loss_functions.SquareLoss,
num_features=10,
)
xgboost = getml.predictors.XGBoostRegressor(n_jobs=7)
pipe2 = getml.pipeline.Pipeline(
data_model=time_series2.data_model,
feature_learners=[multirel, relboost],
predictors=xgboost
)
pipe2.check(time_series2.train)
Checking data model... Staging... [========================================] 100% Checking... [========================================] 100% INFO [MIGHT TAKE LONG]: DATA_ALL__STAGING_TABLE_2 contains 27 categorical and numerical columns. Please note that columns created by the preprocessors are also part of this count. The multirel algorithm does not scale very well to data frames with many columns. This pipeline might take a very long time to fit. You should consider removing some columns or preprocessors. You could use a column selection to pick the right columns. You could also replace Multirel with Relboost. Relboost has been designed to scale well to data frames with many columns.
pipe2.fit(time_series2.train)
Checking data model... Staging... [========================================] 100% INFO [MIGHT TAKE LONG]: DATA_ALL__STAGING_TABLE_2 contains 27 categorical and numerical columns. Please note that columns created by the preprocessors are also part of this count. The multirel algorithm does not scale very well to data frames with many columns. This pipeline might take a very long time to fit. You should consider removing some columns or preprocessors. You could use a column selection to pick the right columns. You could also replace Multirel with Relboost. Relboost has been designed to scale well to data frames with many columns. Staging... [========================================] 100% Multirel: Training features... [========================================] 100% Relboost: Training features... [========================================] 100% Relboost: Training features... [========================================] 100% Relboost: Training features... [========================================] 100% Multirel: Building features... [========================================] 100% Relboost: Building features... [========================================] 100% Relboost: Building features... [========================================] 100% Relboost: Building features... [========================================] 100% XGBoost: Training as predictor... [========================================] 100% XGBoost: Training as predictor... [========================================] 100% XGBoost: Training as predictor... [========================================] 100% Trained pipeline. Time taken: 0h:1m:36.437623
Pipeline(data_model='population', feature_learners=['Multirel', 'Relboost'], feature_selectors=[], include_categorical=False, loss_function=None, peripheral=['data_all'], predictors=['XGBoostRegressor'], preprocessors=[], share_selected_features=0.5, tags=['container-LSE4VK'])
url: http://localhost:1709/#/getpipeline/robot/EQdxaT/0/
pipe2.score(time_series2.test)
Staging... [========================================] 100% Multirel: Building features... [========================================] 100% Relboost: Building features... [========================================] 100% Relboost: Building features... [========================================] 100% Relboost: Building features... [========================================] 100%
date time | set used | target | mae | rmse | rsquared | |
---|---|---|---|---|---|---|
0 | 2022-03-24 21:32:19 | train | f_x | 0.4488 | 0.5888 | 0.9961 |
1 | 2022-03-24 21:32:19 | train | f_y | 0.5262 | 0.6887 | 0.989 |
2 | 2022-03-24 21:32:19 | train | f_z | 0.2722 | 0.3572 | 0.9988 |
3 | 2022-03-24 21:32:30 | test | f_x | 0.5642 | 0.7379 | 0.9949 |
4 | 2022-03-24 21:32:30 | test | f_y | 0.5689 | 0.7563 | 0.987 |
5 | 2022-03-24 21:32:30 | test | f_z | 0.2955 | 0.3838 | 0.9986 |
Sometimes a picture says more than a 1000 words. We therefore want to visualize our predictions on the testing set.
f_x = time_series2.test.population["f_x"].to_numpy()
f_y = time_series2.test.population["f_y"].to_numpy()
f_z = time_series2.test.population["f_z"].to_numpy()
predictions = pipe2.predict(time_series2.test)
Staging... [========================================] 100% Multirel: Building features... [========================================] 100% Relboost: Building features... [========================================] 100% Relboost: Building features... [========================================] 100% Relboost: Building features... [========================================] 100%
col_data = 'black'
col_getml = 'darkviolet'
col_getml_alt = 'coral'
plt.subplots(figsize=(20, 10))
plt.title("x-component of the force vector", size=20)
plt.plot(f_x, label="ground truth", color=col_data)
plt.plot(predictions[:,0], label="prediction",color=col_getml)
plt.legend(loc="upper right", fontsize=16)
<matplotlib.legend.Legend at 0x7f9aa17faeb0>