Skip to content

Commit dadd48a

Browse files
authored
Merge pull request #963 from reyoung/feature/add_const_in_parameter_updater
Add const in ParameterUpdater init
2 parents 2965df5 + 0d1703d commit dadd48a

File tree

8 files changed

+16
-14
lines changed

8 files changed

+16
-14
lines changed

paddle/parameter/ParameterUpdaterBase.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ limitations under the License. */
1919

2020
namespace paddle {
2121

22-
void ParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
22+
void ParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
2323
parameters_ = parameters;
2424
for (ParameterType type : getParameterTypes()) {
2525
for (auto& para : parameters) {

paddle/parameter/ParameterUpdaterBase.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class ParameterUpdater {
3232
parameterTypes_.push_back(type);
3333
}
3434

35-
virtual void init(std::vector<ParameterPtr>& parameters);
35+
virtual void init(const std::vector<ParameterPtr>& parameters);
3636

3737
// called by Trainer when starting a new pass
3838
virtual void startPass() {}
@@ -105,7 +105,7 @@ class ParameterUpdaterComposite : public ParameterUpdater {
105105
ParameterUpdaterComposite() {}
106106
virtual ~ParameterUpdaterComposite() {}
107107

108-
virtual void init(std::vector<ParameterPtr>& parameters) = 0;
108+
virtual void init(const std::vector<ParameterPtr>& parameters) = 0;
109109

110110
virtual void startPass() {
111111
syncThreadPool_->execPlusOwner(

paddle/trainer/ParameterUpdater.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ SgdUpdaterWithCpuAverager::SgdUpdaterWithCpuAverager(
3434
updateWorker_.addJob([]() { hl_set_device(FLAGS_gpu_id); });
3535
}
3636

37-
void SgdUpdaterWithCpuAverager::init(std::vector<ParameterPtr>& parameters) {
37+
void SgdUpdaterWithCpuAverager::init(
38+
const std::vector<ParameterPtr>& parameters) {
3839
SgdLocalUpdater::init(parameters);
3940
averager_->init(parameters_.size(), nullptr);
4041
copyEvents_.resize(parameters_.size());

paddle/trainer/ParameterUpdater.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class SgdLocalUpdater : public ParameterUpdater {
6464
* be initialized.
6565
* @param parameters The parameter need to be initialized.
6666
*/
67-
virtual void init(std::vector<ParameterPtr>& parameters) {
67+
virtual void init(const std::vector<ParameterPtr>& parameters) {
6868
ParameterUpdater::init(parameters);
6969
optimizer_->init(parameters_.size(), nullptr);
7070
// check no L1 decay in parameter configs
@@ -208,7 +208,7 @@ class SgdUpdaterWithCpuAverager : public SgdLocalUpdater {
208208
* @brief init. Initialize cpu parameters, model average optimizer.
209209
* @param parameters
210210
*/
211-
virtual void init(std::vector<ParameterPtr>& parameters);
211+
virtual void init(const std::vector<ParameterPtr>& parameters);
212212

213213
virtual PassType startBatch(int64_t batchSize) {
214214
averager_->startBatch(-1UL);

paddle/trainer/RemoteParameterUpdater.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ RemoteParameterUpdater::RemoteParameterUpdater(
4444
addParameterType(PARAMETER_MOMENTUM);
4545
}
4646

47-
void RemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
47+
void RemoteParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
4848
ParameterUpdater::init(parameters);
4949

5050
if (localUpdater_) {
@@ -595,7 +595,8 @@ SparseRemoteParameterUpdater::SparseRemoteParameterUpdater(
595595
testing_(testing),
596596
useApplyInPserver_(false) {}
597597

598-
void SparseRemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
598+
void SparseRemoteParameterUpdater::init(
599+
const std::vector<ParameterPtr>& parameters) {
599600
ParameterUpdater::init(parameters);
600601

601602
parameterClient_.reset(new ParameterClient2(
@@ -809,7 +810,7 @@ void SparseRemoteParameterUpdater::saveParametersRemote(
809810
}
810811

811812
void SparseRemoteParameterUpdaterComposite::init(
812-
std::vector<ParameterPtr>& parameters) {
813+
const std::vector<ParameterPtr>& parameters) {
813814
parameters_ = parameters;
814815

815816
std::vector<ParameterPtr> parametersArray[NUMBER_UPDATERS];

paddle/trainer/RemoteParameterUpdater.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
6767
/**
6868
* initialize the internal parameter client and itself.
6969
*/
70-
virtual void init(std::vector<ParameterPtr>& parameters);
70+
virtual void init(const std::vector<ParameterPtr>& parameters);
7171
/**
7272
* @brief start batch
7373
*
@@ -274,7 +274,7 @@ class SparseRemoteParameterUpdater : public ParameterUpdater {
274274
}
275275

276276
/// initialization
277-
virtual void init(std::vector<ParameterPtr>& parameters);
277+
virtual void init(const std::vector<ParameterPtr>& parameters);
278278

279279
/// stateful batch control
280280
virtual PassType startBatch(int64_t batchSize);
@@ -360,7 +360,7 @@ class SparseRemoteParameterUpdaterComposite : public ParameterUpdaterComposite {
360360
}
361361

362362
/// initialization of dense and sparse updaters
363-
virtual void init(std::vector<ParameterPtr>& parameters);
363+
virtual void init(const std::vector<ParameterPtr>& parameters);
364364
};
365365

366366
class ParameterUpdaterCreators {

paddle/trainer/ThreadParameterUpdater.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ SgdThreadUpdater::SgdThreadUpdater(const OptimizationConfig& optConfig)
3232
}
3333
}
3434

35-
void SgdThreadUpdater::init(std::vector<ParameterPtr>& parameters) {
35+
void SgdThreadUpdater::init(const std::vector<ParameterPtr>& parameters) {
3636
ParameterUpdater::init(parameters);
3737

3838
// calc max parameter id

paddle/trainer/ThreadParameterUpdater.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class SgdThreadUpdater : public ParameterUpdater {
4949
// Use the finishPass() function of the base optimizer.
5050
virtual bool finishPass(real cost);
5151

52-
virtual void init(std::vector<ParameterPtr>& parameters);
52+
virtual void init(const std::vector<ParameterPtr>& parameters);
5353
virtual PassType startBatch(int64_t batchSize);
5454
// Call finishBatch for each optimizer.
5555
virtual void finishBatch(real cost);

0 commit comments

Comments
 (0)