2023-09-26 20:15:28 +01:00
package models_train
import (
"errors"
"fmt"
2023-09-27 21:20:39 +01:00
"io"
2023-10-19 11:42:38 +01:00
"math"
2023-09-27 21:20:39 +01:00
"os"
"os/exec"
"path"
2023-10-22 23:02:39 +01:00
"sort"
2023-09-27 21:20:39 +01:00
"strconv"
2024-03-09 10:52:08 +00:00
"strings"
2023-09-27 21:20:39 +01:00
"text/template"
2023-09-26 20:15:28 +01:00
2024-04-17 17:46:43 +01:00
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
2024-04-14 14:51:16 +01:00
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
2023-09-26 20:15:28 +01:00
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
2024-04-15 23:04:53 +01:00
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
2023-09-26 20:15:28 +01:00
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
2024-04-15 23:04:53 +01:00
2024-02-02 16:16:26 +00:00
"github.com/charmbracelet/log"
2024-04-15 23:04:53 +01:00
"github.com/goccy/go-json"
2023-09-26 20:15:28 +01:00
)
2023-10-25 14:50:58 +01:00
const EPOCH_PER_RUN = 20
2023-10-19 10:44:13 +01:00
const MAX_EPOCH = 100
2024-03-09 10:52:08 +00:00
func shapeToSize ( shape string ) string {
split := strings . Split ( shape , "," )
return strings . Join ( split [ : len ( split ) - 1 ] , "," )
}
func getDir ( ) string {
dir , err := os . Getwd ( )
if err != nil {
panic ( err )
}
return dir
}
2024-04-15 23:04:53 +01:00
func ModelDefinitionUpdateStatus ( c BasePack , id string , status ModelDefinitionStatus ) ( err error ) {
_ , err = c . GetDb ( ) . Exec ( "update model_definition set status = $1 where id = $2" , status , id )
2023-09-27 21:20:39 +01:00
return
2023-09-26 20:15:28 +01:00
}
2024-04-17 17:46:43 +01:00
func MakeLayer ( db db . Db , def_id string , layer_order int , layer_type LayerType , shape string ) ( err error ) {
2023-09-26 20:15:28 +01:00
_ , err = db . Exec ( "insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)" , def_id , layer_order , layer_type , shape )
return
}
2024-04-17 17:46:43 +01:00
func MakeLayerExpandable ( db db . Db , def_id string , layer_order int , layer_type LayerType , shape string , exp_type int ) ( err error ) {
2024-01-31 21:48:35 +00:00
_ , err = db . Exec ( "insert into model_definition_layer (def_id, layer_order, layer_type, shape, exp_type) values ($1, $2, $3, $4, $5)" , def_id , layer_order , layer_type , shape , exp_type )
return
}
2024-04-15 23:04:53 +01:00
func generateCvs ( c BasePack , run_path string , model_id string ) ( count int , err error ) {
db := c . GetDb ( )
2023-10-02 21:15:31 +01:00
2024-04-08 14:17:13 +01:00
var co struct {
Count int ` db:"count(*)" `
2023-10-10 12:28:49 +01:00
}
2024-04-15 23:04:53 +01:00
err = GetDBOnce ( db , & co , "model_classes where model_id=$1;" , model_id )
2024-04-08 14:17:13 +01:00
if err != nil {
2023-10-10 12:28:49 +01:00
return
}
2024-04-08 14:17:13 +01:00
count = co . Count
2023-10-10 12:28:49 +01:00
2024-04-15 23:04:53 +01:00
data , err := db . Query ( "select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;" , model_id , DATA_POINT_MODE_TRAINING )
2023-10-10 12:28:49 +01:00
if err != nil {
return
}
defer data . Close ( )
f , err := os . Create ( path . Join ( run_path , "train.csv" ) )
if err != nil {
return
}
defer f . Close ( )
f . Write ( [ ] byte ( "Id,Index\n" ) )
for data . Next ( ) {
var id string
var class_order int
var file_path string
if err = data . Scan ( & id , & class_order , & file_path ) ; err != nil {
return
}
if file_path == "id://" {
f . Write ( [ ] byte ( id + "," + strconv . Itoa ( class_order ) + "\n" ) )
} else {
return count , errors . New ( "TODO generateCvs to file_path " + file_path )
}
}
return
2023-10-02 21:15:31 +01:00
}
2024-04-15 23:04:53 +01:00
func setModelClassStatus ( c BasePack , status ModelClassStatus , filter string , args ... any ) ( err error ) {
_ , err = c . GetDb ( ) . Exec ( fmt . Sprintf ( "update model_classes set status=%d where %s" , status , filter ) , args ... )
2024-03-06 23:33:54 +00:00
return
}
2024-05-09 00:46:42 +01:00
func SetModelClassStatus ( c BasePack , status ModelClassStatus , filter string , args ... any ) ( err error ) {
return setModelClassStatus ( c , status , filter , args ... )
}
2024-04-15 23:04:53 +01:00
func generateCvsExp ( c BasePack , run_path string , model_id string , doPanic bool ) ( count int , err error ) {
db := c . GetDb ( )
2024-03-06 23:33:54 +00:00
2024-04-08 14:17:13 +01:00
var co struct {
Count int ` db:"count(*)" `
2024-03-06 23:33:54 +00:00
}
2024-05-06 01:10:58 +01:00
err = GetDBOnce ( db , & co , "model_classes where model_id=$1 and status=$2;" , model_id , CLASS_STATUS_TRAINING )
2024-04-08 14:17:13 +01:00
if err != nil {
2024-03-06 23:33:54 +00:00
return
}
2024-04-08 14:17:13 +01:00
count = co . Count
2024-03-06 23:33:54 +00:00
if count == 0 {
2024-05-06 01:10:58 +01:00
err = setModelClassStatus ( c , CLASS_STATUS_TRAINING , "model_id=$1 and status=$2;" , model_id , CLASS_STATUS_TO_TRAIN )
2024-03-06 23:33:54 +00:00
if err != nil {
return
}
if doPanic {
return 0 , errors . New ( "No model classes available" )
}
return generateCvsExp ( c , run_path , model_id , true )
}
2024-05-06 01:10:58 +01:00
data , err := db . Query ( "select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;" , model_id , DATA_POINT_MODE_TRAINING , CLASS_STATUS_TRAINING )
2024-03-06 23:33:54 +00:00
if err != nil {
return
}
defer data . Close ( )
f , err := os . Create ( path . Join ( run_path , "train.csv" ) )
if err != nil {
return
}
defer f . Close ( )
f . Write ( [ ] byte ( "Id,Index\n" ) )
for data . Next ( ) {
var id string
var class_order int
var file_path string
if err = data . Scan ( & id , & class_order , & file_path ) ; err != nil {
return
}
if file_path == "id://" {
f . Write ( [ ] byte ( id + "," + strconv . Itoa ( class_order ) + "\n" ) )
} else {
return count , errors . New ( "TODO generateCvs to file_path " + file_path )
}
}
return
}
2024-05-09 01:23:43 +01:00
func trainDefinition ( c BasePack , model * BaseModel , def Definition , load_prev bool ) ( accuracy float64 , err error ) {
2024-04-15 23:04:53 +01:00
l := c . GetLogger ( )
l . Warn ( "About to start training definition" )
2023-09-27 21:20:39 +01:00
accuracy = 0
2024-05-09 01:23:43 +01:00
layers , err := def . GetLayers ( c . GetDb ( ) , " order by layer_order asc;" )
2023-09-27 21:20:39 +01:00
if err != nil {
return
}
2024-05-09 01:23:43 +01:00
for _ , layer := range layers {
layer . ShapeToSize ( )
2023-09-27 21:20:39 +01:00
}
// Generate run folder
2024-05-09 01:23:43 +01:00
run_path := path . Join ( "/tmp" , model . Id , "defs" , def . Id )
2023-09-27 21:20:39 +01:00
err = os . MkdirAll ( run_path , os . ModePerm )
if err != nil {
return
}
2024-03-06 23:33:54 +00:00
classCount , err := generateCvs ( c , run_path , model . Id )
2023-10-10 12:28:49 +01:00
if err != nil {
return
}
2023-10-02 21:15:31 +01:00
2023-09-27 21:20:39 +01:00
// Create python script
f , err := os . Create ( path . Join ( run_path , "run.py" ) )
if err != nil {
return
}
defer f . Close ( )
tmpl , err := template . New ( "python_model_template.py" ) . ParseFiles ( "views/py/python_model_template.py" )
if err != nil {
return
}
2023-10-19 10:44:13 +01:00
// Copy result around
2024-05-09 01:23:43 +01:00
result_path := path . Join ( "savedData" , model . Id , "defs" , def . Id )
2023-10-19 10:44:13 +01:00
2023-09-27 21:20:39 +01:00
if err = tmpl . Execute ( f , AnyMap {
2024-05-09 01:23:43 +01:00
"Layers" : layers ,
"Size" : layers [ 0 ] . Shape ,
2023-10-19 10:44:13 +01:00
"DataDir" : path . Join ( getDir ( ) , "savedData" , model . Id , "data" ) ,
"RunPath" : run_path ,
"ColorMode" : model . ImageMode ,
"Model" : model ,
"EPOCH_PER_RUN" : EPOCH_PER_RUN ,
2024-05-09 01:23:43 +01:00
"DefId" : def . Id ,
2023-10-19 10:44:13 +01:00
"LoadPrev" : load_prev ,
"LastModelRunPath" : path . Join ( getDir ( ) , result_path , "model.keras" ) ,
2023-10-22 23:02:39 +01:00
"SaveModelPath" : path . Join ( getDir ( ) , result_path ) ,
2024-03-06 23:33:54 +00:00
"Depth" : classCount ,
"StartPoint" : 0 ,
2024-04-15 23:04:53 +01:00
"Host" : c . GetHost ( ) ,
2023-09-27 21:20:39 +01:00
} ) ; err != nil {
return
}
// Run the command
2023-10-22 23:02:39 +01:00
out , err := exec . Command ( "bash" , "-c" , fmt . Sprintf ( "cd %s && python run.py" , run_path ) ) . CombinedOutput ( )
2023-10-10 12:28:49 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Debug ( string ( out ) )
2023-09-27 21:20:39 +01:00
return
}
2024-04-15 23:04:53 +01:00
l . Info ( "Python finished running" )
2023-09-27 21:20:39 +01:00
2023-10-22 23:02:39 +01:00
if err = os . MkdirAll ( result_path , os . ModePerm ) ; err != nil {
2023-10-19 10:44:13 +01:00
return
}
2023-09-27 21:20:39 +01:00
accuracy_file , err := os . Open ( path . Join ( run_path , "accuracy.val" ) )
if err != nil {
return
}
defer accuracy_file . Close ( )
accuracy_file_bytes , err := io . ReadAll ( accuracy_file )
if err != nil {
return
}
accuracy , err = strconv . ParseFloat ( string ( accuracy_file_bytes ) , 64 )
if err != nil {
return
}
2023-10-10 12:28:49 +01:00
2024-04-08 14:17:13 +01:00
os . RemoveAll ( run_path )
2024-04-15 23:04:53 +01:00
l . Info ( "Model finished training!" , "accuracy" , accuracy )
2023-09-27 21:20:39 +01:00
return
}
2024-04-16 17:48:52 +01:00
func generateCvsExpandExp ( c BasePack , run_path string , model_id string , offset int , doPanic bool ) ( count_re int , err error ) {
l , db := c . GetLogger ( ) , c . GetDb ( )
2024-04-08 14:17:13 +01:00
var co struct {
Count int ` db:"count(*)" `
}
2024-05-06 01:10:58 +01:00
err = GetDBOnce ( db , & co , "model_classes where model_id=$1 and status=$2;" , model_id , CLASS_STATUS_TRAINING )
2024-03-06 23:33:54 +00:00
if err != nil {
2024-04-08 14:17:13 +01:00
return
2024-03-06 23:33:54 +00:00
}
2024-04-16 17:48:52 +01:00
l . Info ( "test here" , "count" , co )
2024-04-08 14:17:13 +01:00
count_re = co . Count
2024-04-08 15:47:31 +01:00
count := co . Count
2024-04-08 14:17:13 +01:00
if count == 0 {
2024-05-06 01:10:58 +01:00
err = setModelClassStatus ( c , CLASS_STATUS_TRAINING , "model_id=$1 and status=$2;" , model_id , CLASS_STATUS_TO_TRAIN )
2024-04-08 14:17:13 +01:00
if err != nil {
return
} else if doPanic {
return 0 , errors . New ( "No model classes available" )
}
return generateCvsExpandExp ( c , run_path , model_id , offset , true )
}
2024-05-06 01:10:58 +01:00
data , err := db . Query ( "select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;" , model_id , DATA_POINT_MODE_TRAINING , CLASS_STATUS_TRAINING )
2024-04-08 14:17:13 +01:00
if err != nil {
return
}
defer data . Close ( )
f , err := os . Create ( path . Join ( run_path , "train.csv" ) )
if err != nil {
return
}
defer f . Close ( )
f . Write ( [ ] byte ( "Id,Index\n" ) )
count = 0
for data . Next ( ) {
var id string
var class_order int
var file_path string
if err = data . Scan ( & id , & class_order , & file_path ) ; err != nil {
return
}
if file_path == "id://" {
f . Write ( [ ] byte ( id + "," + strconv . Itoa ( class_order - offset ) + "\n" ) )
} else {
return count , errors . New ( "TODO generateCvs to file_path " + file_path )
}
count += 1
}
//
// This is to load some extra data so that the model has more things to train on
//
2024-05-06 01:10:58 +01:00
data_other , err := db . Query ( "select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;" , model_id , DATA_POINT_MODE_TRAINING , CLASS_STATUS_TRAINED , count * 10 )
2024-04-08 14:17:13 +01:00
if err != nil {
return
}
defer data_other . Close ( )
for data_other . Next ( ) {
var id string
var class_order int
var file_path string
if err = data_other . Scan ( & id , & class_order , & file_path ) ; err != nil {
return
}
if file_path == "id://" {
2024-04-14 16:51:15 +01:00
f . Write ( [ ] byte ( id + "," + strconv . Itoa ( - 2 ) + "\n" ) )
2024-04-08 14:17:13 +01:00
} else {
return count , errors . New ( "TODO generateCvs to file_path " + file_path )
}
}
return
2024-03-06 23:33:54 +00:00
}
2024-04-16 17:48:52 +01:00
func trainDefinitionExpandExp ( c BasePack , model * BaseModel , definition_id string , load_prev bool ) ( accuracy float64 , err error ) {
2024-02-02 16:16:26 +00:00
accuracy = 0
2024-04-16 17:48:52 +01:00
l := c . GetLogger ( )
l . Warn ( "About to retrain model" )
2024-02-02 16:16:26 +00:00
// Get untrained models heads
2024-04-08 14:17:13 +01:00
type ExpHead struct {
Id string
Start int ` db:"range_start" `
End int ` db:"range_end" `
}
// status = 2 (INIT) 3 (TRAINING)
2024-04-16 17:48:52 +01:00
heads , err := GetDbMultitple [ ExpHead ] ( c . GetDb ( ) , "exp_model_head where def_id=$1 and (status = 2 or status = 3)" , definition_id )
2024-02-02 16:16:26 +00:00
if err != nil {
return
2024-04-08 14:17:13 +01:00
} else if len ( heads ) == 0 {
log . Error ( "Failed to get the exp head of the model" )
return
} else if len ( heads ) != 1 {
err = errors . New ( "This training function can only train one model at the time" )
return
2024-02-02 16:16:26 +00:00
}
2024-04-08 14:17:13 +01:00
exp := heads [ 0 ]
2024-04-16 17:48:52 +01:00
l . Info ( "Got exp head" , "head" , exp )
2024-04-08 14:17:13 +01:00
2024-04-16 17:48:52 +01:00
if err = UpdateStatus ( c . GetDb ( ) , "exp_model_head" , exp . Id , MODEL_DEFINITION_STATUS_TRAINING ) ; err != nil {
2024-04-08 14:17:13 +01:00
return
2024-02-02 16:16:26 +00:00
}
2024-04-16 17:48:52 +01:00
layers , err := c . GetDb ( ) . Query ( "select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;" , definition_id )
2024-04-08 14:17:13 +01:00
if err != nil {
return
}
defer layers . Close ( )
2024-02-02 16:16:26 +00:00
2024-04-08 14:17:13 +01:00
type layerrow struct {
LayerType int
Shape string
ExpType int
LayerNum int
}
got := [ ] layerrow { }
i := 1
var last * layerrow = nil
got_2 := false
2024-04-08 15:47:31 +01:00
var first * layerrow = nil
2024-04-08 14:17:13 +01:00
for layers . Next ( ) {
var row = layerrow { }
if err = layers . Scan ( & row . LayerType , & row . Shape , & row . ExpType ) ; err != nil {
2024-02-03 12:39:22 +00:00
return
}
2024-04-08 14:17:13 +01:00
2024-04-08 15:47:31 +01:00
// Keep track of the first layer so we can keep the size of the image
if first == nil {
first = & row
}
2024-04-08 14:17:13 +01:00
row . LayerNum = i
row . Shape = shapeToSize ( row . Shape )
if row . ExpType == 2 {
if ! got_2 {
got = append ( got , * last )
got_2 = true
}
got = append ( got , row )
}
last = & row
i += 1
}
got = append ( got , layerrow {
LayerType : LAYER_DENSE ,
Shape : fmt . Sprintf ( "%d" , exp . End - exp . Start + 1 ) ,
ExpType : 2 ,
LayerNum : i ,
} )
2024-04-16 17:48:52 +01:00
l . Info ( "Got layers" , "layers" , got )
2024-04-08 14:17:13 +01:00
// Generate run folder
run_path := path . Join ( "/tmp" , model . Id + "-defs-" + definition_id + "-retrain" )
err = os . MkdirAll ( run_path , os . ModePerm )
if err != nil {
return
}
classCount , err := generateCvsExpandExp ( c , run_path , model . Id , exp . Start , false )
if err != nil {
return
}
2024-04-16 17:48:52 +01:00
l . Info ( "Generated cvs" , "classCount" , classCount )
2024-04-08 14:17:13 +01:00
// TODO update the run script
// Create python script
f , err := os . Create ( path . Join ( run_path , "run.py" ) )
if err != nil {
return
}
defer f . Close ( )
2024-04-16 17:48:52 +01:00
l . Info ( "About to run python!" )
2024-04-08 14:17:13 +01:00
tmpl , err := template . New ( "python_model_template_expand.py" ) . ParseFiles ( "views/py/python_model_template_expand.py" )
if err != nil {
return
}
// Copy result around
result_path := path . Join ( "savedData" , model . Id , "defs" , definition_id )
if err = tmpl . Execute ( f , AnyMap {
"Layers" : got ,
"Size" : first . Shape ,
"DataDir" : path . Join ( getDir ( ) , "savedData" , model . Id , "data" ) ,
"HeadId" : exp . Id ,
"RunPath" : run_path ,
"ColorMode" : model . ImageMode ,
"Model" : model ,
"EPOCH_PER_RUN" : EPOCH_PER_RUN ,
"LoadPrev" : load_prev ,
"BaseModel" : path . Join ( getDir ( ) , result_path , "base" , "model.keras" ) ,
"LastModelRunPath" : path . Join ( getDir ( ) , result_path , "head" , exp . Id , "model.keras" ) ,
"SaveModelPath" : path . Join ( getDir ( ) , result_path , "head" , exp . Id ) ,
"Depth" : classCount ,
"StartPoint" : 0 ,
2024-04-16 17:48:52 +01:00
"Host" : c . GetHost ( ) ,
2024-04-08 14:17:13 +01:00
} ) ; err != nil {
return
}
// Run the command
out , err := exec . Command ( "bash" , "-c" , fmt . Sprintf ( "cd %s && python run.py" , run_path ) ) . CombinedOutput ( )
if err != nil {
2024-04-16 17:48:52 +01:00
l . Warn ( "Python failed to run" , "err" , err , "out" , string ( out ) )
2024-04-08 14:17:13 +01:00
return
}
2024-04-16 17:48:52 +01:00
l . Info ( "Python finished running" )
2024-04-08 14:17:13 +01:00
if err = os . MkdirAll ( result_path , os . ModePerm ) ; err != nil {
return
}
accuracy_file , err := os . Open ( path . Join ( run_path , "accuracy.val" ) )
if err != nil {
return
}
defer accuracy_file . Close ( )
accuracy_file_bytes , err := io . ReadAll ( accuracy_file )
if err != nil {
return
}
accuracy , err = strconv . ParseFloat ( string ( accuracy_file_bytes ) , 64 )
if err != nil {
2024-02-03 12:39:22 +00:00
return
}
2024-04-15 23:04:53 +01:00
2024-04-08 14:17:13 +01:00
os . RemoveAll ( run_path )
2024-04-16 17:48:52 +01:00
l . Info ( "Model finished training!" , "accuracy" , accuracy )
2024-04-08 14:17:13 +01:00
return
}
2024-04-15 23:04:53 +01:00
func trainDefinitionExp ( c BasePack , model * BaseModel , definition_id string , load_prev bool ) ( accuracy float64 , err error ) {
2024-04-08 14:17:13 +01:00
accuracy = 0
2024-04-15 23:04:53 +01:00
l := c . GetLogger ( )
db := c . GetDb ( )
2024-04-08 14:17:13 +01:00
2024-04-15 23:04:53 +01:00
l . Warn ( "About to start training definition" )
2024-04-08 14:17:13 +01:00
// Get untrained models heads
type ExpHead struct {
Id string
Start int ` db:"range_start" `
End int ` db:"range_end" `
}
// status = 2 (INIT) 3 (TRAINING)
2024-04-15 23:04:53 +01:00
heads , err := GetDbMultitple [ ExpHead ] ( db , "exp_model_head where def_id=$1 and (status = 2 or status = 3)" , definition_id )
2024-04-08 14:17:13 +01:00
if err != nil {
return
} else if len ( heads ) == 0 {
log . Error ( "Failed to get the exp head of the model" )
return
} else if len ( heads ) != 1 {
2024-02-03 12:39:22 +00:00
log . Error ( "This training function can only train one model at the time" )
err = errors . New ( "This training function can only train one model at the time" )
return
}
2024-02-02 16:16:26 +00:00
2024-04-08 14:17:13 +01:00
exp := heads [ 0 ]
2024-04-15 23:04:53 +01:00
if err = UpdateStatus ( db , "exp_model_head" , exp . Id , MODEL_DEFINITION_STATUS_TRAINING ) ; err != nil {
2024-04-08 14:17:13 +01:00
return
}
2024-02-05 16:42:23 +00:00
2024-04-15 23:04:53 +01:00
layers , err := db . Query ( "select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;" , definition_id )
2024-02-02 16:16:26 +00:00
if err != nil {
return
}
defer layers . Close ( )
type layerrow struct {
LayerType int
Shape string
2024-02-03 12:39:22 +00:00
ExpType int
LayerNum int
2024-02-02 16:16:26 +00:00
}
got := [ ] layerrow { }
2024-02-03 12:39:22 +00:00
i := 1
2024-02-02 16:16:26 +00:00
for layers . Next ( ) {
var row = layerrow { }
2024-02-03 12:39:22 +00:00
if err = layers . Scan ( & row . LayerType , & row . Shape , & row . ExpType ) ; err != nil {
2024-02-02 16:16:26 +00:00
return
}
2024-02-03 12:39:22 +00:00
row . LayerNum = i
2024-02-02 16:16:26 +00:00
row . Shape = shapeToSize ( row . Shape )
got = append ( got , row )
2024-02-03 12:39:22 +00:00
i += 1
2024-02-02 16:16:26 +00:00
}
2024-02-03 12:39:22 +00:00
got = append ( got , layerrow {
LayerType : LAYER_DENSE ,
2024-04-08 14:17:13 +01:00
Shape : fmt . Sprintf ( "%d" , exp . End - exp . Start + 1 ) ,
2024-02-03 12:39:22 +00:00
ExpType : 2 ,
LayerNum : i ,
} )
2024-02-02 16:16:26 +00:00
// Generate run folder
2024-04-08 14:17:13 +01:00
run_path := path . Join ( "/tmp" , model . Id + "-defs-" + definition_id )
2024-02-02 16:16:26 +00:00
err = os . MkdirAll ( run_path , os . ModePerm )
if err != nil {
return
}
2024-03-06 23:33:54 +00:00
classCount , err := generateCvsExp ( c , run_path , model . Id , false )
2024-02-02 16:16:26 +00:00
if err != nil {
return
}
2024-02-03 12:39:22 +00:00
// TODO update the run script
2024-02-02 16:16:26 +00:00
// Create python script
f , err := os . Create ( path . Join ( run_path , "run.py" ) )
if err != nil {
return
}
defer f . Close ( )
2024-02-05 16:42:23 +00:00
tmpl , err := template . New ( "python_model_template.py" ) . ParseFiles ( "views/py/python_model_template.py" )
2024-02-02 16:16:26 +00:00
if err != nil {
return
}
// Copy result around
result_path := path . Join ( "savedData" , model . Id , "defs" , definition_id )
if err = tmpl . Execute ( f , AnyMap {
"Layers" : got ,
"Size" : got [ 0 ] . Shape ,
"DataDir" : path . Join ( getDir ( ) , "savedData" , model . Id , "data" ) ,
2024-04-08 14:17:13 +01:00
"HeadId" : exp . Id ,
2024-02-02 16:16:26 +00:00
"RunPath" : run_path ,
"ColorMode" : model . ImageMode ,
"Model" : model ,
"EPOCH_PER_RUN" : EPOCH_PER_RUN ,
"LoadPrev" : load_prev ,
"LastModelRunPath" : path . Join ( getDir ( ) , result_path , "model.keras" ) ,
"SaveModelPath" : path . Join ( getDir ( ) , result_path ) ,
2024-03-06 23:33:54 +00:00
"Depth" : classCount ,
"StartPoint" : 0 ,
2024-04-15 23:04:53 +01:00
"Host" : c . GetHost ( ) ,
2024-02-02 16:16:26 +00:00
} ) ; err != nil {
return
}
// Run the command
out , err := exec . Command ( "bash" , "-c" , fmt . Sprintf ( "cd %s && python run.py" , run_path ) ) . CombinedOutput ( )
if err != nil {
2024-04-15 23:04:53 +01:00
l . Debug ( string ( out ) )
2024-02-02 16:16:26 +00:00
return
}
2024-04-15 23:04:53 +01:00
l . Info ( "Python finished running" )
2024-02-02 16:16:26 +00:00
if err = os . MkdirAll ( result_path , os . ModePerm ) ; err != nil {
return
}
accuracy_file , err := os . Open ( path . Join ( run_path , "accuracy.val" ) )
if err != nil {
return
}
defer accuracy_file . Close ( )
accuracy_file_bytes , err := io . ReadAll ( accuracy_file )
if err != nil {
return
}
accuracy , err = strconv . ParseFloat ( string ( accuracy_file_bytes ) , 64 )
if err != nil {
return
}
2024-04-08 14:17:13 +01:00
os . RemoveAll ( run_path )
2024-04-15 23:04:53 +01:00
l . Info ( "Model finished training!" , "accuracy" , accuracy )
2024-02-02 16:16:26 +00:00
return
}
2023-10-19 10:44:13 +01:00
func remove [ T interface { } ] ( lst [ ] T , i int ) [ ] T {
lng := len ( lst )
if i >= lng {
return [ ] T { }
}
if i + 1 >= lng {
return lst [ : lng - 1 ]
}
if i == 0 {
return lst [ 1 : ]
}
return append ( lst [ : i ] , lst [ i + 1 : ] ... )
}
2023-10-22 23:02:39 +01:00
type TrainModelRow struct {
id string
target_accuracy int
epoch int
acuracy float64
}
type TraingModelRowDefinitions [ ] TrainModelRow
func ( nf TraingModelRowDefinitions ) Len ( ) int { return len ( nf ) }
func ( nf TraingModelRowDefinitions ) Swap ( i , j int ) { nf [ i ] , nf [ j ] = nf [ j ] , nf [ i ] }
func ( nf TraingModelRowDefinitions ) Less ( i , j int ) bool {
return nf [ i ] . acuracy < nf [ j ] . acuracy
}
type ToRemoveList [ ] int
func ( nf ToRemoveList ) Len ( ) int { return len ( nf ) }
func ( nf ToRemoveList ) Swap ( i , j int ) { nf [ i ] , nf [ j ] = nf [ j ] , nf [ i ] }
func ( nf ToRemoveList ) Less ( i , j int ) bool {
return nf [ i ] < nf [ j ]
}
2024-04-15 23:04:53 +01:00
func trainModel ( c BasePack , model * BaseModel ) ( err error ) {
db := c . GetDb ( )
l := c . GetLogger ( )
2024-03-09 09:41:16 +00:00
2024-05-09 01:23:43 +01:00
defs_ , err := model . GetDefinitions ( db , "and md.status=$2" , MODEL_DEFINITION_STATUS_INIT )
2023-09-27 21:20:39 +01:00
if err != nil {
2024-05-09 01:23:43 +01:00
l . Error ( "Failed to train Model!" , "err" , err )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( c , model . Id , int ( FAILED_TRAINING ) )
2023-09-27 21:20:39 +01:00
return
}
2024-05-09 01:23:43 +01:00
var defs SortByAccuracyDefinitions = defs_
2023-09-27 21:20:39 +01:00
2024-05-09 01:23:43 +01:00
if len ( defs ) == 0 {
2024-04-15 23:04:53 +01:00
l . Error ( "No Definitions defined!" )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( c , model . Id , int ( FAILED_TRAINING ) )
2023-09-27 21:20:39 +01:00
return
}
2023-10-19 10:44:13 +01:00
firstRound := true
2023-10-22 23:02:39 +01:00
finished := false
2023-10-19 10:44:13 +01:00
for {
2023-10-22 23:02:39 +01:00
var toRemove ToRemoveList = [ ] int { }
2024-05-09 01:23:43 +01:00
for i , def := range defs {
ModelDefinitionUpdateStatus ( c , def . Id , MODEL_DEFINITION_STATUS_TRAINING )
accuracy , err := trainDefinition ( c , model , * def , ! firstRound )
2023-10-19 10:44:13 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!Err:" , "err" , err )
2024-05-09 01:23:43 +01:00
ModelDefinitionUpdateStatus ( c , def . Id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
2023-10-22 23:02:39 +01:00
toRemove = append ( toRemove , i )
2023-10-19 10:44:13 +01:00
continue
}
2024-05-09 01:23:43 +01:00
def . Epoch += EPOCH_PER_RUN
2023-10-21 00:26:52 +01:00
accuracy = accuracy * 100
2024-05-09 01:23:43 +01:00
def . Accuracy = float64 ( accuracy )
2023-10-24 22:35:11 +01:00
2024-05-09 01:23:43 +01:00
if accuracy >= float64 ( def . TargetAccuracy ) {
2024-04-15 23:04:53 +01:00
l . Info ( "Found a definition that reaches target_accuracy!" )
2024-05-09 01:23:43 +01:00
_ , err = db . Exec ( "update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4" , accuracy , MODEL_DEFINITION_STATUS_TRANIED , def . Epoch , def . Id )
2023-10-19 10:44:13 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!Err:\n" , "err" , err )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( c , model . Id , int ( FAILED_TRAINING ) )
2024-04-15 23:04:53 +01:00
return err
2023-10-19 10:44:13 +01:00
}
2024-05-09 01:23:43 +01:00
_ , err = db . Exec ( "update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4" , MODEL_DEFINITION_STATUS_CANCELD_TRAINING , def . Id , model . Id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
2023-10-19 10:44:13 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!Err:\n" , "err" , err )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( c , model . Id , int ( FAILED_TRAINING ) )
2024-04-15 23:04:53 +01:00
return err
2023-10-19 10:44:13 +01:00
}
2023-10-22 23:02:39 +01:00
finished = true
2023-10-19 10:44:13 +01:00
break
}
2023-09-27 21:20:39 +01:00
2024-05-09 01:23:43 +01:00
if def . Epoch > MAX_EPOCH {
fmt . Printf ( "Failed to train definition! Accuracy less %f < %d\n" , accuracy , def . TargetAccuracy )
ModelDefinitionUpdateStatus ( c , def . Id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
2023-10-22 23:02:39 +01:00
toRemove = append ( toRemove , i )
2023-10-19 10:44:13 +01:00
continue
}
2023-09-27 21:20:39 +01:00
2024-05-09 01:23:43 +01:00
_ , err = db . Exec ( "update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4" , accuracy , def . Epoch , MODEL_DEFINITION_STATUS_PAUSED_TRAINING , def . Id )
2023-10-22 23:02:39 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!Err:\n" , "err" , err )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( c , model . Id , int ( FAILED_TRAINING ) )
2024-04-15 23:04:53 +01:00
return err
2023-10-22 23:02:39 +01:00
}
2023-09-27 21:20:39 +01:00
}
2023-10-22 23:02:39 +01:00
2023-10-19 10:44:13 +01:00
firstRound = false
2023-10-22 23:02:39 +01:00
if finished {
break
}
2023-10-25 14:50:58 +01:00
sort . Sort ( sort . Reverse ( toRemove ) )
2023-10-22 23:02:39 +01:00
2024-04-15 23:04:53 +01:00
l . Info ( "Round done" , "toRemove" , toRemove )
2023-10-22 23:02:39 +01:00
for _ , n := range toRemove {
2024-05-09 01:23:43 +01:00
defs = remove ( defs , n )
2023-10-22 23:02:39 +01:00
}
2024-05-09 01:23:43 +01:00
len_def := len ( defs )
2023-10-22 23:02:39 +01:00
if len_def == 0 {
2023-10-19 10:44:13 +01:00
break
2024-05-09 01:23:43 +01:00
} else if len_def == 1 {
2023-10-22 23:02:39 +01:00
continue
}
2024-05-09 01:23:43 +01:00
sort . Sort ( sort . Reverse ( defs ) )
2023-10-22 23:02:39 +01:00
2024-05-09 01:23:43 +01:00
acc := defs [ 0 ] . Accuracy - 20.0
2023-10-22 23:02:39 +01:00
2024-05-09 01:23:43 +01:00
l . Info ( "Training models, Highest acc" , "acc" , defs [ 0 ] . Accuracy , "mod_acc" , acc )
2023-10-22 23:02:39 +01:00
toRemove = [ ] int { }
2024-05-09 01:23:43 +01:00
for i , def := range defs {
if def . Accuracy < acc {
2023-10-22 23:02:39 +01:00
toRemove = append ( toRemove , i )
}
}
2024-04-15 23:04:53 +01:00
l . Info ( "Removing due to accuracy" , "toRemove" , toRemove )
2023-10-22 23:02:39 +01:00
2023-10-25 14:50:58 +01:00
sort . Sort ( sort . Reverse ( toRemove ) )
2023-10-22 23:02:39 +01:00
for _ , n := range toRemove {
2024-04-15 23:04:53 +01:00
l . Warn ( "Removing definition not fast enough learning" , "n" , n )
2024-05-09 01:23:43 +01:00
ModelDefinitionUpdateStatus ( c , defs [ n ] . Id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
defs = remove ( defs , n )
2023-10-22 23:02:39 +01:00
}
2023-09-27 21:20:39 +01:00
}
2024-04-15 23:04:53 +01:00
rows , err := db . Query ( "select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;" , model . Id , MODEL_DEFINITION_STATUS_TRANIED )
2023-09-27 21:20:39 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "DB: failed to read definition" )
l . Error ( err )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( c , model . Id , int ( FAILED_TRAINING ) )
2023-09-27 21:20:39 +01:00
return
}
2023-10-10 12:28:49 +01:00
defer rows . Close ( )
2023-09-27 21:20:39 +01:00
2023-10-10 12:28:49 +01:00
if ! rows . Next ( ) {
// TODO Make the Model status have a message
2024-04-15 23:04:53 +01:00
l . Error ( "All definitions failed to train!" )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( c , model . Id , int ( FAILED_TRAINING ) )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
2023-09-27 21:20:39 +01:00
2023-10-10 12:28:49 +01:00
var id string
if err = rows . Scan ( & id ) ; err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to read id:" )
l . Error ( err )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( c , model . Id , int ( FAILED_TRAINING ) )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
2023-09-27 18:07:04 +01:00
2024-04-15 23:04:53 +01:00
if _ , err = db . Exec ( "update model_definition set status=$1 where id=$2;" , MODEL_DEFINITION_STATUS_READY , id ) ; err != nil {
l . Error ( "Failed to update model definition" )
l . Error ( err )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( c , model . Id , int ( FAILED_TRAINING ) )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
2023-09-27 21:20:39 +01:00
2024-04-15 23:04:53 +01:00
to_delete , err := db . Query ( "select id from model_definition where status != $1 and model_id=$2" , MODEL_DEFINITION_STATUS_READY , model . Id )
2023-10-10 12:28:49 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to select model_definition to delete" )
l . Error ( err )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( c , model . Id , int ( FAILED_TRAINING ) )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
defer to_delete . Close ( )
for to_delete . Next ( ) {
var id string
2024-03-06 23:33:54 +00:00
if err = to_delete . Scan ( & id ) ; err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to scan the id of a model_definition to delete" )
l . Error ( err )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( c , model . Id , int ( FAILED_TRAINING ) )
2023-10-10 12:28:49 +01:00
return
}
os . RemoveAll ( path . Join ( "savedData" , model . Id , "defs" , id ) )
}
// TODO Check if returning also works here
2024-04-15 23:04:53 +01:00
if _ , err = db . Exec ( "delete from model_definition where status!=$1 and model_id=$2;" , MODEL_DEFINITION_STATUS_READY , model . Id ) ; err != nil {
l . Error ( "Failed to delete model_definition" )
l . Error ( err )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( c , model . Id , int ( FAILED_TRAINING ) )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
2023-09-27 18:07:04 +01:00
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , READY )
2024-04-15 23:04:53 +01:00
return
2023-09-27 18:07:04 +01:00
}
2024-03-06 23:33:54 +00:00
type TrainModelRowUsable struct {
Id string
TargetAccuracy int ` db:"target_accuracy" `
Epoch int
Acuracy float64 ` db:"0" `
}
type TrainModelRowUsables [ ] * TrainModelRowUsable
func ( nf TrainModelRowUsables ) Len ( ) int { return len ( nf ) }
func ( nf TrainModelRowUsables ) Swap ( i , j int ) { nf [ i ] , nf [ j ] = nf [ j ] , nf [ i ] }
func ( nf TrainModelRowUsables ) Less ( i , j int ) bool {
return nf [ i ] . Acuracy < nf [ j ] . Acuracy
}
2024-04-15 23:04:53 +01:00
func trainModelExp ( c BasePack , model * BaseModel ) ( err error ) {
l := c . GetLogger ( )
db := c . GetDb ( )
2024-01-31 21:48:35 +00:00
2024-03-09 10:52:08 +00:00
var definitions TrainModelRowUsables
2024-01-31 21:48:35 +00:00
2024-04-15 23:04:53 +01:00
definitions , err = GetDbMultitple [ TrainModelRowUsable ] ( db , "model_definition where status=$1 and model_id=$2" , MODEL_DEFINITION_STATUS_INIT , model . Id )
2024-03-09 10:52:08 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to get definitions" )
2024-03-09 10:52:08 +00:00
return
}
2024-01-31 21:48:35 +00:00
if len ( definitions ) == 0 {
2024-04-15 23:04:53 +01:00
l . Error ( "No Definitions defined!" )
return errors . New ( "No Definitions found" )
2024-01-31 21:48:35 +00:00
}
firstRound := true
finished := false
for {
var toRemove ToRemoveList = [ ] int { }
for i , def := range definitions {
2024-03-06 23:33:54 +00:00
ModelDefinitionUpdateStatus ( c , def . Id , MODEL_DEFINITION_STATUS_TRAINING )
accuracy , err := trainDefinitionExp ( c , model , def . Id , ! firstRound )
2024-01-31 21:48:35 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!Err:" , "err" , err )
2024-03-06 23:33:54 +00:00
ModelDefinitionUpdateStatus ( c , def . Id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
2024-01-31 21:48:35 +00:00
toRemove = append ( toRemove , i )
continue
}
2024-03-06 23:33:54 +00:00
def . Epoch += EPOCH_PER_RUN
2024-01-31 21:48:35 +00:00
accuracy = accuracy * 100
2024-03-06 23:33:54 +00:00
def . Acuracy = float64 ( accuracy )
2024-01-31 21:48:35 +00:00
2024-03-06 23:33:54 +00:00
definitions [ i ] . Epoch += EPOCH_PER_RUN
definitions [ i ] . Acuracy = accuracy
2024-01-31 21:48:35 +00:00
2024-03-06 23:33:54 +00:00
if accuracy >= float64 ( def . TargetAccuracy ) {
2024-04-15 23:04:53 +01:00
l . Info ( "Found a definition that reaches target_accuracy!" )
_ , err = db . Exec ( "update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4" , accuracy , MODEL_DEFINITION_STATUS_TRANIED , def . Epoch , def . Id )
2024-01-31 21:48:35 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!" )
return err
2024-01-31 21:48:35 +00:00
}
2024-04-15 23:04:53 +01:00
_ , err = db . Exec ( "update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4" , MODEL_DEFINITION_STATUS_CANCELD_TRAINING , def . Id , model . Id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
2024-01-31 21:48:35 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!" )
return err
2024-01-31 21:48:35 +00:00
}
2024-04-15 23:04:53 +01:00
_ , err = db . Exec ( "update exp_model_head set status=$1 where def_id=$2;" , MODEL_HEAD_STATUS_READY , def . Id )
2024-04-08 14:17:13 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!" )
return err
2024-04-08 14:17:13 +01:00
}
2024-01-31 21:48:35 +00:00
finished = true
break
}
2024-03-06 23:33:54 +00:00
if def . Epoch > MAX_EPOCH {
fmt . Printf ( "Failed to train definition! Accuracy less %f < %d\n" , accuracy , def . TargetAccuracy )
ModelDefinitionUpdateStatus ( c , def . Id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
2024-01-31 21:48:35 +00:00
toRemove = append ( toRemove , i )
continue
}
2024-04-15 23:04:53 +01:00
_ , err = db . Exec ( "update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4" , accuracy , def . Epoch , MODEL_DEFINITION_STATUS_PAUSED_TRAINING , def . Id )
2024-01-31 21:48:35 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!" )
return err
2024-01-31 21:48:35 +00:00
}
}
firstRound = false
if finished {
break
}
sort . Sort ( sort . Reverse ( toRemove ) )
2024-04-15 23:04:53 +01:00
l . Info ( "Round done" , "toRemove" , toRemove )
2024-01-31 21:48:35 +00:00
for _ , n := range toRemove {
definitions = remove ( definitions , n )
}
len_def := len ( definitions )
if len_def == 0 {
break
2024-03-09 09:41:16 +00:00
} else if len_def == 1 {
2024-01-31 21:48:35 +00:00
continue
}
sort . Sort ( sort . Reverse ( definitions ) )
2024-03-06 23:33:54 +00:00
acc := definitions [ 0 ] . Acuracy - 20.0
2024-01-31 21:48:35 +00:00
2024-04-15 23:04:53 +01:00
l . Info ( "Training models, Highest acc" , "acc" , definitions [ 0 ] . Acuracy , "mod_acc" , acc )
2024-01-31 21:48:35 +00:00
toRemove = [ ] int { }
for i , def := range definitions {
2024-03-06 23:33:54 +00:00
if def . Acuracy < acc {
2024-01-31 21:48:35 +00:00
toRemove = append ( toRemove , i )
}
}
2024-04-15 23:04:53 +01:00
l . Info ( "Removing due to accuracy" , "toRemove" , toRemove )
2024-01-31 21:48:35 +00:00
sort . Sort ( sort . Reverse ( toRemove ) )
for _ , n := range toRemove {
2024-04-15 23:04:53 +01:00
l . Warn ( "Removing definition not fast enough learning" , "n" , n )
2024-03-06 23:33:54 +00:00
ModelDefinitionUpdateStatus ( c , definitions [ n ] . Id , MODEL_DEFINITION_STATUS_FAILED_TRAINING )
2024-01-31 21:48:35 +00:00
definitions = remove ( definitions , n )
}
}
2024-03-09 10:52:08 +00:00
var dat JustId
2024-04-15 23:04:53 +01:00
err = GetDBOnce ( db , & dat , "model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;" , model . Id , MODEL_DEFINITION_STATUS_TRANIED )
2024-03-09 10:52:08 +00:00
if err == NotFoundError {
2024-04-08 14:17:13 +01:00
// Set the class status to trained
2024-05-06 01:10:58 +01:00
err = setModelClassStatus ( c , CLASS_STATUS_TO_TRAIN , "model_id=$1 and status=$2;" , model . Id , CLASS_STATUS_TRAINING )
2024-04-08 14:17:13 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "All definitions failed to train! And Failed to set class status" )
return err
2024-04-08 14:17:13 +01:00
}
2024-04-15 23:04:53 +01:00
l . Error ( "All definitions failed to train!" )
return err
2024-03-09 10:52:08 +00:00
} else if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "All definitions failed to train!" )
return err
2024-03-09 10:52:08 +00:00
}
2024-01-31 21:48:35 +00:00
2024-04-15 23:04:53 +01:00
if _ , err = db . Exec ( "update model_definition set status=$1 where id=$2;" , MODEL_DEFINITION_STATUS_READY , dat . Id ) ; err != nil {
l . Error ( "Failed to update model definition" )
return err
2024-01-31 21:48:35 +00:00
}
2024-04-15 23:04:53 +01:00
to_delete , err := GetDbMultitple [ JustId ] ( db , "model_definition where status!=$1 and model_id=$2" , MODEL_DEFINITION_STATUS_READY , model . Id )
2024-03-09 10:52:08 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to select model_definition to delete" )
return err
2024-03-09 10:52:08 +00:00
}
2024-01-31 21:48:35 +00:00
2024-03-09 10:52:08 +00:00
for _ , d := range to_delete {
2024-03-09 09:41:16 +00:00
os . RemoveAll ( path . Join ( "savedData" , model . Id , "defs" , d . Id ) )
2024-03-09 10:52:08 +00:00
}
2024-01-31 21:48:35 +00:00
// TODO Check if returning also works here
2024-04-15 23:04:53 +01:00
if _ , err = db . Exec ( "delete from model_definition where status!=$1 and model_id=$2;" , MODEL_DEFINITION_STATUS_READY , model . Id ) ; err != nil {
l . Error ( "Failed to delete model_definition" )
return err
2024-03-02 12:45:49 +00:00
}
2024-05-09 00:46:42 +01:00
if err = SplitModel ( c , model ) ; err != nil {
2024-05-06 01:10:58 +01:00
err = setModelClassStatus ( c , CLASS_STATUS_TO_TRAIN , "model_id=$1 and status=$2;" , model . Id , CLASS_STATUS_TRAINING )
2024-04-08 14:17:13 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to split the model! And Failed to set class status" )
return err
2024-04-08 14:17:13 +01:00
}
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to split the model" )
return err
2024-01-31 21:48:35 +00:00
}
2024-04-08 14:17:13 +01:00
// Set the class status to trained
2024-05-06 01:10:58 +01:00
err = setModelClassStatus ( c , CLASS_STATUS_TRAINED , "model_id=$1 and status=$2;" , model . Id , CLASS_STATUS_TRAINING )
2024-04-08 14:17:13 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to set class status" )
return err
2024-04-08 14:17:13 +01:00
}
2024-03-02 12:45:49 +00:00
// There should only be one def availabale
def := JustId { }
2024-04-15 23:04:53 +01:00
if err = GetDBOnce ( db , & def , "model_definition where model_id=$1" , model . Id ) ; err != nil {
2024-02-12 14:30:43 +00:00
return
2024-03-02 12:45:49 +00:00
}
2024-02-14 15:11:45 +00:00
2024-03-02 12:45:49 +00:00
// Remove the base model
2024-04-15 23:04:53 +01:00
l . Warn ( "Removing base model for" , "model" , model . Id , "def" , def . Id )
2024-02-14 15:11:45 +00:00
os . RemoveAll ( path . Join ( "savedData" , model . Id , "defs" , def . Id , "model" ) )
os . RemoveAll ( path . Join ( "savedData" , model . Id , "defs" , def . Id , "model.keras" ) )
2024-02-12 14:30:43 +00:00
2024-01-31 21:48:35 +00:00
ModelUpdateStatus ( c , model . Id , READY )
2024-04-15 23:04:53 +01:00
return
2024-01-31 21:48:35 +00:00
}
2024-05-09 00:46:42 +01:00
func SplitModel ( c BasePack , model * BaseModel ) ( err error ) {
2024-04-15 23:04:53 +01:00
db := c . GetDb ( )
l := c . GetLogger ( )
2024-02-12 14:30:43 +00:00
2024-03-02 12:45:49 +00:00
def := JustId { }
2024-04-15 23:04:53 +01:00
if err = GetDBOnce ( db , & def , "model_definition where model_id=$1" , model . Id ) ; err != nil {
2024-03-02 12:45:49 +00:00
return
}
2024-02-12 14:30:43 +00:00
2024-03-02 12:45:49 +00:00
head := JustId { }
2024-04-15 23:04:53 +01:00
if err = GetDBOnce ( db , & head , "exp_model_head where def_id=$1" , def . Id ) ; err != nil {
2024-03-02 12:45:49 +00:00
return
}
2024-02-12 14:30:43 +00:00
// Generate run folder
2024-04-08 14:17:13 +01:00
run_path := path . Join ( "/tmp" , model . Id + "-defs-" + def . Id + "-split" )
2024-02-12 14:30:43 +00:00
err = os . MkdirAll ( run_path , os . ModePerm )
if err != nil {
return
}
// Create python script
f , err := os . Create ( path . Join ( run_path , "run.py" ) )
if err != nil {
return
}
defer f . Close ( )
tmpl , err := template . New ( "python_split_model_template.py" ) . ParseFiles ( "views/py/python_split_model_template.py" )
if err != nil {
return
}
// Copy result around
result_path := path . Join ( getDir ( ) , "savedData" , model . Id , "defs" , def . Id )
2024-03-02 12:45:49 +00:00
// TODO maybe move this to a select count(*)
// Get only fixed lawers
2024-04-15 23:04:53 +01:00
layers , err := db . Query ( "select exp_type from model_definition_layer where def_id=$1 and exp_type=$2 order by layer_order asc;" , def . Id , 1 )
2024-02-12 14:30:43 +00:00
if err != nil {
return
}
defer layers . Close ( )
type layerrow struct {
2024-03-02 12:45:49 +00:00
ExpType int
2024-02-12 14:30:43 +00:00
}
2024-03-02 12:45:49 +00:00
count := - 1
2024-02-12 14:30:43 +00:00
for layers . Next ( ) {
2024-04-19 22:03:14 +01:00
var layerrow layerrow
if err = layers . Scan ( & layerrow . ExpType ) ; err != nil {
return
}
2024-03-02 12:45:49 +00:00
count += 1
2024-04-19 22:03:14 +01:00
if layerrow . ExpType == 2 {
break
}
2024-02-12 14:30:43 +00:00
}
2024-03-02 12:45:49 +00:00
if count == - 1 {
err = errors . New ( "Can not get layers" )
return
}
2024-02-12 14:30:43 +00:00
2024-03-02 12:45:49 +00:00
log . Warn ( "Spliting model" , "def" , def . Id , "head" , head . Id , "count" , count )
2024-02-12 14:30:43 +00:00
2024-03-02 12:45:49 +00:00
basePath := path . Join ( result_path , "base" )
headPath := path . Join ( result_path , "head" , head . Id )
2024-02-12 14:30:43 +00:00
if err = os . MkdirAll ( basePath , os . ModePerm ) ; err != nil {
return
}
if err = os . MkdirAll ( headPath , os . ModePerm ) ; err != nil {
return
}
if err = tmpl . Execute ( f , AnyMap {
2024-03-02 12:45:49 +00:00
"SplitLen" : count ,
"ModelPath" : path . Join ( result_path , "model.keras" ) ,
"BaseModelPath" : basePath ,
"HeadModelPath" : headPath ,
2024-02-12 14:30:43 +00:00
} ) ; err != nil {
return
}
out , err := exec . Command ( "bash" , "-c" , fmt . Sprintf ( "cd %s && python run.py" , run_path ) ) . CombinedOutput ( )
if err != nil {
2024-04-15 23:04:53 +01:00
l . Debug ( string ( out ) )
2024-02-12 14:30:43 +00:00
return
}
2024-04-08 14:17:13 +01:00
os . RemoveAll ( run_path )
2024-04-15 23:04:53 +01:00
l . Info ( "Python finished running" )
2024-03-02 12:45:49 +00:00
return
2024-02-12 14:30:43 +00:00
}
2024-04-15 23:04:53 +01:00
func removeFailedDataPoints ( c BasePack , model * BaseModel ) ( err error ) {
rows , err := c . GetDb ( ) . Query ( "select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;" , model . Id )
2023-10-10 12:28:49 +01:00
if err != nil {
return
}
defer rows . Close ( )
2023-10-19 10:44:13 +01:00
base_path := path . Join ( "savedData" , model . Id , "data" )
2023-10-10 12:28:49 +01:00
for rows . Next ( ) {
var dataPointId string
err = rows . Scan ( & dataPointId )
if err != nil {
return
}
2023-10-20 12:37:56 +01:00
2023-10-21 00:26:52 +01:00
p := path . Join ( base_path , dataPointId + "." + model . Format )
2024-04-15 23:04:53 +01:00
c . GetLogger ( ) . Warn ( "Removing image" , "path" , p )
2023-10-20 12:37:56 +01:00
err = os . RemoveAll ( p )
2023-10-19 10:44:13 +01:00
if err != nil {
return
}
2023-10-10 12:28:49 +01:00
}
2024-04-15 23:04:53 +01:00
_ , err = c . GetDb ( ) . Exec ( "delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;" , model . Id )
2023-10-10 12:28:49 +01:00
return
}
2023-10-19 11:42:38 +01:00
// This generates a definition
2024-04-15 23:04:53 +01:00
func generateDefinition ( c BasePack , model * BaseModel , target_accuracy int , number_of_classes int , complexity int ) ( err error ) {
failed := func ( ) {
2023-10-19 11:42:38 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
2023-10-21 00:26:52 +01:00
}
2023-10-19 11:42:38 +01:00
2024-04-15 23:04:53 +01:00
db := c . GetDb ( )
2024-05-06 01:10:58 +01:00
def , err := MakeDefenition ( db , model . Id , target_accuracy )
2023-10-21 00:26:52 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2023-10-21 00:26:52 +01:00
}
2023-10-19 11:42:38 +01:00
2023-10-22 23:02:39 +01:00
order := 1
2023-10-21 00:26:52 +01:00
2023-10-22 23:02:39 +01:00
// Note the shape of the first layer defines the import size
2024-05-06 01:10:58 +01:00
//_, err = def.MakeLayer(db, order, LAYER_INPUT, ShapeToString(model.Width, model.Height, model.ImageMode))
_ , err = def . MakeLayer ( db , order , LAYER_INPUT , ShapeToString ( 3 , model . Width , model . Height ) )
if err != nil {
failed ( )
return
2023-10-21 00:26:52 +01:00
}
2024-05-06 01:10:58 +01:00
order ++
2023-10-19 11:42:38 +01:00
2024-05-09 00:46:42 +01:00
loop := max ( 1 , int ( ( math . Log ( float64 ( model . Width ) ) / math . Log ( float64 ( 10 ) ) ) ) )
for i := 0 ; i < loop ; i ++ {
_ , err = def . MakeLayer ( db , order , LAYER_SIMPLE_BLOCK , "" )
order ++
2023-10-19 11:42:38 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2023-10-19 11:42:38 +01:00
}
2024-05-09 00:46:42 +01:00
}
2023-10-21 00:26:52 +01:00
2024-05-09 00:46:42 +01:00
_ , err = def . MakeLayer ( db , order , LAYER_FLATTEN , "" )
if err != nil {
failed ( )
return
}
order ++
2023-10-21 00:26:52 +01:00
2024-05-09 00:46:42 +01:00
loop = int ( ( math . Log ( float64 ( number_of_classes ) ) / math . Log ( float64 ( 10 ) ) ) / 2 )
if loop == 0 {
loop = 1
}
for i := 0 ; i < loop ; i ++ {
_ , err = def . MakeLayer ( db , order , LAYER_DENSE , ShapeToString ( number_of_classes * ( loop - i ) ) )
order ++
2023-10-21 00:26:52 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2023-10-21 00:26:52 +01:00
}
}
2023-10-19 11:42:38 +01:00
2024-05-06 01:10:58 +01:00
return def . UpdateStatus ( db , DEFINITION_STATUS_INIT )
2023-10-19 11:42:38 +01:00
}
2024-04-15 23:04:53 +01:00
func generateDefinitions ( c BasePack , model * BaseModel , target_accuracy int , number_of_models int ) ( err error ) {
2024-03-09 09:41:16 +00:00
cls , err := model_classes . ListClasses ( c , model . Id )
2023-10-19 11:42:38 +01:00
if err != nil {
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
2024-04-15 23:04:53 +01:00
return
2023-10-19 11:42:38 +01:00
}
2023-10-20 12:37:56 +01:00
err = removeFailedDataPoints ( c , model )
2023-10-19 11:42:38 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
return
2023-10-19 11:42:38 +01:00
}
2023-10-22 23:02:39 +01:00
cls_len := len ( cls )
2023-10-19 11:42:38 +01:00
2023-10-22 23:02:39 +01:00
if number_of_models == 1 {
if model . Width < 100 && model . Height < 100 && cls_len < 30 {
generateDefinition ( c , model , target_accuracy , cls_len , 0 )
} else if model . Width > 100 && model . Height > 100 {
generateDefinition ( c , model , target_accuracy , cls_len , 2 )
} else {
generateDefinition ( c , model , target_accuracy , cls_len , 1 )
}
} else if number_of_models == 3 {
for i := 0 ; i < number_of_models ; i ++ {
generateDefinition ( c , model , target_accuracy , cls_len , i )
}
} else {
// TODO handle incrisea the complexity
for i := 0 ; i < number_of_models ; i ++ {
generateDefinition ( c , model , target_accuracy , cls_len , 0 )
}
}
2023-10-21 00:26:52 +01:00
return nil
2023-10-19 11:42:38 +01:00
}
2024-04-17 17:46:43 +01:00
func ExpModelHeadUpdateStatus ( db db . Db , id string , status ModelDefinitionStatus ) ( err error ) {
2024-01-31 21:48:35 +00:00
_ , err = db . Exec ( "update model_definition set status = $1 where id = $2" , status , id )
return
}
// This generates a definition
2024-04-15 23:04:53 +01:00
func generateExpandableDefinition ( c BasePack , model * BaseModel , target_accuracy int , number_of_classes int , complexity int ) ( err error ) {
l := c . GetLogger ( )
db := c . GetDb ( )
l . Info ( "Generating expandable new definition for model" , "id" , model . Id , "complexity" , complexity )
2024-02-08 18:20:58 +00:00
2024-04-15 23:04:53 +01:00
failed := func ( ) {
2024-01-31 21:48:35 +00:00
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
}
if complexity == 0 {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-01-31 21:48:35 +00:00
}
2024-05-06 01:10:58 +01:00
def , err := MakeDefenition ( c . GetDb ( ) , model . Id , target_accuracy )
2024-01-31 21:48:35 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-01-31 21:48:35 +00:00
}
2024-05-06 01:10:58 +01:00
def_id := def . Id
2024-01-31 21:48:35 +00:00
order := 1
2024-05-09 00:46:42 +01:00
err = MakeLayerExpandable ( c . GetDb ( ) , def_id , order , LAYER_INPUT , ShapeToString ( 3 , model . Width , model . Height ) , 1 )
if err != nil {
failed ( )
return
2024-02-02 16:16:26 +00:00
}
2024-01-31 21:48:35 +00:00
order ++
2024-02-02 16:16:26 +00:00
// handle the errors inside the pervious if block
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-02-02 16:16:26 +00:00
}
2024-01-31 21:48:35 +00:00
2024-02-02 16:16:26 +00:00
// Create the blocks
loop := int ( ( math . Log ( float64 ( model . Width ) ) / math . Log ( float64 ( 10 ) ) ) )
2024-02-08 18:20:58 +00:00
2024-04-18 15:01:36 +01:00
/ * if model . Width < 50 && model . Height < 50 {
loop = 0
} * /
2024-02-08 18:20:58 +00:00
2024-03-02 12:45:49 +00:00
log . Info ( "Size of the simple block" , "loop" , loop )
2024-02-08 18:20:58 +00:00
2024-03-02 12:45:49 +00:00
//loop = max(loop, 3)
2024-01-31 21:48:35 +00:00
2024-02-02 16:16:26 +00:00
for i := 0 ; i < loop ; i ++ {
2024-04-15 23:04:53 +01:00
err = MakeLayerExpandable ( db , def_id , order , LAYER_SIMPLE_BLOCK , "" , 1 )
2024-02-02 16:16:26 +00:00
order ++
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-02-02 16:16:26 +00:00
}
}
2024-01-31 21:48:35 +00:00
2024-02-02 16:16:26 +00:00
// Flatten the blocks into dense
2024-04-15 23:04:53 +01:00
err = MakeLayerExpandable ( db , def_id , order , LAYER_FLATTEN , "" , 1 )
2024-02-02 16:16:26 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-02-02 16:16:26 +00:00
}
order ++
2024-01-31 21:48:35 +00:00
2024-02-02 16:16:26 +00:00
// Flatten the blocks into dense
2024-05-09 00:46:42 +01:00
err = MakeLayerExpandable ( db , def_id , order , LAYER_DENSE , ShapeToString ( number_of_classes * 2 ) , 1 )
2024-02-02 16:16:26 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-02-02 16:16:26 +00:00
}
order ++
loop = int ( ( math . Log ( float64 ( number_of_classes ) ) / math . Log ( float64 ( 10 ) ) ) / 2 )
2024-02-08 18:20:58 +00:00
2024-03-02 12:45:49 +00:00
log . Info ( "Size of the dense layers" , "loop" , loop )
2024-02-08 18:20:58 +00:00
2024-04-19 22:03:14 +01:00
loop = max ( loop , 3 )
2024-02-02 16:16:26 +00:00
for i := 0 ; i < loop ; i ++ {
2024-05-09 00:46:42 +01:00
err = MakeLayerExpandable ( db , def_id , order , LAYER_DENSE , ShapeToString ( number_of_classes * ( loop - i ) * 2 ) , 2 )
2024-02-02 16:16:26 +00:00
order ++
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-02-02 16:16:26 +00:00
}
}
2024-03-02 12:45:49 +00:00
2024-04-15 23:04:53 +01:00
var newHead = struct {
DefId string ` db:"def_id" `
RangeStart int ` db:"range_start" `
RangeEnd int ` db:"range_end" `
Status ModelDefinitionStatus ` db:"status" `
} {
def_id , 0 , number_of_classes - 1 , MODEL_DEFINITION_STATUS_INIT ,
}
_ , err = InsertReturnId ( c . GetDb ( ) , & newHead , "exp_model_head" , "id" )
2024-02-02 16:16:26 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-02-02 16:16:26 +00:00
}
2024-01-31 21:48:35 +00:00
err = ModelDefinitionUpdateStatus ( c , def_id , MODEL_DEFINITION_STATUS_INIT )
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-01-31 21:48:35 +00:00
}
2024-04-15 23:04:53 +01:00
return
2024-01-31 21:48:35 +00:00
}
2024-03-02 12:45:49 +00:00
// TODO make this json friendy
2024-04-15 23:04:53 +01:00
func generateExpandableDefinitions ( c BasePack , model * BaseModel , target_accuracy int , number_of_models int ) ( err error ) {
2024-03-09 09:41:16 +00:00
cls , err := model_classes . ListClasses ( c , model . Id )
2024-01-31 21:48:35 +00:00
if err != nil {
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
// TODO improve this response
2024-04-15 23:04:53 +01:00
return
2024-01-31 21:48:35 +00:00
}
err = removeFailedDataPoints ( c , model )
if err != nil {
2024-04-15 23:04:53 +01:00
return
2024-01-31 21:48:35 +00:00
}
cls_len := len ( cls )
if number_of_models == 1 {
if model . Width > 100 && model . Height > 100 {
generateExpandableDefinition ( c , model , target_accuracy , cls_len , 2 )
} else {
2024-04-18 15:01:36 +01:00
generateExpandableDefinition ( c , model , target_accuracy , cls_len , 2 )
2024-01-31 21:48:35 +00:00
}
} else if number_of_models == 3 {
for i := 0 ; i < number_of_models ; i ++ {
generateExpandableDefinition ( c , model , target_accuracy , cls_len , i )
}
} else {
// TODO handle incrisea the complexity
for i := 0 ; i < number_of_models ; i ++ {
2024-04-18 15:01:36 +01:00
generateExpandableDefinition ( c , model , target_accuracy , cls_len , 2 )
2024-01-31 21:48:35 +00:00
}
}
return nil
}
2024-04-16 17:48:52 +01:00
func ResetClasses ( c BasePack , model * BaseModel ) {
2024-05-06 01:10:58 +01:00
_ , err := c . GetDb ( ) . Exec ( "update model_classes set status=$1 where status=$2 and model_id=$3" , CLASS_STATUS_TO_TRAIN , CLASS_STATUS_TRAINING , model . Id )
2024-04-08 14:17:13 +01:00
if err != nil {
2024-04-16 17:48:52 +01:00
c . GetLogger ( ) . Error ( "Error while reseting the classes" , "error" , err )
2024-04-08 14:17:13 +01:00
}
}
func trainExpandable ( c * Context , model * BaseModel ) {
var err error = nil
failed := func ( msg string ) {
c . Logger . Error ( msg , "err" , err )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( c , model . Id , int ( FAILED_TRAINING ) )
2024-04-08 14:17:13 +01:00
ResetClasses ( c , model )
}
var definitions TrainModelRowUsables
definitions , err = GetDbMultitple [ TrainModelRowUsable ] ( c , "model_definition where status=$1 and model_id=$2" , MODEL_DEFINITION_STATUS_READY , model . Id )
if err != nil {
failed ( "Failed to get definitions" )
return
}
if len ( definitions ) != 1 {
failed ( "There should only be one definition available!" )
return
}
firstRound := true
def := definitions [ 0 ]
epoch := 0
for {
acc , err := trainDefinitionExp ( c , model , def . Id , ! firstRound )
if err != nil {
failed ( "Failed to train definition!" )
return
}
epoch += EPOCH_PER_RUN
if float64 ( acc * 100 ) >= float64 ( def . Acuracy ) {
c . Logger . Info ( "Found a definition that reaches target_accuracy!" )
_ , err = c . Db . Exec ( "update exp_model_head set status=$1 where def_id=$2 and status=$3;" , MODEL_HEAD_STATUS_READY , def . Id , MODEL_HEAD_STATUS_TRAINING )
if err != nil {
failed ( "Failed to train definition!" )
return
}
break
} else if def . Epoch > MAX_EPOCH {
failed ( fmt . Sprintf ( "Failed to train definition! Accuracy less %f < %d\n" , acc * 100 , def . TargetAccuracy ) )
return
}
}
// Set the class status to trained
2024-05-06 01:10:58 +01:00
err = setModelClassStatus ( c , CLASS_STATUS_TRAINED , "model_id=$1 and status=$2;" , model . Id , CLASS_STATUS_TRAINING )
2024-04-08 14:17:13 +01:00
if err != nil {
failed ( "Failed to set class status" )
return
}
ModelUpdateStatus ( c , model . Id , READY )
}
2024-04-15 23:04:53 +01:00
func RunTaskTrain ( b BasePack , task Task ) ( err error ) {
l := b . GetLogger ( )
2024-03-02 12:45:49 +00:00
2024-04-17 14:56:57 +01:00
model , err := GetBaseModel ( b . GetDb ( ) , * task . ModelId )
2024-04-15 23:04:53 +01:00
if err != nil {
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Failed to get model information" )
l . Error ( "Failed to get model information" , "err" , err )
return err
}
2023-09-26 20:15:28 +01:00
2024-04-15 23:04:53 +01:00
if model . Status != TRAINING {
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Model not in the correct status for training" )
return errors . New ( "Model not in the right status" )
}
task . UpdateStatusLog ( b , TASK_RUNNING , "Training model" )
2023-09-26 20:15:28 +01:00
2024-04-15 23:04:53 +01:00
var dat struct {
NumberOfModels int
Accuracy int
}
err = json . Unmarshal ( [ ] byte ( task . ExtraTaskInfo ) , & dat )
if err != nil {
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Failed to get model extra information" )
}
if model . ModelType == 2 {
full_error := generateExpandableDefinitions ( b , model , dat . Accuracy , dat . NumberOfModels )
if full_error != nil {
l . Error ( "Failed to generate defintions" , "err" , full_error )
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Failed generate model" )
return errors . New ( "Failed to generate definitions" )
}
} else {
full_error := generateDefinitions ( b , model , dat . Accuracy , dat . NumberOfModels )
if full_error != nil {
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Failed generate model" )
return errors . New ( "Failed to generate definitions" )
2024-03-09 10:52:08 +00:00
}
2024-04-15 23:04:53 +01:00
}
2024-01-31 21:48:35 +00:00
2024-04-15 23:04:53 +01:00
if model . ModelType == 2 {
err = trainModelExp ( b , model )
} else {
err = trainModel ( b , model )
}
if err != nil {
l . Error ( "Failed to train model" , "err" , err )
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Failed generate model" )
2024-05-06 01:10:58 +01:00
ModelUpdateStatus ( b , model . Id , int ( FAILED_TRAINING ) )
2024-04-15 23:04:53 +01:00
return
}
task . UpdateStatusLog ( b , TASK_DONE , "Model finished training" )
return
}
2024-04-16 17:48:52 +01:00
func RunTaskRetrain ( b BasePack , task Task ) ( err error ) {
2024-04-17 14:56:57 +01:00
model , err := GetBaseModel ( b . GetDb ( ) , * task . ModelId )
2024-04-16 17:48:52 +01:00
if err != nil {
return err
} else if model . Status != READY_RETRAIN {
return errors . New ( "Model in invalid status for re-training" )
}
l := b . GetLogger ( )
db := b . GetDb ( )
failed := func ( ) {
ResetClasses ( b , model )
ModelUpdateStatus ( b , model . Id , READY_RETRAIN_FAILED )
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Model failed retraining" )
l . Error ( "Failed to retrain" , "err" , err )
}
task . UpdateStatusLog ( b , TASK_RUNNING , "Model retraining" )
2024-04-18 15:01:36 +01:00
var defData struct {
Id string ` db:"md.id" `
TargetAcuuracy float64 ` db:"md.target_accuracy" `
}
err = GetDBOnce ( db , & defData , "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;" , task . ModelId )
2024-04-16 17:48:52 +01:00
if err != nil {
failed ( )
return
}
2024-04-19 22:03:14 +01:00
failed = func ( ) {
ResetClasses ( b , model )
ModelUpdateStatus ( b , model . Id , READY_RETRAIN_FAILED )
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Model failed retraining" )
_ , err_ := db . Exec ( "delete from exp_model_head where def_id=$1 and status in (2,3)" , defData . Id )
if err_ != nil {
panic ( err_ )
}
l . Error ( "Failed to retrain" , "err" , err )
}
2024-04-18 15:01:36 +01:00
var acc float64 = 0
var epocs = 0
// TODO make max epochs come from db
2024-04-19 22:03:14 +01:00
// TODO re increase the target accuracy
for acc * 100 < defData . TargetAcuuracy - 5 && epocs < 10 {
2024-04-18 15:01:36 +01:00
// This is something I have to check
acc , err = trainDefinitionExpandExp ( b , model , defData . Id , epocs > 0 )
if err != nil {
failed ( )
return
}
l . Info ( "Retrained model" , "accuracy" , acc , "target" , defData . TargetAcuuracy )
epocs += 1
}
if acc * 100 < defData . TargetAcuuracy {
l . Error ( "Model never achived targetd accuracy" , "acc" , acc * 100 , "target" , defData . TargetAcuuracy )
2024-04-16 17:48:52 +01:00
failed ( )
return
}
// TODO check accuracy
err = UpdateStatus ( db , "models" , model . Id , READY )
if err != nil {
failed ( )
return
}
l . Info ( "Model updaded" )
2024-05-06 01:10:58 +01:00
_ , err = db . Exec ( "update model_classes set status=$1 where status=$2 and model_id=$3" , CLASS_STATUS_TRAINED , CLASS_STATUS_TRAINING , model . Id )
2024-04-16 17:48:52 +01:00
if err != nil {
l . Error ( "Error while updating the classes" , "error" , err )
2024-05-09 00:46:42 +01:00
failed ( )
return
}
_ , err = db . Exec ( "update exp_model_head set status=$1 where status=$2 and model_id=$3" , MODEL_HEAD_STATUS_READY , MODEL_HEAD_STATUS_TRAINING , model . Id )
if err != nil {
l . Error ( "Error while updating the classes" , "error" , err )
2024-04-16 17:48:52 +01:00
failed ( )
return
}
task . UpdateStatusLog ( b , TASK_DONE , "Model finished retraining" )
return
}
2024-04-15 23:04:53 +01:00
func handleTrain ( handle * Handle ) {
type TrainReq struct {
Id string ` json:"id" validate:"required" `
ModelType string ` json:"model_type" `
NumberOfModels int ` json:"number_of_models" `
Accuracy int ` json:"accuracy" `
}
PostAuthJson ( handle , "/models/train" , User_Normal , func ( c * Context , dat * TrainReq ) * Error {
2024-03-09 10:52:08 +00:00
modelTypeId := 1
if dat . ModelType == "expandable" {
modelTypeId = 2
} else if dat . ModelType != "simple" {
return c . JsonBadRequest ( "Invalid model type!" )
2023-09-27 21:20:39 +01:00
}
2024-03-09 10:52:08 +00:00
model , err := GetBaseModel ( c . Db , dat . Id )
2023-09-26 20:15:28 +01:00
if err == ModelNotFoundError {
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Model not found" )
2023-09-26 20:15:28 +01:00
} else if err != nil {
2024-04-15 23:04:53 +01:00
return c . E500M ( "Failed to get model information" , err )
2024-04-17 14:56:57 +01:00
} else if model . CanTrain == 0 {
return c . JsonBadRequest ( "Model can not be trained!" )
2023-09-26 20:15:28 +01:00
}
if model . Status != CONFIRM_PRE_TRAINING {
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Model in invalid status for training" )
2023-09-26 20:15:28 +01:00
}
2024-04-15 23:04:53 +01:00
_ , err = c . Db . Exec ( "update models set status = $1, model_type = $2 where id = $3" , TRAINING , modelTypeId , model . Id )
if err != nil {
return c . E500M ( "Failed to update model_status" , err )
2024-02-02 16:16:26 +00:00
}
2023-09-26 20:15:28 +01:00
2024-04-15 23:04:53 +01:00
text , err := json . Marshal ( struct {
NumberOfModels int
Accuracy int
} {
NumberOfModels : dat . NumberOfModels ,
Accuracy : dat . Accuracy ,
} )
if err != nil {
return c . E500M ( "Failed create data" , err )
2024-02-02 16:16:26 +00:00
}
2023-09-27 18:07:04 +01:00
2024-04-15 23:04:53 +01:00
type CreateNewTask struct {
UserId string ` db:"user_id" `
ModelId string ` db:"model_id" `
TaskType TaskType ` db:"task_type" `
Status int ` db:"status" `
ExtraTaskInfo string ` db:"extra_task_info" `
}
newTask := CreateNewTask {
UserId : c . User . Id ,
ModelId : model . Id ,
TaskType : TASK_TYPE_TRAINING ,
Status : 1 ,
ExtraTaskInfo : string ( text ) ,
}
id , err := InsertReturnId ( c , & newTask , "tasks" , "id" )
2024-02-02 16:16:26 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
return c . E500M ( "Failed to create task" , err )
2024-02-02 16:16:26 +00:00
}
2023-10-12 12:08:12 +01:00
2024-04-15 23:04:53 +01:00
return c . SendJSON ( id )
2024-03-09 10:52:08 +00:00
} )
2023-10-12 12:08:12 +01:00
2024-04-16 16:02:57 +01:00
PostAuthJson ( handle , "/model/train/retrain" , User_Normal , func ( c * Context , dat * JustId ) * Error {
model , err := GetBaseModel ( c . Db , dat . Id )
if err == ModelNotFoundError {
return c . JsonBadRequest ( "Model not found" )
} else if err != nil {
return c . E500M ( "Faield to get model" , err )
} else if model . Status != READY && model . Status != READY_RETRAIN_FAILED && model . Status != READY_ALTERATION_FAILED {
return c . JsonBadRequest ( "Model in invalid status for re-training" )
2024-04-17 14:56:57 +01:00
} else if model . CanTrain == 0 {
return c . JsonBadRequest ( "Model can not be trained!" )
2024-04-16 16:02:57 +01:00
}
c . Logger . Info ( "Expanding definitions for models" , "id" , model . Id )
classesUpdated := false
failed := func ( ) * Error {
if classesUpdated {
ResetClasses ( c , model )
}
ModelUpdateStatus ( c , model . Id , READY_RETRAIN_FAILED )
return c . E500M ( "Failed to retrain model" , err )
}
var def struct {
Id string
TargetAccuracy int ` db:"target_accuracy" `
}
err = GetDBOnce ( c , & def , "model_definition where model_id=$1;" , model . Id )
if err != nil {
return failed ( )
}
type C struct {
Id string
ClassOrder int ` db:"class_order" `
}
err = c . StartTx ( )
if err != nil {
return failed ( )
}
classes , err := GetDbMultitple [ C ] (
c ,
"model_classes where model_id=$1 and status=$2 order by class_order asc" ,
model . Id ,
2024-05-06 01:10:58 +01:00
CLASS_STATUS_TO_TRAIN ,
2024-04-16 16:02:57 +01:00
)
if err != nil {
_err := c . RollbackTx ( )
if _err != nil {
c . Logger . Error ( "Two errors happended rollback failed" , "err" , _err )
}
return failed ( )
}
if len ( classes ) == 0 {
c . Logger . Error ( "No classes are available!" )
_err := c . RollbackTx ( )
if _err != nil {
c . Logger . Error ( "Two errors happended rollback failed" , "err" , _err )
}
return failed ( )
}
//Update the classes
{
2024-05-06 01:10:58 +01:00
_ , err = c . Exec ( "update model_classes set status=$1 where status=$2 and model_id=$3" , CLASS_STATUS_TRAINING , CLASS_STATUS_TO_TRAIN , model . Id )
2024-04-16 16:02:57 +01:00
if err != nil {
_err := c . RollbackTx ( )
if _err != nil {
c . Logger . Error ( "Two errors happended rollback failed" , "err" , _err )
}
return failed ( )
}
err = c . CommitTx ( )
if err != nil {
_err := c . RollbackTx ( )
if _err != nil {
c . Logger . Error ( "Two errors happended rollback failed" , "err" , _err )
}
return failed ( )
}
classesUpdated = true
}
var newHead = struct {
DefId string ` db:"def_id" `
RangeStart int ` db:"range_start" `
RangeEnd int ` db:"range_end" `
Status ModelDefinitionStatus ` db:"status" `
} {
def . Id , classes [ 0 ] . ClassOrder , classes [ len ( classes ) - 1 ] . ClassOrder , MODEL_DEFINITION_STATUS_INIT ,
}
_ , err = InsertReturnId ( c . GetDb ( ) , & newHead , "exp_model_head" , "id" )
if err != nil {
return failed ( )
}
_ , err = c . Db . Exec ( "update models set status=$1 where id=$2;" , READY_RETRAIN , model . Id )
if err != nil {
2024-04-16 17:48:52 +01:00
return c . E500M ( "Failed to update model status" , err )
}
newTask := struct {
UserId string ` db:"user_id" `
ModelId string ` db:"model_id" `
TaskType TaskType ` db:"task_type" `
Status int ` db:"status" `
} {
UserId : c . User . Id ,
ModelId : model . Id ,
TaskType : TASK_TYPE_RETRAINING ,
Status : 1 ,
}
id , err := InsertReturnId ( c , & newTask , "tasks" , "id" )
if err != nil {
return c . E500M ( "Failed to create task" , err )
2024-04-16 16:02:57 +01:00
}
2024-04-16 17:48:52 +01:00
return c . SendJSON ( JustId { Id : id } )
2024-04-16 16:02:57 +01:00
} )
2024-04-08 14:17:13 +01:00
2024-03-09 10:52:08 +00:00
handle . Get ( "/model/epoch/update" , func ( c * Context ) * Error {
f := c . R . URL . Query ( )
2023-10-12 12:08:12 +01:00
2023-10-22 23:02:39 +01:00
accuracy := 0.0
2023-10-21 00:26:52 +01:00
2023-10-22 23:02:39 +01:00
if ! CheckId ( f , "model_id" ) || ! CheckId ( f , "definition" ) || CheckEmpty ( f , "epoch" ) || ! CheckFloat64 ( f , "accuracy" , & accuracy ) {
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Invalid: model_id or definition or epoch or accuracy" )
2023-10-19 10:44:13 +01:00
}
2023-10-12 12:08:12 +01:00
2023-10-22 23:02:39 +01:00
accuracy = accuracy * 100
2023-10-21 00:26:52 +01:00
2023-10-12 12:08:12 +01:00
model_id := f . Get ( "model_id" )
def_id := f . Get ( "definition" )
2023-10-19 10:44:13 +01:00
epoch , err := strconv . Atoi ( f . Get ( "epoch" ) )
if err != nil {
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Epoch is not a number" )
2023-10-19 10:44:13 +01:00
}
rows , err := c . Db . Query ( "select md.status from model_definition as md where md.model_id=$1 and md.id=$2" , model_id , def_id )
if err != nil {
return c . Error500 ( err )
}
defer rows . Close ( )
if ! rows . Next ( ) {
c . Logger . Error ( "Could not get status of model definition" )
return c . Error500 ( nil )
}
var status int
err = rows . Scan ( & status )
if err != nil {
return c . Error500 ( err )
}
if status != 3 {
c . Logger . Warn ( "Definition not on status 3(training)" , "status" , status )
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Definition not on status 3(training)" )
2023-10-19 10:44:13 +01:00
}
2024-04-08 17:45:32 +01:00
c . Logger . Debug ( "Updated model_definition!" , "model" , model_id , "progress" , epoch , "accuracy" , accuracy )
2023-10-21 00:26:52 +01:00
_ , err = c . Db . Exec ( "update model_definition set epoch_progress=$1, accuracy=$2 where id=$3" , epoch , accuracy , def_id )
2023-10-19 10:44:13 +01:00
if err != nil {
return c . Error500 ( err )
}
2024-04-08 17:45:32 +01:00
2024-04-14 14:51:16 +01:00
c . ShowMessage = false
2023-10-12 12:08:12 +01:00
return nil
} )
2024-02-05 16:42:23 +00:00
2024-03-09 10:52:08 +00:00
handle . Get ( "/model/head/epoch/update" , func ( c * Context ) * Error {
f := c . R . URL . Query ( )
2024-02-05 16:42:23 +00:00
accuracy := 0.0
if ! CheckId ( f , "head_id" ) || CheckEmpty ( f , "epoch" ) || ! CheckFloat64 ( f , "accuracy" , & accuracy ) {
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Invalid: model_id or definition or epoch or accuracy" )
2024-02-05 16:42:23 +00:00
}
accuracy = accuracy * 100
head_id := f . Get ( "head_id" )
epoch , err := strconv . Atoi ( f . Get ( "epoch" ) )
if err != nil {
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Epoch is not a number" )
2024-02-05 16:42:23 +00:00
}
rows , err := c . Db . Query ( "select hd.status from exp_model_head as hd where hd.id=$1;" , head_id )
if err != nil {
return c . Error500 ( err )
}
defer rows . Close ( )
if ! rows . Next ( ) {
c . Logger . Error ( "Could not get status of model head" )
return c . Error500 ( nil )
}
var status int
err = rows . Scan ( & status )
if err != nil {
return c . Error500 ( err )
}
if status != 3 {
c . Logger . Warn ( "Head not on status 3(training)" , "status" , status )
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Head not on status 3(training)" )
2024-02-05 16:42:23 +00:00
}
2024-04-08 17:45:32 +01:00
c . Logger . Debug ( "Updated model_head!" , "head" , head_id , "progress" , epoch , "accuracy" , accuracy )
2024-02-05 16:42:23 +00:00
_ , err = c . Db . Exec ( "update exp_model_head set epoch_progress=$1, accuracy=$2 where id=$3" , epoch , accuracy , head_id )
if err != nil {
return c . Error500 ( err )
}
2024-04-08 17:45:32 +01:00
2024-04-14 14:51:16 +01:00
c . ShowMessage = false
2024-02-05 16:42:23 +00:00
return nil
} )
2023-09-26 20:15:28 +01:00
}