4.2 Statsmodels 核心功能
"Make everything as simple as possible, but not simpler.""让一切尽可能简单,但不要过于简单。"— Albert Einstein, Physicist (物理学家)
Python 统计分析的基石
本节目标
- 掌握 Statsmodels 的核心 API
- 理解 OLS、GLM、时间序列建模
- 学习模型诊断与稳健标准误
- 使用公式接口快速建模
Statsmodels 简介
定位:Python 中的 Stata —— 经典统计分析的首选
核心特点:
- 论文级输出(详细统计表格)
- 丰富的诊断工具
- 异方差稳健标准误
- 时间序列完整支持
安装:
bash
pip install statsmodels
# 或
conda install -c conda-forge statsmodels两种 API 风格
API 1:标准接口(推荐用于编程)
python
import statsmodels.api as sm
import pandas as pd
import numpy as np
# 1. 准备数据
X = df[['education', 'experience']]
X = sm.add_constant(X) # ️ 必须手动添加常数项
y = df['wage']
# 2. 拟合模型
model = sm.OLS(y, X).fit()
# 3. 查看结果
print(model.summary())API 2:公式接口(推荐用于交互分析)
python
import statsmodels.formula.api as smf
# R 风格公式
model = smf.ols('wage ~ education + experience', data=df).fit()
print(model.summary())
# 优势:
# 自动添加常数项
# 自动处理分类变量
# 支持变换和交互项OLS 回归
基础 OLS
python
import statsmodels.api as sm
import pandas as pd
# 数据
df = pd.DataFrame({
'wage': [3000, 3500, 4000, 5000, 5500],
'education': [12, 14, 16, 18, 20],
'experience': [2, 3, 4, 5, 6]
})
# OLS 回归
X = sm.add_constant(df[['education', 'experience']])
y = df['wage']
model = sm.OLS(y, X).fit()
print(model.summary())输出解读:
OLS Regression Results
==============================================================================
Dep. Variable: wage R-squared: 0.999
Model: OLS Adj. R-squared: 0.998
Method: Least Squares F-statistic: 857.5
Prob (F-statistic): 0.00116
...
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
const -250.0000 164.317 -1.521 0.265 -1155.959 655.959
education 125.0000 10.000 12.500 0.003 87.156 162.844
experience 125.0000 20.000 6.250 0.016 38.313 211.687
==============================================================================异方差稳健标准误
python
# HC0, HC1, HC2, HC3(默认 HC1)
model_robust = sm.OLS(y, X).fit(cov_type='HC3')
print(model_robust.summary())
# 对比
print(f"普通 SE: {model.bse['education']:.4f}")
print(f"稳健 SE: {model_robust.bse['education']:.4f}")标准误类型:
'nonrobust': 默认(假设同方差)'HC0': White 稳健标准误'HC1': 小样本调整'HC2': 更保守'HC3': 最稳健(推荐)
加权最小二乘(WLS)
python
# 如果已知异方差结构
weights = 1 / df['variance']
model_wls = sm.WLS(y, X, weights=weights).fit()广义线性模型(GLM)
Logit 回归(二元因变量)
python
# 示例:就业状态(0/1)
df['employed'] = [0, 1, 1, 0, 1, 1, 0, 1]
X = sm.add_constant(df[['education', 'experience']])
y = df['employed']
# Logit
logit_model = sm.Logit(y, X).fit()
print(logit_model.summary())
# 边际效应
marginal_effects = logit_model.get_margeff()
print(marginal_effects.summary())Probit 回归
python
probit_model = sm.Probit(y, X).fit()Poisson 回归(计数数据)
python
# 示例:专利申请数量
df['patents'] = [0, 1, 2, 5, 3, 1, 0, 2]
poisson_model = sm.GLM(df['patents'], X,
family=sm.families.Poisson()).fit()
print(poisson_model.summary())模型诊断
1. 残差分析
python
import matplotlib.pyplot as plt
# 拟合模型
model = sm.OLS(y, X).fit()
# 残差
residuals = model.resid
# 可视化
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# (1) 残差 vs 拟合值
axes[0, 0].scatter(model.fittedvalues, residuals)
axes[0, 0].axhline(0, color='red', linestyle='--')
axes[0, 0].set_xlabel('Fitted values')
axes[0, 0].set_ylabel('Residuals')
axes[0, 0].set_title('Residuals vs Fitted')
# (2) 残差直方图
axes[0, 1].hist(residuals, bins=20, edgecolor='black')
axes[0, 1].set_xlabel('Residuals')
axes[0, 1].set_title('Histogram of Residuals')
# (3) Q-Q plot(正态性)
sm.qqplot(residuals, line='s', ax=axes[1, 0])
axes[1, 0].set_title('Normal Q-Q Plot')
# (4) 残差 vs 杠杆值
influence = model.get_influence()
leverage = influence.hat_matrix_diag
axes[1, 1].scatter(leverage, residuals)
axes[1, 1].set_xlabel('Leverage')
axes[1, 1].set_ylabel('Residuals')
axes[1, 1].set_title('Residuals vs Leverage')
plt.tight_layout()
plt.show()2. 异方差检验
python
from statsmodels.stats.diagnostic import het_breuschpagan
# Breusch-Pagan 检验
bp_test = het_breuschpagan(model.resid, model.model.exog)
labels = ['LM Statistic', 'LM-Test p-value', 'F-Statistic', 'F-Test p-value']
print(dict(zip(labels, bp_test)))
# 解释:
# p-value < 0.05 → 拒绝同方差,存在异方差3. 多重共线性(VIF)
python
from statsmodels.stats.outliers_influence import variance_inflation_factor
# 计算 VIF
vif_data = pd.DataFrame()
vif_data["Variable"] = X.columns
vif_data["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
print(vif_data)
# 判断标准:
# VIF < 5:无问题
# 5 < VIF < 10:中度共线性
# VIF > 10:严重共线性4. 正态性检验
python
from scipy import stats
# Shapiro-Wilk 检验
stat, p_value = stats.shapiro(model.resid)
print(f"Shapiro-Wilk p-value: {p_value:.4f}")
# Jarque-Bera 检验(大样本)
jb_test = sm.stats.jarque_bera(model.resid)
print(f"Jarque-Bera p-value: {jb_test[1]:.4f}")公式接口高级用法
分类变量
python
import statsmodels.formula.api as smf
# 自动创建虚拟变量(C())
model = smf.ols('wage ~ education + experience + C(region)', data=df).fit()
# 指定参照组
model = smf.ols('wage ~ education + C(region, Treatment("East"))', data=df).fit()变换
python
# 对数转换
model = smf.ols('np.log(wage) ~ education + experience', data=df).fit()
# 多项式
model = smf.ols('wage ~ education + experience + I(experience**2)', data=df).fit()
# 标准化
model = smf.ols('scale(wage) ~ scale(education) + scale(experience)', data=df).fit()交互项
python
# 方法 1:显式交互
model = smf.ols('wage ~ education + female + education:female', data=df).fit()
# 方法 2:完全交互
model = smf.ols('wage ~ education * female', data=df).fit()
# 等价于:wage ~ education + female + education:female⏱️ 时间序列
ARIMA 模型
python
from statsmodels.tsa.arima.model import ARIMA
import pandas as pd
# 示例:季度 GDP
df = pd.read_csv('gdp.csv', index_col='date', parse_dates=True)
# ARIMA(1,1,1)
model = ARIMA(df['gdp'], order=(1, 1, 1)).fit()
print(model.summary())
# 预测
forecast = model.forecast(steps=8)
print(forecast)平稳性检验
python
from statsmodels.tsa.stattools import adfuller
# ADF 检验
result = adfuller(df['gdp'])
print(f'ADF Statistic: {result[0]:.4f}')
print(f'p-value: {result[1]:.4f}')
# 判断:
# p-value < 0.05 → 拒绝单位根,序列平稳输出回归表格
单个模型
python
# LaTeX 格式
print(model.summary().as_latex())
# HTML 格式
print(model.summary().as_html())
# 保存
with open('regression_results.tex', 'w') as f:
f.write(model.summary().as_latex())多模型对比
python
from statsmodels.iolib.summary2 import summary_col
# 拟合多个模型
model1 = smf.ols('wage ~ education', data=df).fit()
model2 = smf.ols('wage ~ education + experience', data=df).fit()
model3 = smf.ols('wage ~ education + experience + female', data=df).fit()
# 对比表格
results = summary_col(
[model1, model2, model3],
model_names=['(1)', '(2)', '(3)'],
stars=True,
info_dict={
'N': lambda x: f"{int(x.nobs):,}",
'R²': lambda x: f"{x.rsquared:.3f}"
}
)
print(results)
# 导出为 LaTeX
print(results.as_latex())小结
核心 API
| 任务 | 函数 | 示例 |
|---|---|---|
| OLS | sm.OLS() | 线性回归 |
| Logit | sm.Logit() | 二元因变量 |
| Probit | sm.Probit() | 二元因变量 |
| GLM | sm.GLM() | 广义线性模型 |
| WLS | sm.WLS() | 加权最小二乘 |
| ARIMA | ARIMA() | 时间序列 |
最佳实践
- 使用稳健标准误:
cov_type='HC3' - 检查诊断:残差、VIF、异方差
- 公式接口原型:快速探索
- 标准接口生产:精确控制
- 保存结果:
.summary().as_latex()
下一步
下一节:SciPy.stats - 快速统计检验
参考:
- Statsmodels 官方文档:https://www.statsmodels.org/
- Seabold & Perktold (2010): "Statsmodels: Econometric modeling with Python"