more work on the fyp
This commit is contained in:
@@ -211,33 +211,33 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
defer rows.Close()
|
||||
|
||||
type ExpHead struct {
|
||||
id string
|
||||
id string
|
||||
start int
|
||||
end int
|
||||
}
|
||||
|
||||
exp := ExpHead{}
|
||||
exp := ExpHead{}
|
||||
|
||||
if rows.Next() {
|
||||
if err = rows.Scan(&exp.id, &exp.start, &exp.end); err == nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.Error("Failed to get the exp head of the model")
|
||||
err = errors.New("Failed to get the exp head of the model")
|
||||
return
|
||||
}
|
||||
if rows.Next() {
|
||||
if err = rows.Scan(&exp.id, &exp.start, &exp.end); err == nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.Error("Failed to get the exp head of the model")
|
||||
err = errors.New("Failed to get the exp head of the model")
|
||||
return
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
log.Error("This training function can only train one model at the time")
|
||||
err = errors.New("This training function can only train one model at the time")
|
||||
return
|
||||
}
|
||||
if rows.Next() {
|
||||
log.Error("This training function can only train one model at the time")
|
||||
err = errors.New("This training function can only train one model at the time")
|
||||
return
|
||||
}
|
||||
|
||||
layers, err := c.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -246,23 +246,36 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
type layerrow struct {
|
||||
LayerType int
|
||||
Shape string
|
||||
ExpType int
|
||||
LayerNum int
|
||||
}
|
||||
|
||||
got := []layerrow{}
|
||||
|
||||
remove_top_count := 1
|
||||
|
||||
i := 1
|
||||
|
||||
for layers.Next() {
|
||||
var row = layerrow{}
|
||||
if err = layers.Scan(&row.LayerType, &row.Shape); err != nil {
|
||||
if err = layers.Scan(&row.LayerType, &row.Shape, &row.ExpType); err != nil {
|
||||
return
|
||||
}
|
||||
row.LayerNum = i
|
||||
if row.ExpType == 2 {
|
||||
remove_top_count += 1
|
||||
}
|
||||
row.Shape = shapeToSize(row.Shape)
|
||||
got = append(got, row)
|
||||
i += 1
|
||||
}
|
||||
|
||||
got = append(got, layerrow{
|
||||
LayerType: LAYER_DENSE,
|
||||
Shape: fmt.Sprintf("%d", exp.end - exp.start),
|
||||
})
|
||||
got = append(got, layerrow{
|
||||
LayerType: LAYER_DENSE,
|
||||
Shape: fmt.Sprintf("%d", exp.end-exp.start),
|
||||
ExpType: 2,
|
||||
LayerNum: i,
|
||||
})
|
||||
|
||||
// Generate run folder
|
||||
run_path := path.Join("/tmp", model.Id, "defs", definition_id)
|
||||
@@ -278,7 +291,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
return
|
||||
}
|
||||
|
||||
// TODO update the run script
|
||||
// TODO update the run script
|
||||
|
||||
// Create python script
|
||||
f, err := os.Create(path.Join(run_path, "run.py"))
|
||||
@@ -287,7 +300,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
||||
tmpl, err := template.New("python_model_template-exp.py").ParseFiles("views/py/python_model_template.py")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -307,6 +320,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
"LoadPrev": load_prev,
|
||||
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
|
||||
"SaveModelPath": path.Join(getDir(), result_path),
|
||||
"RemoveTopCount": remove_top_count,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user