Skip to content

4.2 Statsmodels 核心功能

"Make everything as simple as possible, but not simpler.""让一切尽可能简单,但不要过于简单。"— Albert Einstein, Physicist (物理学家)

Python 统计分析的基石


本节目标

  • 掌握 Statsmodels 的核心 API
  • 理解 OLS、GLM、时间序列建模
  • 学习模型诊断与稳健标准误
  • 使用公式接口快速建模

Statsmodels 简介

定位:Python 中的 Stata —— 经典统计分析的首选

核心特点

  • 论文级输出(详细统计表格)
  • 丰富的诊断工具
  • 异方差稳健标准误
  • 时间序列完整支持

安装

bash
pip install statsmodels
# 或
conda install -c conda-forge statsmodels

两种 API 风格

API 1:标准接口(推荐用于编程)

python
import statsmodels.api as sm
import pandas as pd
import numpy as np

# 1. 准备数据
X = df[['education', 'experience']]
X = sm.add_constant(X)  # ️ 必须手动添加常数项
y = df['wage']

# 2. 拟合模型
model = sm.OLS(y, X).fit()

# 3. 查看结果
print(model.summary())

API 2:公式接口(推荐用于交互分析)

python
import statsmodels.formula.api as smf

# R 风格公式
model = smf.ols('wage ~ education + experience', data=df).fit()
print(model.summary())

# 优势:
#  自动添加常数项
#  自动处理分类变量
#  支持变换和交互项

OLS 回归

基础 OLS

python
import statsmodels.api as sm
import pandas as pd

# 数据
df = pd.DataFrame({
    'wage': [3000, 3500, 4000, 5000, 5500],
    'education': [12, 14, 16, 18, 20],
    'experience': [2, 3, 4, 5, 6]
})

# OLS 回归
X = sm.add_constant(df[['education', 'experience']])
y = df['wage']

model = sm.OLS(y, X).fit()
print(model.summary())

输出解读

                            OLS Regression Results                            
==============================================================================
Dep. Variable:                   wage   R-squared:                       0.999
Model:                            OLS   Adj. R-squared:                  0.998
Method:                 Least Squares   F-statistic:                     857.5
Prob (F-statistic):            0.00116
...
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
const       -250.0000    164.317     -1.521      0.265   -1155.959     655.959
education    125.0000     10.000     12.500      0.003      87.156     162.844
experience   125.0000     20.000      6.250      0.016      38.313     211.687
==============================================================================

异方差稳健标准误

python
# HC0, HC1, HC2, HC3(默认 HC1)
model_robust = sm.OLS(y, X).fit(cov_type='HC3')
print(model_robust.summary())

# 对比
print(f"普通 SE: {model.bse['education']:.4f}")
print(f"稳健 SE: {model_robust.bse['education']:.4f}")

标准误类型

  • 'nonrobust': 默认(假设同方差)
  • 'HC0': White 稳健标准误
  • 'HC1': 小样本调整
  • 'HC2': 更保守
  • 'HC3': 最稳健(推荐)

加权最小二乘(WLS)

python
# 如果已知异方差结构
weights = 1 / df['variance']
model_wls = sm.WLS(y, X, weights=weights).fit()

广义线性模型(GLM)

Logit 回归(二元因变量)

python
# 示例:就业状态(0/1)
df['employed'] = [0, 1, 1, 0, 1, 1, 0, 1]

X = sm.add_constant(df[['education', 'experience']])
y = df['employed']

# Logit
logit_model = sm.Logit(y, X).fit()
print(logit_model.summary())

# 边际效应
marginal_effects = logit_model.get_margeff()
print(marginal_effects.summary())

Probit 回归

python
probit_model = sm.Probit(y, X).fit()

Poisson 回归(计数数据)

python
# 示例:专利申请数量
df['patents'] = [0, 1, 2, 5, 3, 1, 0, 2]

poisson_model = sm.GLM(df['patents'], X, 
                       family=sm.families.Poisson()).fit()
print(poisson_model.summary())

模型诊断

1. 残差分析

python
import matplotlib.pyplot as plt

# 拟合模型
model = sm.OLS(y, X).fit()

# 残差
residuals = model.resid

# 可视化
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# (1) 残差 vs 拟合值
axes[0, 0].scatter(model.fittedvalues, residuals)
axes[0, 0].axhline(0, color='red', linestyle='--')
axes[0, 0].set_xlabel('Fitted values')
axes[0, 0].set_ylabel('Residuals')
axes[0, 0].set_title('Residuals vs Fitted')

# (2) 残差直方图
axes[0, 1].hist(residuals, bins=20, edgecolor='black')
axes[0, 1].set_xlabel('Residuals')
axes[0, 1].set_title('Histogram of Residuals')

# (3) Q-Q plot(正态性)
sm.qqplot(residuals, line='s', ax=axes[1, 0])
axes[1, 0].set_title('Normal Q-Q Plot')

