Skip to content

Module 6 小结和复习

面向对象编程基础 —— 理解类与对象


知识点总结

1. OOP 核心概念

什么是 OOP?

  • 面向对象编程是将数据和操作数据的方法组织在一起的编程范式
  • 对象(Object):数据 + 方法的集合
  • 类(Class):对象的模板/蓝图
  • 方法(Method):对象的函数

为什么需要 OOP?

  • 数据和方法天然绑定
  • 代码更有组织性
  • 便于复用和维护
  • 符合真实世界建模

核心术语

术语定义示例
对象的模板class Student:
对象/实例类的具体实例alice = Student()
属性对象的数据alice.name = "Alice"
方法对象的函数alice.calculate_gpa()
self指向当前对象self.name
构造函数初始化对象__init__()

2. 类的基本结构

python
class ClassName:
    """类的文档字符串"""

    # 类属性(所有对象共享)
    class_variable = "shared"

    def __init__(self, param1, param2):
        """构造函数"""
        self.param1 = param1  # 实例属性
        self.param2 = param2

    def instance_method(self):
        """实例方法"""
        return self.param1

    @classmethod
    def class_method(cls):
        """类方法"""
        return cls.class_variable

    @staticmethod
    def static_method():
        """静态方法"""
        return "不依赖类或实例"

三种方法类型

方法类型第一个参数访问实例属性访问类属性用途
实例方法self最常用,操作对象数据
类方法cls工厂方法、替代构造函数
静态方法工具函数

3. 实例属性 vs 类属性

python
class Survey:
    # 类属性(所有对象共享)
    total_surveys = 0

    def __init__(self, name, year):
        # 实例属性(每个对象独有)
        self.name = name
        self.year = year
        Survey.total_surveys += 1  # 修改类属性

# 使用
survey1 = Survey("收入调查", 2024)
survey2 = Survey("健康调查", 2024)

print(survey1.name)           # 收入调查(实例属性)
print(Survey.total_surveys)   # 2(类属性)

区别

  • 实例属性:每个对象独有,通过 self.attr 访问
  • 类属性:所有对象共享,通过 ClassName.attr 访问

4. 特殊方法(Magic Methods)

方法用途触发时机
__init__()构造函数obj = Class()
__str__()字符串表示(用户友好)print(obj)
__repr__()开发者表示repr(obj)
__len__()长度len(obj)
__getitem__()索引访问obj[key]
__eq__()相等比较obj1 == obj2

示例

python
class Survey:
    def __init__(self, name):
        self.name = name
        self.responses = []

    def __str__(self):
        return f"Survey: {self.name} ({len(self.responses)} responses)"

    def __len__(self):
        return len(self.responses)

    def __getitem__(self, index):
        return self.responses[index]

# 使用
survey = Survey("测试")
survey.responses = [1, 2, 3]

print(survey)         # Survey: 测试 (3 responses)
print(len(survey))    # 3
print(survey[0])      # 1

5. 封装:公有 vs 私有

python
class BankAccount:
    def __init__(self, balance):
        self.balance = balance       # 公有属性
        self._transactions = []      # 约定私有(单下划线)
        self.__pin = "1234"          # 真正私有(双下划线)

    def deposit(self, amount):
        """公有方法"""
        self.balance += amount
        self._log_transaction("deposit", amount)

    def _log_transaction(self, type, amount):
        """私有方法(约定)"""
        self._transactions.append({'type': type, 'amount': amount})

命名约定

  • name:公有(可直接访问)
  • _name:约定私有(不推荐外部访问,但可以)
  • __name:真正私有(Python 会改名,外部难以访问)

6. OOP 在数据科学中的应用

Pandas DataFrame

python
import pandas as pd

df = pd.DataFrame({'age': [25, 30, 35]})

# 属性
df.shape      # (3, 1)
df.columns    # Index(['age'])

# 方法
df.head()
df.mean()
df.to_csv('output.csv')

# 链式调用
result = (df
    .query('age > 25')
    .assign(age_squared=lambda x: x['age']**2)
    .sort_values('age')
)

Scikit-learn 模型

python
from sklearn.linear_model import LinearRegression

