线性回归
用一条直线描述数据的规律,机器学习里最简单也最重要的第一步
01 核心原理(大白话版)
你记录了过去一年每天的气温和冰淇淋销量,发现气温越高销量越好。现在老板问你:明天气温 35℃,能卖多少?
你在纸上把数据点都画出来,然后用一把尺子比划——找一条直线,尽量让所有点都离它近。找到了,就用这条线来预测。
线性回归就是让机器自动找出这把"最合适的尺子"。
三件事搞定线性回归
直线只需要两个数:斜率 w(倾斜程度)和截距 b(和 y 轴交的位置)。
公式就一行:y = w × x + b
把每个数据点到直线的距离平方后取平均,叫做均方误差(MSE)。这个数越小,直线越贴合数据。
每次计算一下 w 和 b 往哪边调能让 MSE 变小,就往那边调一点点。反复几百次,直线自然而然逼近最佳位置。
整个过程就是:随机画一条线 → 量偏差 → 微调 → 再量 → 再调……直到线几乎不再移动为止。运行下面的代码就能亲眼看到这个过程。
一步步构建线性回归
我们把完整代码拆开,逐块看清楚每一步在干什么。
用已知的 y = 2x + 5 加随机噪声模拟真实数据,看清楚数据的结构长什么样。
用均方误差(MSE)衡量直线和数据的偏差。MSE 越小,直线越贴合数据。
每轮遍历所有数据,算出 w 和 b 各自的梯度,沿负梯度方向各走一小步。
预测函数 ŷ = wx + b,单个样本损失 L = (wx + b − y)²,令 err = wx + b − y,链式法则得:
如果觉得公式抽象,可以丢掉公式直接看代码,代码是最直观的。下面是公式对应的代码:
const err = w * x + b - y; // 预测误差 const dw = 2 * err * x; // ∂L/∂w const db = 2 * err; // ∂L/∂b
上面是函数形式的表达式,适用于任意一个样本 (x, y)。推广到 n 个样本时,损失取所有样本的平均值(MSE = 均方误差),梯度对每个样本累加再除以 n,1/n 作为常数提出来:
如果觉得公式抽象,可以丢掉公式直接看代码,代码是最直观的。下面是公式对应的代码:
let dw = 0, db = 0;
for (const [x, y] of data) {
const err = w * x + b - y;
dw += err * x; // 累加 ∂L/∂w
db += err; // 累加 ∂L/∂b
}
dw = dw * 2 / n;
db = db * 2 / n;
损失函数与梯度的关系是入门机器学习最核心的一环:
- 损失函数告诉你现在有多差,是一个标量(一个数)
- 梯度是对损失函数求偏导得到的,它的每个分量对应一个参数——告诉你"把这个参数往哪个方向调,损失会上升最快"
- 训练时我们沿负梯度方向更新参数,等于让损失下降最快。在损失曲面上,整个梯度向量的几何意义正是当前位置下坡最陡的方向
$\text{MSE} = \frac{1}{n}\sum(wx+b-y)^2$,令 $\text{err} = wx+b-y$,则:
两个分量合在一起就是 [dw, db],即损失曲面上的梯度向量,沿其反方向走一步就是上面代码最后两行的参数更新。
三段拼在一起,加上可视化,就是完整演示代码——看下面。
w 和 b 分别控制什么?
斜率 w
控制直线的倾斜程度。w 大 → 线陡,x 增加一点 y 增加很多;w 小 → 线平缓;w 为负 → 线向右下倾斜。
截距 b
控制直线的上下位置。当 x=0 时,y=b。调整 b 就像平移整条线,不改变倾斜角度。
02 代码
03 学术性讲解
假设函数(Hypothesis)
给定 n 个训练样本 {(x₁,y₁), …, (xₙ,yₙ)},线性回归假设 x 和 y 之间存在线性映射,用参数化函数表示:
w 称为权重(weight),b 称为偏置(bias)。学习的目标就是找到一组 (w, b) 使 h(x) 尽量贴近真实 y。
目标函数(Objective Function)
要让模型"贴近"数据,首先要定量描述"差多远"。定义第 i 个样本的残差:
为什么用平方而不用绝对值?两个原因:① 处处可导,方便求梯度;② 对大误差惩罚更重。对所有样本的残差平方取平均,得到均方误差(MSE):
J(w,b) 就是目标函数,也叫损失函数(Loss Function)。训练的本质是一个优化问题:
梯度推导
对 J(w,b) 分别对 w 和 b 求偏导,由链式法则:
梯度向量 [∂J/∂w, ∂J/∂b] 指向 J 上升最快的方向,沿其反方向走一步就是梯度下降更新:
α 是学习率(步长),控制每次更新的幅度。反复迭代直到梯度接近零,即到达极小值点,此时 J(w,b) 收敛——因为 MSE 是关于 w、b 的凸函数,极小值即全局最小值。
总结
h(x) = wx + b
εᵢ = h(xᵢ) − yᵢ
J = (1/n)Σεᵢ²(MSE)
∂J/∂w、∂J/∂b 由链式法则推导
参数 ← 参数 − α · 梯度
MSE 关于参数是凸函数,梯度下降收敛到全局最优