feat: conitnued working on the split models
This commit is contained in:
36
views/py/python_split_model_template.py
Normal file
36
views/py/python_split_model_template.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Used vars
|
||||
# - Model Path
|
||||
# - Split Len
|
||||
# - BaseModelPath
|
||||
# - HeadModelPath
|
||||
|
||||
import tensorflow as tf
|
||||
import keras
|
||||
from keras.models import Model
|
||||
from keras.layers import Input
|
||||
|
||||
model = keras.models.load_model("{{ .ModelPath }}")
|
||||
|
||||
print(model.input_shape)
|
||||
|
||||
split_len = {{ .SplitLen }}
|
||||
|
||||
bottom_input = Input(model.input_shape[1:])
|
||||
bottom_output = bottom_input
|
||||
top_input = Input(model.layers[split_len + 1].input_shape[1:])
|
||||
top_output = top_input
|
||||
|
||||
for i, layer in enumerate(model.layers):
|
||||
if split_len >= i:
|
||||
bottom_output = layer(bottom_output)
|
||||
else:
|
||||
top_output = layer(top_output)
|
||||
|
||||
base_model = Model(bottom_input, bottom_output)
|
||||
head_model = Model(top_input, top_output)
|
||||
|
||||
tf.saved_model.save(head_model, "{{ .HeadModelPath }}/model")
|
||||
head_model.save("{{ .HeadModelPath }}/model.keras")
|
||||
|
||||
tf.saved_model.save(base_model, "{{ .BaseModelPath }}/model")
|
||||
base_model.save("{{ .BaseModelPath }}/model.keras")
|
||||
Reference in New Issue
Block a user