2023-09-26 20:15:28 +01:00
package models_train
import (
"database/sql"
"errors"
"fmt"
2023-09-27 21:20:39 +01:00
"io"
2023-10-19 11:42:38 +01:00
"math"
2023-09-26 20:15:28 +01:00
"net/http"
2023-09-27 21:20:39 +01:00
"os"
"os/exec"
"path"
2023-10-22 23:02:39 +01:00
"sort"
2023-09-27 21:20:39 +01:00
"strconv"
"text/template"
2023-09-26 20:15:28 +01:00
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
2024-02-02 16:16:26 +00:00
"github.com/charmbracelet/log"
2023-09-26 20:15:28 +01:00
)
2023-10-25 14:50:58 +01:00
const EPOCH_PER_RUN = 20
2023-10-19 10:44:13 +01:00
const MAX_EPOCH = 100
2023-09-26 20:15:28 +01:00
func MakeDefenition ( db * sql . DB , model_id string , target_accuracy int ) ( id string , err error ) {
2023-09-27 21:20:39 +01:00
id = ""
2023-10-10 12:28:49 +01:00
rows , err := db . Query ( "insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;" , model_id , target_accuracy )
if err != nil {
return
}
defer rows . Close ( )
if ! rows . Next ( ) {
return id , errors . New ( "Something wrong!" )
}
2023-09-27 21:20:39 +01:00
err = rows . Scan ( & id )
2023-10-10 12:28:49 +01:00
return
2023-09-26 20:15:28 +01:00
}
2023-10-06 12:13:19 +01:00
func ModelDefinitionUpdateStatus ( c * Context , id string , status ModelDefinitionStatus ) ( err error ) {
_ , err = c . Db . Exec ( "update model_definition set status = $1 where id = $2" , status , id )
2023-09-27 21:20:39 +01:00
return
2023-09-26 20:15:28 +01:00
}
2024-02-05 16:42:23 +00:00
func UpdateStatus ( c * Context , table string , id string , status int ) ( err error ) {
_ , err = c . Db . Exec ( fmt . Sprintf ( "update %s set status = $1 where id = $2" , table ) , status , id )
return
}
2023-09-27 21:20:39 +01:00
func MakeLayer ( db * sql . DB , def_id string , layer_order int , layer_type LayerType , shape string ) ( err error ) {
2023-09-26 20:15:28 +01:00
_ , err = db . Exec ( "insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)" , def_id , layer_order , layer_type , shape )
return
}
2024-01-31 21:48:35 +00:00
func MakeLayerExpandable ( db * sql . DB , def_id string , layer_order int , layer_type LayerType , shape string , exp_type int ) ( err error ) {
_ , err = db . Exec ( "insert into model_definition_layer (def_id, layer_order, layer_type, shape, exp_type) values ($1, $2, $3, $4, $5)" , def_id , layer_order , layer_type , shape , exp_type )
return
}
2023-10-10 12:28:49 +01:00
func generateCvs ( c * Context , run_path string , model_id string ) ( count int , err error ) {
2023-10-02 21:15:31 +01:00
2023-10-10 12:28:49 +01:00
classes , err := c . Db . Query ( "select count(*) from model_classes where model_id=$1;" , model_id )
if err != nil {
return
}
defer classes . Close ( )
if ! classes . Next ( ) {
return
}
if err = classes . Scan ( & count ) ; err != nil {
return
}
2023-10-20 13:11:46 +01:00
data , err := c . Db . Query ( "select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;" , model_id , model_classes . DATA_POINT_MODE_TRAINING )
2023-10-10 12:28:49 +01:00
if err != nil {
return
}
defer data . Close ( )
f , err := os . Create ( path . Join ( run_path , "train.csv" ) )
if err != nil {
return
}
defer f . Close ( )
f . Write ( [ ] byte ( "Id,Index\n" ) )
for data . Next ( ) {
var id string
var class_order int
var file_path string
if err = data . Scan ( & id , & class_order , & file_path ) ; err != nil {
return
}
if file_path == "id://" {
f . Write ( [ ] byte ( id + "," + strconv . Itoa ( class_order ) + "\n" ) )
} else {
return count , errors . New ( "TODO generateCvs to file_path " + file_path )
}
}
return
2023-10-02 21:15:31 +01:00
}
2023-10-19 10:44:13 +01:00
func trainDefinition ( c * Context , model * BaseModel , definition_id string , load_prev bool ) ( accuracy float64 , err error ) {
c . Logger . Warn ( "About to start training definition" )
2023-09-27 21:20:39 +01:00
accuracy = 0
2023-10-06 12:13:19 +01:00
layers , err := c . Db . Query ( "select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;" , definition_id )
2023-09-27 21:20:39 +01:00
if err != nil {
return
}
defer layers . Close ( )
type layerrow struct {
LayerType int
Shape string
2024-02-05 16:42:23 +00:00
LayerNum int
2023-09-27 21:20:39 +01:00
}
got := [ ] layerrow { }
2024-02-05 16:42:23 +00:00
i := 1
2023-09-27 21:20:39 +01:00
for layers . Next ( ) {
var row = layerrow { }
if err = layers . Scan ( & row . LayerType , & row . Shape ) ; err != nil {
return
}
row . Shape = shapeToSize ( row . Shape )
2024-02-05 16:42:23 +00:00
row . LayerNum = 1
2023-09-27 21:20:39 +01:00
got = append ( got , row )
2024-02-05 16:42:23 +00:00
i = i + 1
2023-09-27 21:20:39 +01:00
}
// Generate run folder
2023-10-03 19:02:02 +01:00
run_path := path . Join ( "/tmp" , model . Id , "defs" , definition_id )
2023-09-27 21:20:39 +01:00
err = os . MkdirAll ( run_path , os . ModePerm )
if err != nil {
return
}
2023-10-22 23:02:39 +01:00
defer os . RemoveAll ( run_path )
2023-09-27 21:20:39 +01:00
2023-10-10 12:28:49 +01:00
_ , err = generateCvs ( c , run_path , model . Id )
if err != nil {
return
}
2023-10-02 21:15:31 +01:00
2023-09-27 21:20:39 +01:00
// Create python script
f , err := os . Create ( path . Join ( run_path , "run.py" ) )
if err != nil {
return
}
defer f . Close ( )
tmpl , err := template . New ( "python_model_template.py" ) . ParseFiles ( "views/py/python_model_template.py" )
if err != nil {
return
}
2023-10-19 10:44:13 +01:00
// Copy result around
result_path := path . Join ( "savedData" , model . Id , "defs" , definition_id )
2023-09-27 21:20:39 +01:00
if err = tmpl . Execute ( f , AnyMap {
2023-10-19 10:44:13 +01:00
"Layers" : got ,
"Size" : got [ 0 ] . Shape ,
"DataDir" : path . Join ( getDir ( ) , "savedData" , model . Id , "data" ) ,
"RunPath" : run_path ,
"ColorMode" : model . ImageMode ,
"Model" : model ,
"EPOCH_PER_RUN" : EPOCH_PER_RUN ,
"DefId" : definition_id ,
"LoadPrev" : load_prev ,
"LastModelRunPath" : path . Join ( getDir ( ) , result_path , "model.keras" ) ,
2023-10-22 23:02:39 +01:00
"SaveModelPath" : path . Join ( getDir ( ) , result_path ) ,
2023-09-27 21:20:39 +01:00
} ) ; err != nil {
return
}
// Run the command
2023-10-22 23:02:39 +01:00
out , err := exec . Command ( "bash" , "-c" , fmt . Sprintf ( "cd %s && python run.py" , run_path ) ) . CombinedOutput ( )
2023-10-10 12:28:49 +01:00
if err != nil {
2023-10-19 10:44:13 +01:00
c . Logger . Debug ( string ( out ) )
2023-09-27 21:20:39 +01:00
return
}
2023-10-22 23:02:39 +01:00
c . Logger . Info ( "Python finished running" )
2023-09-27 21:20:39 +01:00
2023-10-22 23:02:39 +01:00
if err = os . MkdirAll ( result_path , os . ModePerm ) ; err != nil {
2023-10-19 10:44:13 +01:00
return
}
2023-09-27 21:20:39 +01:00
accuracy_file , err := os . Open ( path . Join ( run_path , "accuracy.val" ) )
if err != nil {
return
}
defer accuracy_file . Close ( )
accuracy_file_bytes , err := io . ReadAll ( accuracy_file )
if err != nil {
return
}
accuracy , err = strconv . ParseFloat ( string ( accuracy_file_bytes ) , 64 )
if err != nil {
return
}
2023-10-10 12:28:49 +01:00
2023-10-21 00:26:52 +01:00
c . Logger . Info ( "Model finished training!" , "accuracy" , accuracy )
2023-09-27 21:20:39 +01:00
return
}
2024-02-02 16:16:26 +00:00
func trainDefinitionExp ( c * Context , model * BaseModel , definition_id string , load_prev bool ) ( accuracy float64 , err error ) {
accuracy = 0
c . Logger . Warn ( "About to start training definition" )
// Get untrained models heads
// Status = 2 (INIT)
2024-02-08 18:20:58 +00:00
rows , err := c . Db . Query ( "select id, range_start, range_end from exp_model_head where def_id=$1 and (status = 2 or status = 3)" , definition_id )
2024-02-02 16:16:26 +00:00
if err != nil {
return
}
2024-02-03 12:39:22 +00:00
defer rows . Close ( )
2024-02-02 16:16:26 +00:00
type ExpHead struct {
2024-02-03 12:39:22 +00:00
id string
2024-02-02 16:16:26 +00:00
start int
end int
}
2024-02-03 12:39:22 +00:00
exp := ExpHead { }
2024-02-02 16:16:26 +00:00
2024-02-03 12:39:22 +00:00
if rows . Next ( ) {
2024-02-08 18:20:58 +00:00
if err = rows . Scan ( & exp . id , & exp . start , & exp . end ) ; err != nil {
2024-02-03 12:39:22 +00:00
return
}
} else {
log . Error ( "Failed to get the exp head of the model" )
err = errors . New ( "Failed to get the exp head of the model" )
return
}
2024-02-02 16:16:26 +00:00
2024-02-03 12:39:22 +00:00
if rows . Next ( ) {
log . Error ( "This training function can only train one model at the time" )
err = errors . New ( "This training function can only train one model at the time" )
return
}
2024-02-02 16:16:26 +00:00
2024-02-08 18:20:58 +00:00
UpdateStatus ( c , "exp_model_head" , exp . id , MODEL_DEFINITION_STATUS_TRAINING )
2024-02-05 16:42:23 +00:00
2024-02-03 12:39:22 +00:00
layers , err := c . Db . Query ( "select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;" , definition_id )
2024-02-02 16:16:26 +00:00
if err != nil {
return
}
defer layers . Close ( )
type layerrow struct {
LayerType int
Shape string
2024-02-03 12:39:22 +00:00
ExpType int
LayerNum int
2024-02-02 16:16:26 +00:00
}
got := [ ] layerrow { }
2024-02-03 12:39:22 +00:00
remove_top_count := 1
i := 1
2024-02-02 16:16:26 +00:00
for layers . Next ( ) {
var row = layerrow { }
2024-02-03 12:39:22 +00:00
if err = layers . Scan ( & row . LayerType , & row . Shape , & row . ExpType ) ; err != nil {
2024-02-02 16:16:26 +00:00
return
}
2024-02-03 12:39:22 +00:00
row . LayerNum = i
if row . ExpType == 2 {
remove_top_count += 1
}
2024-02-02 16:16:26 +00:00
row . Shape = shapeToSize ( row . Shape )
got = append ( got , row )
2024-02-03 12:39:22 +00:00
i += 1
2024-02-02 16:16:26 +00:00
}
2024-02-03 12:39:22 +00:00
got = append ( got , layerrow {
LayerType : LAYER_DENSE ,
2024-02-10 09:41:16 +00:00
Shape : fmt . Sprintf ( "%d" , exp . end - exp . start + 1 ) ,
2024-02-03 12:39:22 +00:00
ExpType : 2 ,
LayerNum : i ,
} )
2024-02-02 16:16:26 +00:00
// Generate run folder
run_path := path . Join ( "/tmp" , model . Id , "defs" , definition_id )
err = os . MkdirAll ( run_path , os . ModePerm )
if err != nil {
return
}
defer os . RemoveAll ( run_path )
_ , err = generateCvs ( c , run_path , model . Id )
if err != nil {
return
}
2024-02-03 12:39:22 +00:00
// TODO update the run script
2024-02-02 16:16:26 +00:00
// Create python script
f , err := os . Create ( path . Join ( run_path , "run.py" ) )
if err != nil {
return
}
defer f . Close ( )
2024-02-05 16:42:23 +00:00
tmpl , err := template . New ( "python_model_template.py" ) . ParseFiles ( "views/py/python_model_template.py" )
2024-02-02 16:16:26 +00:00
if err != nil {
return
}
// Copy result around
result_path := path . Join ( "savedData" , model . Id , "defs" , definition_id )
if err = tmpl . Execute ( f , AnyMap {
"Layers" : got ,
"Size" : got [ 0 ] . Shape ,
"DataDir" : path . Join ( getDir ( ) , "savedData" , model . Id , "data" ) ,
2024-02-05 16:42:23 +00:00
"HeadId" : exp . id ,
2024-02-02 16:16:26 +00:00
"RunPath" : run_path ,
"ColorMode" : model . ImageMode ,
"Model" : model ,
"EPOCH_PER_RUN" : EPOCH_PER_RUN ,
"DefId" : definition_id ,
"LoadPrev" : load_prev ,
"LastModelRunPath" : path . Join ( getDir ( ) , result_path , "model.keras" ) ,
"SaveModelPath" : path . Join ( getDir ( ) , result_path ) ,
2024-02-03 12:39:22 +00:00
"RemoveTopCount" : remove_top_count ,
2024-02-02 16:16:26 +00:00
} ) ; err != nil {
return
}
// Run the command
out , err := exec . Command ( "bash" , "-c" , fmt . Sprintf ( "cd %s && python run.py" , run_path ) ) . CombinedOutput ( )
if err != nil {
c . Logger . Debug ( string ( out ) )
return
}
c . Logger . Info ( "Python finished running" )
if err = os . MkdirAll ( result_path , os . ModePerm ) ; err != nil {
return
}
accuracy_file , err := os . Open ( path . Join ( run_path , "accuracy.val" ) )
if err != nil {
return
}
defer accuracy_file . Close ( )
accuracy_file_bytes , err := io . ReadAll ( accuracy_file )
if err != nil {
return
}
accuracy , err = strconv . ParseFloat ( string ( accuracy_file_bytes ) , 64 )
if err != nil {
return
}
c . Logger . Info ( "Model finished training!" , "accuracy" , accuracy )
return
}
2023-10-19 10:44:13 +01:00
func remove [ T interface { } ] ( lst [ ] T , i int ) [ ] T {
lng := len ( lst )
if i >= lng {
return [ ] T { }
}
if i + 1 >= lng {
return lst [ : lng - 1 ]
}
if i == 0 {
return lst [ 1 : ]
}
return append ( lst [ : i ] , lst [ i + 1 : ] ... )
}
2023-10-22 23:02:39 +01:00
type TrainModelRow struct {
id string
target_accuracy int
epoch int
acuracy float64
}
type TraingModelRowDefinitions [ ] TrainModelRow
func ( nf TraingModelRowDefinitions ) Len ( ) int { return len ( nf ) }
func ( nf TraingModelRowDefinitions ) Swap ( i , j int ) { nf [ i ] , nf [ j ] = nf [ j ] , nf [ i ] }
func ( nf TraingModelRowDefinitions ) Less ( i , j int ) bool {
return nf [ i ] . acuracy < nf [ j ] . acuracy
}
type ToRemoveList [ ] int
func ( nf ToRemoveList ) Len ( ) int { return len ( nf ) }
func ( nf ToRemoveList ) Swap ( i , j int ) { nf [ i ] , nf [ j ] = nf [ j ] , nf [ i ] }
func ( nf ToRemoveList ) Less ( i , j int ) bool {
return nf [ i ] < nf [ j ]
}
2023-10-06 12:13:19 +01:00
func trainModel ( c * Context , model * BaseModel ) {
2023-10-19 10:44:13 +01:00
definitionsRows , err := c . Db . Query ( "select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2" , MODEL_DEFINITION_STATUS_INIT , model . Id )
2023-09-27 21:20:39 +01:00
if err != nil {
2023-10-06 12:13:19 +01:00
c . Logger . Error ( "Failed to trainModel!Err:" )
c . Logger . Error ( err )
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
return
}
defer definitionsRows . Close ( )
2023-10-22 23:02:39 +01:00
var definitions TraingModelRowDefinitions = [ ] TrainModelRow { }
2023-09-27 21:20:39 +01:00
for definitionsRows . Next ( ) {
2023-10-22 23:02:39 +01:00
var rowv TrainModelRow
rowv . acuracy = 0
2023-10-19 10:44:13 +01:00
if err = definitionsRows . Scan ( & rowv . id , & rowv . target_accuracy , & rowv . epoch ) ; err != nil {
2023-10-10 12:28:49 +01:00
c . Logger . Error ( "Failed to train Model Could not read definition from db!Err:" )
c . Logger . Error ( err )
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
return
}
definitions = append ( definitions , rowv )
}
if len ( definitions ) == 0 {
2023-10-10 12:28:49 +01:00
c . Logger . Error ( "No Definitions defined!" )
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
return
}
2023-10-19 10:44:13 +01:00
firstRound := true
2023-10-22 23:02:39 +01:00
finished := false
2023-10-19 10:44:13 +01:00
for {
2023-10-22 23:02:39 +01:00
var toRemove ToRemoveList = [ ] int { }
2023-10-19 10:44:13 +01:00
for i , def := range definitions {
ModelDefinitionUpdateStatus ( c , def . id , MODEL_DEFINITION_STATUS_TRAINING )
accuracy , err := trainDefinition ( c , model , def . id , ! firstRound )
if err != nil {
c . Logger . Error ( "Failed to train definition!Err:" , "err" , err )
ModelDefinitionUpdateStatus ( c , def . id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
2023-10-22 23:02:39 +01:00
toRemove = append ( toRemove , i )
2023-10-19 10:44:13 +01:00
continue
}
def . epoch += EPOCH_PER_RUN
2023-10-21 00:26:52 +01:00
accuracy = accuracy * 100
2023-10-24 22:35:11 +01:00
def . acuracy = float64 ( accuracy )
definitions [ i ] . epoch += EPOCH_PER_RUN
definitions [ i ] . acuracy = accuracy
2023-10-19 10:44:13 +01:00
2023-10-21 00:26:52 +01:00
if accuracy >= float64 ( def . target_accuracy ) {
2023-10-19 10:44:13 +01:00
c . Logger . Info ( "Found a definition that reaches target_accuracy!" )
2023-10-21 00:26:52 +01:00
_ , err = c . Db . Exec ( "update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4" , accuracy , MODEL_DEFINITION_STATUS_TRANIED , def . epoch , def . id )
2023-10-19 10:44:13 +01:00
if err != nil {
c . Logger . Error ( "Failed to train definition!Err:\n" , "err" , err )
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
return
}
_ , err = c . Db . Exec ( "update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4" , MODEL_DEFINITION_STATUS_CANCELD_TRAINING , def . id , model . Id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
if err != nil {
c . Logger . Error ( "Failed to train definition!Err:\n" , "err" , err )
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
return
}
2023-10-22 23:02:39 +01:00
finished = true
2023-10-19 10:44:13 +01:00
break
}
2023-09-27 21:20:39 +01:00
2023-10-19 10:44:13 +01:00
if def . epoch > MAX_EPOCH {
2023-10-21 00:26:52 +01:00
fmt . Printf ( "Failed to train definition! Accuracy less %f < %d\n" , accuracy , def . target_accuracy )
2023-10-19 10:44:13 +01:00
ModelDefinitionUpdateStatus ( c , def . id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
2023-10-22 23:02:39 +01:00
toRemove = append ( toRemove , i )
2023-10-19 10:44:13 +01:00
continue
}
2023-09-27 21:20:39 +01:00
2023-10-22 23:02:39 +01:00
_ , err = c . Db . Exec ( "update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4" , accuracy , def . epoch , MODEL_DEFINITION_STATUS_PAUSED_TRAINING , def . id )
if err != nil {
c . Logger . Error ( "Failed to train definition!Err:\n" , "err" , err )
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
return
}
2023-09-27 21:20:39 +01:00
}
2023-10-22 23:02:39 +01:00
2023-10-19 10:44:13 +01:00
firstRound = false
2023-10-22 23:02:39 +01:00
if finished {
break
}
2023-10-25 14:50:58 +01:00
sort . Sort ( sort . Reverse ( toRemove ) )
2023-10-22 23:02:39 +01:00
c . Logger . Info ( "Round done" , "toRemove" , toRemove )
for _ , n := range toRemove {
definitions = remove ( definitions , n )
}
len_def := len ( definitions )
if len_def == 0 {
2023-10-19 10:44:13 +01:00
break
2023-09-27 21:20:39 +01:00
}
2023-10-22 23:02:39 +01:00
if len_def == 1 {
continue
}
2023-10-25 14:50:58 +01:00
sort . Sort ( sort . Reverse ( definitions ) )
2023-10-22 23:02:39 +01:00
2023-10-24 22:35:11 +01:00
acc := definitions [ 0 ] . acuracy - 20.0
2023-10-22 23:02:39 +01:00
2023-10-24 22:35:11 +01:00
c . Logger . Info ( "Training models, Highest acc" , "acc" , definitions [ 0 ] . acuracy , "mod_acc" , acc )
2023-10-22 23:02:39 +01:00
toRemove = [ ] int { }
for i , def := range definitions {
if def . acuracy < acc {
toRemove = append ( toRemove , i )
}
}
c . Logger . Info ( "Removing due to accuracy" , "toRemove" , toRemove )
2023-10-25 14:50:58 +01:00
sort . Sort ( sort . Reverse ( toRemove ) )
2023-10-22 23:02:39 +01:00
for _ , n := range toRemove {
c . Logger . Warn ( "Removing definition not fast enough learning" , "n" , n )
2023-10-25 14:50:58 +01:00
ModelDefinitionUpdateStatus ( c , definitions [ n ] . id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
2023-10-22 23:02:39 +01:00
definitions = remove ( definitions , n )
}
2023-09-27 21:20:39 +01:00
}
2023-10-06 12:13:19 +01:00
rows , err := c . Db . Query ( "select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;" , model . Id , MODEL_DEFINITION_STATUS_TRANIED )
2023-09-27 21:20:39 +01:00
if err != nil {
2023-10-10 12:28:49 +01:00
c . Logger . Error ( "DB: failed to read definition" )
c . Logger . Error ( err )
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
return
}
2023-10-10 12:28:49 +01:00
defer rows . Close ( )
2023-09-27 21:20:39 +01:00
2023-10-10 12:28:49 +01:00
if ! rows . Next ( ) {
// TODO Make the Model status have a message
c . Logger . Error ( "All definitions failed to train!" )
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
2023-09-27 21:20:39 +01:00
2023-10-10 12:28:49 +01:00
var id string
if err = rows . Scan ( & id ) ; err != nil {
c . Logger . Error ( "Failed to read id:" )
c . Logger . Error ( err )
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
2023-09-27 18:07:04 +01:00
2023-10-10 12:28:49 +01:00
if _ , err = c . Db . Exec ( "update model_definition set status=$1 where id=$2;" , MODEL_DEFINITION_STATUS_READY , id ) ; err != nil {
c . Logger . Error ( "Failed to update model definition" )
c . Logger . Error ( err )
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
2023-09-27 21:20:39 +01:00
2023-10-10 12:28:49 +01:00
to_delete , err := c . Db . Query ( "select id from model_definition where status != $1 and model_id=$2" , MODEL_DEFINITION_STATUS_READY , model . Id )
if err != nil {
c . Logger . Error ( "Failed to select model_definition to delete" )
c . Logger . Error ( err )
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
defer to_delete . Close ( )
for to_delete . Next ( ) {
var id string
if to_delete . Scan ( & id ) ; err != nil {
c . Logger . Error ( "Failed to scan the id of a model_definition to delete" )
c . Logger . Error ( err )
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
return
}
os . RemoveAll ( path . Join ( "savedData" , model . Id , "defs" , id ) )
}
// TODO Check if returning also works here
if _ , err = c . Db . Exec ( "delete from model_definition where status!=$1 and model_id=$2;" , MODEL_DEFINITION_STATUS_READY , model . Id ) ; err != nil {
c . Logger . Error ( "Failed to delete model_definition" )
c . Logger . Error ( err )
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
2023-09-27 18:07:04 +01:00
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , READY )
2023-09-27 18:07:04 +01:00
}
2024-01-31 21:48:35 +00:00
func trainModelExp ( c * Context , model * BaseModel ) {
2024-02-02 16:16:26 +00:00
var err error = nil
2024-01-31 21:48:35 +00:00
failed := func ( msg string ) {
c . Logger . Error ( msg , "err" , err )
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
}
definitionsRows , err := c . Db . Query ( "select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2" , MODEL_DEFINITION_STATUS_INIT , model . Id )
if err != nil {
2024-02-02 16:16:26 +00:00
failed ( "Failed to trainModel!" )
2024-01-31 21:48:35 +00:00
return
}
defer definitionsRows . Close ( )
var definitions TraingModelRowDefinitions = [ ] TrainModelRow { }
for definitionsRows . Next ( ) {
var rowv TrainModelRow
rowv . acuracy = 0
if err = definitionsRows . Scan ( & rowv . id , & rowv . target_accuracy , & rowv . epoch ) ; err != nil {
2024-02-05 16:42:23 +00:00
failed ( "Failed to train Model Could not read definition from db!" )
2024-01-31 21:48:35 +00:00
return
}
definitions = append ( definitions , rowv )
}
if len ( definitions ) == 0 {
2024-02-05 16:42:23 +00:00
failed ( "No Definitions defined!" )
2024-01-31 21:48:35 +00:00
return
}
firstRound := true
finished := false
for {
var toRemove ToRemoveList = [ ] int { }
for i , def := range definitions {
ModelDefinitionUpdateStatus ( c , def . id , MODEL_DEFINITION_STATUS_TRAINING )
2024-02-02 16:16:26 +00:00
accuracy , err := trainDefinitionExp ( c , model , def . id , ! firstRound )
2024-01-31 21:48:35 +00:00
if err != nil {
c . Logger . Error ( "Failed to train definition!Err:" , "err" , err )
ModelDefinitionUpdateStatus ( c , def . id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
toRemove = append ( toRemove , i )
continue
}
def . epoch += EPOCH_PER_RUN
accuracy = accuracy * 100
def . acuracy = float64 ( accuracy )
definitions [ i ] . epoch += EPOCH_PER_RUN
definitions [ i ] . acuracy = accuracy
if accuracy >= float64 ( def . target_accuracy ) {
c . Logger . Info ( "Found a definition that reaches target_accuracy!" )
_ , err = c . Db . Exec ( "update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4" , accuracy , MODEL_DEFINITION_STATUS_TRANIED , def . epoch , def . id )
if err != nil {
2024-02-05 16:42:23 +00:00
failed ( "Failed to train definition!" )
2024-01-31 21:48:35 +00:00
return
}
_ , err = c . Db . Exec ( "update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4" , MODEL_DEFINITION_STATUS_CANCELD_TRAINING , def . id , model . Id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
if err != nil {
2024-02-05 16:42:23 +00:00
failed ( "Failed to train definition!" )
2024-01-31 21:48:35 +00:00
return
}
finished = true
break
}
if def . epoch > MAX_EPOCH {
fmt . Printf ( "Failed to train definition! Accuracy less %f < %d\n" , accuracy , def . target_accuracy )
ModelDefinitionUpdateStatus ( c , def . id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
toRemove = append ( toRemove , i )
continue
}
_ , err = c . Db . Exec ( "update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4" , accuracy , def . epoch , MODEL_DEFINITION_STATUS_PAUSED_TRAINING , def . id )
if err != nil {
2024-02-05 16:42:23 +00:00
failed ( "Failed to train definition!" )
2024-01-31 21:48:35 +00:00
return
}
}
firstRound = false
if finished {
break
}
sort . Sort ( sort . Reverse ( toRemove ) )
c . Logger . Info ( "Round done" , "toRemove" , toRemove )
for _ , n := range toRemove {
definitions = remove ( definitions , n )
}
len_def := len ( definitions )
if len_def == 0 {
break
}
if len_def == 1 {
continue
}
sort . Sort ( sort . Reverse ( definitions ) )
acc := definitions [ 0 ] . acuracy - 20.0
c . Logger . Info ( "Training models, Highest acc" , "acc" , definitions [ 0 ] . acuracy , "mod_acc" , acc )
toRemove = [ ] int { }
for i , def := range definitions {
if def . acuracy < acc {
toRemove = append ( toRemove , i )
}
}
c . Logger . Info ( "Removing due to accuracy" , "toRemove" , toRemove )
sort . Sort ( sort . Reverse ( toRemove ) )
for _ , n := range toRemove {
c . Logger . Warn ( "Removing definition not fast enough learning" , "n" , n )
ModelDefinitionUpdateStatus ( c , definitions [ n ] . id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
definitions = remove ( definitions , n )
}
}
rows , err := c . Db . Query ( "select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;" , model . Id , MODEL_DEFINITION_STATUS_TRANIED )
if err != nil {
2024-02-05 16:42:23 +00:00
failed ( "DB: failed to read definition" )
2024-01-31 21:48:35 +00:00
return
}
defer rows . Close ( )
if ! rows . Next ( ) {
2024-02-05 16:42:23 +00:00
failed ( "All definitions failed to train!" )
2024-01-31 21:48:35 +00:00
return
}
var id string
if err = rows . Scan ( & id ) ; err != nil {
2024-02-05 16:42:23 +00:00
failed ( "Failed to read id" )
2024-01-31 21:48:35 +00:00
return
}
if _ , err = c . Db . Exec ( "update model_definition set status=$1 where id=$2;" , MODEL_DEFINITION_STATUS_READY , id ) ; err != nil {
2024-02-05 16:42:23 +00:00
failed ( "Failed to update model definition" )
2024-01-31 21:48:35 +00:00
return
}
to_delete , err := c . Db . Query ( "select id from model_definition where status != $1 and model_id=$2" , MODEL_DEFINITION_STATUS_READY , model . Id )
if err != nil {
2024-02-05 16:42:23 +00:00
failed ( "Failed to select model_definition to delete" )
2024-01-31 21:48:35 +00:00
return
}
defer to_delete . Close ( )
for to_delete . Next ( ) {
var id string
if to_delete . Scan ( & id ) ; err != nil {
2024-02-05 16:42:23 +00:00
failed ( "Failed to scan the id of a model_definition to delete" )
2024-01-31 21:48:35 +00:00
return
}
os . RemoveAll ( path . Join ( "savedData" , model . Id , "defs" , id ) )
}
// TODO Check if returning also works here
if _ , err = c . Db . Exec ( "delete from model_definition where status!=$1 and model_id=$2;" , MODEL_DEFINITION_STATUS_READY , model . Id ) ; err != nil {
2024-02-05 16:42:23 +00:00
failed ( "Failed to delete model_definition" )
2024-01-31 21:48:35 +00:00
return
}
2024-02-12 14:30:43 +00:00
if err = splitModel ( c , model ) ; err != nil {
failed ( "Failed to split the model" )
return
}
2024-01-31 21:48:35 +00:00
ModelUpdateStatus ( c , model . Id , READY )
}
2024-02-12 14:30:43 +00:00
func splitModel ( c * Context , model * BaseModel ) ( err error ) {
type Def struct {
Id string
}
def := Def { }
if err = GetDBOnce ( c , & def , "model_definition where model_id=$1" , model . Id ) ; err != nil {
return
}
head := Def { }
if err = GetDBOnce ( c , & head , "exp_model_head where def_id=$1" , def . Id ) ; err != nil {
return
}
// Generate run folder
run_path := path . Join ( "/tmp" , model . Id , "defs" , def . Id )
err = os . MkdirAll ( run_path , os . ModePerm )
if err != nil {
return
}
// TODO reneable it
// defer os.RemoveAll(run_path)
// Create python script
f , err := os . Create ( path . Join ( run_path , "run.py" ) )
if err != nil {
return
}
defer f . Close ( )
tmpl , err := template . New ( "python_split_model_template.py" ) . ParseFiles ( "views/py/python_split_model_template.py" )
if err != nil {
return
}
// Copy result around
result_path := path . Join ( getDir ( ) , "savedData" , model . Id , "defs" , def . Id )
// TODO maybe move this to a select count(*)
// Get only fixed lawers
layers , err := c . Db . Query ( "select exp_type from model_definition_layer where def_id=$1 and exp_type=$2 order by layer_order asc;" , def . Id , 1 )
if err != nil {
return
}
defer layers . Close ( )
type layerrow struct {
ExpType int
}
count := - 1
for layers . Next ( ) {
count += 1
}
if count == - 1 {
err = errors . New ( "Can not get layers" )
return
}
log . Warn ( "Spliting model" , "def" , def . Id , "head" , head . Id , "count" , count )
basePath := path . Join ( result_path , "base" )
headPath := path . Join ( result_path , "head" , head . Id )
if err = os . MkdirAll ( basePath , os . ModePerm ) ; err != nil {
return
}
if err = os . MkdirAll ( headPath , os . ModePerm ) ; err != nil {
return
}
if err = tmpl . Execute ( f , AnyMap {
"SplitLen" : count ,
"ModelPath" : path . Join ( result_path , "model.keras" ) ,
"BaseModelPath" : basePath ,
"HeadModelPath" : headPath ,
} ) ; err != nil {
return
}
out , err := exec . Command ( "bash" , "-c" , fmt . Sprintf ( "cd %s && python run.py" , run_path ) ) . CombinedOutput ( )
if err != nil {
c . Logger . Debug ( string ( out ) )
return
}
c . Logger . Info ( "Python finished running" )
return
}
2023-10-20 12:37:56 +01:00
func removeFailedDataPoints ( c * Context , model * BaseModel ) ( err error ) {
rows , err := c . Db . Query ( "select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;" , model . Id )
2023-10-10 12:28:49 +01:00
if err != nil {
return
}
defer rows . Close ( )
2023-10-19 10:44:13 +01:00
base_path := path . Join ( "savedData" , model . Id , "data" )
2023-10-10 12:28:49 +01:00
for rows . Next ( ) {
var dataPointId string
err = rows . Scan ( & dataPointId )
if err != nil {
return
}
2023-10-20 12:37:56 +01:00
2023-10-21 00:26:52 +01:00
p := path . Join ( base_path , dataPointId + "." + model . Format )
c . Logger . Warn ( "Removing image" , "path" , p )
2023-10-20 12:37:56 +01:00
err = os . RemoveAll ( p )
2023-10-19 10:44:13 +01:00
if err != nil {
return
}
2023-10-10 12:28:49 +01:00
}
2023-10-20 12:37:56 +01:00
_ , err = c . Db . Exec ( "delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;" , model . Id )
2023-10-10 12:28:49 +01:00
return
}
2023-10-19 11:42:38 +01:00
// This generates a definition
2023-10-21 00:26:52 +01:00
func generateDefinition ( c * Context , model * BaseModel , target_accuracy int , number_of_classes int , complexity int ) * Error {
var err error = nil
failed := func ( ) * Error {
2023-10-19 11:42:38 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
// TODO improve this response
return c . Error500 ( err )
2023-10-21 00:26:52 +01:00
}
2023-10-19 11:42:38 +01:00
2023-10-21 00:26:52 +01:00
def_id , err := MakeDefenition ( c . Db , model . Id , target_accuracy )
if err != nil {
return failed ( )
}
2023-10-19 11:42:38 +01:00
2023-10-22 23:02:39 +01:00
order := 1
2023-10-21 00:26:52 +01:00
2023-10-22 23:02:39 +01:00
// Note the shape of the first layer defines the import size
if complexity == 2 {
// Note the shape for now is no used
width := int ( math . Pow ( 2 , math . Floor ( math . Log ( float64 ( model . Width ) ) / math . Log ( 2.0 ) ) ) )
height := int ( math . Pow ( 2 , math . Floor ( math . Log ( float64 ( model . Height ) ) / math . Log ( 2.0 ) ) ) )
c . Logger . Warn ( "Complexity 2 creating model with smaller size" , "width" , width , "height" , height )
err = MakeLayer ( c . Db , def_id , order , LAYER_INPUT , fmt . Sprintf ( "%d,%d,1" , width , height ) )
if err != nil {
return failed ( )
}
order ++
} else {
err = MakeLayer ( c . Db , def_id , order , LAYER_INPUT , fmt . Sprintf ( "%d,%d,1" , model . Width , model . Height ) )
if err != nil {
return failed ( )
}
order ++
2023-10-21 00:26:52 +01:00
}
2023-10-19 11:42:38 +01:00
2023-10-21 00:26:52 +01:00
if complexity == 0 {
2023-10-19 11:42:38 +01:00
2023-10-21 00:26:52 +01:00
err = MakeLayer ( c . Db , def_id , order , LAYER_FLATTEN , "" )
2023-10-19 11:42:38 +01:00
if err != nil {
2023-10-21 00:26:52 +01:00
return failed ( )
2023-10-19 11:42:38 +01:00
}
2023-10-22 23:02:39 +01:00
order ++
2023-10-19 11:42:38 +01:00
2023-10-21 12:01:10 +01:00
loop := int ( math . Log2 ( float64 ( number_of_classes ) ) )
2023-10-21 00:26:52 +01:00
for i := 0 ; i < loop ; i ++ {
err = MakeLayer ( c . Db , def_id , order , LAYER_DENSE , fmt . Sprintf ( "%d,1" , number_of_classes * ( loop - i ) ) )
2023-10-22 23:02:39 +01:00
order ++
2023-10-21 00:26:52 +01:00
if err != nil {
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
// TODO improve this response
return c . Error500 ( err )
}
}
2024-01-31 21:48:35 +00:00
} else if complexity == 1 || complexity == 2 {
2023-10-21 00:26:52 +01:00
2023-10-22 23:02:39 +01:00
loop := int ( ( math . Log ( float64 ( model . Width ) ) / math . Log ( float64 ( 10 ) ) ) )
if loop == 0 {
loop = 1
}
2023-10-21 00:26:52 +01:00
for i := 0 ; i < loop ; i ++ {
err = MakeLayer ( c . Db , def_id , order , LAYER_SIMPLE_BLOCK , "" )
2023-10-22 23:02:39 +01:00
order ++
2023-10-21 00:26:52 +01:00
if err != nil {
2023-10-22 23:02:39 +01:00
return failed ( )
2023-10-21 00:26:52 +01:00
}
}
err = MakeLayer ( c . Db , def_id , order , LAYER_FLATTEN , "" )
if err != nil {
return failed ( )
}
2023-10-22 23:02:39 +01:00
order ++
2023-10-21 00:26:52 +01:00
2023-10-22 23:02:39 +01:00
loop = int ( ( math . Log ( float64 ( number_of_classes ) ) / math . Log ( float64 ( 10 ) ) ) / 2 )
if loop == 0 {
loop = 1
}
2023-10-21 00:26:52 +01:00
for i := 0 ; i < loop ; i ++ {
err = MakeLayer ( c . Db , def_id , order , LAYER_DENSE , fmt . Sprintf ( "%d,1" , number_of_classes * ( loop - i ) ) )
2023-10-22 23:02:39 +01:00
order ++
2023-10-21 00:26:52 +01:00
if err != nil {
2023-10-22 23:02:39 +01:00
return failed ( )
}
}
2023-10-21 00:26:52 +01:00
} else {
c . Logger . Error ( "Unkown complexity" , "complexity" , complexity )
return failed ( )
}
2023-10-19 11:42:38 +01:00
2023-10-21 00:26:52 +01:00
err = ModelDefinitionUpdateStatus ( c , def_id , MODEL_DEFINITION_STATUS_INIT )
if err != nil {
return failed ( )
}
2023-10-19 11:42:38 +01:00
2023-10-21 00:26:52 +01:00
return nil
2023-10-19 11:42:38 +01:00
}
2023-10-21 00:26:52 +01:00
func generateDefinitions ( c * Context , model * BaseModel , target_accuracy int , number_of_models int ) * Error {
2023-10-19 11:42:38 +01:00
cls , err := model_classes . ListClasses ( c . Db , model . Id )
if err != nil {
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
// TODO improve this response
return c . Error500 ( err )
}
2023-10-20 12:37:56 +01:00
err = removeFailedDataPoints ( c , model )
2023-10-19 11:42:38 +01:00
if err != nil {
return c . Error500 ( err )
}
2023-10-22 23:02:39 +01:00
cls_len := len ( cls )
2023-10-19 11:42:38 +01:00
2023-10-22 23:02:39 +01:00
if number_of_models == 1 {
if model . Width < 100 && model . Height < 100 && cls_len < 30 {
generateDefinition ( c , model , target_accuracy , cls_len , 0 )
} else if model . Width > 100 && model . Height > 100 {
generateDefinition ( c , model , target_accuracy , cls_len , 2 )
} else {
generateDefinition ( c , model , target_accuracy , cls_len , 1 )
}
} else if number_of_models == 3 {
for i := 0 ; i < number_of_models ; i ++ {
generateDefinition ( c , model , target_accuracy , cls_len , i )
}
} else {
// TODO handle incrisea the complexity
for i := 0 ; i < number_of_models ; i ++ {
generateDefinition ( c , model , target_accuracy , cls_len , 0 )
}
}
2023-10-21 00:26:52 +01:00
return nil
2023-10-19 11:42:38 +01:00
}
2024-01-31 21:48:35 +00:00
func CreateExpModelHead ( c * Context , def_id string , range_start int , range_end int , status ModelDefinitionStatus ) ( id string , err error ) {
2024-02-08 18:20:58 +00:00
rows , err := c . Db . Query ( "insert into exp_model_head (def_id, range_start, range_end, status) values ($1, $2, $3, $4) returning id" , def_id , range_start , range_end , status )
2024-01-31 21:48:35 +00:00
if err != nil {
return
}
defer rows . Close ( )
2024-02-02 16:16:26 +00:00
if ! rows . Next ( ) {
c . Logger . Error ( "Could not get status of model definition" )
err = errors . New ( "Could not get status of model definition" )
return
}
2024-01-31 21:48:35 +00:00
err = rows . Scan ( & id )
2024-02-02 16:16:26 +00:00
if err != nil {
return
}
2024-01-31 21:48:35 +00:00
return
}
func ExpModelHeadUpdateStatus ( db * sql . DB , id string , status ModelDefinitionStatus ) ( err error ) {
_ , err = db . Exec ( "update model_definition set status = $1 where id = $2" , status , id )
return
}
// This generates a definition
func generateExpandableDefinition ( c * Context , model * BaseModel , target_accuracy int , number_of_classes int , complexity int ) * Error {
2024-02-08 18:20:58 +00:00
c . Logger . Info ( "Generating expandable new definition for model" , "id" , model . Id , "complexity" , complexity )
2024-01-31 21:48:35 +00:00
var err error = nil
failed := func ( ) * Error {
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
// TODO improve this response
return c . Error500 ( err )
}
if complexity == 0 {
2024-02-02 16:16:26 +00:00
return failed ( )
2024-01-31 21:48:35 +00:00
}
def_id , err := MakeDefenition ( c . Db , model . Id , target_accuracy )
if err != nil {
return failed ( )
}
order := 1
2024-02-02 16:16:26 +00:00
width := model . Width
height := model . Height
2024-01-31 21:48:35 +00:00
// Note the shape of the first layer defines the import size
if complexity == 2 {
// Note the shape for now is no used
width := int ( math . Pow ( 2 , math . Floor ( math . Log ( float64 ( model . Width ) ) / math . Log ( 2.0 ) ) ) )
height := int ( math . Pow ( 2 , math . Floor ( math . Log ( float64 ( model . Height ) ) / math . Log ( 2.0 ) ) ) )
c . Logger . Warn ( "Complexity 2 creating model with smaller size" , "width" , width , "height" , height )
2024-02-02 16:16:26 +00:00
}
2024-01-31 21:48:35 +00:00
2024-02-02 16:16:26 +00:00
err = MakeLayerExpandable ( c . Db , def_id , order , LAYER_INPUT , fmt . Sprintf ( "%d,%d,1" , width , height ) , 1 )
2024-01-31 21:48:35 +00:00
order ++
2024-02-02 16:16:26 +00:00
// handle the errors inside the pervious if block
if err != nil {
return failed ( )
}
2024-01-31 21:48:35 +00:00
2024-02-02 16:16:26 +00:00
// Create the blocks
loop := int ( ( math . Log ( float64 ( model . Width ) ) / math . Log ( float64 ( 10 ) ) ) )
2024-02-08 18:20:58 +00:00
if model . Width < 50 && model . Height < 50 {
loop = 0
}
log . Info ( "Size of the simple block" , "loop" , loop )
//loop = max(loop, 3)
2024-01-31 21:48:35 +00:00
2024-02-02 16:16:26 +00:00
for i := 0 ; i < loop ; i ++ {
err = MakeLayerExpandable ( c . Db , def_id , order , LAYER_SIMPLE_BLOCK , "" , 1 )
order ++
if err != nil {
return failed ( )
}
}
2024-01-31 21:48:35 +00:00
2024-02-02 16:16:26 +00:00
// Flatten the blocks into dense
err = MakeLayerExpandable ( c . Db , def_id , order , LAYER_FLATTEN , "" , 1 )
if err != nil {
return failed ( )
}
order ++
2024-01-31 21:48:35 +00:00
2024-02-02 16:16:26 +00:00
// Flatten the blocks into dense
err = MakeLayerExpandable ( c . Db , def_id , order , LAYER_DENSE , fmt . Sprintf ( "%d,1" , number_of_classes * 2 ) , 1 )
if err != nil {
return failed ( )
}
order ++
loop = int ( ( math . Log ( float64 ( number_of_classes ) ) / math . Log ( float64 ( 10 ) ) ) / 2 )
2024-02-08 18:20:58 +00:00
log . Info ( "Size of the dense layers" , "loop" , loop )
// loop = max(loop, 3)
2024-02-02 16:16:26 +00:00
for i := 0 ; i < loop ; i ++ {
err = MakeLayer ( c . Db , def_id , order , LAYER_DENSE , fmt . Sprintf ( "%d,1" , number_of_classes * ( loop - i ) ) )
order ++
if err != nil {
return failed ( )
}
}
2024-02-08 18:20:58 +00:00
2024-02-02 16:16:26 +00:00
_ , err = CreateExpModelHead ( c , def_id , 0 , number_of_classes - 1 , MODEL_DEFINITION_STATUS_INIT )
if err != nil {
return failed ( )
}
2024-01-31 21:48:35 +00:00
err = ModelDefinitionUpdateStatus ( c , def_id , MODEL_DEFINITION_STATUS_INIT )
if err != nil {
return failed ( )
}
return nil
}
func generateExpandableDefinitions ( c * Context , model * BaseModel , target_accuracy int , number_of_models int ) * Error {
cls , err := model_classes . ListClasses ( c . Db , model . Id )
if err != nil {
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
// TODO improve this response
return c . Error500 ( err )
}
err = removeFailedDataPoints ( c , model )
if err != nil {
return c . Error500 ( err )
}
cls_len := len ( cls )
if number_of_models == 1 {
if model . Width > 100 && model . Height > 100 {
generateExpandableDefinition ( c , model , target_accuracy , cls_len , 2 )
} else {
generateExpandableDefinition ( c , model , target_accuracy , cls_len , 1 )
}
} else if number_of_models == 3 {
for i := 0 ; i < number_of_models ; i ++ {
generateExpandableDefinition ( c , model , target_accuracy , cls_len , i )
}
} else {
// TODO handle incrisea the complexity
for i := 0 ; i < number_of_models ; i ++ {
generateExpandableDefinition ( c , model , target_accuracy , cls_len , 1 )
}
}
return nil
}
2023-09-26 20:15:28 +01:00
func handleTrain ( handle * Handle ) {
handle . Post ( "/models/train" , func ( w http . ResponseWriter , r * http . Request , c * Context ) * Error {
if ! CheckAuthLevel ( 1 , w , r , c ) {
return nil
}
if c . Mode == JSON {
panic ( "TODO /models/train JSON" )
}
r . ParseForm ( )
f := r . Form
number_of_models := 0
accuracy := 0
if ! CheckId ( f , "id" ) || CheckEmpty ( f , "model_type" ) || ! CheckNumber ( f , "number_of_models" , & number_of_models ) || ! CheckNumber ( f , "accuracy" , & accuracy ) {
// TODO improve this response
return ErrorCode ( nil , 400 , c . AddMap ( nil ) )
}
id := f . Get ( "id" )
2024-01-31 21:48:35 +00:00
model_type_id := 1
model_type_form := f . Get ( "model_type" )
if model_type_form == "expandable" {
model_type_id = 2
} else if model_type_form != "simple" {
2024-02-02 16:16:26 +00:00
return c . Error400 ( nil , "Invalid model type!" , w , "/models/edit.html" , "train-model-card" , AnyMap {
"HasData" : true ,
"ErrorMessage" : "Invalid model type!" ,
} )
2023-09-27 21:20:39 +01:00
}
2023-09-26 20:15:28 +01:00
model , err := GetBaseModel ( handle . Db , id )
if err == ModelNotFoundError {
return ErrorCode ( nil , http . StatusNotFound , c . AddMap ( AnyMap {
"NotFoundMessage" : "Model not found" ,
"GoBackLink" : "/models" ,
} ) )
} else if err != nil {
// TODO improve this response
return Error500 ( err )
}
if model . Status != CONFIRM_PRE_TRAINING {
// TODO improve this response
return ErrorCode ( nil , 400 , c . AddMap ( nil ) )
}
2024-02-02 16:16:26 +00:00
if model_type_id == 2 {
full_error := generateExpandableDefinitions ( c , model , accuracy , number_of_models )
if full_error != nil {
return full_error
}
} else {
full_error := generateDefinitions ( c , model , accuracy , number_of_models )
if full_error != nil {
return full_error
}
}
2023-09-26 20:15:28 +01:00
2024-02-02 16:16:26 +00:00
if model_type_id == 2 {
go trainModelExp ( c , model )
} else {
go trainModel ( c , model )
}
2023-09-27 18:07:04 +01:00
2024-02-02 16:16:26 +00:00
_ , err = c . Db . Exec ( "update models set status = $1, model_type = $2 where id = $3" , TRAINING , model_type_id , model . Id )
if err != nil {
fmt . Println ( "Failed to update model status" )
fmt . Println ( err )
2024-01-31 21:48:35 +00:00
// TODO improve this response
return Error500 ( err )
2024-02-02 16:16:26 +00:00
}
2023-09-27 21:20:39 +01:00
Redirect ( "/models/edit?id=" + model . Id , c . Mode , w , r )
2023-09-26 20:15:28 +01:00
return nil
} )
2023-10-12 12:08:12 +01:00
handle . Get ( "/model/epoch/update" , func ( w http . ResponseWriter , r * http . Request , c * Context ) * Error {
2023-10-19 10:44:13 +01:00
// TODO check auth level
2023-10-12 12:08:12 +01:00
if c . Mode != NORMAL {
2023-10-19 10:44:13 +01:00
// This should only handle normal requests
c . Logger . Warn ( "This function only works with normal" )
return c . UnsafeErrorCode ( nil , 400 , nil )
2023-10-12 12:08:12 +01:00
}
2023-10-19 10:44:13 +01:00
f := r . URL . Query ( )
2023-10-12 12:08:12 +01:00
2023-10-22 23:02:39 +01:00
accuracy := 0.0
2023-10-21 00:26:52 +01:00
2023-10-22 23:02:39 +01:00
if ! CheckId ( f , "model_id" ) || ! CheckId ( f , "definition" ) || CheckEmpty ( f , "epoch" ) || ! CheckFloat64 ( f , "accuracy" , & accuracy ) {
2023-10-21 00:26:52 +01:00
c . Logger . Warn ( "Invalid: model_id or definition or epoch or accuracy" )
2023-10-19 10:44:13 +01:00
return c . UnsafeErrorCode ( nil , 400 , nil )
}
2023-10-12 12:08:12 +01:00
2023-10-22 23:02:39 +01:00
accuracy = accuracy * 100
2023-10-21 00:26:52 +01:00
2023-10-12 12:08:12 +01:00
model_id := f . Get ( "model_id" )
def_id := f . Get ( "definition" )
2023-10-19 10:44:13 +01:00
epoch , err := strconv . Atoi ( f . Get ( "epoch" ) )
if err != nil {
c . Logger . Warn ( "Epoch is not a number" )
// No need to improve message because this function is only called internaly
return c . UnsafeErrorCode ( nil , 400 , nil )
}
rows , err := c . Db . Query ( "select md.status from model_definition as md where md.model_id=$1 and md.id=$2" , model_id , def_id )
if err != nil {
return c . Error500 ( err )
}
defer rows . Close ( )
if ! rows . Next ( ) {
c . Logger . Error ( "Could not get status of model definition" )
return c . Error500 ( nil )
}
var status int
err = rows . Scan ( & status )
if err != nil {
return c . Error500 ( err )
}
if status != 3 {
c . Logger . Warn ( "Definition not on status 3(training)" , "status" , status )
// No need to improve message because this function is only called internaly
return c . UnsafeErrorCode ( nil , 400 , nil )
}
2023-10-22 23:02:39 +01:00
c . Logger . Info ( "Updated model_definition!" , "model" , model_id , "progress" , epoch , "accuracy" , accuracy )
2023-10-21 00:26:52 +01:00
_ , err = c . Db . Exec ( "update model_definition set epoch_progress=$1, accuracy=$2 where id=$3" , epoch , accuracy , def_id )
2023-10-19 10:44:13 +01:00
if err != nil {
return c . Error500 ( err )
}
2023-10-12 12:08:12 +01:00
return nil
} )
2024-02-05 16:42:23 +00:00
handle . Get ( "/model/head/epoch/update" , func ( w http . ResponseWriter , r * http . Request , c * Context ) * Error {
// TODO check auth level
if c . Mode != NORMAL {
// This should only handle normal requests
c . Logger . Warn ( "This function only works with normal" )
return c . UnsafeErrorCode ( nil , 400 , nil )
}
f := r . URL . Query ( )
accuracy := 0.0
if ! CheckId ( f , "head_id" ) || CheckEmpty ( f , "epoch" ) || ! CheckFloat64 ( f , "accuracy" , & accuracy ) {
c . Logger . Warn ( "Invalid: model_id or head_id or epoch or accuracy" )
return c . UnsafeErrorCode ( nil , 400 , nil )
}
accuracy = accuracy * 100
head_id := f . Get ( "head_id" )
epoch , err := strconv . Atoi ( f . Get ( "epoch" ) )
if err != nil {
c . Logger . Warn ( "Epoch is not a number" )
// No need to improve message because this function is only called internaly
return c . UnsafeErrorCode ( nil , 400 , nil )
}
rows , err := c . Db . Query ( "select hd.status from exp_model_head as hd where hd.id=$1;" , head_id )
if err != nil {
return c . Error500 ( err )
}
defer rows . Close ( )
if ! rows . Next ( ) {
c . Logger . Error ( "Could not get status of model head" )
return c . Error500 ( nil )
}
var status int
err = rows . Scan ( & status )
if err != nil {
return c . Error500 ( err )
}
if status != 3 {
c . Logger . Warn ( "Head not on status 3(training)" , "status" , status )
// No need to improve message because this function is only called internaly
return c . UnsafeErrorCode ( nil , 400 , nil )
}
c . Logger . Info ( "Updated model_head!" , "head" , head_id , "progress" , epoch , "accuracy" , accuracy )
_ , err = c . Db . Exec ( "update exp_model_head set epoch_progress=$1, accuracy=$2 where id=$3" , epoch , accuracy , head_id )
if err != nil {
return c . Error500 ( err )
}
return nil
} )
2023-09-26 20:15:28 +01:00
}