跳转至

训练相关

网络结构

基类

struct net {
    virtual string shape() = 0;
    virtual void set_train_mode(const bool &) = 0;
    virtual vec_batch forward(const vec_batch &) = 0;
    virtual vec_batch backward(const vec_batch &) = 0;
    virtual void upd(optimizer &, const batch &) = 0;
    virtual void writef(const string &f) = 0;
    virtual void readf(const string &f) = 0;
    virtual vec_batch &out() = 0;
};

sequential

struct sequential : public net {
    int batch_sz;
    vector<layerp> layers;
    sequential() : batch_sz(0) {
        layerp x = make_shared<same>();
        x->name = "input";
        layers.emplace_back(x);
    }
    void add(const layerp &x) { layers.push_back(x); }
    string shape() {
        string res = "";
        for (auto &it : layers) res += it->name + "\n";
        return res;
    }
    void set_train_mode(const bool &new_train_mod) {
        for (auto &l : layers) l->set_train_mode(new_train_mod);
    }
    vec_batch forward(const vec_batch &input) {
        if ((int)input.size() != batch_sz) {
            batch_sz = input.size();
            for (auto &l : layers) l->resize(batch_sz);
        }
        int layer_sz = layers.size();
        layers[0]->forward(input);
        for (int i = 1; i < layer_sz; i++) layers[i]->forward(layers[i - 1]->out);
        return layers.back()->out;
    }
    vec_batch backward(const vec_batch &label) {
        for (int i = 0; i < batch_sz; i++) layers.back()->grad[i] = layers.back()->out[i] - label[i];
        int layer_sz = layers.size();
        for (int i = layer_sz - 2; i >= 0; i--)
            layers[i]->backward(i ? layers[i - 1]->out : vec_batch(), layers[i + 1]->grad);
        return layers[0]->grad;
    }
    void upd(optimizer &opt, const batch &data) {
        int layer_sz = layers.size();
        for (int i = 0; i < layer_sz; i++) layers[i]->clear_grad();
        forward(data.first);
        backward(data.second);
        for (int i = 0; i < layer_sz; i++) layers[i]->upd(opt);
    }
    void writef(const string &f) {
        ofstream fout(f, ios::binary | ios::out);
        int layer_sz = layers.size();
        for (int i = 0; i < layer_sz; i++) layers[i]->write(fout);
        fout.close();
    }
    void readf(const string &f) {
        ifstream fin(f, ios::binary | ios::in);
        int layer_sz = layers.size();
        for (int i = 0; i < layer_sz; i++) layers[i]->read(fin);
        fin.close();
    }
    vec_batch &out() { return layers.back()->out; }
};

数据集对象

struct data_set {
    batch train, valid;
    data_set() {}
    data_set(const batch &all_data) {
        for (int i = 0; i < (int)all_data.first.size(); i++) {
            int rnd = ri(0, 6);
            if (rnd == 0) {
                valid.first.push_back(all_data.first[i]);
                valid.second.push_back(all_data.second[i]);
            } else {
                train.first.push_back(all_data.first[i]);
                train.second.push_back(all_data.second[i]);
            }
        }
    }
    batch get_train_batch(int batch_sz) const {
        assert(train.first.size());
        batch res;
        for (int i = 0; i < batch_sz; i++) {
            int id = ri(0, train.first.size() - 1);
            res.first.push_back(train.first[id]);
            res.second.push_back(train.second[id]);
        }
        return res;
    }
    batch get_valid_batch(int batch_sz) const {
        assert(valid.first.size());
        batch res;
        for (int i = 0; i < batch_sz; i++) {
            int id = ri(0, valid.first.size() - 1);
            res.first.push_back(valid.first[id]);
            res.second.push_back(valid.second[id]);
        }
        return res;
    }
};

构造函数中随机将大概 \(\frac 17\) 的数据放到 validate 里。在数据集较小是可能出现 validate 或者 train 为空,此时也可以自己重写一个数据集对象。

后面也是随机从数据集中取出一个 batch,如对这个机制不满意,可以重写。

默认的训练函数

void upd(optimizer &opt, const data_set &data, net &net, int batch_sz, int epoch,
         function<float(const vec_batch &, const vec_batch &)> err_func, const string &save_file = "") {
    int t0 = clock();
    float tloss = 0, mult = 1, mn = INF;
    for (int i = 1; i <= epoch; i++) {
        auto tmp = data.get_train_batch(batch_sz);
        net.upd(opt, tmp);

获取一个 batch 的训练数据,并更新系数


        mult *= 0.9;
        tloss = tloss * 0.9 + err_func(net.out(), tmp.second) * 0.1;
        if (i % 50 == 0) {
            cerr << "-------------------------" << endl;
            cerr << "Time elapse: " << (float)(clock() - t0) / CLOCKS_PER_SEC << endl;
            cerr << "Epoch: " << i << endl;
            cerr << "Loss: " << tloss / (1. - mult) << endl;
            if (i % 1000 == 0) {
                net.set_train_mode(0);
                float sum = 0;
                for (int j = 0; j < (int)data.valid.first.size(); j++) {
                    batch tmp = {{data.valid.first[j]}, {data.valid.second[j]}};
                    sum += err_func(net.forward(tmp.first), tmp.second);
                }
                net.set_train_mode(1);
                sum /= data.valid.first.size();
                cerr << "!! Error: " << sum << endl;
                if (sum < mn && save_file != "") {
                    cerr << "Saved" << endl;
                    mn = sum;
                    net.writef(save_file);
                }
            }
        }
    }
}

计算平均 loss,并且保存 loss 最低的网络参数。