线性回归
用一条直线描述数据的规律,机器学习里最简单也最重要的第一步
01 核心原理(大白话版)
你记录了过去一年每天的气温和冰淇淋销量,发现气温越高销量越好。现在老板问你:明天气温 35℃,能卖多少?
你在纸上把数据点都画出来,然后用一把尺子比划——找一条直线,尽量让所有点都离它近。找到了,就用这条线来预测。
线性回归就是让机器自动找出这把"最合适的尺子"。
三件事搞定线性回归
直线只需要两个数:斜率 w(倾斜程度)和截距 b(和 y 轴交的位置)。
公式就一行:y = w × x + b
把每个数据点到直线的距离平方后取平均,叫做均方误差(MSE)。这个数越小,直线越贴合数据。
每次计算一下 w 和 b 往哪边调能让 MSE 变小,就往那边调一点点。反复几百次,直线自然而然逼近最佳位置。
整个过程就是:随机画一条线 → 量偏差 → 微调 → 再量 → 再调……直到线几乎不再移动为止。运行下面的代码就能亲眼看到这个过程。
一步步构建线性回归
我们把完整代码拆开,逐块看清楚每一步在干什么。
用已知的 y = 2x + 5 加随机噪声模拟真实数据,看清楚数据的结构长什么样。
用均方误差(MSE)衡量直线和数据的偏差。MSE 越小,直线越贴合数据。
每轮遍历所有数据,算出 w 和 b 各自的梯度,沿负梯度方向各走一小步。
损失函数与梯度的关系是入门机器学习最核心的一环:
- 损失函数告诉你现在有多差,是一个标量(一个数)
- 梯度是对损失函数求偏导得到的,它的每个分量对应一个参数——告诉你"把这个参数往哪个方向调,损失会上升最快"
- 训练时我们沿负梯度方向更新参数,等于让损失下降最快。在损失曲面上,整个梯度向量的几何意义正是当前位置下坡最陡的方向
MSE = 1/n Σ(wx + b − y)²,令 err = wx + b − y,则:
dw += err * xdb += err两个分量合在一起就是 [dw, db],即损失曲面上的梯度向量,沿其反方向走一步就是上面代码最后两行的参数更新。
三段拼在一起,加上可视化,就是完整演示代码——看下面。
w 和 b 分别控制什么?
斜率 w
控制直线的倾斜程度。w 大 → 线陡,x 增加一点 y 增加很多;w 小 → 线平缓;w 为负 → 线向右下倾斜。
截距 b
控制直线的上下位置。当 x=0 时,y=b。调整 b 就像平移整条线,不改变倾斜角度。
02 代码
03 学术性讲解
线性回归是机器学习中最基础的监督学习算法。它的目标是找到一条直线,尽可能地穿过所有的数据点。这条直线可以用来预测未知数据的值。
什么是线性回归?
假设我们有 n 个数据点 (x₁, y₁), (x₂, y₂), ..., (xₙ, yₙ),我们想找到一条直线:
其中 w 是斜率,b 是截距。线性回归的任务就是找到最优的 w 和 b。
损失函数:均方误差 (MSE)
如何衡量一条直线好不好?我们用均方误差 (Mean Squared Error) 来衡量:
这个公式计算的是每个数据点到直线的垂直距离的平方的平均值。我们的目标是最小化这个值。
如何优化?梯度下降!
和梯度下降章节一样,我们用梯度下降来最小化 MSE:
其中 α 是学习率。通过不断迭代,w 和 b 会逐渐收敛到最优值。
总结
找到最佳拟合直线 y = wx + b
均方误差 (MSE)
梯度下降
预测连续值