feat: closes #24

This commit is contained in:
Andre Henriques 2023-10-03 19:02:02 +01:00
parent f68f0e2444
commit 84b9c40a72
4 changed files with 26 additions and 38 deletions

View File

@ -110,7 +110,7 @@ func handleRun(handle *Handle) {
root := tg.NewRoot()
tf_img := ReadPNG(root, path.Join(run_path, "img.png"), 3)
tf_img := ReadPNG(root, path.Join(run_path, "img.png"), int64(model.ImageMode))
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
inputImage, err:= tf.NewTensor(exec_results[0].Value())
@ -130,8 +130,6 @@ func handleRun(handle *Handle) {
vi := 0
var predictions = results[0].Value().([][]float32)[0]
fmt.Println(predictions)
for i, v := range predictions {
if v > vmax {
vi = i
@ -155,7 +153,6 @@ func handleRun(handle *Handle) {
var name string
if err = rows.Scan(&name); err != nil { return nil }
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
"Model": model,
"Result": name,

View File

@ -19,27 +19,12 @@ import (
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
id = ""
_, err = db.Exec("insert into model_definition (model_id, target_accuracy) values ($1, $2);", model_id, target_accuracy)
if err != nil {
return
}
rows, err := db.Query("select id from model_definition where model_id=$1 order by created_on DESC;", model_id)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return id, errors.New("Something wrong!")
}
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
if err != nil { return }
defer rows.Close()
if !rows.Next() { return id, errors.New("Something wrong!") }
err = rows.Scan(&id)
if err != nil {
return
}
return
return
}
type ModelDefinitionStatus int
@ -103,7 +88,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
return
}
func trainDefinition(handle *Handle, model_id string, definition_id string) (accuracy float64, err error) {
func trainDefinition(handle *Handle, model *BaseModel, definition_id string) (accuracy float64, err error) {
accuracy = 0
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil {
@ -128,14 +113,14 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
}
// Generate run folder
run_path := path.Join("/tmp", model_id, "defs", definition_id)
run_path := path.Join("/tmp", model.Id, "defs", definition_id)
err = os.MkdirAll(run_path, os.ModePerm)
if err != nil {
return
}
_, err = generateCvs(handle, run_path, model_id)
_, err = generateCvs(handle, run_path, model.Id)
if err != nil { return }
// Create python script
@ -153,8 +138,9 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
if err = tmpl.Execute(f, AnyMap{
"Layers": got,
"Size": got[0].Shape,
"DataDir": path.Join(getDir(), "savedData", model_id, "data"),
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
"RunPath": run_path,
"ColorMode": model.ImageMode,
}); err != nil {
return
}
@ -167,7 +153,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
}
// Copy result around
result_path := path.Join("savedData", model_id, "defs", definition_id)
result_path := path.Join("savedData", model.Id, "defs", definition_id)
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
return
@ -235,7 +221,7 @@ func trainModel(handle *Handle, model *BaseModel) {
}
for _, def := range definitions {
accuracy, err := trainDefinition(handle, model.Id, def.id)
accuracy, err := trainDefinition(handle, model, def.id)
if err != nil {
fmt.Printf("Failed to train definition!Err:\n")
fmt.Println(err)

View File

@ -10,8 +10,9 @@ type BaseModel struct {
Status int
Id string
Width int
Height int
ImageMode int
Width int
Height int
}
const (
@ -30,7 +31,7 @@ const (
var ModelNotFoundError = errors.New("Model not found error")
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
rows, err := db.Query("select name, status, id, width, height from models where id=$1;", id)
rows, err := db.Query("select name, status, id, width, height, color_mode from models where id=$1;", id)
if err != nil {
return
}
@ -41,10 +42,17 @@ func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
}
base = &BaseModel{}
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height)
var colorMode string
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode)
if err != nil {
return nil, err
}
switch colorMode {
case "greyscale":
base.ImageMode = 1
default:
panic("unkown color mode")
}
return
}

View File

@ -32,10 +32,7 @@ def pathToLabel(path):
#return tf.strings.as_string([path])
def decode_image(img):
# channels were reduced to 1 since image is grayscale
# TODO chnage channel number based if grayscale
img = tf.io.decode_png(img, channels=3)
img = tf.io.decode_png(img, channels={{.ColorMode}})
return tf.image.resize(img, image_size)
def process_path(path):