2023-09-26 20:15:28 +01:00
package models_train
import (
"database/sql"
"errors"
"fmt"
2023-09-27 21:20:39 +01:00
"io"
2023-09-26 20:15:28 +01:00
"net/http"
2023-09-27 21:20:39 +01:00
"os"
"os/exec"
"path"
"strconv"
"text/template"
2023-09-26 20:15:28 +01:00
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
func MakeDefenition ( db * sql . DB , model_id string , target_accuracy int ) ( id string , err error ) {
2023-09-27 21:20:39 +01:00
id = ""
_ , err = db . Exec ( "insert into model_definition (model_id, target_accuracy) values ($1, $2);" , model_id , target_accuracy )
if err != nil {
return
}
2023-09-26 20:15:28 +01:00
2023-09-27 21:20:39 +01:00
rows , err := db . Query ( "select id from model_definition where model_id=$1 order by created_on DESC;" , model_id )
if err != nil {
return
}
defer rows . Close ( )
2023-09-26 20:15:28 +01:00
2023-09-27 21:20:39 +01:00
if ! rows . Next ( ) {
return id , errors . New ( "Something wrong!" )
}
2023-09-26 20:15:28 +01:00
2023-09-27 21:20:39 +01:00
err = rows . Scan ( & id )
if err != nil {
return
}
2023-09-26 20:15:28 +01:00
2023-09-27 21:20:39 +01:00
return
2023-09-26 20:15:28 +01:00
}
type ModelDefinitionStatus int
const (
2023-09-27 21:20:39 +01:00
MODEL_DEFINITION_STATUS_FAILED_TRAINING = - 3
MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1
MODEL_DEFINITION_STATUS_INIT = 2
MODEL_DEFINITION_STATUS_TRAINING = 3
MODEL_DEFINITION_STATUS_TRANIED = 4
MODEL_DEFINITION_STATUS_READY = 5
)
type LayerType int
const (
LAYER_INPUT LayerType = 1
LAYER_DENSE = 2
LAYER_FLATTEN = 3
2023-09-26 20:15:28 +01:00
)
func ModelDefinitionUpdateStatus ( handle * Handle , id string , status ModelDefinitionStatus ) ( err error ) {
_ , err = handle . Db . Exec ( "update model_definition set status = $1 where id = $2" , status , id )
2023-09-27 21:20:39 +01:00
return
2023-09-26 20:15:28 +01:00
}
2023-09-27 21:20:39 +01:00
func MakeLayer ( db * sql . DB , def_id string , layer_order int , layer_type LayerType , shape string ) ( err error ) {
2023-09-26 20:15:28 +01:00
_ , err = db . Exec ( "insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)" , def_id , layer_order , layer_type , shape )
return
}
2023-09-27 21:20:39 +01:00
func trainDefinition ( handle * Handle , model_id string , definition_id string ) ( accuracy float64 , err error ) {
accuracy = 0
layers , err := handle . Db . Query ( "select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;" , definition_id )
if err != nil {
return
}
defer layers . Close ( )
type layerrow struct {
LayerType int
Shape string
}
got := [ ] layerrow { }
for layers . Next ( ) {
var row = layerrow { }
if err = layers . Scan ( & row . LayerType , & row . Shape ) ; err != nil {
return
}
row . Shape = shapeToSize ( row . Shape )
got = append ( got , row )
}
// Generate run folder
run_path := path . Join ( "/tmp" , model_id , "defs" , definition_id )
err = os . MkdirAll ( run_path , os . ModePerm )
if err != nil {
return
}
// Create python script
f , err := os . Create ( path . Join ( run_path , "run.py" ) )
if err != nil {
return
}
defer f . Close ( )
tmpl , err := template . New ( "python_model_template.py" ) . ParseFiles ( "views/py/python_model_template.py" )
if err != nil {
return
}
if err = tmpl . Execute ( f , AnyMap {
"Layers" : got ,
"Size" : got [ 0 ] . Shape ,
"DataDir" : path . Join ( getDir ( ) , "savedData" , model_id , "data" , "training" ) ,
} ) ; err != nil {
return
}
// Run the command
if err = exec . Command ( "bash" , "-c" , fmt . Sprintf ( "cd %s && python run.py" , run_path ) ) . Run ( ) ; err != nil {
return
}
// Copy result around
result_path := path . Join ( "savedData" , model_id , "defs" , definition_id )
if err = os . MkdirAll ( result_path , os . ModePerm ) ; err != nil {
return
}
if err = exec . Command ( "cp" , "-r" , path . Join ( run_path , "model" ) , path . Join ( result_path , "model" ) ) . Run ( ) ; err != nil {
return
}
accuracy_file , err := os . Open ( path . Join ( run_path , "accuracy.val" ) )
if err != nil {
return
}
defer accuracy_file . Close ( )
accuracy_file_bytes , err := io . ReadAll ( accuracy_file )
if err != nil {
return
}
fmt . Println ( string ( accuracy_file_bytes ) )
accuracy , err = strconv . ParseFloat ( string ( accuracy_file_bytes ) , 64 )
if err != nil {
return
}
2023-09-29 13:27:43 +01:00
2023-09-27 21:20:39 +01:00
os . RemoveAll ( run_path )
return
}
2023-09-27 18:07:04 +01:00
func trainModel ( handle * Handle , model * BaseModel ) {
2023-09-27 21:20:39 +01:00
definitionsRows , err := handle . Db . Query ( "select id, target_accuracy from model_definition where status=$1 and model_id=$2" , MODEL_DEFINITION_STATUS_INIT , model . Id )
if err != nil {
fmt . Printf ( "Failed to trainModel!Err:\n" )
fmt . Println ( err )
ModelUpdateStatus ( handle , model . Id , FAILED_TRAINING )
return
}
defer definitionsRows . Close ( )
type row struct {
id string
target_accuracy int
}
definitions := [ ] row { }
for definitionsRows . Next ( ) {
var rowv row
if err = definitionsRows . Scan ( & rowv . id , & rowv . target_accuracy ) ; err != nil {
fmt . Printf ( "Failed to trainModel!Err:\n" )
fmt . Println ( err )
ModelUpdateStatus ( handle , model . Id , FAILED_TRAINING )
return
}
definitions = append ( definitions , rowv )
}
if len ( definitions ) == 0 {
fmt . Printf ( "Failed to trainModel!Err:\n" )
fmt . Println ( err )
ModelUpdateStatus ( handle , model . Id , FAILED_TRAINING )
return
}
for _ , def := range definitions {
accuracy , err := trainDefinition ( handle , model . Id , def . id )
if err != nil {
fmt . Printf ( "Failed to train definition!Err:\n" )
fmt . Println ( err )
ModelDefinitionUpdateStatus ( handle , def . id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
continue
}
int_accuracy := int ( accuracy * 100 )
if int_accuracy < def . target_accuracy {
fmt . Printf ( "Failed to train definition! Accuracy less %d < %d\n" , int_accuracy , def . target_accuracy )
ModelDefinitionUpdateStatus ( handle , def . id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
continue
}
_ , err = handle . Db . Exec ( "update model_definition set accuracy=$1, status=$2 where id=$3" , int_accuracy , MODEL_DEFINITION_STATUS_TRANIED , def . id )
if err != nil {
fmt . Printf ( "Failed to train definition!Err:\n" )
fmt . Println ( err )
ModelUpdateStatus ( handle , model . Id , FAILED_TRAINING )
return
}
}
rows , err := handle . Db . Query ( "select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;" , model . Id , MODEL_DEFINITION_STATUS_TRANIED )
if err != nil {
fmt . Printf ( "Db err select!Err:\n" )
fmt . Println ( err )
ModelUpdateStatus ( handle , model . Id , FAILED_TRAINING )
return
}
defer rows . Close ( )
if ! rows . Next ( ) {
// TODO improve message
fmt . Printf ( "All definitions failed to train!" )
ModelUpdateStatus ( handle , model . Id , FAILED_TRAINING )
return
}
var id string
if err = rows . Scan ( & id ) ; err != nil {
fmt . Printf ( "Db err!Err:\n" )
fmt . Println ( err )
ModelUpdateStatus ( handle , model . Id , FAILED_TRAINING )
return
2023-09-27 18:07:04 +01:00
}
2023-09-27 21:20:39 +01:00
if _ , err = handle . Db . Exec ( "update model_definition set status=$1 where id=$2;" , MODEL_DEFINITION_STATUS_READY , id ) ; err != nil {
fmt . Printf ( "Db err!Err:\n" )
fmt . Println ( err )
ModelUpdateStatus ( handle , model . Id , FAILED_TRAINING )
return
}
to_delete , err := handle . Db . Query ( "select id from model_definition where status != $1 and model_id=$2" , MODEL_DEFINITION_STATUS_READY , model . Id )
if err != nil {
fmt . Printf ( "Db err!Err:\n" )
fmt . Println ( err )
ModelUpdateStatus ( handle , model . Id , FAILED_TRAINING )
return
}
defer to_delete . Close ( )
2023-09-27 18:07:04 +01:00
2023-09-27 21:20:39 +01:00
for to_delete . Next ( ) {
2023-09-27 18:07:04 +01:00
var id string
2023-09-27 21:20:39 +01:00
if to_delete . Scan ( & id ) ; err != nil {
fmt . Printf ( "Db err!Err:\n" )
2023-09-27 18:07:04 +01:00
fmt . Println ( err )
ModelUpdateStatus ( handle , model . Id , FAILED_TRAINING )
return
}
2023-09-27 21:20:39 +01:00
os . RemoveAll ( path . Join ( "savedData" , model . Id , "defs" , id ) )
2023-09-27 18:07:04 +01:00
}
2023-09-27 21:20:39 +01:00
if _ , err = handle . Db . Exec ( "delete from model_definition where status!=$1 and model_id=$2;" , MODEL_DEFINITION_STATUS_READY , model . Id ) ; err != nil {
fmt . Printf ( "Db err!Err:\n" )
fmt . Println ( err )
ModelUpdateStatus ( handle , model . Id , FAILED_TRAINING )
return
2023-09-27 18:07:04 +01:00
}
2023-09-27 21:20:39 +01:00
ModelUpdateStatus ( handle , model . Id , READY )
2023-09-27 18:07:04 +01:00
}
2023-09-26 20:15:28 +01:00
func handleTrain ( handle * Handle ) {
handle . Post ( "/models/train" , func ( w http . ResponseWriter , r * http . Request , c * Context ) * Error {
if ! CheckAuthLevel ( 1 , w , r , c ) {
return nil
}
if c . Mode == JSON {
panic ( "TODO /models/train JSON" )
}
r . ParseForm ( )
f := r . Form
number_of_models := 0
accuracy := 0
if ! CheckId ( f , "id" ) || CheckEmpty ( f , "model_type" ) || ! CheckNumber ( f , "number_of_models" , & number_of_models ) || ! CheckNumber ( f , "accuracy" , & accuracy ) {
fmt . Println (
! CheckId ( f , "id" ) , CheckEmpty ( f , "model_type" ) , ! CheckNumber ( f , "number_of_models" , & number_of_models ) , ! CheckNumber ( f , "accuracy" , & accuracy ) ,
)
// TODO improve this response
return ErrorCode ( nil , 400 , c . AddMap ( nil ) )
}
id := f . Get ( "id" )
model_type := f . Get ( "model_type" )
// Its not used rn
_ = model_type
2023-09-27 21:20:39 +01:00
// TODO check if the model has data
/ * rows , err := handle . Db . Query ( "select mc.name, mdp.file_path from model_classes as mc join model_data_point as mdp on mdp.class_id = mc.id where mdp.model_mode = 1 and mc.model_id = $1 limit 1;" , id )
if err != nil {
return Error500 ( err )
}
defer rows . Close ( )
if ! rows . Next ( ) {
return Error500 ( err )
}
var name string
var file_path string
err = rows . Scan ( & name , & file_path )
if err != nil {
return Error500 ( err )
} * /
2023-09-26 20:15:28 +01:00
model , err := GetBaseModel ( handle . Db , id )
if err == ModelNotFoundError {
return ErrorCode ( nil , http . StatusNotFound , c . AddMap ( AnyMap {
"NotFoundMessage" : "Model not found" ,
"GoBackLink" : "/models" ,
} ) )
} else if err != nil {
// TODO improve this response
return Error500 ( err )
}
if model . Status != CONFIRM_PRE_TRAINING {
// TODO improve this response
return ErrorCode ( nil , 400 , c . AddMap ( nil ) )
}
2023-09-27 21:20:39 +01:00
cls , err := model_classes . ListClasses ( handle . Db , model . Id )
if err != nil {
ModelUpdateStatus ( handle , model . Id , FAILED_PREPARING_TRAINING )
// TODO improve this response
return Error500 ( err )
}
2023-09-26 20:15:28 +01:00
2023-09-27 21:20:39 +01:00
var fid string
2023-09-26 20:15:28 +01:00
for i := 0 ; i < number_of_models ; i ++ {
2023-09-27 21:20:39 +01:00
def_id , err := MakeDefenition ( handle . Db , model . Id , accuracy )
2023-09-26 20:15:28 +01:00
if err != nil {
ModelUpdateStatus ( handle , model . Id , FAILED_PREPARING_TRAINING )
// TODO improve this response
return Error500 ( err )
}
2023-09-27 21:20:39 +01:00
if fid == "" {
fid = def_id
}
// TODO change shape of it depends on the type of the image
err = MakeLayer ( handle . Db , def_id , 1 , LAYER_INPUT , fmt . Sprintf ( "%d,%d,1" , model . Width , model . Height ) )
if err != nil {
ModelUpdateStatus ( handle , model . Id , FAILED_PREPARING_TRAINING )
// TODO improve this response
return Error500 ( err )
}
err = MakeLayer ( handle . Db , def_id , 4 , LAYER_FLATTEN , fmt . Sprintf ( "%d,1" , len ( cls ) ) )
if err != nil {
2023-09-26 20:15:28 +01:00
ModelUpdateStatus ( handle , model . Id , FAILED_PREPARING_TRAINING )
// TODO improve this response
return Error500 ( err )
2023-09-27 21:20:39 +01:00
}
err = MakeLayer ( handle . Db , def_id , 5 , LAYER_DENSE , fmt . Sprintf ( "%d,1" , len ( cls ) * 3 ) )
if err != nil {
2023-09-26 20:15:28 +01:00
ModelUpdateStatus ( handle , model . Id , FAILED_PREPARING_TRAINING )
// TODO improve this response
return Error500 ( err )
2023-09-27 21:20:39 +01:00
}
err = MakeLayer ( handle . Db , def_id , 5 , LAYER_DENSE , fmt . Sprintf ( "%d,1" , len ( cls ) ) )
if err != nil {
2023-09-26 20:15:28 +01:00
ModelUpdateStatus ( handle , model . Id , FAILED_PREPARING_TRAINING )
// TODO improve this response
return Error500 ( err )
2023-09-27 21:20:39 +01:00
}
2023-09-26 20:15:28 +01:00
2023-09-27 21:20:39 +01:00
err = ModelDefinitionUpdateStatus ( handle , def_id , MODEL_DEFINITION_STATUS_INIT )
if err != nil {
2023-09-26 20:15:28 +01:00
ModelUpdateStatus ( handle , model . Id , FAILED_PREPARING_TRAINING )
// TODO improve this response
return Error500 ( err )
2023-09-27 21:20:39 +01:00
}
2023-09-26 20:15:28 +01:00
}
2023-09-27 21:20:39 +01:00
// TODO start training with id fid
2023-09-26 20:15:28 +01:00
2023-09-27 21:20:39 +01:00
go trainModel ( handle , model )
2023-09-27 18:07:04 +01:00
2023-09-27 21:20:39 +01:00
ModelUpdateStatus ( handle , model . Id , TRAINING )
Redirect ( "/models/edit?id=" + model . Id , c . Mode , w , r )
2023-09-26 20:15:28 +01:00
return nil
} )
}