Skip to content

OOP 在数据科学中的应用

理解常用库的面向对象设计


为什么数据科学库使用 OOP?

Pandas、Scikit-learn、Statsmodels 都是面向对象的,因为:

  • 数据和方法天然绑定
  • 链式调用更流畅
  • 代码更易维护

Pandas 的 OOP 设计

DataFrame 对象

python
import pandas as pd

# DataFrame 是一个类
df = pd.DataFrame({
    'age': [25, 30, 35, 40],
    'income': [50000, 60000, 75000, 80000]
})

# 属性
print(df.shape)      # (4, 2)
print(df.columns)    # Index(['age', 'income'])
print(df.dtypes)     # 数据类型

# 方法(链式调用)
result = (df
    .query('age > 30')           # 筛选
    .assign(log_income=lambda x: np.log(x['income']))  # 新列
    .sort_values('income')       # 排序
    .reset_index(drop=True)      # 重置索引
)

Series 对象

python
# Series 也是对象
ages = pd.Series([25, 30, 35, 40], name='age')

# 方法
print(ages.mean())      # 32.5
print(ages.std())       # 6.45
print(ages.quantile(0.5))  # 32.5

# 链式操作
result = (ages
    .apply(lambda x: x ** 2)  # 平方
    .sort_values(ascending=False)
    .head(3)
)

Scikit-learn 的 OOP 设计

模型对象

python
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler

# 创建对象
model = LinearRegression()
scaler = StandardScaler()

# 训练(fit 方法)
X = [[1], [2], [3], [4], [5]]
y = [2, 4, 6, 8, 10]

model.fit(X, y)

# 预测(predict 方法)
predictions = model.predict([[6], [7]])
print(predictions)  # [12. 14.]

# 访问属性
print(model.coef_)       # 系数
print(model.intercept_)  # 截距

为什么用 OOP?

python
#  如果不用 OOP(假设)
X_scaled = standard_scale(X)
model_params = fit_linear_regression(X_scaled, y)
predictions = predict_linear_regression(model_params, X_test)

#  使用 OOP(实际)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

model = LinearRegression()
model.fit(X_scaled, y)
predictions = model.predict(X_test_scaled)

Statsmodels 的 OOP 设计

python
import statsmodels.formula.api as smf
import pandas as pd

df = pd.DataFrame({
    'income': [50000, 60000, 75000, 80000, 95000],
    'education': [12, 14, 16, 16, 18],
    'age': [25, 30, 35, 40, 45]
})

# 创建模型对象
model = smf.ols('income ~ education + age', data=df)

# 拟合
results = model.fit()

# 访问结果属性和方法
print(results.summary())
print(results.rsquared)
print(results.params)
print(results.pvalues)

创建自己的数据科学类

示例:简单的线性回归类

python
import numpy as np

class SimpleLinearRegression:
    """简单线性回归(教学用)"""

    def __init__(self):
        self.slope = None
        self.intercept = None

    def fit(self, X, y):
        """拟合模型"""
        X = np.array(X)
        y = np.array(y)

        # 计算斜率和截距
        x_mean = X.mean()
        y_mean = y.mean()

        numerator = ((X - x_mean) * (y - y_mean)).sum()
        denominator = ((X - x_mean) ** 2).sum()

        self.slope = numerator / denominator
        self.intercept = y_mean - self.slope * x_mean

        return self  # 返回自己(支持链式调用)

    def predict(self, X):
        """预测"""
        if self.slope is None:
            raise ValueError("模型未训练,请先调用 fit()")

        X = np.array(X)
        return self.slope * X + self.intercept

    def score(self, X, y):
        """计算 R²"""
        y_pred = self.predict(X)
        ss_res = ((y - y_pred) ** 2).sum()
        ss_tot = ((y - y.mean()) ** 2).sum()
        return 1 - (ss_res / ss_tot)

    def __repr__(self):
        if self.slope is None:
            return "SimpleLinearRegression(unfitted)"
        return f"SimpleLinearRegression(slope={self.slope:.2f}, intercept={self.intercept:.2f})"

