# Used vars # - Model Path # - Split Len # - BaseModelPath # - HeadModelPath import tensorflow as tf import keras from keras.models import Model from keras.layers import Input model = keras.models.load_model("{{ .ModelPath }}") print(model.input_shape) split_len = {{ .SplitLen }} bottom_input = Input(model.input_shape[1:]) bottom_output = bottom_input top_input = Input(model.layers[split_len + 1].input_shape[1:]) top_output = top_input for i, layer in enumerate(model.layers): if split_len >= i: bottom_output = layer(bottom_output) else: top_output = layer(top_output) base_model = Model(bottom_input, bottom_output) head_model = Model(top_input, top_output) tf.saved_model.save(head_model, "{{ .HeadModelPath }}/model") head_model.save("{{ .HeadModelPath }}/model.keras") tf.saved_model.save(base_model, "{{ .BaseModelPath }}/model") base_model.save("{{ .BaseModelPath }}/model.keras")