model = LinearRegression()  # 创建对象
model.fit(X, y)             # 训练(方法)
predictions = model.predict(X_new)  # 预测(方法)

# 访问属性
print(model.coef_)       # 系数
print(model.intercept_)  # 截距

Python vs Stata vs R

面向对象对比

Python(完全面向对象):

python
df = pd.DataFrame({'x': [1, 2, 3]})
df.mean()           # 方法调用
df.shape            # 属性访问

R(部分面向对象):

r
df <- data.frame(x = c(1, 2, 3))
mean(df$x)          # 函数调用
dim(df)             # 函数调用

Stata(过程式):

stata
* Stata 主要是命令式
summarize income
generate log_income = log(income)
regress y x1 x2

️ 常见错误

1. 忘记 self 参数

python
#  错误
class Student:
    def __init__(name, age):  # 忘记 self
        name = name  # 不会保存到对象

#  正确
class Student:
    def __init__(self, name, age):
        self.name = name
        self.age = age

2. 混淆实例属性和类属性

python
#  错误
class Counter:
    count = 0  # 类属性

    def increment(self):
        count += 1  # NameError: 没有指定是 self.count 还是 Counter.count

#  正确
class Counter:
    count = 0

    def increment(self):
        Counter.count += 1  # 或 self.__class__.count += 1

3. 直接修改类属性导致意外共享

python
#  错误
class Survey:
    responses = []  # 类属性!

    def add_response(self, resp):
        self.responses.append(resp)  # 所有对象共享同一个列表

#  正确
class Survey:
    def __init__(self):
        self.responses = []  # 实例属性

4. 忘记实现 __str__ 导致输出不友好

python
#  不好
class Student:
    def __init__(self, name):
        self.name = name

s = Student("Alice")
print(s)  # <__main__.Student object at 0x...>

#  好
class Student:
    def __init__(self, name):
        self.name = name

    def __str__(self):
        return f"Student(name='{self.name}')"

s = Student("Alice")
print(s)  # Student(name='Alice')

最佳实践

1. 类命名使用 CapWords 风格

python
#  好
class StudentRecord:
    pass

class SurveyData:
    pass

#  不好
class student_record:
    pass

class surveydata:
    pass

2. 方法命名使用 snake_case

python
class DataAnalyzer:
    def calculate_mean(self):  # 
        pass

    def CalculateMean(self):   # 
        pass

3. 使用文档字符串

python
class Survey:
    """问卷调查类

    用于管理问卷数据,包括添加响应、统计分析等功能。

    属性:
        name (str): 问卷名称
        year (int): 调查年份
        responses (list): 响应列表
    """

    def __init__(self, name, year):
        self.name = name
        self.year = year
        self.responses = []

4. 支持链式调用

python
class DataPipeline:
    def remove_outliers(self):
        # 处理逻辑...
        return self  # 返回自己

    def standardize(self):
        # 处理逻辑...
        return self

    def filter_missing(self):
        # 处理逻辑...
        return self

# 链式调用
pipeline = (DataPipeline(data)
    .remove_outliers()
    .standardize()
    .filter_missing()
)

编程练习

练习 1:学生成绩管理系统(基础)

难度:⭐⭐ 时间:20 分钟

创建一个 Student 类。

要求

python
class Student:
    """学生类"""

    def __init__(self, student_id, name, major):
        pass

    def add_grade(self, course, grade):
        """添加成绩"""
        pass

    def get_gpa(self):
        """计算 GPA(假设满分100,转换为4.0制)"""
        pass

    def __str__(self):
        return f"Student: {self.name} ({self.major}), GPA: {self.get_gpa():.2f}"

# 测试
alice = Student(2024001, "Alice Wang", "Economics")
alice.add_grade("Microeconomics", 85)
alice.add_grade("Econometrics", 90)
alice.add_grade("Statistics", 78)

