2023-09-26 20:15:28 +01:00
package models_train
import (
"fmt"
2023-09-27 18:07:04 +01:00
"io"
2023-09-26 20:15:28 +01:00
"net/http"
"os"
2023-09-27 18:07:04 +01:00
"os/exec"
2023-09-26 20:15:28 +01:00
"path"
2023-09-27 18:07:04 +01:00
"strconv"
2023-09-26 20:15:28 +01:00
"strings"
"text/template"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
/ * import (
tf "github.com/galeone/tensorflow/tensorflow/go"
tg "github.com/galeone/tfgo"
"github.com/galeone/tfgo/image"
"github.com/galeone/tfgo/image/filter"
"github.com/galeone/tfgo/image/padding"
) * /
func getDir ( ) string {
dir , err := os . Getwd ( )
if err != nil {
panic ( err )
}
return dir
}
func shapeToSize ( shape string ) string {
split := strings . Split ( shape , "," )
return strings . Join ( split [ : len ( split ) - 1 ] , "," )
}
func handleTest ( handle * Handle ) {
handle . Post ( "/models/train/test" , func ( w http . ResponseWriter , r * http . Request , c * Context ) * Error {
if ! CheckAuthLevel ( 1 , w , r , c ) {
return nil
}
id , err := GetIdFromUrl ( r , "id" )
if err != nil {
return ErrorCode ( err , 400 , c . AddMap ( nil ) )
}
rows , err := handle . Db . Query ( "select mc.name, mdp.file_path from model_classes as mc join model_data_point as mdp on mdp.class_id = mc.id where mdp.model_mode = 1 and mc.model_id = $1 limit 1;" , id )
if err != nil {
return Error500 ( err )
}
defer rows . Close ( )
if ! rows . Next ( ) {
return Error500 ( err )
}
var name string
var file_path string
err = rows . Scan ( & name , & file_path )
if err != nil {
return Error500 ( err )
}
file_path = strings . Replace ( file_path , "file://" , "" , 1 )
img_path := path . Join ( "savedData" , id , "data" , "training" , name , file_path )
fmt . Printf ( "%s\n" , img_path )
definitions , err := handle . Db . Query ( "select id from model_definition where model_id=$1 and status=2 limit 1;" , id )
if err != nil {
return Error500 ( err )
}
defer definitions . Close ( )
if ! definitions . Next ( ) {
fmt . Println ( "Did not find definition" )
return Error500 ( nil )
}
var definition_id string
if err = definitions . Scan ( & definition_id ) ; err != nil {
return Error500 ( err )
}
layers , err := handle . Db . Query ( "select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;" , definition_id )
if err != nil {
return Error500 ( err )
}
defer layers . Close ( )
type layerrow struct {
LayerType int
Shape string
}
got := [ ] layerrow { }
for layers . Next ( ) {
var row = layerrow { }
if err = layers . Scan ( & row . LayerType , & row . Shape ) ; err != nil {
return Error500 ( err )
}
row . Shape = shapeToSize ( row . Shape )
got = append ( got , row )
}
// Generate folder
2023-09-27 18:07:04 +01:00
run_path := path . Join ( "/tmp" , id , "defs" , definition_id )
err = os . MkdirAll ( run_path , os . ModePerm )
2023-09-26 20:15:28 +01:00
if err != nil {
return Error500 ( err )
}
2023-09-27 18:07:04 +01:00
f , err := os . Create ( path . Join ( run_path , "run.py" ) )
2023-09-26 20:15:28 +01:00
if err != nil {
return Error500 ( err )
}
defer f . Close ( )
2023-09-27 18:07:04 +01:00
fmt . Printf ( "Using path: %s\n" , run_path )
2023-09-26 20:15:28 +01:00
tmpl , err := template . New ( "python_model_template.py" ) . ParseFiles ( "views/py/python_model_template.py" )
if err != nil {
return Error500 ( err )
}
if err = tmpl . Execute ( f , AnyMap {
"Layers" : got ,
"Size" : got [ 0 ] . Shape ,
"DataDir" : path . Join ( getDir ( ) , "savedData" , id , "data" , "training" ) ,
} ) ; err != nil {
return Error500 ( err )
}
2023-09-27 18:07:04 +01:00
cmd := exec . Command ( "bash" , "-c" , fmt . Sprintf ( "cd %s && python run.py" , run_path ) )
_ , err = cmd . Output ( )
if err != nil {
return Error500 ( err )
}
result_path := path . Join ( "savedData" , id , "defs" , definition_id )
os . MkdirAll ( result_path , os . ModePerm )
err = exec . Command ( "cp" , path . Join ( run_path , "model.keras" ) , path . Join ( result_path , "model.keras" ) ) . Run ( )
if err != nil {
return Error500 ( err )
}
accuracy_file , err := os . Open ( path . Join ( run_path , "accuracy.val" ) )
if err != nil {
return Error500 ( err )
}
defer accuracy_file . Close ( )
accuracy_file_bytes , err := io . ReadAll ( accuracy_file )
if err != nil {
return Error500 ( err )
}
accuracy , err := strconv . ParseFloat ( string ( accuracy_file_bytes ) , 64 )
if err != nil {
return Error500 ( err )
}
w . Write ( [ ] byte ( strconv . FormatFloat ( accuracy , 'f' , - 1 , 64 ) ) )
2023-09-26 20:15:28 +01:00
return nil
} )
}