Linear Regression in Snowflake
A simple linear regression model is a linear regression model where there is only one independent variable $x$. The dependent variable $y$ is modelled as a linear function of $x$. More simply put the relationship between $x$ and $y$ is described as a straight line with slope $\beta$ (a.k.a gradient) and intercept $\alpha$.
$$ y=\alpha + \beta x $$
In regression we want to find the optimal intercept $\alpha$ and slope $\beta$ to minimise the sum of square errors (or residuals). The residuals are the distances from each point vertically to the model line.
Snowflake ❄️ has some really helpful functions to help with simple linear regression. The three key regression functions are REGR_SLOPE
, REGR_INTERCEPT
and REGR_R2
used to find the optimal slope, intercept and corresponding r-squared respectively but the rest are useful helpers functions!
REGR_VALX
andREGR_VALY
REGR_COUNT
REGR_AVGX
andREGR_AVGY
REGR_SXX
,REGR_SYY
andREGR_SXY
REGR_SLOPE
,REGR_INTERCEPT
andREGR_R2
All of these functions have the same call signature function_name(y, x)
where y is the dependent variable and x is the independent variable in the regression.
The following chart visualises simple linear regression and allows you to vary the slope and intercept and see the impact on the r-squared and other statistics below the chart.
Statistic | Formula | Value |
---|---|---|
Sum of square errors (SSE) | $\sum (y_i - \hat{y}_i)^2$ | 100 |
Sum of square deviation of model from mean (SSM) | $\sum (\hat{y}_i - \bar{y})^2$ | 100 |
r squared (Coefficient of determination) | $1 - \frac{SSE}{SST}$ | 100 |
Sample size (number of data points) (REGR_COUNT) | $n$ | 10 |
The following stats are for the optimal regression line | ||
Mean of $x$ (REGR_AVGX) | $\bar{x} = \frac{1}{n}\sum{x_i}$ | 100 |
Mean of $y$ (REGR_AVGY) | $\bar{y} = \frac{1}{n}\sum{y_i}$ | 100 |
$S_{XX}$ (REGR_SXX) | $\sum (x_i - \bar{x})^2$ | 100 |
$S_{YY}$ (REGR_SYY) | $\sum (y_i - \bar{y})^2$ | 100 |
$S_{XY}$ (REGR_SXY) | $\sum (x_i - \bar{x})(y_i - \bar{y})$ | 100 |
Optimal gradient $\beta$ (REGR_SLOPE) | $\frac{S_{XY}}{S_{XX}}$ | 100 |
Optimal intercept $\alpha$ (REGR_INTERCEPT) | $\bar{y} - \beta\bar{x}$ | 100 |
Optimal r squared - for line with $\alpha$ and $\beta$ (REGR_R2) | $\frac{S_{XY}^2}{S_{XX}S_{YY}}$ | 100 |
Correlation | $\frac{S_{XY}}{\sqrt{S_{XX}S_{YY}}}$ | 100 |
Before we get started with the sql functions let’s make some toy data to play with! The following query generates the numbers from 1 to 10 for the x column and defines y as (4 * x) + 3
with some noise. We also added a few naughty rows with null values, this will help show off snowflakes functions above.
create temporary table temp_table as (
select
row_number() over (order by seq8()) as x,
4 * x + normal(0, 10, random(10)) + 3 as y
from table(generator(rowcount => 20)) G
union all
select
*
from values (null, 100), (50, null)
)
Ok, now that’s done let’s take a look at the table. The data is the same as the chart above.
select
x,
y
from temp_table;
X | Y |
---|---|
1 | -5.940727153 |
2 | -6.435786057 |
3 | 12.874649301 |
4 | 35.910490613 |
5 | 40.969651455 |
6 | 27.466977488 |
7 | 28.339904759 |
8 | 48.022369501 |
9 | 36.211253796 |
10 | 55.04814023 |
null | 100 |
50 | null |
REGR_VALX
and REGR_VALY
REGR_VALX
, REGR_VALY
are helper functions to make sure we only include points (x, y) in the regression where both x and y are not null. REGR_VALX(y , x)
returns x if y is not null else it returns null and REGR_VALY(y , x)
returns y if x is not null else it returns null.
That means the following two queries are equivalent
select
REGR_VALX(y , x) as x_val
from temp_table
select
case when y is not null then x else null end as x_val
from temp_table
Similarly REGR_VALY(y , x)
is equivalent to case when x is not null then y else null end
REGR_COUNT
REGR_COUNT
is similar to the normal sql count
function but only returns the number of non null pairs (x, y). In terms of regression this can be used to calculate the sample size $n$. The following query calculates the same thing 3 times (starting to show how the functions in Snowflake help to simplify)
select
regr_count(y, x) as n,
count(regr_valx(y, x)) as n_2,
count(case when y is not null then x else null end) as n_3
from temp_table;
REGR_AVGX
and REGR_AVGY
REGR_AVGX
and REGR_AVGY
are both similar to the avg
function but only include rows where both x and y are non null in the calculation. These are used to calculate the mean values $\bar{x}$ and $\bar{y}$ which are important to calculate the regression gradient
, intercept
and r-squared
. The following query again calculates the same thing three times each for x and y to show how it works.
select
REGR_AVGX(y, x) as x_mean,
avg(regr_valx(y, x)) as x_mean_2,
avg(case when y is not null then x else null end) as x_mean_3,
REGR_AVGY(y, x) as y_mean,
avg(regr_valy(y, x)) as y_mean_2,
avg(case when x is not null then y else null end) as y_mean_3
from temp_table;
REGR_SXX
, REGR_SYY
and REGR_SXY
SXX
, SYY
and SXY
are useful statistics to calculate the slope, intercept and r-squared. Their formulas are in the table above here.
The following two queries both calculate SXX
, SYY
and SXY
, the first in vanilla SQL the second using the snowflake helper functions.
with averages as (
select
avg(case when y is not null then x else null end) as x_mean,
avg(case when x is not null then y else null end) as y_mean
from temp_table
)
select
sum(pow(case when y is not null then x else null end - x_mean, 2)) as sxx,
sum(pow(case when x is not null then y else null end - y_mean, 2)) as syy,
sum(
(case when y is not null then x else null end - x_mean) *
(case when x is not null then y else null end - y_mean)
) as sxy
from temp_table
inner join averages
;
Using the helpers:
select
regr_sxx(y, x) as sxx,
regr_syy(y, x) as syy,
regr_sxy(y, x) as sxy
from temp_table;
;
SXX | SYY | SXY |
---|---|---|
82.5 | 3,992.00 | 493.48 |
REGR_SLOPE
, REGR_INTERCEPT
and REGR_R2
REGR_SLOPE
and REGR_INTERCEPT
calculate the optimal slope and intercept to minimise the sum of squares of the residuals. The REGR_R2
function calculates the r-squared for the optimal line. This number is between 0 and 1 and explains how much of the variance is explained by the model, the closer the value to 1 the better the fit.
The formulas are given in the tables above here.
Again we show this in vanilla sql and using the snowflake helpers:
with averages as (
/* First we calculate the average of x and y */
select
avg(case when y is not null then x else null end) as x_mean,
avg(case when x is not null then y else null end) as y_mean
from temp_table
),
statistics as (
/* Calculate SXX, SYY and SXY as an intermediate step */
select
x_mean,
y_mean,
sum(pow(case when y is not null then x else null end - x_mean, 2)) as sxx,
sum(pow(case when x is not null then y else null end - y_mean, 2)) as syy,
sum(
(case when y is not null then x else null end - x_mean) *
(case when x is not null then y else null end - y_mean)
) as sxy
from temp_table
inner join averages
group by 1, 2
)
select
/* Finally calculate the optimal slope and intercept
as well as the corresponding r-squared */
sxy / sxx as slope,
y_mean - slope * x_mean as intercept,
(sxy * sxy) / (sxx * syy) as r_squared
from statistics;
;
select
regr_slope(y, x) as slope,
regr_intercept(y, x) as intercept,
regr_r2(y, x) as r_squared
from temp_table;
;
SLOPE | INTERCEPT | R_SQUARED |
---|---|---|
5.981534878 | -5.651749436 | 0.7394144386 |
The functions above allow you to quickly calculate regression values in snowflake and understand the relationships between columns - but remember correlation is not causation s that.
For more about linear regression see my other linear regression article here
Thanks for reading! 👏 Please get in touch with any questions, mistakes or improvements.