fyp/views/py/python_split_model_template.py

37 lines
948 B
Python

# Used vars
# - Model Path
# - Split Len
# - BaseModelPath
# - HeadModelPath
import tensorflow as tf
import keras
from keras.models import Model
from keras.layers import Input
model = keras.models.load_model("{{ .ModelPath }}")
print(model.input_shape)
split_len = {{ .SplitLen }}
bottom_input = Input(model.input_shape[1:])
bottom_output = bottom_input
top_input = Input(model.layers[split_len + 1].input_shape[1:], name="head_input")
top_output = top_input
for i, layer in enumerate(model.layers):
if split_len >= i:
bottom_output = layer(bottom_output)
else:
top_output = layer(top_output)
base_model = Model(bottom_input, bottom_output)
head_model = Model(top_input, top_output)
tf.saved_model.save(head_model, "{{ .HeadModelPath }}/model")
head_model.save("{{ .HeadModelPath }}/model.keras")
tf.saved_model.save(base_model, "{{ .BaseModelPath }}/model")
base_model.save("{{ .BaseModelPath }}/model.keras")