Skip to content

Commit 051d15c

Browse files
committed
add prefetch and onPassEnd to PaddleApi.h
1 parent dae8b9b commit 051d15c

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

paddle/api/GradientMachine.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ GradientMachine* GradientMachine::createByModelConfig(
6464
return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types);
6565
}
6666

67+
void GradientMachine::onPassEnd() { m->machine->onPassEnd(); }
68+
69+
void GradientMachine::prefetch(const Arguments& inArgs) {
70+
auto& in =
71+
m->cast<std::vector<paddle::Argument>>(inArgs.getInternalArgumentsPtr());
72+
m->machine->prefetch(in);
73+
}
74+
6775
void GradientMachine::forward(const Arguments& inArgs,
6876
Arguments* outArgs,
6977
PassType passType) {

paddle/api/PaddleAPI.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,16 @@ class GradientMachine {
714714
GradientMatchineCreateMode mode = CREATE_MODE_NORMAL,
715715
const std::vector<int>& parameterTypes = defaultParamTypes);
716716

717+
/**
718+
* Prefetch row ids of sparse parameter.
719+
*/
720+
void prefetch(const Arguments& inArgs);
721+
722+
/**
723+
* Do some thing when train pass ended.
724+
*/
725+
void onPassEnd();
726+
717727
/**
718728
* The forward stage of GradientMachine.
719729
*

0 commit comments

Comments
 (0)