print(alice)
print(f"GPA: {alice.get_gpa():.2f}")
参考答案
python
class Student:
    """学生类"""

    def __init__(self, student_id, name, major):
        self.student_id = student_id
        self.name = name
        self.major = major
        self.grades = {}  # {course: grade}

    def add_grade(self, course, grade):
        """添加成绩"""
        if not (0 <= grade <= 100):
            raise ValueError("成绩必须在 0-100 之间")
        self.grades[course] = grade

    def get_gpa(self):
        """计算 GPA(百分制转4.0制)"""
        if not self.grades:
            return 0.0

        # 转换规则:90-100=4.0, 80-89=3.0, 70-79=2.0, 60-69=1.0, <60=0.0
        total_points = 0
        for grade in self.grades.values():
            if grade >= 90:
                total_points += 4.0
            elif grade >= 80:
                total_points += 3.0
            elif grade >= 70:
                total_points += 2.0
            elif grade >= 60:
                total_points += 1.0
            else:
                total_points += 0.0

        return total_points / len(self.grades)

    def get_average_score(self):
        """计算平均分"""
        if not self.grades:
            return 0.0
        return sum(self.grades.values()) / len(self.grades)

    def __str__(self):
        return f"Student: {self.name} ({self.major}), GPA: {self.get_gpa():.2f}"

    def __repr__(self):
        return f"Student(id={self.student_id}, name='{self.name}', courses={len(self.grades)})"


# 测试
alice = Student(2024001, "Alice Wang", "Economics")
alice.add_grade("Microeconomics", 85)
alice.add_grade("Econometrics", 90)
alice.add_grade("Statistics", 78)

print(alice)                           # Student: Alice Wang (Economics), GPA: 3.00
print(f"平均分: {alice.get_average_score():.1f}")  # 84.3
print(repr(alice))                      # Student(id=2024001, name='Alice Wang', courses=3)

练习 2:问卷数据容器(基础)

难度:⭐⭐ 时间:25 分钟

python
class SurveyData:
    """问卷数据管理类"""

    def __init__(self, survey_name):
        pass

    def add_response(self, response):
        """添加响应"""
        pass

    def _validate(self, response):
        """私有方法:验证数据"""
        pass

    def get_average_income(self):
        """计算平均收入"""
        pass

    def filter_by_age(self, min_age, max_age):
        """按年龄筛选"""
        pass

    def __len__(self):
        return len(self.responses)

    def __str__(self):
        return f"{self.survey_name}: {len(self)} responses"

# 测试
survey = SurveyData("2024收入调查")
survey.add_response({'id': 1, 'age': 30, 'income': 75000})
survey.add_response({'id': 2, 'age': 35, 'income': 85000})

print(survey)
print(f"平均收入: ${survey.get_average_income():,.0f}")
参考答案
python
class SurveyData:
    """问卷数据管理类"""

    def __init__(self, survey_name):
        self.survey_name = survey_name
        self.responses = []

    def add_response(self, response):
        """添加响应"""
        if self._validate(response):
            self.responses.append(response)
            return True
        else:
            print(f"️  无效数据: {response}")
            return False

    def _validate(self, response):
        """私有方法:验证数据"""
        required_fields = ['id', 'age', 'income']

        # 检查必填字段
        if not all(field in response for field in required_fields):
            return False

        # 验证年龄
        if not (0 < response['age'] < 120):
            return False

        # 验证收入
        if response['income'] < 0:
            return False

        return True

    def get_average_income(self):
        """计算平均收入"""
        if not self.responses:
            return 0
        incomes = [r['income'] for r in self.responses]
        return sum(incomes) / len(incomes)

    def filter_by_age(self, min_age, max_age):
        """按年龄筛选"""
        return [r for r in self.responses
                if min_age <= r['age'] <= max_age]

    def get_income_stats(self):
        """收入统计"""
        if not self.responses:
            return {}

        incomes = [r['income'] for r in self.responses]
        return {
            'mean': sum(incomes) / len(incomes),
            'min': min(incomes),
            'max': max(incomes),
            'count': len(incomes)
        }

    def __len__(self):
        return len(self.responses)

    def __str__(self):
        return f"{self.survey_name}: {len(self)} responses"

    def __getitem__(self, index):
        """支持索引访问"""
        return self.responses[index]


# 测试
survey = SurveyData("2024收入调查")

