Source code for jz_python_chest.callback_example
# trainer callback example
[docs]
class Callback:
[docs]
def on_train_begin(self, logs=None):
pass
[docs]
def on_train_end(self, logs=None):
pass
[docs]
def on_epoch_begin(self, epoch, logs=None):
pass
[docs]
def on_train_step_end(self, step, logs=None):
pass
[docs]
def on_train_epoch_end(self, epoch, logs=None):
pass
[docs]
class PrintStepCallback(Callback):
[docs]
def on_train_step_end(self, step, logs=None):
print(f"Step {step} completed.")
[docs]
class PrintEpochCallback(Callback):
[docs]
def on_train_epoch_end(self, epoch, logs=None):
print(f"Epoch {epoch} completed.")
[docs]
class PrintTrainBeginCallback(Callback):
[docs]
def on_train_begin(self, logs=None):
print("Training started.")
[docs]
class PrintTrainEndCallback(Callback):
[docs]
def on_train_end(self, logs=None):
print("Training finished.")
[docs]
class Trainer:
def __init__(self, model, callbacks=None):
self.model = model
self.callbacks = callbacks or []
[docs]
def train(self, epochs, steps_per_epoch):
logs = {}
# Call on_train_begin callbacks
for callback in self.callbacks:
callback.on_train_begin(logs)
for epoch in range(epochs):
# Call on_epoch_begin callbacks
for callback in self.callbacks:
callback.on_epoch_begin(epoch, logs)
for step in range(steps_per_epoch):
# Simulate a training step
self.model.train_step()
# Call on_train_step_end callbacks
for callback in self.callbacks:
callback.on_train_step_end(step, logs)
# Call on_train_epoch_end callbacks
for callback in self.callbacks:
callback.on_train_epoch_end(epoch, logs)
# Call on_train_end callbacks
for callback in self.callbacks:
callback.on_train_end(logs)
# Example model class
[docs]
class Model:
[docs]
def train_step(self):
pass # Implement your training step logic here
if __name__ == '__main__':
# Usage
model = Model()
callbacks = [
PrintTrainBeginCallback(),
PrintStepCallback(),
PrintEpochCallback(),
PrintTrainEndCallback()
]
trainer = Trainer(model, callbacks)
trainer.train(epochs=5, steps_per_epoch=10)