diff --git a/logic/models/train/train.go b/logic/models/train/train.go index ecd9cb4..bcc99b9 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -780,9 +780,115 @@ func trainModelExp(c *Context, model *BaseModel) { return } + if err = splitModel(c, model); err != nil { + failed("Failed to split the model") + return + } + ModelUpdateStatus(c, model.Id, READY) } +func splitModel(c *Context, model *BaseModel) (err error) { + + type Def struct { + Id string + } + + def := Def{} + + if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil { + return + } + + head := Def{} + + if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil { + return + } + + // Generate run folder + run_path := path.Join("/tmp", model.Id, "defs", def.Id) + + err = os.MkdirAll(run_path, os.ModePerm) + if err != nil { + return + } + // TODO reneable it + // defer os.RemoveAll(run_path) + + // Create python script + f, err := os.Create(path.Join(run_path, "run.py")) + if err != nil { + return + } + defer f.Close() + + tmpl, err := template.New("python_split_model_template.py").ParseFiles("views/py/python_split_model_template.py") + if err != nil { + return + } + + // Copy result around + result_path := path.Join(getDir(), "savedData", model.Id, "defs", def.Id) + + // TODO maybe move this to a select count(*) + // Get only fixed lawers + layers, err := c.Db.Query("select exp_type from model_definition_layer where def_id=$1 and exp_type=$2 order by layer_order asc;", def.Id, 1) + if err != nil { + return + } + defer layers.Close() + + type layerrow struct { + ExpType int + } + + count := -1 + + for layers.Next() { + count += 1 + } + + if count == -1 { + err = errors.New("Can not get layers") + return + } + + log.Warn("Spliting model", "def", def.Id, "head", head.Id, "count", count) + + basePath := path.Join(result_path, "base") + headPath := path.Join(result_path, "head", head.Id) + + if err = os.MkdirAll(basePath, os.ModePerm); err != nil { + return + } + + if err = os.MkdirAll(headPath, os.ModePerm); err != nil { + return + } + + if err = tmpl.Execute(f, AnyMap{ + "SplitLen": count, + "ModelPath": path.Join(result_path, "model.keras"), + "BaseModelPath": basePath, + "HeadModelPath": headPath, + }); err != nil { + return + } + + out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput() + if err != nil { + c.Logger.Debug(string(out)) + return + } + + c.Logger.Info("Python finished running") + + return +} + + + func removeFailedDataPoints(c *Context, model *BaseModel) (err error) { rows, err := c.Db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model.Id) if err != nil { diff --git a/views/py/python_split_model_template.py b/views/py/python_split_model_template.py new file mode 100644 index 0000000..f251cc1 --- /dev/null +++ b/views/py/python_split_model_template.py @@ -0,0 +1,36 @@ +# Used vars +# - Model Path +# - Split Len +# - BaseModelPath +# - HeadModelPath + +import tensorflow as tf +import keras +from keras.models import Model +from keras.layers import Input + +model = keras.models.load_model("{{ .ModelPath }}") + +print(model.input_shape) + +split_len = {{ .SplitLen }} + +bottom_input = Input(model.input_shape[1:]) +bottom_output = bottom_input +top_input = Input(model.layers[split_len + 1].input_shape[1:]) +top_output = top_input + +for i, layer in enumerate(model.layers): + if split_len >= i: + bottom_output = layer(bottom_output) + else: + top_output = layer(top_output) + +base_model = Model(bottom_input, bottom_output) +head_model = Model(top_input, top_output) + +tf.saved_model.save(head_model, "{{ .HeadModelPath }}/model") +head_model.save("{{ .HeadModelPath }}/model.keras") + +tf.saved_model.save(base_model, "{{ .BaseModelPath }}/model") +base_model.save("{{ .BaseModelPath }}/model.keras")