diff --git a/class14/README.md b/class14/README.md index c0f80b50d5594e4060e065e7544bdc6e50c5a457..f84505b12685fe546772981a2c60c03d5cb1c3bb 100755 --- a/class14/README.md +++ b/class14/README.md @@ -243,4 +243,215 @@ print("result1=", reg.predict([[3000, 3, 40]])) ``` +### 1.7 多项式回归(以年份为自变量,总人口为因变量)来预测人口 + +```matlab +% 读取CSV文件 +filename = 'demo.csv'; +data = readmatrix(filename, 'FileType', 'text', 'Encoding', 'UTF-8', 'NumHeaderLines', 1); + +% 提取数据:年份(第1列),总人口(第2列) +years = data(:, 1); % 自变量:年份 +total_pop = data(:, 2); % 因变量:总人口 + +% 获取多项式阶数 +poly_degree = input('请输入多项式阶数 (1-5,推荐3): '); +if isempty(poly_degree) + poly_degree = 3; % 默认使用3阶多项式 +end + +% 多项式回归 +coefficients = polyfit(years, total_pop, poly_degree); + +% 显示多项式方程 +fprintf('\n多项式回归方程 (阶数 %d):\n', poly_degree); +for i = 1:length(coefficients) + power = length(coefficients) - i; + if power > 1 + fprintf('%.6f × 年份^%d', coefficients(i), power); + elseif power == 1 + fprintf('%.6f × 年份', coefficients(i)); + else + fprintf('%.6f', coefficients(i)); + end + + if i < length(coefficients) + fprintf(' + '); + end +end +fprintf('\n\n'); + +% 创建多项式预测函数 +predictPolyPop = @(year) polyval(coefficients, year); + +% 计算R²值 +predicted_values = polyval(coefficients, years); +SS_res = sum((total_pop - predicted_values).^2); +SS_tot = sum((total_pop - mean(total_pop)).^2); +R_squared = 1 - (SS_res / SS_tot); +fprintf('R² = %.4f\n', R_squared); + +% 获取用户输入年份 +year_input = input('请输入要预测的年份: '); +if isempty(year_input) + year_input = 2030; % 默认预测2030年 +end + +% 预测人口 +predicted_pop = predictPolyPop(year_input); +fprintf('预测%d年的人口: %.0f\n\n', year_input, predicted_pop); + +% 创建更密集的年份序列用于绘制平滑曲线 +years_dense = linspace(min(years), max(years) + (max(years)-min(years))*0.2, 200); +pop_dense = polyval(coefficients, years_dense); + +% 创建图形 +figure('Position', [100, 100, 1000, 700]); + +% 子图1:原始数据点与多项式回归曲线 +subplot(2, 2, 1); +plot(years, total_pop, 'bo', 'MarkerSize', 8, 'MarkerFaceColor', 'b', 'DisplayName', '原始数据'); +hold on; +plot(years_dense, pop_dense, 'r-', 'LineWidth', 2, 'DisplayName', sprintf('%d阶多项式拟合', poly_degree)); + +% 标记预测点 +if year_input >= min(years_dense) && year_input <= max(years_dense) + plot(year_input, predicted_pop, 'gs', 'MarkerSize', 12, 'MarkerFaceColor', 'g', 'DisplayName', sprintf('预测%d年', year_input)); + + % 添加预测线 + plot([year_input, year_input], [min(total_pop)*0.9, predicted_pop], 'g--', 'LineWidth', 1); + plot([min(years_dense), year_input], [predicted_pop, predicted_pop], 'g--', 'LineWidth', 1); + + % 添加预测值标签 + text(year_input, predicted_pop*0.97, sprintf('%.0f', predicted_pop), ... + 'HorizontalAlignment', 'center', 'VerticalAlignment', 'top', ... + 'FontSize', 10, 'FontWeight', 'bold', 'Color', 'g', 'BackgroundColor', 'white'); +end + +xlabel('年份', 'FontSize', 12, 'FontWeight', 'bold'); +ylabel('总人口', 'FontSize', 12, 'FontWeight', 'bold'); +title(sprintf('多项式回归拟合 (阶数 = %d)', poly_degree), 'FontSize', 14, 'FontWeight', 'bold'); +legend('Location', 'best'); +grid on; + +% 子图2:残差分析 +subplot(2, 2, 2); +residuals = total_pop - predicted_values; +plot(years, residuals, 'ko', 'MarkerSize', 8, 'MarkerFaceColor', 'k'); +hold on; +plot([min(years), max(years)], [0, 0], 'r-', 'LineWidth', 1.5); +xlabel('年份', 'FontSize', 12, 'FontWeight', 'bold'); +ylabel('残差', 'FontSize', 12, 'FontWeight', 'bold'); +title('残差分析图', 'FontSize', 14, 'FontWeight', 'bold'); +grid on; + +% 添加残差统计信息 +mean_residual = mean(residuals); +std_residual = std(residuals); +text(min(years), max(residuals)*0.9, sprintf('均值: %.2f\n标准差: %.2f', mean_residual, std_residual), ... + 'FontSize', 10, 'BackgroundColor', 'white', 'EdgeColor', 'black'); + +% 子图3:预测值与实际值对比 +subplot(2, 2, 3); +plot(years, total_pop, 'bo-', 'LineWidth', 1.5, 'DisplayName', '实际值'); +hold on; +plot(years, predicted_values, 'r*-', 'LineWidth', 1.5, 'DisplayName', '拟合值'); +xlabel('年份', 'FontSize', 12, 'FontWeight', 'bold'); +ylabel('人口', 'FontSize', 12, 'FontWeight', 'bold'); +title('实际值与拟合值对比', 'FontSize', 14, 'FontWeight', 'bold'); +legend('Location', 'best'); +grid on; + +% 子图4:不同阶数多项式拟合效果对比 +subplot(2, 2, 4); +colors = {'r', 'g', 'b', 'm', 'c'}; +plot(years, total_pop, 'ko', 'MarkerSize', 8, 'MarkerFaceColor', 'k', 'DisplayName', '原始数据'); +hold on; + +for deg = 1:min(5, length(years)-1) + coeff_temp = polyfit(years, total_pop, deg); + pop_temp = polyval(coeff_temp, years_dense); + + % 计算该阶数的R² + pred_temp = polyval(coeff_temp, years); + SS_res_temp = sum((total_pop - pred_temp).^2); + R_squared_temp = 1 - (SS_res_temp / SS_tot); + + plot(years_dense, pop_temp, colors{deg}, 'LineWidth', 1.5, ... + 'DisplayName', sprintf('阶数%d (R²=%.4f)', deg, R_squared_temp)); +end + +xlabel('年份', 'FontSize', 12, 'FontWeight', 'bold'); +ylabel('总人口', 'FontSize', 12, 'FontWeight', 'bold'); +title('不同阶数多项式拟合对比', 'FontSize', 14, 'FontWeight', 'bold'); +legend('Location', 'best', 'FontSize', 8); +grid on; + +% 在图形底部添加回归方程 +annotation('textbox', [0.15, 0.01, 0.7, 0.05], 'String', ... + sprintf('多项式回归方程 (阶数%d): y = %s', poly_degree, poly2str(coefficients, '年份')), ... + 'FontSize', 10, 'FontWeight', 'bold', 'BackgroundColor', 'white', 'EdgeColor', 'black'); + +% 创建辅助函数来格式化多项式字符串 +function str = poly2str(coeffs, var_name) + str = ''; + n = length(coeffs); + for i = 1:n + power = n - i; + if coeffs(i) ~= 0 + if ~isempty(str) && coeffs(i) > 0 + str = [str, ' + ']; + elseif coeffs(i) < 0 + str = [str, ' - ']; + end + + if power == 0 + str = sprintf('%s%.4f', str, abs(coeffs(i))); + elseif power == 1 + str = sprintf('%s%.4f × %s', str, abs(coeffs(i)), var_name); + else + str = sprintf('%s%.4f × %s^%d', str, abs(coeffs(i)), var_name, power); + end + end + end +end +``` + +**程序说明:** + +1. **数据读取**: + - 读取CSV文件,提取年份(第1列)和总人口(第2列) + +2. **多项式回归**: + - 用户可以指定多项式的阶数(1-5,默认3阶) + - 使用`polyfit`函数进行多项式回归 + - 计算R²值评估模型拟合优度 + +3. **预测功能**: + - 创建多项式预测函数`predictPolyPop` + - 用户可以输入任意年份进行预测 + - 程序输出预测的人口值 + +4. **可视化**: + - **子图1**:显示原始数据点、多项式拟合曲线和预测点 + - **子图2**:残差分析图,评估模型的拟合质量 + - **子图3**:实际值与拟合值对比图 + - **子图4**:不同阶数多项式拟合效果对比 + +5. **统计信息**: + - 显示多项式回归方程 + - 显示R²值 + - 显示残差的均值和标准差 + +**使用说明:** +1. 将`demo.csv`文件与MATLAB脚本放在同一目录 +2. 运行程序 +3. 输入多项式阶数(推荐3阶) +4. 输入要预测的年份 +5. 程序会显示预测结果和四个子图的可视化 + +**注意事项:** +- 多项式阶数越高,拟合曲线越灵活,但可能导致过拟合 +- 对于超出数据范围的年份,多项式外推可能不准确 +- 残差分析有助于评估模型假设是否成立 ​