1091 字
5 分钟
基于模拟退火的线性回归模型计算方法
2025-06-04

奇思妙想#

上课的时候突然想到了这种奇妙方法,但是适用性并不普遍,大抵只适用于低次的拟合函数。——后来发现好像还挺普遍的

但是感觉挺有趣的,顺便温习一下之前所学。

观察到对于一组样本数据 (x1,y1),(x2,y2),,(xn,yn)(x_1,y_1),(x_2,y_2),\cdots,(x_n,y_n) 进行建模。若当前拟合函数为 y=ax+by=ax+b 则其均方误差(MSE)被定义为:

E(a,b)=1ni=1n(yi(axi+b))2E(a, b) = \frac{1}{n} \sum_{i=1}^{n} (y_i - (ax_i + b))^2

对于目标拟合函数,则要求 E(a,b)E(a,b) 取得最小值。

故此,不难想到对于未知单调性且貌似找不到什么规律的函数求其最值,较好的方法便是使用模拟退火跑 aabb。对于模拟退火的学习笔记可以看我两年前写的 洛谷博客

然后套式子就行。

#include<iostream>
#include<cmath>
#include<cstdio>
#include<iomanip>
#include<cstdlib>
using namespace std;

const int maxn = 1010;

int n;
struct Point {
    double x, y;
} p[maxn];

double pow2(double x) { return x * x; }

double MSE(double a, double b) {
    double res = 0;
    for (int i = 1; i <= n; i++) {
        double pred = a * p[i].x + b;
        res += pow2(p[i].y - pred);
    }
    return res / n;
}

double a_ans, b_ans, best_cost;

void SA() {
    double a = a_ans, b = b_ans;
    double T = 1000, Tmin = 1e-6, delta = 0.98;
    double cur_cost = MSE(a, b);

    while (T > Tmin) {
        double a_new = a + (rand() * 2.0 - RAND_MAX) / RAND_MAX * T;
        double b_new = b + (rand() * 2.0 - RAND_MAX) / RAND_MAX * T;
        double new_cost = MSE(a_new, b_new);

        if (new_cost < cur_cost || exp((cur_cost - new_cost) / T) * RAND_MAX > rand()) {
            a = a_new, b = b_new;
            cur_cost = new_cost;
            if (cur_cost < best_cost) {
                a_ans = a;
                b_ans = b;
                best_cost = cur_cost;
            }
        }

        T *= delta;
    }
}

int main() {
    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> p[i].x >> p[i].y;
    }

    a_ans = 0, b_ans = 0;
    best_cost = MSE(a_ans, b_ans);

    SA();SA();SA();SA();SA();SA(); //想准一点就多退几次

    cout << fixed << setprecision(5) << "a = " << a_ans << ", b = " << b_ans << endl;
    return 0;
}

更进一步#

NOTE

以下记于 2025.6.13

具体实现#

此外,容易发现此方法具有极强的扩展性。对于任意次的拟合函数只需要退火时多退几个参数即可。即使这会导致结果大概率错误,但是貌似对于“无限次回归”没有通解的存在,普遍使用神经网络等算法进行实现,其难度对比此非属同一量级。

以下进一步进行说明。

对于一拟合函数:

y^(x)=k=0Nindexkxk\hat{y}(x) = \sum_{k=0}^{N} \text{index}_k x^k

其目标函数可被定义为:

Loss(index)=i1m(yiy^(xi))2\text{Loss}(\text{index})=\sum^m_{i-1}(y_i-\hat{y}(x_i))^2 容易写出代码:

#include<iostream>
#include<cmath>
#include<cstdio>
#include<iomanip>
#include<cstdlib>
using namespace std;

const int maxn = 1010;
const int maxd = 20;

int n, deg;
struct Point {
    double x, y;
} p[maxn];

double pow2(double x) { return x * x; }

double index[maxd + 1]; 
double best_index[maxd + 1];
double best_cost;

double val(double x) {
    double res = 0, x_pow = 1;
    for (int i = 0; i <= deg; i++) {
        res += index[i] * x_pow;
        x_pow *= x;
    }
    return res;
}

double MSE() {
    double res = 0;
    for (int i = 1; i <= n; i++) {
        res += pow2(p[i].y - val(p[i].x));
    }
    return res / n;
}

void SA() {
    double T = 1000, Tmin = 1e-6, delta = 0.98;
    double cur_cost = MSE();

    while (T > Tmin) {
        double new_index[maxd + 1];
        for (int i = 0; i <= deg; i++) {
            new_index[i] = index[i] + (rand() * 2.0 - RAND_MAX) / RAND_MAX * T;
        }

        for (int i = 0; i <= deg; i++) index[i] = new_index[i];
        double new_cost = MSE();

        if (new_cost < cur_cost || exp((cur_cost - new_cost) / T) * RAND_MAX > rand()) {
            cur_cost = new_cost;
            if (cur_cost < best_cost) {
                for (int i = 0; i <= deg; i++) best_index[i] = index[i];
                best_cost = cur_cost;
            }
        } else {
            for (int i = 0; i <= deg; i++) index[i] = best_index[i];
        }

        T *= delta;
    }
}

int main() {
    cin >> n >> deg; // n 对应点数,deg 对应多项式次数
    for (int i = 1; i <= n; i++) {
        cin >> p[i].x >> p[i].y;
    }

    for (int i = 0; i <= deg; i++) index[i] = best_index[i] = 0;
    best_cost = MSE();

    SA();SA();SA();SA();SA();SA();SA();SA();

    cout << fixed << setprecision(5);
    for (int i = 0; i <= deg; i++) {
        cout << "a[" << i << "] = " << best_index[i] << endl;
    }

    return 0;
}

复杂度分析#

约定 nn 为数据点数,dd 为多项式次数

MSE 复杂度#

一眼 O(n×d)O(n\times d)

SA 复杂度#

退火总轮数#

温度每次乘以 delta=0.98\text{delta}=0.98,直到 T<TminT<\text{Tmin},次数为:

k=log0.98(TminT0)=O(log(1Tmin))k = \log_{0.98}\left(\frac{T_{\min}}{T_0}\right) = O\left(\log\left(\frac{1}{T_{\min}}\right)\right)

代码中 Tmin=1e6T_{\min} = 1e{-6},所以一般需要:

log(1e6)log(1/0.98)13.80.02=690 次\frac{\log(1e6)}{\log(1/0.98)} \approx \frac{13.8}{0.02} = 690 \text{ 次}

所以我们记退火迭代次数为 Tsteps6001000T_{\text{steps}} \approx 600 \sim 1000

每次迭代#
  • 扰动 d+1d+1 个参数,O(d)O(d)
  • 调用一次 MSE,O(n×d)O(n\times d)
  • 复制俩数组 O(d)O(d)

总复杂度 O(n×d)O(n\times d)

SA 总复杂度为 O(Tstepsnd)O(T_{\text{steps}} \cdot n \cdot d)


综上可知,总复杂度为 O(Tstepsndk)O(T_{\text{steps}} \cdot n \cdot d\cdot k),其中 kk 为你跑 SA 的次数。

基于模拟退火的线性回归模型计算方法
https://blog.mitufun.top/posts/基于模拟退火的线性回归模型计算方法/基于模拟退火的线性回归模型计算方法/
作者
MituFun
发布于
2025-06-04
许可协议
CC BY-NC-SA 4.0