# 使用
X = [1, 2, 3, 4, 5]
y = [2, 4, 5, 4, 5]

model = SimpleLinearRegression()
model.fit(X, y)
print(model)  # SimpleLinearRegression(slope=0.60, intercept=2.20)

predictions = model.predict([6, 7, 8])
print(predictions)  # [5.8 6.4 7. ]

r2 = model.score(X, y)
print(f"R² = {r2:.3f}")

实战:数据分析流水线类

python
class DataPipeline:
    """数据处理流水线"""

    def __init__(self, df):
        self.df = df.copy()
        self.original_df = df.copy()
        self.steps = []

    def remove_missing(self, subset=None):
        """删除缺失值"""
        self.df = self.df.dropna(subset=subset)
        self.steps.append("remove_missing")
        return self  # 返回自己以支持链式调用

    def filter_age(self, min_age, max_age):
        """筛选年龄"""
        self.df = self.df[(self.df['age'] >= min_age) & (self.df['age'] <= max_age)]
        self.steps.append(f"filter_age({min_age}, {max_age})")
        return self

    def standardize(self, columns):
        """标准化"""
        for col in columns:
            mean = self.df[col].mean()
            std = self.df[col].std()
            self.df[f'{col}_std'] = (self.df[col] - mean) / std
        self.steps.append(f"standardize({columns})")
        return self

    def get_result(self):
        """获取结果"""
        return self.df

    def get_summary(self):
        """处理摘要"""
        print("=== 数据处理流水线 ===")
        print(f"原始数据: {len(self.original_df)} 行")
        print(f"处理后: {len(self.df)} 行")
        print("\n处理步骤:")
        for i, step in enumerate(self.steps, 1):
            print(f"  {i}. {step}")

# 使用
import pandas as pd

df = pd.DataFrame({
    'id': [1, 2, 3, 4, 5, 6],
    'age': [25, 30, None, 40, 15, 50],
    'income': [50000, 60000, 75000, None, 30000, 90000]
})

# 链式调用
pipeline = DataPipeline(df)
result = (pipeline
    .remove_missing()
    .filter_age(18, 65)
    .standardize(['income'])
    .get_result()
)

pipeline.get_summary()
print("\n处理后的数据:")
print(result)

OOP 最佳实践(数据科学)

1. 设计可链式调用的方法

python
class DataCleaner:
    def __init__(self, df):
        self.df = df

    def drop_na(self):
        self.df = self.df.dropna()
        return self  # 返回自己

    def remove_outliers(self, column):
        q1 = self.df[column].quantile(0.25)
        q3 = self.df[column].quantile(0.75)
        iqr = q3 - q1
        self.df = self.df[
            (self.df[column] >= q1 - 1.5 * iqr) &
            (self.df[column] <= q3 + 1.5 * iqr)
        ]
        return self

# 链式调用
cleaner = DataCleaner(df)
result = (cleaner
    .drop_na()
    .remove_outliers('income')
    .df
)

2. 使用属性存储元数据

python
class Model:
    def __init__(self):
        self.is_fitted = False
        self.n_features = None

    def fit(self, X, y):
        self.n_features = X.shape[1]
        self.is_fitted = True
        # 训练逻辑...

3. 实现 __repr__ 便于调试

python
class Survey:
    def __repr__(self):
        return f"Survey(name='{self.name}', n={len(self.responses)})"

总结

社科学生需要理解的 OOP 要点

  1. Pandas 的 DataFrame 是对象

    python
    df.head()  # 方法
    df.shape   # 属性
  2. Scikit-learn 的模型是对象

    python
    model = LinearRegression()
    model.fit(X, y)
    model.predict(X_new)
  3. 你不需要写复杂的类,但要会用

何时自己创建类?

  • 构建可复用的数据流水线
  • 封装复杂的分析逻辑
  • 大型项目需要组织代码

下一步

Module 6 完成!在下一个模块中,我们将学习 文件操作,学会读写 CSV、Excel、Stata 等数据文件。

继续前进!

基于 MIT 许可证发布。内容版权归作者所有。