# (4) 残差 vs 杠杆值
influence = model.get_influence()
leverage = influence.hat_matrix_diag
axes[1, 1].scatter(leverage, residuals)
axes[1, 1].set_xlabel('Leverage')
axes[1, 1].set_ylabel('Residuals')
axes[1, 1].set_title('Residuals vs Leverage')

plt.tight_layout()
plt.show()

2. 异方差检验

python
from statsmodels.stats.diagnostic import het_breuschpagan

# Breusch-Pagan 检验
bp_test = het_breuschpagan(model.resid, model.model.exog)

labels = ['LM Statistic', 'LM-Test p-value', 'F-Statistic', 'F-Test p-value']
print(dict(zip(labels, bp_test)))

# 解释:
# p-value < 0.05 → 拒绝同方差,存在异方差

3. 多重共线性(VIF)

python
from statsmodels.stats.outliers_influence import variance_inflation_factor

# 计算 VIF
vif_data = pd.DataFrame()
vif_data["Variable"] = X.columns
vif_data["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]

print(vif_data)

# 判断标准:
# VIF < 5:无问题
# 5 < VIF < 10:中度共线性
# VIF > 10:严重共线性

4. 正态性检验

python
from scipy import stats

# Shapiro-Wilk 检验
stat, p_value = stats.shapiro(model.resid)
print(f"Shapiro-Wilk p-value: {p_value:.4f}")

# Jarque-Bera 检验(大样本)
jb_test = sm.stats.jarque_bera(model.resid)
print(f"Jarque-Bera p-value: {jb_test[1]:.4f}")

公式接口高级用法

分类变量

python
import statsmodels.formula.api as smf

# 自动创建虚拟变量(C())
model = smf.ols('wage ~ education + experience + C(region)', data=df).fit()

# 指定参照组
model = smf.ols('wage ~ education + C(region, Treatment("East"))', data=df).fit()

变换

python
# 对数转换
model = smf.ols('np.log(wage) ~ education + experience', data=df).fit()

# 多项式
model = smf.ols('wage ~ education + experience + I(experience**2)', data=df).fit()

# 标准化
model = smf.ols('scale(wage) ~ scale(education) + scale(experience)', data=df).fit()

交互项

python
# 方法 1:显式交互
model = smf.ols('wage ~ education + female + education:female', data=df).fit()

# 方法 2:完全交互
model = smf.ols('wage ~ education * female', data=df).fit()
# 等价于:wage ~ education + female + education:female

⏱️ 时间序列

ARIMA 模型

python
from statsmodels.tsa.arima.model import ARIMA
import pandas as pd

# 示例:季度 GDP
df = pd.read_csv('gdp.csv', index_col='date', parse_dates=True)

# ARIMA(1,1,1)
model = ARIMA(df['gdp'], order=(1, 1, 1)).fit()
print(model.summary())

# 预测
forecast = model.forecast(steps=8)
print(forecast)

平稳性检验

python
from statsmodels.tsa.stattools import adfuller

# ADF 检验
result = adfuller(df['gdp'])
print(f'ADF Statistic: {result[0]:.4f}')
print(f'p-value: {result[1]:.4f}')

# 判断:
# p-value < 0.05 → 拒绝单位根,序列平稳

输出回归表格

单个模型

python
# LaTeX 格式
print(model.summary().as_latex())

# HTML 格式
print(model.summary().as_html())

# 保存
with open('regression_results.tex', 'w') as f:
    f.write(model.summary().as_latex())

多模型对比

python
from statsmodels.iolib.summary2 import summary_col

# 拟合多个模型
model1 = smf.ols('wage ~ education', data=df).fit()
model2 = smf.ols('wage ~ education + experience', data=df).fit()
model3 = smf.ols('wage ~ education + experience + female', data=df).fit()

# 对比表格
results = summary_col(
    [model1, model2, model3],
    model_names=['(1)', '(2)', '(3)'],
    stars=True,
    info_dict={
        'N': lambda x: f"{int(x.nobs):,}",
        'R²': lambda x: f"{x.rsquared:.3f}"
    }
)

print(results)

# 导出为 LaTeX
print(results.as_latex())

小结

核心 API

任务函数示例
OLSsm.OLS()线性回归
Logitsm.Logit()二元因变量
Probitsm.Probit()二元因变量
GLMsm.GLM()广义线性模型
WLSsm.WLS()加权最小二乘
ARIMAARIMA()时间序列

最佳实践

  1. 使用稳健标准误cov_type='HC3'
  2. 检查诊断:残差、VIF、异方差
  3. 公式接口原型:快速探索
  4. 标准接口生产:精确控制
  5. 保存结果.summary().as_latex()

下一步

下一节:SciPy.stats - 快速统计检验


参考

基于 MIT 许可证发布。内容版权归作者所有。