37 lines
		
	
	
		
			929 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			37 lines
		
	
	
		
			929 B
		
	
	
	
		
			Python
		
	
	
	
	
	
# Used vars
 | 
						|
#  - Model Path
 | 
						|
#  - Split Len
 | 
						|
#  - BaseModelPath
 | 
						|
#  - HeadModelPath
 | 
						|
 | 
						|
import tensorflow as tf
 | 
						|
import keras
 | 
						|
from keras.models import Model
 | 
						|
from keras.layers import Input
 | 
						|
 | 
						|
model = keras.models.load_model("{{ .ModelPath }}")
 | 
						|
 | 
						|
print(model.input_shape)
 | 
						|
 | 
						|
split_len = {{ .SplitLen }}
 | 
						|
 | 
						|
bottom_input = Input(model.input_shape[1:])
 | 
						|
bottom_output = bottom_input
 | 
						|
top_input = Input(model.layers[split_len + 1].input_shape[1:])
 | 
						|
top_output = top_input
 | 
						|
 | 
						|
for i, layer in enumerate(model.layers):
 | 
						|
    if split_len >= i:
 | 
						|
        bottom_output = layer(bottom_output)
 | 
						|
    else:
 | 
						|
        top_output = layer(top_output)
 | 
						|
 | 
						|
base_model = Model(bottom_input, bottom_output)
 | 
						|
head_model = Model(top_input, top_output)
 | 
						|
 | 
						|
tf.saved_model.save(head_model, "{{ .HeadModelPath }}/model")
 | 
						|
head_model.save("{{ .HeadModelPath }}/model.keras")
 | 
						|
 | 
						|
tf.saved_model.save(base_model, "{{ .BaseModelPath }}/model")
 | 
						|
base_model.save("{{ .BaseModelPath }}/model.keras")
 |