feat: conitnued working on the split models

This commit is contained in:
Andre Henriques 2024-02-12 14:30:43 +00:00
parent ef1e10cb7c
commit 2c39a6e7fe
2 changed files with 142 additions and 0 deletions

View File

@ -780,9 +780,115 @@ func trainModelExp(c *Context, model *BaseModel) {
return return
} }
if err = splitModel(c, model); err != nil {
failed("Failed to split the model")
return
}
ModelUpdateStatus(c, model.Id, READY) ModelUpdateStatus(c, model.Id, READY)
} }
func splitModel(c *Context, model *BaseModel) (err error) {
type Def struct {
Id string
}
def := Def{}
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
return
}
head := Def{}
if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil {
return
}
// Generate run folder
run_path := path.Join("/tmp", model.Id, "defs", def.Id)
err = os.MkdirAll(run_path, os.ModePerm)
if err != nil {
return
}
// TODO reneable it
// defer os.RemoveAll(run_path)
// Create python script
f, err := os.Create(path.Join(run_path, "run.py"))
if err != nil {
return
}
defer f.Close()
tmpl, err := template.New("python_split_model_template.py").ParseFiles("views/py/python_split_model_template.py")
if err != nil {
return
}
// Copy result around
result_path := path.Join(getDir(), "savedData", model.Id, "defs", def.Id)
// TODO maybe move this to a select count(*)
// Get only fixed lawers
layers, err := c.Db.Query("select exp_type from model_definition_layer where def_id=$1 and exp_type=$2 order by layer_order asc;", def.Id, 1)
if err != nil {
return
}
defer layers.Close()
type layerrow struct {
ExpType int
}
count := -1
for layers.Next() {
count += 1
}
if count == -1 {
err = errors.New("Can not get layers")
return
}
log.Warn("Spliting model", "def", def.Id, "head", head.Id, "count", count)
basePath := path.Join(result_path, "base")
headPath := path.Join(result_path, "head", head.Id)
if err = os.MkdirAll(basePath, os.ModePerm); err != nil {
return
}
if err = os.MkdirAll(headPath, os.ModePerm); err != nil {
return
}
if err = tmpl.Execute(f, AnyMap{
"SplitLen": count,
"ModelPath": path.Join(result_path, "model.keras"),
"BaseModelPath": basePath,
"HeadModelPath": headPath,
}); err != nil {
return
}
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput()
if err != nil {
c.Logger.Debug(string(out))
return
}
c.Logger.Info("Python finished running")
return
}
func removeFailedDataPoints(c *Context, model *BaseModel) (err error) { func removeFailedDataPoints(c *Context, model *BaseModel) (err error) {
rows, err := c.Db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model.Id) rows, err := c.Db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model.Id)
if err != nil { if err != nil {

View File

@ -0,0 +1,36 @@
# Used vars
# - Model Path
# - Split Len
# - BaseModelPath
# - HeadModelPath
import tensorflow as tf
import keras
from keras.models import Model
from keras.layers import Input
model = keras.models.load_model("{{ .ModelPath }}")
print(model.input_shape)
split_len = {{ .SplitLen }}
bottom_input = Input(model.input_shape[1:])
bottom_output = bottom_input
top_input = Input(model.layers[split_len + 1].input_shape[1:])
top_output = top_input
for i, layer in enumerate(model.layers):
if split_len >= i:
bottom_output = layer(bottom_output)
else:
top_output = layer(top_output)
base_model = Model(bottom_input, bottom_output)
head_model = Model(top_input, top_output)
tf.saved_model.save(head_model, "{{ .HeadModelPath }}/model")
head_model.save("{{ .HeadModelPath }}/model.keras")
tf.saved_model.save(base_model, "{{ .BaseModelPath }}/model")
base_model.save("{{ .BaseModelPath }}/model.keras")