# 添加有效数据
survey.add_response({'id': 1, 'age': 30, 'income': 75000})
survey.add_response({'id': 2, 'age': 35, 'income': 85000})
survey.add_response({'id': 3, 'age': 45, 'income': 95000})

# 添加无效数据(会被拒绝)
survey.add_response({'id': 4, 'age': -5, 'income': 50000})  # 年龄无效
survey.add_response({'id': 5, 'age': 28})  # 缺少 income 字段

print(survey)  # 2024收入调查: 3 responses
print(f"平均收入: ${survey.get_average_income():,.0f}")
print(f"30-40岁: {len(survey.filter_by_age(30, 40))} 人")
print(f"第一条数据: {survey[0]}")

stats = survey.get_income_stats()
print(f"\n收入统计:")
print(f"  样本量: {stats['count']}")
print(f"  平均值: ${stats['mean']:,.0f}")
print(f"  范围: ${stats['min']:,} - ${stats['max']:,}")

练习 3:数据分析流水线(中等)

难度:⭐⭐⭐ 时间:35 分钟

创建一个支持链式调用的数据处理流水线。

python
class DataPipeline:
    """数据处理流水线"""

    def __init__(self, data):
        pass

    def filter_by(self, condition):
        """按条件筛选,支持 Lambda"""
        pass

    def transform(self, func):
        """转换数据"""
        pass

    def group_by(self, key):
        """分组"""
        pass

    def get_result(self):
        """获取结果"""
        pass

    def summary(self):
        """处理摘要"""
        pass

# 测试
data = [
    {'id': 1, 'age': 25, 'income': 50000, 'city': 'Beijing'},
    {'id': 2, 'age': 35, 'income': 80000, 'city': 'Shanghai'},
    # ...
]

result = (DataPipeline(data)
    .filter_by(lambda x: x['age'] >= 30)
    .transform(lambda x: {**x, 'income_万元': x['income'] / 10000})
    .get_result()
)
参考答案
python
class DataPipeline:
    """数据处理流水线"""

    def __init__(self, data):
        self.original_data = data.copy()
        self.data = data.copy()
        self.steps = []

    def filter_by(self, condition):
        """按条件筛选"""
        self.data = [item for item in self.data if condition(item)]
        self.steps.append(f"filter_by (保留 {len(self.data)} 条)")
        return self

    def transform(self, func):
        """转换数据"""
        self.data = [func(item) for item in self.data]
        self.steps.append("transform")
        return self

    def remove_field(self, *fields):
        """删除字段"""
        self.data = [{k: v for k, v in item.items() if k not in fields}
                     for item in self.data]
        self.steps.append(f"remove_field({', '.join(fields)})")
        return self

    def add_field(self, field_name, func):
        """添加新字段"""
        for item in self.data:
            item[field_name] = func(item)
        self.steps.append(f"add_field('{field_name}')")
        return self

    def sort_by(self, key, reverse=False):
        """排序"""
        self.data = sorted(self.data, key=key, reverse=reverse)
        self.steps.append(f"sort_by (reverse={reverse})")
        return self

    def limit(self, n):
        """限制数量"""
        self.data = self.data[:n]
        self.steps.append(f"limit({n})")
        return self

    def group_by(self, key_func):
        """分组"""
        groups = {}
        for item in self.data:
            group_key = key_func(item)
            if group_key not in groups:
                groups[group_key] = []
            groups[group_key].append(item)

        # 转换为分组结果格式
        self.data = [
            {'group': k, 'items': v, 'count': len(v)}
            for k, v in groups.items()
        ]
        self.steps.append(f"group_by ({len(self.data)} 组)")
        return self

    def get_result(self):
        """获取结果"""
        return self.data

    def summary(self):
        """处理摘要"""
        print("=" * 50)
        print(f"数据处理流水线摘要")
        print("=" * 50)
        print(f"原始数据: {len(self.original_data)} 条")
        print(f"处理后: {len(self.data)} 条")
        print(f"\n处理步骤:")
        for i, step in enumerate(self.steps, 1):
            print(f"  {i}. {step}")
        print("=" * 50)

    def __len__(self):
        return len(self.data)

    def __repr__(self):
        return f"DataPipeline(records={len(self.data)}, steps={len(self.steps)})"


