37 lines
929 B
Python
37 lines
929 B
Python
# Used vars
|
|
# - Model Path
|
|
# - Split Len
|
|
# - BaseModelPath
|
|
# - HeadModelPath
|
|
|
|
import tensorflow as tf
|
|
import keras
|
|
from keras.models import Model
|
|
from keras.layers import Input
|
|
|
|
model = keras.models.load_model("{{ .ModelPath }}")
|
|
|
|
print(model.input_shape)
|
|
|
|
split_len = {{ .SplitLen }}
|
|
|
|
bottom_input = Input(model.input_shape[1:])
|
|
bottom_output = bottom_input
|
|
top_input = Input(model.layers[split_len + 1].input_shape[1:])
|
|
top_output = top_input
|
|
|
|
for i, layer in enumerate(model.layers):
|
|
if split_len >= i:
|
|
bottom_output = layer(bottom_output)
|
|
else:
|
|
top_output = layer(top_output)
|
|
|
|
base_model = Model(bottom_input, bottom_output)
|
|
head_model = Model(top_input, top_output)
|
|
|
|
tf.saved_model.save(head_model, "{{ .HeadModelPath }}/model")
|
|
head_model.save("{{ .HeadModelPath }}/model.keras")
|
|
|
|
tf.saved_model.save(base_model, "{{ .BaseModelPath }}/model")
|
|
base_model.save("{{ .BaseModelPath }}/model.keras")
|