# 测试
data = [
    {'id': 1, 'age': 25, 'income': 50000, 'city': 'Beijing', 'gender': 'F'},
    {'id': 2, 'age': 35, 'income': 80000, 'city': 'Shanghai', 'gender': 'M'},
    {'id': 3, 'age': 45, 'income': 120000, 'city': 'Beijing', 'gender': 'F'},
    {'id': 4, 'age': 28, 'income': 65000, 'city': 'Guangzhou', 'gender': 'M'},
    {'id': 5, 'age': 32, 'income': 95000, 'city': 'Shanghai', 'gender': 'F'},
    {'id': 6, 'age': 40, 'income': 110000, 'city': 'Beijing', 'gender': 'M'},
]

# 示例1: 基本流水线
print("示例1: 筛选30岁以上,转换收入为万元")
result1 = (DataPipeline(data)
    .filter_by(lambda x: x['age'] >= 30)
    .add_field('income_万元', lambda x: round(x['income'] / 10000, 2))
    .remove_field('gender')
    .sort_by(lambda x: x['income'], reverse=True)
    .get_result()
)

for r in result1:
    print(f"  ID{r['id']}: {r['age']}岁, {r['city']}, {r['income_万元']}万元")

# 示例2: 分组统计
print("\n示例2: 按城市分组")
pipeline2 = DataPipeline(data)
result2 = (pipeline2
    .filter_by(lambda x: x['age'] >= 25)
    .group_by(lambda x: x['city'])
    .get_result()
)

for group in result2:
    avg_income = sum(item['income'] for item in group['items']) / len(group['items'])
    print(f"  {group['group']:12s}: {group['count']} 人, 平均收入 ${avg_income:,.0f}")

pipeline2.summary()

# 示例3: Top N
print("\n示例3: 收入最高的3人")
result3 = (DataPipeline(data)
    .sort_by(lambda x: x['income'], reverse=True)
    .limit(3)
    .get_result()
)

for i, r in enumerate(result3, 1):
    print(f"  {i}. ID{r['id']}: {r['age']}岁, ${r['income']:,}")

练习 4:简单线性回归类(进阶)

难度:⭐⭐⭐⭐ 时间:40 分钟

实现一个简单的线性回归类,模仿 Scikit-learn 的 API 设计。

python
class SimpleLinearRegression:
    """简单线性回归"""

    def __init__(self):
        pass

    def fit(self, X, y):
        """拟合模型"""
        pass

    def predict(self, X):
        """预测"""
        pass

    def score(self, X, y):
        """计算 R²"""
        pass

    def __repr__(self):
        pass

# 测试
X = [1, 2, 3, 4, 5]
y = [2, 4, 5, 4, 5]

model = SimpleLinearRegression()
model.fit(X, y)
print(model)  # 显示斜率和截距

predictions = model.predict([6, 7, 8])
print(f"预测值: {predictions}")

r2 = model.score(X, y)
print(f"R² = {r2:.3f}")
参考答案
python
import numpy as np

class SimpleLinearRegression:
    """简单线性回归(y = slope * x + intercept)"""

    def __init__(self):
        self.slope = None
        self.intercept = None
        self.is_fitted = False

    def fit(self, X, y):
        """拟合模型

        参数:
            X: 自变量(一维数组)
            y: 因变量(一维数组)

        返回:
            self (支持链式调用)
        """
        X = np.array(X)
        y = np.array(y)

        if len(X) != len(y):
            raise ValueError("X 和 y 长度必须相同")

        # 计算斜率和截距
        x_mean = X.mean()
        y_mean = y.mean()

        # slope = Σ((x - x̄)(y - ȳ)) / Σ((x - x̄)²)
        numerator = ((X - x_mean) * (y - y_mean)).sum()
        denominator = ((X - x_mean) ** 2).sum()

        if denominator == 0:
            raise ValueError("X 的方差为0,无法拟合")

        self.slope = numerator / denominator
        self.intercept = y_mean - self.slope * x_mean
        self.is_fitted = True

        return self  # 支持链式调用

    def predict(self, X):
        """预测

        参数:
            X: 自变量

        返回:
            预测值数组
        """
        if not self.is_fitted:
            raise ValueError("模型未训练,请先调用 fit()")

        X = np.array(X)
        return self.slope * X + self.intercept

    def score(self, X, y):
        """计算 R²(决定系数)

        R² = 1 - (SS_res / SS_tot)

        参数:
            X: 自变量
            y: 真实值

        返回:
            R² 值(0-1,越接近1越好)
        """
        y = np.array(y)
        y_pred = self.predict(X)

        # 残差平方和
        ss_res = ((y - y_pred) ** 2).sum()

        # 总平方和
        ss_tot = ((y - y.mean()) ** 2).sum()

        if ss_tot == 0:
            return 0.0

        return 1 - (ss_res / ss_tot)

    def get_residuals(self, X, y):
        """计算残差"""
        y_pred = self.predict(X)
        return np.array(y) - y_pred

    def summary(self):
        """打印模型摘要"""
        if not self.is_fitted:
            print("模型未训练")
            return

        print("=" * 50)
        print("简单线性回归模型摘要")
        print("=" * 50)
        print(f"斜率 (slope):     {self.slope:.4f}")
        print(f"截距 (intercept): {self.intercept:.4f}")
        print(f"方程: y = {self.slope:.4f}x + {self.intercept:.4f}")
        print("=" * 50)

    def __repr__(self):
        if not self.is_fitted:
            return "SimpleLinearRegression(unfitted)"
        return f"SimpleLinearRegression(slope={self.slope:.4f}, intercept={self.intercept:.4f})"

    def __str__(self):
        if not self.is_fitted:
            return "未训练的模型"
        return f"y = {self.slope:.4f}x + {self.intercept:.4f}"


# 测试
print("=" * 60)
print("简单线性回归测试")
print("=" * 60)

# 数据1: 完美线性关系
print("\n测试1: 完美线性关系 (y = 2x)")
X1 = [1, 2, 3, 4, 5]
y1 = [2, 4, 6, 8, 10]

model1 = SimpleLinearRegression()
model1.fit(X1, y1)
print(model1)
model1.summary()

predictions1 = model1.predict([6, 7, 8])
print(f"预测 x=[6,7,8]: {predictions1}")
print(f"R² = {model1.score(X1, y1):.4f}")

# 数据2: 有噪声的线性关系
print("\n测试2: 有噪声的线性关系")
X2 = [1, 2, 3, 4, 5]
y2 = [2, 4, 5, 4, 5]

model2 = SimpleLinearRegression()
model2.fit(X2, y2)
print(model2)

predictions2 = model2.predict([6, 7, 8])
print(f"预测 x=[6,7,8]: {predictions2}")
print(f"R² = {model2.score(X2, y2):.4f}")

# 残差分析
residuals = model2.get_residuals(X2, y2)
print(f"残差: {residuals}")

# 数据3: 收入和教育年限
print("\n测试3: 收入 vs 教育年限")
education_years = [12, 14, 16, 18, 20]  # 教育年限
income = [35000, 45000, 60000, 75000, 90000]  # 收入

model3 = SimpleLinearRegression()
model3.fit(education_years, income)
model3.summary()

# 预测:大学本科(16年)和硕士(18年)
predictions3 = model3.predict([16, 18, 20])
print(f"\n预测收入:")
print(f"  本科(16年): ${predictions3[0]:,.0f}")
print(f"  硕士(18年): ${predictions3[1]:,.0f}")
print(f"  博士(20年): ${predictions3[2]:,.0f}")
print(f"\nR² = {model3.score(education_years, income):.4f}")

print("\n" + "=" * 60)

下一步

完成本章后,你已经掌握了:

  • OOP 核心概念(类、对象、方法、属性)
  • 特殊方法(__init__, __str__, __len__ 等)
  • 封装(公有/私有)
  • OOP 在数据科学中的应用

恭喜你完成 Module 6!

Module 7 中,我们将学习文件操作,包括读写 CSV、Excel、Stata 文件等。


扩展阅读

准备好学习文件操作了吗?

基于 MIT 许可证发布。内容版权归作者所有。